Data loading in pytoch
In deep learning, the amount of data is usually very large. With such a large amount of data, it is impossible to carry out forward calculation and back propagation in the model at one time. Often, we will randomly disrupt the order of the whole data, process the data into batch es one by one, and preprocess the data at the same time.
Introduction to dataset class
The base class of data set torch. Is provided in torch utils. data. Dataset, which inherits this base class, can quickly load data.
torch. utils. data. The source code of dataset is as follows:
class Dataset(object): """An abstract class representing a Dataset. All other datasets should subclass it. All subclasses should override ``__len__``, that provides the size of the dataset, and ``__getitem__``, supporting integer indexing in range from 0 to len(self) exclusive. """ def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError def __add__(self, other): return ConcatDataset([self, other])
- __ len__ Method: get the number of elements through the global len() method
- __ getitem__ Method: the data can be obtained by passing in the index, for example: dataset[i] obtains the i-th data.
-__ add__ Method: merge the two data sets together.
2. Data loading cases
Data source: http://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection
Data introduction: SMS Spam Collection is a classic data set for harassing SMS identification, which is completely from the real SMS content, including 4831 normal SMS and 747 harassing SMS. Normal SMS and harassing SMS are saved in a text file. Each line completely records the content of a short message. At the beginning of each line, ham and spam are used to identify normal short messages and harassment short messages.
Then instantiate the Dataset and get the data iteratively:
The data are as follows:
Then instantiate the Dataset and iterate to get the data
d = CifarDataset() for i in range(len(d)): print(i,d[i])
The data of the first 5 rows is taken and the output is as follows:
3. Data loading code
import torch from torch.utils.data import Dataset data_path = 'data/SMSSpamCollection' #Complete dataset class class MyDataset(Dataset): def __init__(self): self.lines = open(data_path,encoding='utf-8').readlines() def __getitem__(self, index): #Get a piece of data corresponding to the index cur_line = self.lines[index].strip() #Separation of data and labels label = cur_line[:4].strip() content = cur_line[4:].strip() return label,content def __len__(self): return len(self.lines) if __name__ == '__main__': my_dataset = MyDataset() print(my_dataset[0]) print(len(my_dataset))
4. Iterative data set
The above method can be used to read data, but there are still many contents that have not been realized:
- Batching the data
- Shuffling the data
- Use multithreading and multiprocessing to load data in parallel.
In pytorch utils. data. Dataloader provides the methods used above
Example of how to use DataLoader:
from torch.utils.data import DataLoader dataset = CifarDataset() data_loader = DataLoader(dataset=dataset,batch_size=10,shuffle=True,num_workers=2,drop_last=True) #Traversal to obtain the results of each batch for index, (label, context) in enumerate(data_loader): print(index,label,context) print("*"*100)
Where, parameter meaning:
-
Dataset: an instance of a dataset defined in advance
-
batch_size: the size of the batch of incoming data, commonly used 128256, etc
-
shuffle: bool type, indicating whether to disrupt the data in advance each time data is obtained
-
num_workers: number of threads loading data
-
drop_last=True means that when the data cannot be rounded, the last batch is rounded off.
-
be careful
(1)len(dataset) = number of samples in the dataset
(2)len(dataloader) = math. Ceil (number of samples / batch_size) is rounded up