# Neural Network Training Dynamics

## 1: A Toy Model: Linear Regression

### Why toy models

NNs are too complex to prove things about. However, we can notice phenomena and try to replicate them in toy models such as linear regression.

We can then use the toy models to make predictions that we can test empirically. In this way, simple models that offer many predictions are more lucrative than complex models.

### Convex quadratic objective equation

\[ w_* \in \arg\min_w \frac{1}{2} w^TAw + b^Tw + c \]

### Rigid transformations in SGD

SGD is translation invariant, meaning the solutions don’t change wrt rotations and translations.

### Spectral decomposition

Any symmetric matrix \(A\) has a full set of eigenvectors, all eigenvalues are real, and eigenvectors can be taken to be orthogonal.

Basically, we can rewrite \(A\) as:

\[ A = QDQ^T \]

Where \(Q\) is orthogonal, and \(D\) is a diagonal matrix containing the eigenvalues of \(A\).

### Analysing closed-form dynamics

- We use spectral decomposition to rotate to a basis where the eigenvalues are orthogonal.
- Thus, we can analyse each axis independently.
- We rewrite the per-axis updates as \(w_j = w_j - \alpha(a_jw_j + b_j)\).
- We handle the cases of \(a_j\) separately:
- If \(a_j > 0\), the update is stable, depending on \(\alpha\).
- If \(a_j = 0, b_j \neq 0\), the direction is irrelevant, and \(w_j\) never converges.
- If \(a_j = 0, b_j = 0\), \(w_j\) is never updated.

## 2: Taylor Approximations

### Hessian, in terms of Taylor appoximation

A second-order Taylor approximation. The second derivative of the cost function wrt. the weights.

### Jacobian notation

The Jacobian of \(y\) wrt. \(w\) is: $$ [J_{yw}]_{ij} = $$

### Vector-Jacobian product (VJP)

The goal of the VJP is to compute \(J^Tv\), where \(J\) is the Jacobian of a function \(f\) and point \(x\), and \(v\) is a vector.

\[ \begin{align} x \in \mathbb{R}^{n} \\ y \in \mathbb{R}^{m} \\ v \in \mathbb{R}^{m} \\ \\ J \in \mathbb{R}^{m \times n} \\ J^T \in \mathbb{R}^{n \times m} \\ J^Tv \in \mathbb{R}^{n} \\ \end{align} \]

#### Gradient in terms of VJP

The gradient is the derivative of the loss function wrt. the weights. This is equivalent to the vector-Jacobian product wrt. the weights, where \(v = [1]\).

### Gateaux derivative

This is the *directional derivative* of \(f\) in the direction of
\(\Delta w\).

\[ \mathcal{R}_{\Delta \mathbf{w}} f(\mathbf{w}) = \lim_{h \to 0} \frac{f(\mathbf{w} + \Delta \mathbf{w}) - f(\mathbf{w})}{h} \]

#### Gateaux derivative in terms of JVP

\[ \mathcal{R}_{\Delta \mathbf{w}} f(\mathbf{w}) = \mathbf{J}_{\mathbf{y}\mathbf{w}} \Delta \mathbf{w} \]

### Why JVP+VJP instead of computing J?

The difference between JVP and VJP is only a transpose of \(J\). The reason why we don’t just compute \(J\) and simply transpose it, is that this can be prohibitively expensive if the inputs and outputs of \(f\) are high dimensional. So, the VJP and JVP functions are a way to compute this cheaply.

### Jacobian and Hessian landscape interpretation

Say we have a function \(f\), its Jacobian \(\mathbf{J}\), and its Hessian \(\mathbf{H}\).

- For a direction \(\mathbf{v}\), if \(\mathbf{Jv} > 0\), then the direction points upwards.
- For a direction \(\mathbf{v}\), if \(\mathbf{v}^T\mathbf{Hv} > 0\), then the direction curves upwards.

### Convex functions

“Bowl shaped” functions. A function is convex if any line segment connecting two points on its graph lies entirely above the line. Formally:

\[ \forall x_1, x_2, \lambda \in [0, 1]. \lambda f(x_1) + (1-\lambda) f(x_2) >= f(\lambda x_1 + (1-\lambda)x_2) \]

#### Convex functions in terms of Hessian

A function is convex iff the Hessian is positive semi-definite, i.e. \(\mathbf{v}^T \mathbf{H} \mathbf{v} >= 0\) for all \(\mathbf{v}\), or all eigenvalues of \(\mathbf{H}\) are \(\geq 0\).

### Stationary point definition

Any \(\mathbf{w}_*\) such that \(\nabla \mathcal{J}(\mathbf{w}_*) = \mathbf{0}\).

### Categorising stationary points with Hessians

Assuming \(\mathbf{w}_*\) is a stationary point:

- If \(\mathbf{H}\) is positive definite, then \(\mathbf{w}_*\) is a local minimum.
- If \(\mathbf{H}\) is negative definite, then \(\mathbf{w}_*\) is a local maximum.
- If \(\mathbf{H}\) has both positive and negative eigenvalues, then \(\mathbf{w}_*\) is a saddle point.
- If \(\mathbf{H}\) has zero-valued eigenvalues, then we can’t say whether its a minimum/maximum/saddle point. The dynamics depends on higher-order derivatives.

### Stable stationary points

In dynamical systems terminology, saddle points are unstable, and local optima are stable.

### Eigenspectrum of the Hessian in practice

The eigenspectrum (the set of eigenvalues) typically:

Will have many zero values, as networks are over-parameterised.

At initialisation, will have many large positive and negative eigenvalues.

Large negative eigenvalues are eliminated quickly (although I’m unsure why this isn’t true for large positive eigenvalues).

During training, we expect to see large positive eigenvalues, small positive and negative eigenvalues, and many zero eigenvalues.

### Hessian-vector products

The product of the Hessian \(\mathbf{H} \in \mathbb{R}^{m \times n \times n}\) with a vector \(\mathbf{v} \in \mathbb{R}^n\), \(\mathbf{Hv} \in \mathbb{R}^{m \times n}\).

### Calculating HVP with JVP and VPJ

\[ \begin{align} f \in \mathbb{R}^n &\in \mathbb{R}^m \\ \mathbf{y} &= f(\mathbf{x}) \\ \mathbf{J_{yx}} &\in \mathbb{R}^{m \times n} \\ \mathbf{H_{yx}} &\in \mathbb{R}^{m \times n \times n} \\ \mathbf{H_{yx}v} &\in \mathbb{R}^{m \times n} \\ \\ \mathbf{H_{yx}} &= \mathbf{J_{J_{yx}x}} \\ \mathbf{H_{yx}v} &= \mathbf{J_{J_{yx}x}v} \\ \end{align} \]

Hence, to calculate the HVP you first calculate the Jacobian via JVP over \(f\) where \(\mathbf{v} = [1]\),then calculate the VJP over the JVP with your given \(\mathbf{v}\) value.

### ReLU and Hessian

We often can’t calculate the Hessian of real neural networks, as the ReLU isn’t twice differentiable: the second derivative is always zero.

### Gauss-Newton Hessian

The Hessian of the cost function if we linearise the neural network around the current weights – but do not linearise the cost function.

If we decompose a cost function \(\mathcal{J}\) into: \[ \begin{align} \mathbf{y} &= \mathcal{L}(\mathbf{z}, \mathbf{t}) \\ \mathbf{z} &= f(\mathbf{x}, \mathbf{w}) \\ \end{align} \]

Then we can decompose the Hessian into:

\[ \nabla^2 \mathcal{J}_{\mathbf{t},\mathbf{x}} = \mathbf{J_{zw}}^T \mathbf{H_{yz}} \mathbf{J_{zw}} + \sum_a \frac{\delta \mathcal{L}}{\delta y_a} \nabla^2[f(\mathbf{x}, \mathbf{w})]_a \]

If we only take the first term, we get the Gauss-Newton Hessian:

\[ \mathbf{G} = \mathbf{J_{zw}}^T \mathbf{H_{yz}} \mathbf{J_{zw}} \] #### Why use Gauss-Newton Hessian?

- We only require the first derivative of the network, which means the network doesn’t have to be twice-differentiable, and we are free to use e.g. ReLUs.
- As long as the loss function is convex, then \(\mathbf{G}\) is guaranteed to be positive semi-definite.

### Response function

Say we’ve found some optimal solution \(\mathbf{w}_* = \arg\min_w \mathcal{J}(\mathbf{w}; \theta)\) given hyperparameters \(\mathbf{\theta}\). Say we want to see how the hyperparameters influence the optimal solution.

We can reparameterise the problem as a *response
function* which we then go on to analyse: \[
\mathbf{w_*} = r(\mathbf{\theta})
\]

The Implicit Function Theorem states that we can reparameterise this way, given some constraints.

#### Response Jacobian

\[ \mathbf{J_{w_*\theta}} = - [\nabla^2_{\mathbf{w}} \mathcal{J}(\mathbf{w}; \mathbf{\theta})]^{-1} [\nabla^2_{\mathbf{w \theta}} \mathcal{J}(\mathbf{w}; \mathbf{\theta})] \]

## 2.1: Forward and backward mode autodiff

Some revision on forward mode and reverse mode autodiff, as I was getting confused in lecture 2. ### Forward mode

#### Forward mode with evaluation traces

As you go through a function, record all intermediate values – including the inputs and outputs. Then, while you’re doing this, you can also compute the derivatives. For example, if your trace consists of \(x_1, x_2, i_1, i_2, y_1\), and you wanted to compute \(\frac{dy_1}{dx_1}\), then while computing the variables you’d also compute \(\frac{dx_1}{dx_1}, \frac{dx_2}{dx_1}, ...\).

#### Forward mode with dual numbers

We replace our values with *dual numbers* of
the form \(a +
b\epsilon\), where addition and multiplication
work as expected, and \(\epsilon^2 = 0\). This
means we can calculate the primals and tangents at
once.

#### Forward mode complexity

Forward mode is linear in the number of inputs, as we’re calculating the differential for each of them separately. One forward pass (although they can be done “in parallel”) needs to be done for each input parameter.

#### Forward mode for JVPs

When we want to do \(\mathbf{Jv}\), the \(\mathbf{v}\) is being
applied to the *input space* of the function.
When we do forward mode autodiff, we’re doing this
parameter-by-parameter, which is equivalent to
multiplying the input space by e.g. \([1, 0, 0]\) to get the
first of three parameters. Using \(\mathbf{v}\) generalises
this to be *any* direction in input space.

### Reverse mode

We first compute the evaluation trace, i.e. the values of all intermediate variables, including inputs and outputs. For example,\(x_1, x_2, i_1, i_2, y_1\). Then, we calculate \(\frac{dy_1}{di_2}\), which is possible now that we know the values of \(y_1\) and \(i_2\). We then apply the chain rule progressively to eventually get what we care about: \(\frac{dy_1}{dx_1}\) and \(\frac{dy_1}{dx_2}\).

#### Reverse mode complexity

We need to do one forward pass, then a backward pass for every output variable.

### Tying back to
`jax`

`jax.jvp`

computes the Jacobian, and then
does a transformation on its *input space*.
Therefore, it uses forward mode autodiff. It must take
\(\mathbf{v}\) when
initially called, as this is required to transform the
input space st. we’re only have “one input
variable”.

`jax.vjp`

computes the Jacobian, and then
does a transformation on its *output space*.
Therefore, it uses reverse mode autodiff. When initially
called, it doesn’t take \(\mathbf{v}\). I presume
this is because the initial `vjp`

call
performs the forward mode, calculating all intermediate
values. Then, `vjp`

returns a function that
we can call with \(\mathbf{v}\), thus actually
performing the backwards pass.

N.B.: The `jvp`

and `vjp`

functions are intentionally cheap, and as a result,
their type signature is st. they always return a 1d
vector. If you want to e.g. compute the full Jacobian or
Hessian, then this will require multiple
`jvp`

or `vjp`

calls.