[use of Dataset and DataLoader for data loading]

Posted by thenior on Tue, 08 Feb 2022 03:51:21 +0100

Data loading in pytoch

In deep learning, the amount of data is usually very large. With such a large amount of data, it is impossible to carry out forward calculation and back propagation in the model at one time. Often, we will randomly disrupt the order of the whole data, process the data into batch es one by one, and preprocess the data at the same time.

Introduction to dataset class

The base class of data set torch. Is provided in torch utils. data. Dataset, which inherits this base class, can quickly load data.
torch. utils. data. The source code of dataset is as follows:

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])
  • __ len__ Method: get the number of elements through the global len() method
  • __ getitem__ Method: the data can be obtained by passing in the index, for example: dataset[i] obtains the i-th data.
    -__ add__ Method: merge the two data sets together.

2. Data loading cases

Data source: http://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection

Data introduction: SMS Spam Collection is a classic data set for harassing SMS identification, which is completely from the real SMS content, including 4831 normal SMS and 747 harassing SMS. Normal SMS and harassing SMS are saved in a text file. Each line completely records the content of a short message. At the beginning of each line, ham and spam are used to identify normal short messages and harassment short messages.

Then instantiate the Dataset and get the data iteratively:

The data are as follows:

Then instantiate the Dataset and iterate to get the data

d = CifarDataset()
for i in range(len(d)):
    print(i,d[i])

The data of the first 5 rows is taken and the output is as follows:

3. Data loading code

import torch
from torch.utils.data import Dataset

data_path = 'data/SMSSpamCollection'

#Complete dataset class
class MyDataset(Dataset):
    def __init__(self):
        self.lines = open(data_path,encoding='utf-8').readlines()

    def __getitem__(self, index):
        #Get a piece of data corresponding to the index
        cur_line =  self.lines[index].strip()
		#Separation of data and labels
        label = cur_line[:4].strip()
        content = cur_line[4:].strip()
        return label,content
    def __len__(self):
        return len(self.lines)

if __name__ == '__main__':
    my_dataset = MyDataset()
    print(my_dataset[0])
    print(len(my_dataset))

4. Iterative data set

The above method can be used to read data, but there are still many contents that have not been realized:

  • Batching the data
  • Shuffling the data
  • Use multithreading and multiprocessing to load data in parallel.

In pytorch utils. data. Dataloader provides the methods used above

Example of how to use DataLoader:

from torch.utils.data import DataLoader
dataset = CifarDataset()
data_loader = DataLoader(dataset=dataset,batch_size=10,shuffle=True,num_workers=2,drop_last=True)

#Traversal to obtain the results of each batch
for index, (label, context) in enumerate(data_loader):
    print(index,label,context)
    print("*"*100)

Where, parameter meaning:

  1. Dataset: an instance of a dataset defined in advance

  2. batch_size: the size of the batch of incoming data, commonly used 128256, etc

  3. shuffle: bool type, indicating whether to disrupt the data in advance each time data is obtained

  4. num_workers: number of threads loading data

  5. drop_last=True means that when the data cannot be rounded, the last batch is rounded off.

  6. be careful
    (1)len(dataset) = number of samples in the dataset
    (2)len(dataloader) = math. Ceil (number of samples / batch_size) is rounded up

Topics: AI Pytorch Deep Learning