pytorch learning - datasets & dataloaders

Posted by mark_h_uk on Thu, 30 Dec 2021 02:33:01 +0100

The code for processing data samples may become chaotic and difficult to maintain; We ideally want our Dataset code to be separated from our model training code for better readability and modularity. PyTorch provides two data primitives: torch utils. data. DataLoader and torch utils. data. Datasets, which allow you to use preloaded datasets as well as your own data. The Dataset stores samples and their corresponding tags. The DataLoader wraps an iteratable object around the Dataset to easily access the samples.

The PyTorch domain library provides many preloaded datasets (such as FashionMNIST), which are subclasses of torch.utils.data.Dataset and implement data specific functions. They can be used to prototype and benchmark models. You can find them here: Image Datasets, Text Datasets, and Audio Datasets.

1. Load dataset

The following is an example of how to load a fashion MNIST dataset from TorchVision. Fashion MNIST is a data set of Zalando article images, which consists of 60000 training examples and 10000 test examples. Each example contains a 28 × 28 grayscale images and related labels from one of 10 categories.

We load the FashionMNIST dataset with the following parameters:

  • root is the path to store training / test data,
  • train specifies the training or test data set,
  • download=True downloads data from the Internet if root is not available.
  • transform and target_transform specifies the feature and label transformations.
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

Out:

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

2. Iterative and visual data sets

We can manually index the dataset like a list: training_data[index]. We use matplotlib to visualize some samples in the training data.

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

3. Create a custom dataset

The custom Dataset class must implement three functions: init, len, and getitem. Look at this implementation; FashionMNIST images are stored in the directory img_dir, their labels are stored in a CSV file annotations_file.txt.
In the next section, we will decompose what happens in each function.

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

__init__

__ init__ The function runs once when instantiating the Dataset object. We initialize the directory containing the image, the annotation file, and two transformations (described in detail in the next section).

label. The CSV file is as follows:

tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file)
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform

__len__

__ len__ Function returns the number of samples in our dataset.

example:

def __len__(self):
    return len(self.img_labels)

__getitem__

__ getitem__ The function loads and returns a sample from the data set of the given index idx. Based on the index, it identifies the position of the image on the disk and uses read_image converts it to a tensor from self img_ Retrieve the corresponding tag from the csv data in labels, call the conversion function (if applicable), and return the tensor image and the corresponding tag in the tuple.

def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    return image, label

4. Prepare data for training using DataLoaders

Dataset retrieves the characteristics of our dataset and marks one sample at a time. When training models, we usually want to transfer samples in the form of "small batch", reshuffle the data in each period to reduce over fitting of the model, and use Python's multiprocessing to accelerate data retrieval.

DataLoader is an iterator that abstracts this complexity for us in a simple API.

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

5. Traverse through DataLoader

We have loaded the dataset into the · DataLoader · and can traverse the dataset as needed. Each iteration below will return a batch of · trains_ Features · and · train_labels · (including · batch_size=64 · features and labels respectively). Because we specify · shuffle=True ·, the data will be disrupted after we traverse all batches (for more fine-grained control over the data loading order, please check Samplers).

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")


Out:

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 5

6.Further Reading

torch.utils.data API

Topics: Python Pytorch torch