Neural Network Training Dynamics
1: A Toy Model: Linear Regression
Why toy models
NNs are too complex to prove things about. However, we can notice phenomena and try to replicate them in toy models such as linear regression.
We can then use the toy models to make predictions that we can test empirically. In this way, simple models that offer many predictions are more lucrative than complex models.
Convex quadratic objective equation
\[ w_* \in \arg\min_w \frac{1}{2} w^TAw + b^Tw + c \]
Rigid transformations in SGD
SGD is translation invariant, meaning the solutions don’t change wrt rotations and translations.
Spectral decomposition
Any symmetric matrix \(A\) has a full set of eigenvectors, all eigenvalues are real, and eigenvectors can be taken to be orthogonal.
Basically, we can rewrite \(A\) as:
\[ A = QDQ^T \]
Where \(Q\) is orthogonal, and \(D\) is a diagonal matrix containing the eigenvalues of \(A\).
Analysing closed-form dynamics
- We use spectral decomposition to rotate to a basis where the eigenvalues are orthogonal.
- Thus, we can analyse each axis independently.
- We rewrite the per-axis updates as \(w_j = w_j - \alpha(a_jw_j + b_j)\).
- We handle the cases of \(a_j\) separately:
- If \(a_j > 0\), the update is stable, depending on \(\alpha\).
- If \(a_j = 0, b_j \neq 0\), the direction is irrelevant, and \(w_j\) never converges.
- If \(a_j = 0, b_j = 0\), \(w_j\) is never updated.
2: Taylor Approximations
Hessian, in terms of Taylor appoximation
A second-order Taylor approximation. The second derivative of the cost function wrt. the weights.
Jacobian notation
The Jacobian of \(y\) wrt. \(w\) is: $$ [J_{yw}]_{ij} = $$
Vector-Jacobian product (VJP)
The goal of the VJP is to compute \(J^Tv\), where \(J\) is the Jacobian of a function \(f\) and point \(x\), and \(v\) is a vector.
\[ \begin{align} x \in \mathbb{R}^{n} \\ y \in \mathbb{R}^{m} \\ v \in \mathbb{R}^{m} \\ \\ J \in \mathbb{R}^{m \times n} \\ J^T \in \mathbb{R}^{n \times m} \\ J^Tv \in \mathbb{R}^{n} \\ \end{align} \]
Gradient in terms of VJP
The gradient is the derivative of the loss function wrt. the weights. This is equivalent to the vector-Jacobian product wrt. the weights, where \(v = [1]\).
Gateaux derivative
This is the directional derivative of \(f\) in the direction of \(\Delta w\).
\[ \mathcal{R}_{\Delta \mathbf{w}} f(\mathbf{w}) = \lim_{h \to 0} \frac{f(\mathbf{w} + \Delta \mathbf{w}) - f(\mathbf{w})}{h} \]
Gateaux derivative in terms of JVP
\[ \mathcal{R}_{\Delta \mathbf{w}} f(\mathbf{w}) = \mathbf{J}_{\mathbf{y}\mathbf{w}} \Delta \mathbf{w} \]
Why JVP+VJP instead of computing J?
The difference between JVP and VJP is only a transpose of \(J\). The reason why we don’t just compute \(J\) and simply transpose it, is that this can be prohibitively expensive if the inputs and outputs of \(f\) are high dimensional. So, the VJP and JVP functions are a way to compute this cheaply.
Jacobian and Hessian landscape interpretation
Say we have a function \(f\), its Jacobian \(\mathbf{J}\), and its Hessian \(\mathbf{H}\).
- For a direction \(\mathbf{v}\), if \(\mathbf{Jv} > 0\), then the direction points upwards.
- For a direction \(\mathbf{v}\), if \(\mathbf{v}^T\mathbf{Hv} > 0\), then the direction curves upwards.
Convex functions
“Bowl shaped” functions. A function is convex if any line segment connecting two points on its graph lies entirely above the line. Formally:
\[ \forall x_1, x_2, \lambda \in [0, 1]. \lambda f(x_1) + (1-\lambda) f(x_2) >= f(\lambda x_1 + (1-\lambda)x_2) \]
Convex functions in terms of Hessian
A function is convex iff the Hessian is positive semi-definite, i.e. \(\mathbf{v}^T \mathbf{H} \mathbf{v} >= 0\) for all \(\mathbf{v}\), or all eigenvalues of \(\mathbf{H}\) are \(\geq 0\).
Stationary point definition
Any \(\mathbf{w}_*\) such that \(\nabla \mathcal{J}(\mathbf{w}_*) = \mathbf{0}\).
Categorising stationary points with Hessians
Assuming \(\mathbf{w}_*\) is a stationary point:
- If \(\mathbf{H}\) is positive definite, then \(\mathbf{w}_*\) is a local minimum.
- If \(\mathbf{H}\) is negative definite, then \(\mathbf{w}_*\) is a local maximum.
- If \(\mathbf{H}\) has both positive and negative eigenvalues, then \(\mathbf{w}_*\) is a saddle point.
- If \(\mathbf{H}\) has zero-valued eigenvalues, then we can’t say whether its a minimum/maximum/saddle point. The dynamics depends on higher-order derivatives.
Stable stationary points
In dynamical systems terminology, saddle points are unstable, and local optima are stable.
Eigenspectrum of the Hessian in practice
The eigenspectrum (the set of eigenvalues) typically:
Will have many zero values, as networks are over-parameterised.
At initialisation, will have many large positive and negative eigenvalues.
Large negative eigenvalues are eliminated quickly (although I’m unsure why this isn’t true for large positive eigenvalues).
During training, we expect to see large positive eigenvalues, small positive and negative eigenvalues, and many zero eigenvalues.
Hessian-vector products
The product of the Hessian \(\mathbf{H} \in \mathbb{R}^{m \times n \times n}\) with a vector \(\mathbf{v} \in \mathbb{R}^n\), \(\mathbf{Hv} \in \mathbb{R}^{m \times n}\).
Calculating HVP with JVP and VPJ
\[ \begin{align} f \in \mathbb{R}^n &\in \mathbb{R}^m \\ \mathbf{y} &= f(\mathbf{x}) \\ \mathbf{J_{yx}} &\in \mathbb{R}^{m \times n} \\ \mathbf{H_{yx}} &\in \mathbb{R}^{m \times n \times n} \\ \mathbf{H_{yx}v} &\in \mathbb{R}^{m \times n} \\ \\ \mathbf{H_{yx}} &= \mathbf{J_{J_{yx}x}} \\ \mathbf{H_{yx}v} &= \mathbf{J_{J_{yx}x}v} \\ \end{align} \]
Hence, to calculate the HVP you first calculate the Jacobian via JVP over \(f\) where \(\mathbf{v} = [1]\),then calculate the VJP over the JVP with your given \(\mathbf{v}\) value.
ReLU and Hessian
We often can’t calculate the Hessian of real neural networks, as the ReLU isn’t twice differentiable: the second derivative is always zero.
Gauss-Newton Hessian
The Hessian of the cost function if we linearise the neural network around the current weights – but do not linearise the cost function.
If we decompose a cost function \(\mathcal{J}\) into: \[ \begin{align} \mathbf{y} &= \mathcal{L}(\mathbf{z}, \mathbf{t}) \\ \mathbf{z} &= f(\mathbf{x}, \mathbf{w}) \\ \end{align} \]
Then we can decompose the Hessian into:
\[ \nabla^2 \mathcal{J}_{\mathbf{t},\mathbf{x}} = \mathbf{J_{zw}}^T \mathbf{H_{yz}} \mathbf{J_{zw}} + \sum_a \frac{\delta \mathcal{L}}{\delta y_a} \nabla^2[f(\mathbf{x}, \mathbf{w})]_a \]
If we only take the first term, we get the Gauss-Newton Hessian:
\[ \mathbf{G} = \mathbf{J_{zw}}^T \mathbf{H_{yz}} \mathbf{J_{zw}} \] #### Why use Gauss-Newton Hessian?
- We only require the first derivative of the network, which means the network doesn’t have to be twice-differentiable, and we are free to use e.g. ReLUs.
- As long as the loss function is convex, then \(\mathbf{G}\) is guaranteed to be positive semi-definite.
Response function
Say we’ve found some optimal solution \(\mathbf{w}_* = \arg\min_w \mathcal{J}(\mathbf{w}; \theta)\) given hyperparameters \(\mathbf{\theta}\). Say we want to see how the hyperparameters influence the optimal solution.
We can reparameterise the problem as a response function which we then go on to analyse: \[ \mathbf{w_*} = r(\mathbf{\theta}) \]
The Implicit Function Theorem states that we can reparameterise this way, given some constraints.
Response Jacobian
\[ \mathbf{J_{w_*\theta}} = - [\nabla^2_{\mathbf{w}} \mathcal{J}(\mathbf{w}; \mathbf{\theta})]^{-1} [\nabla^2_{\mathbf{w \theta}} \mathcal{J}(\mathbf{w}; \mathbf{\theta})] \]
2.1: Forward and backward mode autodiff
Some revision on forward mode and reverse mode autodiff, as I was getting confused in lecture 2. ### Forward mode
Forward mode with evaluation traces
As you go through a function, record all intermediate values – including the inputs and outputs. Then, while you’re doing this, you can also compute the derivatives. For example, if your trace consists of \(x_1, x_2, i_1, i_2, y_1\), and you wanted to compute \(\frac{dy_1}{dx_1}\), then while computing the variables you’d also compute \(\frac{dx_1}{dx_1}, \frac{dx_2}{dx_1}, ...\).
Forward mode with dual numbers
We replace our values with dual numbers of the form \(a + b\epsilon\), where addition and multiplication work as expected, and \(\epsilon^2 = 0\). This means we can calculate the primals and tangents at once.
Forward mode complexity
Forward mode is linear in the number of inputs, as we’re calculating the differential for each of them separately. One forward pass (although they can be done “in parallel”) needs to be done for each input parameter.
Forward mode for JVPs
When we want to do \(\mathbf{Jv}\), the \(\mathbf{v}\) is being applied to the input space of the function. When we do forward mode autodiff, we’re doing this parameter-by-parameter, which is equivalent to multiplying the input space by e.g. \([1, 0, 0]\) to get the first of three parameters. Using \(\mathbf{v}\) generalises this to be any direction in input space.
Reverse mode
We first compute the evaluation trace, i.e. the values of all intermediate variables, including inputs and outputs. For example,\(x_1, x_2, i_1, i_2, y_1\). Then, we calculate \(\frac{dy_1}{di_2}\), which is possible now that we know the values of \(y_1\) and \(i_2\). We then apply the chain rule progressively to eventually get what we care about: \(\frac{dy_1}{dx_1}\) and \(\frac{dy_1}{dx_2}\).
Reverse mode complexity
We need to do one forward pass, then a backward pass for every output variable.
Tying back to
jax
jax.jvp
computes the Jacobian, and then
does a transformation on its input space.
Therefore, it uses forward mode autodiff. It must take
\(\mathbf{v}\) when
initially called, as this is required to transform the
input space st. we’re only have “one input
variable”.
jax.vjp
computes the Jacobian, and then
does a transformation on its output space.
Therefore, it uses reverse mode autodiff. When initially
called, it doesn’t take \(\mathbf{v}\). I presume
this is because the initial vjp
call
performs the forward mode, calculating all intermediate
values. Then, vjp
returns a function that
we can call with \(\mathbf{v}\), thus actually
performing the backwards pass.
N.B.: The jvp
and vjp
functions are intentionally cheap, and as a result,
their type signature is st. they always return a 1d
vector. If you want to e.g. compute the full Jacobian or
Hessian, then this will require multiple
jvp
or vjp
calls.