Exploding and vanishing gradients

The following code explores how gradients explode or vanish when backpropagating through a neural network. It depends on the type of activation function you are using and how many layers you have in the network.

Observations for some activation functions:

  • sigmoid tends to lead to vanishing gradient
  • sometimes with relu you get a lot of zeros in the gradient.
  • I didn’t see any vanishing gradient with relu, but there was exploding gradient. Same with tanh, which was unexpected. I expected to see vanishing gradient for tanh.
  • I didn’t see correspondence with relu between a weight value < 0 for a given weight layer and a gradient of zero. This was also unexpected.

Here’s the code:

import numpy as np
import torch

def print_gradients(f=torch.Tensor.relu, n_layers = 15, 
                    layer_size=3, normalise_x = True, print_weights=False):
    """
    A function to print the gradients of intermediate layers at the init stage of 
    a simple dense neural network. No training of the neural network is done. 
    
    All intermediate layers are square with shape (`layer_size`,
    `layer_size`), except the last layer which has shape (`layer_size`, 1). 
    
    The function creates some sample data to calculate a sample loss function. The 
    loss is for a regression problem and is MSE. 
    
    Params:
        f: activation function, like torch.Tensor.relu, or torch.Tensor.tanh 
        n_layers: how many layers to have in the network
        layer_size: what is the size of these layers. 
        normalise_x: set to True to set x to (x - mean(x)) / std(x)
        print_weights: set to True to print the weights of each layer along with its gradient
        
    """
    l = dict()
    n = n_layers
    w = layer_size
    torch.Tensor.f = f
    
    # create some sample data
    x = torch.randint(low = -10, high = 10, size=(100,w), dtype = torch.float)
    y = torch.randint(low=0, high=5, size=(100,),dtype = torch.float) 
    
    if normalise_x:  x = (x- x.mean(dim=0)) / x.std(dim=0)

    # create random intermediate weight layers 
    for i in range(n): 
        name = 'w' + str(i)
        size = (w,w) if i < (n-1) else (w,1)
        l[name] = torch.randn(size, dtype = torch.float, requires_grad=True)
    
    # forward pass, loss function, calculate gradients of intermediate layers
    tmp = f(x)
    for i in range(n): 
        name = 'w' + str(i)
        if i < (n-1):       tmp =  tmp.matmul(l[name]).f() 
        else:               tmp =  tmp.matmul(l[name]) 
    yp = tmp
    L = (yp - y).pow(2).sum()
    L.backward()

    # print out the gradients 
    for i in range(n): 
        name = 'w' + str(i)
        print("####", name, "####")
        if print_weights: print("weights\n", l[name].detach().numpy().round(4))
        print("grad\n", str(l[name].grad.numpy().round(4)),"\n")

Here is how it is used:

n_layers = 10
f = torch.Tensor.relu
print_gradients(f,n_layers, layer_size=4,print_weights=True)

Some example output:

    #### w0 ####
    weights
     [[-1.86    0.6851 -0.5556 -0.5878]
     [-0.5929  2.1052 -0.4813 -0.7117]
     [-0.3172 -0.0884  1.0429  0.6922]
     [-1.8281 -0.8995  1.0208 -0.9727]]
    grad
     [[    0.      5713.2656 -1374.9153     0.    ]
     [    0.      9508.161  -5190.595   1170.9376]
     [    0.      4955.623  -5194.9365  1775.924 ]
     [    0.      2636.087  -1891.1196   181.3048]] 
    
    #### w1 ####
    weights
     [[ 0.6907  0.1203 -0.4189 -2.7348]
     [ 0.7943  1.1307  1.3316 -0.8833]
     [-0.9451 -0.6082  0.3866  0.8421]
     [ 1.3765  1.2247  0.4677 -0.4686]]
    grad
     [[ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
     [-1.5590840e+02  2.6379445e+04 -6.4444326e+03  0.0000000e+00]
     [-6.6746460e+02  6.1963838e+03 -2.4536016e+03  0.0000000e+00]
     [-2.0615299e+01  2.3536490e+02 -8.6139603e+01  0.0000000e+00]] 
    
    #### w2 ####
    weights
     [[ 0.1008  0.1396 -1.4134 -0.1433]
     [ 1.8467 -0.8925  0.0248 -1.4145]
     [-0.2846  0.4181 -1.2936 -0.1676]
     [ 0.3009  2.1765 -1.7931  0.2762]]
    grad
     [[  7417.276   -1930.0452      0.          0.    ]
     [ 11098.887   -6554.655       0.          0.    ]
     [ 13884.539  -13437.387       0.          0.    ]
     [     0.          0.          0.          0.    ]] 
    
    #### w3 ####
    weights
     [[-0.7117 -0.4262 -1.4683  0.0227]
     [-1.0138  1.305  -0.5682 -0.8239]
     [ 0.5405 -0.4138  1.5537 -0.6959]
     [-1.1945  1.0653 -2.0544  1.5107]]
    grad
     [[0.000000e+00 0.000000e+00 0.000000e+00 7.618817e+05]
     [0.000000e+00 0.000000e+00 0.000000e+00 4.573620e+01]
     [0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00]
     [0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00]] 
    
    #### w4 ####
    weights
     [[ 2.4130e-01  9.0000e-04 -6.2140e-01 -1.4420e-01]
     [ 1.5490e+00 -8.5590e-01  9.4570e-01 -9.3770e-01]
     [-3.0020e-01  2.0068e+00  1.1331e+00 -6.4380e-01]
     [ 2.8960e-01  8.1560e-01  1.8700e-02 -1.3460e-01]]
    grad
     [[     0.         0.         0.         0.   ]
     [     0.         0.         0.         0.   ]
     [     0.         0.         0.         0.   ]
     [-11734.712  24611.941  30977.482      0.   ]] 
    
    #### w5 ####
    weights
     [[ 0.8595  0.2263  0.2642 -0.9272]
     [ 0.5051  0.478   2.3984 -0.8763]
     [-2.3839 -1.2242 -1.0006  0.6598]
     [ 0.1359 -0.8366  1.4181  0.0567]]
    grad
     [[ -4956.994   -1142.9338   4243.301       0.    ]
     [-13960.942   -3218.9739  11950.89        0.    ]
     [  -320.1633    -73.82      274.0672      0.    ]
     [     0.          0.          0.          0.    ]] 
    
    #### w6 ####
    weights
     [[-1.6302 -0.3628 -0.5336  0.4115]
     [-0.3475 -1.2729  0.0135 -0.0092]
     [ 1.3637 -0.3827  0.3039 -0.3609]
     [-0.2053 -0.7371  0.0209  1.2956]]
    grad
     [[ 6943.7407     0.     -1444.9719     0.    ]
     [ 4873.084      0.     -1014.0744     0.    ]
     [22692.582      0.     -4722.259      0.    ]
     [    0.         0.         0.         0.    ]] 
    
    #### w7 ####
    weights
     [[ 0.7484  0.1641  0.3732  1.5976]
     [-0.3619  1.863  -1.3552  0.4618]
     [ 0.4905  0.2353 -0.4943  0.2616]
     [-0.004  -0.4753 -0.1268  0.2372]]
    grad
     [[-10202.121    1074.2303   5650.8564  14572.965 ]
     [     0.          0.          0.          0.    ]
     [ -1852.5332    195.0621   1026.1002   2646.205 ]
     [     0.          0.          0.          0.    ]] 
    
    #### w8 ####
    weights
     [[-0.515   0.6908  0.0826  1.5223]
     [-0.3799  0.8091  0.0506 -1.2749]
     [ 0.6146  1.4947  0.8942 -0.1798]
     [ 1.3187 -0.5813  1.2728 -0.3163]]
    grad
     [[ 6438.6147  -893.7505  1732.8018 -3122.6318]
     [ 1589.804   -220.6823   427.8584  -771.0312]
     [ 2178.9844  -302.467    586.4224 -1056.7747]
     [12648.798  -1755.7922  3404.1267 -6134.4775]] 
    
    #### w9 ####
    weights
     [[-2.6363]
     [ 0.366 ]
     [-0.7095]
     [ 1.2786]]
    grad
     [[-5348.286 ]
     [ -621.3683]
     [-7078.1904]
     [-1282.6868]]