Convolutional Neural Network Backward Pass

This tutorial will show how to perform the backward pass of a convolutional neural network.

Backprop through a single layer

Let's say we have a neural network where one of the middle layers is a convolution \(Z = W\circledast A + b\) (\(\circledast\) is the convolution operator). Let's start with a batch size of 1, assume A has 1 channel and a height and width of 3, and the W is a single filterof shape (2, 2); hence Z has shape (2, 2):

\( A = \begin{bmatrix} a_1 & a_2 & a_3\\ a_4 & a_5 & a_6\\ a_7 & a_8 & a_9 \end{bmatrix} \)
\( W = \begin{bmatrix} w_1 & w_2\\ w_3 & w_4 \end{bmatrix} \)
\( b = \begin{bmatrix} b_1 \end{bmatrix} \)

\( Z = \begin{bmatrix} a_1w_1 + a_2w_2 + a_4w_3 + a_5w_4 & a_2w_1 + a_3w_2 + a_5w_3 + a_6w_4 \\ a_4w_1 + a_5w_2 + a_7w_3 + a_8w_4 & a_5w_1 + a_6w_2 + a_8w_3 + a_9w_4 \end{bmatrix} + b_1 \) \( = \begin{bmatrix} z_1 & z_2\\ z_3 & z_4 \end{bmatrix} \)

Assume we have already calculated the upstream gradient \(\dfrac{\partial L}{\partial Z}\) (let's call it \(Z'\)): \[\dfrac{\partial L}{\partial Z} = Z' = \begin{bmatrix} z'_1 & z'_2\\ z'_3 & z'_4 \end{bmatrix} \] What is \(\dfrac{\partial L}{\partial W}\), \(\dfrac{\partial L}{\partial A}\), and \(\dfrac{\partial L}{\partial b}\)?

\(\dfrac{\partial L}{\partial W}\):

\(\dfrac{\partial z}{\partial W}\) would be a 4d tensor, so instead let's look at the derivative of each element of \(Z\) with respect to \(W\):

\( \dfrac{\partial z_1}{\partial W} = \begin{bmatrix} a_1 & a_2\\ a_4 & a_5 \end{bmatrix} \)
\( \dfrac{\partial z_2}{\partial W} = \begin{bmatrix} a_2 & a_3\\ a_5 & a_6 \end{bmatrix} \)
\( \dfrac{\partial z_3}{\partial W} = \begin{bmatrix} a_4 & a_5\\ a_7 & a_8 \end{bmatrix} \)
\( \dfrac{\partial z_4}{\partial W} = \begin{bmatrix} a_5 & a_6\\ a_8 & a_9 \end{bmatrix} \)

The total derivative of loss with respect to W is:

\(\dfrac{\partial L}{\partial W} = \dfrac{\partial L}{\partial z_1}\dfrac{\partial z_1}{\partial W} + \dfrac{\partial L}{\partial z_2}\dfrac{\partial z_2}{\partial W} \) \( + \dfrac{\partial L}{\partial z_3}\dfrac{\partial z_3}{\partial W} + \dfrac{\partial L}{\partial z_4}\dfrac{\partial z_4}{\partial W} \)

So:

\(\dfrac{\partial L}{\partial W} = z'_1 \begin{bmatrix} a_1 & a_2\\ a_4 & a_5 \end{bmatrix} + z'_2 \begin{bmatrix} a_2 & a_3\\ a_5 & a_6 \end{bmatrix} \) \( + z'_3 \begin{bmatrix} a_4 & a_5\\ a_7 & a_8 \end{bmatrix} + z'_4 \begin{bmatrix} a_5 & a_6\\ a_8 & a_9 \end{bmatrix} = Z' \circledast A \)

Thus, we simply convolve \(Z'\) over \(A\). The more advanced cases where the input has multiple channels, there are multiple filters, and the batch size is greater than one can be broken down the same way to look for the pattern. If input \(A\) has \(C\) channels and \(W\) has \(F\) filters (and thus \(Z\) has \(F\) channels), to compute \(\frac{\partial L}{\partial W_{cf}}\) (i.e. the loss of the \(c^{th}\) channel of the \(f^{th}\) filter), convolve the \(f^{th}\) channel of \(Z'\) over the \(c^{th}\) channel of \(A\), and sum over the batch dimension.

\(\dfrac{\partial L}{\partial A}\):

Starting with the simplest case of 1 channel for \(A\) and one filter as before:

\( \dfrac{\partial z_1}{\partial A} = \begin{bmatrix} w_1 & w_2 & 0\\ w_3 & w_4 & 0\\ 0 & 0 & 0 \end{bmatrix} \)
\( \dfrac{\partial z_2}{\partial A} = \begin{bmatrix} 0 &w_1 & w_2 \\ 0 & w_3 & w_4\\ 0 & 0 & 0 \end{bmatrix} \)
\( \dfrac{\partial z_3}{\partial A} = \begin{bmatrix} 0 & 0 & 0\\ w_1 & w_2 & 0\\ w_3 & w_4 & 0 \end{bmatrix} \)
\( \dfrac{\partial z_4}{\partial A} = \begin{bmatrix} 0 & 0 & 0\\ 0 &w_1 & w_2 \\ 0 & w_3 & w_4 \end{bmatrix} \)
So:

\(\dfrac{\partial L}{\partial A} = z'_1 \begin{bmatrix} w_1 & w_2 & 0\\ w_3 & w_4 & 0\\ 0 & 0 & 0 \end{bmatrix} + z'_2 \begin{bmatrix} 0 &w_1 & w_2 \\ 0 & w_3 & w_4\\ 0 & 0 & 0 \end{bmatrix} \) \( + z'_3 \begin{bmatrix} 0 & 0 & 0\\ w_1 & w_2 & 0\\ w_3 & w_4 & 0 \end{bmatrix}+ z'_4 \begin{bmatrix} 0 & 0 & 0\\ 0 &w_1 & w_2 \\ 0 & w_3 & w_4 \end{bmatrix} \) \( = W_{180} \circledast Z_{pad\_k-1}' \)

So we just flip the kernel 180 degrees and then convolve it over \(Z'\) with padding \(k-1\). See Figure 1 below.

Conv Image1 Conv Image2 Conv Image3
Figure 1: Sliding the flipped kernel over the padded \(Z'\)

If the batch size is greater than one, this convolution is done independently for each example in the batch and the batch dimension is retained.

If the input has more than one channel (and hence each filter has more than one channel), each channel of W is convolved over \(Z'\) so that the channel dimension is retained (this can be done in one step with broadcasting).

If there are multiple filters (and hence \(Z'\) has multiple channels), the aforementioned convolutions are applied separately for each filter over its corresponding channel of \(Z'\), the results of which are then summed over the dimension corresponding to the number of filters. For example, if input \(A\) is 1x1x3, and there are two filters (each of shape 1x1x3), then \(Z'\) is 1x1x2. \(A'\) is calculated by multiplying the first channel of \(Z'\) by all channels of the first filter, multiplying the second channel of \(Z'\) by all channels of the second filter, and then summing these two terms.

\(\dfrac{\partial L}{\partial b}\):

Simply sum \(Z'\) over all dimensions except the one corresponding to the number of filters (i.e. the result will be 1d with a length equal to the number of filters).

Pooling Derivative:

Now we'll take a look how to calculate the derivative from a pooling layer, specifically max pooling.

Let's take a 3x3 input, \(Z\), and apply a max pooling operation with a 2x2 kernel size and a stride of 1. This will produce a 2x2 output, \(A\) = MaxPool(\(Z\)). Also, we'll assume we have the derivative of loss with respect to \(A\), \(\dfrac{\partial L}{\partial A} = A'\):

\( Z = \begin{bmatrix} 1 & 2 & 3\\ 4 & 10 & 5\\ 11 & 12 & 13 \end{bmatrix} \)
\( A = \begin{bmatrix} 10 & 10\\ 12 & 13 \end{bmatrix} \)
\( A' = \begin{bmatrix} a'_1 & a'_2\\ a'_3 & a'_4 \end{bmatrix} \)

Let's calculate the derivative of each element of \(A\) with respect to \(Z\):

\( \dfrac{\partial a_1}{\partial Z} = \begin{bmatrix} 0 & 0 & 0\\ 0 & 1 & 0\\ 0 & 0 & 0 \end{bmatrix} \)
\( \dfrac{\partial a_2}{\partial Z} = \begin{bmatrix} 0 & 0 & 0\\ 0 & 1 & 0\\ 0 & 0 & 0 \end{bmatrix} \)
\( \dfrac{\partial a_3}{\partial Z} = \begin{bmatrix} 0 & 0 & 0\\ 0 & 0 & 0\\ 0 & 1 & 0 \end{bmatrix} \)
\( \dfrac{\partial a_4}{\partial Z} = \begin{bmatrix} 0 & 0 & 0\\ 0 & 0 & 0\\ 0 & 0 & 1 \end{bmatrix} \)

Now we can calculate the total derivative of loss with respect to \(Z\), \(\dfrac{\partial L}{\partial Z}\):

\(\dfrac{\partial L}{\partial Z} = \dfrac{\partial L}{\partial a_1}\dfrac{\partial a_1}{\partial Z} + \dfrac{\partial L}{\partial a_2}\dfrac{\partial a_2}{\partial Z} \) \( + \dfrac{\partial L}{\partial a_3}\dfrac{\partial a_3}{\partial Z} + \dfrac{\partial L}{\partial a_4}\dfrac{\partial a_4}{\partial Z} \)

\(= a'_1 \begin{bmatrix} 0 & 0 & 0\\ 0 & 1 & 0\\ 0 & 0 & 0 \end{bmatrix} + a'_2 \begin{bmatrix} 0 & 0 & 0 \\ 0 & 1 & 0\\ 0 & 0 & 0 \end{bmatrix} \) \( + a'_3 \begin{bmatrix} 0 & 0 & 0\\ 0 & 0 & 0\\ 0 & 1 & 0 \end{bmatrix}+ a'_4 \begin{bmatrix} 0 & 0 & 0\\ 0 & 0 & 0 \\ 0 & 0 & 1 \end{bmatrix} \) \( = \begin{bmatrix} 0 & 0 & 0\\ 0 & a'_1 + a'_2 & 0 \\ 0 & a'_3 & a'_4 \end{bmatrix} \)

Thus, to calculate the max pool gradient do the following:

Initialize \(Z'\) as all zeros with the same shape as \(Z\). Then, for each element \(a_i\) in \(A\), find the argmax, \(z_i\), of \(Z\) over the window corresponding to \(a_i\), and add the element of \(A'\) that's in the same location as \(a_i\) to \(Z'[z_i]\).

Here's a python script using PyTorch that automatically calculates the gradients. It's using a somewhat arbitrary loss function in order to get values for \(A'\) (just taking the -log of each element and then summing them all).

import numpy as np
import torch


Z = np.array(([1., 2, 3], [4, 10, 5], [11, 12, 13]))[np.newaxis, np.newaxis]
Z = torch.tensor(Z, requires_grad=True)
A = torch.max_pool2d(Z, kernel_size=2, stride=1)
A.retain_grad()
loss = torch.sum(-torch.log(A))
loss.backward()
print('Z:', Z.detach().numpy(), 'A:', A.detach().numpy(), 'dL/dA:', A.grad.numpy(),
      'dL/dZ:', Z.grad.numpy(), sep='\n')

This gives the following output. As can be seen, given specific values for \(A'\), the results match up for our example above.

Z:
[[[[ 1.  2.  3.]
   [ 4. 10.  5.]
   [11. 12. 13.]]]]
A:
[[[[10. 10.]
   [12. 13.]]]]
dL/dA:
[[[[-0.1        -0.1       ]
   [-0.08333333 -0.07692308]]]]
dL/dZ:
[[[[ 0.          0.          0.        ]
   [ 0.         -0.2         0.        ]
   [ 0.         -0.08333333 -0.07692308]]]]