Feature Learning Greg Yang

Download as pdf or txt
Download as pdf or txt
You are on page 1of 63

Feature Learning in Infinite-Width Neural Networks

Greg Yang Edward J. Hu∗


Microsoft Research AI Microsoft Dynamics AI
[email protected] [email protected]
arXiv:2011.14522v1 [cs.LG] 30 Nov 2020

Abstract
As its width tends to infinity, a deep neural network’s behavior under gradient
descent can become simplified and predictable (e.g. given by the Neural Tangent
Kernel (NTK)), if it is parametrized appropriately (e.g. the NTK parametrization).
However, we show that the standard and NTK parametrizations of a neural network
do not admit infinite-width limits that can learn features, which is crucial for pre-
training and transfer learning such as with BERT. We propose simple modifications
to the standard parametrization to allow for feature learning in the limit. Using
the Tensor Programs technique, we derive explicit formulas for such limits. On
Word2Vec and few-shot learning on Omniglot via MAML, two canonical tasks
that rely crucially on feature learning, we compute these limits exactly. We find
that they outperform both NTK baselines and finite-width networks, with the latter
approaching the infinite-width feature learning performance as width increases.
More generally, we classify a natural space of neural network parametrizations
that generalizes standard, NTK, and Mean Field parametrizations. We show 1) any
parametrization in this space either admits feature learning or has an infinite-width
training dynamics given by kernel gradient descent, but not both; 2) any such
infinite-width limit can be computed using the Tensor Programs technique.
NTK Width 64 Width (Feature Learning)
type
state
city

Figure 1: PCA of Word2Vec embeddings of top US cities and states, for NTK, width-64, and width-∞
feature learning networks (Definition 5.1). NTK embeddings are essentially random, while cities and
states get naturally separated in embedding space as width increases in the feature learning regime.

1 Introduction
The study of infinite-width limits of neural networks, in particular the Neural Tangent Kernel (NTK),
has recently solved many longstanding open problems on the optimization and generalization of
overparametrized neural networks [23]. However, we prove that, in the NTK limit, (last layer) features
learned during pretraining are essentially the same as those from random initialization (Corollary 3.9
and Theorem G.12); this is verified empirically in Word2Vec in Fig. 1. As feature learning (e.g.
Imagenet and BERT) lies at the core of deep learning’s far-ranging impact so far [5, 11, 20], this
insight amounts to a fatal weakness of the NTK theory as a model of neural networks in practice.
We seek to capture feature learning in overparametrized networks by considering other parametriza-
tions and their infinite-width limits. By slightly modifying the standard parametrization (SP), in fact,
we can enable feature learning that is maximal in a sense to be explained shortly. We describe how to
compute this limit exactly (and rigorously) via the Tensor Programs technique developed in [43–46].

Work done partly during the Microsoft AI Residency Program
Feature Learning Infinite-Width Networks on Real Tasks word2vec pretrained on text8
We explicitly calculate this limit for the tasks of Word2Vec 40
[28, 29] and few-shot learning on Omniglot via MAML [13],2

word analogy acc


two standard tasks relying crucially on feature learning. In 30
log2(width)
Word2Vec, an important early instance of large-scale language 6.0
20
pretraining, we must learn, in an unsupervised manner, word 8.0
embeddings so that similar words have close embeddings. Then 10 10.0
we test the learned embeddings on the word analogy task, which NTK/GP
asks questions of the kind “what to a queen is as a man to a 0
woman?” In few-shot learning, the model is asked to make 5 10 15
epoch
predictions given only a handful (e.g. 5) of labeled examples.
Metalearning/MAML makes this possible by having the model learn good representations of typical
examples that can adapt quickly, via a small number of SGD steps, to new few-shot learning tasks. On
both tasks, we find our feature learning infinite-width networks outperform both NTK baselines and
finite-width networks, with the latter approaching the infinite-width performance as width increases.
Figure above shows this for one of our Word2Vec results. See Section 9 for our other experiments.

abc-Parametrizations This paper studies a natural class of parametrizations, which we call the
abc-Parametrization and describe here. Consider an L-hidden-layer perceptron: For weight matrices
W 1 ∈ Rn×d and W 2 , . . . , W L ∈ Rn×n , and nonlinearity φ : R → R, such a neural network on
input ξ ∈ Rd is given by h1 (ξ) = W 1 ξ ∈ Rn , and
xl (ξ) = φ(hl (ξ)) ∈ Rn , hl+1 (ξ) = W l+1 xl (ξ) ∈ Rn , for l = 1, . . . , L − 1, (1)
L+1 L L+1 1×n
and the network output (also called the logit(s)) is f (ξ) = W x (ξ) for W ∈R . An
abc-parametrization is specified by a set of numbers {al , bl }l ∪ {c} such that
(a) We parametrize each weight as W l = n−al wl for actual trainable parameter wl
l
(b) We initialize each wαβ ∼ N (0, n−2bl ), and
(c) The SGD learning rate is ηn−c for some width-independent η.3 4

Examples: The NTK parametrization (NTP) [23] has a1 = 0 and al = 1/2 for l ≥ 2; bl = 0 for
all l; c = 0. When depth L = 1, the Mean Field parametrization (MFP) [9, 26, 37, 39] has a1 = 0,
a2 = 1; bl = 0 for all l; c = −1. The standard parametrization (SP) available as the default setting in
PyTorch [33]5 has al = 0 for all l; b1 = 0 and bl = 1/2 for l ≥ 2; c = 0. However, we shall see that
c is too small (learning rate too large) in SP. We can define abc-parametrization and generalize our
results to arbitrary neural architectures (Appendix C), but we shall focus on MLPs in the main text.

Dynamical Dichotomy For any abc-parametrization, if c is too small (i.e. learning rate too large),
SGD can lead to blowup of preactivation and/or logits; we say this parametrization is unstable. In
practice this translates to numerical issues. If c is too large (i.e. learning rate too small), then the
function computed by the network does not change in finite time; we say this parametrization is
trivial. We prove what we call the Dynamical Dichotomy theorem (Corollary 3.9):
Any nontrivial stable abc-parametrization yields a (discrete-time) infinite-width limit.
This limit either 1) allows the embedding xL (ξ) to evolve nontrivially (Definition 3.5) or
2) is described by kernel gradient descent in function space (Definition 3.7), but not both.
We call the former kind a feature learning limit and the latter a kernel limit. For 1-hidden-layer MLPs,
the former is exemplified by MFP, and the latter, NTP. This dichotomy implies that certain functional
dynamics, such as higher order generalizations of the NTK dynamics, are not valid infinite-width
limits (see Remark 3.12). In addition, the neural network function f (defined in Eq. (1)) in any feature
learning limit must be identically 0 at initialization (see Corollary 3.10).6
2
Short for Model Agnostic Meta-Learning
3
Observe that by changing al , bl while holding al + bl fixed, we effectively give layer l its own learning rate.
4
One can further include a set of constants in front of n−al and n−bl , for example powers of input dimension
d, but we shall keep it simple here as we are only concerned with scaling behavior with n.
5
This is also known as the “fanin” or “Lecun” initialization; “Kaiming” initialization is the same up to
multiplicative constants. The default in Tensorflow [1] uses Glorot initialization, where the variance of an entry
scales like 1/(f anin + f anout). This causes the first layer preactivation to converge to 0 as n → ∞, and thus
yields pathological behavior in the limit.
6
We stress this is in the n → ∞ limit, so does not contradict the feature learning seen in finite-width SP NN.

2
Standard Param. Does Not Learn Features We Verifying Max Learning Rate for P and SP
show that the SP (resp. NTP) can only allow

Maximal Update Param


0.5 0.5
O(1/width) (resp. O(1)) learning rate (i.e. c = 1,
0.4 width 0.4
resp. c = 0), so as to avoid blowup, and yield kernel

valid acc
512
limits (Section 4). Instead, we propose a parametriza- 0.3 1024 0.3 max lr shifts
tion that has Θ(1) max learning rate and admits fea- 2048
0.2 4096 0.2
ture learning maximally: it allows every parameter 0.1
8192
0.1
to be updated maximally (in terms of scaling with 10 5 0 0 5 10
width) without leading to blowup (Section 5). We
thus call it the Maximal Update Parametrization (ab- 0.5 0.5

Standard Param
breviated MUP or µP). It is given by a1 = −1/2, 0.4 0.4

valid acc
aL+1 = 1/2, and al = 0 for all 2 ≤ l ≤ L; bl = 1/2 0.3 max lr shifts 0.3
for all l; c = 0. In a 1-hidden-layer MLP, this spe-
cializes to MFP, up to symmetry (see Eq. (5)). The 0.2 0.2
“feature learning limits” mentioned above in our main 0.1 0.1
experiments are µP limits. Figure to the right: We 10 5 0 0 5 10
log2(lr) log2(lr * width)
empirically verify our max learning rate predictions
on relu MLP with 2 hidden layers, trained with square loss on CIFAR10. We plot learning rate vs
accuracy in each subplot. Each curve represents MLP with a specific width. The right edge of each
curve indicates the max learning rate. The diagonal subplots scale the x-axes (log learning rate) in
the correct width-scaling for the corresponding parametrizations. We see, indeed, max learning rate
for SP scales like 1/width but is constant in µP.
SGD Training Progress
Key Theoretical Idea: Tensor Programs In
Section 7 and Appendix G.4, we describe the
Tensor Programs technique for deriving (rigor-
ously) the infinite-width training dynamics of

2nd backward pass


1st backward pass

2nd forward pass


1st forward pass

NN
any abc-parametrization. The main insight of
this approach is:
When width is large, every activation vec-
tor has roughly iid coordinates, at any time
during training. Using Tensor Programs, Take 𝑤𝑖𝑑𝑡ℎ → ∞ Limit via Tensor Programs
we can recursively calculate such coordi-
nate distributions, and consequently un-
derstand how the neural network function
evolves. NNGP NTK Feature Learning Limit
Prior works This work
The Tensor Programs technique was developed
in a series of papers [43–46] that proved the architectural universality of the Neural Network-Gaussian
Process (NNGP) Correspondence and the Neural Tangent Kernel (NTK) limits and showed how to
compute the corresponding infinite-width kernels. In the Figure above, the NNGP kernel can be
thought of as the “limit” of the first forward pass of a randomly initialized model; the NTK can be
similarly thought of as the “limit” of its first backward pass. The mechanics of calculating such limits
is 1) to write down the relevant neural network computation (e.g. the first forward pass in the NNGP
case) as a principled composition of matrix multiplication and coordinatewise nonlinearities, called
a Tensor Program, and 2) to recursively calculate the distribution of coordinates of each vector via
what’s called the Master Theorem. In this paper, we follow the exact same recipe, where in 1) we just
write down the entire SGD training instead of only the first step. More generally,
To derive the infinite-width limit of any neural computation (e.g. SGD training),
1) express it as a Tensor Program, and 2) mechanically apply the Master Theorem.
For example, we easily recover the (discrete-time) 1-hidden-layer mean field limit (Theorem 6.1).
It readily applies to practically any neural architecture (e.g. ResNet and Transformers)7 as well as
many common variants of SGD; however, in this paper, for pedagogical clarity, we only focus on
multilayer perceptrons. The generality of our approach allows us to easily adapt to settings outside the
traditional (CIFAR10-style) supervised classification, such as the Word2Vec and few-shot learning
tasks in this paper, or reinforcement learning and image generation outside of our scope.
7
e.g. by extending the example programs of [43, 45], which express only the first forward and backward
passes, into the entire training computation.

3
Our Contributions
1. Formulate a natural space of NN parametrizations (abc-parametrizations).
2. Prove Dynamical Dichotomy: Any nontrivial stable abc-parametrization yields either feature
learning or kernel limits, but not both.
3. Show both NTK and standard parametrizations yield kernel limits and propose the Maximal
Update Parametrization (µP) , which admits maximal feature learning in a suitable sense.
4. Use Tensor Programs to derive the infinite-width limit of µP and, more generally, the limit
of any abc-parametrization. We verify our theory using extensive experiments.
5. Show the µP limit outperforms both NNGP/NTK baselines and finite networks on 1)
Word2Vec and 2) Omniglot few-shot learning, trained via first-order MAML.

Tensor Programs Series While this work is self-contained, it is positioned as the 4th paper in the
series, following Yang [43, 45, 46]. We do not extend the Tensor Programs machinery further here,
but instead extract the first major payoff of the foundation laid in the earlier works. In fact, this paper
is the original motivation for this series; for a short history, see Appendix A.

2 Related Works
Comparison with Mean Field Limits For 1-hidden-layer MLP, the mean field limit [9, 26, 37, 39]
is equivalent to the µP limit modulo the symmetry of Eq. (5) (see Section 3.1). Several works also
proposed different versions of mean field frameworks for deeper MLPs [3, 12, 30, 31, 40]. However,
they did not consider the typical Gaussian N (0, 1/n) random initialization (or the appropriately
rescaled version in their respective parametrizations)8 , which has a Central-Limit effect as opposed
to a Law-of-Large-Numbers effect. For example, [3, 31] can cover the case of N (0, 1/n2 ), instead
of N (0, 1/n), initialization, which in fact causes the function to be stuck at initialization. Of these
works, the mean field limit of [12] has the form most similar to what we derive here. There, as we
do here, the coordinate distribution of each (pre)activation vector is tracked recursively. The main
difference is, while [12] has an atypical initialization involving `2 regression, we consider the usual
Gaussian N (0, 1/n) scheme. Such a (size n × n) Gaussian matrix in the middle of the network
has a distinctly different effect, more similar to that of a Gaussian matrix in the usual NNGP/NTK
calculation,9 than the “mean field” matrices considered in [12] and previous works [3, 30, 31, 40],
which has an “integral kernel” effect that is the straightforward generalization of matrices to function
spaces. Nevertheless, discrete time versions of the 1-hidden-layer mean field limit and of many of the
multilayer limits (such as [12, 31]) can be derived directly by writing the corresponding initialization
and training inside a Tensor Program and applying the Master Theorem (Theorem 7.4).

Discrete- vs Continuous-Time Gradient Descent At a high level, there are two natural limits
of neural networks training dynamics: large-width and continuous-time. Most prior works on
infinite-width limits of neural networks also took the continuous-time limit simultaneously, e.g.
[9, 23, 26, 37, 39]. In contrast, here we only take the large width limit, so that gradient descent stays
discrete-time. Then the results of these prior works can be recovered by taking another continuous-
time limit. From a practical perspective, the continuous-time limit is often unnatural, e.g. 1) because
the step size is usually as large as possible to speed up training, 2) because of the task (such as
reinforcement learning), or 3) because of the importance of hyperparameters like batch size that are
hidden away in such limits. On the theory side, taking the continuous-time limit can create issues
with 1) well-posedness and 2) existence and uniqueness of the resulting ODE/PDE. While they can
sometimes be proved to hold, they are artifacts of the continuous-time limit, as the corresponding
questions for the discrete time evolution are trivial, and thus not relevant to the behavior of real
networks.

Technical Assumptions Earlier works on neural tangent or mean field limits (e.g. [9, 12, 23,
26, 31, 37, 39]) assume various forms of regularity conditions, such as 1) 0th, 1st, and/or 2nd
order smoothness on the nonlinearity or other related functions, and 2) the support boundedness,
8
In fact, empirically we observe such Gaussian random initialization to be crucial to performance compared
to the mean-field-style initialization in this literature.
9
Actually, it is more similar to the Gaussian matrix in asymmetric message passing [4] in that care must be
taken to keep track of correlation between W and W > .

4
subgaussianity, and/or PDF smoothness of initialization distributions. These are often either unnatural
or difficult to check. In our work, the only assumption needed to rigorously obtain the infinite-width
limit is that the nonlinearity φ has a polynomially bounded weak 2nd derivative and that the loss
function has a continuous derivative w.r.t. the prediction (Assumption G.21). In particular, when
we specialize to the 1-hidden-layer case and derive the discrete time version of the mean field limit,
we cover the standard Gaussian initialization; in fact, we can allow any heavy-tailed initialization
that can be written as the image of a Gaussian under a pseudo-Lipschitz function, which include
nonsmooth PDFs and singular distributions.10 This generosity of technical assumptions is due to that
of the Tensor Programs Master Theorems proven in [43, 45, 46].

Training Time Many prior works (e.g. [2, 22, 26]) derived explicit time dependence of the con-
vergence to infinite-width limit, so that a larger width can allow the network to stay close to the
limit for longer. In this paper, our results only concern training time independent of width, since
our primary objective is to investigate the limit itself and its feature learning capabilities. Moreover,
recent evidence suggests that, given a fixed computational budget, it’s always better to train a larger
model for a shorter amount of time [25], which validates the practical relevance of our limit mode.
Nevertheless, it is possible to prove a quantitative version of the Tensor Programs Master Theorem,
by which one can straightforwardly allow training time to increase with width.

Classification of Parametrizations [8] pointed out that the weights move very little in the NTK
limit, so that linearization approximately holds around the initial parameters, in contrast to the mean
field limit (for 1-hidden-layer networks) where the weights move substantially. For this reason, they
called the former “lazy training” and the latter “active training,” which are classified nonrigorously by
a multiplicative scaling factor of the logit (similar to n−aL+1 in this paper). While these terms are not
formally defined, they intuitively correspond to the kernel and feature learning regimes in our paper.
From a different perspective, [27] observed that the NTK and mean field limit can be thought of as
short and long time-scale regimes of the mean field evolution equations. Neither of the above works
attempted to formally classify natural parametrizations of neural networks. In contrast, [42] studied a
toy class of neural networks in the context of implicit regularization due to the scale α of initialization
(which is closely related to logit multiplier of [8] noted above). They identified the α → ∞ limit (of
the scale α, not of width) with the “kernel regime” and the α → 0 limit with what they call the “rich
regime”. They showed that the former is implicitly minimizing an `2 risk while the latter, an `1 risk.
They claim width allows the toy model to enter the kernel regime more naturally, but as we see in this
work, both kernel and feature learning regimes are admissible in the large width limit of a standard
MLP. Closer to our approach, [16] studied what amounts to a 2-dimensional subspace of the space of
stable abc-parametrizations for L = 1. They proposed a notion of stability which is similar to the
combination of stability and nontriviality in this paper. They characterized when the Neural Tangent
Kernel, suitably generalized to any parametrization and playing a role similar to the feature kernel
in this paper, evolves over time. However, to simplify the proofs, they assumed that the gradients
for the different weight matrices are estimated using different inputs, a very unnatural condition.
In contrast, here our results are for the usual SGD algorithm applied to MLPs of arbitrary depth.
In all of the above works and most of existing literature, not much attention is paid to the feature
learning capabilities of neural networks in the right parametrization, as opposed to our focus here. A
notable exception is [10], which showed that the mean field limit, but not the NTK limit, can learn
low dimension linear structure of the input distribution resulting in ambient-dimension-independent
generalization bounds.

Other Related Works [24] proposed a toy model to study how large learning rate can induce a
neural network to move out of the kernel regime in Ω(log(width)) time. Since our dichotomy result
only concerns training for O(1) time (which, as we argue above, is more practically relevant), there
is no contradiction. [41] also noted that standard parametrization leads√to unstable training
√ dynamics.
They then injected constants in the NTK parametrization, such as α/ n instead of 1/ n and tuned
α in the resulting kernel. [14] empirically observed that wider networks achieve better downstream
performance with linear transfer learning, even though on the original pretraining task there can be
little difference.

10
We won’t expand further here, but it can be derived straightforwardly from the Master Theorem
(Theorem 7.4).

5
3 Feature Learning vs Kernel Behavior
In this section, we give a characterization of training procedures that induce feature learning vs kernel
behavior; we will elaborate on what we mean by these two kinds of behavior below. We first motivate
this discussion by reviewing the well-known tangent kernel and mean field limits of a shallow neural
network.

3.1 Motivating Examples: Neural Tangent Kernel and Mean Field Limits
For simplicity, define a shallow network f (ξ) with input/output dimension 1 by
f (ξ) = V x(ξ) ∈ R, x(ξ) = φ(h(ξ)) ∈ Rn , h(ξ) = U ξ ∈ Rn . (2)
−av 1×n −au n×1
As a specialization of Eq. (1), we parametrize weights V = n v∈R and U = n u∈R ,
where the width n should be thought of as tending to ∞, and v, u should be thought of as the actual
trainable parameters. We will sample vα ∼ N (0, n−2bv ), uα ∼ N (0, n−2bu ) for α ∈ [n]. The
learning rate is ηn−c for some η independent of n.
For example, in the Neural Tangent Parametrization (abbreviated NTP) [23], au = bv = bu = 0,
av = 1/2, c = 0. The Mean Field Parametrization (abbreviated MFP) corresponds to av = 1,
au = bu = bv = 0, c = −1; however, as will be explained shortly, we will use the equivalent
formulation au = −1/2, av = bu = bv = 1/2, c = 0 in this section so c = 0 for both NTP and MFP.
We remark that the GP limit, i.e. training only the last layer of a infinite-wide, randomly initialized
network, is a special case of the NTK limit where the first layer is not trained. Everything we discuss
below about the NTK limit specializes to the GP limit appropriately.
Given an input ξ, the gradient of f can be calculated as
dx(ξ) = V, dh(ξ) = dx(ξ) φ0 (h(ξ)), dv(ξ) = n−av x(ξ), du(ξ) = n−au dh(ξ)ξ
where d • (ξ) is shorthand for ∇• f (ξ) (however, note that later in Section 6, d • (ξ) will stand for
n∇• f (ξ)). For loss function L : R × R → R, the loss gradient on a pair (ξ, y) is then given by
L0 (f (ξ), y)[dx(ξ), dh(ξ), dv(ξ), du(ξ)] (where L0 denotes derivative in first argument).
Note that one can keep the function f invariant while changing the magnitude of the gradient dv by
changing av , bv , holding av + bv constant; likewise for du. Thus, the trajectory of f stays fixed if,
for any θ ∈ R, we set au ← au + θ, av ← av + θ, bu ← bu − θ, bv ← bv − θ, c ← c − 2θ (also see
Eq. (5)). With θ = −1/2, this explains why the two formulations of MFP above are equivalent. Then,
for both NTP and MFP, we will consider the dynamics of f trained under stochastic gradient descent
with learning rate η = 1 and batch size 1, where the network is fed the pair (ξt , yt ) at time t, starting
with t = 0. This simplicity is intended to intuitively illustrate our points below, but we shall state
formal results regarding more common settings in Section 3.2.
Notation and Setup Below, when p we say a (random) vector v ∈ Rn has coordinate size O(na )
a 11 a
(written v = O(n )), we mean kvk2 /n = O(n ) with high probability for large n. Intuitively,
this means that each coordinate has a typical fluctuation of O(na ). Likewise if O(na ) is replaced
with Θ(na ) or Ω(na ). See Definition G.2 for a formal definition.
Let ft , ht , xt , Ut , Vt , dxt , dht , dvt , dut denote the corresponding objects at time t, with t = 0 corre-
sponding to random initialization. We also abuse notation and write xt = xt (ξt ), i.e. applying the
function xt specifically to tth input ξt ; similarly for ft , ht , dxt , dht , dvt , dut . These symbols will
never appear by themselves to denote the corresponding function, so this should cause no confusion.
Then SGD effectively updates U and V by
Ut+1 = Ut − χt n−au dut , Vt+1 = Vt − χt n−av dvt .
def
where χt = L0 (ft , yt ). Finally, let ∆•t =
def
•t − •0 , for all • ∈ {f, h, x, U, V, dx, dh, dv, du}. For
example, after 1 SGD update, we have, for any ξ ∈ R,
∆h1 (ξ) = h1 (ξ) − h0 (ξ) = −n−au χ0 ξdu0 = −n−2au χ0 ξ0 ξdh0
= −n−2au χ0 ξ0 ξdx0 φ0 (h0 ) (3)
−av
∆f1 (ξ) = V0 ∆x1 (ξ) + ∆V1 x1 (ξ) = V0 ∆x1 (ξ) − n dv0> x1 (ξ)
= V0 ∆x1 (ξ) − n−2av x>
0 x1 (ξ) (4)
11 a a
Contrast this with a common semantics of v = O(n ) as kvk = O(n ).

6
3.1.1 Key Observations
Let’s list a few characteristics of the NTK and MF limits in the context of the shallow network in
Eq. (2), and then discuss them in the general setting of deep MLP. We will keep our discussion
intuitive to carry across the key ideas.

Feature Evolution For a generic ξ ∈ R, its embedding vector x0 (ξ) has coordinates of Θ(1) size
√ and MFP. However, for any t ≥ 1 independent of n, ∆xt (ξ) generically has coordinate
in both NTP
size Θ(1/ n) in NTP but Θ(1) in MFP.
Example for t = 1: By Eq. (3), we have
∆h1 (ξ) = n−2au χ0 ξ0 ξdx0 φ0 (h0 ).
Plug in au = 0 for NTP. Observe that ξ0 , ξ, χ0 = Θ(1),12 so
∆h1 (ξ) = Θ(1) · dx0 φ0 (h0 ). (in NTP)
In addition, φ0 (h0 ) = Θ(1) because h0 = Θ(1), so
∆h1 (ξ) = Θ(1) · dx0 Θ(1). (in NTP)

Finally, dx0 = V0 = Θ(1/ n) in NTP. Altogether, this implies

∆h1 (ξ) = Θ(1/ n)

=⇒ ∆x1 (ξ) ≈ φ0 (h0 (ξ)) ∆h1 (ξ) = Θ(1/ n) → 0, as n → ∞. (in NTP)
On the other hand, in MFP, the only thing different is au = −1/2 and dx0 = Θ(1/n), which implies
∆h1 (ξ) = Θ(n) · Θ(1/n) Θ(1) = Θ(1) =⇒ ∆x1 (ξ) = Θ(1). (in MFP)

Feature Kernel Evolution Therefore the feature kernel Ft (ξ, ζ) = def


xt (ξ)> xt (ζ)/n does not
change in the NTK limit but it does in the MF limit, i.e. for any fixed t ≥ 1,13
lim Ft (ξ, ζ) = lim F0 (ξ, ζ), in NTP, but
n→∞ n→∞
lim Ft (ξ, ζ) 6= lim F0 (ξ, ζ), in MFP, in general.
n→∞ n→∞

Indeed, regardless of parametrization, we have


1
x0 (ξ)> x0 (ζ) + ∆xt (ξ)> x0 (ζ) + x0 (ξ)> ∆xt (ζ) + ∆xt (ξ)> ∆xt (ζ) .

Ft (ξ, ζ) =
n

In NTP, because ∆xt (ξ) = Θ(1/ n) as noted above,
n n
1 1X 1X
∆xt (ξ)> x0 (ζ) = ∆xt (ξ)α x0 (ζ)α = O(n−1/2 ) = O(n−1/2 ),
n n α=1 n α=1

and likewise the other terms involving ∆xt will vanish as n → ∞. But Pnin MFP, ∆xt (ξ) = Θ(1) will
in general be correlated with x0 (ζ) such that n1 ∆xt (ξ)> x0 (ζ) = n1 α=1 Θ(1) = Θ(1).
It may seem somewhat puzzling how the NTK limit induces change in f without feature or feature
kernel evolution. We give some intuition in Appendix B.

Pretraining and Transfer Learning The simple fact above about the feature kernel K implies
that the NTK limit is unable to perform linear transfer learning. By linear transfer learning, we mean
the popular style of transfer learning where one discards the pretrained linear classifier layer and
train a new one on top of the features (e.g. x in our example), which are fixed. Indeed, this is a linear
problem and thus only depends on the kernel of the features. If this kernel is the same as the kernel at
initialization, then the pretraining phase has had no effect on the outcome of this “transfer” learning.
In fact, a more sophisticated reasoning shows pretraining in the NTK limit is no better than random
initialization for transfer learning even if finetuning is performed to the whole network, not just
12
χ0 = L0 (f0 , y0 ) = Θ(1) because f0 has variance Θ(1).
13
here the limit should be construed as almost sure limits; see Theorem 7.4.

7
Table 1: We summarize the abc values of SP (standard), NTP (Neural Tangent), MFP (Mean Field,
for 1-hidden-layer nets), µP (Maximal Update, ours). We show the minimal value of c such that the
parametrization is stable (Definition G.4). We also list the quantities r, 2aL+1 + c, aL+1 + bL+1 + r
involved in stability, feature learning, and kernel regime properties of the parametrizations. Here we
only focus on scaling with n and ignore dependence on input dimension. Recall the MLP definition:
h1 = W 1 ξ ∈ Rn , xl = φ(hl ) ∈ Rn , hl+1 = W l+1 xl ∈ Rn , f (ξ) = W L+1 xL
Definition SP (w/ LR n1 ) NTP MFP (L = 1) µP (ours)

  −1/2 l = 1
0 l=1 0 l=1
al W l = n−al wl 0 1/2 l ≥ 2
0 2≤l≤L
1 l=2
/2 l =L+1
1

0 l=1
bl l
wαβ ∼ N (0, n−2bl ) 1/2 l ≥ 2
0 0 1/2

c LR = ηn−c 1 0 −1 0
r Definition 3.2 1/2 1/2 0 0
2aL+1 + c 1 1 1 1
aL+1 + bL+1 + r 1 1 1 1
Nontrivial? X X X X
Stable? X X X X
Feature Learning? X X
Kernel Regime? X X

the classifier layer. This remains true if we replace the linear classifier layer by a new deep neural
network. See Remark G.15 and Theorem G.16. The Word2Vec experiment we do in this paper is a
linear transfer task.
In some other settings, such as some settings of metalearning, like the few-shot learning task in
this paper, the last layer of the pretrained network is not discarded. This is called adaptation. Then
the NTK limit does not automatically trivialize transfer learning. However, as will be seen in our
experiments, the NTK limit still vastly underperforms the feature learning limit, which is exemplified
by the MF limit here.

Kernel Gradient Descent in Function Space In NTP, as n → ∞, h∇U,V f0 (ξ), ∇U,V f0 (ζ)i
converges to some deterministic value K(ξ, ζ) such that K forms a kernel (the NTK). Then, in
this limit, if the learning rate is η, the function f evolves according to kernel gradient descent
ft+1 (ξ) = ft (ξ) − ηK(ξ, ξt )χt . However, this shouldn’t be the case for the MF limit. For example,
if φ is identity, then intuitively ft+1 (ξ) − ft (ξ) should be quadratic in η, not linear, because two
layers are updated at the same time.

3.2 abc-Parametrizations and Dynamical Dichotomy

In this section, we broaden our scope to the abc-parametrizations of deeper MLPs, defined by Eq. (1),
and their infinite-width limits. In Table 1, we summarize the {al , bl }l ∪ {c} values of various
abc-parametrizations in the literature.
Assumption 3.1. Our main results in this section (and this section only) will assume φ is either tanh
or a smooth version of relu called σ-gelu (see Definition G.1), for sufficiently small σ > 0 (which
means σ-gelu approximates relu arbitrarily well).

Note this assumption is only needed for the classification of abc-parametrizations. For deriving the
infinite-width limits, the much weaker Assumption G.21 suffices. We believe our results here will
hold for generic nonlinearities, but making this precise is outside our scope. (See Remark G.14 for
some discussion).

Symmetries of abc-Parametrizations As above, we can scale the parameter gradients ∇wl f


arbitrarily while keeping f fixed, if we vary al , bl while fixing al + bl : ∇wl f is scaled by n−θ if
al ← al + θ, bl ← bl − θ. In other words, changing al , bl this way effectively gives wl a per-layer

8
learning rate. If we apply this gradient with learning rate ηn−c , then the change in W l is scaled by
ηn−c−2θ . Consequently, if c ← c − 2θ, then W l is not affected by the change in al , bl . In summary,
∀θ ∈ R : ft (ξ) stays fixed for all t and ξ if we set al ← al + θ, bl ← bl − θ, c ← c − 2θ. (5)

Stable abc-Parametrizations We will only consider abc-parametrizations such that, as n → ∞,


1) the preactivations {hl }l and activations {xl }l have Θ(1) coordinates at initialization, and 2) their
coordinates and the logit f (ξ) all stay O(1) throughout the course of SGD.14 Otherwise, they tend
to ∞ with n, eventually going out of floating point range. Indeed, this is an acute and real problem
common in modern deep learning, where float16 is necessary to train large models. We call any such
parametrization stable (see Definition G.4 for a formal definition). Thus unstable parametrizations
are of no practical interest.
It turns out stable abc-parametrizations can be characterized by a set of inequalities on {al , bl }l ∪ {c}
(so that the stable ones form a polyhedron). To present these inequalities succinctly, it’s useful to
define
Definition 3.2. For any abc-parametrization, we write r for the quantity
L
def
r= min(aL+1 + bL+1 , 2aL+1 + c) + c − 1 + min [2al + I(l = 1)] .
l=1

For example, in NTP, r = 1/2, while in MFP (when L = 1), r = 0. Intuitively, r is the exponent
−r
such that ∆xL t (ξ) = Θ(n ). Thus, to avoid activation blowup, we want r ≥ 0; to perform feature
learning, we want r = 0.
Theorem 3.3 (Stability Characterization, c.f. Theorem G.6). An abc-parametrization is stable iff all
of the following are true (with intuitions in parentheses):
1. ((pre)activations xl0 , hl0 at initialization are Θ(1) and logits f0 are O(1))
a1 + b1 = 0; al + bl = 1/2, ∀l ∈ [2, L]; aL+1 + bL+1 ≥ 1/2. (6)

2. (features don’t blowup, i.e. ∆xlt = O(1) for all l)


r ≥ 0. (7)

3. (logits don’t blow up during training, i.e. ∆WtL+1 xL L+1


t , W0 ∆xL
t = O(1))
2aL+1 + c ≥ 1; aL+1 + bL+1 + r ≥ 1. (8)

Nontrivial abc-Parametrizations Among stable abc-parametrizations, there are also those where
f does not change throughout training in the infinite-width limit. We say such parametrizations are
trivial. Our dichotomy result will only apply to nontrivial stable abc-parametrizations.15
Nontrivial abc-parametrizations can also be described by a disjunction of equations on {al , bl }l ∪
{c} (geometrically, they correspond to the union of two faces on the polyhedron of stable abc-
parametrizations).
Theorem 3.4. A stable abc-parametrization is nontrivial iff aL+1 + bL+1 + r = 1 or 2aL+1 + c = 1.

Feature Learning Below, for brevity, we say training routine to mean the package of learning
rate ηn−c , training sequence {(ξt , yt )}t≥0 ,16 and a loss function L(f (ξ), y) that is continuously
differentiable in the prediction of the model f (ξ). As above, we use •t to denote the object • after t
steps of SGD.
Definition 3.5 (c.f. Definitions G.9 and G.10). We say an abc-parametrization admits feature
learning (resp. evolves the feature kernel) if, as n → ∞, ∆xL t (ξ) has Ω(1) coordinates (resp.
1 L > L L > L
n (x t (ξ) xt (ζ) − x 0 (ξ) x 0 (ζ)) = Ω(1)) for some training routine, time t ≥ 1, and input ξ (resp.
ξ, ζ).1718
14
but they may depend on training time and η; in particular, it’s possible that they diverge with time.
15
In particular, it’s possible for the function f to stay fixed with time, but for the features to change.
16
For simplicity, we only consider batch size 1; it’s straightforward to generalize to larger batch sizes.
17
For the sake of streamlining the main text presentation, we defined feature learning and feature kernel
evolution slightly differently than in Definition G.9, but ultimately they are equivalent as a result of our theorems.
18
We note that this is a rather weak notion of “feature learning”, as we only require that the embedding xL t (ξ)
changes from its initialization for some scenario, rather than, say for generic scenarios; nor do we speak at all

9
MFP, in the 1-hidden-layer case, is an example of feature learning parametrization.
Intuitively, feature kernel evolution implies feature learning, but a priori it seems possible that the
latter can occur without the former (akin to some kind of rotation of features). If so, then, e.g. in
terms of linear transfer learning, the pretraining ultimately had no benefit. But, in fact,
Theorem 3.6. A nontrivial stable abc-parametrization admits feature learning iff it evolves the
feature kernel iff r = 0.

Kernel Regime While feature learning here is defined by looking at the embedding of an input ξ,
we can also look at the dynamics of the function represented by the neural network.
Definition 3.7 (c.f. Definition G.11). We say an abc-parametrization is in kernel regime if there
exists a positive semidefinite kernel K such that, for any training routine, time t ≥ 0, and input ξ, in
the n → ∞ limit,
ft+1 (ξ) = ft (ξ) − ηK(ξ, ξt )L0 (ft (ξt ), yt ), ∀t ≥ 0. (9)
In other words, SGD reduces to kernel gradient descent in the large n limit.
Theorem 3.8. A nontrivial stable abc-parametrization is in kernel regime iff r > 0.

NTP is a typical example of this, where r = 1/2 and K is given by the NTK.
Space of abc-Parametrizations
Dynamical Dichotomy Since a stable abc-parametrization
has either r = 0 or r > 0 by Eq. (7): Maximal
Update Mean Field
Corollary 3.9. A nontrivial stable abc-parametrization either (ours) when depth=1
Unstable
admits feature learning or is in kernel regime, but not both. or
Trivial
We note that we are under Assumption 3.1. For example, if φ
is linear, then this dichotomy doesn’t hold, as a 1-hidden-layer Kernel
Regime
linear network where only the first layer is trained would both
Standard
admit feature learning and is in kernel regime. 𝐿𝑅 = Θ(1/𝑤𝑖𝑑𝑡ℎ)

An interesting consequence of Dynamical Dichotomy is


Standard Neural
Corollary 3.10. Any nontrivial stable feature learning abc- 𝐿𝑅 = Θ(1) Tangent

parametrization must have limn→∞ f0 (ξ) = 0 for all ξ, where


the limit is almost sure. Figure 2: A Caricature of abc-
Parametrizations. The nontrivial
Theorems 3.6 and 3.8 and Corollary 3.10 are consequences of stable parametrizations form a high
the more general classification theorem Theorem G.12, which dimensional polyhedron. Those on
in addition shows: 1) feature learning in layer l would imply the a part of its boundary admit feature
same for layers l, . . . , L; 2) in any feature learning parametriza- learning, while all others are in ker-
tion, ft in the large n limit becomes deterministic, and thus is nel regime. µP is a vertex in the for-
incompatible with any Bayesian perspective (in contrast to the mer, while NTP, latter. See Fig. 5
NNGP limit). for a more geometrically accurate
Dynamical Dichotomy in the shallow perceptron case is illus- depiction.
trated by the NTK and MF limits, as presented in Section 3.1,
which shows the NTK limit exemplifies Theorem 3.8 while the MF limit exemplifies Theorem 3.6. We
present a simplified picture of abc-parametrizations in Fig. 2, but see Fig. 5 for a more geometrically
accurate depiction.
Remark 3.11 (Function Space Picture). A kernel regime limit resides solely in the function space
picture, i.e. the evolution of f at any time being solely determined by the function values {lim ft (ζ)}ζ
themselves (as opposed to the internal activations of f as well) along with η, L, and (ξt , yt ). Intuitively,
this cannot be true for the feature learning limit, and therefore, at least informally, Dynamical
Dichotomy is also a dichotomy over the sufficiency of the function space picture for determining
the training evolution: We can construct two settings where {lim ft (ζ)}ζ , η, L, and (ξt , yt ) are the
about the “quality” of feature learning, e.g. how it helps downstream tasks. But our proofs (see Appendix G.7)
will show that “some scenario” in fact implies much more general scenarios. In addition, we argue that such
formal weakness is more than compensated by our experiments, which show that infinite-width limits of feature
learning (in the sense defined here) abc-parametrized MLPs outperform finite MLPs and their NTK limits on
tasks (namely, Word2Vec and few-shot learning) where feature learning, in the colloquial notion of the phrase,
is crucial.

10
same but ft+1 are different. 1) The first setting is at t = 0, where lim ft (ζ) = 0 for all input ζ by
Corollary 3.10. Here a typical SGD will change f . 2) In the second setting, suppose φ is relu. Design
a sequence of inputs such that training the MLP on them with very large learning rate will make all
relu neurons saturated in the 0 region. Then f is everywhere 0, and an SGD step will not change that.
Remark 3.12 (Not All Dynamics are Infinite-Width Limits). Accordingly, a nonlinear function space
dynamics cannot be a valid infinite-width limit of some abc-parametrization. By nonlinear, we mean
ft+1 (ξ) − ft (ξ) is nonlinear in L0 (ft (ξt ), yt ). For example, any natural higher-order generalization
of Eq. (9) (perhaps derived from a Taylor expansion at initialization) is not a valid limit.19

Pretraining and Transfer Learning As in the shallow examples, Corollary 3.9 says that any
kernel regime parametrization (including NTP) trivializes pretraining and transfer learning20 in the
infinite-width limit.
By calculating r for the standard parametrization (SP), we can easily see that it cannot admit feature
learning in the sense here without becoming unstable. However, in the next section, we will manually
analyze the training dynamics in an SP MLP to give an intuition why this is the case. In turn, we then
propose a simple modification of SP, the Maximal Update Parametrization (MUP or µP), which does
admit feature learning and, in fact, does so maximally in a suitable sense. In the pedagogical spirit,
we will focus on the key insights and stress the right heuristics without dwelling on formal aspects.

4 Standard Parametrization
In this section, we give intuition for why gradient descent of neural network in standard parametriza-
tion (SP) will lead to logits blowup after 1 step, if the learning rate is ω(1/n), where n is the width.
In addition, we will see why, with learning rate O(1/n), SP is in kernel regime. We first consider the
simplest example and then state the general result at the end of the section.
To demonstrate the general principle in deep networks, it is necessary to consider the behavior of an
n × n matrix in the middle of the network. Thus, the simplest case is a 2-hidden-layer linear MLP,
i.e. Eq. (1) with L = 2 and φ = id. The standard parametrization is given by
al = 0 ∀l, b1 = 0, bl = 1/2 ∀l ≥ 2. (SP)
−c
We consider 1 step of SGD with learning rate n on a single data pair (ξ, y). Then we can without
ambiguity suppress explicit dependence on ξ and write
f = V h̄, h̄ = W h, h = U ξ, (10)
where Uαβ ∼ N (0, 1) and Wαβ , Vαβ ∼ N (0, 1/n) are the trainable parameters (simplifying the
notation in Section 3). As in Section 3, we use •t to denote the quantity • after t step of SGD.
Because we only focus on the 1st step of SGD, we lighten notation and write • = •0 .

Initialization Since U, W, V are independently sampled, a standard Central Limit argument would
show that h, h̄, f all have roughly iid Gaussian coordinates of variance Θ(1).

First Gradient Now let’s consider the gradients of f on the data pair (ξ, y), which are given by
dh̄ = V > , dh = W > dh̄,
dV = h̄, dW = dh̄ h> = V > h> , dU = dh ξ > . (11)
For simplicity, suppose we only update W by learning rate n−c (and leave U, V unchanged); our
conclusion will not change in the general case where we train all layers. Then with χ denoting the
loss derivative L0 (f, y), we can write
W1 = W − n−c χ dW.
We shall show now that c ≥ 1 or else f1 blows up with the width n after this SGD step.
19
It may seem that Neural Tangent Hierarchy [22], which allow some kind of higher order dynamics in the
function space, violates our observation. But their infinite-width limit is identical to NTK in the constant time
t = O(1) regime, which is what Remark 3.12 (and this paper) concerns. Moreover, here we are talking about
functional dynamics that doesn’t depend on n (because we are already at the n → ∞ limit) whereas their
functional dynamics does.
20
linear and nonlinear; see Theorem G.16.

11
After First SGD Step At t = 1, h1 = h since we did not update U , but
h̄1 = W1 h = h̄ − n−c χ dW h = h̄ − n−c χ · V > h> h (12)
−c > >
f1 = V h̄1 = f − n χ V V h h. (13)

Now, as noted above, h has iid Θ(1) coordinates, so h> h = Θ(n) ∈ R. Similarly, V ∈ R1×n has
Gaussian coordinates of variance Θ(1/n), so V V > = Θ(1) ∈ R. Finally, for typical loss function
L like MSE or cross entropy, χ = L0 (f, y) is of order Θ(1) because f fluctuates on the order Θ(1).
Altogether,
f1 = f − Θ(n1−c ).
Therefore, for f1 to remain O(1), we must have c ≥ 1, i.e. the learning rate is O(1/n).

Kernel Regime and Lack of Feature Learning Consequently, the network cannot learn features
in the large width limit if we would like the logits to not blow up. Indeed, this version of SGD where
only W is updated can be seen to correspond to the limit where
a1 = θ, b1 = −θ, a2 = 0, b2 = 1/2, a3 = θ, b3 = −θ + 1/2, θ → ∞.
With c = 1 as derived above, the parametrization is stable and nontrivial, as can be checked from
Theorems 3.3 and 3.4. Then we get r = 1/2 > 0, so by Corollary 3.9, this parametrization is in
kernel regime and does not admit feature learning. We can also see this directly from Eq. (12): from
our calculations above,
h̄1 − h̄ = O(n1−c ) V > = O(1) V >
whose coordinates have size O(n−1/2 ) since V ’s coordinates do, so there’s no feature learning (at
least in the first step). Finally, from Eq. (13), because V V > → 1 and n−c h> h = n−1 h> h → kξk2 ,
we get21
f1 − f → −χK(ξ, ξ) = def
−χkξk2 ,
i.e. f evolves by kernel gradient descent with the linear kernel. Our derivations here only illustrate
the first SGD step, but we can get the same conclusion from all steps of SGD similarly.
We summarize the general case below, which follows trivially from Theorem 3.3 and Corollary 3.9.
Theorem 4.1. An L-hidden-layer MLP in standard parametrization (see Eq. (SP) and Table 1) can
only allow SGD learning rate of order O(1/n) if we require limn→∞ E ft (ξ)2 < ∞ for all training
routine, time t, and input ξ. In this case, it is in kernel regime and does not admit feature learning.

5 Maximal Update Parametrization


As shown in the last section, the standard parametrization does not admit a feature learning infinite-
width limit without blowing up logits. Here we propose simple modifications of the standard
parametrization to make this possible
√ while maintaining stability: 1) To enable feature learning,
it suffices to divide the logits by n and use Θ(1) learning rate, i.e. set aL+1 = 1/2, c = 0 on
top of Eq. (SP); 2) to allow every layer to perform feature learning, we should furthermore set
a1 = −1/2, b1 = 1/2. We will see that this essentially means we update each weight matrix as
much as possible without blowing up the logits or activations, so we call this the Maximal Update
Parametrization (abbreviated MUP or µP).

5.1 Dividing Logits by n

For example, in the 2-hidden-layer linear MLP example above, the network would compute
1
f (ξ) = √ v h̄(ξ), h̄(ξ) = W h(ξ), h(ξ) = U ξ, (14)
n
where Uαβ ∼ N (0, 1) and Wαβ , vαβ ∼ N (0, 1/n) are the trainable parameters. Compared to SP
(Eq. (10)), h(ξ), h̄(ξ) stays the same; only the logit f (ξ) is scaled down. Again, to simplify notation,
we abbreviate • = •0 and suppress explicit dependence on ξ. This has two consequences
21
Formally, these are almost sure convergences, but we suppress these details to emphasize on intuition.

12
Logits at Initialization Converge to 0 since f has variance Θ(1/n) (compare to the GP limit of
MLP in SP at initialization).

Θ(1) Learning Rate and Feature Learning Even though f → 0, the loss derivative χ = L0 (f, y)
stays Θ(1) if y 6= 0. When we redo the calculation in Eq. (12), we see
h̄1 = h̄ − n−c−1/2 χ v > h> h = h̄ − Θ(n−c+1/2 )v > (15)
−c−1 > > −c
f1 = f − n χ vv h h = f − Θ(n ).
Because v has coordinates of size Θ(n−1/2 ), we see that h̄ and f both change by Θ(1) coordinatewise
if c = 0 (i.e. learning rate is Θ(1)). This directly illustrates feature learning after just 1 step of SGD.
For general MLPs, we can also check aL+1 = 1/2, c = 0 on top of Eq. (SP) implies r = 0 and thus
admits feature learning by Theorem 3.6.

Kernel Behavior or Lack Thereof The example we have here, where we only train the middle
layer in a linear MLP, actually is in kernel regime. This does not violate Corollary 3.9, however,
which assumes Assumption 3.1. If, for example, we have tanh nonlinearity, then it is easy to see the
µP SGD dynamics does not have a kernel limit: If so, then f1 − f is linear in the learning rate η. But
note h̄1 − h̄ is Θ(1) as n → ∞ and linear in η, as can be derived similarly to Eq. (15). √
Because tanh
is bounded, this cannot happen. Contrast this with SP or NTP, where h̄1 − h̄ is Θ(1/ n) and thus
“resides in the linear regime of tanh”, allowing perfect scaling with η.
In addition, even in an linear MLP, if we train the middle layer and the last layer, then the dynamics
intuitively will become quadratic in the weights, so will not have a kernel limit. Contrast this with SP
or NTP, which suppress these higher order interactions because the learning rate is small, and a first
order Taylor expansion heuristic holds.

How is this different from standard √ parametrization with learning rate 1/ n? √As shown
above, the logit f blows up like Θ( n) after 1 step of SGD with learning rate Θ(1/ n) in the
standard parametrization, but remains Θ(1) in our parametrization here. The reason these two
parametrizations seem similar is because in the 1st step, the weights receive the same updates modulo
the loss derivative χ = L0 (f, y). Consequently, xL L L L
1 − x and h1 − h are Θ(1) coordinatewise
in both cases. However, this update makes x1 correlated with W1 , so that W1L+1 xL
L L+1
1 (and f1 )
scales like Θ(n1−aL+1 −bL+1 ) due to Law of Large Numbers. Thus only in our parametrization here
(aL+1 = bL+1 =√1/2) is it Θ(1), while in standard parametrization (aL+1 = 0, bL+1 = 1/2) it
blows up like Θ( n). Contrast this with the behavior at initialization, where W L+1 and xL are
independent and zero-mean, so W L+1 xL scales like Θ(n1/2−aL+1 −bL+1 ) by Central Limit Theorem.

5.2 First Layer Parametrization


While this now enables feature learning, the first layer preactivation h effectively stays fixed through-
out training even if we were to train U . For example, if we update U in the linear MLP example
Eq. (14), then by Eq. (11),
U1 = U − n−c χ dU = U − n−c χ dhξ >
h1 = U1 ξ = h − n−c χ dhξ > ξ = h − Θ(n−c )dh
since ξ > ξ, χ = Θ(1). Now dh = W > dh̄ = W > √1n v > has roughly iid Gaussian coordinates, each
of size Θ(1/n), since √1n v > has coordinates of the same size. Therefore, even with c = 0, h changes
by at most O(1/n) coordinatewise, which is dominated by its value at initialization. This O(1/n)
change also induces a O(1/n) change in f , which would be dominated by the Θ(1) change due to
W ’s evolution, as seen in Eq. (15).
We therefore propose to set a1 = −1/2, b1 = 1/2 on top of Section 5.1’s parametrization. This
implies the forward pass of f remains the same but U ’s gradient is scaled up by n, so that h now
changes by Θ(1) coordinatewise. In summary, we define
Definition 5.1. The Maximal Update Parametrization (abbreviated MUP, or µP), in the context of
an L-hidden-layer MLP (Eq. (1)), is given by

−1/2 l = 1
c = 0, bl = 1/2 ∀l, al = 0 2≤l≤L
1/2 l = L + 1.

13
Notice that µP for a 1-hidden-layer perceptron is equivalent to the mean field parametrization by
Eq. (5). We also describe µP for any architecture in Appendix C.1.

5.3 What is µP Maximal In?

For technical reasons, we adopt Assumption 3.1 again for the formal results of this section.
In an abc-parametrization, the change in weight W = Wtl for any l ≥ 2 due to learning rate n−c
is δW =def
−n−c · n−2a dh x> where we abbreviated x = xl−1 l
t , h = ht , a = al . (We will use δ to
denote 1-step change, but ∆ to denote lifetime change). In the next forward pass, δW contributes
δW x̄ = −n1−c−2a (x> x̄/n)dh, where x̄ is the new activation due to change in previous layers’
weights. In general, x and x̄ are strongly correlated. Then x> x̄/n → R for some R 6= 0 by Law of
Large Numbers (as they both have Θ(1) coordinates in a stable parametrization). One can heuristically
see that dh has the same size as the last layer weights, which is Θ(n−(aL+1 +bL+1 ) + n−(2aL+1 +c) )
(where the first summand is from W0L+1 and the other from ∆WtL+1 ). Thus, δW x̄ is a vector
with Θ(n−rl ) =def
Θ((n−(aL+1 +bL+1 ) + n−(2aL+1 +c) )n1−c−2a ) coordinates. If rl > 0, then δW x̄
contributes vanishingly; if rl < 0, then δW x̄ blows up. For l = 1, we get similar insights after
accounting for the finite dimensionality of ξ.
Definition 5.2. For l ∈ [L], we say W l is updated maximally if ∆Wtl xl−1
t (ξ) has Θ(1) coordinates
for some training routine22 , time t ≥ 1, and input ξ.
Proposition 5.3. In a stable abc-parametrization, for any l ∈ [L], W l is updated maximally iff
def
rl = min(aL+1 + bL+1 , 2aL+1 + c) + c − 1 + 2al + I(l = 1) = 0.

Note that r (Definition 3.2) is the minimum of rl over all l. In µP, we can calculate that rl = 0 for all
l ∈ [L], so all W l , l ∈ [L], are updated maximally. Put another way, the final embedding xL (ξ) will
have nonvanishing (nonlinear) contributions from ∆W l of all l. These contributions cause the logit
f (ξ) to change via interactions with W0L+1 and ∆WtL+1 . If both W0L+1 and ∆WtL+1 are too small,
then the logit is fixed to its initial value, so all of the feature learning would have been useless.23 It’s
also possible for one to contribute vanishingly but not the other.24 But both contribute in µP.
Definition 5.4. We say W L+1 is updated maximally (resp. initialized maximally) if ∆WtL+1 xL
t (ξ) =
Θ(1) (resp. W0L+1 ∆xLt (ξ) = Θ(1)) for some training routine, time t ≥ 1, and input ξ.

Note Definition 5.4 is similar to Definition 5.2 except ∆WtL+1 xL l l−1 n


t (ξ) ∈ R but ∆Wt xt (ξ) ∈ R .
Proposition 5.5. In a stable abc-parametrization, W L+1 is 1) updated maximally iff 2aL+1 + c = 1,
and 2) initialized maximally iff aL+1 + bL+1 + r = 1.

We remark that, by Theorem 3.4, a parametrization is nontrivial iff W L+1 is maximally updated or
initialized. Using Propositions 5.3 and 5.5 and Theorem 3.3, we can now easily conclude
Theorem 5.6. In µP, W l is updated maximally for every l ∈ [L + 1], and W L+1 is also initialized
maximally. µP is the unique stable abc-parametrization with this property.

6 Deriving Feature Learning Infinite-Width Limit: Intuition and Examples


We propose the Tensor Programs technique for deriving the infinite-width limit of any abc-
parametrization. This ultimately just requires the researcher to mechanically apply a set of rules to
the computation graph underlying SGD. However, while operationally simple, this procedure would
seem “too magical” at first. In this section, through a series of examples, we seek to build intuition
for what is being automated by this procedure. Then, in the next section, we formally describe the
Tensor Programs framework.
22
Recall that training routine means a package of learning rate ηn−c , training sequence {(ξt , yt )}t≥0 , and a
loss function L(f (ξ), y) that is continuously differentiable in the prediction of the model f (ξ).
23
It is indeed possible to perform feature learning in a trivial parametrization, e.g. bl = 1/2 ∀l, a1 =
−1/2, a2 = 100 + 1/2, c = −100 in a 2-hidden-layer MLP.
24
e.g. take aL+1 = 100 + 1/2, bL+1 = −100 + 1/2, then ∆W L+1 is negligible.

14
Setup and Notation For pedagogical simplicity, we only consider input dimension d = 1 and
learning rate η = 1 here, but generalization to d > 1, η 6= 1 is straightforward. We consider SGD
with a singleton minibatch {(ξt , yt )} at time t = 0, 1, 2, . . ., where ξt is the network input and yt
is the label. We write Wtl for the matrix W l after t steps of such training. For any network input
ξ ∈ R, we write xlt (ξ) (resp. hlt (ξ), ft (ξ)) for the activation xl (resp. preactivation hl , logits f ) of
the network after t steps of SGD. We denote the scaled gradient n∇xlt ft (ξ) (resp. n∇hlt ft (ξ)) by
dxlt (ξ) (resp. dhlt (ξ)). For brevity, we abuse notation and use xlt (without being applied to ξ) to also
denote the vector xlt (ξt ) (applied specifically to ξt ); likewise for hlt , dhlt , dxlt , ft . We will not use xlt
on its own to denote the function ξ 7→ xlt (ξ) so this should not cause confusion. The loss function is
denoted L and the loss derivative L0 (logit, target) is in the first argument. We write χt = def
L0 (ft , yt ).

6.1 1-Hidden-Layer MLP


As mentioned above, for 1 hidden layer, the infinite-width µP limit is the same as the mean field
limit of [9, 26, 37, 39]. Nevertheless, we present a slightly different derivation of this that is more
consistent with the philosophy of Tensor Programs. Such a network on input ξ ∈ R is given by
f (ξ) = V x(ξ), x(ξ) = φ(h(ξ)), h(ξ) = U ξ, (16)
n×1 1×n
√ 1
for U ∈ R ,V ∈ R parametrized like U = nu, V = √n v and with initialization uαβ , vαβ ∼
25
N (0, 1/n). Then U0 (the initial value of U ) has iid N (0, 1) coordinates. It will turn out to be
convenient to represent each such coordinate distribution as a random variable Z U0 = def
N (0, 1).
nV0 def U0
Likewise, let Z = N (0, 1), independent from Z , represent the coordinate distribution of nV0
(we do nV0 instead of V0 so that the Z random variable is always independent of n). We derive
the µP limits of the first forward and backward passes manually before stating the general case. To
lighten notation, we suppress the t = 0 subscript (e.g. U = U0 , h = h0 , f = f0 , etc), as we will
spend some time on the first SGD step.

First Forward Pass After randomly initialization, the preactivation h = h(ξ) (where ξ = ξ0 ∈ R
is the first input) has iid coordinates, each a sample from Z h = def
ξZ U ∈ R. Naturally, x = x(ξ) has
Pn
iid coordinates as well, each a sample from Z = φ(Z ). Finally, f = V x = n1 α=1 (nV )α xα →
x def h

f˚ =
def
E Z nV Z x by Law of Large Numbers as n → ∞.26 In particular, f becomes deterministically
0 in this limit because V and U are independent. For a typical loss function L, the loss derivative
χ= def
L0 (f, y) then also become deterministic, χ → χ̊ =def
L0 (f˚, y).

First Backward Pass Similarly, dx = nV > (recall dxt = def


n∇xt ft ) has coordinates distributed like
0
Z = Z and dh = dx φ (h) has coordinates distributed like Z dh =
dx def nV def
Z dx φ0 (Z h ) = Z nV φ0 (Z h ).
Then SGD with learning rate 1 makes the following updates:

v1 = v − χx/ n =⇒ V1 = V − χx/n

u1 = u − χξ dh/ n =⇒ U1 = U − χξ dh.
Since χ converges to a deterministic limit χ̊, the coordinates of these updates are roughly iid,
corresponding to an update of Z random variables:
Z nV1 = Z nV − χ̊Z x , Z U1 = Z U − χ̊ξZ dh .

Second Forward Pass Thus V1 and U1 still have roughly iid coordinates after 1 SGD step. Then,
in the second forward pass, h1 has coordinates
Z h1 =
def
ξ1 Z U1 = ξ1 Z U − ξ1 χ̊ξZ dh = ξ1 Z U − ξ1 χ̊ξZ nV φ0 (Z h ),
x1 has coordinates Z x1 =
def
φ(Z h1 ), and the output is
n
1X
f1 = (nV1 )α xα → f˚1 =def
E Z nV1 Z x1 = E(Z nV − χ̊Z x )Z x1 (17)
n α=1

25
Again, more generally, we can insert constants in this parametrization, like U = √nd u, but we omit them
here for simplicity.
26
All convergence in this section will be almost sure, but to focus on the intuition here and less on the
formalities, we do not explicitly write this down.

15
as n → ∞. Then χ1 = def
L0 (f1 , y1 ) → χ̊1 =
def
L0 (f˚1 , y1 ) becomes deterministic. The gradient vectors
have roughly iid coordinates by a similar logic.

tth Iteration Repeating the above reasoning shows that at any time t (independent of n), we obtain

Theorem 6.1. Consider a 1-hidden-layer MLP in µP (Eq. (16)) and any training routine with
learning rate 1. Suppose φ0 is pseudo-Lipschitz.27 As n → ∞, for every input ξ, ft (ξ) converges
almost surely to f˚t (ξ) defined as follows:
a.s.
ft (ξ) −−→ f˚t (ξ) =
def
E Z nVt Z xt (ξ) , Z xt (ξ) =
def
φ(Z ht (ξ) ), Z ht (ξ) =def
ξZ Ut , (18)
def
χ̊t = L0 (f˚t , yt ), Z nVt+1 =
def
Z nVt − χ̊t Z xt , Z Ut+1 = def
Z Ut − χ̊t ξt Z nVt φ0 (Z ht ), (19)

with, as initial conditions, Z U0 and Z nV0 being independent standard Gaussians, where in Eq. (19)
we abbreviated f˚t = f˚t (ξt ), xt = xt (ξt ), ht = ht (ξt ).

As aforementioned, this is a discrete time, minibatched version of the mean field limit of [9, 26,
37, 39].28 When φ is identity, it’s easy to see that Z nVt and Z Ut are always (deterministic) linear
combinations of Z nV0 and Z U0 , say Z nVt = At Z nV0 + Bt Z U0 and Z Ut = Ct Z nV0 + Dt Z U0 . Then
the limit f˚t depends solely on At , Bt , Ct , Dt . By tracking their evolution, we get the following
greatly simplified formula for an infinite-width µP linear network.
Corollary 6.2. Consider a 1-hidden-layer linear MLP in µP (Eq. (16)) and any training routine
with learning rate 1. As n → ∞, for every input ξ, ft (ξ) converges almost surely to f˚t (ξ) defined as
follows:

f˚t (ξ) = (At Ct + Bt Dt )ξ, χ̊t = L0 (f˚t , yt ),


(At+1 , Bt+1 ) = (At , Bt ) − χ̊t ξt (Ct , Dt ),
(Ct+1 , Dt+1 ) = (Ct , Dt ) − χ̊t ξt (At , Bt ),
with initial condition A0 = D0 = 1, B0 = C0 = 0.

This can be easily generalized to larger input and output dimenions (see Appendix D.1). In a gist, such
an infinite-width µP linear network with input dimension d and output dimension do is equivalent to
a width-(d + do ) linear network with the same input/output dimensions but an “diagonal”, instead of
random, initialization. Our Word2Vec and MAML experiments will crucially rely on this simplifying
observation. We remark that, in contrast to our approach, such an observation would be obscured by
the PDE perspective of prior works [9, 26, 37, 39].

6.2 2-Hidden-Layer MLP: SGD with Partially Decoupled Backpropagation

A 2-hidden-layer MLP is given by


f (ξ) = V x̄(ξ), x̄(ξ) = φ(h̄(ξ)), h̄(ξ) = W x(ξ),
x(ξ) = φ(h(ξ)), h(ξ) = U ξ,

for U ∈ Rn×1 , W ∈ Rn×n , V ∈ R1×n parametrized like U = nu, V = √1n v and with initial-
ization uαβ , Wαβ , vαβ ∼ N (0, 1/n). The presence of the n × n Gaussian matrix W (“∞ × ∞”
as opposed to “∞× finite” like U or “finite ×∞” like V ) is new and has two major effects on the
infinite-width training dynamics: 1) A Central Limit effect from the random Gaussian nature of
W and 2) a correlation effect between W and its transpose W > . We isolate the first effect here by
analyzing a slightly different version of backpropagation (which has a different limit than normal
backpropagation), and then discuss the second effect in the next section. We abuse notation and
abbreviate W = W0 .
27
This roughly means that φ0 has a polynomially bounded weak derivative; see Definition E.3.
28
[9, 26, 37, 39] present the equations in terms of the PDF of Z random variables. Formally, the PDF limit
can be obtained by taking the continous-time limit of Eqs. (18) and (19) and then applying Focker-Planck. Note
our derivation, when formalized using the Tensor Programs framework below, does not require smoothness
and support assumptions on the initialization of U, V in those works: The initialization distribution here can be
replaced with any image of Gaussians under pseudo-Lipschitz functions, which includes nonsmooth and singular
distributions.

16
Partially Decoupled Backpropagation In this section, we analyze a version of SGD where the
backpropagation weights are partially decoupled from the forward propagation weights. Here, we
think of ∆Wt as the trainable weights, initialized at 0, and think of the Gaussian W as untrainable
“constants”. The forward pass proceeds normally29 with Wt = W + ∆Wt . But we sample and fix an
f of W > before training, and in the backward pass compute
iid copy W
f + ∆W > )dh̄t
dxt = (W instead of dxt = (W > + ∆Wt> )dh̄t = Wt> dh̄t . (20)
t

In particular, at initialization, we would have dx0 = Wf dh̄0 instead of dx0 = W > dh̄0 . Everything
30
else stays the same in the backward pass . Finally, each weight is still updated by SGD via the usual
outer products: with χt = def
L0 (ft , yt ),
√ √
vt+1 = vt − χt x̄> t / n, ∆wt+1 = ∆wt − χt dh̄t x>
t /n, ut+1 = ut − χt ξt dh>
t / n. (21)
√ √
Since V = v/ n, W = w, U = nu per µP, this causes the following changes in W s:
Vt+1 = Vt − χt x̄>
t /n, ∆Wt+1 = ∆Wt − χt dh̄t x>
t /n, Ut+1 = Ut − χt ξt dh>
t (22)
Note here we update ∆w and ∆W instead of w and W .

Why This Decoupled SGD? The reasons we talk about this version of SGD is that it isolates the
effect of having a Gaussian n × n matrix Wf in the backward pass, and we can derive its infinite-width
limit relatively easily using Central Limit heuristics. In the normal version of SGD, W
f would equal
>
W , and its correlation with W creates additional terms in the infinite-width dynamics, that are
better explained on their own.
Again, we walk through the first few forward and backward passes to gain some intuition for the
infinite-width limit, before stating the general case.

First Forward Pass is similar to that in Section 6.1 and follows the usual calculations involved in
deriving the NNGP31 .

First Backward Pass is similar to that in Section 6.1 and to calculations involved in deriving
Neural Tangent Kernel, except swapping W > with Wf (which at this point has no visible effect,
because of the Gradient Independence Phenomenon [45]; but the effect will become clear in the
second forward pass)32 . We end up with ∆W1 = −χ0 dh̄0 x>
0 , as usual.

Second Forward Pass As usual, we have Z h1 = ξ1 Z U1 = ξ1 Z U0 − χ̊0 ξ1 ξ0 Z dh0 and Z x1 =


φ(Z h1 ), reflecting the coordinate distributions of h1 and x1 33 . Next,
x>
0 x1
h̄1 = W x1 + ∆W1 x1 = W x1 − χ0 dh̄0 . (23)
n
x> x
On one hand, 1) 0n 1 → E Z x1 Z x0 by a Law of Large Numbers heuristic. On the other hand, 2)
by a Central Limit heuristic, W x1 should roughly have Gaussian coordinates Z W x1 correlated with
x> x
Z h̄0 = Z W x0 with Cov(Z W x1 , Z W x0 ) = lim 0n 1 = E Z x1 Z x0 . However, very importantly, this
29
i.e. ft = Vt x̄t , x̄t = φ(h̄t ), h̄t = (W + ∆Wt )xt , xt = φ(ht ), ht = U ξt .
30
i.e. dx̄t = nVt> , dh̄t = φ0 (h̄t ) dx̄t , dht = φ0 (ht ) dxt
31
1) h0 is iid Gaussian with coordinates drawn from Z h0 = ξ0 Z U0 ; 2) x0 has coordinates Z x0 = φ(Z h0 );
3) h̄0 = W x0 has roughly iid coordinates drawn from a zero-mean Gaussian Z h̄0 by a Central Limit heuristic,
where Z h̄0 is correlated with Z h̄0 (ξ) for any ξ (including ξ = ξ0 ) with covariance Cov(Z h̄0 , Z h̄0 (ξ) ) =
limn→∞ n1 x> x0 x0 (ξ)
; 4) x̄0 has coordinates Z x̄0 = φ(Z h̄0 ); 5) f0 = n1 n
P
0 x0 (ξ) = E Z Z α=1 (nV0 )α x̄0α →
˚ def
f0 = E Z nV0 x̄0
Z by a Law of Large Number heuristic.
32
1) dx̄0 = nV0> so Z dx̄0 = Z nV0 ; 2) Z dh̄0 = φ0 (Z h̄0 ) Z dx̄0 ; 3) Z dx0 = Z W dh̄0 is Gaussian with
f

covariance Cov(Z dx0 , Z dx0 (ξ) ) = limn→∞ n1 dh> 0 dh0 (ξ) = E Z


dh0 dh0 (ξ)
Z for any input ξ; 4) Z dh0 =
h0 dx0 ˚
φ (Z ) Z . Since f converges to a deterministic number f0 , we also generically have L0 (f, y0 ) → χ̊0 def
0
=
L0 (f˚0 , y0 ). Finally, the weights are updated like Eq. (22).
33
Recall they abbreviate h1 (ξ1 ) and x1 (ξ1 )

17
Central Limit heuristic is correct only because we used W f in backprop instead of W > ; otherwise,
h1 has a strong correlation with W through dh0 = φ (h0 ) (W > dh̄0 ), and thus so does x1 , so
0

that W x1 no longer has Gaussian coordinates. This is the “second major effect” referred to in the
beginning of this section. See Section 6.3 for how to handle this correlation.
In any case, in our scenario here,
Z h̄1 =
def
Z W x1 − cZ dh̄0 , where c = χ̊0 E Z x1 Z x0 ,
is a linear combination of a Gaussian variable and the gradient dh̄0 ’s coordinate random vari-
Pn
able. Finally, Z x̄1 = φ(Z h̄1 ) and the logit is f1 = n1 α=1 (nV1 )α x̄1α → f˚1 =
def
E Z nV1 Z x̄1 =
nV0 x̄1 x̄0 x̄1
E Z Z − χ̊0 E Z Z .

Second Backward Pass Everything proceeds just like in the 1-hidden-layer case34 except for the
computation of
>
dx1 = Wf dh̄1 − ∆W1> dh̄1 = W f dh̄1 − χ0 x0 dh̄0 dh̄1 .
n
dh̄>
0 dh̄1 dh̄0 dh̄1
Like in the computation of h̄1 in Eq. (23), n → EZ Z and Wf dh̄1 is roughly Gaussian
(and correlated with W
f dh̄0 in the natural way). But again, for this Gaussian intuition to be correct, it
f here instead of W > , or else dx̄1 (and thus dh̄1 ) is strongly correlated with
is crucial that we use W
W > (through x̄0 = φ(W x0 ) inside n∆V1 = −χ0 x̄> 0 ).
In any case, we have
Z dx1 = Z W dh̄1 − cZ x0 , where c = χ̊0 E Z dh̄0 Z dh̄1 ,
f

is a sum of Gaussian Z W dh̄1 and a multiple of Z x0 . Then weights are updated according to Eq. (22).
f

tth Iteration For general t, we always have (true in normal SGD as well)
t−1
1X
∆Wt = − χs dh̄s x>
s
n s=0
so that in the forward pass
t−1
X x>
s xt
h̄t = W xt + ∆Wt xt = W xt − χs dh̄s (24)
s=0
n
t−1
X
Z h̄t =
def
Z W xt − χ̊s Z dh̄s E Z xs Z xt .
s=0
W xt
Here Z is Gaussian with covariance Cov(Z W xt , Z W xs ) = E Z xt Z xs for any s. This means that
Z and Z are correlated through Z W xt , Z W xs (but also through Z dh̄r , r ≤ min(t, s)). Likewise,
h̄t h̄s

in the backward pass,


t−1
f dh̄t − ∆W > dh̄t = W
f dh̄t −
X dh̄>
s dh̄t
dxt = W χs xs
s=0
n
t−1
X
Z dxt =
def
Z W dh̄t − χ̊s Z xs E Z dh̄s Z dh̄t
f

s=0

W
f dh̄t
Here, Z is Gaussian with covariance Cov(Z W dh̄t , Z W dh̄s ) = E Z dh̄t Z dh̄s for any s. Thus,
f f

dxt dxs
Z and Z are correlated through Z W dh̄t , Z W dh̄s (but also through Z xr , r ≤ min(t, s)). Again,
f f

the Gaussianity of Z W xt and Z W dh̄t depend crucially on the fact that we use W
f f instead of W > in
backpropagation.
Other parts of the forward and backward propagations are similar to before. Our reasoning can be
formalized via Tensor Programs to prove the following
34
dx̄1 = nV1> , dh̄1 = dx̄1 φ0 (h̄1 ), dh1 = dx1 φ0 (h1 )

18
Theorem 6.3. Consider a 2-hidden-layer MLP in µP with partially decoupled backpropagation as in
Eq. (20) and any training routine with learning rate 1. Suppose φ0 is pseudo-Lipschitz.35 As n → ∞,
for every input ξ,
a.s.
ft (ξ) −−→ f˚t (ξ), where f˚t (ξ) is defined as follows:
(forward pass)
f˚t (ξ) =
def
E Z nVt Z x̄t (ξ) , Z x̄t (ξ) =
def
φ(Z h̄t (ξ) ), Z xt (ξ) = def
φ(Z ht (ξ) ), Z ht (ξ) =
def
ξZ Ut
t−1
X
Z h̄t (ξ) =
def
Z W xt (ξ) − χ̊s Z dh̄s E Z xs Z xt (ξ) (25)
s=0

{Z W xt (ξ) }ξ,t centered, jointly Gaussian with Cov(Z W xt (ξ) , Z W xs (ζ) ) = E Z xt (ξ) Z xs (ζ)
(backward pass)
def
χt = L0 (f˚t , yt ), Z dx̄t =
def
Z nVt , def 0
Z dh̄t = def 0
φ (Z h̄t )Z dx̄t Z dht = φ (Z ht )Z dxt
t−1
X
Z dxt =
def
Z W dh̄t − χ̊s Z xs E Z dh̄s Z dh̄t (26)
f

s=0

{Z W dh̄t }t centered, jointly Gaussian with Cov(Z W dh̄t , Z W dh̄s ) = E Z dh̄t Z dh̄s
f f f

(U, V updates)
Z nVt+1 =
def
Z nVt − χ̊t Z x̄t Z Ut+1 =
def
Z Ut − χ̊t ξt Z dht
with Z U0 and Z nV0 being independent standard Gaussians as initial conditions, and by definition,
{Z W xt (ξ) }ξ,t , {Z W dh̄t }t , Z U0 , and Z nV0 are mutually independent sets of random variables. Here,
f

if ht appears without argument, it means ht (ξt ); likewise for h̄t , xt , x̄t , dht , dh̄t , dxt , dx̄t , f˚t .

6.3 2-Hidden-Layer MLP: Normal SGD


Finally, we dicuss normal SGD for 2-hidden-layer MLP, i.e. in backprop we compute
dxt = Wt> dh̄t = (W > + ∆W > )dh̄t .
The first forward and backward passes are essentially the same as in the last section. However, as
mentioned there, in the second forward pass, W x1 (a part of h̄1 = W x1 + ∆W1 x1 ) will no longer be
approximately Gaussian because of the correlation between x1 and W . Let’s first get some intuition
for why this is before stating the infinite-width limit formally.

Warmup: φ = id First, as warmup, suppose φ = id. In this case, W x1 will actually still be
Gaussian, but its variance will be different than what’s predicted in the previous section. To lighten
notation, we write x = x1 in this section. Then unwinding the definition of x, we have
x = h + aW > z
where we abbreviated h = ξ1 U0 , z = dh̄0 , a = −χ0 ξ0 ξ1 . Then W x has coordinates
(W x)α = (W h)α + a(W W > z)α .
As derived in the first forward pass in Section 6.2, (W h)α is approximately Gaussian (particularly
because W, U0 are independent). This is true for (W W > z)α as well here because we assumed
φ = id, but not true generally. Indeed,
X X XX
(W W > z)α = Wαβ Wγβ zγ = zα (Wαβ )2 + Wαβ Wγβ zγ .
β,γ β β γ6=α

We will soon see the derivations of Section


P6.2 correspond to ignoring the first term: In the second
term, there are n summands of the form γ6=α Wαβ Wγβ zγ that are approximately iid with vari-
ance ≈ kzk2 /n2 . Thus, the second term itself, by a Central PLimit heuristic, should converge to
N (0, limn→∞ kzk2 /n). On the other hand, the first term zα β (Wαβ )2 → zα by Law of Large
35
This roughly means that φ0 has a polynomially bounded weak derivative; see Definition E.3.

19
P P Tying it all together, (W x)α is a linear combination of two Gaussian terms (W h)α and
Numbers.
β γ6=α Wαβ Wγβ zγ , as well as as zα (which is Gaussian in the case of φ = id, but not generally).

Note that, if we did (W Wf z)α instead of (W W > z)α , as in the last section, then the same analysis
P fβα → 0, while the second term converge in distribution to
would show the first term is zα β Wαβ W
the same Gaussian. Thus, the effect of decoupling in Section 6.2 is killing the copy of z in (W x)α .
We can summarize our derivation here in terms of Z:
> >
For φ = id: Z W x =
def
Z W h + aZ W W z
= Z W h + a(Ẑ W W z
+ Z z ), (27)
>
W W z def z 2
where Ẑ = N (0, E(Z ) ).
> >
Note the Central Limit heuristic in the derivation of Ẑ W W z also shows Ẑ W W z is jointly Gaussian
> >
with Z W h with Cov(Ẑ W W z , Z W h ) = E Z W z Z h . So, to put Eq. (27) in a form more suggestive
of the general case, we will write
>
z d
Z W x = Ẑ W x + aZ z , where Ẑ W x = Z W h + aẐ W W = N (0, E(Z x )2 ). (28)

General φ Unwinding the definition of x, we have


x = φ(h + aW > z φ0 (h0 )). (29)
By Taylor-expanding φ, we can apply a similar (though more tedious) argument as above to derive
Z W x = Ẑ W x + cZ z (30)
d
where c = a E φ0 (Z h1 )φ0 (Z h0 ) and Ẑ W x = N (0, E(Z x )2 ). In the case of φ = id, c reduces to a
as above, recovering Eq. (28). For general φ, we can immediately see that Z W x is not Gaussian
because Z z = Z dx̄0 φ0 (Z h̄0 ) is not. In the Tensor Programs framework formalized in Section 7, cZ z
is denoted Ż W x .
Similarly, coordinates distribution of dx1 = W1> dh̄1 will also change in the backward pass.

General t For general t, we obtain dynamical equations in Z identical to those in Theorem 6.3
except that Eq. (25) and Eq. (26) need to be modified. We state the general result below.
Theorem 6.4. Consider a 2-hidden-layer MLP in µP and any training routine with learning rate
a.s.
1. Suppose φ0 is pseudo-Lipschitz.36 As n → ∞, for every input ξ, ft (ξ) −−→ f˚t (ξ) where f˚t (ξ) is
defined the same way as in Theorem 6.3 except that Eq. (25) should be replaced with
t−1
X
h̄t (ξ) def W xt (ξ) W xt (ξ)
Z = Ẑ + Ż − χ̊s Z dh̄s E Z xs Z xt (ξ)
s=0

{Ẑ W xt (ξ) }ξ,t centered, jointly Gaussian with Cov(Ẑ W xt (ξ) , Ẑ W xs (ζ) ) = E Z xt (ξ) Z xs (ζ)
and Eq. (26) should be replaced with
t−1
> > X
Z dxt =
def
Ẑ W dh̄t
+ Ż W dh̄t
− χ̊s Z xs E Z dh̄s Z dh̄t
s=0
W > dh̄t > >
{Ẑ }t centered, jointly Gaussian with Cov(Ẑ W dh̄t
, Ẑ W dh̄s
) = E Z dh̄t Z dh̄s .
>
Like in Theorem 6.3, by definition, {Ẑ W xt (ξ) }ξ,t , {Ẑ W dh̄t }t , Z U0 , and Z nV0 are mutually indepen-
dent sets of random variables.
def Pt−1
Here, Ż W xt (ξ) = r=0 θr Z
dh̄r
where θr is calculated like so: Z xt (ξ) by definition is constructed as
> >
Z xt (ξ) = Φ(Ẑ W dh̄0
, . . . , Ẑ W dh̄t−1
, Z U0 )
for some function37 Φ : Rt+1 → R. Then
> > >
def
θr = E ∂Φ(Ẑ W dh̄0
, . . . , Ẑ W dh̄t−1
, Z U0 )/∂ Ẑ W dh̄r
.
36
This roughly means that φ0 has a polynomially bounded weak derivative; see Definition E.3.
37
that may depend on various scalars such as χ̊s , E Z xs Z xs0 (ξ) , and E Z dh̄s Z dh̄s0

20
> Pt−1
dh̄t def
Likewise, Ż W = r=0 θr Z
xr
where θr is calculated as follows: Z dh̄t by definition is con-
structed as
Z dh̄t = Ψ(Ẑ W x0 , . . . , Ẑ W xt−1 , Z V0 )
for some function37 Ψ : Rt+1 → R. Then
def
θr = E ∂Ψ(Ẑ W x0 , . . . , Ẑ W xt−1 , Z V0 )/∂ Ẑ W xr .

For example, generalizing Eq. (29), for any input ξ, we have


>
Z x1 (ξ) = Φ(Z W dh̄0
, Z U0 ), def
where Φ(z, u) = φ(ξu − χ̊0 ξ0 ξφ0 (ξ0 u)z).
>
Then θ0 = E ∂z Φ(Z W dh̄0 , Z U0 ) = −χ̊0 ξ0 ξ E φ0 (Z h1 (ξ) )φ0 (Z h0 ), which specializes to c in
Eq. (30). Altogether, Ż W x1 (ξ) = −χ̊0 ξ0 ξZ dh̄0 E φ0 (Z h1 (ξ) )φ0 (Z h0 ).
Note that Ẑ W xt here does not equal Z W xt in Eq. (25) in general, because the covariance
Cov(Ẑ W xt , Ẑ W xs ) = E Z xt Z xs is affected by the presence of Ż W xr for all r ≤ max(s, t).

6.4 MLP of Arbitrary Depth


The µP limit of deeper MLPs can be derived along similar logic; see Appendices G.3 to G.5 for
a rigorous treatment within the Tensor Programs framework, which also covers all stable abc-
parametrizations.

What happens in other feature learning parametrizations If we are in the feature learning
regime, then any W l that is not maximally updated (Definition 5.2) will be effectively fixed (to its
initialized value) in the infinite-width limit (i.e. no learning occurs).

6.5 Summary of Main Intuitions for Deriving the µP Limit


Law of Large Numbers Any vector z has roughly iid coordinates given by Z z . For any two vectors
Pn 0
z, z 0 ∈ Rn , n1 α=1 zα zα0 → E Z z Z z .
1. This is all we needed to derive the 1-hidden-layer dynamics of Section 6.1, since all
the matrices there are size-n vectors.
2. In Sections 6.2 and 6.3, this is also used in calculating the limit of ∆Wt xt .
Central Limit If the underlying computation graph never involves the transpose W > of a n × n
Gaussian matrix W in a matrix multiplication, then W z is roughly iid Gaussian with
d
coordinate Z W z = N (0, E(Z z )2 ) (if Wαβ ∼ N (0, 1/n))
1. This along with the last intuition are all we used to derive the 2-hidden-layer decoupled
dynamics of Section 6.2, where W is the middle layer weight matrix.
(W , W > ) Correlation If W > is involved, then W z has coordinates distributed like random variable
Ẑ W z + Ż W z where Ẑ W z is the Gaussian obtained by pretending W is independent from
W > , and Ż W z results from the correlation between W and W > . Ż W z is purely a linear
0
combination of Z z for previously defined vectors z 0 such that z depends on W > z 0 .
1. All three intuitions above are needed to derive the 2-hidden-layer dynamics of normal
SGD (Section 6.3), where W > is used in backpropagation.
2. The calculation of Ż W x is quite intricate, which is why we first discussed decoupled
SGD in Section 6.2, which doesn’t need Ż W x calculation, before discussing normal
SGD in Section 6.3.

7 Tensor Programs Framework


While the previous section demonstrates the intuition of how to derive the µP limit, it also lays
bare 1) the increasing complexity of a manual derivation as the training goes on, as well as 2) the
mounting uncertainty for whether the intuition still holds after many steps of SGD. This is a perfect
call for the Tensor Programs framework, which automates (and makes rigorous) the limit derivation

21
1 2 3
𝑍𝒱 = ( 𝑍𝑣 𝑍𝑣 𝑍𝑣 𝑍𝑣
𝑗
)
2
iid 𝒩 0, 𝜎𝑊 /𝑛 entries 𝑣1 𝑣2 𝑣3 𝑣 𝑗

𝑛→∞
Setup

𝑛→∞
𝒲= 𝑊 𝒱= 𝒞=

MatMul Nonlin
1 ,…,𝑥 𝑘 ) 1 2 3
2
= 𝜙 ( 𝑍𝑥 ; )
𝑘
𝑍 𝑊𝑥 = 𝑍ሶ 𝑊𝑥 + 𝑍መ 𝑊𝑥 𝒩 0, 𝜎𝑊 𝔼 𝑍𝑥 2
𝑍 𝜙(𝑥 𝑍𝑥 𝑍𝑥 𝑍𝑥 𝜃ሞ 1 𝜃ሞ 2 𝜃ሞ ℓ
Correction due to (𝑊, 𝑥) correlation 𝑥1 𝑥2 𝑥3 𝑥𝑘 𝜃1 𝜃2 𝜃ℓ
𝜙( ; )
𝜙( ; )
𝜙( ; )
𝑊 𝑥
𝜙( ; )
𝜙( ; )
𝜙( ; )

Master Theorem Moment


1 2 3
𝜗ሞ = 𝔼 𝜙 ( 𝑍𝑥 𝑍𝑥 𝑍𝑥 𝑍𝑥
𝑘
; 𝜃ሞ 1 𝜃ሞ 2 𝜃ሞ ℓ )
𝑥1 𝑥2 𝑥3 𝑥𝑘 𝜃1 𝜃2 𝜃ℓ
𝜙( ; )
𝜙( ; )
Average
𝑎. 𝑠.
𝜗 𝑛 𝜙( ; )
1
𝜗ሞ ෍
as 𝑛 → ∞ 𝑛
1
𝜙( ; )
𝜙( ; )
𝜙( ; )

Figure 3: Graphical overview of the Tensor Programs framework. For the Master Theorem, we
illustrate Theorem 7.4(2) since Theorem 7.4(1) is a corollary of Theorem 7.4(2) for a larger program.

for any “computation graph” — including the computation graph underlying SGD. Here we review
this framework (developed in Yang [43, 44, 45, 46]) in the context of µP limit. Fig. 3 graphically
overviews the content of this section.
As seen abundantly in Section 6, the computation underlying SGD can be expressed purely via three
instructions: matrix multiplication (by a Gaussian matrix,
Pn e.g. W0 x0 ), coordinatewise nonlinearities
(e.g. φ), and taking coordinatewise average (e.g. n1 α=1 (nV1 )α x1α ). In deriving the µP SGD limit,
we focused mostly on keeping track of Rn vectors (e.g. x̄t or dht ), but importantlyP
we also computed
n
scalars ft and χt by (what amounts to) taking coordinatewise average (e.g. f1 = n1 α=1 (nV1 )α x1α ).
We implicitly compute scalars as well inside ∆Wt xt . This motivates the following notion of a
program, which can be thought of as a low-level symbolic representation of a computation graph
common in deep learning (e.g. underlying Tensorflow and Pytorch).
Definition 7.1. A Tensor Program38 is a sequence of Rn -vectors and R-scalars inductively generated
via one of the following ways from an initial set C of random scalars, V of random Rn vectors, and a
set W of random Rn×n matrices (which will be sampled with iid Gaussian entries in Setup 7.2)

MatMul Given W ∈ Rn×n and x ∈ Rn , we can generate W x ∈ Rn or W > x ∈ Rn

What we refer to as Tensor Program is the same as N ETSOR>+ in Yang [46]; we will not talk about other
38

languages (like N ETSOR>) so this should not cause any confusion

22
Nonlin Given φ : Rk × Rl → R, previous scalars θ1 , . . . , θl ∈ R and vectors x1 , . . . , xk ∈ Rn , we
can generate a new vector
φ(x1 , . . . , xk ; θ1 , . . . , θl ) ∈ Rn
where φ(−; θ1 , . . . , θl ) applies coordinatewise to each “α-slice” (x1α , . . . , xkα ).
Moment Given same setup as above, we can also generate a new scalar
n
1X
φ(x1α , . . . , xkα ; θ1 , . . . , θl ) ∈ R.
n α=1

Explanation of Definition 7.1 The vectors mentioned in Definition 7.1 are exemplified by
ht , xt , dht , dxt in Section 6. The scalars mentioned are exemplified by ft , χt as well as e.g. x>
s xt /n
inside the calculating of ht (Eq. (24)). The θi s in Nonlin and Moment rules may appear cryptic at first.
These scalars are not needed in the first forward and backward passes. But in the second forward pass,
for example for the 1-hidden-layer MLP (Section 6.1), x1 = φ(h1 ) = φ(ξ1 U0 − χ0 ξ1 ξ0 nV0 φ0 (h0 ))
depends on the scalar χ0 , ξ0 , ξ1 , and can be written in the form of Nonlin as φ̄(U0 , nV0 , h0 ; χ0 ) for
some φ̄ appropriately defined.
The initial set of scalars C is the training sequence {ξt , yt }t for all three examples of Section 6. In
our 2-hidden-layer MLP examples, the initial set of matrices W is {W } (Section 6.3) or {W, W f}
(Section 6.2), i.e. the random Rn×n Gaussian matrices. On the other hand, in the 1-hidden-layer
MLP example (Section 6.1), W is empty. The initial set of vectors V in all three examples are
V = {U0 , nV0 }.3940 Notice how the vectors of these V are sampled with iid standard Gaussian
coordinates. We formalize a more general setup for arbitrary Tensor Programs:
2
Setup 7.2. 1) For each initial W ∈ W, we sample iid Wαβ ∼ N (0, σW /n) for some variance
0
σ 2
W associated to W , independent of other W ∈ W; 2) for some multivariate Gaussian Z V =
Z : h ∈ V ∈ R , we sample the initial set of vectors V like {hα : h ∈ V} ∼ Z V iid for each
V
 h
a.s.
α ∈ [n]. 3) For each initial scalar θ ∈ C, we require θ −−→ θ̊ for some deterministic θ̊ ∈ R.
2
In all of our examples, we took σW = 1 for simplicity, but Setup 7.2 allows for other initializations
2
(e.g. a typical initialization for relu networks is σW = 2); additionally, Z h , h ∈ V, are all standard
Gaussians, independent from one another, since U0 , nV0 are sampled this way; and our initial scalars
{ξt , yt }t are fixed with n, so they are their own limits.41

What Does a Tensor Program Vector Look Like? Recall that we represented the coordinate
distribution of each vector h with a random variable Z h in Section 6 and kept track of how different
Zs are correlated with each other. We also calculated scalar limits like ft → f˚t , χt → χ̊t . These
calculations led to a set of formulas for the µP limit (e.g. Theorems 6.1, 6.3 and 6.4). We can also
construct such Z h and θ̊ for vectors h and scalars θ in any Tensor Program. They intuitively capture
the coordinate distribution of vector h and the deterministic limit of θ. The following definition
formally defines Z h and θ̊, but the connection between Z h (resp. θ̊) and the coordinates of h (resp.
θ) is not made rigorously until Theorem 7.4 later. The ZMatMul rule below perhaps asks for some
discussion, and we shall do so after the definition.
Definition 7.3 (Z h and θ̊). Given a Tensor Program, we recursively define Z h for each vector h and
θ̊ for each scalar θ as follows.
ZInit If h ∈ V, then Z h is defined as in Setup 7.2. We also set Ẑ h =
def
Z h and Ż h =
def
0.
ZNonlin+ Given φ : Rk × Rl → R, previous scalars θ1 , . . . , θl ∈ R and vectors x1 , . . . , xk ∈ Rn ,
we have 1 k 1 k
Z φ(x ,...,x ;θ1 ,...,θl ) =
def
φ(Z x , . . . , Z x ; θ̊1 , . . . , θ̊l ).
39
Here we write nV0 instead of V0 because we want all vectors to have Θ(1) coordinates; see Setup 7.2.
40
In Section 6 we assumed input dimension is 1. In general, each column of U0 would be a separate initial
vector. Likewise, if the output dimension is greater than 1, then each row of V0 would be a separate initial vector.
41
Since {ξt , yt }t are fixed with n, we can WLOG absorb them into any nonlinearities in Nonlin that they are
involved in, and set C = ∅. But, in kernel regime or nonmaximal feature learning parametrization, we usually
have initial scalars, such as n−2aL+1 −c , that tend to 0 with n; see Appendix G.4.

23
1
Pn
ZMoment Given same setup as above and scalar θ = n α=1 φ(x1α , . . . , xkα ; θ1 , . . . , θl ), then
1 k
def
θ̊ = E φ(Z x , . . . , Z x ; θ̊1 , . . . , θ̊l ).
1 k
Here θ̊1 , . . . , θ̊l are deterministic, so the expectation is taken over Z x , . . . , Z x .

ZMatMul Z W x =
def
Ẑ W x + Ż W x for every matrix W (with N (0, σW
2
/n) entries) and vector x, where

ZHat Ẑ W x is a Gaussian variable with zero mean. Let VW denote the set of all vectors in
the program of the form W y for some y. Then {Ẑ W y : W y ∈ VW } is defined to be
jointly Gaussian with zero mean and covariance
 
Cov Ẑ W x , Ẑ W y =def 2
σW E Z x Z y , for any W x, W y ∈ VW .

Furthermore, {Ẑ W y : W y ∈ VW } is mutually independent from {Ẑ v : v ∈ V ∪


>
S
W̄ 6=W VW̄ }, where W̄ ranges over W ∪ {A : A ∈ W}.
ZDot We can always unwind Z x = Φ(· · · ), for some arguments (· · · ) =
> i i
({Ẑ W y }ki=1 , {Ẑ z }ji=1 ; {θ̊i }li=1 ), z i 6∈ VW > (where VW > is defined in ZHat), and
> i def
deterministic function Φ : Rk+j+l → R. Define ∂Z x /∂ Ẑ W y = ∂i Φ(· · · ). Then we
set
k
X i ∂Z x
Ż W x = def 2
σW Zy E , (31)
i=1 ∂ Ẑ W > yi
There is some nuance in this definition, so see Remark E.1 and E.2.

Explanation of Definition 7.3 Nonlin and Moment should appear only natural. However, we
pause to digest the meaning of ZMatMul by relating back to our examples in Section 6. First notice
that Ż W x = 0 if W > is not used in the program, so that Z W x = Ẑ W x . This is the case in Section 6.2,
where W f is used in backprop instead of W > . There (in Eq. (25)), Z W xt is Gaussian with covariance
Cov(Z W xt , Z W xs ) = E Z xt Z xs for any s, consistent with ZHat. In Section 6.3, however, Ż W x 6= 0
in general. The ZDot rule is a direct generalization of the calculation of Ż in Theorem 6.4.

>
Ż W xt and Ż W dh̄t of Section 6.3 for general t will all be nonzero but have no easy expression.
Here we seek to convey the complexity of computing them; this is optional reading for the first time
>
reader. To calculate Ż W xt (Ż W dh̄t is similar), we need to express Z xt as a function of purely
> >
Ẑ W dh̄s , s < t, and Z U0 = Ẑ U0 . Then we symbolically differentiate Z xt by Ẑ W dh̄s and take
expectation to obtain the coefficient of Z dh̄s in Ż W xt . For t = 1 as in the examples in Section 6.3,
>
this task is easy because Ẑ W dh̄0 = Ẑ dx0 = Z dx0 . But in general, the calculation can balloon
quickly. Indeed, note Z = φ(Z ht ) and
xt

t−1
X t−1
X
Z ht = ξt Z Ut = ξt Z U0 − ξt χ̊s ξs Z dhs = ξt Z U0 − ξt χ̊s ξs φ0 (Z hs )Z dxs .
s=0 s=0
> > >
However, each Z dxs is a linear combination of Z W dh̄s = Ẑ W dh̄s + Ż W dh̄s and Z xr , r < s
>
(coming from ∆Wt> dh̄s ). Each of Ż W dh̄s and Z xr then needs to be recursively expanded in terms
>
of Ẑ before we can calculate the symbolic partial derivative ∂Z xt /∂ Ẑ W dh̄s .

Master Theorem Finally, we relate the symbolic nature of a Tensor Program given in Definition 7.3
to the analytic limit of its computation, in the following Master Theorem. Pseudo-Lipschitz functions
are, roughly speaking, functions whose (weak) derivatives are polynomially bounded. We state the
theorem assuming mild regularity conditions (Assumption E.4) that roughly says most nonlinearities
in the program should be pseudo-Lipschitz.
Theorem 7.4 (Tensor Program Master Theorem, c.f. Theorem E.15 of [46]). Fix a Tensor Program
initialized accordingly to Setup 7.2. Adopt Assumption E.4. Then

24
Algorithm 1 Compute the infinite-width limit of an NN in any abc-parametrization and any task
1: Write the computation graph underlying training and inference in a Tensor Program (akin to
writing low level PyTorch or Tensorflow code).
2: Calculate Z h for each vector h and θ̊ for each scalar θ in the program, according to Definition 7.3.
3: The logits ft (ξ) of the neural network at any time t should be written as a collection of scalars,
so f˚t (ξ) is calculated in the previous step. For t being inference time, f˚t (ξ) is the output of the
infinite-width network after training.

1. For any fixed k and any pseudo-Lipschitz ψ : Rk → R, as n → ∞,


n
1X a.s. 1 k
ψ(h1α , . . . , hkα ) −−→ E ψ(Z h , . . . , Z h ), (32)
n α=1
i
for any vectors h1 , . . . , hk in the program, where Z h are as defined in Definition 7.3.

2. Any scalar θ in the program tends to θ̊ almost surely, where θ̊ is as defined in Definition 7.3.

Intuitively, Theorem 7.4(1) says that each “coordinate slice” (h1α , . . . , hkα ) can be thought of as an iid
1 k
copy of (Z h , . . . , Z h ).42 This intuition is consistent with our heuristic derivation in Section 6, and
Theorem 7.4 underlies the proof of Theorems 6.1, 6.3 and 6.4. Theorem 7.4(2) allows us to directly
obtain the function learned at the end of training: For example, for a 1-hidden-layer MLP, it shows
that the network’s output on any input ξ at time t converges to f˚t (ξ) given in Theorem 6.1.
Algorithm 1 summarizes how to compute the infinite-width limit of any network in any abc-
parametrization and for any task, using the Tensor Programs framework laid out in this section.
It generalizes the manual derivations of Section 6. We carry out Algorithm 1 for MLPs in all of our
experiments.

Architectural and algorithmic universality Given that Tensor Programs can express the first
forward and backward computation of practically any architecture [43, 45], it should perhaps come
as no surprise that they can also express practically any training and inference procedure — or just
any computation — involving any such architecture. This includes both feature learning and kernel
limits. We leverage this flexibility to derive and compute the µP and kernel limits for metalearning
and Word2Vec; see Section 9.

Extensions We focused on programs whose vectors all have the same dimension n here. But it’s
easy to generalize to the case where vectors have different dimensions, which corresponds to e.g.
when a network’s widths are non-uniform. See [46].

8 Computational Considerations
While the TP framework is very general, computing the feature learning limits analytically is
inherently computationally intensive aside from special cases like the linear 1-hidden-layer MLP
(Corollary 6.2). Here we explain why, so as to motivate our experimental choices below.

No closed-form formula for evaluating the expectations (e.g. in Eq. (32)) involving general
nonlinearities except in special cases For example, for a 1-hidden-layer MLP (Section 6.1), after
1 step of SGD, the logit is of the form E(Z1 + bφ(Z2 ))φ(Z3 + cZ1 φ0 (Z2 )) where Zi s denote different
(correlated) Gaussians (Eq. (17)). While one can still evaluate this via Monte-Carlo, the error will
compound quickly with training time. On the other hand, because of the nesting of φ0 inside φ, there
is no closed-form formula for this expectation in general.
Notable Exception: If the nonlinearity φ is polynomial, then the expectation is a polynomial moment
of a multivariate Gaussian and can be evaluated analytically, e.g. using Isserlis’ theorem from the
covariance matrix.
42
This implies an explicit convergence in distribution (see [46]), but this convergence in distribution is strictly
weaker than the formulation in Theorem 7.4, which is in general much more useful.

25
Even with nonlinear polynomial φ, there is exponential computational bottleneck As training
time t increases, due to the nesting of φ and φ0 in the preactivations, the integrand of the expectation,
e.g. E Z x̄t Z nVt , will turn out to be a polynomial in Ω(1) Gaussian variables with degree Ω(2t ). The
covariance matrix of the Gaussian variables will in general be nontrivial, so evaluating the expectation,
e.g. using Isserlis’ theorem, requires super-exponential time. This is because we would need to
expand the polynomial integrand into monomials, and there would be Ω(2t ) monomials, each of
which require Ω(2t ) time to evaluate using Isserlis’ theorem.

n × n Gaussian matrices Both points above apply to 1-hidden-layer MLPs. Additional difficulties
with deeper networks is caused by the n × n initial Gaussian matrix W0l , 2 ≤ l ≤ L, in the middle of
the network. 1) In general, due to the nonlinearities, xl−1
t would be linearly independent from xl−1s
for all s < t. Therefore, in calculating Wt xt = W0 xt + ∆Wtl xl−1
l l−1 l l−1
t , we create a new Gaussian
l l−1 l l−1
variable Ẑ W0 xt linearly independent from all previous Ẑ W0 xs , s < t. This then requires us
to compute and store the covariance between them. Thus, t steps of SGD costs Ω(t2 ) space and
time (not mentioning that the computation of each covariance entry can require exponential time,
as discussed above). 2) In addition, due to the interaction between Wtl in the forward pass and
Wtl> in the backward pass, there is nonzero Ż, as demonstrated in Eq. (30). This Ż is generally a
linear combination of Ω(t) terms, and the coefficients of this combination require evaluation of some
expectations that typically run into the exponential bottleneck discussed above.

Summary From easiest to hardest in terms of µP limit’s computational cost, we have 1) 1-hidden-
layer linear networks; 2) L-hidden-layer linear MLP, L ≥ 2; 3) nonlinear MLP with polynomial
activations; 4) nonlinear MLP with nonpolynomial activations. Nevertheless, 1-hidden-layer linear
networks are more than sufficient to demonstrate feature learning in Word2Vec and few-shot learning
with MAML, as we show below.

9 Experiments
In light of the computational difficulties discussed above, we divide our experiments into two
groups: 1) Verifying our theory; 2) Scaling up to realistic datasets to demonstrate feature learning.
The experiments in group 1 focus on stress-testing our theory in many scenarios to show that it
describes empirical phenomena accurately. They will run into the discussed computational difficulties
(Section 8), so we cannot train the infinite-width µP networks for very long, but nevertheless long
enough to verify the theory. Those in group 2 focus on real datasets (metalearning and Word2Vec)
where feature learning is critical, and demonstrate that the GP and NTK limits are inadequate for
those tasks. Necessarily, we adopt simpler neural architectures for this purpose so we can scale up.

9.1 Verifying the Theory


In Fig. 4, we analytically computed the µP limits derived in Section 6 for quadratic and linear
activations, and verified them against finite width networks.

9.2 Few-Shot Learning on Omniglot via First Order MAML


In few-shot learning, the model is given only a small number of labeled examples before asking to
make predictions on unseen data. Therefore, this tests whether a model contains a good prior that
can adapt quickly to the small amount of data at hand.

MAML In Model Agnostic Meta-Learning (MAML), the model performs few-shot learning by
one or more SGD steps on the given training data; this is called adaptation. In a pretraining (also
called meta-training) phase, MAML learns a good initialization of the model parameters for this
adaptation. The training objective is to minimize the loss on a random task’s test set after the model
has adapted to its training set. More precisely, the basic First Order MAML at training time goes as
follows: With fθ denoting the model with parameters θ, and with step sizes , η, we do

1. At each time point, sample a few-shot task T


2. From T , sample a training set D
3. Adapt θ0 ← θ − ∇θ LD (fθ ), where LD (fθ ) is the loss of fθ over D

26
depth 2 depth 2, decoupled depth 1
1.5 1.5
width
1024 0.8
4096

(x) = x2
1.0 1.0

loss
inf 0.6
0.5
0.5
0.4
0 1 2 3 0 1 2 3 0 1 2 3 4 5
1.0 1.0
0.7
0.8 0.8
0.6
(x) = x
loss

0.6 0.6
0.5
0.4 0.4
0.4
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 3 6 9 12 15 18
iter iter iter

Figure 4: Empirical Simulation Agrees with Theory. We analytically compute the infinite-width
µP limit for the three kinds of networks (depth 1, depth 2 decoupled, depth 2) described in Section 6,
with either quadratic φ(x) = x2 or linear φ(x) = x activation. The training set is random ξt ∈
{±1}, yt ∈ {±1}, so that the deviation of finite width from infinite width losses are accentuated. We
compare against finite width µP networks with width 1024 or 4096. For each width, we randomly
initialize with 100 different seeds and aggregate the loss curves. The mean across these seeds is
plotted as solid curves, and the standard deviation represented by the shade. As discussed in Section 8,
nonlinear activation functions and higher depth face computational difficulties exponential with
training time. Thus here we only train for a few steps. We observe that the quadratic network
converges slower to the limit with width. This is expected since the tail of Z xt is fatter for a quadratic
activation than a linear activation.

4. Sample a test set D0 from T


5. Update θ ← θ − η∇θ0 LD0 (fθ0 ), where LD0 (fθ0 ) is the loss of fθ0 over D0
6. Repeat

In practice, we batch the tasks, just like batches in SGD, so that we accumulate all the gradients from
Step 5 and update θ only at the end of the batch.
During meta-test time, we are tested on random unseen few-shot tasks, where each task T provides a
training set D and a test set D0 as during meta-training. We adapt to D as in Step 3 above (or more
generally we can take multiple gradient steps to adapt better) to obtain adapted parameters θ0 . Finally,
we calculate the accuracy of θ0 on the test set D. We average this accuracy over many tasks T , which
we report as the meta-test accuracy.

First Order vs Second Order MAML Notice in Step 5, we take the gradient of LD0 (fθ0 ) with
respect to the adapted parameters θ0 . In Second Order MAML, we would instead take the gradient
against the unadapted parameters θ, which would involve the Hessian ∇θ ∇θ LD (fθ ). Second Order
MAML generally achieves performance slightly better than First Order MAML, but at the cost of
significantly slower updates [32]. In order to scale up, we will focus on First Order MAML, hereafter
referred to as just MAML.

Few-Shot Learning Terminologies An N -way classification task asks the model to predict a class
from N possiblities. A K-shot classification task provides K input/output pairs per class, for a total
of N K training points for N -way classification.

Omniglot Omniglot is a standard few-shot learning benchmark. It consists of 20 instances of 1623


characters from 50 different alphabets, each handwritten by a different person. We test our models on
1-shot 5-way classification: We draw 5 random characters, along with 1 training instance and 1 test
instance for each character. After the model adapts to the training instances, it’s asked to predict the
character of the test instances (choosing among the 5 characters).

27
Table 2: Omniglot Meta-Test Accuracies after Pretraining with First Order MAML.
φ = relu φ = identity ; number = log2 width
GP NTK 1 3 5 7 9 11 13 µP GP/NTK
47.60 47.82 55.34 64.54 66.21 66.31 66.43 66.36 66.41 66.42 41.68
±.02 ±.04 ±1.24 ±0.70 ±.15 ±.16 ±.23 ±.22 ±.18 ±.19 ±.09

Models Our main model is the µP limit of a 1-hidden-layer linear MLP. We compare against: 1)
finite width versions of the same;43 2) the NNGP and NTK limits of the same; 3) the NNGP and
NTK limits of a 1-hidden-layer relu MLP. Note 2) is equivalent to a 0-hidden-layer perceptron,
because the NNGP and NTK there are both linear kernels. In addition, the infinite-width SP limit of a
1-hidden-layer network is the same as the NNGP limit. Both 2) and 3) are equivalent to linear models
with fixed (not learned) features, so MAML’s adaptation only applies to the linear weights. On the
other hand, the µP limit and the finite µP networks will learn new representations of the data over
time that can quickly adapt to new tasks.44

Hyperparameters We use (task) batch size 32 and adaptation step size 0.4 ( in Step 3). We also
clip the gradient in Step 5 if the gradient has norm ≥ 0.5.45 For each model, we tune its weight
initializaton variances and the meta learning rate (η in Step 5). During meta-test time, we take 20
gradient steps during adaptation (i.e. we loop Step 3 above 20 times to obtain θ0 ). See Appendix D.1
for more details.
Omniglot, 1-Shot 5-Way
Findings Our results are summarized in the Figure to the 0.65
right and Table 2, where curves indicate means and shades 0.60 lin finite
indicate standard deviations. There are three key takeaways: test set acc lin P
1) The feature learning µP limit significantly outperforms the 0.55 lin ntk/gp
kernel limits. 2) The benefit of feature learning dominates 0.50 relu gp
relu ntk
the benefit of having nonlinearities. 3) As width increases,
the finite µP networks approach the performance of the µP 0.45
limit from below.
1 3 5 7 9 11 13
log2(width)
9.3 Word2Vec

Word2Vec [28, 29] is an early example of large-scale pretraining and transfer learning in natural
language processing, where one learns a feature vector h(ξ) for every word ξ based on the principle
of distributional semantics. For simplicity, we focus on a specific scheme of Word2Vec using context
as a bag-of-word (CBOW), negative example sampling, and Sigmoid loss function.

Word2Vec Pretraining Consider training on a corpus with vocabulary V. At each time step, we
sample a sentence for the corpus and choose a word i ∈ V. This word’s context J ⊆ V is a window
of words around it in the sentence, thought of as a bag of words. Let ξ i ∈ R|V| be the one-hot vector
def 1
Pn
corresponding to word i. We pass the averaged context ξ J = |J|
j
j∈J ξ through a 1-hidden-layer
MLP with hidden size n and identity activation:

f (ξ J ) = V h(ξ J ) ∈ R|V| , h(ξ J ) = U ξ J ∈ Rn , (33)

where V ∈ R|V|×n , U ∈ Rn×|V| factor as V = n−av v, U = n−au u with initialization vα ∼


N (0, n−2bv ), uα ∼ N (0, n−2bu ), where {av , bv , au , bu } specify the parametrization of the network.
After each forward pass, we sample a target word τ from V: with probability p, we take τ = i; with
probability 1 − p, we sample τ uniformly from V \ {i}. Following [28, 29], we take p = 1/21 ≈
43
Because we will tune initialization variances, our results also represent finite-width SP networks.
44
Note that the transfer learning comment in Section 3.1 does not apply directly to the few-shot setting here,
because the readout weights of the network carry over from the pretraining phase. Nevertheless, we will see a
large performance gap between the kernel limits (2,3) and the µP limit.
45
One can write down gradient clipping easily in a Tensor Program, so the its infinite-width limit can be
computed straightforwardly via Theorem 7.4; see Appendix D.

28
Table 3: Test Accuracies on Word Analogy after Pretraining with CBOW Word2Vec.
number = log2 width
Dataset 6 8 10 µP GP/NTK
text8 33.35 41.58 42.56 43.31 0.0
fil9 44.39 54.24 55.69 56.45 0.0

4.76%. The loss is then calculated with the Sigmoid function σ(·) :

log(1 − σ(f (ξ J )> ξ τ )) τ = i



J τ
L(f (ξ ), ξ ) = (34)
log σ(f (ξ J )> ξ τ ) τ=6 i

Then v and u are updated via SGD as usual (causing V and U to update). Conventionally, h(ξ) ∈ Rn
is taken as the Word2Vec embedding for a word ξ after many iterations of forward-backward updates.

Word Analogy Evaluation We evaluate the word embeddings h(ξ) with the word analogy task.
This task asks the question of the kind: What to a ‘queen’ is as a ‘man’ to a ‘woman’? (answer is
‘king’). The Word2Vec model answers this question by computing

argmax h(ξ i )> (h(ξ ‘man’ ) − h(ξ ‘woman’ ) + h(ξ ‘queen’ )) (35)
i

where i ranges over V \ {‘man’, ‘woman’, ‘queen’}. If the argmax here is i = ‘king’, then the model
answers correctly; otherwise, it’s incorrect. The accuracy score is the percentage of such questions
answered correctly.

Dataset We train the models on text8,46 a clean dataset consisting of the first 100 million charac-
ters of a 2006 Wikipedia dump. The dataset has been featured in the original Word2Vec codebase and
the Hutter Prize. text8 contains the first 100 million characters of fil9, a larger dataset obtained by
filtering the first 1 billion characters in the aforementioned Wikipedia dump. We space-separate the
datasets into tokens and keep ones that appear no less than 5 times in the entire dataset for text8 and
10 times for fil9. The resulting datasets have 71,291 and 142,276 unique vocabulary items.

Models Our main model is the µP limit of Eq. (33). We compare against the baselines of 1) finite-
width versions of the same, and 2) the NTK and GP limits of Eq. (33). As shown in Corollary 3.9, the
features of the NTK limit are fixed at initialization as n → ∞ (and so are those of the GP limit, by
definition), so its answer to Eq. (35) is uniformly selected from the whole vocabulary.47 Its accuracy
1
is thus |V|−3 . Since |V| is 71,291 for text8 and 142,276 for fil9, this number is practically 0.
We compute the µP limit according to Algorithm 1, but we relate more implementation details in
Appendix D.2.
word2vec pretrained on text8 word2vec pretrained on fil9
Findings We show our re- 45
sults in Table 3 and Figure 40 55
to the right. As expected, 35 50
word analogy acc

the infinite-width and finite-


width µP networks signifi- 30
45
cantly outperform the NTK 25
limit. In addition, we ob- 20 log2(width) 40 log2(width)
serve the finite width µP 6.0 6.0
15 8.0 35 8.0
networks converge to the 10.0 10.0
performance of the µP limit 10
30
from below, as width in- 2.5 5.0 7.5 10.0 12.5 15.0 1 2 3 4 5
creases. epoch epoch
46
http://mattmahoney.net/dc/textdata.html
47 ¯ is actually Θ(√n) instead of Θ(n) because ξ, ξ¯ are one-hot,
There is some nuance here because h(ξ)> h(ξ)
but the conclusion is the same; see Appendix D.2.

29
10 Conclusion
In this paper, we presented a framework, based on the notion of abc-parametrizations and Tensor Pro-
grams technique, that unifies the Neural Tangent Kernel (NTK) and Mean Field limits of large width
neural networks (NNs). In the Dynamical Dichotomy theorem, we classified the abc-parametrizations
into feature learning and kernel regimes. We identified the lack of feature learning as a fatal weakness
of NTK as a model for real NN. In fact, we showed the standard parametrization suffers from the
same problem. As a solution, we proposed the Maximal Update Parametrization (µP) and derived its
infinite-width limit, which admits feature learning. Through experiments on Word2Vec and few-shot
learning, we demonstrated that µP is a good model for feature learning behavior in neural networks.
More generally, this paper showcased the power of the Tensor Programs technique: Any computation
expressable in a Tensor Program has a “infinite-width” limit we can derive. Because of the universality
of Tensor Programs for expressing deep learning computation [43, 45], this technique systematically
solves the mathematical problem of taking infinite-width limits which has been dealt with haphazardly
in prior literature. Its immense flexibility means that the theory of reinforcement learning, self-
supervised learning, deep generative models, etc with overparametrized neural networks in the feature
learning regime are now ripe for the picking.

Acknowledgements
In alphabetical order, we thank Zeyuan Allen-Zhu, Francis Bach, Yasaman Bahri, Lenaic Chizat,
Jeremy Cohen, Yarin Gal, Quanquan Gu, Bobby He, Di He, Jiaoyang Huang, Arthur Jacot, Jaehoon
Lee, Jason Lee, Zhiyuan Li, Etai Littwin, Yiping Lu, Song Mei, Roman Novak, Vinay Rao, Michael
Santacroce, Sam Schoenholz, Lisa Schut, Jascha Sohl-Dickstein, Alessandro Sordoni, Denny Wu,
Huishuai Zhang, and Pengchuan Zhang for discusson and feedback.

References
[1] Martín Abadi, Paul Barham, Jianmin Chen, Zhifeng Chen, Andy Davis, Jeffrey Dean, Matthieu
Devin, Sanjay Ghemawat, Geoffrey Irving, Michael Isard, et al. Tensorflow: A system for
large-scale machine learning. In 12th {USENIX} Symposium on Operating Systems Design and
Implementation ({OSDI} 16), pages 265–283, 2016.
[2] Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. A Convergence Theory for Deep Learning
via Over-Parameterization. arXiv:1811.03962 [cs, math, stat], November 2018. URL http:
//arxiv.org/abs/1811.03962.
[3] Dyego Araújo, Roberto I. Oliveira, and Daniel Yukimura. A mean-field limit for certain deep
neural networks. arXiv:1906.00193 [cond-mat, stat], June 2019. URL http://arxiv.org/
abs/1906.00193.
[4] Mohsen Bayati and Andrea Montanari. The dynamics of message passing on dense graphs,
with applications to compressed sensing. IEEE Transactions on Information Theory, 57(2):
764–785, February 2011. ISSN 0018-9448, 1557-9654. doi: 10.1109/TIT.2010.2094817. URL
http://arxiv.org/abs/1001.3448.
[5] Tom B. Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared Kaplan, Prafulla Dhari-
wal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, Sandhini Agar-
wal, Ariel Herbert-Voss, Gretchen Krueger, Tom Henighan, Rewon Child, Aditya Ramesh,
Daniel M. Ziegler, Jeffrey Wu, Clemens Winter, Christopher Hesse, Mark Chen, Eric Sigler,
Mateusz Litwin, Scott Gray, Benjamin Chess, Jack Clark, Christopher Berner, Sam McCandlish,
Alec Radford, Ilya Sutskever, and Dario Amodei. Language Models are Few-Shot Learners.
arXiv:2005.14165 [cs], July 2020. URL http://arxiv.org/abs/2005.14165.
[6] Tom B. Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared Kaplan, Prafulla Dhari-
wal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, Sandhini Agar-
wal, Ariel Herbert-Voss, Gretchen Krueger, Tom Henighan, Rewon Child, Aditya Ramesh,
Daniel M. Ziegler, Jeffrey Wu, Clemens Winter, Christopher Hesse, Mark Chen, Eric Sigler,
Mateusz Litwin, Scott Gray, Benjamin Chess, Jack Clark, Christopher Berner, Sam McCandlish,
Alec Radford, Ilya Sutskever, and Dario Amodei. Language Models are Few-Shot Learners.
arXiv:2005.14165 [cs], July 2020. URL http://arxiv.org/abs/2005.14165.

30
[7] Minmin Chen, Jeffrey Pennington, and Samuel Schoenholz. Dynamical Isometry and a Mean
Field Theory of RNNs: Gating Enables Signal Propagation in Recurrent Neural Networks.
In Proceedings of the 35th International Conference on Machine Learning, volume 80 of
Proceedings of Machine Learning Research, pages 873–882, Stockholmsmässan, Stockholm
Sweden, July 2018. PMLR. URL http://proceedings.mlr.press/v80/chen18i.html.
[8] Lenaic Chizat and Francis Bach. A Note on Lazy Training in Supervised Differentiable
Programming. page 19.
[9] Lenaic Chizat and Francis Bach. On the Global Convergence of Gradient Descent for Over-
parameterized Models using Optimal Transport. arXiv:1805.09545 [cs, math, stat], May 2018.
URL http://arxiv.org/abs/1805.09545.
[10] Lenaic Chizat and Francis Bach. Implicit Bias of Gradient Descent for Wide Two-layer Neural
Networks Trained with the Logistic Loss. arXiv:2002.04486 [cs, math, stat], June 2020. URL
http://arxiv.org/abs/2002.04486.
[11] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. BERT: Pre-training of
Deep Bidirectional Transformers for Language Understanding. arXiv:1810.04805 [cs], May
2019. URL http://arxiv.org/abs/1810.04805. arXiv: 1810.04805 version: 2.
[12] Cong Fang, Jason D. Lee, Pengkun Yang, and Tong Zhang. Modeling from Features: a Mean-
field Framework for Over-parameterized Deep Neural Networks. arXiv:2007.01452 [cs, math,
stat], July 2020. URL http://arxiv.org/abs/2007.01452. arXiv: 2007.01452.
[13] Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-Agnostic Meta-Learning for Fast
Adaptation of Deep Networks. arXiv:1703.03400 [cs], July 2017. URL http://arxiv.org/
abs/1703.03400.
[14] Dar Gilboa and Guy Gur-Ari. Wider Networks Learn Better Features. September 2019. URL
https://arxiv.org/abs/1909.11572v1.
[15] Dar Gilboa, Bo Chang, Minmin Chen, Greg Yang, Samuel S. Schoenholz, Ed H. Chi, and
Jeffrey Pennington. Dynamical Isometry and a Mean Field Theory of LSTMs and GRUs.
arXiv:1901.08987 [cs, stat], January 2019. URL http://arxiv.org/abs/1901.08987.
[16] Eugene A. Golikov. Dynamically Stable Infinite-Width Limits of Neural Classifiers.
arXiv:2006.06574 [cs, stat], October 2020. URL http://arxiv.org/abs/2006.06574.
[17] Boris Hanin. Which Neural Net Architectures Give Rise To Exploding and Vanishing Gradients?
January 2018. URL https://arxiv.org/abs/1801.03744.
[18] Boris Hanin and David Rolnick. How to Start Training: The Effect of Initialization and
Architecture. arXiv:1803.01719 [cs, stat], March 2018. URL http://arxiv.org/abs/1803.
01719.
[19] Soufiane Hayou, Arnaud Doucet, and Judith Rousseau. On the Selection of Initialization and
Activation Function for Deep Neural Networks. arXiv:1805.08266 [cs, stat], May 2018. URL
http://arxiv.org/abs/1805.08266.
[20] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep Residual Learn-
ing for Image Recognition. pages 770–778, 2016. URL https://www.cv-
foundation.org/openaccess/content_cvpr_2016/html/He_Deep_Residual_
Learning_CVPR_2016_paper.html.
[21] Dan Hendrycks and Kevin Gimpel. Gaussian Error Linear Units (GELUs). arXiv:1606.08415
[cs], July 2020. URL http://arxiv.org/abs/1606.08415.
[22] Jiaoyang Huang and Horng-Tzer Yau. Dynamics of deep neural networks and neural tangent
hierarchy, 2019.
[23] Arthur Jacot, Franck Gabriel, and Clément Hongler. Neural Tangent Kernel: Convergence
and Generalization in Neural Networks. arXiv:1806.07572 [cs, math, stat], June 2018. URL
http://arxiv.org/abs/1806.07572.

31
[24] Aitor Lewkowycz, Yasaman Bahri, Ethan Dyer, Jascha Sohl-Dickstein, and Guy Gur-Ari. The
large learning rate phase of deep learning: the catapult mechanism. arXiv:2003.02218 [cs, stat],
March 2020. URL http://arxiv.org/abs/2003.02218.
[25] Zhuohan Li, Eric Wallace, Sheng Shen, Kevin Lin, Kurt Keutzer, Dan Klein, and Joseph E.
Gonzalez. Train Large, Then Compress: Rethinking Model Size for Efficient Training and
Inference of Transformers. arXiv:2002.11794 [cs], June 2020. URL http://arxiv.org/
abs/2002.11794.
[26] Song Mei, Andrea Montanari, and Phan-Minh Nguyen. A mean field view of the landscape
of two-layer neural networks. Proceedings of the National Academy of Sciences, 115(33):
E7665–E7671, August 2018. ISSN 0027-8424, 1091-6490. doi: 10.1073/pnas.1806579115.
URL https://www.pnas.org/content/115/33/E7665.
[27] Song Mei, Theodor Misiakiewicz, and Andrea Montanari. Mean-field theory of two-layers
neural networks: dimension-free bounds and kernel limit. arXiv:1902.06015 [cond-mat, stat],
February 2019. URL http://arxiv.org/abs/1902.06015.
[28] Tomas Mikolov, Kai Chen, Greg Corrado, and Jeffrey Dean. Efficient Estimation of Word
Representations in Vector Space. arXiv:1301.3781 [cs], September 2013. URL http://arxiv.
org/abs/1301.3781.
[29] Tomas Mikolov, Ilya Sutskever, Kai Chen, Greg Corrado, and Jeffrey Dean. Distributed
Representations of Words and Phrases and their Compositionality. arXiv:1310.4546 [cs, stat],
October 2013. URL http://arxiv.org/abs/1310.4546.
[30] Phan-Minh Nguyen. Mean Field Limit of the Learning Dynamics of Multilayer Neural Networks.
arXiv:1902.02880 [cond-mat, stat], February 2019. URL http://arxiv.org/abs/1902.
02880.
[31] Phan-Minh Nguyen and Huy Tuan Pham. A Rigorous Framework for the Mean Field Limit
of Multilayer Neural Networks. arXiv:2001.11443 [cond-mat, stat], January 2020. URL
http://arxiv.org/abs/2001.11443.
[32] Alex Nichol, Joshua Achiam, and John Schulman. On First-Order Meta-Learning Algorithms.
March 2018. URL https://arxiv.org/abs/1803.02999v3.
[33] Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory
Chanan, Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga, Alban Desmai-
son, Andreas Kopf, Edward Yang, Zachary DeVito, Martin Raison, Alykhan Tejani,
Sasank Chilamkurthy, Benoit Steiner, Lu Fang, Junjie Bai, and Soumith Chintala. Py-
torch: An imperative style, high-performance deep learning library. In H. Wallach,
H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett, editors, Ad-
vances in Neural Information Processing Systems 32, pages 8024–8035. Curran Associates,
Inc., 2019. URL http://papers.neurips.cc/paper/9015-pytorch-an-imperative-
style-high-performance-deep-learning-library.pdf.
[34] Jeffrey Pennington, Samuel Schoenholz, and Surya Ganguli. Resurrecting the sigmoid in deep
learning through dynamical isometry: theory and practice. In I. Guyon, U. V. Luxburg,
S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, editors, Advances
in Neural Information Processing Systems 30, pages 4788–4798. Curran Associates, Inc.,
2017. URL http://papers.nips.cc/paper/7064-resurrecting-the-sigmoid-in-
deep-learning-through-dynamical-isometry-theory-and-practice.pdf.
[35] George Philipp and Jaime G. Carbonell. The Nonlinearity Coefficient - Predicting Overfitting
in Deep Neural Networks. arXiv:1806.00179 [cs, stat], May 2018. URL http://arxiv.org/
abs/1806.00179.
[36] Ben Poole, Subhaneil Lahiri, Maithreyi Raghu, Jascha Sohl-Dickstein, and Surya Ganguli.
Exponential expressivity in deep neural networks through transient chaos. In Advances In
Neural Information Processing Systems, pages 3360–3368, 2016.
[37] Grant M. Rotskoff and Eric Vanden-Eijnden. Neural Networks as Interacting Particle Systems:
Asymptotic Convexity of the Loss Landscape and Universal Scaling of the Approximation Error.
arXiv:1805.00915 [cond-mat, stat], May 2018. URL http://arxiv.org/abs/1805.00915.

32
[38] Samuel S. Schoenholz, Justin Gilmer, Surya Ganguli, and Jascha Sohl-Dickstein. Deep Infor-
mation Propagation. 2017. URL https://openreview.net/pdf?id=H1W1UN9gg.
[39] Justin Sirignano and Konstantinos Spiliopoulos. Mean Field Analysis of Neural Networks.
arXiv:1805.01053 [math], May 2018. URL http://arxiv.org/abs/1805.01053.
[40] Justin Sirignano and Konstantinos Spiliopoulos. Mean Field Analysis of Deep Neural Networks.
arXiv:1903.04440 [math, stat], February 2020. URL http://arxiv.org/abs/1903.04440.
[41] Jascha Sohl-Dickstein, Roman Novak, Samuel S. Schoenholz, and Jaehoon Lee. On the infinite
width limit of neural networks with a standard parameterization. arXiv:2001.07301 [cs, stat],
January 2020. URL http://arxiv.org/abs/2001.07301.
[42] Blake Woodworth, Suriya Gunasekar, Jason D. Lee, Edward Moroshko, Pedro Savarese, Itay
Golan, Daniel Soudry, and Nathan Srebro. Kernel and Rich Regimes in Overparametrized
Models. arXiv:2002.09277 [cs, stat], July 2020. URL http://arxiv.org/abs/2002.09277.
[43] Greg Yang. Tensor Programs I: Wide Feedforward or Recurrent Neural Networks of Any Archi-
tecture are Gaussian Processes. arXiv:1910.12478 [cond-mat, physics:math-ph], December
2019. URL http://arxiv.org/abs/1910.12478.
[44] Greg Yang. Scaling Limits of Wide Neural Networks with Weight Sharing: Gaussian Process
Behavior, Gradient Independence, and Neural Tangent Kernel Derivation. arXiv:1902.04760
[cond-mat, physics:math-ph, stat], February 2019. URL http://arxiv.org/abs/1902.
04760.
[45] Greg Yang. Tensor Programs II: Neural Tangent Kernel for Any Architecture. arXiv:2006.14548
[cond-mat, stat], August 2020. URL http://arxiv.org/abs/2006.14548.
[46] Greg Yang. Tensor Programs III: Neural Matrix Laws. arXiv:2009.10685 [cs, math], September
2020. URL http://arxiv.org/abs/2009.10685.
[47] Greg Yang and Hadi Salman. A fine-grained spectral perspective on neural networks, 2019.
[48] Greg Yang and Sam S. Schoenholz. Deep mean field theory: Layerwise variance and width
variation as methods to control gradient explosion, 2018. URL https://openreview.net/
forum?id=rJGY8GbR-.
[49] Greg Yang and Samuel S. Schoenholz. Mean Field Residual Network: On the Edge of Chaos.
In Advances in neural information processing systems, 2017.
[50] Greg Yang, Jeffrey Pennington, Vinay Rao, Jascha Sohl-Dickstein, and Samuel S. Schoenholz.
A Mean Field Theory of Batch Normalization. arXiv:1902.08129 [cond-mat], February 2019.
URL http://arxiv.org/abs/1902.08129.

33
A A Short Origin Story of the Tensor Programs Paper Series
The Tensor Programs framework was initially proposed in [44] in February 2019, and was mainly
applied to extend the NNGP and NTK limits to arbitrary architectures (and to make rigorous the
signal propagation literature [7, 15, 17–19, 34–36, 38, 47–50]). While NNGP and NTK amount to
taking limits of neural networks at initialization, it was soon, in April 2019, realized that Tensor
Programs could 1) also trivially take limits of the entire training procedure of neural networks (which
is the main theoretical idea of this paper), and 2) calculate the feature learning limit. However, at
that point, it also became clear that [44] was not written accessibly, and its formulation of Tensor
Programs was cumbersome to use. A question had to be asked: Should the feature learning paper
be written immediately on such an unwieldy foundation, or should significant effort be devoted
to fixing this foundation first? Eventually, a decision was made in favor of the latter. The Tensor
Programs series was created as way to re-organize and re-present the Tensor Programs machinery in
a user-friendly way to the machine learning audience (the first 3 papers [43, 45, 46] of the series),
before extracting payoffs from this foundation (starting from this paper).

B Further Discussions on the Shallow NTK and MF Examples


How does the Function Change? If the NTK limit does not allow features to evolve, then how
does learning occur? To answer this question, note
∆ft (ξ) = V0 ∆xt (ξ) + ∆Vt x0 (ξ) + ∆Vt ∆xt (ξ).
In short, then, the evolution of ft (ξ) in the NTK limit is predominantly due to V0 ∆xt (ξ) and
∆Vt x0 (ξ) only, while in the MF limit, ∆Vt ∆xt (ξ) also contributes nontrivially.
Example: For t = 1, ∆f1 (ξ) = V0 ∆x1 (ξ) + n−2av x>
0 x0 (ξ) + n
−2av >
x0 ∆x1 (ξ). In NTP, av = 1/2,
−2av >
so the√term n x0 x0 (ξ) = Θ(1)√ for generic ξ, ξ0 . On the other hand, n−2av x>0 ∆x1 (ξ) =
O(1/ n) because ∆x1 (ξ) = O(1/ n) as noted above. Likewise,
V0 ∆x1 (ξ) ≈ V0 [φ0 (h0 (ξ)) ∆h1 (ξ)] = V0 [φ0 (h0 (ξ)) ∆h1 (ξ)]
X n Xn
=C V0α φ0 (h0 (ξ)α )V0α φ0 (h0α ) = C (V0α )2 φ0 (h0 (ξ)α )φ0 (h0α ),
α=1 α=1
2
where C = χ0 ξ0 ξ = Θ(1). Now (V0α ) = Θ(1/n) and is almost surely positive. On the other hand,
φ0 (h0 (ξ)α )φ0 (h0α ) = Θ(1) and should have a nonzero expectation over random initialization (for
example, if φ is relu then this is obvious). Therefore, the sum above should amount to V0 ∆x1 (ξ) ≈
Θ(1). In summary, in the NTK limit, ∆f1 (ξ) = Θ(1) due to the interactions between V0 and ∆x1 (ξ)
and between ∆V1 and x0 (ξ), but there is only vanishing interaction between ∆V1 and ∆x1 (ξ).
The case for general t, again, can be derived easily using Tensor Programs.

C abc-Parametrization for General Neural Architectures


We can straightforwardly generalize abc-parametrizations to an arbitrary neural architecture. Each
parameter tensor W would get its own aW and bW , such that W = n−aW w and w is the actual
trainable parameter with initialization wαβ ∼ N (0, n−2bW ). The learning rate is still ηn−c for some
fixed η.

C.1 Maximal Update Parametrization


MLP with Biases Suppose in Eq. (1), for each l ∈ [L], we have hl (ξ) = W l xl−1 (ξ) + bl instead,
for bias bl ∈ Rn . Then in µP, the bias bl should have abl = −1/2 and bbl = 1/2. We can also have
bias bL+1 in the logits f (ξ) = W L+1 xL (ξ) + bL+1 . Then we set abL+1 = bbL+1 = 0.

General Neural Architectures More generally, µP can be defined easily for any neural architecture
whose forward pass can be written down as a Tensor Program (e.g. ResNet or Transformer; see
[43] for explicit programs). The learning rate is always independent of width, i.e. c = 0. For any
parameter tensor W , bW is always 1/2, and aW can be defined as follows: If W is not an output
weight matrix, then aW should be set to −1 + 21 pW , where pW = limn→∞ logn #(W ) is a) 0 if both

34
sides of W are fixed w.r.t. n; b) 1 if W is a vector (e.g. bias) or with one side being fixed dimensional
(e.g. W 1 ); and c) 2 if W is a matrix with both sides scaling like n (e.g. weights in the middle of an
MLP). If W is an output weight matrix (and thus the output dimension is fixed w.r.t. n), then aW
should be 21 . If W is an output bias, then aW should be 0.

Optimality Properties One can formalize, in this general context, the notion of stability and the
notions of a parameter tensor being updated maximally and (a set of readout weights) being initialized
maximally. Then one can show that µP is the unique stable abc-parametrization such that all of its
parameter tensors are updated maximally and all of its readout weights are initialized maximally.

D Experimental Details
The main models in our experiments are all 1-hidden-layer linear MLPs with input dimension d and
output dimension do . In our experiments, we will consider more advanced forms, but, as warmup, a
basic version of such a network is given by
f (ξ) = V x(ξ),
x(ξ) = φ(h(ξ)), h(ξ) = U ξ, (36)

for U ∈ Rn×d , V ∈ Rdo ×n parametrized like U = nu, V = √1n v and with initialization
uαβ , vαβ ∼ N (0, 1/n). In this case, Corollary 6.2 generalizes to
Theorem D.1. Consider a 1-hidden-layer linear MLP in µP (Eq. (36)) and any training routine with
learning rate η. As n → ∞, for every input ξ ∈ Rd , ft (ξ) ∈ Rdo converges almost surely to f˚t (ξ)
defined as follows:
f˚t (ξ) = (At Ct + Bt Dt )ξ ∈ Rdo ,
χ̊t = L0 (f˚t , yt ) ∈ Rdo ,
(At+1 , Bt+1 ) = (At , Bt ) − ηχ̊t ⊗ (Ct ξt , Dt ξt ),
(Ct+1 , Dt+1 ) = (Ct , Dt ) − η(A> >
t χ̊t , Bt χ̊t ) ⊗ ξt ,

where ⊗ denotes outer product (u ⊗ v = uv > ), with initial condition


A0 = Ido ∈ Rdo ×do , D0 = Id ∈ Rd×d , B0 = 0 ∈ Rdo ×d , C0 = 0 ∈ Rd×do .

While we will not use this theorem, we intend it to give an idea of the mathematical process underneath
our implementations, which we discuss now.

D.1 Few-Shot Learning on Omniglot via MAML


D.1.1 Linear 1-Hidden-Layer µP Network
We consider a linear 1-hidden-layer MLP with bias, input dimension d, output dimension do , given
by
f (ξ) = V h(ξ) ∈ Rdo , h(ξ) = U ξ + B ∈ Rn ,
√ √
where ξ ∈ Rd . Following µP, we factor U = nu ∈ Rn×d , V = √1n v ∈ Rdo ×n , B = α nβ ∈ Rn ,
where u, v, β are the trainable parameters.√We initialize uαβ ∼ N (0, σu2 /n), vαβ ∼ N (0, σv2 /n),
β = 0 ∈ Rn . We can cancel the factors of n and rewrite
f (ξ) = vh(ξ) ∈ Rdo , h(ξ) = uξ + b ∈ Rn ,
where b = αβ. We will also consider gradient clipping with threshold g and weight decay with
coefficient γ. So in summary, the hyperparameters are
σu , σv (init. std.), α (bias multiplier), η (LR), g (grad. clip), γ (weight decay).

As in Corollary 6.2, it’s easy to see that each column of ut at any time t is always a linear combination
of the columns of u0 and the rows of v0 such that the coefficients of these linear combinations converge
deterministically in the n → ∞ limit; likewise for bt and the rows of vt . To track the evolution of f ,
it suffices to track these coefficients. Therefore, for implementation, we reparametrize as follows:

35
Coefficient matrix and vector Let µ1 , . . . , µd , ν1 , . . √ . , νdo ∈ Rn be √ standard Gaussian vectors
such that the columns of √ u0 will be initialized
√ as σ u µ1 / n, . . . , σ u µd / n and the rows of V0 will
be initialized as σv ν1 / n, . . . , σv νdo / n. Write µ = (µ1 , . . . , µd ) ∈ Rn×d , ν = (ν1 , . . . , νdo ) ∈
Rn×do . Define coefficient matrices
u> ∈ Rd×(d+do ) , v ∈ Rdo ×(d+do ) ,
such that at any time, (u, v > ) ∈ Rn×(d+do ) is √1n (µ, ν)(u, v > ) in the infinite-width limit. We
initialize    
u> σu I 0
← 0 σv I ,
v
i.e. a “diagonal” initialization. Likewise, define coefficient vector b ∈ Rd+do , initialized at 0,
such that, at any time, b is approximately distributed as √1n (µ, ν)b. To track the evolution of the
infinite-width network, we will track the evolution of u, v, b.
In general, we use bold to denote the coefficients (in µ, ν) of a tensor (e.g. b for coefficients of b). We
also use capital letters to denote the batched version (e.g. H for batched version of h). Algorithms 2
and 3 below summarize the SGD training of the finite- and the infinite-width networks. Note that
aside from initialization and the hidden size (n vs d + do ), the algorithms are essentially identical.

Algorithm 2 SGD Training of Finite-Width Lin- Algorithm 3 SGD Training of Infinite-Width Lin-
ear µP 1-Hidden-Layer Network ear µP 1-Hidden-Layer Network
Input: Hyperparameters n, σu , σv , α, η, g, γ. Input: Hyperparameters σu , σv , α, η, g, γ.
1: Initialize uαβ ∼ N (0, σu2 /n) 1: Initialize u> ← (σu I, 0)
2: Initialize vαβ ∼ N (0, σv2 /n) 2: Initialize v ← (0, σv I)
3: Initialize b ← 0 3: Initialize b ← 0
4: for each batch of inputs Ξ ∈ RB×d and la- 4: for each batch of inputs Ξ ∈ RB×d and la-
bels Y ∈ RB×do do bels Y ∈ RB×do do
5: // Forward Pass 5: // Forward Pass
6: H ← Ξu> + b ∈ RB×n 6: H ← Ξu> + b ∈ RB×(d+do )
7: f (Ξ) ← Hv > ∈ RB×do 7: f (Ξ) ← Hv > ∈ RB×do
8: // Backward Pass 8: // Backward Pass
9: χ ← L0 (f (Ξ), Y ) ∈ RB×do 9: χ ← L0 (f (Ξ), Y ) ∈ RB×do
10: du ← −v > χ> Ξ ∈ Rn×d 10: du ← −v > χ> Ξ ∈ R(d+do )×d
11: dv ← −χ> H ∈ Rdo ×n 11: dv ← −χ> H ∈ Rdo ×(d+do )
12: db ← −α2 1> χv ∈ Rn 12: db ← −α2 1> χv ∈ Rd+do
13: // Gradient
q Clipping // Gradient
13: q Clipping
14: G ← kduk2F + kdvk2F + k db αk
2
14: G ← kduk2F + kdvk2F + k db αk
2

15: ρ ← min(1, g/G) 15: ρ ← min(1, g/G)


16: du ← ρdu 16: du ← ρdu
17: dv ← ρdv 17: dv ← ρdv
18: db ← ρdb 18: db ← ρdb
19: // Gradient Step w/ Weight Decay 19: // Gradient Step w/ Weight Decay
20: u += ηdu − ηγu ∈ Rd×n 20: u += ηdu − ηγu ∈ R(d+do )×d
21: v += ηdv − ηγv ∈ Rdo ×n 21: v += ηdv − ηγv ∈ Rdo ×(d+do )
22: b += ηdb − ηγb ∈ Rn 22: b += ηdb − ηγb ∈ Rd+do
23: end for 23: end for

The algorithms for MAML can then be obtained by a straightforward modification of these algorithms.
(Note that in MAML, we do not clip gradients during adaptation, but rather clip the gradient against
the validation loss of task; we also disable weight decay by setting the coefficient γ to 0).

Hyperparameter Sweep We sweep σu , σv , η and α with the following grid for finite width and
µP networks.

• σu : [0.5, 1, 2, 4, 8],
• σv : [2−5 , 2−4 , 2−3 , 2−2 , 2−1 ],

36
Algorithm 4 MAML Training of Kernel Model with Kernel K
Input: Kernel K, adaptation step size , meta learning rate η, batch size B, gradient clip g
1: Initialize Q = {}
2: while True do
3: Draw a batch of tasks
4: for each task in batch do
5: // Adaptation
6: Sample training set D
7: for each input/label pair (ξi , yi ) ∈ D do
8: χi ← L0 (fQ (ξi ), yi )
9: end for
10: for each input/label pair (ξi , yi ) ∈ D do
11: Q.push((ξi , −χi ))
12: end for
13: // Calculate Test Set Gradient
14: Sample test set D̂
15: for each input/label pair (ξˆi , ŷi ) ∈ D̂ do
16: χ̂i ← L0 (fQ (ξˆi ), ŷi )
17: end for
18: for each input/label pair (ξi , yi ) ∈ D do
19: Q.pop((ξi , −χi ))
20: end for
21: // Gradient
qP Clip
G←
P ˆ ˆ
22: (ξ̂i ,ŷi )∈D̂ (ξ̂j ,ŷj )∈D̂ χ̂i χ̂j K(ξi , ξj )
23: ρ ← min(1, g/G)
24: // Gradient Update
25: for each input/label pair (ξˆi , ŷi ) ∈ D̂ do
26: Q.push((ξˆi , −ρη χ̂i ))
27: end for
28: end for
29: end while

• η : [0.025, 0.05, 0.1, 0.2, 0.4],


• α : [0.25, 0.5, 1, 2, 4]
We are interested in 1-shot, 5-way learning with Omniglot. This means that each task provides
5 training samples, each corresponding to one of the 5 labels of the task. Each hyperparameter
combination above is used to train for 100 epochs over 3 random seeds, where each epoch consists of
100 batches of 32 tasks. We average the validation accuracy across the last 10 epochs and document
the best hyperparameters in Table 4, along with the test accuracy from a 15-seed rerun48 for better
benchmarking. For NTK and GP, we additionally tune the initialization σb for biases, which is set to
0 for both finite and µP networks for simplicity.

D.1.2 NNGP and NTK for Relu Networks


Consider a kernel K, which in our case will be the NNGP or NTK of a 1-hidden-layer relu network.
WLOG, it is induced by an embedding Φ such that K(ξ, ζ) = hΦ(ξ), Φ(ζ)i where h, i is the inner
product in the embedding space; we do not care about the details of Φ or h, i as eventually our
algorithm only depends on K.
def
In our setting, we will train a linear layer W on top of Φ via MAML, f (ξ) = hW, Φ(ξ)i. One can see
easily that W is always a linear combination of Φ(ζ) for various ζ from the trainingPset we’ve seen
so far. Thus, to track W , it suffices to keep an array Q of pairs (ζ, q) such that W = (ζ,q)∈Q qΦ(ζ)
at all times. Let fQ be the function with W given by Q. Then
X
fQ (ξ) = qζ K(ζ, ξ).
(ζ,qζ )∈Q

48
After excluding outliers at least one standard deviation away from the mean.

37
Table 4: Best hyperparameters for the MAML experiment.
log2 Width/Limit σu σv σb η α Val. Acc. (%) Test Acc. (%)
1 0.5 0.5 - 0.05 2 46.72 ± 4.30 55.34 ± 1.24
3 0.5 0.25 - 0.1 1 65.30 ± .27 64.54 ± .70
5 1 0.125 - 0.4 0.5 68.74 ± .18 66.21 ± .15
7 1 0.125 - 0.1 1 69.03 ± .04 66.31 ± .16
9 1 0.03125 - 0.1 1 69.32 ± .07 66.43 ± .23
11 1 0.03125 - 0.1 1 69.27 ± .11 66.36 ± .22
13 1 0.03125 - 0.1 1 69.27 ± .14 66.41 ± .18
µP 1 0.03125 - 0.1 1 69.26 ± .13 66.42 ± .19
NTK 0.25 1 1 0.05 1 47.47 ± .13 47.82 ± .04
GP 1 0.25 1 0.05 1 38.92 ± .15 47.60 ± .02

In our case, the number of possible inputs is too large to instantiate a value q for every ζ, so we
gradually grow a dynamic array Q, which we model as a stack. Then MAML can be implemented as
in Algorithm 4.

Hyperparameter Sweep We sweep σu , σv , σb and η with the following grid for GP and NTK.
• σu : [0.25, 0.5, 1, 2, 4],
• σv : [0.25, 0.5, 1, 2, 4],
• σb : [0.25, 0.5, 1, 2, 4],
• η : [0.05, 0.1, 0.2, 0.4, 0.8]
Each hyperparameter combination above is used to train for 5 epochs (the first epoch is almost always
the best) over 3 random seeds, where each epoch consists of 100 batches of 32 tasks. We take the
validation accuracy among all epochs and document the best hyperparameters in Table 4, along with
the test accuracy from a 15-seed rerun.

D.2 Word2Vec Experimental Details


D.2.1 µP Limit
We shall derive the training algorithm for µP Word2Vec. First, we introduce the notation for word
embeddings. We denote Φi = def
h(ξ i ). If ξ i is a one-hot vector with the ith element set to 1, Φi is
th
essentially the i column of the weight matrix U . We also define the following short-hands for the
context embedding: ΦJ = def
Ej∈J Φj = h(ξ J ). Similarly, V > ξ τ describes a row in V ; we can define
Φτ̂ = ĥ(ξ τ ) = V > ξ τ and rewrite the loss function.
def def

log(1 − σ(ΦJ > Φτ̂ )) τ = i



J τ
L(f (ξ ), ξ ) = (37)
log σ(ΦJ > Φτ̂ ) τ 6= i.

Consequently, the backward pass becomes:


( η τ̂ J > τ̂
j 1 η ∂L |J| Φ (1 − σ(Φ Φ )) τ =i
∆Φ = ∆ΦJ = = η τ̂ J > τ̂ (38)
|J| |J| ∂ΦJ − |J| Φ σ(Φ Φ ) τ 6= i.

Following µP, we initialize Uαβ ∼ N (0, σu n−1 ) and −1


√ Vαβ ∼ N (0, σ√v n ), where n is the width of
the finite network. (Here the explicit multipliers of n in U and 1/ n in V cancel out because the
network is linear). The tunable hyperparameters are the initialization std σu and σv , learning rate η
and weight decay ratio γ. Rather than tuning the hyperparameters extensively for each width, we pick
some reasonable values and use them for all of our experiments. Specifically, we have σu = σv = 1,
η = 0.05 and γ = 0.001.
Again, using Corollary 6.2, we can train the µP limit in the coefficient space of u> ∈ R|V|×2|V| , v ∈
R|V|×2|V| , with the same “diagonal” initialization:

38
   
u> σu I 0
← 0 σv I ,
v
We can adopt the embedding notation and represent a row of u with the embedding coefficient vector
Φ• and a column of v with Φ•ˆ . This is computationally equivalent to training with a hidden size of
2|V| and with embeddings initialized as rows (or columns) of one-hot vectors. The full algorithm is
described in Algorithm 2 and Algorithm 3; in this case, we remove biases and use weight decay with
coefficient γ = 0.001. After training, rows of the weight matrix u (resp. coefficient matrix u), i.e.
Φ• (resp. Φ• ), are taken as the word vectors.

D.2.2 NTK Limit


In the NTK parametrization, V and U in Eq. (33) factor as V = √1n v and U = u, and the learning
rate is Θ(1). Each column U•i of U is equal to h(ξ i ). At any fixed time t, it is easy to see via Tensor
Programs that X √
ht (ξ i ) = h0 (ξ i ) + O(1/ n)vj + Ocoord (1/n)
j∈V
where vj denotes the jth row of v at initialization, and where Ocoord (1/n) means a vector that is
O(1/n) coordinatewise. Recall that U = u and v are initialized with iid standard Gaussian entries.
Because ξ i is one-hot, this in particular implies h0 (ξ i ) has standard Gaussian entries, and h0 (ξ i ) is
independent from h0 (ξ j ) for i 6= j. Then for any i 6= j,
1 1 a.s. 1 d
√ ht (ξ i )> ht (ξ j ) − √ h0 (ξ i )> h0 (ξ j ) −−→ 0, √ h0 (ξ i )> h0 (ξ j ) −
→ N (0, 1)
n n n
by Law of Large Numbers (or more formally, Theorem 7.4) and Central Limit Theorem. In other
words, √1n h0 (ξ i )> h0 (ξ j ) is distributed completely randomly, with no regard to the semantic simi-
larities of i and j. Likewise, the inner product in Eq. (35) is random, and the argmax is a uniform
sample.49 Therefore, in the NTK limit, Word2Vec gives random answers and achieves an accuracy of
1
|V|−3 .

E Nuances of the Master Theorem


Remark E.1 (Partial derivative). The partial derivative in ZDot should be interpreted as follows. By a
simple inductive argument, Z x for every vector x in the program is defined uniquely as a deterministic
1 k
function ϕ(Ẑ x , . . . , Ẑ x ) of some x1 , . . . , xk in V or introduced by MatMul (notationally, we are
suppressing the possible dependence on limit scalars θ̊1 , . . . , θ̊l ). For instance, if in a program we
have A ∈ W, v ∈ V, y = Av, x = A> y, then Z x = Ẑ x + Ẑ v , so ϕ is given by ϕ(a, b) = a + b.
Then
i 1 k
∂Z x /∂ Ẑ x =def
∂i ϕ(Ẑ x , . . . , Ẑ x ), and ∂Z x /∂ Ẑ z =
def
0 for any z 6∈ {x1 , . . . , xk }.
Note this definition depends on the precise way the program is written, not just on the underlying
mathematics. For example, if y, z ∈ V and x = φ(W (y + z)), then Z x = φ(Ẑ W (y+z) ) so that
∂Z x /∂ Ẑ W y = ∂Z x /∂ Ẑ W z = 0. If instead, we have x = φ(W y+W z), then Z x = φ(Ẑ W y + Ẑ W z )
>
so that ∂Z x /∂ Ẑ W (x+y) = 0. However, in both cases, Ż W x = (Z y + Z z ) E φ0 (Ẑ W (y+z) ).
∂Z x
Remark E.2 (Partial derivative expectation). The quantity E is well defined if Z x is differen-
∂ Ẑ W > y
>
tiable in Ẑ W y . However, even if this is not the case, e.g. if x = θ(W > y) where θ is the Heavyside
step function, we can still define this expectation by leveraging Stein’s lemma:
In ZDot, suppose {W > y i }ki=1 are all elements of VW > introduced before x. Define the matrix
i j > i
C ∈ Rk×k by Cij = def
E Z y Z y and define the vector b ∈ Rk by bi =def
E Ẑ W y Z x . If a = C + b
(where C + denotes the pseudoinverse of C), then in ZDot we may set
2 ∂Z x
σW E = ai . (39)
∂ Ẑ W > yi
49
Here the randomness comes from initialization: the argmax is different for different random initializations,
but it is fixed throughout training in the large width limit.

39
This definition agrees with the partial derivative expectation by Stein’s lemma when the latter is well
defined. Theorem 7.4 holds with this broader definition of partial derivative expectation.

Pseudo-Lipschitz functions are, roughly speaking, functions whose weak derivatives are polyno-
mially bounded.
Definition E.3. A function f : Rk → R is called pseudo-Lipschitz of degree d if |f (x) − f (y)| ≤
Pk
Ckx − yk(1 + i=1 |xi |d + |yi |d ) for some C. We say f is pseudo-Lipschitz if it is so for any
degree.
Here are some basic properties of pseudo-Lipschitz functions:
• The norm k · k in Definition E.3 can be any norm equivalent to the `2 norm, e.g. `p , p ≥ 1,
Pk
norms. Similarly, i=1 |xi |d + |yi |d can be replaced by kxkdp + kykdp , for any p ≥ 1.
• A pseudo-Lipschitz function is polynomially bounded.
• A composition of pseudo-Lipschitz functions of degrees d1 and d2 is pseudo-Lipschitz of
degree d1 + d2 .
• A pseudo-Lipschitz function is Lipschitz on any compact set.
We adopt the following assumption for the Master Theorem Theorem 7.4.
Assumption E.4. Suppose
1. If a function φ(; −) : R0+l → R with only parameter arguments is used in Moment, then φ
is continuous in those arguments.
2. Any other function φ(−; −) : Rk+l → R with parameters (where k > 0) used in Nonlin or
Moment is pseudo-Lipschitz in all of its arguments (both inputs and parameters).
Statement 1 in Assumption E.4 essentially says that if we have scalars θ1 , . . . , θl in the program,
then we can produce a new scalar by applying a continuous function (a weaker restriction than a
pseudo-Lipschitz function) to them. Indeed, if θ1 , . . . , θl converge almost surely, then this new scalar
does too. In our setting, statement 1 is used to allow any loss function whose derivative is continuous.
Other versions of the Master Theorem can be found in [46], for example, versions where the we do
not assume any smoothness condition at all on the nonlinearities beyond that they be polynomially
bounded, in exchange for assuming what’s called a rank stability condition. This rank stability should
be generically true, but checking it rigorously is subtle, so we are content with the pseudo-Lipschitz
condition in this paper.

F A Rough Sketch of the Geometry of abc-Parametrizations


By the results of Section 3.2, the stable abc-parametrizations form a polyhedron defined by the
inequalities of Theorem 3.3. We call the polyhedron obtained by quotienting Eq. (5) the stable
polyhedron. In this section, we remark on some geometric properties of this polyhedron.
First, observe that the stable polyhedron is unbounded (thus, we say polyhedron instead of polytope).
Indeed, given any stable parametrization, for any l, we can set al ← al + θ, bl ← bl − θ for any
θ ≥ 0 to obtain another stable parametrization. This corresponds decreasing the layer l learning rate,
so that as θ → ∞, W l is not trained.
Second, by Theorem 3.4, the nontrivial parametrizations reside in two facets of the stable polyhedron.
These facets are unbounded for the same reason as above.
Next, we show that NTP (as well as µP) is a vertex on the intersection of these two facets, and NTP
and µP are connected by an edge.
Definition F.1. Consider a stable abc-parametrization of the MLP in Eq. (1). We say the body of the
MLP is uniformly updated if, for some training routine, time t ≥ 1, and input ξ, ∆Wtl xlt (ξ) = Θ(n−r )
for all l simultaneously, where r is as defined in Definition 3.2.
In the results of this section below, we assume Assumption G.21.
Proposition F.2. In a stable abc-parametrization, the MLP body is uniformly updated iff rl = r for
all l ∈ [L], where rl is as defined in Proposition 5.3.

40
Boundary of Uniform Stable Polyhedron
𝑏2 = 0

𝑎2 + 𝑏2 = 1/2

Kernel Regime
𝑓0 is a nonzero GP

𝑟>0
𝑎2 = 1/2 NTP Body NTK limit
𝑊 𝐿+1 updated 𝑎1 = 0
maximally 𝑎2 + 𝑏2 + 𝑎1 = 1/2
𝑊 𝐿+1 init. maximally 1
𝑎1 = −
2
𝜇P
Nontrivial

Feature Learning
𝑎2 + 𝑎1 = 0 𝑎2 + 𝑏2 + 2𝑎1 = 0 feature learning
Δ𝑊 𝐿+1 𝑊0𝐿+1

𝑟=0
dominates dominates Legend
𝑊0𝐿+1 Δ𝑊 𝐿+1 Trivial
Nontrivial
−𝑎1 = 𝑎2 = 𝑏2

Figure 5: 2D Projection of the Boundary of the Uniform Stable Polyhedron (Equivalently, the
Boundary of the Stable Polyhedron for L = 1). Here, we label each facet and edge of the graph
with orange text to indicate the corresponding defining algebraic condition in the L = 1 case (as part
of the stable polyhedron, assuming c = 0 and b1 = −a1 ), and with black text to indicate the verbal
interpretation valid for all L (as part of the uniform stable polyhedron). We obtain the caricature
in the introduction by taking the nontrivial part of the graph and quotienting the two facets by their
respective points at infinity. Explanation of some captions: GP limit means the training dynamics
amounts to training only the last layer in the infinite-width limit, starting from a nonzero initial GP.
Body NTK limit means NTK dynamics except the last layer does not contribute to the NT kernel.

Theorem F.3. In NTP, the MLP body is updated uniformly and W L+1 is both initialized and updated
maximally. Furthermore, at initialization, f0 converges in distribution50 to a Gaussian Process with
nonzero kernel. NTP is the unique (modulo Eq. (5)) stable abc-parametrization with both of these
properties.
Theorem F.4. For any r ∈ [0, 1/2], there is a unique (modulo Eq. (5)) stable abc-parametrization
with 1) that value of r and the property that 2) the MLP body is updated uniformly and W L+1 is both
initialized and updated maximally. We call this parametrization the Uniform Parametrization with
r-value r, denoted UPr . Its abc values are
1
al = − I(l = 1) + r ∀l ∈ [L], aL+1 = 1/2; bl = 1/2 − r; c = 0.
2

In particular, UP0 is µP and UP1/2 is NTP. For r > 1/2, such a uniform parametrization is not
stable because W0 would need to be Θ(nr−1 ), which would cause the initial GP to blow up. Thus,
geometrically, UPr , r ∈ [0, 1/2], form an edge of the stable polyhedron.
We can define the uniform stable polyhedron to be the subset of the stable polyhedron corresponding to
parametrizations which update the MLP body uniformly. This is isomorphic to the stable polyhedron
when L = 1. Since stable abc-parametrizations with L = 1 has only 3 degrees of freedom, say
a1 , a2 , b2 while we fix c = 0 (via Eq. (5)) and b1 = −a1 , we can visualize the corresponding stable
polyhedron in 3D. However, the nontrivial parametrizations only reside in the boundary of this
polyhedron. Because of its unbounded nature, we can project its boundary in 2D and visualize it.
This is done in Fig. 5.

G Proofs of Main Results


G.1 Rigorous Statements of Main Results

Applicable Nonlinearities For technical reasons, in our main results we restrict our attention to
the canonical examples of nonlinearities: tanh and relu — or rather, a smooth version of relu called
gelu [21] common in transformer models [6]. More precisely,
−σ −2 x2
Definition G.1. Define σ-gelu to be the function x 7→ 21 xerf(σ −1 x) + σ e √
2 π
+ x2 .

50
as is conventional in the machine learning literature, the convergence in distribution we mean here is really
d
→ (f˚0 (ξ1 ), . . . , f˚0 (ξk )) where f˚0 is the limit GP.
over finite dimensional marginals, i.e. (f0 (ξ1 ), . . . , f0 (ξk )) −

41
σ-gelu is a smooth approximation of relu and is the integral of 12 (erf(σ −1 x) + 1) that is 0 at −∞. The
large σ is, the smoother σ-gelu is. As σ → 0, σ-gelu converges to relu. We believe our results will
hold for generic nonlinearities, but making this precise is outside our scope here. (See Remark G.14
for some discussion).

Notations and Terminologies


Definition G.2 (Big-O Notation). Given a sequence of scalar random variables c = {cn ∈ R}∞ n=1 ,
we write c = Θ(n−a ) if there exist constants A, B such that An−a ≤ |c| ≤ Bn−a for sufficiently
large n, almost surely51 . Given a sequence of random vectors x = {xn ∈ Rn }∞ n=1 , we say x has
coordinates
p of size Θ(n−a ) and write x = Θ(n−a ) to mean the scalar random variable sequence
{ kxn k2 /n}n is Θ(n−a ). Similarly for the notations O(n−a ), Ω(n−a ). We use the notations
Θξ (n−a ), Oξ (n−a ), Ωξ (n−a ) if the hidden constants A, B are allowed to depend on some object ξ.
For brevity, we will often abuse notation and say c itself is a random variable or x itself is a random
vector.
Most often, the vector x will have “approximately iid” coordinates, so the notation x = Θ(n−a ) can
be interpreted intuitively to say x has coordinates of “standard deviation” Θ(n−a ), which justifies
the name.
Definition G.3. An abc-parametrization is a joint parametrization of an MLP and the learning rate
specified by the numbers {al , bl }l ∪ {c} as in Eq. (1). Below we will often say abc-parametrization
of an MLP for short, even though the parametrization affects the learning rate as well. A training
routine is a combination of learning rate ηn−c , training sequence {(ξt , yt )}t≥0 , and a loss function
L(f (ξ), y) that is continuously differentiable in the prediction of the model f (ξ).

Main Results We will mainly focus on stable parametrizations, defined below, which intuitively
means 1) the preactivations {hl }l and activations {xl }l have Θ(1) coordinates at initialization, and
2) their coordinates and the logit f (ξ) all stay O(1) (i.e. bounded independent of n) throughout the
course of SGD.52 Otherwise, they tend to ∞ with n, eventually going out of floating point range.
Indeed, this is an acute and real problem common in modern deep learning, where float16 is necessary
to train large models.
Definition G.4 (Stability). We say an abc-parametrization of an L-hidden layer MLP is stable if
1. For every nonzero input ξ ∈ X ,
hl0 (ξ), xl0 (ξ) = Θξ (1), ∀l ∈ [L], and E f0 (ξ)2 = Oξ (1), (40)
where the expectation is taken over the random initialization.
2. For any training routine, any time t ≥ 0, l ∈ [L], ξ ∈ X , we have
∆hlt (ξ), ∆xlt (ξ) = O∗ (1), ∀l ∈ [L], and ft (ξ) = O∗ (1),
where the hidden constant inside O can depend on the training routine, t, ξ, and the initial
function values f0 (X ).53
Recall from the main text,
Definition G.5. For any abc-parametrization, we write r for the quantity
L
def
r= min(aL+1 + bL+1 , 2aL+1 + c) + c − 1 + min [2al + I(l = 1)] .
l=1

For example, in NTP, r = 1/2, while in µP, r = 0. Intuitively, r is the exponent such that
−r
∆xLt (ξ) = Θξ (n ). Thus, to avoid activation blowup, we want r ≥ 0; to perform feature learning,
we want r = 0.
51
Here almost surely means for almost every instantiation of c1 , c2 , . . ., i.e. it is with regard to the product
probability space generated by all of {cn }∞n=1 . In this paper, this probability space will be generated by random
initializations of a neural network at every width n. Very importantly, note the order of the qualifiers: we are
saying for almost every instantiation of c1 , c2 , . . ., for large enough n, An−a ≤ |c| ≤ Bn−a .
52
but they may depend on training time and η; in particular, it’s possible that they diverge with time
53
For e.g. the NTK limit, f0 is a GP, so that we should expect the bounds on ∆hlt (ξ), ∆xlt (ξ) to depend on
f0 .

42
Theorem G.6 (Stability Characterization). Suppose φ is tanh or σ-gelu for sufficiently small σ. An
abc-parametrization is stable iff all of the following are true (with intuitions in parentheses):

1. ((pre)activations at initialization are Θ(1) and logits are O(1))


a1 + b1 = 0; al + bl = 1/2, ∀l ∈ [2, L]; aL+1 + bL+1 ≥ 1/2. (41)

2. (features don’t blowup, i.e. ∆xlt = O(1) for all l)


r ≥ 0. (42)

3. (logits don’t blow up during training, i.e. ∆WtL+1 xL L+1


t , W0 ∆xL
t = O(1))
2aL+1 + c ≥ 1; aL+1 + bL+1 + r ≥ 1. (43)

Here, r is as defined in Definition G.5.

In Eq. (43), ∆WtL+1 turns out to be Θ(n−(2aL+1 +c) ) and is correlated with xL t = Θ(1) such that
their product behaves according to Law of Large Numbers; the first inequality says this should not
blow up. Similarly, W0L+1 = Θ(n−(aL+1 +bL+1 ) ) and it turns out ∆xL t = Θ(n
−r
) and they will
interact via Law of Large Numbers, so the second inequality says their product shouldn’t blow up.
Our main results concern nontrivial parametrizations:
Definition G.7 (Nontriviality). We say an abc-parametrization of an L-hidden layer MLP is trivial if
a.s.
for every training routine, ft (ξ) − f0 (ξ) −−→ 0 for any time t ≥ 1 and input ξ ∈ X (i.e. the function
does not evolve in the infinite-width limit). We say the parametrization is nontrivial otherwise.
Theorem G.8 (Nontriviality Characterization). Suppose φ is tanh or σ-gelu for sufficiently small σ.
A stable abc-parametrization is nontrivial iff aL+1 + bL+1 + r = 1 or 2aL+1 + c = 1.
Definition G.9 (Feature Learning). We say an abc-parametrization of an L-hidden layer MLP admits
feature learning in the lth layer if there exists some training routine such that
∆xlt (ξ) = Ω∗ (1) (44)
for some t ≥ 0, ξ ∈ X , where the hidden constant inside Ω can depend on the training routine, t, ξ,
and the initial function values f0 (X ). We say the parametrization admits feature learning if it does
so in any layer.
We say the parametrization fixes the lth layer features if for all training routine,
a.s.
k∆xlt (ξ)k2 /n −−→ 0
for all t ≥ 0, ξ ∈ X . We say the parametrization fixes all features if it does so in every layer.
We make similar definitions as above replacing feature with prefeature and xl with hl .
Note that the probabilistic nature of Ω∗ (1) means that no feature learning does not imply fixing all
features (because ∆xlt (ξ) can just fluctuate wildly between 0 and infinity), but we will see that in the
context of nontrivial stable abc-parametrizations, this is true.
A somewhat stronger notion of feature learning is that the feature kernel evolves. This is, for example,
essential for linear transfer learning such as in self-supervised learning of image data.
Definition G.10 (Feature Kernel Evolution). We say an abc-parametrization of an L-hidden layer
MLP evolves the lth layer feature kernel if there exists some training routine such that
xlt (ξ)> xlt (ζ)/n − xl0 (ξ)> xl0 (ζ)/n = Ω∗ (1)
for some t ≥ 0, ξ, ζ ∈ X , where the hidden constant inside Ω can depend on the training routine, t,
ξ, ζ, and the initial function values f0 (X ). We say the parametrization evolves feature kernels if it
does so in any layer.
We say the parametrization fixes the lth layer feature kernel if for all training routine,
a.s.
xlt (ξ)> xlt (ζ)/n − xl0 (ξ)> xl0 (ζ)/n −−→ 0, as n → ∞,
for all t ≥ 0, ξ, ζ ∈ X . We say the parametrization fixes all feature kernels if it does so in every layer.
We make similar definitions as above replacing feature with prefeature and xl with hl .

43
Intuitively, for a stable parametrization, feature kernel evolution should imply feature learning (one
can see the contrapositive easily). In fact, we shall see below they are equivalent notions.
On the other hand, from the NTK example, we know certain limits can be described entirely through
kernel gradient descent with some kernel. Appropriately, we make the following definition.
Definition G.11 (Kernel Regime). We say an abc-parametrization of an L-hidden layer MLP is in
kernel regime if there exists a positive semidefinite kernel K : X 2 → R such that for every training
routine, the MLP function evolves under kernel gradient descent, i.e. there exist random variables
f˚t (ξ) for each time t ≥ 0 and input ξ ∈ X such that, as n → ∞,54
d
→ {f˚t (ξ)}t≤T,ξ∈X ,
{ft (ξ)}t≤T,ξ∈X − ∀T ≥ 1,
d
where −
→ denotes convergence in distribution, and
f˚t+1 (ξ) = f˚t (ξ) − ηK(ξ, ξt )L0 (f˚t (ξt ), yt ), ∀t ≥ 0. (45)

Observe that, in kernel regime, f˚t (ξ) is deterministic conditioned on f˚0 (ξ), as evident inductively
from Eq. (45). For example, in the NTK limit, {f˚0 (ξ) : ξ ∈ X } is a nontrivial Gaussian Process (GP),
but the function evolution conditioned on this GP is deterministic.
All of the concepts defined above are related to each other by the following theorem.
Theorem G.12 (Classification of abc-Parametrizations). Suppose φ is tanh or σ-gelu for sufficiently
small σ. Consider a nontrivial stable abc-parametrization of an L-hidden layer MLP. Then

1. The following are equivalent to r = 0


(a) feature learning
(b) feature learning in the Lth layer
(c) feature kernels evolution
(d) feature kernel evolution in the Lth layer
(e) prefeature learning
(f) prefeature learning in the Lth layer
(g) prefeature kernels evolution
(h) prefeature kernel evolution in the Lth layer
2. The following are equivalent to r > 0
(a) kernel regime
(b) fixes all features
(c) fixes features in the Lth layer
(d) fixes all feature kernels
(e) fixes feature kernel in the Lth layer
(f) fixes all prefeatures
(g) fixes prefeatures in the Lth layer
(h) fixes all prefeature kernels
(i) fixes prefeature kernel in the Lth layer
3. If there is feature learning or feature kernel evolution or prefeature learning or prefeature
kernel evolution in layer l, then there is feature learning and feature kernel evolution and
prefeature learning and prefeature kernel evolution in layers l, . . . , L.
a.s. a.s.
4. If r = 0, then for all ξ ∈ X , f0 (ξ) −−→ 0 and ft (ξ) −−→ f˚t (ξ) for some deterministic f˚t (ξ).
However, the converse is not true.
5. If r > 0, aL+1 + bL+1 + r > 1 and 2aL+1 + c = 1, then we have the Neural Network-
Gaussian Process limit.
54
Here because we want to avoid topological issues arising for convergence in distribution of infinite sequences,
we only require convergence in distribution jointly in all ξ ∈ X and time t below some cutoff T for every finite
T.

44
In particular, Statement 4 implies that feature learning, at least in our context, is incompatible with
Bayesian, distributional perspectives of neural network limits, such as the NNGP limit.
The characterization above then trivially implies the following dichotomy.
Corollary G.13 (Dynamical Dichotomy). For φ being tanh or σ-gelu for sufficiently small σ, a
nontrivial stable parametrization of an L-hidden layer MLP either admits feature learning or is in
kernel regime, but not both.
Remark G.14. The dependence on φ being tanh or σ-gelu for sufficiently small σ is only needed to
explicitly construct a training routine that leads to feature learning for r = 0. We expect this should
be true for generic φ, but this problem is surprisingly subtle, and is outside the scope of this paper.
Remark G.15. The equivalence between kernel regime and fixed feature kernel implies that linear
transfer learning is trivialized in any kernel regime limit. This is where the classifier layer of the
pretrained network is discarded and a new one (potentially outputting to a new output space) is trained
on top of the body of the pretrained network. But we can in fact say more: any nonlinear transfer
learning, where we replace the classifier layer with a neural network instead of a linear layer, is
trivialized as well. In addition, linear or nonlinear transfer learning has no effect even if we finetune
the entire network, instead of just the new classification network. The intuitive reason for this is
that, as discussed in Appendix B, the effect of ∆xL (ξ) on the output of the MLP is solely through
the interaction with W0L+1 . If W L+1 , W L+2 , . . . , are sampled anew, then this effect vanishes. We
formalize this below.

Theorem G.16 (Kernel Regime Limit Trivializes Transfer Learning). Suppose f is an L-hidden-layer
MLP55 in a stable kernel regime parametrization. Let A and B be two training routines.56
For any T, t ≥ 0,57 we define a network58 gT ;t as follows. Train f on A for T steps to obtain fT . Then
discard W L+1 in fT and extend the body of fT into an M -hidden-layer MLP g, where M ≥ L.59
Parametrize and initialize the new weights of g according to any stable abc-parametrization that
extends the parametrization of f . Train g on B for t steps to obtain gT ;t .
Then

1. (Finetuning the whole network) As n → ∞, for any ξ ∈ X and T, t ≥ 0,


a.s.
gT ;t (ξ) − g0;t (ξ) −−→ 0.

2. (Training only the classifier) The above is true even if we define gT ;t by only training the
new weights W L+1 , . . . , W M in g.

The Organization for the Proof of Our Main Results Above


Definition G.17. Below, we will abbreviate abc-parametrization of an L-layer MLP to just
parametrization. We will call parametrizations satisfying the conditions of Theorem G.6 pseudostable
while we try to prove Theorem G.6 (which, in this terminology, says stability and pseudostability are
equivalent).

We first characterize stability at initialization and prove Eq. (40) holds iff Eq. (41) (Appendix G.2).
Then, we describe the Tensor Program encoding the SGD of an MLP, assuming its parametrization
is pseudostable. The Master Theorem then naturally lets us calculate its infinite-width limit. We
then divide into the case of r > 0 and r = 0. In the former case, we show the infinite-width limit is
described by kernel gradient descent as in Eq. (45). In the latter case, we construct a training routine
where feature learning occurs and where the limit is not given by kernel gradient descent for any
kernel. Finally, in Appendix G.8, we combine all of our analyses to prove the main results in this
section.
55
the “pretrained network”
56
the “pretraining dataset” and the “finetuning dataset”
57
the “pretraining time” and “finetuning time”
58
the “finetuned network”
59
If M = L, then this is linear transfer learning where we replace just the last layer of f ; otherwise, it’s
nonlinear transfer learning.

45
G.2 Stability at Initialization
In this section, we characterize stability at initialization, which will form a foundation for our later
results.
Theorem G.18. Assume φ is not zero almost everywhere. For any parametrization, Eq. (40) holds iff
Eq. (41) holds, i.e. the following are equivalent

1. For every nonzero input ξ ∈ X ,


hl0 (ξ), xl0 (ξ) = Θξ (1), ∀l ∈ [L], and E f0 (ξ)2 = Oξ (1),
where the expectation is taken over the random initialization.
2. a1 + b1 = 0; al + bl = 1/2, ∀l ∈ [2, L]; aL+1 + bL+1 ≥ 1/2.

Proof. Fix an input ξ 6= 0. Here, because we focus on initialization, we will suppress the time 0
subscript and ξ dependence of hl , xl to mean t = 0, applied to ξ.
Obviously, h1 = W 1 ξ is a Gaussian vector with N (0, n−(a1 +b1 ) kξk2 ) coordinates, so h1 = Θξ (1)
a.s. 1
iff a1 + b1 = 0. Assume a1 + b1 = 0. By Law of Large Numbers, n1 kx1 k2 −−→ E φ(Z h )2 where
1
Z h = N (0, kξk2 ). Since φ is not almost everywhere zero and ξ 6= 0, this expectation is nonzero so
that x1 = Θξ (1).
We construct the following Tensor Program: the lone initial vector is h1 , the initial matrices are
c l , 2 ≤ l ≤ L, and initial scalars θl =
W def
n1/2−(al +bl ) . We sample h1α ∼ N (0, kξk2 ) and Wcl ∼
αβ
N (0, 1/n). Mathematically, we will represent W l = θl W c l . The program is then given by

xl = φ(hl ), ∀l ∈ [L], ĥl = W


c l xl−1 , hl = θl ĥl , ∀l ∈ [2, L],
where we used Nonlin, MatMul, and Nonlin (with parameter θl ).
l l l−1
Suppose al + bl = 1/2 (i.e. θl = 1) for all 2 ≤ l ≤ L. Then, Z h = Z ĥ = N (0, E φ(Z h )2 )
l
for each l ≤ L. Because φ is not everywhere zero, this inductively implies E(Z h )2 > 0 (and so
l a.s. l a.s.
also E(Z x )2 > 0) for all l ≤ L. By the Master Theorem, n1 khl k2 −−→ E(Z h )2 and n1 kxl k2 −−→
l
E(Z x )2 so this implies hl , xl = Θξ (1) for all l ≤ L as desired.
Conversely, suppose m is the smallest l ≥ 2 such that al + bl 6= 1/2. Then by the above reasoning,
ĥm = Θξ (1) so hm = Θξ (n1/2−(al +bl ) ) is either blowing up to ∞ or shrinking to 0 with n. This
shows that hl , xl = Θξ (1) for all l ≤ L iff a1 + b1 = 0 and al + bl = 1/2 for all 2 ≤ l ≤ L.
Finally, if a1 + b1 = 0 and al + bl = 1/2 for all 2 ≤ l ≤ L, then we see E f0 (ξ)2 =
L
(n1/2−(aL+1 +bL+1 ) )2 E kZ x k2 /n. For large n, this is Θξ ((n1/2−(aL+1 +bL+1 ) )2 ) and is Oξ (1) iff
aL+1 + bL+1 ≥ 1/2.
Definition G.19. We say a parametrization is initialization-stable if it satisfies Eq. (40) (or equiva-
lently, Eq. (41)).

G.3 Program Setup


In the next section, we construct the Tensor Program that encodes the training of an L-hidden layer
MLP under an abc-parametrization. Here we first describe the initial matrices, vectors, and scalars of
the program, along with necessary notations.
We first remark on a simplification we will make to streamline the proof.

The Size of W0L+1 vs ∆WtL+1 By construction, W0L+1 = Θ(n−(aL+1 +bL+1 ) ). If xL t (ξ) = Θ(1)
L+1 −(2aL+1 +c)
as in a stable parametrization, then ∆Wt = Θ(n ). Therefore, if aL+1 + bL+1 ≤
2aL+1 + c, then W0L+1 is at least as large as ∆WtL+1 , so that WtL+1 will stay the same order (in
terms of n) for all t. If the reverse inequality is true, then W0L+1 is smaller than WtL+1 for t ≥ 1.
This in particular implies that the gradients at time 0 is smaller than gradients at subsequent times.
For example, we can take aL+1 + bL+1 → ∞ while fixing 2aL+1 + c, in which case W0L+1 = 0 and

46
the weight gradients at initialization are all 0 except for that of W L+1 . One can thus think of this as a
“lag” in the training dynamics for 1 step.
Assumption G.20. For clarity of the proof, we will assume aL+1 + bL+1 ≤ 2aL+1 + c, i.e. WtL+1
stays the same order for all t. The case of aL+1 + bL+1 > 2aL+1 + c, corresponding to a 1-step “lag”
as explained above, can be dealt with similarly. We will remark whenever this requires some subtlety.
For the construction of the program and the application of the Master Theorem, we will also assume
the following for the rest of this paper.
Assumption G.21. φ0 is pseudo-Lipschitz and not almost everywhere zero.

Initial Matrices, Vectors, Scalars We will assume the parametrization is initialization-stable. For
ease of presentation, we also assume the input dimension d = 1.

1. Initial matrices: W02 , . . . , W0L , sampled like (W0l )αβ ∼ N (0, 1/n).
2. Initial vectors: input layer matrix W01 ∈ Rn×1 and normalized output layer matrix W c L+1 =def
0
W0L+1 naL+1 +bL+1 ∈ R1×n , sampled like (W01 )α , (W c L+1 )α ∼ N (0, 1).
0
3. Initial scalars: We define the following scalars (where we explain the intuition in parenthesis).
The reader can skip this part on a first read but come back when referred to.
(a) (n times the scale of coordinates of ∆Wtl ) For l ≥ 2, define
def −(aL+1 +bL+1 +c−1+2al )
θW l = n
(b) (scale of coordinates of ∆Wt1 and ∆h1t ) Define
def −(aL+1 +bL+1 +c+2a1 )
θ1 = θW 1 = n
(c) (scale of coordinates of ∆WtL+1 )
def −2aL+1 −c
θL+1 = θW L+1 = n
(d) (scale of ∆hlt and ∆xlt ) For l ∈ [L], define
def
θhl = θxl = θl = max θW m = max(θW l , θl−1 ) (46)
m≤l
l
= n−(aL+1 +bL+1 +c−1+minm=1 (2am +I(m=1)))
Note that θL = n−r with r defined in Definition G.5.
(e) (scale of WtL+1 )
def −(aL+1 +bL+1 )
θf = n
(f) (convenience scalars)
θxl−1 /hl = θxl−1 /θhl
θW l /hl = θW l /θhl
θW l xl−1 /hl = θW l θxl−1 /θhl
θL+1/f = θL+1 /θf
0
θL+1 = nθL+1 = n1−2aL+1 −c
0
θLf = nθL θf = n1−(r+aL+1 +bL+1 )
(g) Depending on the the value of aL+1 + bL+1 , we will also construct the values of f at
initialization as initial scalars. See Appendix G.4.1 for an explanation.
By our assumption that aL+1 + bL+1 ≤ 2aL+1 + c, the pseudostability inequalities of Theorem G.6
imply all of these θs either converge to 0 or stay constant at 1. This means that, assuming appropriate
regularity conditions on the nonlinearities and rank stability, we can apply the Master Theorem (if θ
blows up to ∞ then we can’t do that).

Notations We use := to more clearly denote assignment happening in the program, as opposed to
mathematical equality. To clearly demonstrate the application of Nonlin, we will also freely introduce
function symbols Ψ to put things into Nonlin form.

47
Preview of Names for Vectors In the program, for each z ∈ {xl , hl }l , we will construct vectors
δzt (ξ) to mathematically represent θz−1 (zt (ξ) − zt−1 (ξ)) (intuition: change in z scaled to have Θ(1)
coordinates). Similarly, for w ∈ {W L+1 , W 1 }, we will construct δwt to mathematically represent
−1
θw (wt − wt−1 ) (intuition: change in w scaled to have Θ(1) coordinates). Then, mathematically,
zt (ξ) = zt−1 (ξ) + θz δzt (ξ), wt = wt−1 + θw δwt .
We will also construct dz to mathematically represent θf−1 ∇z f (intuition: gradient ∇z f scaled to
have Θ(1) coordinates). For weight changes, we have the following identity
1 l
Wtl − Wt−1
l
= −ηn−c χt−1 n−2al θf dhlt−1 xl−1> l−1>
t−1 = −ηχt−1 θW l ht−1 xt−1 , ∀l ∈ [2, L], (47)
n
and for l = 1,
Wtl − Wt−1
l
= −ηn−c χt−1 n−2al θf dhlt−1 ξt−1
> >
= −ηχt−1 θW l hlt−1 ξt−1 . (48)

G.4 Program Construction


Here we construct the Tensor Program encoding the SGD of an MLP. We separately describe the first
forward and backward passes followed by the later forward and backward passes.

G.4.1 First Forward Pass


For every ξ ∈ X , we compute h10 (ξ) := W01 ξ ∈ Rn via Nonlin (as Ψ(W01 ; ξ), where Ψ is multiplica-
tion by ξ), and we construct the following vectors via Nonlin and MatMul

xl0 (ξ) := φ(hl0 (ξ)) ∈ Rn , hl+1 l+1 l


0 (ξ) := W0 x0 (ξ) ∈ Rn , for l = 1, . . . , L − 1, (49)

Function Output The first output is f0 (ξ) = W0L+1 xL


0 (ξ), but we will define f0 (ξ) in the program
slightly differently.
a.s.
Case when aL+1 + bL+1 > 1/2 Then f0 (ξ) −−→ 0 for all ξ ∈ X . In the program, we will
construct f0 (ξ) as an initial scalar mathematically defined by W0L+1 xL
0 (ξ).
6061

Case when aL+1 + bL+1 = 1/2 If aL+1 + bL+1 = 1/2, then f0 (ξ) converges to a nontrival
Gaussian via CLT [43], so we will condition on f0 (ξ) for all ξ ∈ X . Given values g(ξ) ∈ R for all
c L+1 xL (ξ) equals g(ξ) for all ξ ∈ X . The distribution
ξ ∈ X , let E be the event that f0 (ξ) = √1n W 0 0
c L+1 conditioned on E is given by
of W 0
d √
c L+1 =
W E
f L+1
nX + g + ΠW
0 0

where Wf L+1 is an iid copy of Wc L+1 , g ∈ RX is the vector of {g(ξ) : ξ ∈ X }, X ∈ RX ×n has


0 0
L
x0 (ξ) as rows, and Π is the orthogonal projection into the orthogonal complement of the space
spanned by {xL +
0 (ξ) : ξ ∈ X }. Here X denotes the pseudo-inverse of X.
By standard formulas for pseudo-inverse and orthogonal projection, we can write X + =
1 > > + 1 > > +
n X (XX /n) , Π = I − n X (XX /n) X.
def
Let Σ = XX > /n and γ =
def f L+1 /n). Then ΠW
(X W f L+1 − X > Σ+ γ, and √nX + g =
f L+1 = W
0 0 0
√1 X > Σ+ g.
n
a.s. f L+1 is independent from X, and Σ − a.s.
By the Master Theorem, γ −−→ 0 because W 0 −→ Σ̊ for some
PSD matrix Σ̊. At this point in the program, all scalars we used (like ξ) are constant with n and
can be absorbed into nonlinearities. By the rank stability property of any program without scalars
60
It is completely OK to define an initial scalar using randomness from other parts of the program, as long as
this scalar converges almost surely to a deterministic limit
61
We cannot define it using a Moment instruction because, intuitively, the mechanism of this convergence is
through CLT, not Law of Large Numbers.

48
a.s.
[46], the rank of Σ is fixed for large enough n, almost surely, so Σ+ −−→ Σ̊+ by the continuity of
pseudo-inverse on fixed rank matrices.
We will now replace Wc L+1 in the program with
0
 
c L+1 = > + g f L+1 − X > Σ+ γ
def

W E X Σ √ +W 0
n
 
constructed using Nonlin, where Σ+ √gn and (Σ+ γ) are finite dimensional and formally consid-
ered (collections of) scalars involved as coefficients for linear combination of rows of X. Since
a.s. c L+1 f L+1
Σ+ √gn , Σ+ γ −−→ 0, we have Z WE = Z W0 . Intuitively, this means that, even after conditioning
f L+1 is practically the same as the original distribution.
on f0 = g, the conditional distribution of W 0
We can then proceed exactly as in the case when aL+1 + bL+1 > 1/2, with W c L+1 taking the role of
E
f L+1 . The program then encodes the evolution of f conditioned on f0 (ξ) = g(ξ), ∀ξ ∈ X .62
W 0
Assumption G.22. For the above reason, we will assume aL+1 + bL+1 > 1/2, and remark whenever
the case aL+1 + bL+1 = 1/2 involves subtleties.

G.4.2 First Backward Pass


Next, we write the backward pass

dxL c L+1
0 (ξ) := W0
dhl0 (ξ) := dxl0 (ξ) φ0 (hl0 (ξ))
dxl−1 l> l
0 (ξ) := W0 dh0 (ξ)

where, recall, dz mathematically equals θf−1 ∇z f .


For ξ = ξ0 and its label y0 , we define the first loss derivative as
a.s.
χ0 := L0 (f0 (ξ0 ), y0 ) −−→ χ̊0 (ξ) = L0 (0, y0 )

where the convergence is because L0 is continuous by assumption.


We also define
δW1L+1 := −ηχ0 xL
0 (ξ0 )

to represent the (normalized) change in W L+1 due to the first gradient step.

G.4.3 tth Forward Pass, t ≥ 1


Overview We iteratively define δzt (ξ) to mathematically represent θz−1 (zt (ξ) − zt−1 (ξ)), for
z ∈ {xl , hl }l . Then we eventually set

zt (ξ) := z0 (ξ) + θz δz1 (ξ) + · · · + θz δzt (ξ).

Likewise, we will define δWtL+1 so that WtL+1 = θf W c L+1 + θL+1 (δW L+1 + · · · + δWtL+1 ). In
0 1
the program, we will not directly use WtL+1 but instead use

ctL+1 := W
W c L+1 + θL+1/f (δW L+1 + · · · + δWtL+1 ) (50)
0 1

ctL+1 = θ−1 WtL+1 .


where θL+1/f = θL+1 /θf . Mathematically, W f

Recall we shorthand zt = zt (ξt ) for all z ∈ {xl , hl , dxl , dhl }l ∪ {f, χ}.

62
Formally, we can also have {g(ξ) : ξ ∈ X } as initial scalars, but since they are fixed with n, they can be
c L+1 .
absorbed into the Nonlin that defines W E

49
The Construction of (Pre)Activations We start with h = h1 : By Eq. (48), we have
> >
δht (ξ) := −ηχt−1 ξt−1 ξdht−1 = Ψ(dht−1 ; ξt−1 ξ, ηχt−1 ).
(Notationally, recall we freely introduce function symbols Ψ to clarify the way we apply Nonlin).
For higher layers, if h = hl , x = xl−1 , and W = W l , then h = W x. By Eq. (47), we have,
mathematically,
θh δht (ξ) = θx Wt−1 δxt (ξ) + (Wt − Wt−1 )xt (ξ)
t−1
!
X
= θx W0 δxt (ξ) + (Ws − Ws−1 )δxt (ξ) + (Wt − Wt−1 )xt (ξ)
s=1
t−1
!
X x> δxt (ξ) x>
t−1 xt (ξ)
= θx W0 δxt (ξ) − ηθW χs−1 s−1 dhs−1 − ηχt−1 θW dht−1
s=1
n n
x> δx (ξ)
Recall θx/h = θh−1 θx , θW/h = θh−1 θW , θW x/h = θh−1 θW θx . With cs denoting s nt , we con-
struct
t−1
X
δht (ξ) := θx/h W0 δxt (ξ) − ηθW x/h χs−1 cs−1 dhs−1 − ηχt−1 θW/h ct−1 dht−1
s=1
= Ψ(W0 δxt (ξ), dh0 , . . . , dht−1 ; η, θx/h , θW x/h , θW/h , {cs , χs }t−1
s=0 )
If x = xl , h = hl , then x = φ(h), and (using θx = θh (Eq. (46))),
δxt (ξ) := θh−1 (φ(ht−1 (ξ) + θh δht (ξ)) − φ(ht−1 (ξ)))
= Ψ(ht−1 (ξ), δht (ξ); θh ) (51)
where Ψ is precisely the difference quotient for the function φ.63

The Function Outputs We do not construct ft (ξ) directly, but rather through scalars δft (ξ) =
ft (ξ) − ft−1 (ξ), so that
ft (ξ) := f0 (ξ) + δf1 (ξ) + · · · + δft (ξ).
Mathematically, δft (ξ) = θL+1 δWtL+1 xL L+1 L
t (ξ) + Wt−1 θL δxt (ξ), but we shall write it slightly
differently in the program:
δWtL+1 xL c L+1 L
0 t (ξ) 0 Wt−1 δxt (ξ)
δft (ξ) := θL+1 + θLf
n n
0 0 L+1
where θL+1 = nθL+1 , θLf = nθL θf and Wt−1 is constructed in Eq. (50).
c

G.4.4 tth Backward Pass, t ≥ 1


In the last layer, we construct
dxL c L+1 .
t (ξ) := Wt
For each l = L, . . . , 1 for dhl and l = L, . . . , 2 for dxl−1 , we also calculate
dhlt (ξ) := dxlt (ξ) φ0 (hlt (ξ))
t−1
X
dxl−1
t (ξ) := W0l> dhlt (ξ) − ηθW l χs cs xl−1
s
s=0

= Ψ(W0l> dhlt (ξ), xl−1 l−1 t−1


0 , . . . , xt−1 ; ηθW l , {χs , cs }s=0 )
dhl> l
s dht (ξ)
where cs = n . For ξ = ξt and its label yt , we define64
χt := L0 (ft (ξt ), yt ).
Finally, we compute the (normalized) change in W L+1 after this SGD update.
L+1
δWt+1 := −ηχt xL t (ξt ).
63
The pseudo-Lipschitzness of φ0 assumed in Assumption G.21 implies that Ψ here is pseudo-Lipschitz, so
that we can ultimately apply our Master Theorem.
64
Here we use Moment with the function φ(; ft (ξt )) = L0 (ft (ξt ), yt ) with no input and one parameter
(we absorb yt into φ since it does not change with n). The continuity of L0 in its first argument satisfies
Assumption E.4(1), so the Master Theorem can apply.

50
G.5 The Infinite-Width Limit
In this section, we describe the Z random variables (Definition 7.3) corresponding to the vectors
of the program constructed above. According to the Master Theorem, each such vector z will have
roughly iid coordinates distributed like Z z in the large n limit.
Let θ̊• denote the limit of any θ• in Appendix G.3. If pseudostability holds, then θ̊• is either 0 or 1,
as one can easily verify. We can construct the Z random variables for each vector in the program, as
follows.
1. For the first forward and backward passes, we have,
1 1 l l l+1 l+1
xl0 (ξ)
Z h0 (ξ) = ξZ W0 , Z x0 (ξ) = φ(Z h0 (ξ) ), Z h0 (ξ)
= Z W0 ,
dxL c L+1 dhl0 (ξ) dxl0 (ξ) hl0 (ξ) dxl−1 W0l> dhl0 (ξ)
Z 0 (ξ) =Z W0 , Z =Z φ0 (Z ), Z 0 (ξ)
=Z
2. For z ∈ {xl , hl }l , we have
Z zt (ξ) = Z z0 (ξ) + θ̊z Z δz1 (ξ) + · · · + θ̊z Z δzt (ξ) (52)
3. For l ∈ [L], x = xl , h = hl , we have Z δxt (ξ) = Ψ(Z ht−1 (ξ) , Z δht (ξ) ; θ̊h ) where Ψ is as in
Eq. (51). If θ̊h = 0 (e.g. if r > 0), then
Z δxt (ξ) = φ0 (Z ht−1 (ξ) )Z δht (ξ) . (53)
Otherwise, θ̊h = 1, and
Z δxt (ξ) = φ(Z ht (ξ) ) − φ(Z ht−1 (ξ) ). (54)
1
4. For h = h , we have
>
Z δht (ξ) = −ηχ̊t−1 ξt−1 ξZ dht−1 .
5. For l ≥ 2, h = hl , x = xl−1 , W = W l , we have
t−2
X
Z δht (ξ) = θ̊x/h Z W0 δxt (ξ) − η θ̊W x/h χ̊s Z dhs E Z xs Z xt (ξ)
s=0

− ηχ̊t−1 θ̊W/h Z dht−1 E Z xt−1 Z xt (ξ) (55)


where at least one of θ̊x/h and θ̊W/h equals 1. As usual, here we have the ZHat-ZDot
decomposition of Z W0 δxt (ξ) .
Z W0 δxt (ξ) = Ẑ W0 δxt (ξ) + Ż W0 δxt (ξ)
t−1
X ∂Z δxt (ξ)
= Ẑ W0 δxt (ξ) + Z dhs E > dh
.
s=0 ∂ Ẑ W0 s

6. For last layer weight


L+1 L
Z δWt = −ηχ̊t−1 Z xt−1 (56)
and
c L+1 c L+1 L+1 L+1
Z Wt = Z W0 + θ̊L+1/f (Z δW1 + · · · + Z δWt ) (57)
7. The output deltas have limits
L+1 L c L+1 L
δ f˚t (ξ) = θ̊L+1
0
E Z δWt 0
Z xt (ξ) + θ̊Lf E Z Wt−1 Z δxt (ξ) (58)
and
f˚t (ξ) = δ f˚1 (ξ) + · · · + δ f˚t (ξ).
8. For gradients:
L c L+1
Z dxt (ξ) = Z Wt
l l l
Z dht (ξ) = Z dxt (ξ) φ0 (Z ht (ξ) )
t−1
dxl−1 W0l> dhlt (ξ) l−1 l l
X
(ξ)
Z t =Z − η θ̊W l χ̊s Z xs E Z dhs Z dht (ξ)
s=0

51
9. Loss derivative
χ̊t = L0 (f˚t , y0 ).
The following fact follows from the results of [45] (or can be verified by straightforward calculation)
and will be useful for us.
l l l
Proposition G.23. Ż dx0 (ξ) = 0 and Z dx0 (ξ) = Ẑ dx0 (ξ) for any ξ ∈ X .
If the parametrization is pseudostable, then all the θ• converge to 0 or 1 so Setup 7.2 is satisfied.
Therefore, the Master Theorem applies and says that, for any collection of vectors v 1 , . . . , v k such
i
that Z v is defined above, we have
n
1X a.s. 1 k
ψ(vα1 , . . . , vαk ) −−→ E ψ(Z v , . . . , Z v )
n α=1
for any pseudo-Lipschitz ψ. In addition,65
a.s. a.s. a.s.
δft (ξ) −−→ δ f˚t (ξ), ft (ξ) −−→ f˚t (ξ), χt −−→ χ̊t , ∀ξ ∈ X , t ≥ 1.
We now describe some immediate consequences of this.

G.5.1 Some Immediate Results


Proposition G.24. A pseudostable parametrization is trivial if
2aL+1 + c > 1 and aL+1 + bL+1 + r > 1.

0
Proof. In this case, θL+1 0
, θLf 0
, θL,L+1 → 0, and δ f˚t (ξ) = 0 for all t and ξ ∈ X by Eq. (58).
Proposition G.25. A pseudostable parametrization is stable.
l l
Proof. For a pseudostable parametrization, all of θs converge to 1 or 0, and all of the Z δht (ξ) , Z δxt (ξ)
have well defined (finite) limits, which implies ∆hlt (ξ), ∆xlt (ξ) = O∗ (1), ∀l ∈ [L], and ft (ξ) =
O∗ (1).
Proposition G.26. Consider a pseudostable parametrization. If r > 0, then it fixes all (pre)features
a.s.
and all (pre)feature kernels. In addition, ∆WtL+1 ∆xL
t (ξ) −−→ 0.

Proof. If r > 0, then θl → 0 for all l ∈ [L], so that for all z ∈ {xl , hl }l , ∆zt (ξ) = zt (ξ) − z0 (ξ) =
a.s.
θz δz1 (ξ) + · · · + θz δzt (ξ) has k∆zt (ξ)k2 /n −−→ 0 by Eq. (52) and the Master Theorem, i.e. all
a.s.
features are fixed. Similarly, for any pair ξ, ξ¯ ∈ X , zt (ξ)> zt (ξ)/n
¯ ¯
− z0 (ξ)> z0 (ξ)/n −−→ 0, so all
a.s.
0
feature kernels are fixed. Finally, r > 0 implies θL,L+1 → 0, which means ∆WtL+1 ∆xL
t (ξ) −−→ 0
by the Master Theorem.
Proposition G.27. An initialization-stable parametrization with r < 0 is not stable.

Proof. If r < 0, then there is some ` ∈ [L] such that θL ≥ · · · ≥ θ` > 1 ≥ θ`−1 ≥ · · · ≥ θ1 .
For h = h` , x = x`−1 , W = W ` , we would have θx/h = θ`−1 /θ` → 0, θW/h = 1, and θW x/h =
θW/h θ`−1 → 0. The Tensor Program up to the definition of δh1 (ξ0 ) satisfies the conditions of the
a.s.
Master Theorem. Therefore, kδh1 (ξ0 )k2 /2 −−→ E(Z δh1 (ξ0 ) )2 = E(ηχ̊t−1 Z dh0 E Z x0 Z x1 (ξ0 ) )2 . If
ξ0 6= 0, then E(Z ) > 0. If η is in addition sufficiently small but nonzero, then E Z x0 Z x1 (ξ0 ) ≈
dh0 2

E(Z x0 )2 > 0. Therefore, under these conditions, and with a training sequence that has χ̊0 6= 0,
we have E(ηχ̊t−1 Z dh0 E Z x0 Z x1 (ξ0 ) )2 > 0, so that δh1 (ξ0 ) = Θξ0 (1). However, ∆h1 (ξ0 ) =
θh δh1 (ξ0 ) and θh = θ` → ∞. Hence ∆h1 (ξ0 ) 6= Oξ0 (1), as desired.

G.6 r > 0 Implies Kernel Regime


In this section, we analyze the case when r > 0. Our main result is deriving the corresponding
infinite-width kernel gradient descent dynamics (Theorem G.31). Nothing here depends on φ being
tanh or σ-gelu.
65
Again, if aL+1 + bL+1 = 1/2, remember we are conditioning on f0 (ξ), ξ ∈ X .

52
Preliminary Derivations If r > 0, then θ̊l = θ̊W l = 0 for all l ∈ [L], so that we have
l l l l l l l l c L+1 c L+1
Z ht (ξ) = Z h0 (ξ) , Z xt (ξ) = Z x0 (ξ) , Z dht (ξ) = Z dh0 (ξ) , Z dxt (ξ) = Z dx0 (ξ) , Z Wt = Z W0
for all t and ξ ∈ X . Let ` ∈ [L] be the unique ` such that 1 = θL /θL = · · · = θ` /θL > θ`−1 /θL ≥
· · · ≥ θ1 /θL . Then for l ≥ ` + 1 and shorthand h = hl , x = xl−1 , W = W l , we have θ̊x/h = 1,
θ̊W x/h = 0 and, by Eq. (55),

Z δht (ξ) = Z W0 δxt (ξ) − ηχ̊t−1 θ̊W/h Z dht−1 E Z xt−1 Z xt (ξ) ,


= Z W0 δxt (ξ) − ηχ̊t−1 θ̊W/h Z dh0 (ξt−1 ) E Z x0 (ξt−1 ) Z x0 (ξ) (59)

where θ̊W/h can be either 0 or 1. For l = `, because θh = θl = maxm≤l θW m = max(θW l , θl−1 ) =


max(θW l , θx ) so θ̊x/h = θ̊W x/h = 0 and θ̊W/h = 1, we also have
Z δht (ξ) = −ηχ̊t−1 Z dht−1 E Z xt−1 Z xt (ξ)
= −ηχ̊t−1 Z dh0 (ξt−1 ) E Z x0 (ξt−1 ) Z x0 (ξ) . (60)
Finally, for all l ∈ [L], we have, by Eq. (53),
Z δxt (ξ) = φ0 (Z ht−1 (ξ) )Z δht (ξ) = φ0 (Z h0 (ξ) )Z δht (ξ) .
Definition G.28. For 1 ≤ m ≤ l and ξ, ζ ∈ X , define
m m m+1 m+1 l l
Σml (ξ, ζ) =
def
E Z x0 (ξ)
Z x0 (ζ)
× E φ0 (Z h0 (ξ)
)φ0 (Z h0 (ζ)
) × · · · × E φ0 (Z h0 (ξ) )φ0 (Z h0 (ζ) ).
We also define
m+1 m+1 l l
def >
Σ0l (ξ, ζ) = ξ ζ × E φ0 (Z h0 (ξ)
)φ0 (Z h0 (ζ)
) × · · · × E φ0 (Z h0 (ξ) )φ0 (Z h0 (ζ) )

For example,
l l
Σll (ξ, ζ) = E Z x0 (ξ) Z x0 (ζ)
l l l+1 l+1
Σl,l+1 (ξ, ζ) = E Z x0 (ξ) Z x0 (ζ) E φ0 (Z h0 (ξ)
)φ0 (Z h0 (ζ)
),
and so on.

Notation For brevity, below we will shorthand ϑm = θW m /hm . We write Z x ≡ Z y mod Ẑ W • if


Z x − Z y is a linear combination of Ẑ W u for various vectors u.
Lemma G.29. For any input ξ, any l ≥ `, at any time t,
l−1
l l X l
Z δht (ξ) ≡ −ηχ̊t−1 Z dh0 (ξt−1 ) ϑ̊m+1 Σm,l−1 (ξt−1 , ξ) mod Ẑ W0 • . (61)
m=`−1

Proof. We proceed by induction.


Base Case l = `: this is given by Eq. (60).
Induction: Assume eq:kernelIH holds for l − 1, and we shall prove it for l.
To alleviate notation, we write x = xl−1 l−1 l−1 l−1
t , x̄ = xt−1 , x0 = x0 , h = ht , h̄ = ht−1 , h0 =
l−1
l−1 ¯ l
h0 , ξ = ξt−1 , W = W0 , i.e. we use ¯• to denote time t − 1 in contrast to • for time t, and we
suppress layer index. In contrast, we will write hl0 , hlt , and ξ for their usual meanings.
First, note that Z δx(ξ) = φ0 (Z h̄(ξ) )Z δh(ξ) by Eq. (53). Because Z h̄(ξ) = Z h0 (ξ) , and, by induction
hypothesis, Z δh(ξ) is a scalar multiple of Z dh0 (ξ̄) = Z dx0 (ξ̄) φ0 (Z h0 (ξ̄) ), Z δx(ξ) is symbolically
solely a function of Z h0 (ξ) , Z h0 (ξ̄) , Z dx0 (ξ̄) ,all of which are equal to their Ẑ versions (with the last
> l
due to Proposition G.23). Among these, only Z dx0 (ξ̄) = Z W dh0 (ξ̄) is constructed from matrix
multiplication with W0> . Thus,
l ∂Z δx(ξ) dhl0 (ξ̄) 0 h0 (ξ) ∂Z
δh(ξ)
Ż W0 δx(ξ) = Z dh0 (ξ̄) E = Z E φ (Z ) . (62)
∂Z dx0 (ξ̄) ∂Z dx0 (ξ̄)

53
By induction hypothesis,
l−2
∂Z δh(ξ) 0 h0 (ξ̄)
X
¯ ξ).
= −ηχ̊ t−1 φ (Z ) ϑ̊m+1 Σm,l−2 (ξ,
∂Z dx0 (ξ̄) m=`−1

Therefore,
l−2
∂Z δh(ξ) h i X
¯ ξ).
E φ0 (Z h0 (ξ) ) = −ηχ̊t−1 E φ 0
(Z h0 (ξ) 0
)φ (Z h0 (ξ̄)
) ϑ̊m+1 Σm,l−2 (ξ,
∂Z dx0 (ξ̄) m=`−1

By definition of Σml , this equals


l−2
∂Z δh(ξ) X
¯ ξ).
E φ0 (Z h0 (ξ) ) = −ηχ̊ t−1 ϑ̊m+1 Σm,l−1 (ξ,
∂Z dx0 (ξ̄) m=`−1

Plugging this back into Eq. (62), we get


l−2
l X
Ż W0 δx(ξ) = −ηχ̊t−1 Z dh0 (ξ̄) ¯ ξ).
ϑ̊m+1 Σm,l−1 (ξ, (63)
m=`−1

Finally, by Eq. (59),


l l
Z δht (ξ) = Ż W0 δx(ξ) − ηχ̊t−1 ϑ̊l Z dh0 (ξ̄) E Z x0 (ξ̄) Z x0 (ξ)
l
¯ ξ).
= Ż W0 δx(ξ) − ηχ̊t−1 ϑ̊l Z dh0 (ξ̄) Σl−1,l−1 (ξ,
Together with Eq. (63), this completes the induction.
Lemma G.30. Assume pseudostability, r > 0, and aL+1 + bL+1 ≤ 2aL+1 + c. If θ̊L+1/f = 1 then
0
θ̊Lf = 0.

Proof. aL+1 + bL+1 ≤ 2aL+1 + c iff θL+1 ≤ θf . So θ̊L+1/f = 1 implies θL+1 = θf . By


pseudostability, nθL+1 ≤ 1. Since θL = n−r , we have θLf
0
= n · n−r · θf = n−r · nθL+1 < 0 since
0
r > 0. Therefore θ̊Lf = 0.
Theorem G.31. Consider a pseudostable parametrization. At any time t, for any input ξ ∈ X , we
have
δ f˚t (ξ) = −ηχ̊t−1 Σ(ξt−1 , ξ),
where the kernel Σ is defined for any ξ, ζ ∈ X by
L−1
X
def 0 0
Σ(ζ, ξ) = θ̊L+1 ΣLL (ζ, ξ) + θ̊Lf ϑ̊m+1 ΣmL (ζ, ξ).
m=`−1

0 0
Observe that in the NTK parametrization, ` = 1, and θ̊L+1 = θ̊Lf = ϑ̊m+1 = 1 for all m, so
PL mL
Σ = m=0 Σ is precisely the NTK (for MLP without biases).

Proof. By Eqs. (57) and (58),


L+1 L c L+1 L
δ f˚t (ξ) = θ̊L+1
0
E Z δWt 0
Z xt (ξ) + θ̊Lf E Z Wt−1 Z δxt (ξ)
c L+1 c L+1 L+1 L+1
Z Wt = Z W0 + θ̊L+1/f (Z δW1 + · · · + Z δWt ).

L+1 L+1
0
Now by Lemma G.30, either θ̊L+1/f = 0 or θ̊Lf = 0. In both cases, (Z δW1 + · · · + Z δWt )
L+1 c L+1
contributes 0 to δ f˚t (ξ). So we can replace Z Wt−1 with Z W0 above, and write
c

L+1 L c L+1 L
δ f˚t (ξ) = θ̊L+1
0
E Z δWt 0
Z xt (ξ) + θ̊Lf E Z W0 Z δxt (ξ) .

54
If Eq. (61) is true for l = L, then
L−1
c L+1 L c L+1 L L X
E Z W0 Z δxt (ξ) = −ηχ̊t−1 E Z W0 Z dh0 (ξt−1 ) φ0 (Z h0 (ξ) ) ϑ̊m+1 Σm,L−1 (ξt−1 , ξ)
m=`−1

L L c L+1
where the contributions from Ẑ W0 • in Z δxt (ξ) vanish as they are independent from Z W0 . Since
L c L+1 L
Z dh0 (ξ) = Z W0 φ0 (Z h0 (ξ) ), we continue
L−1
c L+1 2 0
 
c L+1 L L L X
E Z W0 Z δxt (ξ) = −ηχ̊t−1 E Z W0 φ (Z h0 (ξt−1 ) )φ0 (Z h0 (ξ) ) ϑ̊m+1 Σm,L−1 (ξt−1 , ξ)
m=`−1
L−1
X
= −ηχ̊t−1 ϑ̊m+1 ΣmL (ξt−1 , ξ).
m=`−1

Similarly, by Eq. (56),


L+1 L L L
E Z δWt Z xt (ξ) = −ηχ̊t−1 E Z xt−1 (ξt−1 ) Z xt (ξ)
L L
= −ηχ̊t−1 E Z x0 (ξt−1 ) Z x0 (ξ) = −ηχ̊t−1 ΣLL (ξt−1 , ξ).
Altogether, these prove the desired claim.
Corollary G.32. A pseudostable parametrization with r > 0 is nontrivial iff aL+1 + bL+1 + r = 1
or 2aL+1 + c = 1.

0 0
Proof. The kernel Σ in Theorem G.31 is nonzero iff θ̊L+1 or θ̊Lf is 1, which is equivalent to saying
aL+1 + bL+1 + r = 1 or 2aL+1 + c = 1.
Corollary G.33. An initialization-stable parametrization with r > 0 but aL+1 + bL+1 + r < 1 or
2aL+1 + c < 1 is not stable.

0 0
Proof. If aL+1 + bL+1 + r < 1 or 2aL+1 + c < 1, then θL+1 → ∞ or θLf → ∞. Clearly, from
the definition, ΣmL (ξ, ξ) > 0 for any ξ 6= 0 and m ∈ [0, L]. All of our reasoning leading up to
Theorem G.31 applied at t = 1 holds, so Theorem G.31 (along with the Master Theorem) implies
a.s.
|δft (ξ)| −−→ ∞.
a.s.
Corollary G.34. If aL+1 + bL+1 + r > 1 and 2aL+1 + c = 1, then for all ξ ∈ X , f˚t (ξ) −−→ 0 and
δ f˚t (ξ) = −ηχ̊t−1 ΣLL (ξt−1 , ξ), i.e. we have the Neural Network-Gaussian Process (NNGP) limit.

Conventionally, the NNGP limit is associated with only training the last layer and nothing else. This
result says that the same limit can be achieved if we train the body of the network slightly, so that
L+1
∆xLt does not interact with W0 enough (embodied in the inequality aL+1 + bL+1 + r > 1) to
cause changes in ft .

0 0
Proof. The premise implies θ̊L+1 = 1 and θ̊Lf = 0, and the rest follows from Theorem G.31.
Remark G.35. We have assumed for simplicity of the proof that aL+1 + bL+1 ≤ 2aL+1 + c. If this
is not the case, then we can easily see Corollary G.34 applies anyway.

G.7 r = 0 Implies Feature Learning


In this section, we assume r = 0 and show any such pseudostable parametrization 1) admits
(pre)feature learning and (pre)feature kernel evolution, and 2) is not in kernel regime (Theorem G.50).
The overarching logic goes like this.

1. The Master Theorem shows that the specific entry n1 kxL 2


1 (ξ0 )k of the feature ker-
xL (ξ )
nel converges to E(Z 1 0
) . If the learning rate η = 0, then xL
2 L
1 (ξ0 ) = x0 and
L L L
E(Z x1 (ξ0 ) )2 = E(Z x0 )2 . We hope to say that as η increases, E(Z x1 (ξ0 ) )2 moves away

55
L
from E(Z x0 )2 , which would imply feature kernel evolution in layer L. To do so, we com-
L
pute ∂η2 E(Z x1 (ξ0 ) )2 evaluated at η = 0 and show it is nonzero (it turns out ∂η vanishes, so
the next best thing is ∂η2 ). This then also implies feature learning in layer L. Analogous
results for prefeatures and for other layers can be derived similarly.
2. If the parametrization is in the kernel regime with kernel K, the first step of SGD in the large
width limit would look like f˚1 (ξ) − f˚0 (ξ) = −ηχ̊0 K(ξ, ξ0 ); in particular, f˚1 (ξ) − f˚0 (ξ)
is linear in η. To show that a pseudostable parametrization with r = 0 is not in the kernel
regime, we will show ∂η3 (f˚1 (ξ) − f˚0 (ξ)) = ∂η3 f˚1 (ξ) is nonzero. (It turns out ∂η2 vanishes,
so the next best thing is ∂η3 ).

To calculate these η derivatives, we will derive recurrence relations involving quantities defined below
(see Lemma G.37 and Theorem G.40).

Setup and Notation First, write


l l
xl−1 l
Ztl =
def
Z ht (ξ0 ) , Ẑtl =
def
Ẑ W t (ξ0 )
, Ż0l =
def
Z dh0 .

Note that Ż0l is a centered Gaussian independent from Ẑtl , Ztl . Then we define

γ l (η) =
def
E φ(Z0l )φ(Z1l ), l
γ11 def
(η) = E φ0 (Z0l )φ0 (Z1l ), l
γ02 def
(η) = E φ(Z0l )φ00 (Z1l )
l
γ20 def
(η) = E φ00 (Z0l )φ(Z1l ), λl (η) =
def
E φ(Z1l )2
where the dependence on η is from Z1l . Naturally, since φ and φ0 are not almost everywhere zero,
we have γ l (0), λl (0), γ11
l
(0) > 0. Note at η = 0, we have Z1l = Z0l , so γ l (0) = λl (0) = E φ(Z0l )2 .
l l
Observe that (Ẑ1 , Ẑ0 ) is jointly Gaussian with mean zero and covariance
 l 
λ (η) γ l (η)
Γl (η) =
def
. (64)
γ l (η) λl (0)
WLOG, for simplicity of notation, we assume we choose a training routine such that χ̊0 = 1. We
assume ξ0 6= 0.
Since r = 0, WLOG we can suppose for some ` ∈ [L], we have θL = · · · = θ` = 1 > θ`−1 ≥ · · · ≥
θ1 .
Lemma G.36. With the setup above, we have
Z0`−1 = Z1`−1 , . . . , Z01 = Z11 ,
and
Z1l = Ẑ1l + ηβ l Ż0l φ0 (Z0l ), ∀l ∈ [`, L],
l
where β is defined recursively by
l−1
β l = β l (η) =
def
−γ l−1 (η) + β l−1 (η)γ11 (η)
β `−1 (η) =
def
0.

Additionally, β l (0) < 0 for all l ≥ `.

Proof. Straightforward calculation using Moment and Zdot. Here, −γ l−1 (η) comes from
l−1 l l−1
∆W1l x11 (ξ0 ) and β l−1 (η)γ11 (η) comes from Ż h1 (ξ0 ) . Since γ l (0), γ11 (0) > 0 for all l, the
l l
recurrence on β implies that β (0) < 0 for all l ≥ `.

G.7.1 Deriving Recurrence Relations on ∂η λl , ∂η γ l , ∂η2 λl , ∂η2 γ l


Below, we derive the recurrence relations required for our main result. They depend on the following
constants.
κl1 = E (φ2 )00 (Z0l ) , κl2 = E (φ2 )00 (Z0l )φ0 (Z0l )2 , κl3 = E φ(Z0l )φ00 (Z0l )φ0 (Z0l )2 .
def
  def
  def
 

56
Lemma G.37. With the setup above, we have, for all l ∈ [L],
1 l
∂η λl = κ ∂η λl−1 (65)
2 1
1 l
∂η γ l = γ02 ∂η λl−1 + γ11
l
∂η γ l−1 .
2

Proof. We first derive the recurrence on ∂η λl . By Lemma G.38 below, we have


1
∂η λl = 2 E φ(Z1l )∂η φ(Z1l ) + E(φ2 )00 (Z1l )∂η λl−1 .
2
Since
∂η φ(Z1l ) = φ0 (Z1l )(β l Ż0l φ0 (Z0l ) + η Ż0l φ0 (Z0l )∂η β l ), (66)
we compute
E φ(Z1l )∂η φ(Z1l ) = E φ(Z1l )φ0 (Z1l )(β l Ż0l φ0 (Z0l ) + η Ż0l φ0 (Z0l )∂η β l ) = 0

because Ż0l is independent from everything else in the first expectation. This directly implies the
result for ∂η λl .
 l l

γ02 γ11
For ∂η γ l , let Σ = Σ(η) =
def
l l . With Γl−1 as in Eq. (64), we have
γ11 γ20
1
∂η γ l = E φ(Z0l )∂η φ(Z1l ) + hΣ, ∂η Γl−1 i
2
l−1 def
By
 samel−1
 Eq. (65), the first term of this sum is zero. Since ∂η Γ (η) =
reasoning as in
l−1
∂η λ (η) ∂η γ (η)
, we have
∂η γ l (η) 0
1 1 l
∂η γ l = hΣ, ∂η Γl−1 i = γ02 ∂η λl−1 + γ11
l
∂η γ l−1 .
2 2

Lemma G.38. Consider a twice continuously differentiable f and Gaussian vector Z ∼ N (0, Σ)
such that f and Σ both depend on a parameter η. Then
1
∂η E f (Z) = E ∂η f (Z) + hE ∇2 f (z), ∂η Σi,
2
where ∇2 denotes Hessian wrt z, and h, i denotes trace inner product of matrices.

Proof. Let p(z) denote the PDF of Z. We have


Z Z Z
∂η E f (Z) = ∂η f (z)p(z) dz = ∂η f (z)p(z) dz + f (z)∂η p(z) dz

The first integral is E ∂η f (Z). The second integral can be rewritten using integration-by-parts as
hE ∇2 f (z), ∂η Σi. (e.g. see Lemma F.18 of [50])

We then easily have


Theorem G.39. For all l ∈ [L],
∂η γ l (0) = ∂η λl (0) = 0.

Proof. For l < `, we obviously have ∂η γ l (η) = ∂η λl (0) = 0 for all η. Then this follows from
Lemma G.37 and a simple induction.

Unfortunately, this means that the first η derivative doesn’t give us what we need. So we try the
second derivative, which will turn out to work.

57
Theorem G.40. For all l < `,∂η2 λl (0) = ∂η2 γ l (0) = 0, and for all l ≥ `,
1
∂η2 λl (0) = Cκl2 + κl1 ∂η2 λl−1 (0)
2
2 l l 1 l
∂η γ (0) = Cκ3 + γ02 (0)∂η2 λl−1 (0) + γ11
l
(0)∂η2 γ l−1 (0),
2
where C = 2(β l (0))2 E(Ż0l )2 > 0.

Proof. We start with the ∂η2 λl (0) recurrence. For l ≥ `, ∂η2 λl is a sum of 3 terms, representing 1) 2
derivatives in the integrand, 2) 2 derivatives in the Gaussian variance, and 3) 1 derivative each. When
evaluated at η = 0, only the first two terms survive because ∂η λl−1 (0) = 0 by Theorem G.39:
1
∂η2 λl (0) = E ∂η2 φ2 (Z1l )|η=0 + E(φ2 )00 (Z0l )∂η2 λl−1 (0).
2
Now
E ∂η2 φ2 (Z1l ) = 2∂η (E φ(Z1l )φ0 (Z1l )(β l Ż0l φ0 (Z0l ) + η Ż0l φ0 (Z0l )∂η β l ))
= 2 E(φ2 )00 (Z1l )(β l Ż0l φ0 (Z0l ) + η Ż0l φ0 (Z0l )∂η β l )2 + · · ·
where other terms appear in this sum but they vanish because Ż0l appears unpaired in the expectation.
Thus,
E ∂η2 φ2 (Z1l )|η=0 = 2(β l (0))2 E(Ż0l )2 E(φ2 )00 (Z0l )φ0 (Z0l )2 .
Plugging this back in, we get the recurrence on ∂η2 λl (0).
The ∂η2 γ l (0) recurrence is derived similarly.

The following result will be useful for showing ∂η3 f˚1 (ξ0 ) 6= 0.
Theorem G.41. Define
κ̇l3 =
def
E φ000 (Z0l )φ0 (Z0l )3 , l def 0 l 000 l l def 00 l 2
 
γ13 = E φ (Z0 )φ (Z0 ), γ22 = E φ (Z0 ) .
Then for all l ≥ `,
1 l 2 l−1
∂η2 γ11
l
(0) = C κ̇l3 + γ13 l
∂η λ (0) + γ22 ∂η2 γ l−1 (0),
2
where C = 2(β l (0))2 E(Ż0l )2 > 0.

Proof. Similar to the proof of Theorem G.40.

The following result will be useful for showing prefeature kernel evolution.
Theorem G.42. For all l ≥ `,
∂η2 E(Z1l )2 |η=0 = 2C + γ11
l
(0)∂η2 λl−1 (0),

where C = 2(β l (0))2 E(Ż0l )2 > 0.

Proof. Similar to the proof of Theorem G.40.

G.7.2 Applications to σ-Gelu


The following proposition regarding σ-gelu is easy to verify.
Proposition G.43. Let φ be σ-gelu. For any centered Gaussian Z ∈ R with nonzero variance,
E(φ2 )00 (Z), E(φ2 )00 (Z)φ0 (Z)2 , E φ(Z)φ00 (Z)φ0 (Z)2 , E φ(Z)φ00 (Z), E φ00 (Z)2 > 0,
and they converge to 0 as σ → 0. Also,
E φ000 (Z)φ0 (Z)3 , E φ0 (Z)φ000 (Z) < 0,
and they converge to −∞ as σ → 0.

58
This particularly implies that κl1 , κl2 , κl3 , γ02
l
(0), γ22l
> 0 and converges to 0 with small σ, but
l l
κ̇3 , γ13 < 0 and diverges to −∞ with small σ.
Theorem G.44. Consider a pseudostable parametrization with r = 0. If φ is σ-gelu, then for all
l ≥ `,
∂η2 γ l (0), ∂η2 λl (0) > 0
and they converge to 0 as σ → 0.

Proof. We always have (β l (0))2 , E(Ż0l )2 > 0. By Proposition G.43, κl1 , κl2 > 0 as well. Thus, by
Theorem G.40, ∂η2 λl (0) > 0 for all l ≥ `. By Proposition G.43, κl3 , γ02
l
(0) > 0, so by Theorem G.40,
∂η γ (0) > 0 for all l ≥ ` as well. As σ → 0, κ1 , κ2 , κ3 , γ02 (0) → 0, so ∂η2 λl (0), ∂η2 γ L (0) → 0.
2 l l l l l

Theorem G.45. Consider a pseudostable parametrization with r = 0. Suppose aL+1 + bL+1 + r = 1


or 2aL+1 + c = 1. If φ is σ-gelu for sufficiently small σ, then
∂η3 f˚1 (ξ0 ) 6= 0.

L+1 L c L+1 L
Proof. We have f˚1 (ξ0 ) = θ̊L+1
0 0
E Z δW1 Z x1 (ξ0 ) + θ̊Lf E Z W0 Z δx1 (ξ0 ) , where at least one of
0 0
θ̊Lf and θ̊L+1 is 1 because aL+1 + bL+1 + r = 1 or 2aL+1 + c = 1. We have
L+1 L L L
E Z δW1 Z x1 (ξ0 ) = −η E Z x0 Z x1 (ξ0 )
c L+1 L c L+1 L c L+1 0 L L−1 L−1
E Z W0 Z x1 (ξ0 ) = E Z W0 φ(Z h0 − ηZ W0 φ (Z h0 ) E Z x0 Z x1 (ξ0 )
)
hL hL xL−1 xL−1
= −η E φ0 (Z 1 (ξ0 ) )φ0 (Z 0 )EZ 0 Z 1 (ξ0 )

where we used Stein’s Lemma for the last equality. Thus


 
∂η3 f˚1 (ξ0 ) = − θ̊L+1
0 0
∂η2 γ L (0) + θ̊Lf ∂η2 (γ11
L L−1
γ )(0) .

Below we will show that for small σ, ∂η2 γ L (0) is small and positive and ∂η2 (γ11
L L−1
γ )(0) is large and
3˚ 0 0
negative, so ∂η f1 (ξ0 ) cannot be 0 no matter the values of θ̊L+1 and θ̊Lf .
Claim: For sufficiently small σ, ∂η2 γ11
L
(0) < 0. It converges to −∞ as σ → 0.
Proof: By Theorem G.41, ∂η2 γ11l
(0) = C κ̇l3 + 21 γ13
l
∂η2 λl−1 (0) + γ22
l
∂η2 γ l−1 (0). Note ∂η2 λl−1 (0) ≥ 0
by Theorem G.44. Also, by Proposition G.43, κ̇3 , γ13 < 0, γ22 > 0, and as σ → 0, κ̇l3 , γ13
l l l l

l 2 L−1 2 l
−∞, γ22 → 0 (as well as ∂η γ (0), ∂η λ (0) → 0 by Theorem G.44). One can see that C
converges to a positive constant as σ → 0 as well. Therefore, for small enough σ, ∂η2 γ11 l
(0) < 0, and
2 L
as σ → 0, ∂η γ11 (0) → −∞.
Claim: For sufficiently small σ, ∂η2 (γ11
L L−1
γ )(0) < 0. It converges to −∞ as σ → 0.
Proof: Observe ∂η2 (γ11
L L−1
γ )(0) = ∂η2 γ11
L
(0)γ L−1 (0) + γ11
L
(0)∂η2 γ L−1 (0) because ∂η γ L−1 (0) = 0
by Theorem G.39. So the above claim and Theorem G.44 yield the desired results.
0
Finishing the main proof: Therefore, if θ̊L+1 = 1 but θ̊Lf0
= 0, then −∂η3 f˚1 (ξ0 ) > 0 because
0 0 3˚
∂η γ (0) > 0; if θ̊L+1 = 0 but θ̊Lf = 1, then −∂η f1 (ξ0 ) < 0 for small σ because ∂η2 (γ11
2 L L L−1
γ )(0) <
0 0 3˚ 2 L L−1
0; if θ̊L+1 = θ̊Lf = 1, then −∂η f1 (ξ0 ) < 0 for small σ because ∂η (γ11 γ )(0) → −∞ while
∂η2 γ L (0) → 0 as σ → 0.

G.7.3 Applications to Tanh


The following property of tanh is easy to verify.
Proposition G.46. Let φ = tanh. For any centered Gaussian Z ∈ R with nonzero variance,
E(φ2 )00 (Z), E(φ2 )00 (Z)φ0 (Z)2 , E φ00 (Z)2 > 0,
and
E φ(Z)φ00 (Z)φ0 (Z)2 , E φ(Z)φ00 (Z), E φ000 (Z)φ0 (Z)3 , E φ0 (Z)φ000 (Z) < 0.

59
In particular, this means
κl1 , κl2 , γ22
l
> 0, κl3 , γ02
l
(0), κ̇l3 , γ13
l
< 0.
Theorem G.47. Consider a pseudostable parametrization with r = 0. If φ is tanh, then for all l ≥ `,
∂η2 γ l (0) < 0, ∂η2 λl (0) > 0.

Proof. Similar to the proof of Theorem G.44, except that here κl3 , γ02
l
(0) < 0, making ∂η2 γ l (0) <
0.
Theorem G.48. Consider a pseudostable parametrization with r = 0. Suppose aL+1 + bL+1 + r = 1
or 2aL+1 + c = 1. If φ is tanh, then
∂η3 f˚1 (ξ0 ) 6= 0.

Proof. Similar to the proof of Theorem G.45, except in the expression


 
∂η3 f˚1 (ξ0 ) = − θ̊L+1
0 0
∂η2 γ L (0) + θ̊Lf ∂η2 (γ11
L L−1
γ )(0) ,

∂η2 γ L (0) and ∂η2 (γ11


L L−1
γ )(0) are both negative. The former is because of Theorem G.47. The latter
2 L−1
is because ∂η γ (0) ≤ 0 for the same reason, and ∂η2 γ11
L
(0) < 0 since κ̇l3 , γ13
l l
< 0, γ22 > 0 by
Proposition G.46.

G.7.4 Main Results


Proposition G.49. Suppose φ is tanh or σ-gelu for sufficiently small σ. A pseudostable parametriza-
tion with r = 0 is nontrivial iff aL+1 + bL+1 = 1 or 2aL+1 + c = 1.

Proof. If aL+1 + bL+1 + r = 1 or 2aL+1 + c = 1, then Theorem G.45 and Theorem G.48 show that
the parametrization is nontrivial. Otherwise, it is trivial by Proposition G.24.
Theorem G.50. Suppose φ is tanh or σ-gelu for sufficiently small σ. For any nontrivial pseudostable
parametrization with r = 0, the following are true of the parametrization:

1. not in kernel regime


2. feature learning
3. feature learning in the Lth layer
4. feature kernels evolution
5. feature kernel evolution in the Lth layer
6. prefeature learning
7. prefeature learning in the Lth layer
8. prefeature kernels evolution
9. prefeature kernel evolution in the Lth layer
10. if there is feature learning or feature kernel evolution or prefeature learning or prefeature
kernel evolution in layer l, then there is feature learning and feature kernel evolution and
prefeature learning and prefeature kernel evolution in layers l, . . . , L.

Proof. The parametrization cannot be in kernel regime since ∂η3 f˚1 (ξ0 ) 6= 0 by Theorem G.48 or
Theorem G.45. By Theorem G.44 or Theorem G.47, ∂η2 λl (0) > 0 for all l ≥ `, so the feature kernel
evolves in layer `, . . . , L, for some normalized learning rate η > 0. This implies feature learning in
L L L L
layer `, . . . , L, since Z x1 (ξ0 ) − Z x0 6= 0 in this case. This then implies Z h1 (ξ0 ) − Z h0 6= 0, so we
have prefeature learning in layer `, . . . , L. Prefeature kernel evolution in layer `, . . . , L is implied by
Theorem G.42. Finally, the last statement follows clearly from our logic above.

60
Corollary G.51. Suppose φ is tanh or σ-gelu for sufficiently small σ. Consider any initialization-
stable parametrization with r = 0. If aL+1 + bL+1 < 1 or 2aL+1 + c < 1, then the parametrization
is not stable.

0
Proof. First suppose aL+1 + bL+1 < 1 and 2aL+1 + c ≥ 1. Then θLf = n1−(aL+1 +bL+1 ) → ∞ but
0 c L+1 L
θ̊L+1 ≤ 1. As in the proof of Theorem G.45, there is some η 6= 0 such that E Z W0 Z δx1 (ξ0 ) =
a.s.
c L+1 δxL (ξ0 ) −
R for some R 6= 0. Therefore, by the Master Theorem, n1 W 0 1 −→ R =⇒
L+1 L 1−(aL+1 +bL+1 ) L+1 L
|W0 ∆x1 (ξ0 )| = Θ(n ) → ∞. This dominates ∆W1 x1 (ξ0 ), which by sim-
ilar reasoning is O(1). So f1 (ξ0 ) diverges and the parametrization is not stable.
Now suppose aL+1 + bL+1 ≥ 1 and 2aL+1 + c < 1. This violates our simplifying assumption
a.s. L L
that aL+1 + bL+1 ≤ 2aL+1 + c, but it’s easy to see that n1 δW1L+1 xL
1 (ξ0 ) −−→ −ηχ̊0 E Z x0 Z x1 (ξ0 ) .
L
For η small enough, this is close to −ηχ̊0 E(Z x0 )2 and thus is nonzero. Then |∆W1L+1 xL 1 (ξ0 )| =
Θ(n1−(2aL+1 +c) ) → ∞. This dominates W0L+1 ∆xL (ξ
1 0 ) = O(1), so f (ξ
1 0 ) diverges. Therefore,
the parametrization is not stable.
Finally, suppose both aL+1 + bL+1 , 2aL+1 + c < 1. If aL+1 + bL+1 6= 2aL+1 + c, then we
have one of ∆W1L+1 xL 1 (ξ0 ) and W0
L+1
∆xL1 (ξ0 ) dominate the other like the above, leading to
divergence. If aL+1 + bL+1 = 2aL+1 + c, then in the case of σ-gelu with small σ, W0L+1 ∆xL 1 (ξ0 )
will dominate ∆W1L+1 xL (ξ
1 0 ), as in Theorem G.45; and in the case of tanh, both have the same sign,
as in Theorem G.48. In either case, f1 (ξ0 ) diverges, so the parametrization is not stable.

G.8 Putting Everything Together


Finally, in this section we tie all of our insights above to prove our main theorems.
Theorem G.52. Suppose φ is tanh or σ-gelu for sufficiently small σ. A parametrization is stable iff
it is pseudostable.

Proof. The “if” direction is given by Proposition G.25. We now show that when any (in)equality of
pseudostability is violated, the parametrization is not stable.
First, if Eq. (41) is not satisfied, then Theorem G.18 shows lack of stability.
Second, if Eq. (41) is satisfied but r < 0, then Proposition G.27 shows lack of stability.
Finally, if Eq. (41) is satisfied and r ≥ 0 but aL+1 + bL+1 < 1 or 2aL+1 + c < 1, then Corollary G.51
or Corollary G.33 shows lack of stability.

Given this result, we will now just say “stable” instead of “pseudostable” from here on.
Theorem G.8 (Nontriviality Characterization). Suppose φ is tanh or σ-gelu for sufficiently small σ.
A stable abc-parametrization is nontrivial iff aL+1 + bL+1 + r = 1 or 2aL+1 + c = 1.

Proof. The case of r = 0 and the case of r > 0 are resp. given by Proposition G.49 and
Corollary G.32.
Theorem G.12 (Classification of abc-Parametrizations). Suppose φ is tanh or σ-gelu for sufficiently
small σ. Consider a nontrivial stable abc-parametrization of an L-hidden layer MLP. Then

1. The following are equivalent to r = 0


(a) feature learning
(b) feature learning in the Lth layer
(c) feature kernels evolution
(d) feature kernel evolution in the Lth layer
(e) prefeature learning
(f) prefeature learning in the Lth layer
(g) prefeature kernels evolution
(h) prefeature kernel evolution in the Lth layer

61
2. The following are equivalent to r > 0
(a) kernel regime
(b) fixes all features
(c) fixes features in the Lth layer
(d) fixes all feature kernels
(e) fixes feature kernel in the Lth layer
(f) fixes all prefeatures
(g) fixes prefeatures in the Lth layer
(h) fixes all prefeature kernels
(i) fixes prefeature kernel in the Lth layer
3. If there is feature learning or feature kernel evolution or prefeature learning or prefeature
kernel evolution in layer l, then there is feature learning and feature kernel evolution and
prefeature learning and prefeature kernel evolution in layers l, . . . , L.
a.s. a.s.
4. If r = 0, then for all ξ ∈ X , f0 (ξ) −−→ 0 and ft (ξ) −−→ f˚t (ξ) for some deterministic f˚t (ξ).
However, the converse is not true.
5. If r > 0, aL+1 + bL+1 + r > 1 and 2aL+1 + c = 1, then we have the Neural Network-
Gaussian Process limit.

Proof. A nontrivial stable parametrization has either r = 0 or r > 0. By Theorem G.50,


Proposition G.26, and Theorem G.31, r = 0 implies all of the statements in (1) and r > 0 im-
plies all of the statements in (2). Consequently, if feature learning happens, then clearly r cannot be
positive, so r must be 0. Likewise, all of the statements in (1) imply r = 0. Symmetrically, all of the
statements in (2) about fixing features imply r > 0. Finally, if the parametrization is in kernel regime,
then by Theorem G.50(1), r cannot be 0, so r > 0. This proves (1) and (2).
If the premise of (3) holds, then by the above, r = 0, so the conclusion follows from Theorem G.50.
This proves (3).
a.s.
If r = 0, then nontriviality means aL+1 + bL+1 ≥ 1. This implies f0 (ξ) −−→ 0 for all ξ ∈ X (more
precisely, f0 (ξ) has standard deviation Θ(n1/2−(aL+1 +bL+1 ) ) → 0 by Central Limit Theorem). The
program describes the unconditional SGD trajectory of f (as opposed to the case when aL+1 +bL+1 =
a.s.
1/2), so ft (ξ) −−→ f˚t (ξ) does not depend on f0 . The converse is not true, for example because of
Corollary G.34. This prove (4).
(5) follows from Corollary G.34 (which actually allows much more general φ).

Proofs of Theorems 6.1, 6.3 and 6.4 For any finite subset X of the input space Rd (where
d = 1 here), we can write out the SGD computation as a Tensor Program like in Appendix G.4.
a.s.
Then the Master Theorem implies the convergence of ft (ξ) −−→ f˚t (ξ) S for every ξ ∈ X . Let
d
X1 ⊆ · · · ⊆ Xk ⊆ · · · be an infinite chain of finite subsets of R such that k Xk is a dense subset of
a.s.
Rd . Then the convergence of ft (ξ) −−→ f˚t (ξ) holds for every ξ ∈ k Xk (because we have almost
S
sure convergence). Finally, we apply a continuity argument to get this convergence for all of Rd :
Because φ0 and thus φ are pseudo-Lipschitz, they are locally Lipschitz (i.e. Lipschitz on any compact
set). In addition, the operator norms of W L are almost surely bounded from standard matrix operator
bounds. Thus one can see that the Tensor S Program is locally Lipschitz in ξ. Consequently, f˚t (ξ) is
continuous in ξ. This allows to pass from k Xk to Rd .

Proofs of Propositions 5.3, 5.5 and F.2 and Theorems F.3 and F.4 follow by dividing into cases
of r > 0 and r = 0 and easy modification of the reasoning in Appendices G.6 and G.7.

Proof of Theorem G.16 follows from straightforward calculations. The basic outline of the
calculations is: 1) During pretraining, f ’s change is purely due to a) the interaction betwen ∆W l , l ≤
L, and W0L+1 , and b) the interaction between xL and ∆W L+1 . 2) When W L+1 is re-initialized √ in
l M
g, these interactions are killed. The pretrained ∆W , l ≤ L, will cause x to differ by Θ(1/ n)
coordinatewise compared to if ∆W l , l ≤ L, are all reset to 0, but this difference is uncorrelated with

62
the last layer weights W M +1 of g, so their interaction is subleading in n, i.e. in the infinite-width
limit,
a.s.
gT ;t (ξ) − g0;t (ξ) −−→ 0,
whether all of g or just the new weights are trained during fintetuning.

63

You might also like