Deep learning has shown remarkable success across a variety of machine learning tasks. At the same time, our theoretical understanding of deep learning methods remains limited. In particular, the interplay between training dynamics, properties of the learned network, and generalization remains a largely open problem.
In this work we take a step toward addressing these questions. We present a dynamical mechanism that allows deep networks trained using SGD to find flat minima and achieve superior performance. Our theoretical predictions agree well with empirical results in a variety of deep learning settings. In many cases we are able to predict the regime of learning rates where optimal performance is achieved. Figure 1 summarizes our main results. This work builds on several existing results, which we now review.
1.1. Large learning rate SGD improves generalization
SGD training with large initial learning rates often leads to improved performance over training with small initial learning rates (see Li et al. (2019); Leclerc & Madry (2020); Xie et al. (2020); Frankle et al. (2020); Jastrzebski et al. (2020) for recent discussions). It has been suggested that one of the mechanisms underlying the benefit of large learning rates is that noise from stochastic gradient descent leads to flat minima, and that flat minima generalize better than sharp minima (Hochreiter & Schmidhuber, 1997; Keskar et al., 2016; Smith & Le, 2018; Jiang et al., 2020; Park et al., 2019) (though see Dinh et al. (2017) for discussion of some caveats). According to this suggestion, training with a large learning rate (or with a small batch size) can improve performance because it leads to more stochasticity during training (Mandt et al., 2017; Smith et al., 2017; Smith & Le, 2018; Smith et al., 2018).
We will develop a connection between large learning rate and flatness of minima in models trained via SGD. Unlike the relationship explored in most previous work though, this connection is not driven by SGD noise, but arises solely as a result of training with a large initial learning rate, and holds even for full batch gradient descent.
Figure 1. A summary of our main results. (a) A visualization of gradient descent dynamics derived in our theoretical setup. A 2D slice of parameter space is shown, where lighter color indicates higher loss and dots represents points visited during optimization. Initially, the loss grows rapidly while local curvature decreases. Once curvature is sufficiently low, gradient descent converges to a flat minimum. We call this the catapult effect. See Figures 2 and S1 for more details. (b) Confirmation of our theoretical predictions in a practical deep learning setting. Line shows the test accuracy of a Wide ResNet trained on CIFAR-10 as a function of learning rate, each trained for a fixed number of steps. Dashed lines show our predictions for the boundaries of the large learning rate regime (the catapult phase), where we expect optimal performance to occur. Maximal performance is achieved between the dashed lines, confirming our predictions. See Section 3 for details.
1.2. The existing theory of infinite width networks is insufficient to describe large learning rates
A recent body of work has investigated the gradient descent dynamics of deep networks in the limit of infinite width (Daniely, 2017; Jacot et al., 2018; Lee et al., 2019; Du et al., 2019; Zou et al., 2018; Allen-Zhu et al., 2019; Li & Liang, 2018; Chizat et al., 2019; Mei et al., 2018; Rotskoff & Vanden-Eijnden, 2018; Sirignano & Spiliopoulos, 2018; Woodworth et al., 2019; Naveh et al.). Of particular relevance is the work by Jacot et al. (2018) showing that gradient flow in the space of functions is governed by a dynamical quantity called the Neural Tangent Kernel (NTK) which is fixed at its initial value in this limit. Lee et al. (2019) showed this result is equivalent to training the linearization of a model around its initialization in parameter space. Finally, moving away from the strict limit of infinite width by working perturbatively, Dyer & Gur-Ari (2020); Huang & Yau (2019) introduced an approach to computing the finite-width corrections to network evolution.
Despite this progress, it seems these results are insufficient to capture the full dynamics of deep networks, as well as their superior performance, in regimes applicable to practice. Prior work has focused on comparisons between various infinite-width kernels associated with deep networks and their finite-width, SGD-trained counterparts (Lee et al., 2018; Novak et al., 2019; Arora et al., 2019). Specific findings vary depending on precise choices for architecture and hyperparameters. However, dramatic performance gaps are consistently observed between non-linear CNNs and their limiting kernels, implying that the theory is not sufficient to explain the performance of deep networks in this realistic setup. Furthermore, some hyperparameter settings in finite-width models have no known analogue in the infinite width limit, and it is these settings that often lead to optimal performance.
In particular, finite width networks are often trained with large learning rates that would cause divergence for infinite width linearized models. Further, these large learning rates cause finite width networks to converge to flat minima. For infinite width linearized models, trained with MSE loss, all minima have the same curvature, and the notion of flat minima does not apply. We argue that the reduction in curvature during optimization, and support for learning rates that are infeasible for infinite width linearized models, may thus partially explain performance gaps observed between linear and non-linear models.
1.3. Our contribution: three learning rate regimes
In this work, we identify a dynamical mechanism which enables finite-width networks to stably access large learning rates. We show that this mechanism causes training to converge to flatter minima and is associated with improved generalization. We further show that this same mechanism can describe the behavior of infinite width networks, if training time is increased
along with network width.
This new mechanism enables a characterization of gradient descent training in terms of three learning rate regimes, or phases: the lazy phase, the catapult phase, and the divergent phase. In Section 2 we analytically derive the behavior in these three learning rate regimes for one hidden layer linear networks with large but finite width, trained with MSE loss. We confirm experimentally in Section 3 that these phases also apply to deep nonlinear fully- connected, convolutional, and residual architectures. In Section 4 we study additional predictions of the analytic solution.
We now summarize all three phases, using to indicate the learning rate, and
to indicate the initial curvature (defined precisely in Section 2.1). The phase is determined by the curvature at initialization and by the learning rate, despite the fact that the curvature may change significantly during training. Based on the experimental evidence we expect the behavior described below to apply in typical deep learning settings, when training sufficiently wide networks using SGD.
Lazy phase: For sufficiently small learning rate, the curvature
at training step t remains constant during the initial part of training. The model behaves (loosely) as a model linearized about its initial parameters (Lee et al., 2019); this becomes exact in the infinite width limit, where these dynamics are sometimes called lazy training (Jacot et al., 2018; Lee et al., 2019; Du et al., 2019; Li & Liang, 2018; Zou et al., 2018; Allen-Zhu et al., 2019; Chizat et al., 2019; Dyer & Gur-Ari, 2020). For a discussion of trainability and the connection to the NTK in the lazy phase see Xiao et al. (2019).
Catapult phase: In this phase, the curvature at initialization is too high for training to converge to a nearby point, and the linear approximation quickly breaks down. Optimization begins with a period of exponential growth in the loss, coupled with a rapid decrease in curvature, until curvature stabilizes at a value
. Once the curvature drops below
, training converges, ultimately reaching a minimum that is flatter than those found in the lazy phase. This initial period lasts for a number of training steps that is of order log(n), where n is the network width, and is therefore quite short for realistic networks (often lasting less than a single epoch). Optimal performance is often achieved when the initial learning rate is in this range. The gradient descent dynamics in this phase are visualized in SM Figure S1 and in Figure 1.
The maximum learning rate is approximately given by is an architecture-dependent constant. Empirically, we find that this constant depends strongly on the non-linearity but only weakly on other aspects of the architecture. For networks with ReLU non-linearity we find empirically that
. For the theoretical model, we show that
Divergent phase: When the learning rate is above the maximum learning rate of the model, the loss diverges and the model does not train.
We now present our main theoretical result, an analysis of gradient descent dynamics for a neural network with large but finite width.
Given a network function with model parameters
, and a training set
, the MSE loss is
The NTK is defined by
We denote by the maximum eigenvalue of the kernel. In large width models,
provides a local measure of the loss landscape curvature that is similar to the top eigenvalue of the Hessian (Dyer & Gur-Ari, 2020).
In this section, we will consider a network with one hidden layer and linear activations, where the network function f is given by
Here n is the width (number of neurons in the hidden layer), are the model parameters (collectively denoted
is the training input. At initialization, the weights are drawn from N(0, 1).
2.1. Warmup: a simplified model
Before analyzing the dynamics of the model, we analyze a simpler setting which captures the most important aspects of the full solution. Consider a dataset with 1D inputs, and with a single training sample x = 1 with label y = 0. The network function evaluated on this input is then , with
, and the loss is
. The gradient descent equations at training step t are
Next, consider the update equations in function space. These can be written in terms of the Neural Tangent Kernel. For this model, the kernel evaluated on the training set is a scalar which is equal to , its top eigenvalue, and is given by
At initialization, both and
scale as
with width. The following update equations for f and
at step t can be derived from (4).
It is important to note that these are the exact update equations for this model, and that no higher-order terms were neglected. We now analyze these dynamical equations assuming the width n is large. Two learning rates that will be important in the analysis are . In terms of the notation introduced above, the architecture-dependent constant that determines that maximum learning rate in this model is
2.1.1. LAZY PHASE
Taking the strict infinite width limit, equations (6) and (7) become
When remains constant throughout training. This is a special case of NTK dynamics, where the kernel is constant and the network evolves as a linear model (Lee et al., 2019). The function and the loss both shrink to zero because the multiplicative factor obeys
. This convergence happens in
2.1.2. CATAPULT PHASE
When , the loss diverges in the infinite width limit. Indeed, from (8) we see that the kernel is constant in the limit, while f receives multiplicative updates where
. This is the well known instability of gradient descent dynamics for linear models with MSE loss. However, the underlying model is not linear in its parameters, and finite width contributions turn out to be important. We therefore relax the infinite width limit and analyze equations (6,7) for large but finite width,
First, note that by assumption, and therefore the (additive) kernel updates are negative for all t. During early training,
grows (as in the infinite width limit) while
remains constant up to small
updates. After
steps,
grows to order
. At this point, the kernel updates are no longer negligible because
is of order
kernel
receives negative, non-negligible updates while both
and the loss continue to grow (for now, we ignore the term in (6) with an explicit 1/n dependence). This continues until the kernel is sufficiently small that the condition
is met.1 We call this curvature-reduction effect the catapult effect. Beyond this point,
shrinks, and the loss converges to a global minimum. The n dependence of the steps until optimization converges is log (n).
It remains to show that the term in (6) with an explicit dependence does not affect these conclusions. Once
to order
, this term is no longer negligible and can cause the multiplicative factor in front of
to become smaller than 1 in absolute value, causing
to start shrinking. However, once
shrinks sufficiently this term again becomes negligible. Therefore, the loss will not converge to zero unless the curvature eventually drops below
. Conversely, notice that this term cannot cause
to diverge for learning rates below
. Indeed, if this were to happen then equation (7) would drive
to negative values, leading to a contradiction. This completes the analysis in this phase.
Let us make a few comments about the catapult phase.
It is important for the analysis that we take a modified large width limit, in which the number of training steps grows like log(n) as n becomes large. This is different than the large width limit commonly studied in the literature, in which the number of steps is kept fixed as the width is taken large. When using this modified limit, the analysis above holds even in the limit. Note as well that the catapult effect takes place over log(n) steps, and for practical networks will occur within the first 100 steps or so of training.
In the catapult phase, the kernel at the end of training is smaller by an order amount compared with its value at initialization. The kernel provides a local measure of the loss curvature. Therefore, the minima that SGD finds in the catapult phase are flatter than those it finds in the lazy phase. Contrast this situation, in which the kernel receives non-negligible updates, with the conclusions of Jacot et al. (2018) where the kernel is constant throughout training. The difference is due to the large learning rate, which leads to a breakdown of the linearized approximation even at large width.
Figure 2 illustrates the dynamics in the catapult phase. For learning rates we observe the catapult effect: the loss goes up before converging to zero. The curvature exhibits the expected sharp transitions as a function of the learning rate: it is constant in the lazy phase, decreases in the catapult phase, and diverges for
Figure 2. Empirical results for the gradient descent dynamics of the warmup model with , for which
. (a) Training loss for different learning rates. (b) Maximum NTK eigenvalue as a function of time. For
decreases rapidly to a fixed value. (c) Maximum NTK eigenvalue at
. The shaded area indicates learning rates for which training diverges empirically. The results are presented as a function of
(rather than t) for convenience.
2.1.3. DIVERGENT PHASE
Completing the analysis of this model, when the loss diverges because the kernel receives positive updates, accelerating the rate of growth of the function. Therefore,
is the maximum learning rate of the model.
2.2. Full model
We now turn to analyzing the model presented at the beginning of this section, with d-dimensional inputs and m training samples with general labels. The full analysis is presented in SM Section D.1; here we summarize the argument. The conclusions are essentially the same as those of the warmup model.
We introduce the notation for the function evaluated on a training sample,
for the error, and
for the kernel elements. We will treat
evaluated on the training set as vectors in
, whose elements are
. Consider the following update equation for the error, which can be derived from the update equations for the
parameters. Note that this is the exact update equation for this model; no higher-order terms were neglected.
We again take the modified large width limit , allowing the number of steps to scale logarithmically in the width. At initialization,
,
, and
are all of order
. We now analyze the gradient descent dynamics as a function of the learning rate.
The maximum eigenvalue of the kernel at step shrinks to zero in
time while the kernel receives
corrections. Therefore, in the limit the kernel remains constant until convergence. This is a special case of the NTK result (Jacot et al., 2018), and the model evolves as a linear model.
Next, suppose that . Early during training
grows, with the fastest growth taking place along the direction of the top kernel eigenvector,
. During this part of training the kernel receives
updates, and so
does not change much. As a result,
becomes aligned with
. In addition,
becomes close to
while the label is constant. We therefore consider the following approximate update equations for
for the maximum eigenvalue
, which can be approximated by
We note in passing the similarity between these equations and (6), (7). We see that once become of order
receives non-negligible negative corrections of order
. This evolution continues until
, after which the error converges to zero. Finally, if
, the error grows while
receives positive updates, and the loss diverges. This concludes the discussion of the theoretical model; further details can be found in Section 4 and in SM Section D.1.
In this section we test the extent to which the behavior of our theoretical model describes the dynamics of deep networks in practical settings. The theoretical results of Section 2, describing distinct learning rate phases, are not guaranteed to hold beyond the model analyzed there. We treat these results as predictions to be tested empirically, including the values of the learning rates that separate the three phases.
In a variety of deep learning settings, we find clear evidence of the different phases predicted by the model. The experiments all use MSE loss, sufficiently wide networks, and SGD2. Parameters such as network architecture, choice of non-linearity, weight parameterization, and regularization, do not significantly affect this conclusion.
In terms of the learning rates that determine the location of the transitions, the only modification needed to obtain good agreement with experiment is to replace the theoretical maximum learning rate, , with a 1-parameter function
, where
is an architecture-dependent constant. We find that
for all network that use ReLU non-linearity, and it seems this parameter depends only weakly on other details of the architecture. We find the level of agreement with the experiments surprising, given that our theoretical model involves a shallow network without non-linearities.
Building on the observed correlation between lower curvature and generalization performance (Keskar et al., 2016; Jiang et al., 2020), we conjecture that optimal performance occurs in the large learning rate (catapult) phase, where the loss converges to a flatter minimum. For a fixed amount of computational budget, we find that this conjecture holds in all cases we tried. Even when comparing different learning rates trained for a fixed amount of physical time , we find that performance of models trained in the catapult phase either matches or exceeds that of models trained in the lazy phase.
3.1. Early time curvature dynamics
Our theoretical model makes detailed predictions for the gradient descent evolution of , the top eigenvalue of the NTK. Here we test these predictions against empirical results in a variety of deep learning models (see the Supplement for additional experimental results).
Figure 3 shows during the early part of training for two deep learning settings. The results are compared against the theoretical predictions of a phase transition at
, and a maximum learning rate of
. Here
is the top eigenvalue of the empirical NTK at initialization.
For learning rates , we find that
is independent of the learning rate and constant throughout training, as expected in the lazy phase. For
we find that
decreases during training to below
, matching the predicted behavior in the catapult phase (note that in the Wide ResNet example,
initially increases before reaching its stable value).
The large learning rate behavior predicted by the model appears to persist up to the maximum learning rate, which is larger in these experiments than in the theoretical model. In these and other experiments involving ReLU networks, we find that is a good predictor of the maximum learning rate (in the SM C.4 we discuss other nonlinearities). We conjecture that this is the typical maximum learning rate of networks with ReLU non-linearities.
Figure 3 also shows the loss initially increasing before converging in the catapult phase, confirming another prediction of the model. This transient behavior is very short, taking less than 10 steps to complete.
Figure 3. Early time dynamics. (a,b,c) A 3 hidden layer fully-connected network with ReLU non-linearity trained on MNIST ((d,e,f) Wide ResNet 28-10 trained on CIFAR-10 (
). Both networks are trained with vanilla SGD; for more experimental details see SM Section A. (a,d) Early time dynamics of the training loss for learning rates in the linear and catapult phases. (b,e) Early time dynamics of the curvature for learning rates in the linear and catapult phase. (c,f)
measured at
(for FC) and
(for WRN), as a function of learning rate, compared with theoretical predictions for the locations of phase transitions. Training diverges for learning rates in the shaded region.
3.2. Generalization performance
We now consider the performance of trained models in the different phases discussed in this work. Keskar et al. (2016) observed a correlation between the flatness of a minimum found by SGD and the generalization performance (see Jiang et al. (2020) for additional empirical confirmation of this correlation). In this work, we showed that the minima SGD finds are flatter in the catapult phase, as measured by the top kernel eigenvalue. Our measure of flatness differs from that of Keskar et al. (2016), but we expect that these measures are correlated.
We therefore conjecture that optimal performance is often obtained for learning rates above and below the maximum learning rate.
In this section we test this conjecture empirically. We find that performance in the large learning rate range always matches or exceeds the performance when . For a fixed compute budget, we find that the best performance is always found in the catapult phase.
Figure 4 shows the accuracy as a function of the learning rate for a fully-connected ReLU network trained on a subset of MNIST. We find that the optimal performance is achieved above and close to
, the expected maximum learning rate.
Figure 4. Final accuracy versus learning rate for a fully-connected 1 hidden layer ReLU network, trained on 512 samples of MNIST with full-batch gradient descent until training accuracy reaches 1 or 700k physical steps (see SM Section A for details). We used a subset of samples to accentuate the performance difference between phases. The optimal performance is obtained when the learning rate is above , and close to
Next, Figure 5 shows the performance of a convolutional network and a Wide ResNet (WRN) trained on CIFAR-10. The experimental setup, which we now describe, was chosen to ensure a fair comparison of the performance across different learning rates. The network is trained with different initial learning rates, followed by a decay at a fixed physical time the same final learning rate. This schedule is introduced in order to ensure that all experiments have the same level of SGD noise toward the end of training.
We present results using two different stopping conditions. In Figure 5a, 5c, all models were trained for a fixed number of training steps. We find a significant performance gap between small and large learning rates, with the optimal learning rate above and close to
. Beyond this learning rate, performance drops sharply.
The fixed compute stopping condition, while of practical interest, biases the results in favor of large learning rates. Indeed, in the limit of small learning rate, training for a fixed number of steps will keep the model close to initialization. To control for this, in Figure 5b,5d models were trained for the same amount of physical time . For the CNN of figure 5b, decaying the learning rate does not have a significant effect on performance and we observe that performance is flat up to
there is no correlation between our measure of curvature and generalization performance. Figure 5d shows the analogous experiment for WRN. When decaying the learning rate toward the end of training to control for SGD noise, we find that optimal performance is achieved above
. In all these cases,
is a good predictor of the maximal learning rate, despite significant differences in the architectures. Notice that by tuning the learning rate to the catapult phase, we are able to achieve performance using MSE loss, and without momentum, that is competitive with the best reported results for this
Figure 5. Test accuracy vs learning rate for (a,b) a CNN trained on CIFAR-10 using SGD with batch size 256 and regularization (
) and (c,d) WRN28-10 trained on CIFAR-10 using SGD with batch size 1024,
regularization, and data augmentation (
for details. (a,c) have a fixed compute budget: (a) 437k steps and (b) 12k steps. (b,d) have been evolved for a fixed amount of physical time: (b) was evolved for 475
steps (purple) and evolved for 50k more steps at learning rate
(d) was evolved for
steps with learning rate
(purple) and then evolved for 4800 more steps at learning rate 0.035 (red). In all cases, optimal performance is achieved above
and close to the expected maximum learning rate, in agreement with our predictions.
In SM B.1, we present additional results for WRN on CIFAR-100, with similar conclusions as those for WRN on CIFAR-10.
So far we have focused on the generalization performance and curvature of the large learning rate phase. Here we investigate additional predictions made by our model.
4.1. Restoration of linear dynamics
One striking prediction of the model is that after a period of excursion, the logit differences settle back to O(1) values, the NTK stops changing, and evolution is again well approximated by a linear model with constant kernel at large width.
We speculate that the return to linearity and constancy of the kernel may hold asymptotically in width for more general models for a range of learning rates above . We test this by evolving the model for order log(n) steps until the catapult effect is over, linearizing the model, and comparing the evolution of the two models beyond this point. Figure 6 shows an example of this. At fixed width, the accuracy of the linear and non-linear networks match for a range of learning rates above the transition up to
. We present additional evidence for this asymptotic linearization behavior in the Supplement.
Figure 6. Evidence for linear dynamics after the catapult effect is over. Here we show the same model as in Figure 4 with the addition of models linearized at step 0 and another linearized at step 10. We observe that the model linearized after 10 steps tracks the non-linear performance in the catapult phase up to
4.2. Non-perturbative phase transition
The large width analysis of the small learning rate phase has been the subject of much work. In this phase, at infinite width, the network map evolves as a linear random features model, is the function of the linearized model. At large but finite width, corrections to this linear evolution can be systematically incorporated via a perturbative expansion (Taylor expansion) around infinite width (Dyer & Gur-Ari, 2020; Huang & Yau, 2019).
The evolution equations (10) and (11) of the solvable model are an example of this. At large width and in the small learning rate phase, the terms are suppressed for all times. In contrast, the leading order dynamics of
diverge when
, and so the true evolution cannot be described by the linear model. Indeed, the logits grow to
and thus all terms in (10) and (11) are of the same order. Similarly, the growth observed empirically in the catapult phase for more general models cannot be described by truncating the series (12) at any order, because the terms all become comparable.
In this work we took a step toward understanding the role of large learning rates in deep learning. We presented a dynamical mechanism that allows deep networks to be trained at larger learning rates than those accessible to their linear counterparts. For MSE loss, linear model training diverges when the learning rate is above the critical value curvature at initialization. We showed that deep networks can train for larger learning rates by navigating to an area of the landscape that has sufficiently low curvature. Perhaps counterintuitively, training in this regime involves an initial period during which the loss increases before converging to its final, small value. We call this the catapult effect.
5.1. A tractable model illustrating catapult dynamics
These observations are made concrete in our theoretical model, where we fully analyze the gradient descent dynamics as a function of the learning rate. The analysis involves a modified large width limit, in which both the width and training time are taken to be large. Sweeping the learning rate from small to large, and working in the limit, we find sharp transitions from a lazy phase where linearized model training is stable, to a catapult phase in which only the full model converges, and finally to a divergent phase in which training is unstable. These transitions have the hallmarks of phase transitions that commonly appear in physical systems such as ferromagnets or water, as one changes parameters such as temperature. In particular, these transitions are non-perturbative: a Taylor series expansion of the linearized model that takes into account finite width corrections is not sufficient to describe the behavior beyond the critical learning rate.
We derive the learning rates at which these transitions occur as a function of the curvature at initialization. We then treat these theoretical results as predictions, to be tested beyond the regime where they are guaranteed to hold, and find good quantitative agreement with empirical results across a variety of realistic deep learning settings.
We find it striking that a relatively simple theoretical model can correctly predict the behavior of realistic deep learning models. In particular, we conjecture that the maximum learning rate is typically a simple function of the curvature at initialization, with a single parameter that seems to depend only on the non-linearity. For ReLU networks, we conjecture that the maximum learning rate is approximately
, which we confirm in many cases.
5.2. Reducing misalignment of activations and gradients
The catapult dynamics for the simplified model in Section 2.1 reduce curvature by shrinking the component of the first layer weights u which is orthogonal to the second layer weights v, and shrinking the component of the second layer weights v which is orthogonal to the first layer weights u. We can rewrite the simplified model in terms of a hidden layer h = ux, where . The gradient with respect to this hidden layer is
. These hidden layer gradients
thus point in the same direction as v, while the hidden activations h point in the same direction as u. An alternative interpretation of the catapult dynamics is then that they reduce the components of h and
which are orthogonal to each other. The catapult dynamics thus serve, in this simplified model, to reduce the misalignment between feedforward activations h, and backpropagated gradients
. We hypothesize that this reduction of misalignment between activations and gradients may be a feature of large learning rates and catapult dynamics in deep, as well as shallow, networks. We further hypothesize that it may play a directly beneficial role in generalization, for instance by making the model output less sensitive to orthogonal, out-of-distribution, perturbations of activations.
5.3. Catapult dynamics often improve generalization
Our results shed light on the regularizing effect of training at large learning rates. The effect presented here is independent of the regularizing effect of stochastic gradient noise, which has been studied extensively. Building on previous works, we noted the observed correlation between flatness and generalization performance. Based on these observations, we expect the optimal performance to often occur for learning rates larger than , where the linearized model is unstable. Observing this effect required controlling for several confounding factors that affect the comparison of performance between different learning rates. Under a fair comparison, and also for a fixed compute budget, we find that this expectation holds in practice.
5.4. Beyond infinite linear models
One outcome of our work is to address the performance gap between ordinary neural networks, and linear models inspired by the theory of wide networks. Optimal performance is often obtained at large learning rates which are inaccessible to linearized models. In such cases, we expect the performance gap to persist even at arbitrarily large widths. We hope our work can further improve the understanding of deep learning methods.
5.5. Other open questions
There are several remaining open questions. While the model predicts a maximum learning rate of , for models with ReLU activations we find that the maximum learning rate is consistently higher. This may be due to a separate dynamical curvature-reduction mechanism that relies on ReLU. In addition, we do not explore the degree to which our results extend to softmax classification. While we expect qualitatively similar behavior there, the non-constant Hessian of the softmax cross entropy makes controlled experiments more challenging. Similarly, behavior for other optimizers such as SGD with momentum may differ. For example, the maximum learning rate when training a linear model is larger for gradient descent with momentum than for vanilla gradient descent, and therefore the transition to the catapult phase (if it exists) will occur at a higher learning rate. We leave these questions to future work.
The authors would like to thank Kyle Aitken, Dar Gilboa, Justin Gilmer, Boris Hanin, Tengyu Ma, Andrea Montanari, and Behnam Neyshabur for useful discussions. We would also like to thank Jaehoon Lee for early discussions about empirical properties of the lazy phase.
Allen-Zhu, Z., Li, Y., and Song, Z. A convergence theory for deep learning via over-parameterization. In Chaudhuri, K. and Salakhutdinov, R. (eds.), Proceedings of the 36th International Conference on Machine Learning, volume 97 of Proceedings of Machine Learning Research, pp. 242–252, Long Beach, California, USA, 09–15 Jun 2019. PMLR.
Arora, S., Du, S. S., Hu, W., Li, Z., Salakhutdinov, R. R., and Wang, R. On exact computation with an infinitely wide neural net. In Advances in Neural Information Processing Systems, pp. 8139–8148, 2019.
Bradbury, J., Frostig, R., Hawkins, P., Johnson, M. J., Leary, C., Maclaurin, D., and Wanderman-Milne, S. JAX: composable transformations of Python+NumPy programs, 2018. URL http://github.com/google/jax.
Chizat, L., Oyallon, E., and Bach, F. On lazy training in differentiable programming. In Wallach, H., Larochelle, H., Beygelzimer, A., d ´Alch´e-Buc, F., Fox, E., and Garnett, R. (eds.), Advances in Neural Information Processing Systems 32, pp. 2933–2943. Curran Associates, Inc., 2019. URL http://papers.nips.cc/paper/ 8559-on-lazy-training-in-\differentiable-programming.pdf.
Daniely, A. Sgd learns the conjugate kernel class of the network. In Advances in Neural Information Processing Systems, pp. 2422–2430, 2017.
Dinh, L., Pascanu, R., Bengio, S., and Bengio, Y. Sharp minima can generalize for deep nets. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pp. 1019–1028. JMLR. org, 2017.
Du, S. S., Lee, J. D., Li, H., Wang, L., and Zhai, X. Gradient descent finds global minima of deep neural networks. In Proceedings of the 36th International Conference on Machine Learning, ICML 2019, 9-15 June 2019, Long Beach, California, USA, pp. 1675–1685, 2019. URL http://proceedings.mlr.press/v97/du19c.html.
Dyer, E. and Gur-Ari, G. Asymptotics of wide networks from feynman diagrams. In International Conference on Learning Representations, 2020. URL https://openreview.net/forum?id=S1gFvANKDS.
Frankle, J., Schwab, D. J., and Morcos, A. S. The early phase of neural network training. arXiv preprint arXiv:2002.10365, 2020.
Hochreiter, S. and Schmidhuber, J. Flat minima. Neural Computation, 9(1):1–42, 1997.
Huang, J. and Yau, H.-T. Dynamics of Deep Neural Networks and Neural Tangent Hierarchy. arXiv e-prints, art. arXiv:1909.08156, Sep 2019.
Jacot, A., Gabriel, F., and Hongler, C. Neural tangent kernel: Convergence and generalization in neural networks. In Bengio, S., Wallach, H., Larochelle, H., Grauman, K., Cesa-Bianchi, N., and Garnett, R. (eds.), Advances in Neural Information Processing Systems 31, pp. 8571–8580. Curran Associates, Inc., 2018.
Jastrzebski, S., Szymczak, M., Fort, S., Arpit, D., Tabor, J., Cho, K., and Geras, K. The break-even point on optimization trajectories of deep neural networks. arXiv preprint arXiv:2002.09572, 2020.
Jiang, Y., Neyshabur, B., Krishnan, D., Mobahi, H., and Bengio, S. Fantastic generalization measures and where to find them. In International Conference on Learning Representations, 2020. URL https://openreview.net/forum? id=SJgIPJBFvH.
Keskar, N. S., Mudigere, D., Nocedal, J., Smelyanskiy, M., and Tang, P. T. P. On large-batch training for deep learning: Generalization gap and sharp minima. CoRR, abs/1609.04836, 2016. URL http://arxiv.org/abs/1609. 04836.
Leclerc, G. and Madry, A. The two regimes of deep network training, 2020.
Lee, J., Bahri, Y., Novak, R., Schoenholz, S., Pennington, J., and Sohl-dickstein, J. Deep neural networks as gaussian processes. In International Conference on Learning Representations, 2018. URL https://openreview.net/ forum?id=B1EA-M-0Z.
Lee, J., Xiao, L., Schoenholz, S., Bahri, Y., Novak, R., Sohl-Dickstein, J., and Pennington, J. Wide neural networks of any depth evolve as linear models under gradient descent. In Wallach, H., Larochelle, H., Beygelzimer, A., d’ Alch´e-Buc, F., Fox, E., and Garnett, R. (eds.), Advances in Neural Information Processing Systems 32, pp. 8570–8581. Curran Associates, Inc., 2019. URL http://papers.nips.cc/paper/9063-wide-neural-networks-of-\ any-depth-evolve-as-linear-models-\under-gradient-descent.pdf.
Li, Y. and Liang, Y. Learning overparameterized neural networks via stochastic gradient descent on structured data. In Advances in Neural Information Processing Systems, pp. 8157–8166, 2018.
Li, Y., Wei, C., and Ma, T. Towards explaining the regularization effect of initial large learning rate in training neural networks. In Wallach, H., Larochelle, H., Beygelzimer, A., d’Alch´e Buc, F., Fox, E., and Garnett, R. (eds.), Advances in Neural Information Processing Systems 32, pp. 11669–11680. Curran Associates, Inc., 2019.
Mandt, S., Hoffman, M. D., and Blei, D. M. Stochastic gradient descent as approximate bayesian inference. The Journal of Machine Learning Research, 18(1):4873–4907, 2017.
May, R. M. Simple mathematical models with very complicated dynamics. Nature, 261(5560):459–467, 1976.
Mei, S., Montanari, A., and Nguyen, P.-M. A mean field view of the landscape of two-layer neural networks. 115(33): E7665–E7671, 2018. doi: 10.1073/pnas.1806579115.
Naveh, Ben-David, Sompolinsky, and Ringel. to be published.
Novak, R., Xiao, L., Bahri, Y., Lee, J., Yang, G., Abolafia, D. A., Pennington, J., and Sohl-dickstein, J. Bayesian deep convolutional networks with many channels are gaussian processes. In International Conference on Learning Representations, 2019. URL https://openreview.net/forum?id=B1g30j0qF7.
Novak, R., Xiao, L., Hron, J., Lee, J., Alemi, A. A., Sohl-Dickstein, J., and Schoenholz, S. S. Neural tangents: Fast and easy infinite neural networks in python. In International Conference on Learning Representations, 2020. URL https://github.com/google/neural-tangents.
Park, D. S., Sohl-Dickstein, J., Le, Q. V., and Smith, S. L. The effect of network width on stochastic gradient descent and generalization: an empirical study. CoRR, abs/1905.03776, 2019. URL http://arxiv.org/abs/1905.03776.
Rotskoff, G. and Vanden-Eijnden, E. Parameters as interacting particles: long time convergence and asymptotic error scaling of neural networks. In Advances in neural information processing systems, pp. 7146–7155, 2018.
Sirignano, J. and Spiliopoulos, K. Mean field analysis of neural networks. arXiv preprint arXiv:1805.01053, 2018.
Smith, S. L. and Le, Q. V. A bayesian perspective on generalization and stochastic gradient descent. In International Conference on Learning Representations, 2018. URL https://openreview.net/forum?id=BJij4yg0Z.
Smith, S. L., Kindermans, P.-J., Ying, C., and Le, Q. V. Don’t Decay the Learning Rate, Increase the Batch Size. arXiv e-prints, art. arXiv:1711.00489, Nov 2017.
Smith, S. L., Duckworth, D., Rezchikov, S., Le, Q. V., and Sohl-Dickstein, J. Stochastic natural gradient descent draws posterior samples in function space. arXiv preprint arXiv:1806.09597, 2018.
Woodworth, B., Gunasekar, S., Lee, J., Soudry, D., and Srebro, N. Kernel and deep regimes in overparametrized models. arXiv preprint arXiv:1906.05827, 2019.
Xiao, L., Pennington, J., and Schoenholz, S. S. Disentangling trainability and generalization in deep learning, 2019.
Xie, Z., Sato, I., and Sugiyama, M. A diffusion theory for deep learning dynamics: Stochastic gradient descent escapes from sharp minima exponentially fast. arXiv preprint arXiv:2002.03495, 2020.
Zagoruyko, S. and Komodakis, N. Wide residual networks. CoRR, abs/1605.07146, 2016. URL http://arxiv.org/ abs/1605.07146.
Zou, D., Cao, Y., Zhou, D., and Gu, Q. Stochastic gradient descent optimizes over-parameterized deep relu networks. arXiv preprint arXiv:1811.08888, 2018.
We are using JAX (Bradbury et al., 2018) and the Neural Tangents Library for our experiments (Novak et al., 2020).
All the models have been trained with Mean Squared Error normalized as , where k is the number of classes and
are one-targets.
In a similar way, we have normalized the NTK as so that the eigenvalues of the NTK are the same as the non-zero eigenvalues of the Fisher information:
In our experiments we measure the top eigenvalue of the NTK using Lanczos’ algorithm. We construct the NTK on a small batch of data, typically several hundred samples, compute the top eigenvalue, and then average over batches. In this work, we do not focus on precision aspects such as fluctuations in the top eigenvalue across batches.
All experiments that compare different learning rates use the same seed for the weights at initialization and we consider only one such initialization (unless otherwise stated) although we have not seen much variance in the phenomena described. We let denote the constant (width-independent) coefficient of the standard deviation of the weight and bias initializations, respectively.
Here we describe experimental settings specific to a figure.
Figure 3a,3b,3c. Fully connected, three hidden layers w = 2048, ReLU non-linearity trained using SGD (no momentum) on MNIST. Batch size= 512, using NTK normalization,
Figures 3d,3e,3f. Wide ResNet 28-18 trained on CIFAR10 with SGD (no momentum). Batch size of 128, LeCun initialization with
Figures 4,6 Fully connected network with one hidden layer and ReLU non-linearity trained on 512 samples of MNIST with SGD (no momentum). Batch size of 512, NTK initialization with
Figures 5a,5b. The convolutional network has the following architecture: ConvReLU
MaxPool((2,2), ’VALID’)
MaxPool((2,2), ’VALID’)
Flatten()
. Dense(n) denotes a fully-connected layer with output dimension n. Conv
denote convolutional layers with ’SAME’ or ’VALID’ padding and n filters, respectively; all convolutional layers use (3, 3) filters. MaxPool((2,2), ’VALID’) performs max pooling with ’VALID’ padding and a (2,2) window size. LeCun initialization is used, with the standard deviation of the weights and biases drawn as
,
Figures 1, 5c,5d. Wide ResNet on CIFAR10 using SGD (no momentum). Training on v3-8 TPUs with a total batch size of 1024 (and per device batch size of 128). They all use regularization= 0.0005, LeCun initialization with
There is also data augmentation: we use flip, crop and mixup. With softmax classification, these models can get test accuracy of 0.965 if one uses cosine decay, so we don’t observe a big performance decay due to using MSE. Furthermore, we are using JAX’s implementation of Batch Norm which doesn’t keep track of training batch statistics for test mode evaluation. We have not hyperparameter tuned for learning rates nor
regularization parameter.
Figures S2,S3. Wide ResNet on CIFAR100 using SGD (no momentum). Same setting as figure 5c, 5d except for the different dataset, different L2 regularization = 0.000025 and label smoothing (we have subtracted 0.01 from the target one-hot labels).
Figure S7. Two hidden layer, ReLU network for one data point x = 1, y = 1.
Figure S10. Fully connected network with two hidden layers and tanh non-linearity trained on MNIST with SGD (no momentum). Batch size of 512, LeCun initialization with
Figure S8a. Two-hidden layer fully connected network trained on MNIST with batch size 512, NTK normalization with . Trained using both momenta
and vanilla SGD for three different non-linearities: tanh, ReLU and identity (no non-linearity). The learning rate for each non-linearity was chosen to correspond to
Rest of SM figures. Small modifications of experiments in previous figures, specified in captions.
Figure S1. Visualization of training dynamics in all three phases. In the lazy phase, the network is approximately linear in its parameters, and converges exponentially to a global minimum. In the catapult phase, the loss initially grows, while the weight norm and curvature decrease. Once the curvature is low enough, optimization converges. In the divergent phase, both the loss and parameter magnitudes diverge. (a)-(d) Loss surface and training dynamics visualized in a 2d linear subspace. The network has a single hidden layer with width n = 500, linear activations, and is trained with MSE loss on a single 1D sample x = 1 with label y = 0. The parameter subspace is defined by are orthonormal vectors,
are the weight vectors, and [dim1], [dim2] are the coordinates in the subspace. If initialized in this 2d subspace,
remain in the subspace throughout training, and so training dynamics can be fully visualized with a two dimensional plot. (e) Visualization of the loss surface and training dynamics in terms of a nonlinear reparameterization, providing interpretable properties: x-axis correlation between weight vectors, y-axis curvature
. The trajectory shown is identical to that in (c), and in Figure 1.
B.1. CIFAR-100 performance
We can also repeat the performance experiments for CIFAR-100 and the same Wide ResNet 28-10 setup. In this case, using MSE and SGD we require to evolve the system for longer times, which requires a smaller regularization. We didn’t tune for it, but found that
works. With only one decay we can get within 3% of the Zagoruyko & Komodakis (2016) performance that used softmax classification and two learning rate decays. However, evolution for longer time is needed: we found that different learning rates converge at
physical epochs. Similar to the main text experiments, we observe that if we decay after evolving for the same amount of physical epochs, larger learning rates do better. See figure S2.
B.2. Different learning rates converge at the same physical time
We can also plot the test accuracy versus physical time for different learning rates to show that for vanilla SGD, the performance curves of different learning rates are basically on top of each other if we plot them in physical time, which is why we find that the fair comparison between learning rates should be at the same physical time.
We have picked a subset of learning rates of the previous WRN28-18 CIFAR100 experiment of SM B.1. In figure S3, we see how even if the curves are slightly different they converge to roughly the same accuracy. The only curve which is slightly different is which is a rather high learning rate (close to
B.3. Comparison of learning rates for different regularization for WRN28-10 on CIFAR10
Even if in the main section we have considered a model with fixed regularization, we can study the effect without
with a different value. In these two examples, we will be considering the same setup as figures 5c,5d.
Without regularization, we see that the larger learning rate does better even in the absence of learning rate decay, although training takes a really long time. In our experience, comparing this setup with state of the art,
regularization makes
Figure S2. Test accuracy vs learning rate for WRN28-10 and CIFAR100 with vanilla SGD, regularization, data augmentation, label smoothing and batch size 1024. The critical learning rate is
. (a) Evolved for 38400 steps. (b) Evolved for 96000
steps with learning rate
(blue) and then evolved for 7200 more steps at learning rate 0.01 (red).
Figure S3. Test accuracy vs physical time for different learning rates in the WRN CIFAR100 experiment of the previous section B.1
the experiment take longer before convergence but does not influence performance much.
In the presence of regularization we picked the particular value
in order to make sure that our conclusion is not dependent on the choice of
, the only hyperparameter (other than
), we have considered a larger
that the optimal performance in physical time is also peaked in the catapult phase, although the difference here is smaller.
B.4. Training accuracy plots
The training accuracies of the previous experiments are shown in figure S6.
Figure S4. WRN28-10 on CIFAR10 without . Same setup as 5d but evolved for longer times.
Figure S5. Test accuracies for a larger CIFAR10 experiment like that of the main section. (a) WRN CIFAR-10 7200 steps as in figure 5c. (b) WRN CIFAR10 2400 physical steps and then 4800 more steps at learning rate 0.01 as in figure 5d.
Figure S6. Training accuracies for the performance experiments. Smaller learning rates have higher training accuracy when compared in physical time. However, they still perform worse for a fixed number of steps. (a) WRN CIFAR-10 12000 steps as in figure 5c. (b) WRN CIFAR10 3360 physical steps as in figure 5d. (c) WRN CIFAR100 38400 steps as in figure S2a.(d) WRN CIFAR100 96000 physical steps as in figure S2b.
C.1. ReLU activations for the simple model
In the main text we have been using ReLU non-linearities. Compared with the simple model with no non-linearities, ReLU networks have a broader trainability regime after . It looks like these networks generically well train until
. This is a generic feature of deep ReLU networks and can be already observed for the model of section 2 with a target y = 1, two hidden layers and a ReLU non-linearity: f = u.ReLU(w.ReLU(v)), as shown in figure S7). In this single sample context for
, the loss doesn’t diverge but the neurons die and end up giving the trivial f = 0 function. For deep networks with more than one hidden layer and multiple samples, as discussed in the main text, we observe that the loss diverges after
Figure S7. Simple model ReLU non-linearity (). (b) is evaluated at physical time 100.
C.2. Momenta
The effect of the optimizer also affects these dynamics. If we consider a similar setup with momenta, first we expect that a linear model converges in a broader range . For smooth non-linearities, we observe that for
constant. However this is not true for ReLU, see figure S8a. In fact, for ReLu networks, we observe that there is a small learning rate, roughly
, below which the time dynamics of
is similar (but non-constant). However, for
, there are strong time dynamics, we illustrate this in figure S8b with a 3 hidden layer ReLu network.
C.3. Effect of regularization to early time dynamics
We don’t expect regularization to affect the early time dynamics, but because of the strong rearrangement that goes on in the first steps, it could potentially have a non-trivial effect; among other things, the Hessian spectrum necessarily is decaying. We can see how the dynamics that drives the rearrangement is roughly the same, even in the maximum eigenvalue at early times is decreasing slowly.
C.4. Tanh activations
We observe that for Tanh activation, is closer to the simple model expectation
, see figure S10.
C.5. WRN NTK Normalization
As illustrated in the text in figures 3b, 3c we also see this behaviour for NTK normalization. For completeness we include the WRN model with NTK normalization. From the linearized intuition, we expect the phases to also be determined by the quantity , independently of the normalization. Figure S11 has the same setup as in figure 3.
Figure S8. (a) Evolution of the normalized curvature FC connected networks evolved with momenta (same networks with SGD with dashed line for reference) evolved for
. We observe that ReLU networks evolved with momenta doesn’t have a constant kernel in the naive ‘lazy’ phase. (b)
Same setup as the FC network of figure 3 with momenta
: fully connected, three hidden layers w = 2048, ReLU non-linearity.
is slightly different due to variations at initialization.
Figure S9. Same WRN as figure 3d,f with regularization= 0.0005. Dynamics in physical steps of the
at physical time 25
Figure S10. Maximum NTK eigenvalue at early times for a 2 hidden layer fully connected network with tanh non-linearity trained on MNIST, with
. (a) Early time dynamics of the curvature for learning rates in the linear and catapult phase. (b)
measured at
D.1. Full model analysis
Here we provide additional details on the theoretical analysis of the full model in Section 2.2. The gradient descent update equations are
and
The update equations for the error and kernel evaluated on training set inputs are
Where . We now consider the dynamics of the kernel projected onto the
direction, which is given by
Let us now analyze the phase structure of (S3) and (S5). For now, we neglect the last term on the right-hand side of (S3) (at initialization this term is of order and is negligible at large width). Let
be the maximal eigenvalue of the kernel at initialization, and let
be the corresponding eigenvector. Notice that
projected onto the top eigenvector evolves as
Lazy phase. When , we see that
shrinks during training. The kernel updates are of order
, while convergence happens in order
steps. Therefore the kernel does not change by much during training. This is a special case of the NTK result (Jacot et al., 2018). Effectively, the model evolves as a linear model in this phase.
Catapult phase. When grows exponentially fast, and it grows fastest in the
direction. Therefore, the vector
becomes aligned with
after a number of steps that is of order
. Also, f itself grows quickly while the label is constant, and so we find that
after a similar number of steps. When these approximations hold, notice that
. From equation (S5) we can then derive an approximate equation for the evolution of the top NTK eigenvalue.
While grows exponentially fast, so will
. When
becomes of order
, the updates to the top eigenvalue become of order
(and negative), causing
to decrease by a non-negligible amount. This will continue until
, at which point
will start converging to zero. Eventually, after a number of steps of order log(n), gradient descent will converge to a global minimum that has a lower curvature than the curvature at initialization.
The justification for dropping the order term in (S6) was explained in the warmup model: While this term may affect the details of the dynamics, eventually the maximum kernel eigenvalue must drop below
for the component
the error (and therefore for the loss) to converge to zero.
Divergent phase. When , both
and
will grow, and optimization will diverge. Therefore,
is the maximum learning rate for this model.
Here we consider the gradient descent dynamics of the model analyzed in Section 2, for learning rates that are close to the critical point
. The analysis reveals that the gradient descent dynamics of the model are qualitatively different above and below this point. For example, the loss decreases monotonically during training when
, but not when
. In this section we show that the transition from small to large learning rate becomes sharp once we take the modified large width limit, in the following sense: certain functions of the learning rate become non-analytic at
limit. This sharp transition bears close resemblance to phase transitions of the kind found in physical systems, such as the transition between the liquid and gaseous phases of water. In particular, our case involves a dynamical system, where the dynamics are governed by the gradient descent equations. These dynamics undergo a phase transition as a function of the learning rate — an external parameter. We point to the logistic map (May, 1976) as a well-known example of a dynamical system that undergoes phase transitions as a function of an external parameter.
E.1. Non-perturbative dynamics
A phase transition is a drastic change in a system’s behavior incurred under a small change in external parameters. Mathematically, it is a non-analyticity in some property of the system as a function of these parameters. For example, consider the property , the curvature of the model at the end of training as a function of the learning rate. In the modified large width limit,
is constant for
, but not for
. Therefore, this function is not analytic at
. Notice that this statement is true in the limit but not necessarily at finite width, where the final curvature may be an analytic function of the learning rate even at
. It is well known in physics that phase transitions only occur in a limit where the number of dynamical variables (in this case the number of model parameters) is taken to infinity. One immediate consequence of the non-analyticity at
is that the large learning rate phase is inaccessible from the small learning rate phase via a perturbative expansion. In other words, we cannot describe all properties of the model for some
by doing a Taylor expansion around a point
and keeping a finite number of terms.
Dyer & Gur-Ari (2020); Huang & Yau (2019) developed a formalism that allows one to compute finite-width corrections to various properties of deep networks, using a perturbative expansion around the infinite width limit. We have argued that the usual infinite width approximation to the training dynamics is not valid for learning rates above , and that a full analysis must account for large finite-width effects. One may have hoped that including the perturbative finite-width corrections discussed in Dyer & Gur-Ari (2020); Huang & Yau (2019) would allow us to regain analytic control over the dynamics. The results presented here suggest that this is not the case: For
, we expect that the perturbative expansion will not provide a good approximation to the gradient descent dynamics at any finite order in inverse width.
E.2. Critical exponents
When the external parameters are close to a phase transition, one often finds that the dynamical properties of the system obey power law behavior. The exponents of these power laws (called critical exponents) are of interest because they are often found to be universal, in the sense that the same set of exponents is often found to describe the phase transitions of completely different physical systems.
Here we consider , the number of steps until convergence, as a function of the learning rate. We will now show that
exhibits power-law behavior when
is close to
. For simplicity we consider the warmup model studied in Section 2. First, suppose that we are below the transition, setting
for some small
. From the update equation,
we see that
will converge to some fixed small value
after time
. Here we assumed that
is constant in t, which is true as long as
is independent of n (namely we fix
and then take n large). Therefore, the convergence time below the transition scales as
, and the critical exponent is -1.
Next, suppose that . Now the update equation reads
. This approximation holds early during training, when the curvature updates are small. Initially,
will grow until it is of order
, at which point the updates to
become of order
. This will happen in time
. Following this, the optimizer will converge. At this point
is no longer tuned to be close to the transition, and so the convergence time measured from this point on will not be sensitive to
. Therefore, for small
the convergence time will be dominated by the early part of training, namely
. The critical exponent is again -1. Figure S12 show an empirical verification of this behavior.
Figure S12. The convergence time diverges when the learning rate is close to the critical value , indicated by the solid green line. The measured exponents (shown in parentheses) are close to the predicted value of -1. Experiment involves the warmup model of Section 2 with width 16,000.
Here we present some more detailed evidence for the re-emergence of linear dynamics in the catapult phase. Figure S13 show results for models trained on subsets of MNIST with learning rates . In figure Figure S13a we see that for a one-hidden-layer fully connected model trained on 512 MNIST images, the performance of the full non-linear model and model linearized after 10 steps track closely. Models evolve as linear models when the NTK is constant. In Figure S13b we give evidence that as networks become wider, the change in the kernel decreases.
Figure S13. Evidence for a return of linear dynamics after . (a,b) Show the same model as in figure 4 with the addition of linearized models at step 0 and 10. We observe that the linearized model after 10 steps tracks the non-linear performance in the ‘catapult’ phase up to
(c) The change in the NTK between
steps decreases as the width increases. Here we consider 2-class MNIST with 100 samples per class.