Disentangling trainability and generalization in deep learning

2019·Arxiv

Abstract

Abstract

A longstanding goal in the theory of deep learning is to characterize the conditions under which a given neural network architecture will be trainable, and if so, how well it might generalize to unseen data. In this work, we provide such a characterization in the limit of very wide and very deep networks, for which the analysis simplifies considerably. For wide networks, the trajectory under gradient descent is governed by the Neural Tangent Kernel (NTK), and for deep networks the NTK itself maintains only weak data dependence. By analyzing the spectrum of the NTK, we formulate necessary conditions for trainability and generalization across a range of architectures, including Fully Connected Networks (FCNs) and Convolutional Neural Networks (CNNs). We identify large regions of hyperparameter space for which networks can memorize the training set but completely fail to generalize. We find that CNNs without global average pooling behave almost identically to FCNs, but that CNNs with pooling have markedly different and often better generalization performance. These theoretical results are corroborated experimentally on CIFAR10 for a variety of network architectures and we include a colab1

notebook that reproduces the essential results of the paper.

1. Introduction

Machine learning models based on deep neural networks have attained state-of-the-art performance across a dizzying array of tasks including vision (Cubuk et al., 2019), speech recognition (Park et al., 2019), machine translation (Bah- danau et al., 2014), chemical property prediction (Gilmer et al., 2017), diagnosing medical conditions (Raghu et al., 2019), and playing games (Silver et al., 2018). Historically, the rampant success of deep learning models has lacked a sturdy theoretical foundation: architectures, hyperparameters, and learning algorithms are often selected by brute force search (Bergstra & Bengio, 2012) and heuristics (Glorot & Bengio, 2010). Recently, significant theoretical progress has been made on several fronts that have shown promise in making neural network design more systematic. In particular, in the infinite width (or channel) limit, the distribution of functions induced by neural networks with random weights and biases has been precisely characterized before, during, and after training.

The study of infinite networks dates back to seminal work by Neal (1994) who showed that the distribution of functions given by single hidden-layer networks with random weights and biases in the infinite-width limit are Gaussian Processes (GPs). Recently, there has been renewed interest in studying random, infinite, networks starting with concurrent work on “conjugate kernels” (Daniely et al., 2016; Daniely, 2017) and “mean-field theory” (Poole et al., 2016; Schoenholz et al., 2017). Among numerous contributions, the pair of papers by Daniely et al. argued that the empirical covariance matrix of pre-activations becomes deterministic in the infinite-width limit and called this the conjugate kernel of the network. Meanwhile, from a mean-field perspective, the latter two papers studied the properties of these limiting kernels. In particular, the spectrum of the conjugate kernel of wide, fully-connected, networks approaches a well-defined and data-independent limit when the depth exceeds a certain scale, . Networks with tanh-nonlinearities (among other bounded activations) exhibit a phase transition between two limiting spectral distributions of the conjugate kernel as a function of their hyperparameters with diverging at the transition. It was additionally hypothesized that networks were un-trainable when the conjugate kernel was sufficiently close to its limit.

Since then this analysis has been extended to include a wide range for architectures such as convolutions (Xiao et al., 2018), recurrent networks (Chen et al., 2018; Gilboa et al., 2019), networks with residual connections (Yang & Schoen- holz, 2017), networks with quantized activations (Blumen- feld et al., 2019), the spectrum of the fisher (Karakida et al., 2018), a range of activation functions (Hayou et al., 2018), and batch normalization (Yang et al., 2019). In each case, it was observed that the spectra of the kernels correlated strongly with whether or not the architectures were trainable. While these papers studied the properties of the conjugate kernels, especially the spectrum in the large-depth limit, a branch of concurrent work took a Bayesian perspective: that many networks converge to Gaussian Processes as their width becomes large (Lee et al., 2018; Matthews et al., 2018; Novak et al., 2019b; Garriga-Alonso et al., 2018; Yang, 2019). In this case, the Conjugate Kernel was referred to as the Neural Network Gaussian Process (NNGP) kernel, which is used to train neural networks in a fully Bayesian fashion. As such, the NNGP kernel characterizes performance of the corresponding NNGP.

Together this work offered a significant advance to our understanding of wide neural networks; however, this theoretical progress was limited to networks at initialization or after Bayesian posterior estimation and provided no link to gradient descent. Moreover, there was some preliminary evidence that suggested the situation might be more nuanced than the qualitative link between the NNGP spectrum and trainability might suggest. For example, Philipp et al. (2017) showed that deep tanh FCNs could be trained after the kernel reached its large-depth, data-independent, limit but that these networks did not generalize to unseen data.

Recently, significant theoretical clarity has been reached regarding the relationship between the GP prior and the distribution following gradient descent. In particular, Jacot et al. (2018) along with followup work (Lee et al., 2019; Chizat et al., 2019) showed that the distribution of functions induced by gradient descent for infinite-width networks is a Gaussian Process with a particular compositional kernel known as the Neural Tangent Kernel (NTK). In addition to characterizing the distribution over functions following gradient descent in the wide network limit, the learning dynamics can be solved analytically throughout optimization.

In this paper, we leverage these developments and revisit the relationship between architecture, hyperparameters, trainability, and generalization in the large-depth limit for a variety of neural networks. In particular, we make the following contributions:

• Trainability. We compute the large-depth asymptotics of several quantities related to trainability, including the largest/smallest eigenvalue of the NTK, , and the condition number ; see Table 1.

• Generalization. We characterize the mean predictor , which is intimately related to the prediction of wide neural networks on the test set following gradient descent training. As such, the mean predictor is intimately related to the model’s ability to generalize. In particular, we argue that networks fail to generalize if the mean predictor becomes data-independent.

Table 1. Evolution of the NTK spectra and as a function of depth l. The NTKs of FCN and CNN without pooling (CNN-F) are essentially the same and the scaling of for these networks is written in black. Corrections to these quantities due to the addition of an average pooling layer (CNN-P) with window size d is written in blue.

• We show that the ordered and chaotic phases identi-fied in Poole et al. (2016) lead to markedly different limiting spectra of the NTK. In the ordered phase the trainability of neural networks degrades at large depths, but their ability to generalize persists. By contrast, in the chaotic phase we show that trainability improves with depth, but generalization degrades and neural networks behave like hash functions.

A corollary of these differences in the spectra is that, as a function of depth, the optimal learning rates ought to decay exponentially in the chaotic phase, linearly on the order-to-chase trainsition line, and remain roughly a constant in the ordered phase.

• We examine the differences in the above quantities for fully-connected networks (FCNs) and convolutional networks (CNNs) with and without pooling and precisely characterize the effect of pooling on the interplay between trainability, generalization, and depth.

In each case, we provide empirical evidence to support our theoretical conclusions. Together these results provide a complete, analytically tractable, and dataset-independent theory for learning in very deep and wide networks. Philosophically, we find that trainability and generalization are distinct notions that are, at least in this case, at odds with one another. Indeed, good conditioning of the NTK (which is a necessary condition for training) seems necessarily to lead to poor generalization performance. It will be interesting to see whether these results carry over in shallower and narrower networks. The tractable nature of the wide and deep regime leads us to conclude that these models will be an interesting testbed to investigate various theories of generalization in deep learning.

2. Related Work

Recent work Jacot et al. (2018); Du et al. (2018b); Allen- Zhu et al. (2018); Du et al. (2018a); Zou et al. (2018) and many others proved global convergence of over-parameterized deep networks by showing that the NTK essentailly remains a constant over the course of training. However, in a different scaling limit the NTK changes over the course of training and global convergence is much more difficult to obtain and is known for neural networks with one hidden layer Mei et al. (2018); Chizat & Bach (2018); Sirig- nano & Spiliopoulos (2018); Rotskoff & Vanden-Eijnden (2018). Therefore, understanding the training and generalization properties in this scaling limit remains a very challenging open question.

Another two excellent recent works (Hayou et al., 2019; Jacot et al., 2019) also study the dynamics of for FCNs (and deconvolutions in (Jacot et al., 2019)) as a function of depth and variances of the weights and biases. (Hayou et al., 2019) investigates role of activation functions (smooth v.s. non-smooth) and skip-connection. (Jacot et al., 2019) demonstrate that batch normalization helps remove the “ordered phase” (as in (Yang et al., 2019)) and a layer-dependent learning rate allows every layer in a network to contribute to learning.

3. Background

We summarize recent developments in the study of wide random networks. We will keep our discussion relatively informal; see e.g. (Novak et al., 2019b) for a more rigorous version of these arguments. To simplify this discussion and as a warm-up for the main text, we will consider the case of FCNs. Consider a fully-connected network of depth L where each layer has a width and an activation function . In the main text we will restrict our discussion to or tanh for clarity, however we include results for a range of architectures including with and without skip connections and layer normalization in the supplementary material (see Sec. C). We find that the high level picture described here applies to a wide range of architectural components, though important specifics - such as the phase diagram - can vary substantially. For simplicity, we will take the width of the hidden layers to infinity sequentially: . The network is parameterized by weights and biases that we take to be randomly initialized with along with hyperparameters, that set the scale of the weights and biases respectively. Letting the pre-activation in the layer due to an input x be given by , the network

is then described by the recursion, for

Notice that as , the sum ends up being over a large number of random variables and we can invoke the central limit theorem to conclude that the are i.i.d. Gaussian with zero mean. Given a dataset of m points, the distribution over pre-activations can therefore be described completely by the covariance matrix, i.e. the NNGP kernel, between neurons in different inputs Inspecting Equation 1, we see that can be computed in terms of

Equation 2 describes a dynamical system on positive semi-definite matrices K. It was shown in Poole et al. (2016) that fixed points, , of these dynamics exist such that with independent of the inputs x and . The values of are determined by the hyperparameters, . However Equation 2 admits multiple fixed points (e.g. ) and the stability of these fixed points plays a significant role in determining the properties of the network. Generically, there are large regions of the plane in which the fixed-point structure is constant punctuated by curves, called phase transitions, where the structure changes; see Fig 5 for tanh-networks.

The rate at which approaches or departs can be determined by expanding Equation 2 about its fixed point,

with and is the derivative of . This expansion naturally exhibits exponential convergence to - or divergence from - the fixed-point as where . Since does not depend on x or it follows that will take on a single value, whenever . If then this fixed point is stable, but if then the fixed point is unstable and, as discussed above, the system will converge to a different fixed point. If then the hyperparameters lie at a phase transition and convergence is non-exponential. As was shown in Poole et al. (2016), there is always a fixed-point at whose stability is determined by . This is the so-called ordered phase since any pair of inputs will converge to identical outputs. The line defined by defines the order-to-chaos transition separating the ordered phase from the “chaotic” phase (where ). Note, that can be used to define a depth-scale, that describes the number of layers over which approaches

This provides a precise characterization of the NNGP kernel at large depths. As discussed above, recent work (Jacot et al., 2018; Lee et al., 2019; Chizat et al., 2019) has connected the prior described by the NNGP with the result of gradient descent training using a quantity called the NTK. To construct the NTK, suppose we enumerate all the parameters in the fully-connected network described above by . The finite width NTK is defined by where is the Jacobian evaluated at a point x. The main result in Jacot et al. (2018) was to show that in the infinite-width limit, the NTK converges to a deterministic kernel and remains constant over the course of training. As such, at a time t during gradient descent training with an MSE loss, the expected outputs of an infinitely wide network, , evolve as

for train and test points respectively; see Section 2 in Lee et al. (2019). Here test, train denotes the NTK between the test inputs and training inputs and train, train is defined similarly. Since converges to as the network’s width approaches infinity, the gradient flow dynamics of real network also converge to the dynamics described by Equation 5 and Equation 6 (Jacot et al., 2018; Lee et al., 2019; Chizat et al., 2019; Yang, 2019; Arora et al., 2019; Huang & Yau, 2019). As the training time, t, tends to infinity we note that these equations reduce to test, traintrain, train. Consequently we call

the “mean predictor”. We can also compute the mean predictor of the NNGP kernel, P(K), which analogously can be used to find the mean of the posterior after Bayesian inference. We will discuss the connection between the mean predictor and generalization in the next section.

In addition to showing that the NTK describes networks during gradient descent, Jacot et al. (2018) showed that the NTK could be computed in closed form in terms of T , , and the NNGP as,

where is the NTK for the pre-activations at layer-l.

4. Metrics for Trainability and Generalization at Large Depth

We begin by discussing the interplay between the conditioning of train, train and the trainability of wide networks. We can write Equation 5 in terms of the spectrum of train, train. To do this we write the eigendecomposition of train, train as train, train with D a diagonal matrix of eigenvalues and U a unitary matrix. In this case Equation 5 can be written as,

where are the eigenvalues of train, train and , are the mean prediction and the labels respectively written in the eigenbasis of train,train. If we order the eigenvalues such that then it has been hypothesized2 in e.g. Lee et al. (2019) that the maximum feasible learning rate scales as as we verify empirically in section 4. Plugging this scaling for into Equation 9 we see that the smallest eigenvalue will converge exponentially at a rate given by , where is the condition number. It follows that if the condition number of the NTK associated with a neural network diverges then it will become untrainable and so we use as a metric for trainability.

We will see that at large depths, the spectrum of train, train typically features a single large eigenvalue, , and then a gap that is large compared with the rest of the spectrum. We therefore will often refer to a typical eigenvalue in the bulk as and approximate the condition number as

We now turn our attention to generalization. At large depths, we will see that test, train and train, train converge their fixed points independent of the data distribution. Consequently it is often the case that will be data-independent and the network will fail to generalize. In this case, by symmetry, it is necessarily true that will be a constant matrix. Contracting this matrix with a vector of labels that have been standardized to have zero mean it will follow that and the network will output zero in expectation on all test points. Clearly, in this setting the network will not be able to generalize. At large, but finite, depths the generalization performance of the network can be quantified by considering the rate at which decays to zero. There are cases, however, where despite the data-independence of remains nonzero and the network can continue to generalize even in the asymptotic limit. In either case, we will show that precisely characterizing allows us to understand exactly where networks can, and cannot, generalize.

Our goal is therefore to characterize the evolution of the two metrics and in l. We follow the methodology outlined in Schoenholz et al. (2017); Xiao et al. (2018) to explore the spectrum of the NTK as a function of depth. We will use this to make precise predictions relating trainability and generalization to the hyperparameters . Our main results are summarized in Table 1 which describes the evolution of (the largest eigenvalue of remaining eigenvalues), , and as a function of depth for three different network configurations (the ordered phase, the chaotic phase, and the phase transition). We study the dependence on: the size of the training set, m; the choices of architecture including fully-connected networks (FCN), convolutional networks with flattening (CNN-F), and convolutions with pooling (CNN-P); and the size, d, of the window in the pooling layer (which we always take to be the penultimate layer).

Before discussing the methodology it is useful to first give a qualitative overview of the phenomenology. We find identical phenomenology between FCNs and CNN-F architectures. In the ordered phase, , and . At large depths since it follows that and so the condition num- ber diverges exponentially quickly. Thus, in the ordered phase we expect networks not to be trainable (or, specifi-cally, the time they take to learn will grow exponentially in their depth). Here converges to a data dependent constant independent of depth; thus, in the ordered phase networks fail to train but can generalize indefinitely.

By contrast, in the chaotic phase we see that there is no gap between and and networks become perfectly conditioned and are trainable everywhere. However, in this regime we see that the mean predictor scales as Since in the chaotic phase and it follows that over a depth . Thus, in the chaotic phase, networks fail to generalize at a finite depth but remain trainable indefinitely. Finally, introducing pooling modestly augments the depth over which networks can generalize in the chaotic phase but reduces the depth in the ordered phase. We will explore all of these predictions in detail in section 7.

5. A Toy Example: RBF Kernel

To provide more intuition about our analysis, we present a toy example using RBF kernels which already shares some core observations for deep neural networks. Consider a Gaussian process along with the RBF kernel given by,

where along with a bandwidth h > 0. Note that for all h and x. Considering the follow-

ing two cases.

If the bandwidth is given by and , then which converges to 1 ex- ponentially fast. Thus, the largest eigenvalue of is and the bulk is of order . Thus the condition number which diverges with l. We will see in the Ordered Phase behaves qualitatively similar to this setting.

On the other hand, if the bandwidth is given by h = 1/l and then the off-diagonals . For large is very close to the identity matrix and the condition number of it is almost 1. In the Chaotic Phase, is qualitatively similar to

6. Large-Depth Asymptotics of the NNGP and NTK

We now give a brief derivation of the results in Table 1. Details can be found in Sec.B, D in the appendix. To simplify notation we will discuss fully-connected networks and then extend the results to CNNs with pooling (CNN-P) and without pooling (CNN-F).

As in Sec. 3, we will be concerned with the fixed points of as well as the linearization of Equation 8 about its fixed point. Recall that the fixed point structure is invariant within a phase so it suffices to consider the ordered phase, the chaotic phase, and the critical line separately. In cases where a stable fixed point exists, we will describe how converges to the fixed point. We will see that in the chaotic phase and on the critical line, has no stable fixed point and in that case we will describe its divergence. As above, in each case the fixed points of have a simple structure with

To simplify the forthcoming analysis, without a loss of generality, we assume the inputs are normalized to have variance . As such, we can treat , restricted on , as a point-wise functions. To see this note that with this normalization . It follows that both and depend only on

Since all of the off-diagonal elements approach the same fixed point at the same rate, we use and to denote any off diagonal entry of respectively. We will similarly use to denote the limits, and . Finally, although the diagonal entries of are all , the diagonal entries of can

Figure 1. Condition number and mean predictor of NTKs and their rate of convergence for FCN, CNN-F and CNN-P. (a) In the chaotic phase, converges to 1 for all architectures. (b) We plot , confirming that explodes with rate in the ordered phase. In (c) and (d), the solid lines are and dashed lines are the ratio between first and second eigenvalues. We see that, on the order-to-chaos transition, these two numbers converge to (horizontal lines) for FC/CNN-F and CNN-P respectively, where m = 12 or 20 is the batch size and d = 36 is the spatial dimension. (e) In the chaotic phase, the mean predictor decays to zero exponentially fast. (f) In the ordered phase the mean predictor converges to a data dependent value.

vary and we denote them

In what follows, we split the discussion into three sections according to the values of recalling that in Poole et al. (2016); Schoenholz et al. (2017) it was shown that controls the fixed point structure. In each section, we analyze the evolution of (1) the entries of , (2) the spectrum and , (3) the trainability and generalization metrics , and finally (4) discuss the impact on finite width networks.

6.1. The Chaotic Phase χ1 > 1:

The chaotic phase is so-named because it has a stable fixed-point ; as such similar inputs become increasingly uncorrelated as they pass through the network. Our first result is to show that (see Sec. B.1),

where

Note that controls the convergence of the and is always less than 1 in the chaotic phase (Poole et al., 2016; Schoenholz et al., 2017; Xiao et al., 2018). Since , diverges with rate remains finite. It follows that as . Thus, in the chaotic phase, the spectrum of the NTK for very deep networks approaches the diverging constant multiplying the identity. This implies

Figure 1a plots the evolution of in this phase, confirming for all three different architectures (FCN, CNN-F and CNN-P).

We now describe the asymptotic behavior of the mean predictor. Since test, train has no diagonal elements, it follows that it remains finite at large depths and so It follows that in the chaotic phase, the predictions of asymptotically deep neural networks on unseen test points will converge to zero exponentially quickly (see Sec. D.1),

Neglecting the relatively slowly varying polynomial term, this implies that we expect chaotic networks to fail to generalize when their depth is much larger than a scale set by . We confirm this scaling in Fig 1e.

We confirm these predictions for finite-width neural network training using SGD as well as gradient-flow on infinite networks in the experimental results; see Fig 2.

6.2. The Ordered Phase χ1 = σ2ω ˙T (q∗) < 1:

The ordered phase is defined by the stability of the fixed point. Here disparate inputs will end up converging to the same output at the end of the network. We show in Sec. B.2 that elements of the NNGP kernel and NTK have asymptotic dynamics given by,

where . Here all of the entries of converge to the same value, , and the limiting kernel has the form is the all-ones vector of dimension m (typically m will correspond to the number of datapoints in the training set). The NNGP kernel has the same structure with . Consequently both the NNGP kernel and the NTK are highly singular and feature a single non-zero eigenvalue, , with eigenvector

For large-but-finite depths, has (approximately) two eigenspaces: the first eigenspace corresponds to finite-depth corrections to

The second eigenspace comes from lifting the degenerate zero-modes has dimension with eigenvalues that scale like It follows that and so the conditioning number explodes exponentially quickly. We confirm the presence of the 1/l correction term in by plotting against l in Fig- ure 1b. Neglecting this correction, we expect networks in the ordered phase to become untrainable when their depth exceeds a scale given by

We now turn our discussion to the mean predictor. Equation 14 shows that we can write the finite-depth corrections to the NTK as . Here is the data-dependent piece that lifts the zero eigenvalues. In the appendix, converges to ; see Lemma 2. In Sec. D.3 we show that despite the singular nature of mean has a well-defined limit as,

where is some correction term. Thus, the mean predictor remains well-behaved and data dependent even in the infinite-depth limit. Thus, we suspect that networks in the ordered phase should be able to generalize whenever they can be trained. We confirm the asymptotic data-dependence of the mean predictor in Fig 1f.

6.3. The Critical Line χ1 = σ2ω ˙T (q∗) = 1

On the critical line the fixed point is marginally stable and dynamics become powerlaw. Here, both the diagonal and the off-diagonal elements of diverge linearly in the depth with . The condition number converges to a finite value and the network is always trainable. However, the mean predictor decreases linearly with depth. In particular we show in Sec. B.3,

For large l it follows that essentially has two eigenspaces: one has dimension one and the other has dimension

It follows that the condition number . Unlike in the chaotic and ordered phases, here converges with rate . Figure 1c confirms the for both FCN and CNN-F (the global average pooling in CNN introduces a correction term that we will discuss below). A similar calculation gives on the critical line.

In summary, converges to a finite number and the network ought to be trainable for arbitrary depth but the mean predictor decays as a powerlaw. Decay as is much slower than exponential and is slow on the scale of neural networks. This explains why critically initialized networks with thousands of layers could still generalize (Xiao et al., 2018).

6.4. The Effect of Convolutions

The above theory can be extended to CNNs. We will provide an informal description here, with details in Sec. F. For an input-images of size (m, k, k, 3) the NTK and NNGP kernels will have shape (m, k, k, m, k, k) and will contain information about the covariance between each pair of pixels in each image. For convenience we will let . In the large depth setting deviations of both kernels from their fixed point decomposes via Fourier transform in the spatial dimensions as,

where q denotes the Fourier mode with q = 0 being the zerofrequency (uniform) mode and are eigenvalues of certain convolution operator. Here are deviations from the fixed-point for the fully-connected deviation described above. We show that which implies that asymptotically the nonuniform modes become subleading as at large depths different pixels evolve identically as FCNs.

In Sec. F.2 we discuss the differences that arise when one combines a CNN with a flattening layer compared with an average pooling layer at the readout. In the case of flattening, the pixel-pixel correlations are discarded and . The plots in the first row of Figure 1 confirm that the of and of evolve al- most identically in all phases. Note that this clarifies an empirical observation in Xiao et al. (2018) (Figure 3 of Xiao et al. (2018)) that test performance of critically initialized CNNs degrades towards that of FCNs as depth increases. This is because (i) in the large width limit, the prediction of neural networks is characterized by the NTK and (ii) the NTKs of the two models are almost identical for large depth. However, when CNNs are combined with global average pooling a correction to the spectrum of the NTK (NNGP) emerges oweing to pixel-pixel correlations; this alters the dynamics of and . In particular, we find that global average pooling increases by a factor of d in the ordered phase and on the critical line; see Table 1 for the exact correction as well as Figures 1d for experimental evidence of this correction.

6.5. Dropout, Relu and Skip-connection

Adding a dropout to the penultimate layer has a similar effect to adding a diagonal regularization term to the NTK, which significantly improves the conditioning of the NTK in the ordered phase. In particular, adding a single dropout layer can cause to converge to a finite rather than diverges exponentially; see Figure 4 and Sec. E.

For critically initialized Relu networks (aka, He’s initialization (He et al., 2015)), the entries of the NTK also diverges linearly and and ; see Table 2 and Figure 3. In addition, adding skip-connections makes all entries of the NTK to diverge exponentially, resulting exploding of gradients. However, we find that skip connections do not alter the dynamics of . Finally, layer normalization could help address the issue of exploding of gradients; see Sec. C.

7. Experiments

Evolution of We randomly sample inputs with shape compute the exact NTK with activation function Erf using the Neural Tangents library (Novak et al., 2019a). We see excellent agreement between the theoretical calculation of (summarized in Table 1) and the experimental results Figure 1.

Maximum Learning Rates (Figure 2 (c)). In practice, given a set of hyper-parameters of a network, knowing the range of feasible learning rates is extremely valuable. As discussed above, in the infinite width setting, Equation 5 implies the maximal convergent learning rate is given by . From our theoretical results above, vary- ing the hyperparameters of our network allows us to vary over a wide range and test this hypothesis. This is shown for depth 10 networks varying We see that networks become untrainable when as predicted.

Trainability vs Generalization (Figure 2 (a,b)). We conduct an experiment training finite-width CNN-F networks with 1k training samples from CIFAR-10 with different configurations. We train each network using SGD with batch size b = 256 and learning rate . We see in Figure 2 (a) that deep in the chaotic phase we see that all configurations reach perfect training accuracy, but the network completely fails to generalize in the sense test accuracy is around 10%. As expected, in the ordered phase we see that although the training accuracy degrades generalization improves. As expected we see that the depth-scales and control trainability in the ordered phase and generalization in the chaotic phase respectively. We also conduct extra experiments for FCN with more training points (16k); see Figure 6.

CNN-P v.s. CNN-F: spatial correction (Figure 2 (d-f)). We compute the test accuracy using the analytic equations for gradient flow, Equation 6, which corresponds to the test accuracy of ensemble of gradient descent trained neural networks taking the width to infinity. As above, we use 1k training points and consider a grid of configurations for . We plot the test performance of CNN-P and CNN-F and the performance difference in Fig 2 (d-f). As expected, we see that the performance of both CNN-P and CNN-F are captured by in the ordered phase and by in the chaotic phase. We see that the test performance difference between CNN-P and CNN-F exhibits a region in the ordered phase (a blue strip) where CNN-F outperforms CNN-P by a large margin. This performance difference is due to the correction term d as predicted by the -row of Table 1. We also conduct extra experiments densely varying Together these results provide an extremely stringent test of our theory.

8. Conclusion and Future Work

In this work, we identify several quantities (, , and ) related to the spectrum of the NTK that

Figure 2. Trainability and generalization are captured by (a,b) The training and test accuracy of CNN-F trained with SGD. The network is untrainable above the green line because is too large and is ungeneralizable above the orange line because is too small. (c) The accuracy vs learning rate for FCNs trained with SGD sweeping over the weight variance. (d,e) The test accuracy of CNN-P and CNN-F using kernel regression. (f) The difference in accuracy between CNN-P and CNN-F networks.

control trainability and generalization of deep networks. We offer a precise characterization of these quantities and provide substantial experimental evidence supporting their role in predicting the training and generalization performance of deep neural networks. Future work might extend our framework to other architectures (for example, residual networks with batch-norm or attention architectures). Understanding the role of the nonuniform Fourier modes in the NTK in determining the test performance of CNNs is also an important research direction.

In practice, the correspondence between the NTK and neural networks is often broken due to, e.g., insufficient width, using a large learning rate, or changing the parameterization. Our theory does not directly apply to this setting. As such, developing an understanding of training and generalization away from the NTK regime remains an important research direction.

Acknowledgements

We thank Jascha Sohl-dickstein, Greg Yang, Ben Adlam, Jaehoon Lee, Roman Novak and Yasaman Bahri for useful discussions and feedback. We also thank anonymous reviewers for feedback that helped improve the manuscript.

References

Allen-Zhu, Z., Li, Y., and Song, Z. A convergence theory for deep learning via over-parameterization. arXiv preprint arXiv:1811.03962, 2018.

Arora, S., Du, S. S., Hu, W., Li, Z., Salakhutdinov, R., and Wang, R. On exact computation with an infinitely wide neural net. arXiv preprint arXiv:1904.11955, 2019.

Bahdanau, D., Cho, K., and Bengio, Y. Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473, 2014.

Bergstra, J. and Bengio, Y. Random search for hyperparameter optimization. Journal of Machine Learning Research, 13(Feb):281–305, 2012.

Blumenfeld, Y., Gilboa, D., and Soudry, D. A mean field theory of quantized deep networks: The quantizationdepth trade-off. arXiv preprint arXiv:1906.00771, 2019.

Chen, M., Pennington, J., and Schoenholz, S. Dynamical isometry and a mean field theory of RNNs: Gating enables signal propagation in recurrent neural networks. In International Conference on Machine Learning, 2018.

Chizat, L. and Bach, F. On the global convergence of gradi- ent descent for over-parameterized models using optimal

transport. In Advances in neural information processing systems, pp. 3040–3050, 2018.

Chizat, L., Oyallon, E., and Bach, F. On lazy training in differentiable programming. 2019.

Cubuk, E. D., Zoph, B., Mane, D., Vasudevan, V., and Le, Q. V. Autoaugment: Learning augmentation strategies from data. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), June 2019.

Daniely, A. SGD learns the conjugate kernel class of the network. In Advances in Neural Information Processing Systems 30. 2017.

Daniely, A., Frostig, R., and Singer, Y. Toward deeper under- standing of neural networks: The power of initialization and a dual view on expressivity. In Advances In Neural Information Processing Systems, 2016.

Du, S. S., Lee, J. D., Li, H., Wang, L., and Zhai, X. Gradient descent finds global minima of deep neural networks. arXiv preprint arXiv:1811.03804, 2018a.

Du, S. S., Zhai, X., Poczos, B., and Singh, A. Gradient descent provably optimizes over-parameterized neural networks, 2018b.

Garriga-Alonso, A., Rasmussen, C. E., and Aitchison, L. Deep convolutional networks as shallow gaussian processes, 2018.

Gilboa, D., Chang, B., Chen, M., Yang, G., Schoenholz, S. S., Chi, E. H., and Pennington, J. Dynamical isometry and a mean field theory of lstms and grus. CoRR, abs/1901.08987, 2019. URL http://arxiv.org/ abs/1901.08987.

Gilmer, J., Schoenholz, S. S., Riley, P. F., Vinyals, O., and Dahl, G. E. Neural message passing for quantum chemistry. In Proceedings of the 34th International Conference on Machine Learning - Volume 70, ICML’17, pp. 1263– 1272. JMLR.org, 2017. URL http://dl.acm.org/ citation.cfm?id=3305381.3305512.

Glorot, X. and Bengio, Y. Understanding the difficulty of training deep feedforward neural networks. In International Conference on Artificial Intelligence and Statistics, pp. 249–256, 2010.

Hayou, S., Doucet, A., and Rousseau, J. On the selection of initialization and activation function for deep neural networks. arXiv preprint arXiv:1805.08266, 2018.

Hayou, S., Doucet, A., and Rousseau, J. Mean-field be- haviour of neural tangent kernel for deep neural networks, 2019.

He, K., Zhang, X., Ren, S., and Sun, J. Delving deep into rectifiers: Surpassing human-level performance on imagenet classification. CoRR, abs/1502.01852, 2015. URL http://arxiv.org/abs/1502.01852.

Huang, J. and Yau, H.-T. Dynamics of deep neural net- works and neural tangent hierarchy. arXiv preprint arXiv:1909.08156, 2019.

Jacot, A., Gabriel, F., and Hongler, C. Neural tangent kernel: Convergence and generalization in neural networks. In Advances in Neural Information Processing Systems 31. 2018.

Jacot, A., Gabriel, F., and Hongler, C. Freeze and chaos for dnns: an ntk view of batch normalization, checkerboard and boundary effects, 2019.

Karakida, R., Akaho, S., and Amari, S.-i. Universal statistics of fisher information in deep neural networks: mean field approach. arXiv preprint arXiv:1806.01316, 2018.

Lee, J., Bahri, Y., Novak, R., Schoenholz, S., Pennington, J., and Sohl-dickstein, J. Deep neural networks as gaussian processes. In International Conference on Learning Representations, 2018.

Lee, J., Xiao, L., Schoenholz, S. S., Bahri, Y., Sohl- Dickstein, J., and Pennington, J. Wide neural networks of any depth evolve as linear models under gradient descent. arXiv preprint arXiv:1902.06720, 2019.

Matthews, A., Hron, J., Rowland, M., Turner, R. E., and Ghahramani, Z. Gaussian process behaviour in wide deep neural networks. In International Conference on Learning Representations, 4 2018. URL https:// openreview.net/forum?id=H1-nGgWC-.

Mei, S., Montanari, A., and Nguyen, P.-M. A mean field view of the landscape of two-layer neural networks. Proceedings of the National Academy of Sciences, 115(33): E7665–E7671, 2018.

Neal, R. M. Priors for infinite networks (tech. rep. no. crg- tr-94-1). University of Toronto, 1994.

Novak, R., Lee, L. X. J., Sohl-Dickstein, J., and Schoenholz, S. S. Neural tangents: Fast and easy infinite neural networks in python, 2019a. URL http://github.com/ google/neural-tangents.

Novak, R., Xiao, L., Lee, J., Bahri, Y., Yang, G., Hron, J., Abolafia, D. A., Pennington, J., and Sohl-Dickstein, J. Bayesian deep convolutional networks with many channels are gaussian processes. In International Conference on Learning Representations, 2019b.

Park, D. S., Chan, W., Zhang, Y., Chiu, C.-C., Zoph, B., Cubuk, E. D., and Le, Q. V. Specaugment: A simple data augmentation method for automatic speech recognition. arXiv preprint arXiv:1904.08779, 2019.

Pennington, J., Schoenholz, S. S., and Ganguli, S. The emer- gence of spectral universality in deep networks. arXiv preprint arXiv:1802.09979, 2018.

Philipp, G., Song, D., and Carbonell, J. G. The explod- ing gradient problem demystified-definition, prevalence, impact, origin, tradeoffs, and solutions. arXiv preprint arXiv:1712.05577, 2017.

Poole, B., Lahiri, S., Raghu, M., Sohl-Dickstein, J., and Ganguli, S. Exponential expressivity in deep neural networks through transient chaos. In Advances In Neural Information Processing Systems, pp. 3360–3368, 2016.

Raghu, M., Zhang, C., Kleinberg, J., and Bengio, S. Trans- fusion: Understanding transfer learning with applications to medical imaging. arXiv preprint arXiv:1902.07208, 2019.

Rotskoff, G. M. and Vanden-Eijnden, E. Neural networks as interacting particle systems: Asymptotic convexity of the loss landscape and universal scaling of the approximation error. arXiv preprint arXiv:1805.00915, 2018.

Schoenholz, S. S., Gilmer, J., Ganguli, S., and Sohl- Dickstein, J. Deep information propagation. International Conference on Learning Representations, 2017.

Silver, D., Hubert, T., Schrittwieser, J., Antonoglou, I., Lai, M., Guez, A., Lanctot, M., Sifre, L., Kumaran, D., Graepel, T., Lillicrap, T., Simonyan, K., and Hassabis, D. A general reinforcement learning algorithm that masters chess, shogi, and go through self-play. Science, 362(6419):1140–1144, 2018. ISSN 0036-8075. doi: 10.1126/science.aar6404. URL https://science. sciencemag.org/content/362/6419/1140.

Sirignano, J. and Spiliopoulos, K. Mean field analysis of neural networks. arXiv preprint arXiv:1805.01053, 2018.

Xiao, L., Bahri, Y., Sohl-Dickstein, J., Schoenholz, S., and Pennington, J. Dynamical isometry and a mean field theory of CNNs: How to train 10,000-layer vanilla convolutional neural networks. In International Conference on Machine Learning, 2018.

Yang, G. Scaling limits of wide neural networks with weight sharing: Gaussian process behavior, gradient independence, and neural tangent kernel derivation. arXiv preprint arXiv:1902.04760, 2019.

Yang, G. and Schoenholz, S. Mean field residual networks: On the edge of chaos. In Advances in Neural Information Processing Systems. 2017.

Yang, G., Pennington, J., Rao, V., Sohl-Dickstein, J., and Schoenholz, S. S. A mean field theory of batch normalization. arXiv preprint arXiv:1902.08129, 2019.

Zou, D., Cao, Y., Zhou, D., and Gu, Q. Stochastic gradient descent optimizes over-parameterized deep relu networks. arXiv preprint arXiv:1811.08888, 2018.

A. Related Work

Recent work Jacot et al. (2018); Du et al. (2018b); Allen-Zhu et al. (2018); Du et al. (2018a); Zou et al. (2018) proved global convergence of over-parameterized deep networks by showing that the NTK essentailly remains a constant over the course of training. However, in a different scaling limit the NTK changes over the course of training and global convergence is much more difficult to obtain and is known for neural networks with one hidden layer Mei et al. (2018); Chizat & Bach (2018); Sirignano & Spiliopoulos (2018); Rotskoff & Vanden-Eijnden (2018). Therefore, understanding the training and generalization properties in this scaling limit remains a very challenging open question.

Two excellent concurrent works (Hayou et al., 2019; Jacot et al., 2019) also study the dynamics of for FCNs (and deconvolutions in (Jacot et al., 2019)) as a function of depth and variances of the weights and biases. (Hayou et al., 2019) investigates role of activation functions (smooth v.s. non-smooth) and skip-connection. (Jacot et al., 2019) demonstrate that batch normalization helps remove the “ordered phase” (as in (Yang et al., 2019)) and a layer-dependent learning rate allows every layer in a network to contribute to learning. As opposed to these contributions, here we focus our effort on understanding trainability and generalization in this context. We also provide a theory for a wider range of architectures than these other efforts.

B. Signal propagation of NNGP and NTK

In this section, we assume that the activation function has a continuous third derivative. Recall that the recursive formulas for NNGP and the NTK are given by

where

Note that we have normalized each input to have variance and the diagonals of are equal to for all l. The off-diagonal terms of and are denoted by and , resp. and the diagonal terms are and , resp. The above equations can be simplified to

In what follows, we compute the evolution of and the spectrum and condition numbers of will use and to denote the maximum eigenvalues, the bulk eigenvalues and the condition number of

B.1. Chaotic Phase

B.1.1. CORRECTION OF THE OFF-DIAGONAL/DIAGONAL

The diagonal terms are relatively simple to compute. Equation 24 gives

i.e.

In the chaotic phase, , i.e. diverges exponentially quickly.

NTK of FC/CNN-F, CNN-P

Table 2. Evolution of the NTK/NNGP spectrum and as a function of depth l. The NTKs of FCN and CNN without pooling (CNN-F) are essentially the same and the scaling of for these networks is written in black. Corrections to these quantities due to the addition of an average pooling layer (CNN-P) with window size d is written in blue.

Now we compute the off-diagonal terms. Since in the chaotic, exists and is finite. Indeed, letting in equation 23, we have

which gives

To compute the finite depth correction, let

Applying Taylor’s expansion to the first equation of 23 gives

That is

Similarly, applying Taylor’s expansion to the second equation of 23 gives

Lemma 1. There exist a finite number

We want to emphasize that the limits are data-dependent, which was verified in Fig. 1e and 1f empirically.

Equation 37 gives

Summing over all l implies

We consider the spectrum of K and in this phase. For , we have (with and

where

The NNGP has two different eigenvalues: of order 1 and the size of the dataset. For large l, since the spectral norm of , the spectrum and condition number of

For

Thus is essentially a diverging constant multiplying the identity and

B.2. Ordered Phase

B.2.1. THE CORRECTION OF THE DIAGONAL/OFF-DIAGONAL

Similar, in the ordered phase we have the following. Lemma 2. There exists

Therefore the following limits exist

Since the proof is almost identical to Lemma 1, we omit the details.

For

which implies

which implies

B.3. The critical line.

Figure 3. Condition numbers of NNGP and their rate of convergence. In the chaotic phase, converges to a constant (see Table 2) for FCN, CNN-F (a) and CNN-P (b). However, it diverges exponentially in the ordered phase (c) and linearly on the critical line (d). For critical RELU network, diverges quadratically (e) while converges to a fixed number with rate Equation 92) and we plot the value of of the NTK in (f).

B.3.1. CORRECTION OF THE DIAGONALS/OFF-DIAGONALS.

We have on the critical line. Equation 24 implies , i.e. the diagonal terms diverge linearly. To capture the linear divergence of

We need to expand the first equation of 23 to the second order

Here we assume T has a continuous third derivative (which is sufficient to assume the activation to have a continuous third derivative.) The above equation implies

Then

Plugging Equation 73 into the above equation gives

B.3.2. THE SPECTRUM OF NNGP AND NTK

C. NNGP and NTK of Relu networks.

C.1. Critical Relu.

We only consider the critical initialization (i.e. He’s initialization (He et al., 2015)) , which preserves the norm of an input from layer to layer. We also normalize the inputs to have unit variance, i.e. that

This implies

which gives . Using the equations in Appendix C of (Lee et al., 2019) gives

and taking the derivative w.r.t.

Thus

This is enough to conclude (similar to the above calculation)

and

Recall that the diagonals of , resp. Therefore the spectrum and the condition numbers

C.2. Residual Relu

We consider the following “continuum” residual network

where t denotes the ‘depth’ and dt > 0 is sufficiently small and W and b are the weights and biases. We also set ). The NNGP and NTK have the following form

Taking the limit

Using the fact that (i.e. the inputs have unit variance), we can compute the diagonal terms Letting and applying the above fractional Taylor expansion to

Ignoring the higher order term and set

Applying this estimate to Equation 97 gives

Thus the limiting condition number of the NTK is m/3 + 1. This is the same as the above non-residual Relu case although the entries of blow up exponentially with t.

C.3. Residual Relu + Layer Norm

As we saw above, all the entries of and of a residual Relu network blow up exponentially, so do its gradients. In what follows, we show that normalization could help to avoid this issue. We consider the following “continuum” residual network with “layer norm”

We also set ). The normalization term makes sure has unit norm and removes the exponentially factor in both NNGP and NTK. To ses this, note that

Taking the limit

Using the fact that (i.e. the inputs have unit variance) and the mapping 2T is norm preserving, we see that because

This implies (note that and we assume the initial value .) The off-diagonal terms can be computed similarly and

Thus the condition number of the NTK is m/3 + 1. This is the same as the non-residual Relu case discussed above.

D. Asymptotic of P(Θ(l))

To keep the notation simple, we denote test, train, train, train. Recall that

We split our calculation into three parts.

D.1. Chaotic phase

In this case the diagonal diverges exponentially and the off-diagonals converges to a bounded constant . We further assume the input labels are centered in the sense contains the same number of positive (+1) and negative (-1) labels4. We expand about its “fixed point”

In the last equation, we have used the fact is balanced. Therefore

Remark 1. Without centering the labels and normalizing each input in to have the same variance, we will get a decay for instead of

D.2. Critical line

Note that in this phase, both the diagonals and the off-diagonals diverge linearly. In this case

Here we use to denote the all ‘1’ (column) vector with length equal to the number of training points in and is defined similarly. Note that the constant matrix B is invertible. By Equation 77

D.3. Ordered Phase

In the ordered phase, we have that , a symmetric matrix, represents the data-dependent piece of . To simply the notation, in the calculation below we will replace by . We also assume is invertible. To compute the mean predictor, , asymptotically we begin by computing

where we have set

Note that there is no divergence in as and the limit is well-defined. The term is independent from the input data.

We therefore see that even in the infinite-depth limit the mean predictor retains its data-dependence and we expect these networks to be able generalize indefinitely.

E. Dropout

In this section, we investigate the effect of adding a dropout layer to the penultimate layer. Let random variables

and for the output layer,

where and are iid Gaussians N(0, 1). Since no dropout is applied in the first L layers, the NNGP kernel and can be computed using Equation 20 and Equation 8. Let and denote the NNGP and NTK of the (L + 1)-th layer. Note that when . We will compute the correction induced by

implies that the NNGP kernel Schoenholz et al., 2017) is

σ

Now we compute the NTK