Lee et al. (2019) Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent
definitions
- The training dataset is defined as . The collection of all inputs is denoted as and the labels as .
- The vector represents the collection of all trainable network parameters (weights and biases) concatenated together. The parameters at a specific training time are , and the initial parameters are .
- The output (logits) of the neural network for an input at time is denoted as .
- The empirical Neural Tangent Kernel at time is an evolving matrix defined as .
setup
- The model is a feed-forward neural network with hidden layers of width . The weights are drawn from a standard normal distribution and scaled using the “NTK parameterization” (e.g., scaled by a factor of ).
- The network is optimized using Mean Squared Error (MSE) loss, defined as . Under continuous-time gradient descent (gradient flow) with a learning rate , the parameters evolve according to
By applying the chain rule, the evolution of the network’s predictions can be perfectly described using the tangent kernel:
We can define a simplified, linearized version of the network using a first-order Taylor expansion around its initial parameters:
Because is constant, this linearized model’s dynamics rely on a fixed NTK .
Thm. 2.1 (informal)
For a network with identically sized hidden layers of width , trained with gradient descent at a learning rate (where depends on ‘s, the analytic NTK, eigenvalues; ; stronger than the condition for max stable learning rate for linear models), the network’s behavior converges to that of its linearized model as .
Specifically, with probability arbitrarily close to 1 over the random initialization, the maximum difference between the real network’s output and the linearized network’s output over all time is bounded:
Similarly, the change in the weights and the shift in the empirical kernel are also bounded by .
intuition.
As the network width becomes massive, we enter the “lazy regime,” where the updates to individual weights during training become vanishingly small. Intuitively, if there are millions, or even billions of parameters, one individual weight does not have to change much for the output of the network to change significantly. Even though a single weight barely moves, the microscopic individual updates collectively combine (or “conspire”) to produce a significant, finite change in the network’s final output.
The strategy for the proof of theorem 2.1 is as follows: First, we need to show that the empirical NTK stays approximately constant (). Because the individual weights need only travel a tiny distance () from initialization in order to reach optima, the gradient of the output with respect to the weights barely changes. This guarantees that the empirical tangent kernel stays effectively constant throughout the entire training process.
The next and final step is to show that . Recall that
The actual network’s training trajectory can be viewed as the linearized network’s trajectory plus some fluctuation caused by the slight shifting of the NTK . By applying (the integral form of) Grönwall’s Inequality, we can strictly bound this compounding error integral to mathematically prove that the models stay aligned.
disclaimer: this note was mostly transcribed by Gemini