Notes for mu-Transfer Paper

Introduction

The paper Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer[1] introduces a practical way to set the parameters for the deep neural network optimally so that the NN can do zero-shot hyperparameter transfers across width and depth and do maximal feature learning[2]. The paper shows mathematically that the network output and weight update won't blowup (independent of the NN width) if the maximal update parameterization is used. Thought the paper is nicely written with a lot of details written in the appendix. I found it still omits a lot of essential computation details. As my paper reading notes, I tried to fill those detailed derivations in this blog for future references. Make sure you read the paper before checking this blog. Otherwise it won't make any sense. Hopefully it is useful for you to understand the paper better.

Background: Matrix Differentiation

Linear Case

For matrix ARm×nA \in \Reals^{m \times n} and BRn×kB \in \Reals^{n \times k}, define C=ABRm×kC=AB \in \Reals^{m \times k}. If we know the gradient of output CC as LCRk×m\frac{\partial L}{\partial C} \in \Reals^{k \times m}. The gradient with respect to BB is

LB=LCA\frac{\partial L}{\partial B} = \frac{\partial L}{\partial C} A(0.1)

The gradient with respect to AA is

LA=BLC\frac{\partial L}{\partial A} = B \frac{\partial L}{\partial C} (0.2)

Proof of Eq. (0.1).

CijBmn=(kAikBkj)Bmn=kAikδkmδjn=Aimδjn\begin{split}\frac{\partial C_{ij}}{\partial B_{mn}} &= \frac{\partial (\sum_k A_{ik}B_{kj})}{\partial B_{mn}} \\ &= \sum_k A_{ik} \delta_{km} \delta_{jn} \\ &= A_{im} \delta_{jn} \end{split}(0.3)
LBmn=ijLCijCijBmn=ijLCijAimδjn=iLCinAim\begin{split}\frac{\partial L}{\partial B_{mn}} &= \sum_{ij} \frac{\partial L}{\partial C_{ij}}\frac{\partial C_{ij}}{\partial B_{mn}} \\ &= \sum_{ij} \frac{\partial L}{\partial C_{ij}} A_{im} \delta_{jn} \\ &= \sum_{i} \frac{\partial L}{\partial C_{in}} A_{im} \end{split}(0.4)

So Eq (0.4) is equivalent to

LB=LCA\frac{\partial L}{\partial B} = \frac{\partial L}{\partial C} A

q.e.d.

Proof of Eq. (0.2)

CT=BTATC^T = B^TA^T(0.5)

apply Eq. (0.1) on Eq. (0.5)

LAT=LCTBT \frac{\partial L}{\partial A^T} = \frac{\partial L}{\partial C^T} B^T(0.6)

Transpose both sides

LA=BLC \frac{\partial L}{\partial A} = B\frac{\partial L}{\partial C}(0.7)

q.e.d

(Or check appendix for alternative proof.)

Non-Linear Case

For matrix ARm×nA \in \Reals^{m \times n} and BRn×kB \in \Reals^{n \times k}, define C=σ(AB)Rm×kC=\sigma(AB) \in \Reals^{m \times k}, where σ(x)\sigma(x) is a element-wise non-linear function. If we know the gradient of output CC as LCRk×m\frac{\partial L}{\partial C} \in \Reals^{k \times m}. The gradient with respect to BB is

LB=LCσTA\frac{\partial L}{\partial B} = \frac{\partial L}{\partial C}\cdot {\sigma^{\prime}}^{T} A(0.8)

The gradient with respect to AA is

LA=BLCσT\frac{\partial L}{\partial A} = B \frac{\partial L}{\partial C}\cdot {\sigma^{\prime}}^{T} (0.9)

Proof of Eq. (0.8).

CijBmn=σ(kAikBkj)Bmn=σ(kAikBkj)σ(kAikBkj)Bmn=σ(kAikBkj)kAikδkmδjn=σijAimδjn\begin{split}\frac{\partial C_{ij}}{\partial B_{mn}} &= \frac{\partial \sigma(\sum_k A_{ik}B_{kj})}{\partial B_{mn}} \\ &= \sigma^{\prime}(\sum_k A_{ik}B_{kj}) \frac{\partial \sigma (\sum_k A_{ik}B_{kj})}{\partial B_{mn}} \\ &= \sigma^{\prime}(\sum_k A_{ik}B_{kj}) \sum_k A_{ik} \delta_{km} \delta_{jn} \\ &= \sigma^{\prime}_{ij}A_{im} \delta_{jn} \end{split}(0.10)
LBmn=ijLCijCijBmn=ijLCijσijAimδjn=iLCinσinAim\begin{split}\frac{\partial L}{\partial B_{mn}} &= \sum_{ij} \frac{\partial L}{\partial C_{ij}}\frac{\partial C_{ij}}{\partial B_{mn}} \\ &= \sum_{ij} \frac{\partial L}{\partial C_{ij}} \sigma^{\prime}_{ij} A_{im} \delta_{jn} \\ &= \sum_{i} \frac{\partial L}{\partial C_{in}} \sigma^{\prime}_{in} A_{im} \end{split}(0.11)

So Eq (0.11) is equivalent to

LB=LCσTA\frac{\partial L}{\partial B} = \frac{\partial L}{\partial C}\cdot {\sigma^{\prime}}^{T} A

where \cdot is element-wise multiplication.

q.e.d.

Proof of Eq. (0.9) is similar to proof of Eq. (0.8)

Forward Pass

Let's assume we have an 2-layer MLP neural network.

Input is a column vector x0Rv×1x_0 \in \Reals^{v \times 1}, where vv is a finite dimension. E.g. vocabulary dimension.

The input layer converts the finite dimension to infinite dd by matrix W1Rd×vW_1 \in \Reals^{d \times v}.

Assume all the point-wise non-linear activation functions are σ(x)Rd×1\sigma(x) \in \Reals^{d \times 1}.

The output x1x_1 is defined as

x1=σ(W1x0)x_1 = \sigma(W_1 x_0)(1)

Similarly, the inter-layer weight W2Rd×dW_2 \in \Reals^{d \times d} transform the input x1x_1 to the output x2x_2 by

x2=σ(W2x1)x_2 = \sigma(W_2 x_1)(2)

Lastly, the output layer matrix W3R,v×dW_3 \in \Reals^{,v \times d} convert the infinite dimension dd into finite vv.

x3=σ(W3x2)x_3 = \sigma(W_3 x_2)(3)

Add the MSE loss layer as

L=x3y2 L = \Vert x_3 - y \Vert ^2 (4)

Backward Pass

Starting from the loss LL, we calculate Lx3R1×v\frac{\partial L}{\partial x_3} \in \Reals^{1 \times v} as:

Lx3=2(x3y)T \frac{\partial L}{\partial x_3} = 2 (x_3 - y)^T (5)

Following the 3 desiderata as in stated by Yang et al [1], i.e.

  1. Every (pre)activation vector in a network should have Θ(1)\Theta(1)-sized coordinates.
  2. Neural network output should be O(1)O(1).
  3. All parameters should be updated as much as possible (in terms of scaling in width) without leading to divergence.

Since x0,x3Rv×1x_0, x_3 \in \Reals^{v \times 1} and x1,x2Rd×1x_1, x_2 \in \Reals^{d \times 1} are all Θ(1)\Theta(1)-sized coordinates, we can see Lx3\frac{\partial L}{\partial x_3} is of size Θ(1)\Theta(1).

Since we know from Yang et al [1] that the output layer should have W3W_3 of size Θ(1/d)\Theta(1/d) to make x3x_3 of size Θ(1)\Theta(1). This is because x2x_2 and W3W_3 are correlated and x2x_2 are of size Θ(1)\Theta(1). Using law of large number, we can see W3x2W_3 x_2 is of coordinate size Θ(1)\Theta(1).

Output Layer

Calculate the LW3\frac{\partial L}{\partial W_3} according to Eq. (0.9):

LW3=x2Lx3σT(W3x2)=2x2(x3y)TσT(W3x2) \begin{split} \frac{\partial L}{\partial W_3} &= x_2 \frac{\partial L}{\partial x_3}\cdot {\sigma^{\prime}}^{T}(W_3 x_2) \\ &= 2 x_2 (x_3 - y)^T \cdot {\sigma^{\prime}}^{T}(W_3 x_2) \end{split} (6)

where \cdot is point-wise multiplication. We define vector c2=2(x3y)σ(W3x2)Rv×1c_2=2 (x_3 - y) \cdot \sigma^{\prime}(W_3 x_2) \in \Reals^{v \times 1}. It is easy to see c2c_2 has coordinate size of Θ(1)\Theta(1). And LW3\frac{\partial L}{\partial W_3} is simplified to

LW3=x2c2T \frac{\partial L}{\partial W_3} = x_2 c_2^T (7)

According to the 3 desiderata, we want (W3ηΔW3)x2(W_3 - \eta \Delta W_3)x_2 of coordinate size Θ(1)\Theta(1), which implies both W3x2W_3 x_2 and ηΔW3x2\eta \Delta W_3 x_2 have coordinate size of Θ(1)\Theta(1). So, W3W_3 has to be initialized with coordinate size of Θ(1/d)\Theta(1/d). While for ηΔW3x2=ηc2x2Tx2 \eta \Delta W_3 x_2 = \eta c_2x_2^T x_2. For SGD, we need to set η=1/d\eta=1/d, so x2Tx2/dx_2^T x_2 / d has coordinate size of Θ(1)\Theta(1). Note, x2Tx_2^T is the x2x_2 in the previous step. For Adam, x2Tx2x_2^T x_2 needs a constant 1/d1/d to be independent of dd.

Calculate Lx2R1×d\frac{\partial L}{\partial x_2} \in \Reals^{1 \times d} according to Eq. (0.8)

Lx2=Lx3σT(W3x2)W3=c2TW3\begin{split}\frac{\partial L}{\partial x_2} &= \frac{\partial L}{\partial x_3}\cdot {\sigma^{\prime}}^{T}(W_3 x_2) W_3 \\ &= c_2^T W_3\end{split}(8)

Since vv is finite, W3W_3 is of order Θ(1/d)\Theta(1/d), Lx2\frac{\partial L}{\partial x_2} is of order Θ(1/d)\Theta(1/d)

Hidden Layer

Calculate the LW2\frac{\partial L}{\partial W_2} according to Eq. (0.9):

LW2=x1Lx2σT(W2x1)=x1(c2TW3)σT(W2x1) \begin{split} \frac{\partial L}{\partial W_2} &= x_1 \frac{\partial L}{\partial x_2}\cdot {\sigma^{\prime}}^{T}(W_2 x_1) \\ &= x_1 (c^T_2 W_3) \cdot {\sigma^{\prime}}^{T} (W_2 x_1) \end{split} (9)

We define vector c1=(c2TW3)Tσ(W2x1)Rd×1c_1= (c^T_2 W_3)^T \cdot \sigma^{\prime}(W_2 x_1) \in \Reals^{d \times 1} which has coordinate size of Θ(1/d)\Theta(1/d).

According to the 3 desiderata, we want (W2ηΔW2)x1(W_2 - \eta \Delta W_2)x_1 of coordinate size Θ(1)\Theta(1), which implies both W2x1W_2 x_1 and ηΔW2x1\eta \Delta W_2 x_1 have coordinate size of Θ(1)\Theta(1). So, W2W_2 has to be initialized with coordinate size of Θ(1/d)\Theta(1/\sqrt{d}). While for ηΔW2x1=ηc1x1Tx1 \eta \Delta W_2 x_1 = \eta c_1x_1^T x_1. For SGD, we need to set η=1\eta=1, so c1x1Tx1c_1 x_1^T x_1 has coordinate size of Θ(1)\Theta(1) because c1c_1 has order of Θ(1/d)\Theta(1/d). Note, x1Tx_1^T is the x1x_1 in the previous step. For Adam, x1Tx1x_1^T x_1 needs a constant 1/d1/d because of the Θ(1/d)\Theta(1/d) term c1c_1 in the gradient is canceled out by normalization.

Calculate the Lx1\frac{\partial L}{\partial x_1} according to Eq. (0.8):

Lx1=Lx2σT(W2x1)W2=c1TW2\begin{split}\frac{\partial L}{\partial x_1} &= \frac{\partial L}{\partial x_2}\cdot {\sigma^{\prime}}^{T} (W_2 x_1) W_2 \\ &= c_1^T W_2\end{split}(10)

Input Layer

Calculate the LW1\frac{\partial L}{\partial W_1} according to Eq. (0.9):

LW1=x0Lx1σT(W1x0)=x0c1TW2σT(W1x0) \begin{split} \frac{\partial L}{\partial W_1} &= x_0 \frac{\partial L}{\partial x_1}\cdot {\sigma^{\prime}}^{T}(W_1 x_0) \\ &= x_0 c_1^T W_2 \cdot {\sigma^{\prime}}^{T}(W_1 x_0) \end{split} (11)

According to the 3 desiderata, we want (W1ηΔW1)x0(W_1 - \eta \Delta W_1)x_0 of coordinate size Θ(1)\Theta(1), which implies both W1x0W_1 x_0 and ηΔW1x0\eta \Delta W_1 x_0 have coordinate size of Θ(1)\Theta(1). So, W1W_1 has to be initialized with coordinate size of Θ(1)\Theta(1) because of the finite x0x_0 dimension. While for ηΔW1x0=ηW2Tc1σ(W1x0)x0Tx0 \eta \Delta W_1 x_0 = \eta W_2^T c_1 \cdot \sigma^{\prime}(W_1 x_0) x_0^T x_0. x0Tx0x_0^T x_0 is of order Θ(1)\Theta(1) because of the finite vv dimension. To counter for the extra Θ(1/d)\Theta(1/d) from the c1c_1 term, for SGD, we need to set η=d\eta=d. Note, x0Tx_0^T is the x0x_0 in the previous step. For Adam, it needs a constant 11 because the Θ(1/d)\Theta(1/d) term c1c_1 in the gradient is canceled out by normalization.

To see why x2t1T{x_2^{t-1}}^T at t1t-1 correlates with x2tx_2^t at tt. Calculate the x2tx_2^t by the forward pass Eq. 2 and note that weight matrix W2t=W2t1ηc1t1x1t1TW_2^t = W_2^{t-1} - \eta c_1^{t-1} {x_1^{t-1}}^T according to Eq. 9.

x2t=σ(W2tx1t))=(σ((W2t1ηc1x1t1T)x1t)) \begin{split} x_2^t &= \sigma(W_2^{t} x_1^t)) \\ &= (\sigma((W_2^{t-1} - \eta c_1 {x_1^{t-1}}^T) x_1^t)) \end{split}

And by forward Eq. 2,

x2t1=σ(W2t1x1t1) x_2^{t-1} = \sigma(W_2^{t-1}x_1^{t-1})

It is clear to see that both x2tx_2^t and x2t1x_2^{t-1} depend on the same term x1t1x_1^{t-1}, so they are positively correlated.

Self-attention Layer

From above example, we can see the layer output gradient has coordinate size of Θ(1/d)\Theta(1/d) because the c1c_1 term passed down from the output layer in the backward pass. So we have LO\frac{\partial L}{\partial O} has size Θ(1/d)\Theta(1/d)

Ignoring the multiple heads, assume the input xRs×dx \in \Reals^{s \times d} where ss is the finite sequence dimension, the matrices K,Q,VRd×dK,Q,V \in \Reals^{d \times d}. The self-attention layer is

O=σ(xQKTxTd)xVO = \sigma(\frac{xQK^Tx^T}{d})xV

Note here we use dd to scale the attention score as shown later it is necessary.

Calculate the Lσ\frac{\partial L}{\partial \sigma} according to Eq. (0.2):

Lσ=xVLO \begin{split} \frac{\partial L}{\partial \sigma} &= xV \frac{\partial L}{\partial O} \end{split} (12)

Calculate the LxQ\frac{\partial L}{\partial xQ} according to Eq. (0.9):

LxQ=KTxTdLσσT \begin{split} \frac{\partial L}{\partial xQ} &= \frac{K^Tx^T}{d} \frac{\partial L}{\partial \sigma}\cdot {\sigma^{\prime}}^T \end{split} (13)

Calculate the LQ\frac{\partial L}{\partial Q} according to Eq. (0.1) and substitute Eq. 12 and Eq 13:

LQ=LxQx=KTxTdLσσTx=KTxTdxVLOσTx \begin{split} \frac{\partial L}{\partial Q} &= \frac{\partial L}{\partial xQ} x \\ &= \frac{K^Tx^T}{d} \frac{\partial L}{\partial \sigma}\cdot {\sigma^{\prime}}^T x \\ &= \frac{K^Tx^T}{d} xV \frac{\partial L}{\partial O} \cdot {\sigma^{\prime}}^T x \\ \end{split} (14)

According to the 3 desiderata, we want x(QηΔQ)x(Q - \eta \Delta Q) of coordinate size Θ(1)\Theta(1), which implies both xQxQ and ηxΔQ\eta x \Delta Q have coordinate size of Θ(1)\Theta(1). So, QQ has to be initialized with coordinate size of Θ(1/d)\Theta(1/\sqrt{d}). While for ηxΔQ=ηxxTσLOTVTxTxKd \eta x \Delta Q = \eta x x^T {\sigma^{\prime}} \cdot \frac{\partial L}{\partial O^T} \frac{V^T x^T x K}{d} . VTxTxKV^Tx^T x K is of order Θ(1)\Theta(1) because of the finite ss dimension. xxTd\frac{xx^T}{d} is of order Θ(1)\Theta(1) by law of large numbers while the LO\frac{\partial L}{\partial O} term has order of Θ(1/d)\Theta(1/d) as we discussed before. So this is the same case as the hidden layer. For SGD, we need to set η=1\eta=1, because LO\frac{\partial L}{\partial O} term has order of Θ(1/d)\Theta(1/d). Note, x1Tx_1^T is the x1x_1 in the previous step. For Adam, η\eta is set to a constant 1/d1/d because of the Θ(1/d)\Theta(1/d) term in the gradient is canceled out by normalization.

For LK\frac{\partial L}{\partial K} and LV\frac{\partial L}{\partial V} cases, they are similar to the LQ\frac{\partial L}{\partial Q} case and we can use the hidden layer parameterization rules.

Layernorm and Bias

Layernorm layer is expressed mathematically as the following:

ln(x)=xμσG+B ln(x) = \frac{x - \mu}{\sigma} \cdot G + B

where the GRd×1G \in \Reals^{d \times 1} is the gain parameter used to re-scale the standardized summed inputs. BRd×1B \in \Reals^{d \times 1} is the bias parameter like any other bias terms. The mean is μ=1ni=1nxi\mu=\frac{1}{n}\sum_{i=1}^n x_i and standard deviation σ=1ni=1n(xiμ)2\sigma=\sqrt{\frac{1}{n}\sum_{i=1}^n (x_i - \mu)^2}. According to the desiderata 1, the input xx has value of magnitude of Θ(1)\Theta(1) which is independent of size dd. μ\mu and σ\sigma all have value of size Θ(1)\Theta(1), so is the standardized summed input xˉ=xμσ\bar{x}=\frac{x - \mu}{\sigma}.

Calculate the LG\frac{\partial L}{\partial G}:

LG=LlnlnG=Llnxμσ=Llnxˉ \begin{split} \frac{\partial L}{\partial G} &= \frac{\partial L}{\partial ln} \cdot \frac{\partial ln}{\partial G} \\ &= \frac{\partial L}{\partial ln} \cdot \frac{x - \mu}{\sigma} \\ &= \frac{\partial L}{\partial ln} \cdot \bar{x} \end{split}

Calculate the LB\frac{\partial L}{\partial B}:

LB=LlnlnB=Lln1=Lln \begin{split} \frac{\partial L}{\partial B} &= \frac{\partial L}{\partial ln} \cdot \frac{\partial ln}{\partial B} \\ &= \frac{\partial L}{\partial ln} \cdot 1 \\ &= \frac{\partial L}{\partial ln} \end{split}

We can view the bias term BB as parameters of an input layer with constant input 1R11 \in \Reals^1. All of our conclusions for input layer apply here. Let's focus on the xμσG\frac{x - \mu}{\sigma} \cdot G term.

According to the 3 desiderata, we want x(GηΔG)x\cdot (G - \eta \Delta G) of coordinate size Θ(1)\Theta(1), which implies both xGx \cdot G and ηxΔG\eta x \cdot \Delta G have coordinate size of Θ(1)\Theta(1). So, GG has to be initialized with coordinate size of Θ(1)\Theta(1). While for ηxΔG=ηxLlnxˉ \eta x \cdot \Delta G = \eta x \cdot \frac{\partial L}{\partial ln} \cdot \bar{x}. We know xxˉx\cdot \bar{x} has order of Θ(1)\Theta(1). As shown previously, Lln\frac{\partial L}{\partial ln} is of order Θ(1/d)\Theta(1/d) because of the term passed down from the readout layer. To counter for it, for SGD, we need to set η=d\eta=d. For Adam, it needs a constant 11 because the Θ(1/d)\Theta(1/d) term from Lln\frac{\partial L}{\partial ln} in the gradient is canceled out by normalization. We can see it is the exact same case as the input layer.

Intuition about Lemma J.1

In the paper[1], the ABCABC parameterization is not unique. Different sets of parameterization can be converted from one to the other according to the Lemma J.1

Lemma J.1. Let ft(x)f_t(x) denote the neural network function after tt steps of training (using any fixed sequence of batches), evaluated on input xx. Consider a parameter tensor WW with learning rate CC, initialized as WN(0,B2)W \backsim N(0, B^2), and with a multiplier AA. Then for any θ>0\theta \gt 0, ft(x)f_t(x) stays fixed for all tt and xx if we set

  • when the optimizer is SGD
    AAθ,BB/θ,CC/θ2A \gets A\theta, B \gets B/\theta, C \gets C/\theta^2
  • when the optimizer is Adam,
    AAθ,BB/θ,CC/θA \gets A\theta, B \gets B/\theta, C \gets C/\theta

To see why it is the case, let's write down the forward pass equation as

f0(x)=σ(αW0x)f_0(x) = \sigma(\alpha W_0 x)

where α\alpha is the multiplier constant. Note, α\alpha is initialized as AA, WW is initialized as a Gaussian random number N(0,B2)\backsim N(0, B^2) and learning rate η\eta is initialized as CC. After one step, it becomes:

f1(x)=σ(αW1x)=σ(αW0xηαΔWx) \begin{split} f_1(x) &= \sigma(\alpha W_1 x) \\ &= \sigma(\alpha W_0 x- \eta \alpha \Delta W x) \\ \end{split} (15)

Using RMS loss and following the same derivation as in Eq. 6

ΔW=LW=αxLfσT(αWx)=2αx(xy)TσT(αWx) \begin{split} \Delta W = \frac{\partial L}{\partial W} &= \alpha x \frac{\partial L}{\partial f}\cdot {\sigma^{\prime}}^{T}(\alpha W x) \\ &= 2 \textcolor{red}{\alpha} x (x - y)^T \cdot {\sigma^{\prime}}^{T}(\textcolor{green}{\alpha W x}) \end{split} (16)

We can see that αWx\textcolor{green}{\alpha W x} term is invariant by AAθ,BB/θA \gets A\theta, B \gets B/\theta changes. And ΔW\Delta W is increased by a factor of θ\theta because of the extra α\textcolor{red}{\alpha} term. However, when using the Adam optimizer, the factor of θ\theta is canceled out by the second momentum normalization.

Substitute Eq. 16 to Eq. 15, we notice the first term αW0x\alpha W_0 x is invariant under the AAθ,BB/θA \gets A\theta, B \gets B/\theta changes. For the second term ηαΔWx\eta \textcolor{red}{\alpha} \textcolor{green}{\Delta W} x, it is invariant under the CC/θ2C \gets C/\theta^2 for SGD optimizer, because both α\textcolor{red}{\alpha} and ΔW\textcolor{green}{\Delta W} have a factor of θ\theta. But for Adam optimizer, ΔW\textcolor{green}{\Delta W} has factor of 11 due to the normalization, so we need CC/θC \gets C/\theta to make it invariant.

We have seen the case for a single layer. What about multiple layers? Let's calculate the gradient with respect to the lower layer output (or current layer's input) like Eq. 8.

Lx=αLfσT(αWx)W=2α(xy)TσT(αWx)W \begin{split} \frac{\partial L}{\partial x} &= \alpha \frac{\partial L}{\partial f}\cdot {\sigma^{\prime}}^{T}(\alpha W x) W \\ &= 2 \textcolor{red}{\alpha} (x - y)^T \cdot {\sigma^{\prime}}^{T}(\textcolor{green}{\alpha W x}) \textcolor{red}{W} \end{split} (17)

It is easy to see the αW\textcolor{red}{\alpha W} is invariant under AAθ,BB/θA \gets A\theta, B \gets B/\theta. So gradient passed down in the backward step is invariant under Lemma J.1 . The same invariant arguments for single layer above applies for multiple layer neural networks recursively by replacing Lf\frac{\partial L}{\partial f} with Lx\frac{\partial L}{\partial x} in Eq. 16 and Eq. 17.

References

  1. Yang, Greg, et al. "Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer." arXiv preprint arXiv:2203.03466 (2022).
  2. Yang, Greg, and Edward J. Hu. "Feature learning in infinite-width neural networks." arXiv preprint arXiv:2011.14522 (2020).

Appendix

Alternative proof of Eq. (0.2)

CijAmn=(kAikBkj)Amn=kBkjδimδkn=Bnjδim\begin{split}\frac{\partial C_{ij}}{\partial A_{mn}} &= \frac{\partial (\sum_k A_{ik}B_{kj})}{\partial A_{mn}} \\ &= \sum_k B_{kj} \delta_{im} \delta_{kn} \\ &= B_{nj} \delta_{im} \end{split}(0.10)
LAmn=ijLCijCijAmn=ijLCijBnjδim=jLCmjBnj\begin{split}\frac{\partial L}{\partial A_{mn}} &= \sum_{ij} \frac{\partial L}{\partial C_{ij}}\frac{\partial C_{ij}}{\partial A_{mn}} \\ &= \sum_{ij} \frac{\partial L}{\partial C_{ij}} B_{nj} \delta_{im} \\ &= \sum_{j} \frac{\partial L}{\partial C_{mj}} B_{nj} \end{split}(0.11)

So Eq (0.11) is equivalent to

LA=BLC\frac{\partial L}{\partial A} = B \frac{\partial L}{\partial C}

q.e.d.

Written on July 3, 2022