19 Feb 2020 · 825 words

The following code explores how gradients explode or vanish when backpropagating through a neural network. It depends on the type of activation function you are using and how many layers you have in the network.

Observations for some activation functions:

• sometimes with relu you get a lot of zeros in the gradient.
• I didn’t see any vanishing gradient with relu, but there was exploding gradient. Same with tanh, which was unexpected. I expected to see vanishing gradient for tanh.
• I didn’t see correspondence with relu between a weight value < 0 for a given weight layer and a gradient of zero. This was also unexpected.

Here’s the code:

``````import numpy as np
import torch

layer_size=3, normalise_x = True, print_weights=False):
"""
A function to print the gradients of intermediate layers at the init stage of
a simple dense neural network. No training of the neural network is done.

All intermediate layers are square with shape (`layer_size`,
`layer_size`), except the last layer which has shape (`layer_size`, 1).

The function creates some sample data to calculate a sample loss function. The
loss is for a regression problem and is MSE.

Params:
f: activation function, like torch.Tensor.relu, or torch.Tensor.tanh
n_layers: how many layers to have in the network
layer_size: what is the size of these layers.
normalise_x: set to True to set x to (x - mean(x)) / std(x)
print_weights: set to True to print the weights of each layer along with its gradient

"""
l = dict()
n = n_layers
w = layer_size
torch.Tensor.f = f

# create some sample data
x = torch.randint(low = -10, high = 10, size=(100,w), dtype = torch.float)
y = torch.randint(low=0, high=5, size=(100,),dtype = torch.float)

if normalise_x:  x = (x- x.mean(dim=0)) / x.std(dim=0)

# create random intermediate weight layers
for i in range(n):
name = 'w' + str(i)
size = (w,w) if i < (n-1) else (w,1)
l[name] = torch.randn(size, dtype = torch.float, requires_grad=True)

# forward pass, loss function, calculate gradients of intermediate layers
tmp = f(x)
for i in range(n):
name = 'w' + str(i)
if i < (n-1):       tmp =  tmp.matmul(l[name]).f()
else:               tmp =  tmp.matmul(l[name])
yp = tmp
L = (yp - y).pow(2).sum()
L.backward()

for i in range(n):
name = 'w' + str(i)
print("####", name, "####")
if print_weights: print("weights\n", l[name].detach().numpy().round(4))

Here is how it is used:

``````n_layers = 10
f = torch.Tensor.relu

Some example output:

```    #### w0 ####
weights
[[-1.86    0.6851 -0.5556 -0.5878]
[-0.5929  2.1052 -0.4813 -0.7117]
[-0.3172 -0.0884  1.0429  0.6922]
[-1.8281 -0.8995  1.0208 -0.9727]]
[[    0.      5713.2656 -1374.9153     0.    ]
[    0.      9508.161  -5190.595   1170.9376]
[    0.      4955.623  -5194.9365  1775.924 ]
[    0.      2636.087  -1891.1196   181.3048]]

#### w1 ####
weights
[[ 0.6907  0.1203 -0.4189 -2.7348]
[ 0.7943  1.1307  1.3316 -0.8833]
[-0.9451 -0.6082  0.3866  0.8421]
[ 1.3765  1.2247  0.4677 -0.4686]]
[[ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
[-1.5590840e+02  2.6379445e+04 -6.4444326e+03  0.0000000e+00]
[-6.6746460e+02  6.1963838e+03 -2.4536016e+03  0.0000000e+00]
[-2.0615299e+01  2.3536490e+02 -8.6139603e+01  0.0000000e+00]]

#### w2 ####
weights
[[ 0.1008  0.1396 -1.4134 -0.1433]
[ 1.8467 -0.8925  0.0248 -1.4145]
[-0.2846  0.4181 -1.2936 -0.1676]
[ 0.3009  2.1765 -1.7931  0.2762]]
[[  7417.276   -1930.0452      0.          0.    ]
[ 11098.887   -6554.655       0.          0.    ]
[ 13884.539  -13437.387       0.          0.    ]
[     0.          0.          0.          0.    ]]

#### w3 ####
weights
[[-0.7117 -0.4262 -1.4683  0.0227]
[-1.0138  1.305  -0.5682 -0.8239]
[ 0.5405 -0.4138  1.5537 -0.6959]
[-1.1945  1.0653 -2.0544  1.5107]]
[[0.000000e+00 0.000000e+00 0.000000e+00 7.618817e+05]
[0.000000e+00 0.000000e+00 0.000000e+00 4.573620e+01]
[0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00]
[0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00]]

#### w4 ####
weights
[[ 2.4130e-01  9.0000e-04 -6.2140e-01 -1.4420e-01]
[ 1.5490e+00 -8.5590e-01  9.4570e-01 -9.3770e-01]
[-3.0020e-01  2.0068e+00  1.1331e+00 -6.4380e-01]
[ 2.8960e-01  8.1560e-01  1.8700e-02 -1.3460e-01]]
[[     0.         0.         0.         0.   ]
[     0.         0.         0.         0.   ]
[     0.         0.         0.         0.   ]
[-11734.712  24611.941  30977.482      0.   ]]

#### w5 ####
weights
[[ 0.8595  0.2263  0.2642 -0.9272]
[ 0.5051  0.478   2.3984 -0.8763]
[-2.3839 -1.2242 -1.0006  0.6598]
[ 0.1359 -0.8366  1.4181  0.0567]]
[[ -4956.994   -1142.9338   4243.301       0.    ]
[-13960.942   -3218.9739  11950.89        0.    ]
[  -320.1633    -73.82      274.0672      0.    ]
[     0.          0.          0.          0.    ]]

#### w6 ####
weights
[[-1.6302 -0.3628 -0.5336  0.4115]
[-0.3475 -1.2729  0.0135 -0.0092]
[ 1.3637 -0.3827  0.3039 -0.3609]
[-0.2053 -0.7371  0.0209  1.2956]]
[[ 6943.7407     0.     -1444.9719     0.    ]
[ 4873.084      0.     -1014.0744     0.    ]
[22692.582      0.     -4722.259      0.    ]
[    0.         0.         0.         0.    ]]

#### w7 ####
weights
[[ 0.7484  0.1641  0.3732  1.5976]
[-0.3619  1.863  -1.3552  0.4618]
[ 0.4905  0.2353 -0.4943  0.2616]
[-0.004  -0.4753 -0.1268  0.2372]]
[[-10202.121    1074.2303   5650.8564  14572.965 ]
[     0.          0.          0.          0.    ]
[ -1852.5332    195.0621   1026.1002   2646.205 ]
[     0.          0.          0.          0.    ]]

#### w8 ####
weights
[[-0.515   0.6908  0.0826  1.5223]
[-0.3799  0.8091  0.0506 -1.2749]
[ 0.6146  1.4947  0.8942 -0.1798]
[ 1.3187 -0.5813  1.2728 -0.3163]]
[[ 6438.6147  -893.7505  1732.8018 -3122.6318]
[ 1589.804   -220.6823   427.8584  -771.0312]
[ 2178.9844  -302.467    586.4224 -1056.7747]
[12648.798  -1755.7922  3404.1267 -6134.4775]]

#### w9 ####
weights
[[-2.6363]
[ 0.366 ]
[-0.7095]
[ 1.2786]]