pytorch training classifier

Posted by hunainrasheed on Tue, 22 Feb 2022 14:24:29 +0100

Training a Classifier

We learned how to define the neural network, calculate the loss and update the network weight

What about data?

Usually, when you have to deal with images, text, audio or video, you can use a standard python package that can load data into a numpy array, and then convert the array to torch* Tensor

  • Image: pilot, opencv
  • Audio: scipy, librosa
  • Text: original loading based on Python or python, or NLTK and SpaCy

Specifically for vision, a package named torchvision is created, including loaders for common data sets (ImageNet, CIFAR10, MNIST) and data converters for images (torchvision.datasets and

Provide great convenience and avoid writing boilerplate code

The CIFAR10 dataset is classified into 'airplane', 'automobile', 'bird', 'cat', 'der', 'dog', 'frog', 'horse', 'ship' and 'truck'. The image size in CIFAR-10 is 3x32x32, that is, a 3-channel color image with a size of 32x32 pixels

The 4D tensor of the image is (B,C,H,W)

  • B:batch size
  • C:channel
  • H:height
  • W:width

Training an image classifier

  • 1. Use torchvision to load and standardize CIFAR10 training and test data sets
  • 2. Define convolutional neural network
  • 3. Define loss function
  • 4. Training network based on training data
  • 5. Test network based on test data

1. Load and standardize CIFAR10

torchvision library includes data sets, models and image converters for computer vision. It is a graph of pytorch. torchvision includes the following:

  • torchvision.datasets: some functions for loading data and common data set interfaces
  • torchvision.models: contains common model structures (including pre training models), such as AlexNet, VGG, ResNet, etc
  • torchvision. Transformations: commonly used image transformations, such as cropping, rotation, etc
  • torchvision.utils: some other useful methods
import torch
import torchvision
import torchvision.transforms as transforms

The output of torchvision dataset is a PILImage image in the [0,1] range, which needs to be converted into a [- 1,1] tensor in the standardized range

torchvision.transforms.Compose combines multiple image transformations. Common transformations include:

  • ToTensor: transform the gray range from 0-255 to 0-1
  • Normalize: normalize the tensor image with mean and standard deviation
  • CenterCrop: Crop in the middle of the picture

Python image library PIL(Python Image Library) is the third-party image processing library of Python

An important interface for reading data in PyTorch is torch utils. data. Dataloader, which is defined in dataloader In the PY script, this interface is basically used whenever PyTorch is used to train the model. This interface is mainly used to package the output of the user-defined data reading interface or the input of the existing data reading interface of PyTorch into Tensor according to the batch size. Later, it only needs to be packaged into variables as the input of the model. Therefore, this interface plays an important role as a connecting link between the preceding and the following.

import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# Cancel certificate validation
transform = transforms.Compose(
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# transforms.Normalize(mean,std), the image size is 3 * 32 * 32, consistent
batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader =, batch_size=batch_size,shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader =, batch_size=batch_size,shuffle=False, num_workers=2)

# DataLoader data iterator, used to encapsulate data, num_ The number of threads that workers read data. If shuffle is set to True, it means shuffling data in each epoch

classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Files already downloaded and verified
Files already downloaded and verified

Let us show some of the training images, for fun.

Iteration is one of Python's most powerful features and a way to access collection elements. Strings, lists, and tuples can all be used to create iterators. The iterator object is accessed from the first element of the collection until all the elements are accessed. There are two methods:

  • iter() creates an iterator
  • next() returns the next item of the iterator.
for x in it:
    print(x,end=' ')
1 2 3 4 

Show some training images

import matplotlib.pyplot as plt
import numpy as np

# functions to show an image

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()  # Convert image to pinuary
    plt.imshow(np.transpose(npimg, (1, 2, 0))) # np. Transfer reverses or displaces the axis of the array

# get some random training images
# trainloader is equivalent to a list containing images and labels. Previously, shuffle is set to True, so the results will be different each time
dataiter = iter(trainloader)
images, labels =

# show images
# torchvision.utils.make_grid puts together several images into a grid
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

horse bird deer truck

2. Define convolutional neural network

Copy the neural network code from the previous part and change the image to 3 channels

nn.Conv2d: apply two-dimensional convolution on the input signal composed of multiple input planes

  • nn.Conv2d(in_channels,out_channels,kernel_size)

nn.MaxPool2d: apply a 2D max pool on the input signal composed of several input planes

  • nn.MaxPool2d(kernel_size,stride)

nn.Linear: applies a linear transformation to the input data y = x A T + b y=xA^T+b y=xAT+b

  • nn.Linear(in_features,out_features)

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        self.conv1 = nn.Conv2d(3, 6, 5)# Convolution calculation
        # The 32 * 32 original image of 3channel is converted into 28 * 28 image of 6channel after convolution calculation of 6 5 * 5 filters
        self.pool = nn.MaxPool2d(2, 2)# Pooling
        # The 28 * 28 image of 6 channel is pooled with 2 * 2 and becomes 14 * 14, string = kernel_ Size means there is no duplicate part, 28 / 2 = 14
        self.conv2 = nn.Conv2d(6, 16, 5)# Convolution calculation
        # The 14 * 14 image of 6 channel becomes the 10 * 10 image of 16 channel after convolution calculation of 16 5 * 5 filters
        #self.pool = nn.MaxPool2d(2, 2)
        # The 10 * 10 image of 16channel is pooled with 2 * 2 to 5 * 5, 10 / 2 = 5
        self.fc1 = nn.Linear(16 * 5 * 5, 120)# linear transformation 
        # The 5 * 5 tile of 16channel, i.e. 16 * 5 * 5, is used as the input F5 of FC first layer
        self.fc2 = nn.Linear(120, 84)
        # FC second layer F6
        self.fc3 = nn.Linear(84, 10)
        # FC layer 3 Gaussian layer output

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))# Convolution - > activation - > pooling
        x = self.pool(F.relu(self.conv2(x)))# Convolution - > activation - > pooling
        x = torch.flatten(x, 1) # All dimensions are tiled except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)# The last layer is Gaussian connection
        return x

net = Net()

2. Define loss function and optimizer

Using classified cross entropy loss and momentum SGD

torch.nn.CrossEntropyLoss: calculate the cross entropy loss between the input and the target value. It is suitable for classification problems with C categories. The input is the original non standardized score of each category

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

criterion type:<class 'torch.nn.modules.loss.MSELoss'>

loss type:<class 'torch.Tensor'>

3. Training network

Traverse the data iterator, feed the input to the network and optimize it

The enumerate() function is used to combine a traversable data object (such as list, tuple or string) into an index sequence, and list data and data subscripts at the same time. It is generally used in the for loop. The syntax enumerate(sequence, [start=0]):

  • Sequence: a sequence, iterator, or other object that supports iteration
  • Start: start to end of subscript
# Normal for loop
for e in sequence:
# for loop using enumerate
for i,e in enumerate(sequence,0):
0 one
1 two
2 three
0 one
1 two
2 three
for epoch in range(2):
    run_loss = 0.0 # Calculate average error
    # Get inputs. Data is a list [inputs,labels]
    for i,data in enumerate(trainloader,0):
        inputs,labels = data
        # Gradient clear 0
        # forward+loss+backward+optimize
        outputs = net(inputs)
        loss = criterion(outputs,labels)
        # If run is used_ loss+=run_ Loss, which will lead to memory explosion. Here, loss is a variable
        run_loss += loss.item()
        if i%2000 == 1999: # Output every 2000 Mini batches
            run_loss = 0.0
[1, 2000],loss:2.268
[1, 4000],loss:2.029
[1, 6000],loss:1.834
[1, 8000],loss:1.666
[2, 2000],loss:1.459
[2, 4000],loss:1.418
[2, 6000],loss:1.373
[2, 8000],loss:1.355

Save the trained model

PATH = './cifar_net.pth', PATH)

pytorch more models here

5. Test network based on test data

The network is trained twice based on the training data. In order to detect the network performance, the category label output by the neural network is predicted and compared with the actual situation. If the prediction is correct, the sample is added to the correct prediction table

First, display several images in the test set

dataitertest =iter(testloader)
images,labels =
# labels here are the categories represented by numbers
print('groundtruth:',' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))
tensor([3, 8, 8, 0])

groundtruth: cat   ship  ship  plane

Next, reload the saved model (it's not needed in practice. Here's how to save it)

net = Net()
<All keys matched successfully>

Now let's look at the neural network prediction of the above example

outputs = net(images)
tensor([[-0.5511, -1.2592,  1.0451,  1.7341,  0.2255,  1.0719,  0.3474, -0.0722,
         -0.7703, -1.7738],
        [ 2.9296,  4.5538, -0.4796, -1.7549, -2.4294, -2.7830, -3.4919, -3.0665,
          4.3148,  2.5193],
        [ 2.0322,  2.4424,  0.4408, -1.1508, -1.1923, -1.9300, -2.9568, -1.5784,
          2.8175,  2.0967],
        [ 3.1805,  2.2340,  0.1468, -1.6451, -0.8934, -2.9459, -3.4108, -2.2368,
          4.2390,  2.2832]], grad_fn=<AddmmBackward0>)

The output is the energy of 10 categories of 4 images. The higher the energy of a category, it means that the network tends to think that the image belongs to this category. Therefore, let's obtain the index of the highest energy

torch.max(input, dim, keepdim=False, out=None): returns the maximum value of all elements in the input tensor

torch.max(tensor,0): returns the element with the largest value in each column (1 row) and the index (returns the row index of the largest element in this column)

_,predicted = torch.max(outputs,1)
print("predicted:",''.join('%5s'%classes[predicted[j]] for j in range(4)))
predicted:   cat  car ship ship

Accuracy 75%

Next, let's look at the performance of the network in the whole data set

totalnum = 0
correctnum = 0
# There is no training, so there is no need to calculate the gradient of the output
with torch.no_grad():
    for data in testloader:
        images,labels = data
        # Forward propagation
        outputs = net(images)
        _,predicted = torch.max(outputs,1)
        # totalnum the number of all test images, and correctnum predicts the number of accurate images
        totalnum += labels.size(0)
        correctnum += (predicted==labels).sum().item()

print("Accuracy of the network on the 10000 test images:%d %%"%(100*correctnum/totalnum))
Accuracy of the network on the 10000 test images:55 %

Select a class randomly, and the accuracy rate is 10%, so the neural network training is better than random. Next, analyze which classes the network performs well and which classes do not perform well

zip([iterable,...]) The function is used to take the iteratable object as a parameter, package the corresponding elements in the object into tuples, and then return the object composed of these tuples. The advantage of this is to save a lot of memory. Use the list() transformation to output the list

# The appearance of list is not callable indicates that a variable name has been named list. Pay attention to the naming specification!
[(1, 4), (2, 5), (3, 6)]
# The dictionary stores the predicted correct quantity and total quantity of each category
correct_pred = {classname:0 for classname in classes}
total_pred = {classname:0 for classname in classes}
# Predict and count
with torch.no_grad():
    for data in testloader:
        images,labels = data
        outputs = net(images)
        _,predictions = torch.max(outputs,1)
        for label,prediction in zip(labels,predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1

for classname,correct_count in correct_pred.items():
    accuracy = 100*float(correct_count)/total_pred[classname]
    print('accuracy of %5s:%2d %%'%(classname,accuracy))
accuracy of plane:54 %
accuracy of   car:74 %
accuracy of  bird:49 %
accuracy of   cat:31 %
accuracy of  deer:53 %
accuracy of   dog:47 %
accuracy of  frog:60 %
accuracy of horse:58 %
accuracy of  ship:69 %
accuracy of truck:54 %

Training on GPU

GPU image processor: microprocessor specialized in image and graphics related operations. Just as tensor can be transferred to GPU, so can neural network. It cannot be realized without CUDA equipment

Topics: Python Machine Learning Pytorch Deep Learning