[machine learning] use pyplot to draw handwritten digits in MNIST dataset

Posted by WLW on Mon, 24 Jan 2022 07:13:15 +0100

MNIST data set is a set of handwritten digital data set given by Yann LeCun, the leader of artificial intelligence. The training set contains 60000 samples and labels, and the test set contains 10000 samples and labels. It can be used for novices to practice their hands.

Dataset representation

  1. Annotation: the number is divided into 0-9, a total of 10 numbers. The annotation is also from 0-9, corresponding to 0-910 numbers respectively;
  2. Picture: each picture is divided into 2828 matrices. Each element of the matrix is represented by gray value, so a total of 2828 matrices are used to represent the picture;

Download dataset

Data set download address: http://yann.lecun.com/exdb/mnist/
The entire data set is divided into four parts:

  • train-images-idx3-ubyte.gz: training set pictures (9912422 bytes)
  • train-labels-idx1-ubyte.gz: training set label (28881 bytes)
  • t10k-images-idx3-ubyte.gz: test set picture (1648877 bytes)
  • t10k-labels-idx1-ubyte.gz: test set annotation (4542 bytes)

Parsing dataset files

On LeCun's website, the format of data set is given, and the points needing attention are:

  1. Storing binary data;
  2. Use big end storage method;
  3. Annotation set (including training annotation set and test annotation set): the first byte is Magic Number, the second byte is the total number of annotations (training set - 60000, test set 10000), and each subsequent byte is the corresponding annotation value;
  4. Picture set (including training picture set and test picture set): the first byte is Magic Number, the second byte is the total number of pictures (training set - 60000, test set 10000), the fourth byte is the rows of each picture representing the matrix, and the fifth byte is the cols of each picture representing the matrix

Draw picture

Use pyplot to plot the 10 numbers 0-9. Generally speaking, there are the following steps:

  1. Load pictures and label data;
  2. Create a 2 * 5 canvas using the subplots method;
  3. Expand the canvas and draw 0-9 numbers with imshow method;
  4. plt.show();

Specific code

  1. Load data (load_data.py)
import os
import struct
import numpy as np


def load_mnist(path, kind='train'):
    """Splicing path"""
    labels_path = os.path.join(path, '%s-labels-idx1-ubyte' % kind)
    images_path = os.path.join(path, '%s-images-idx3-ubyte' % kind)

    with open(labels_path, 'rb') as lbpath:
    	"""Use the big endian method to read 2 bytes. The first is magic number and the second is number"""
        magic, n = struct.unpack('>II',
                                 lbpath.read(8))
        """Read the dimension values in turn"""
        labels = np.fromfile(lbpath,
                             dtype=np.uint8)

    with open(images_path, 'rb') as imgpath:
        """Use the big end method to read 4 bytes, the first is magic number, the second is number, and three or four are respectively rows,cols"""
        magic, num, rows, cols = struct.unpack('>IIII',
                                               imgpath.read(16))
        """Read the values in sequence and reshape by length*784 Matrix of"""
        images = np.fromfile(imgpath,
                             dtype=np.uint8).reshape(len(labels), 784)

    return images, labels

  1. Draw picture (main.py)
import load_data
import matplotlib.pyplot as plt

"""Load data"""
images, labels = load_data.load_mnist('/Users/wowo/Documents/0-Tensorflow')

"""Create canvas"""
fig, ax = plt.subplots(
    nrows=2,
    ncols=5,
    sharex=True,
    sharey=True,
)

"""Tile canvas"""
ax = ax.flatten()
for i in range(10):
    """Gets the first occurrence of 0 in the dataset-9 Numbers, and reshape To 28*28 Matrix of"""
    img = images[labels == i][0].reshape(28, 28)
    """Draw numbers"""
    ax[i].imshow(img, cmap='Greys', interpolation='nearest')

"""Hide abscissa and ordinate"""
ax[0].set_xticks([])
ax[0].set_yticks([])
"""Beautify the canvas to make it more compact"""
plt.tight_layout()
"""Draw canvas"""
plt.show()

Reference documents

  1. http://yann.lecun.com/exdb/mnist/
  2. https://www.cnblogs.com/xianhan/p/9145966.html

Topics: Machine Learning