Neural Network Training Dynamics

1: A Toy Model: Linear Regression

Why toy models

NNs are too complex to prove things about. However, we can notice phenomena and try to replicate them in toy models such as linear regression.

We can then use the toy models to make predictions that we can test empirically. In this way, simple models that offer many predictions are more lucrative than complex models.

Convex quadratic objective equation

\[ w_* \in \arg\min_w \frac{1}{2} w^TAw + b^Tw + c \]

Rigid transformations in SGD

SGD is translation invariant, meaning the solutions don’t change wrt rotations and translations.

Spectral decomposition

Any symmetric matrix \(A\) has a full set of eigenvectors, all eigenvalues are real, and eigenvectors can be taken to be orthogonal.

Basically, we can rewrite \(A\) as:

\[ A = QDQ^T \]

Where \(Q\) is orthogonal, and \(D\) is a diagonal matrix containing the eigenvalues of \(A\).

Analysing closed-form dynamics

2: Taylor Approximations

Hessian, in terms of Taylor appoximation

A second-order Taylor approximation. The second derivative of the cost function wrt. the weights.

Jacobian notation

The Jacobian of \(y\) wrt. \(w\) is: $$ [J_{yw}]_{ij} = $$

Vector-Jacobian product (VJP)

The goal of the VJP is to compute \(J^Tv\), where \(J\) is the Jacobian of a function \(f\) and point \(x\), and \(v\) is a vector.

\[ \begin{align} x \in \mathbb{R}^{n} \\ y \in \mathbb{R}^{m} \\ v \in \mathbb{R}^{m} \\ \\ J \in \mathbb{R}^{m \times n} \\ J^T \in \mathbb{R}^{n \times m} \\ J^Tv \in \mathbb{R}^{n} \\ \end{align} \]

Gradient in terms of VJP

The gradient is the derivative of the loss function wrt. the weights. This is equivalent to the vector-Jacobian product wrt. the weights, where \(v = [1]\).

Gateaux derivative

This is the directional derivative of \(f\) in the direction of \(\Delta w\).

\[ \mathcal{R}_{\Delta \mathbf{w}} f(\mathbf{w}) = \lim_{h \to 0} \frac{f(\mathbf{w} + \Delta \mathbf{w}) - f(\mathbf{w})}{h} \]

Gateaux derivative in terms of JVP

\[ \mathcal{R}_{\Delta \mathbf{w}} f(\mathbf{w}) = \mathbf{J}_{\mathbf{y}\mathbf{w}} \Delta \mathbf{w} \]

Why JVP+VJP instead of computing J?

The difference between JVP and VJP is only a transpose of \(J\). The reason why we don’t just compute \(J\) and simply transpose it, is that this can be prohibitively expensive if the inputs and outputs of \(f\) are high dimensional. So, the VJP and JVP functions are a way to compute this cheaply.

Jacobian and Hessian landscape interpretation

Say we have a function \(f\), its Jacobian \(\mathbf{J}\), and its Hessian \(\mathbf{H}\).

Convex functions

“Bowl shaped” functions. A function is convex if any line segment connecting two points on its graph lies entirely above the line. Formally:

\[ \forall x_1, x_2, \lambda \in [0, 1]. \lambda f(x_1) + (1-\lambda) f(x_2) >= f(\lambda x_1 + (1-\lambda)x_2) \]

Convex functions in terms of Hessian

A function is convex iff the Hessian is positive semi-definite, i.e. \(\mathbf{v}^T \mathbf{H} \mathbf{v} >= 0\) for all \(\mathbf{v}\), or all eigenvalues of \(\mathbf{H}\) are \(\geq 0\).

Stationary point definition

Any \(\mathbf{w}_*\) such that \(\nabla \mathcal{J}(\mathbf{w}_*) = \mathbf{0}\).

Categorising stationary points with Hessians

Assuming \(\mathbf{w}_*\) is a stationary point:

Stable stationary points

In dynamical systems terminology, saddle points are unstable, and local optima are stable.

Eigenspectrum of the Hessian in practice

The eigenspectrum (the set of eigenvalues) typically:

Hessian-vector products

The product of the Hessian \(\mathbf{H} \in \mathbb{R}^{m \times n \times n}\) with a vector \(\mathbf{v} \in \mathbb{R}^n\), \(\mathbf{Hv} \in \mathbb{R}^{m \times n}\).

Calculating HVP with JVP and VPJ

\[ \begin{align} f \in \mathbb{R}^n &\in \mathbb{R}^m \\ \mathbf{y} &= f(\mathbf{x}) \\ \mathbf{J_{yx}} &\in \mathbb{R}^{m \times n} \\ \mathbf{H_{yx}} &\in \mathbb{R}^{m \times n \times n} \\ \mathbf{H_{yx}v} &\in \mathbb{R}^{m \times n} \\ \\ \mathbf{H_{yx}} &= \mathbf{J_{J_{yx}x}} \\ \mathbf{H_{yx}v} &= \mathbf{J_{J_{yx}x}v} \\ \end{align} \]

Hence, to calculate the HVP you first calculate the Jacobian via JVP over \(f\) where \(\mathbf{v} = [1]\),then calculate the VJP over the JVP with your given \(\mathbf{v}\) value.

ReLU and Hessian

We often can’t calculate the Hessian of real neural networks, as the ReLU isn’t twice differentiable: the second derivative is always zero.

Gauss-Newton Hessian

The Hessian of the cost function if we linearise the neural network around the current weights – but do not linearise the cost function.

If we decompose a cost function \(\mathcal{J}\) into: \[ \begin{align} \mathbf{y} &= \mathcal{L}(\mathbf{z}, \mathbf{t}) \\ \mathbf{z} &= f(\mathbf{x}, \mathbf{w}) \\ \end{align} \]

Then we can decompose the Hessian into:

\[ \nabla^2 \mathcal{J}_{\mathbf{t},\mathbf{x}} = \mathbf{J_{zw}}^T \mathbf{H_{yz}} \mathbf{J_{zw}} + \sum_a \frac{\delta \mathcal{L}}{\delta y_a} \nabla^2[f(\mathbf{x}, \mathbf{w})]_a \]

If we only take the first term, we get the Gauss-Newton Hessian:

\[ \mathbf{G} = \mathbf{J_{zw}}^T \mathbf{H_{yz}} \mathbf{J_{zw}} \] #### Why use Gauss-Newton Hessian?

Response function

Say we’ve found some optimal solution \(\mathbf{w}_* = \arg\min_w \mathcal{J}(\mathbf{w}; \theta)\) given hyperparameters \(\mathbf{\theta}\). Say we want to see how the hyperparameters influence the optimal solution.

We can reparameterise the problem as a response function which we then go on to analyse: \[ \mathbf{w_*} = r(\mathbf{\theta}) \]

The Implicit Function Theorem states that we can reparameterise this way, given some constraints.

Response Jacobian

\[ \mathbf{J_{w_*\theta}} = - [\nabla^2_{\mathbf{w}} \mathcal{J}(\mathbf{w}; \mathbf{\theta})]^{-1} [\nabla^2_{\mathbf{w \theta}} \mathcal{J}(\mathbf{w}; \mathbf{\theta})] \]

2.1: Forward and backward mode autodiff

Some revision on forward mode and reverse mode autodiff, as I was getting confused in lecture 2. ### Forward mode

Forward mode with evaluation traces

As you go through a function, record all intermediate values – including the inputs and outputs. Then, while you’re doing this, you can also compute the derivatives. For example, if your trace consists of \(x_1, x_2, i_1, i_2, y_1\), and you wanted to compute \(\frac{dy_1}{dx_1}\), then while computing the variables you’d also compute \(\frac{dx_1}{dx_1}, \frac{dx_2}{dx_1}, ...\).

Forward mode with dual numbers

We replace our values with dual numbers of the form \(a + b\epsilon\), where addition and multiplication work as expected, and \(\epsilon^2 = 0\). This means we can calculate the primals and tangents at once.

Forward mode complexity

Forward mode is linear in the number of inputs, as we’re calculating the differential for each of them separately. One forward pass (although they can be done “in parallel”) needs to be done for each input parameter.

Forward mode for JVPs

When we want to do \(\mathbf{Jv}\), the \(\mathbf{v}\) is being applied to the input space of the function. When we do forward mode autodiff, we’re doing this parameter-by-parameter, which is equivalent to multiplying the input space by e.g. \([1, 0, 0]\) to get the first of three parameters. Using \(\mathbf{v}\) generalises this to be any direction in input space.

Reverse mode

We first compute the evaluation trace, i.e. the values of all intermediate variables, including inputs and outputs. For example,\(x_1, x_2, i_1, i_2, y_1\). Then, we calculate \(\frac{dy_1}{di_2}\), which is possible now that we know the values of \(y_1\) and \(i_2\). We then apply the chain rule progressively to eventually get what we care about: \(\frac{dy_1}{dx_1}\) and \(\frac{dy_1}{dx_2}\).

Reverse mode complexity

We need to do one forward pass, then a backward pass for every output variable.

Tying back to jax

jax.jvp computes the Jacobian, and then does a transformation on its input space. Therefore, it uses forward mode autodiff. It must take \(\mathbf{v}\) when initially called, as this is required to transform the input space st. we’re only have “one input variable”.

jax.vjp computes the Jacobian, and then does a transformation on its output space. Therefore, it uses reverse mode autodiff. When initially called, it doesn’t take \(\mathbf{v}\). I presume this is because the initial vjp call performs the forward mode, calculating all intermediate values. Then, vjp returns a function that we can call with \(\mathbf{v}\), thus actually performing the backwards pass.

N.B.: The jvp and vjp functions are intentionally cheap, and as a result, their type signature is st. they always return a 1d vector. If you want to e.g. compute the full Jacobian or Hessian, then this will require multiple jvp or vjp calls.