A simple CNN with Pytorch

Find the code for this blog post here: https://github.com/puzzler10/simple_pytorch_cnn

We will build a classifier on CIFAR10 to predict the class of each image, using PyTorch along the way.

This is basically following along with the official Pytorch tutorial except I add rough notes to explain things as I go. There were a lot of things I didn’t find straightforward, so hopefully this piece can help someone else out there. If you’re reading this, I recommend having both this article and the Pytorch tutorial open.

Start off with initialisations

import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt 
import pandas as pd 
%matplotlib inline 
path_data = './data/'

Getting the data Link to heading

Pytorch provides a package called torchvision that is a useful utility for getting common datasets. Using this package we can download train and test sets CIFAR10 easily and save it to a folder. The training set is about 270MB. If you’ve already downloaded it once, you don’t have to redownload it.

train = torchvision.datasets.CIFAR10(root=path_data, train=True, download=True) 
test = torchvision.datasets.CIFAR10(root=path_data, train=False, download=True) 
Files already downloaded and verified
Files already downloaded and verified

Let’s look at train. What we get from this is a class called CIFAR10.


Let’s inspect this object. There’s a few useful things you can do with this class:

  • train.classes: view the output classes as strings
  • train.targets: numericalised output classes (0-9)
  • train.class_to_idx: mapping between train.classes and train.targets
  • train.data: has the raw data as a PIL Image, held in numpy.ndarray format in the range [0,255], in the order (H x W x C). This has shape (50000, 32, 32, 3).

As always train.__dict__ lets you see everything at once.

Now its time to transform the data. Use torchvision.transforms for this. Some basic transforms:

  • transforms.ToTensor(): convers PIL/Numpy to Tensor format. It converts a PIL Image or numpy.ndarray with range [0,255] and shape (H x W x C) to a torch.FloatTensor of shape (C x H x W) and range [0.0, 1.0]. So this operation also rescales your data. It’s not a simple “ndarray –> tensor” operation.

  • transforms.Normalize(): normalises each channel of the input Tensor. The formula is this: input[channel] = (input[channel] - mean[channel]) / std[channel]. You have to pass in two parameters: a sequence of means for each channel, and a sequence of standard deviations for each channel. In practice you see this called as transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) for the CIFAR10 example, rather than transforms.Normalize((127.5,127.5,127.5), (some_std_here)) because it is put after transforms.ToTensor() and that rescales to 0-1.

  • transforms.Compose(): the function that lets you chain together different transforms.

Let’s build the transform now:

cifar_transform = transforms.Compose([
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))

Most examples specify a transform when calling a dataset (like torchvision.datasets.CIFAR10) using the transform parameter. But I think you can also just add it to the transform and transforms attribute of the train or test object. The transform doesn’t get called at this point anyway: you need a DataLoader to execute it. The difference with transforms is you need to run it through the torchvision.datasets.vision.StandardTransform class to get the exact same behaviour. So do this:

train.transform = cifar_transform
test.transform = cifar_transform
train.transforms = torchvision.datasets.vision.StandardTransform(cifar_transform)
test.transforms = torchvision.datasets.vision.StandardTransform(cifar_transform)

and it should be fine. Now use train.transform or train.transforms to see if it worked:

    Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))


Transform: Compose(
               Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))

Note train.data remains unscaled after the transform. Transforms are only applied with the DataLoader.

Datasets and DataLoaders Link to heading

There are two types of Dataset in Pytorch.

The first type is called a map-style dataset and is a class that implements __len__() and __getitem__(). You can access individual points of one of these datasets with square brackets (e.g. data[3]) and it’s the type of dataset for most common needs. Then there’s the iterable-style dataset that implements __iter__() and is used for streaming-type things. The Dataset class is a map-style dataset and the IterableDataset class is an iterable-style dataset. In this case CIFAR10 is a map-style dataset.

The DataLoader class combines with the Dataset class and helps you iterate over a dataset. You can specify how many data points to take in a batch, to shuffle them or not, implement sampling strategies, use multiprocessing for loading data, and so on. Without using a DataLoader you’d have a lot more overhead to get things working: implement your own for-loops with indicies, shuffle batches yourself and so on.

We make a loader for both our train and test set. The tutorial sets shuffle=False for the test set; there’s no need to shuffle the data when you are just evaluating it.

trainloader = torch.utils.data.DataLoader(train, batch_size=4,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(train, batch_size=4,
                                          shuffle=False, num_workers=2)

You can get some data by converting trainloader to an iterator and then calling next on it. This gives us a list of length 2: it has both the training data and the labels, or in common maths terms, (X, y).

Let’s look at a batch in raw form.

train_iter = iter(trainloader)
images, labels = train_iter.next()
tensor([[[0.2941, 0.2627, 0.3255,  ..., 0.7804, 0.7647, 0.7490],
         [0.4118, 0.3569, 0.4039,  ..., 0.8039, 0.7882, 0.7725],
         [0.7804, 0.7412, 0.7490,  ..., 0.8039, 0.7882, 0.7725],
         [0.6627, 0.6314, 0.6235,  ..., 0.6863, 0.7647, 0.8118],
         [0.6627, 0.6471, 0.6549,  ..., 0.7725, 0.8118, 0.8431],
         [0.6471, 0.6392, 0.6471,  ..., 0.7882, 0.8039, 0.8275]],

        [[0.3647, 0.3333, 0.3961,  ..., 0.8510, 0.8353, 0.8196],
         [0.4824, 0.4275, 0.4745,  ..., 0.8745, 0.8588, 0.8431],
         [0.8510, 0.8118, 0.8196,  ..., 0.8745, 0.8588, 0.8431],
         [0.7647, 0.7333, 0.7255,  ..., 0.8275, 0.8902, 0.9137],
         [0.7647, 0.7490, 0.7569,  ..., 0.8824, 0.9137, 0.9216],
         [0.7490, 0.7412, 0.7490,  ..., 0.8745, 0.8824, 0.8980]],

        [[0.2314, 0.2000, 0.2627,  ..., 0.6863, 0.6706, 0.6549],
         [0.3412, 0.2863, 0.3333,  ..., 0.7098, 0.6941, 0.6784],
         [0.6941, 0.6549, 0.6627,  ..., 0.7098, 0.6941, 0.6784],
         [0.6235, 0.5922, 0.5922,  ..., 0.7882, 0.8118, 0.7882],
         [0.6235, 0.6078, 0.6078,  ..., 0.8118, 0.8039, 0.7725],
         [0.6078, 0.6078, 0.6157,  ..., 0.7490, 0.7412, 0.7255]]])

The batch has shape torch.Size([4, 3, 32, 32]), since we set the batch size to 4. It’s also been rescaled to be between -1 and 1, or in other words: all the transforms in cifar_transform have been executed now.

Plotting images Link to heading

It’d be useful to us to try and plot what’s in images as actual images. So let’s do that.

The images array is a Tensor and is arranged in the order (B x C x H x W), where B is batch size, C is channels, H height and W width. If we want to use image plotting methods from matplotlib like imshow, we need each image to look like (H x W x C). Another problem is that imshow needs values between 0 and 1, and currently our image values are between -1 and 1.

Useful to this is the function torchvision.utils.make_grid(). What this does is take a bunch of separate images and squish them together into a ‘film-strip’ style image with axes in order of (C x H x W) with some amount of padding between each image. So we’ll do this to merge our images, reshape the axes with np.transpose() into an imshow compatible format, and then we can plot them.

def plot_images(images, labels): 
    # normalise=True below shifts [-1,1] to [0,1]
    img_grid = torchvision.utils.make_grid(images, nrow=4, normalize=True)
    np_img = img_grid.numpy().transpose(1,2,0)  
d_class2idx = train.class_to_idx
d_idx2class = dict(zip(d_class2idx.values(),d_class2idx.keys()))

images, labels = train_iter.next()
print(' '.join('%5s' % d_idx2class[int(labels[j])]for j in range(len(images))))
 ship   dog  deer  bird


Making the CNN Link to heading

import torch.nn as nn 
import torch.nn.functional as F 

We’re going to define a class Net that has the CNN. This class will inherit from nn.Module and have two methods: an __init__() method and a forward() method. Here’s the architecture (except ours is on CIFAR, not MNIST):

Neural network components Link to heading

It looks like all layers run only for a batch of samples and not for a single point. The input to a nn.Conv2d layer for example will be something of shape (nSamples x nChannels x Height x Width), or (S x C x H x W).

If you want to put a single sample through, you can use input.unsqueeze(0) to add a fake batch dimension to it so that it will work properly.

Convolutional layers Link to heading

The first type of layer we have is a 2D convolutional layer, implemented using nn.Conv2d(). This has three compulsory parameters:

  • in_channels: basically the number of channels in your input data. For CIFAR we have 3 channels (RGB), for a dataset like MNIST we would only have one.
  • out_channels: the number of convolutional filters you’d like to have in this layer. The network will learn the weights for all of these.
  • kernel_size: the size of the square filter. Set to 3 to use 3x3 conv filters, 2 to use 2x2 conv filters etc.

There are also a bunch of other parameters you can set: stride, padding, dilation and so forth. This link has a good description of these parameters and how they affect the results.

MaxPooling layers Link to heading

These are called nn.MaxPool2d(). The compulsory parameter is kernel_size and it sets the size of the square window to which the maxpool operator is called. Like before you can set strides and other parameters.

Linear layers Link to heading

A simple linear layer of the form y = XW + b.

Parameters: in_features (neurons coming into the layer), out_features (neurons going out of the layer) and bias, which you should set to True almost always.

Activations Link to heading

We will use ReLu activations in the network. We can find that in F.relu and it is simple to apply. We’ll use the forward method to take layers we define in __init__ and stitch them together with F.relu as the activation function.

Reshaping Link to heading

If x is a Tensor, we use x.view to reshape it. For example, if x is given by a 16x1 tensor. x.view(4,4) reshapes it to a 4x4 tensor. You can write -1 to infer the dimension on that axis, based on the number of elements in x and the shape of the other axes. For example, x.view(2,-1) returns a Tensor of shape 2x8. Only one axis can be inferred.

The view function doesn’t create a new object. The object returned by view shares data with the original object, so if you change one, the other changes. This contrasts with np.reshape, which returns a new object.

class Net(nn.Module): 
    def __init__(self): 
        self.conv1 = nn.Conv2d(3, 6, 5)
        # we use the maxpool multiple times, but define it once
        self.pool = nn.MaxPool2d(2,2)
        # in_channels = 6 because self.conv1 output 6 channel
        self.conv2 = nn.Conv2d(6,16,5) 
        # 5*5 comes from the dimension of the last convnet layer
        self.fc1 = nn.Linear(16*5*5, 120) 
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    def forward(self, x): 
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)  # no activation on final layer 
        return x

net = Net()

Loss function and optimisation Link to heading

We will use a cross entropy loss, found in the function nn.CrossEntropyLoss(). This function expects raw logits as the final layer of the neural network, which is why we didn’t have a softmax final layer. The function also has a weights parameter which would be useful if we had some unbalanced classes, because it could oversample the rare class.

Optimisation is done with stochastic gradient descent, or optim.SGD. The first argument is the parameters of the neural network: i.e. you are giving the optimiser something to optimise. Second argument is the learning rate, and third argument is an option to set a momentum parameter and hence use momentum in the optimisation. Other options include dampening for momentum, l2 weight decay and an option for Nesterov momentum.

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

Note that nn.CrossEntropyLoss() returns a function, that we’ve called criterion, so when you see criterion later on remember it’s actually a function.

Training Link to heading

Let’s go through how to train the network.

A reminder: we’d defined trainloader like this:

trainloader = torch.utils.data.DataLoader(train, batch_size=4, 
                                          shuffle=True, num_workers=2)

If we iterate through trainloader we get tuples with (data, labels), so we’ll have to unpack it. The variable data refers to the image data and it’ll come in batches of 4 at each iteration, as a Tensor of size (4, 3, 32, 32). labels will be a 1d Tensor.

Next we zero the gradient with optimizer.zero_grad(). Gradients aren’t reset to zero after a backprop step, so if we don’t do this, they’ll accumulate and won’t be correct.

Then comes the forward pass. We created an instance of our Net module earlier, and called it net. We can put an image through the network directly with net(inputs), which is the same as the forward pass.

Loss is easy: just put criterion(outputs, labels), and you’ll get a tensor back. Backpropagate with loss.backward(), and rely on the autograd functionality of Pytorch to get gradients for your weights with respect to the loss (no analytical calculations of derivatives required!)

Update weights with optimizer.step(). This uses the learning rate and the algorithm that you seeded optim.SGD with and updates the parameters of the network (that you also gave to optim.SGD)

Finally, we’ll want to keep track of the average loss. This will let us see if our network is learning quickly enough. We’ll print out diagnostics every so often.

running_loss = 0 
printfreq = 1000
for epoch in range(2):
    for i, data in enumerate(trainloader):
        inputs, labels = data
        outputs = net(inputs)  # forward pass 
        loss = criterion(outputs, labels)

        running_loss += loss.item()
        if i % printfreq == printfreq-1:  
            print(epoch, i+1, running_loss / printfreq)
            running_loss = 0 
0 1000 2.261470085859299
0 2000 1.9996697949171067
0 3000 1.8479941136837006
0 4000 1.735244540154934
0 5000 1.6975953910946846
0 6000 1.6573741003870963
0 7000 1.5628319020867347
0 8000 1.5547611890733242
0 9000 1.5162546570003033
0 10000 1.5124652062356472
0 11000 1.4926373654603957
0 12000 1.4613762157261372
1 1000 2.169555962353945
1 2000 1.4098796974420547
1 3000 1.4137214373350144
1 4000 1.3870477629005908
1 5000 1.351047681376338
1 6000 1.3671476569324732
1 7000 1.329134451970458
1 8000 1.32600463321805
1 9000 1.3019814370572567
1 10000 1.3028037322461605
1 11000 1.309966327637434
1 12000 1.2809565272629262

Saving, loading, and state_dict Link to heading

It is good to save and load models. Models can take a long time to train, so saving them after they are done is good to avoid retraining them. You’ll also need a way to reload them.

When saving a model, we want to save the state_dict of the network (net.state_dict(), which holds weights for all the layers. This doesn’t save any of the optimiser information, so if we want to save that, we can also save optimiser.state_dict() too.

Let’s have a look in the state_dict of our net that we trained:

for param_tensor in net.state_dict():
    print(param_tensor, "\t", net.state_dict()[param_tensor].size())
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])

We can see the bias and weights are saved, each in the correct shape of the layer.

Let’s look at the state_dict of the optimiser object too:

dict_keys(['state', 'param_groups'])
[{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [5019782816, 5019782960, 5019782600, 5019782672, 5019782888, 5019781736, 5019782456, 5019779720, 5019781592, 5019779792]}]

There’s more, but it’s big, so I won’t print it.

Saving and loading is done with torch.save, torch.load, and net.load_state_dict. Saving an object will pickle it. It seems to be a PyTorch convention to save the weights with a .pt or a .pth file extension.

fname = './models/CIFAR10_cnn.pth'
torch.save(net.state_dict(), fname)
loaded_dict = torch.load(fname)
<All keys matched successfully>

Some layers like Dropout or Batchnorm won’t work properly if you don’t call net.eval() after loading.

  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)

There is much more to saving and loading than this. Check out this guide for more information. Some examples: transfer learning, model checkpointing, transferring from CPU to GPU and vice versa.

Model testing Link to heading

It’s time to see how our trained net does on the test set. Will it have learned anything?

# Reload net if needed
fname = './models/CIFAR10_cnn.pth'
loaded_dict = torch.load(fname)
<All keys matched successfully>

As a sanity check, let’s first take some images from our test set and plot them with their ground-truth labels:

dataiter = iter(testloader)
images, labels = dataiter.next()
print(' '.join('%5s' % d_idx2class[int(labels[j])]for j in range(len(images))))
 frog truck truck  deer


Looks good. Now let’s run the images through our net and see what we get.

outputs = net(images)
tensor([[-2.4268, -2.2433,  0.7154,  1.9097,  2.9379,  1.5184,  1.5930,  0.7512,
         -3.0894, -3.7157],
        [ 2.6312,  2.0198, -0.4813, -1.5693, -2.4001, -2.0169, -4.2537,  0.2893,
          2.6629,  3.0607],
        [ 1.0008,  0.0441, -1.0073, -0.1218, -0.4955, -1.1580, -1.8499, -0.0223,
          1.5421,  1.3126],
        [-1.0473, -2.0484,  0.8610,  0.7433,  4.1755,  0.5730,  2.4846, -0.3962,
         -2.6474, -3.2877]], grad_fn=<AddmmBackward>)

These are logits for each of the ten classes. Since the highest logit will be the predicted class, we can generate labels easily from the logits.

preds = outputs.argmax(dim=1)
print(' '.join('%5s' % d_idx2class[int(preds[j])]for j in range(len(images))))
 deer truck  ship  deer


It’s got some right, not all. For a better evaluation of performance, we’ll have to look at performance across the entire test set.

Most of the code follows similar patterns to the training loop above. I’ll comment on the things I find interesting.

A useful function is torch.max(). This returns a namedtuple with the standard max values along an axis, but somewhat usefully also the argmax values along that axis, too. This is good for us because we don’t really care about the max value, but more its argmax, since that corresponds to the label.

We can do element-wise comparison with == on PyTorch tensors (and Numpy arrays too). We see this in the line using predicted == labels below, which will return a vector filled with True/False values. If predicted and labels were lists, by comparison, we’d just get a single True value if all the elements were the same, or False if any were different. `

Note the code is inside the torch.no_grad() context manager, which turns off gradient tracking for the variables. It’s claimed that this reduces memory usage, and increases computation speed. I wasn’t sure, so I did a rudimentary speed test. I didn’t track the memory usage, but there is definitely a speed benefit.

  • with no_grad(): 50.8 s
  • without no_grad(): 55.6s

Given it’s one line, it’s probably worth the effort to do.

total = 0  # keeps track of how many images we have processed 
correct = 0  # keeps track of how many correct images our net predicts
with torch.no_grad():
    for i, data in enumerate(testloader): 
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size()[0]
        correct += (predicted == labels).sum().item()
print("Accuracy: ", correct/total)
Accuracy:  0.56358

We are tracking at about 56% accuracy.

Prediction analysis Link to heading

We’re going to want to know how our model does on different classes. It’s unlikely its predictions for every class will be similarly accurate.

class_correct = list(0 for i in range(10))  # Holds how many correct images for the class
class_total = list(0 for i in range(10))  # Holds total images for the class 

with torch.no_grad(): 
    for i, data in enumerate(testloader): 
        images, labels = data 
        outputs = net(images) 
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels)
        for j in range(4): 
            label = labels[j]
            class_correct[label] += c[j].item()
            class_total[label] += 1
for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        d_idx2class[i], 100 * class_correct[i] / class_total[i]))
Accuracy of airplane : 69 %
Accuracy of automobile : 69 %
Accuracy of  bird : 35 %
Accuracy of   cat : 44 %
Accuracy of  deer : 60 %
Accuracy of   dog : 33 %
Accuracy of  frog : 57 %
Accuracy of horse : 73 %
Accuracy of  ship : 73 %
Accuracy of truck : 46 %

You can see significant differences in the accuracy of different classes. For example, our network is bad at predicting birds, but better at predicting horses.