Vanilla Neural Network Backward Pass

This tutorial will show how to perform the backward pass of a vanilla neural network.

Backprop through a single layer

Let's say we have a neural network where one of the middle layers is \(z = aW + b\). Let's start with a batch size of 1, assume \(a\) has 3 features, and \(W\) is shape (3, 2); hence \(z\) has shape (1, 2): \[ a = \begin{bmatrix} a_1 & a_2 & a_3 \end{bmatrix}\quad \] \[ W = \begin{bmatrix} w_1 & w_2\\ w_3 & w_4\\ w_5 & w_6 \end{bmatrix}\quad \] \[ b = \begin{bmatrix} b_1 & b_2 \end{bmatrix}\quad \] \[ z = aW + b = \begin{bmatrix} z_1 & z_2 \end{bmatrix} \] Assume we have already calculated the upstream gradient \(\dfrac{\partial L}{\partial z}\) (let's call it \(z'\)): \[\dfrac{\partial L}{\partial z} = z' = \begin{bmatrix} z'_1 & z'_2 \end{bmatrix} \] What is \(\dfrac{\partial L}{\partial W}\), \(\dfrac{\partial L}{\partial b}\), and \(\dfrac{\partial L}{\partial a}\)?

\(\dfrac{\partial L}{\partial W}\):

\(\dfrac{\partial z}{\partial W}\) would be a 3d tensor, so instead let's look at the derivative of each element of \(z\) with respect to \(W\): \[ \dfrac{\partial z_1}{\partial W} = \begin{bmatrix} a_1 & 0\\ a_2 & 0\\ a_3 & 0 \end{bmatrix} \qquad \dfrac{\partial z_2}{\partial W} = \begin{bmatrix} 0 & a_1\\ 0 & a_2\\ 0 & a_3 \end{bmatrix} \] The total derivative of loss with respect to \(W\) is: \[\dfrac{\partial L}{\partial W} = \dfrac{\partial L}{\partial z_1}\dfrac{\partial z_1}{\partial W} + \dfrac{\partial L}{\partial z_2}\dfrac{\partial z_2}{\partial W} \] So:

\(\dfrac{\partial L}{\partial W} = z'_1 \begin{bmatrix} a_1 & 0\\ a_2 & 0\\ a_3 & 0 \end{bmatrix} + z'_2 \begin{bmatrix} 0 & a_1\\ 0 & a_2\\ 0 & a_3 \end{bmatrix} \) \( = \begin{bmatrix} a_1 \\ a_2 \\ a_3 \end{bmatrix} \begin{bmatrix} z'_1 & z'_2 \end{bmatrix} = a^\top z' \)

If the batch size is greater than 1, the resulting elements of \(\dfrac{\partial z}{\partial W}\) are simply automatically summed over the batch by the matrix multiplication (the first subscript number denotes the example number in the batch): \[ a = \begin{bmatrix} a_{11} & a_{12} & a_{13}\\ a_{21} & a_{22} & a_{23} \end{bmatrix} \qquad z' = \begin{bmatrix} z'_{11} & z'_{12} \\ z'_{21} & z'_{22} \end{bmatrix} \]

\( a^\top z' = \begin{bmatrix} a_{11} & a_{21} \\ a_{12} & a_{22} \\ a_{13} & a_{23} \end{bmatrix} \begin{bmatrix} z'_{11} & z'_{12} \\ z'_{21} & z'_{22} \end{bmatrix} \) \( = \begin{bmatrix} a_{11}z'_{11} + a_{21}z'_{21}& a_{11}z'_{12} + a_{21}z'_{22} \\ a_{12}z'_{11} + a_{22}z'_{21}& a_{12}z'_{12} + a_{22}z'_{22} \\ a_{13}z'_{11} + a_{23}z'_{21}& a_{13}z'_{12} + a_{23}z'_{22} \end{bmatrix} \)

NOTE: If you have square matrices make sure you are multiplying the correct elements.

\(\dfrac{\partial L}{\partial b}\):

\(\dfrac{\partial L}{\partial b}\) is very straightforward if we follow our method above: \[\dfrac{\partial z_1}{\partial b} = \begin{bmatrix} 1 & 0 \end{bmatrix} \qquad \dfrac{\partial z_2}{\partial b} = \begin{bmatrix} 0 & 1 \end{bmatrix} \]

\(\dfrac{\partial L}{\partial b} = \dfrac{\partial L}{\partial z_1}\dfrac{\partial z_1}{\partial b} + \dfrac{\partial L}{\partial z_2}\dfrac{\partial z_2}{\partial b} \) \( = z'_1 \begin{bmatrix} 1 & 0 \end{bmatrix} + z'_2 \begin{bmatrix} 0 & 1 \end{bmatrix} = \begin{bmatrix} z'_1 & z'_2 \end{bmatrix} \)

What if the batch size is greater than 1? We'll use a batch size of 2 as we did for \(\dfrac{\partial L}{\partial W}\): \[ z = \begin{bmatrix} z_{11} & z_{12} \\ z_{21} & z_{22} \end{bmatrix} \] Then

\(\dfrac{\partial z_{11}}{\partial b} = \begin{bmatrix} 1 & 0 \end{bmatrix} \)
\( \dfrac{\partial z_{12}}{\partial b} = \begin{bmatrix} 0 & 1 \end{bmatrix} \)
\( \dfrac{\partial z_{21}}{\partial b} = \begin{bmatrix} 1 & 0 \end{bmatrix} \)
\( \dfrac{\partial z_{22}}{\partial b} = \begin{bmatrix} 0 & 1 \end{bmatrix} \)

\(\dfrac{\partial L}{\partial b} = \dfrac{\partial L}{\partial z_{11}}\dfrac{\partial z_{11}}{\partial b} + \dfrac{\partial L}{\partial z_{12}}\dfrac{\partial z_{12}}{\partial b} \) \( + \dfrac{\partial L}{\partial z_{21}}\dfrac{\partial z_{21}}{\partial b} + \dfrac{\partial L}{\partial z_{22}}\dfrac{\partial z_{22}}{\partial b}\)

\( = z'_{11} \begin{bmatrix} 1 & 0 \end{bmatrix} + z'_{12} \begin{bmatrix} 0 & 1 \end{bmatrix} + z'_{21} \begin{bmatrix} 1 & 0 \end{bmatrix} \) \( + z'_{22} \begin{bmatrix} 0 & 1 \end{bmatrix} = \begin{bmatrix} z'_{11} + z'_{21} & z'_{12} + z'_{22} \end{bmatrix} \)

Thus, \(\dfrac{\partial L}{\partial b}\) is the sum of \(\dfrac{\partial L}{\partial z}\) over the batch dimension.

\(\dfrac{\partial L}{\partial a}\):

So, continuing with our method of taking the derivative of just one element of \(z\) with respect to \(a\):

\(\dfrac{\partial z_1}{\partial a} = \begin{bmatrix} w_1 & w_3 & w_5 \end{bmatrix} \)
\( \dfrac{\partial z_2}{\partial a} = \begin{bmatrix} w_2 & w_4 & w_6 \end{bmatrix} \)

\(\dfrac{\partial L}{\partial a} = \dfrac{\partial L}{\partial z_{1}}\dfrac{\partial z_{1}}{\partial a} + \dfrac{\partial L}{\partial z_{2}}\dfrac{\partial z_{2}}{\partial a} \) \( = z'_{1} \begin{bmatrix} w_1 & w_3 & w_5 \end{bmatrix} + z'_{2} \begin{bmatrix} w_2 & w_4 & w_6 \end{bmatrix} \) \( = z'W^\top \)

Since

\( z'W^\top = \begin{bmatrix} z'_1 & z'_2 \end{bmatrix} \begin{bmatrix} w_1 & w_3 & w_5\\ w_2 & w_4 & w_6\\ \end{bmatrix} \) \( = \begin{bmatrix} a'_1 & a'_2 & a'_3 \end{bmatrix} = \dfrac{\partial L}{\partial a} \)

If the batch size is greater than 1, then this same calculation will automatically be carried out for each example by the matrix multiplication, and the output will have multiple rows:

\( z'W^\top = \begin{bmatrix} z'_{11} & z'_{12}\\ z'_{21} & z'_{22} \end{bmatrix} \begin{bmatrix} w_1 & w_3 & w_5\\ w_2 & w_4 & w_6\\ \end{bmatrix} \) \( = \begin{bmatrix} a'_{11} & a'_{12} & a'_{13}\\ a'_{21} & a'_{22} & a'_{23} \end{bmatrix} = \dfrac{\partial L}{\partial a} \)

Activation Function Derivative:

Now that we've covered how to compute the gradient through the linear layer, let's see how to perform backprop through the non-linear activation function. For this example we have an output from a linear layer with a batch size of two called \(z\), and we put that through a ReLU activation and get \(a\):

\( z = \begin{bmatrix} z_{11} & z_{12} \\ z_{21} & z_{22} \end{bmatrix} \)
\( a = \begin{bmatrix} a_{11} & a_{12} \\ a_{21} & a_{22} \end{bmatrix} = ReLU(z)\)

Let's assume we already have \(\dfrac{\partial L}{\partial a}\):

\[\dfrac{\partial L}{\partial a} = \begin{bmatrix} a'_{11} & a'_{12} \\ a'_{21} & a'_{22} \end{bmatrix} \]

What is \(\dfrac{\partial L}{\partial z}\)? Since \(a\) and \(z\) are both matrices, \(\dfrac{\partial a}{\partial z}\) will be a 4d tensor. So as we've done before, let's just take derivative of a single element of \(a\) with respect to \(z\) to find the pattern:

\( \dfrac{\partial a_{11}}{\partial z} = \begin{bmatrix} ReLU\_d(z_{11}) & 0 \\ 0 & 0 \end{bmatrix} \)
\( \dfrac{\partial a_{12}}{\partial z} = \begin{bmatrix} 0 & ReLU\_d(z_{12})\\ 0 & 0 \end{bmatrix} \)
\( \dfrac{\partial a_{21}}{\partial z} = \begin{bmatrix} 0 & 0 \\ ReLU\_d(z_{21}) & 0 \end{bmatrix} \)
\( \dfrac{\partial a_{22}}{\partial z} = \begin{bmatrix} 0 & 0\\ 0 & ReLU\_d(z_{22}) \end{bmatrix} \)

Where \(ReLU\_d\) is the ReLU derivative: 1 if \(z>0\), \(0\) otherwise. Thus

\( \dfrac{\partial L}{\partial z} = \dfrac{\partial L}{\partial a_{11}}\dfrac{\partial a_{11}}{\partial z} + \dfrac{\partial L}{\partial a_{12}}\dfrac{\partial a_{12}}{\partial z} \) \( + \dfrac{\partial L}{\partial a_{21}}\dfrac{\partial a_{21}}{\partial z} + \dfrac{\partial L}{\partial a_{22}}\dfrac{\partial a_{22}}{\partial z} \)

\[ = \begin{bmatrix} a'_{11}*ReLU\_d(z_{11}) & a'_{12}*ReLU\_d(z_{12}) \\ a'_{21}*ReLU\_d(z_{21}) & a'_{22}*ReLU\_d(z_{22}) \end{bmatrix} \] \[ = a' \circ ReLU\_d(z) \] Where \(\circ\) is element-wise multiplication!