Analysis of GitHub Open Source Project Hyperspectral-Classification

Posted by system on Wed, 11 Sep 2019 15:02:57 +0200

GitHub link: Hyperspectral-Classification Pytorch.

Project brief introduction

The author of the project is Xidian university, which is a classification program based on PyTorch hyperspectral image. The project is compatible with Python 2.7 and Python 3.5+, based on PyTorch in-depth learning and GPU computing framework, and uses Visdom Visual Server.

The predefined public datasets are:

  • university of pavia
  • Pavia Center
  • Kennedy Space Center
  • India pine
  • botswana

Users can also add custom datasets, such as the "Hyperspectral Data Set of Data Fusion Competition 2018" DFC2018_HSI. The developer should add a new entry for the CUSTOM_DATASETS_CONFIG variable and define a specific data loader for its use case.

The tool implements several variants of SVM in the scikit-learn s library and many of the most advanced deep networks implemented in PyTorch:

  • SVM (Linear, RBF and Multi-Core with Grid Search)
  • SGD (Fast optimization using linear SVM with random gradient descent)
    Baseline Neural Network (4 fully connected layers, missing)
  • 1D CNN (Depth Convolution Neural Network for Hyperspectral Image Classification, Hu et al., Journal of Sensors 2015)
  • Semi-supervised 1D CNN (Autoencodeurs pour la visualization d'images hyperspectrals, Boulch et al). GRETSI 2017)
  • 2D CNN (Hyperspectral CNN for image classification and band selection, Face Recognition, Sharma et al., Technical Report 2018)
  • Semi-supervised 2D CNN (Semi-supervised Convolutional Neural Network for Hyperspectral Image Classification, Liu et al., Remote Sensing Letter 2017)
  • 3D CNN (3-D in-depth learning method for remote sensing image classification, Hamida et al., TGRS 2018)
  • 3D FCN (Hyperspectral Classification Based on Context Depth CNN, Lee and Kwon, IGARSS 2016)
  • 3D CNN (Deep Feature Extraction Based on Convolutional Neural Network and Hyperspectral Image Classification, Chen et al., TGRS 2016)
  • 3D CNN (Spectral-Spatial Classification of Hyperspectral Images Based on 3D Convolutional Neural Network, Li et al., Remote Sensing 2017)
  • 3D CNN (HSI-CNN: A New Convolutional Neural Network for Hyperspectral Images, Luo et al., ICPR 2018)
  • Multiscale 3D CNN (Multiscale 3D Depth Convolution Neural Network, He et al., ICIP 2017) for Hyperspectral Image Classification

Users can also add custom deep networks by modifying models.py files. This means creating a new class for customizing the deep network and changing the get_model function.

Analysis of the Modules and Functions of the Project

utils.py

get_device(ordinal)

Functions:

According to the input parameters, the device is judged to be CPU or GPU.

Input and output:
Input:
  • ordinal: A number of int types indicating which GPU to use
Output:
  • device: A hyperparameter representing the location of the operation (CPU or GPU)
Code:
def get_device(ordinal):
    # Use GPU ?
    if ordinal < 0:
        print("Computation on CPU")
        device = torch.device('cpu')
    elif torch.cuda.is_available():
        print("Computation on CUDA GPU device {}".format(ordinal))
        device = torch.device('cuda:{}'.format(ordinal))
    else:
        print("/!\\ CUDA was requested but is not available! Computation will go on CPU. /!\\")
        device = torch.device('cpu')
    return device
Analysis:

In fact, it is a simple branch structure:

  • ordinal < 0: CPU
  • Ordinal < 0 and orch.cuda.is_available() == True: GPU
  • Ordinal < 0 and orch.cuda.is_available() == False: CPU

open_file(dataset)

Functions:

Open the file for the specified data set.

Input and output:
Input:
  • dataset: The complete path of the data set file, such as C: Datasets OwnData OwnData. mat.
output

(Take reading. mat as an example, because most of the files read are. mat files):

  • A dictionary with variable names as keys and data values.
Code:
def open_file(dataset):
    _, ext = os.path.splitext(dataset)
    ext = ext.lower()
    if ext == '.mat':
        # Load Matlab array
        return io.loadmat(dataset)
    elif ext == '.tif' or ext == '.tiff':
        # Load TIFF file
        return misc.imread(dataset)
    elif ext == '.hdr':
        img = spectral.open_image(dataset)
        return img.load()
    else:
        raise ValueError("Unknown file format: {}".format(ext))
Analysis:

Most importantly, the os.path.splitext(path) function in, ext = os.path.splitext(dataset).
This function divides the input path into file name + extension and takes it as the return value in turn. _ Ext means that only the extension is obtained and stored in the variable ext. Then we choose different ways of opening according to different extensions.

Note that when you open the. mat file, the return value is a dictionary with the variable name as the key and the data as the value. To extract the data, we need dictionary operation and access keys to get the value, such as img = open_file (folder + OwnData. mat') ['Data'].

convert_to_color_()

Functions:

Convert the label array to RGB color-coded image.

Input and output:
Input:
  • Arr_2d: Two-dimensional tag array of type int (int 2D array of labels)
  • palette: RGB tuple corresponding to each tag, with three values (dict of colors used (label number - > RGB tuple)
Output:
  • 2D image of color-encoded labels in int RGB format
Code:
def convert_to_color_(arr_2d, palette=None):
    """Convert an array of labels to RGB color-encoded image.

    Args:
        arr_2d: int 2D array of labels
        palette: dict of colors used (label number -> RGB tuple)    # Which label corresponds to what color (RGB three values)

    Returns:
        arr_3d: int 2D images of color-encoded labels in RGB format     # RGB three-channel image

    """
    arr_3d = np.zeros((arr_2d.shape[0], arr_2d.shape[1], 3), dtype=np.uint8)    # Determine the dimension and encoding method, row as arr_2d column, encoding method is uint8
    # Abnormal error reporting
    if palette is None:
        raise Exception("Unknown color palette")

    for c, i in palette.items():
        m = arr_2d == c
        arr_3d[m] = i

    return arr_3d
Analysis:

(Provisional)

convert_from_color_()

Functions:

The RGB coded image is transformed into gray label.

Input and output:
Input:
  • arr_3d: int 2D image of color-coded labels on 3 channels
  • palette: dict of colors used (RGB tuple -> label number)
Output:
  • arr_2d: int 2D array of labels
Code:
def convert_from_color_(arr_3d, palette=None):
    """Convert an RGB-encoded image to grayscale labels.

    Args:
        arr_3d: int 2D image of color-coded labels on 3 channels
        palette: dict of colors used (RGB tuple -> label number)

    Returns:
        arr_2d: int 2D array of labels

    """
    if palette is None:
        raise Exception("Unknown color palette")

    arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8)

    for c, i in palette.items():
        m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2)
        arr_2d[m] = i

    return arr_2d
Analysis:

(Provisional)

display_predictions()

Functions:

visdom visualization service is used to visualize the prediction results.

Input and output:
Input:
  • pred: Prediction results, two-dimensional
  • vis: vis Service
  • gt: ground truth
  • caption: Chart name
Output:
  • visdom Server Web Site Display Chart
Code:
def display_predictions(pred, vis, gt=None, caption=""):        # caption subtitle
    if gt is None:
        vis.images([np.transpose(pred, (2, 0, 1))],
                    opts={'caption': caption})
    else:
        vis.images([np.transpose(pred, (2, 0, 1)),
                    np.transpose(gt, (2, 0, 1))],
                    nrow=2,
                    opts={'caption': caption})
Analysis:

The whole function is a simple branching structure, which can be divided into two cases: gt is None and gt is not None.

When gt is None:
The vis.images() function draws a list image. It requires an input B x C x H x W (C - > channel; H - > height; W - > width) tensor or list of images of the same size. It makes the size of the image (B / Nrow, Nrow) grid.

The adjustable parameters of vis.images() are as follows:

  • nrow: Number of consecutive images
  • padding: Fill around the image and fill all sides evenly.
  • opts.jpgquality: JPG quality (number0-100; default = 100)
  • opts.caption: the title of the image

So vis. images ([np. transpose (pred, (2, 0, 1)], opts = {'caption': caption}) have two main parts.

  • [np. transpose (pred, (2, 0, 1)] represents the graph to be visualized.
  • Opts = {caption': caption} represents an optional operation.

np.transpose() is an array of exchange matrix dimensions (see another blog for details) Dimensional Exchange Function - a.transpose(m,n,r) Because the original image dimension ranking defaults to H * W * C, while vis.images() requires C * H * W. So the dimension ordering of tensor dimension is changed from (0, 1, 2) to (2, 0, 1), which is realized by np.transpose() function.

Opts = {caption': caption} is to add a title to the graph.

When gt is not None:
For pred and gt, the dimension is exchanged through the np.transpose() function, and the parameter nrow needs to be specified as 2.

display_dataset()

Functions:

Three bands are selected as RGB bands to display RGB synthetic images.

Input and output:
Input:
  • img: 3D hyperspectral image
  • gt: 2D array labels
  • bands: tuple of RGB bands to select
  • labels: list of label class names
  • palette: dict of colors
  • display (optional): type of display, if any

However, only img and bands are used, and none of the other four variables are used.

Output:
  • visdom Server Web Site Display Chart
Code:
def display_dataset(img, gt, bands, labels, palette, vis):
    """Display the specified dataset.

    Args:
        img: 3D hyperspectral image
        gt: 2D array labels
        bands: tuple of RGB bands to select
        labels: list of label class names
        palette: dict of colors
        display (optional): type of display, if any

    """
    print("Image has dimensions {}x{} and {} channels".format(*img.shape))
    rgb = spectral.get_rgb(img, bands)          # Extract RGB data from SpyFile objects or numpy arrays for display.
    rgb /= np.max(rgb)                          # Maximization
    rgb = np.asarray(255 * rgb, dtype='uint8')  # Convert to ndarray type

    # Display the RGB composite image to display RGB composite images
    caption = "RGB (bands {}, {}, {})".format(*bands)       # * To disassemble variables
    # send to visdom server
    vis.images([np.transpose(rgb, (2, 0, 1))],
                opts={'caption': caption})
Analysis:

Firstly, the bands specified in the IMG are obtained by rgb = spectral.get_rgb(img, bands) as RGB bands. Then the rgb /= np.max(rgb) is maximized and the value is scaled to [0,1]. Then RGB is scaled to [0,255] by rgb = np.asarray(255 * rgb, dtype='uint8'), and dtype='uint8', or RGB image encoded by uint8, is set.

Then there is the operation of visdom server. First, set the title caption = RGB (bands {}, {}, {}). format(*bands), where format(*bands) disassembles the bands of the list type (I guess) through * and outputs them separately. vis.images() is then called to visualize rgb, and the parameters are parsed in the display_predictions() function section above.

explore_spectrums()

(Provisional)

plot_spectrums()

(Provisional)

build_dataset()

Functions:

Create a list of training samples based on images and masks.

Input and output:
Input:
  • mat: 3D hyperspectral matrix to extract the spectrums from # for extracting the hyperspectral matrix of the spectrum
  • gt: 2D ground truth
  • ignored_labels (optional): list of classes to ignore, e.g. 0 to remove
Output:
  • Create a list of training samples based on an image and a mask.
Code:
ef build_dataset(mat, gt, ignored_labels=None):
    """Create a list of training samples based on an image and a mask.

    Args:
        mat: 3D hyperspectral matrix to extract the spectrums from      # Hyperspectral Matrix for Spectrum Extraction
        gt: 2D ground truth
        ignored_labels (optional): list of classes to ignore, e.g. 0 to remove
        unlabeled pixels
        return_indices (optional): bool set to True to return the indices of
        the chosen samples

    """
    samples = []
    labels = []
    # Check that image and ground truth have the same 2D dimensions
    assert mat.shape[:2] == gt.shape[:2]    # Check whether the dimensions match, such as PaviaU's mat and gt are (610, 340)

    for label in np.unique(gt):
        if label in ignored_labels:
            continue
        else:
            indices = np.nonzero(gt == label)       # Returns all indexes for the same class of labels. (To determine whether each element of gt is label or not, if it is 1 or 0, then extract the index of all non-zero elements
            samples += list(mat[indices])
            labels += len(indices[0]) * [label]
    return np.asarray(samples), np.asarray(labels)
Analysis:

First, check whether the array dimensions are the same, implemented by the assert keyword, where the assert condition equals if not condition: raise Assertion Error ().
Assert mat. shape [: 2]== gt. shape [: 2] to check whether the first two dimensions of mat and GT are the same.
mat.shape[:2] is the first two dimensions of extracting mat array.

The np.unique() function returns The sorted unique values of type ndarray.
The return value of np.unique() is then traversed, and the index of the element GT = label in each traversed GT is obtained by np. nonzero (gt = label), which is returned as indices.

Then the corresponding index elements in mat are extended to samples by samples += list(mat[indices]).

The following is illustrated by a simple example:

import random
import numpy as np

mat = np.array([[0,0,0,0,0],[0,100,200,300,0],[0,200,300,200,0],[0,300,200,100,0],[0,0,0,0,0]])
gt = np.array([[0,0,0,0,0],[0,1,2,3,0],[0,2,3,2,0],[0,3,2,1,0],[0,0,0,0,0]])
ignored_labels = [0]
samples = []
labels = []

# Check that image and ground truth have the same 2D dimensions
assert mat.shape[:2] == gt.shape[:2]    # Check whether the dimensions match, such as PaviaU's mat and gt are (610, 340)

for label in np.unique(gt):
    if label in ignored_labels:
        continue
    else:
        indices = np.nonzero(gt == label)       # Returns all indexes for the same class of labels. (To determine whether each element of gt is label or not, if it is 1 or 0, then extract the index of all non-zero elements
        samples += list(mat[indices])
        labels += len(indices[0]) * [label]
print(mat)
# [[  0   0   0   0   0]
#  [  0 100 200 300   0]
#  [  0 200 300 200   0]
#  [  0 300 200 100   0]
#  [  0   0   0   0   0]]
print(gt)
# [[0 0 0 0 0]
#  [0 1 2 3 0]
#  [0 2 3 2 0]
#  [0 3 2 1 0]
#  [0 0 0 0 0]]
print(samples)
# [100, 100, 200, 200, 200, 200, 300, 300, 300]
print(labels)
# [1, 1, 2, 2, 2, 2, 3, 3, 3]

get_random_pos()

Functions:

Return the corner s of a random window in the input image

Input and output:
Input:
  • img: 2D (or more) image, e.g. RGB or grayscale image
  • window_shape: (width, height) tuple of the window
Output:
  • xmin, xmax, ymin, ymax: tuple of the corners of the window

    The two points representing the corner position (lower left corner and upper right corner) represent two parameters.

Code:
def get_random_pos(img, window_shape):
    """ Return the corners of a random window in the input image

    Args:
        img: 2D (or more) image, e.g. RGB or grayscale image
        window_shape: (width, height) tuple of the window

    Returns:
        xmin, xmax, ymin, ymax: tuple of the corners of the window

    """
    # Idea: First, find a point randomly, and then add w and h points of the grid to form a grid.
    w, h = window_shape
    W, H = img.shape[:2]        # Getting the first two dimensions of img
    x1 = random.randint(0, W - w - 1)       # Generating Random Numbers in Front-Closed and Post-Closed Intervals
    x2 = x1 + w
    y1 = random.randint(0, H - h - 1)
    y2 = y1 + h
    return x1, x2, y1, y2
analysis

The width and height are extracted from the input tuple by w, h = window_shape, and the first two dimensions W and H of img are obtained by W, H = img.shape[:2].

Then the position of the lower left corner is generated. The function used is random.randint(), which generates random numbers in the closed interval before and after closing. The range of random numbers of W dimension is (0, W - w - 1), and H dimension is the same.

If the lower left corner (x1,y1) is added with w and h, the lower right corner (x2,y2) is obtained. This represents a corner.

Finally, the four values of xmin, xmax, Ymin and ymax are returned as the return values of functions.

sliding_window()

Functions:

Generate Sliding window generator over an input image on the input image

Input and output:
Input:
  • image: 2D+ image to slide the window on, e.g. RGB or hyperspectral
  • step: int stride of the sliding window
  • window_size: int tuple, width and height of the window
  • with_data (optional): bool set to True to return both the data and the corner indices
Output:

When with_data is true, return image[x:x + w, y:y + h], x, y, w, h, that is, window data and window position parameters. When with_data is false, return x, y, w, h, that is, only return the location parameters of the window.

Code:
def sliding_window(image, step=10, window_size=(20, 20), with_data=True):
    """Sliding window generator over an input image.        # Sliding Window Generator on Input Image

    Args:
        image: 2D+ image to slide the window on, e.g. RGB or hyperspectral
        step: int stride of the sliding window
        window_size: int tuple, width and height of the window
        with_data (optional): bool set to True to return both the data and the
        corner indices
    Yields:
        ([data], x, y, w, h) where x and y are the top-left corner of the
        window, (w,h) the window size

    """
    # slide a window across the image
    w, h = window_size
    W, H = image.shape[:2]
    offset_w = (W - w) % step
    offset_h = (H - h) % step
    for x in range(0, W - w + offset_w, step):
        if x + w > W:
            x = W - w
        for y in range(0, H - h + offset_h, step):
            if y + h > H:
                y = H - h

            if with_data:
                yield image[x:x + w, y:y + h], x, y, w, h
            else:
                yield x, y, w, h
Analysis:

The parameters of window size and image size are obtained by w, h = window_size and W, H = image.shape[:2]. Then offset_w and offset_h are defined so that the window can slide in an appropriate range.

Keyword yield to create a generator.

The function with yield is a generator, not a function. The generator has a function called next function. Next is equivalent to the number generated by "next step". This time, the next starts at the place where the last next stops. So when you call next, the generator does not start with the function, but starts at the place where the last step stops, and then return s the number to be generated when you encounter yield, which ends.

For a detailed explanation of yield, see The Use of yield in python: the simplest and clearest explanation.

For this function, every time the function is called, the window will slide a step from the previous position.

count_sliding_window()

Functions:

Count the number of windows in an image.

Input and output:
Input:
  • image: 2D+ image to slide the window on, e.g. RGB or hyperspectral, ...
  • step: int stride of the sliding window
  • window_size: int tuple, width and height of the window
Output:
  • int number of windows
Code:

def count_sliding_window(top, step=10, window_size=(20, 20)):
Count the number of windows in an image. # Calculate the number of windows in an image

def count_sliding_window(top, step=10, window_size=(20, 20)):
    """ Count the number of windows in an image.        # Calculate the number of windows in an image

    Args:
        image: 2D+ image to slide the window on, e.g. RGB or hyperspectral, ...
        step: int stride of the sliding window
        window_size: int tuple, width and height of the window
    Returns:
        int number of windows
    """
    sw = sliding_window(top, step, window_size, with_data=False)
    return sum(1 for _ in sw)
Analysis:

The set sw of windows is obtained by calling the sliding_window() function, and then the return value + 1 is traversed through sw every time.

grouper()

Functions:

Browse an iterable by grouping n elements by n elements.

Input and output:
Input:
  • n: int, size of the groups
  • iterable: the iterable to Browse
Output:
  • chunk of n elements from the iterable
Code:
def grouper(n, iterable):       # Packet?
    """ Browse an iterable by grouping n elements by n elements.        # Browse iterable by grouping n elements with n elements

    Args:
        n: int, size of the groups
        iterable: the iterable to Browse    iteration
    Yields:
        chunk of n elements from the iterable       Iterable n Element block

    """
    it = iter(iterable)
    while True:
        chunk = tuple(itertools.islice(it, n))
        if not chunk:
            return
        yield chunk
Analysis:

(Provisional)

metrics()

Functions:

Calculate and print indicators, including accuracy, confusion matrix and F1 scores.

Input and output:
Input:
  • prediction: list of predicted labels
  • target: list of target labels
  • ignored_labels (optional): list of labels to ignore, e.g. 0 for undef
  • n_classes (optional): number of classes, max(target) by default
Output:
  • accuracy
  • F1 score by class
  • confusion matrix
Code:
def metrics(prediction, target, ignored_labels=[], n_classes=None):         # Output index
    """Compute and print metrics (accuracy, confusion matrix and F1 scores).

    Args:
        prediction: list of predicted labels
        target: list of target labels
        ignored_labels (optional): list of labels to ignore, e.g. 0 for undef
        n_classes (optional): number of classes, max(target) by default
    Returns:
        accuracy, F1 score by class, confusion matrix
    """
    ignored_mask = np.zeros(target.shape[:2], dtype=np.bool)
    for l in ignored_labels:
        ignored_mask[target == l] = True
    ignored_mask = ~ignored_mask
    target = target[ignored_mask]
    prediction = prediction[ignored_mask]

    results = {}

    n_classes = np.max(target) + 1 if n_classes is None else n_classes

    cm = confusion_matrix(
        target,
        prediction,
        labels=range(n_classes))

    results["Confusion matrix"] = cm

    # Compute global accuracy
    total = np.sum(cm)
    accuracy = sum([cm[x][x] for x in range(len(cm))])
    accuracy *= 100 / float(total)

    results["Accuracy"] = accuracy

    # Compute F1 score
    F1scores = np.zeros(len(cm))
    for i in range(len(cm)):
        try:
            F1 = 2. * cm[i, i] / (np.sum(cm[i, :]) + np.sum(cm[:, i]))
        except ZeroDivisionError:
            F1 = 0.
        F1scores[i] = F1

    results["F1 scores"] = F1scores

    # Compute kappa coefficient
    pa = np.trace(cm) / float(total)
    pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / \
        float(total * total)
    kappa = (pa - pe) / (1 - pe)
    results["Kappa"] = kappa

    return results
Analysis:

(I want to say here that I am not clear about the organization of prediction and target data, mainly because the code np.max(target) exists.)

The ignored_mask part is a mask created so that the part labeled ignored_labels is not considered.

Results = {} defines the output as a dictionary type, which is the return value of the function. At first, we just use results as an empty dictionary, and then gradually increase the key-value pairs.

Computing obfuscation matrix

cm = confusion_matrix(target, prediction, labels=range(n_classes)) calls the confusion_matrix() function (see the function for details). python sklearn calculates confusion_matrix() function of confusion matrix ) Simply put, this code calculates the obfuscation matrix of array type by three parameters: target, prediction and labels, and returns the result to cm.

Then the result ["Confusion matrix"] = cm code adds "Confusion matrix": the key-value pair of CM to the result of the dictionary type.

Calculating classification accuracy

The total number of samples is calculated by total = np.sum(cm). By summing up each element of the confusion matrix cm, the sum is the total sample number.

Then calculate the exact number of samples. The sum of the diagonal element values of the obfuscation matrix cm is the exact number of samples for classification.

The ratio of the two is the final accuracy. accuracy *= 100 / float(total)

import numpy as np
cm = np.array([[1,1,1],[2,2,2],[3,3,3]])
print(cm)
# [[1 1 1]
#  [2 2 2]
#  [3 3 3]]
print(len(cm))
# 3
print(range(len(cm)))
# range(0, 3)
print('\n')
for i in range(len(cm)):
    print(i)
# 0
# 1
# 2

Then results["Accuracy"] = accuracy, add "Accuracy": the key-value pair of accuracy to the result of the dictionary type.

Calculate F1 score

F1 Score is an index used to measure the accuracy of classification model in statistics. It takes into account both Accuracy and Recall Rate of the classification model. F1 score can be regarded as a harmonic average of model accuracy and recall rate.

This part of the code calls directly, just a set of formulas.

Calculating kappa coefficient

It is also an indicator, a set of formulas. (Provisional)

Return result

result of dictionary type is used as the return value of the function.

show_results()

Functions:

The results are output in text form on the visdom interface.

Input and output:
Input:
  • results: The dictionary dictionary type contains four key s: Confusion matrix, Accuracy, F1 score and Kappa.
  • vis: visdom Visualization Service.
  • label_values: The default is None.
  • Agrgated: False by default.
Output:
  • text: text output from visdom.
Code:
def show_results(results, vis, label_values=None, agregated=False):         # Visualization module
    text = ""
    # if agregated doesn't understand what to do
    if agregated:
        accuracies = [r["Accuracy"] for r in results]
        kappas = [r["Kappa"] for r in results]
        F1_scores = [r["F1 scores"] for r in results]

        F1_scores_mean = np.mean(F1_scores, axis=0)
        F1_scores_std = np.std(F1_scores, axis=0)
        cm = np.mean([r["Confusion matrix"] for r in results], axis=0)
        text += "Agregated results :\n"
    else:
        cm = results["Confusion matrix"]
        accuracy = results["Accuracy"]
        F1scores = results["F1 scores"]
        kappa = results["Kappa"]

    vis.heatmap(cm, opts={'title': "Confusion matrix", 
                          'marginbottom': 150,
                          'marginleft': 150,
                          'width': 500,
                          'height': 500,
                          'rownames': label_values, 'columnnames': label_values})
    text += "Confusion matrix :\n"
    text += str(cm)
    text += "---\n"

    if agregated:
        text += ("Accuracy: {:.03f} +- {:.03f}\n".format(np.mean(accuracies),
                                                         np.std(accuracies)))
    else:
        text += "Accuracy : {:.03f}%\n".format(accuracy)
    text += "---\n"

    text += "F1 scores :\n"
    if agregated:
        for label, score, std in zip(label_values, F1_scores_mean,
                                     F1_scores_std):
            text += "\t{}: {:.03f} +- {:.03f}\n".format(label, score, std)
    else:
        for label, score in zip(label_values, F1scores):
            text += "\t{}: {:.03f}\n".format(label, score)
    text += "---\n"

    if agregated:
        text += ("Kappa: {:.03f} +- {:.03f}\n".format(np.mean(kappas),
                                                      np.std(kappas)))
    else:
        text += "Kappa: {:.03f}\n".format(kappa)

    vis.text(text.replace('\n', '<br/>'))
    print(text)
Analysis:

The whole idea is to traverse the result key-value pairs, then expand the text by text += XXX, and finally print the result on the visdom.

Because we don't know what agregated is doing and the default is False, we only consider the case of agregated = false.

First, the value of the corresponding key is obtained by accessing the key of the dictionary result.

Then a thermal graph is drawn by the vis.heatmap() function. It needs to input NxM tensor X to specify the value of each position in the thermal graph, where is cm. Set title:'title':'Confusion matrix', size:'marginbottom': 150,'marginleft': 150,'width': 500,'height': 500, row and column labels:'rownames': label_values,'columnnames': label_values.

For cm, the CM is first converted to string type by str(cm), and then extended to text by +=.

For Accuracy, it is also extended to text, text += "Accuracy: {03f}% n. format (accuracy).

For F1scores, corresponding label_values are required. Through zip(label_values, F1scores), the corresponding elements of iterative objects label_values and F1scores are composed into tuples and returned as objects (zip() functions for details: Python zip() function ) Then, the object returned by zip() is traversed through the for loop, and the text is expanded at the same time. For label, score in zip (label_values, F1scores): text +=" t {}: {:. 03f} n." format (label, score).

For Kappa, it is simply extended to text, text += "Kappa: {0.03f} n". format (kappa).

The vis.text() function prints text in a box. It can be used to embed arbitrary HTML. It needs to enter a text string. opts currently has no specific support. Text. replace (' n','< br />) is used to replace line breaks. Vis. text (text. replace (' n','<br/>')

sample_gt()

Functions:

Extract a fixed percentage of samples from an array of labels from the tag array gt.

It should be emphasized that the samples divided into training sets and test sets do not include those classified as ignored_labels. Only valid samples are segmented.

Input and output:
Input:
  • gt: a 2D array of int labels
  • percentage: [0, 1] float
Output:
  • train_gt: 2D arrays of int labels
  • test_gt: 2D arrays of int labels
Code:
def sample_gt(gt, train_size, mode='random'):
    """Extract a fixed percentage of samples from an array of labels.   Extract a fixed percentage sample from the tag array.

    Args:
        gt: a 2D array of int labels
        percentage: [0, 1] float
    Returns:
        train_gt, test_gt: 2D arrays of int labels

    """
    indices = np.nonzero(gt)
    X = list(zip(*indices)) # x,y features
    y = gt[indices].ravel() # classes
    train_gt = np.zeros_like(gt)
    test_gt = np.zeros_like(gt)
    if train_size > 1:
       train_size = int(train_size)
    
    if mode == 'random':
       train_indices, test_indices = sklearn.model_selection.train_test_split(X, train_size=train_size, stratify=y)
       train_indices = [list(t) for t in zip(*train_indices)]
       test_indices = [list(t) for t in zip(*test_indices)]
       train_gt[train_indices] = gt[train_indices]
       test_gt[test_indices] = gt[test_indices]
    elif mode == 'fixed':
       print("Sampling {} with train size = {}".format(mode, train_size))
       train_indices, test_indices = [], []
       for c in np.unique(gt):
           if c == 0:
              continue
           indices = np.nonzero(gt == c)
           X = list(zip(*indices)) # x,y features

           train, test = sklearn.model_selection.train_test_split(X, train_size=train_size)
           train_indices += train
           test_indices += test
       train_indices = [list(t) for t in zip(*train_indices)]
       test_indices = [list(t) for t in zip(*test_indices)]
       train_gt[train_indices] = gt[train_indices]
       test_gt[test_indices] = gt[test_indices]

    elif mode == 'disjoint':
        train_gt = np.copy(gt)
        test_gt = np.copy(gt)
        for c in np.unique(gt):
            mask = gt == c
            for x in range(gt.shape[0]):
                first_half_count = np.count_nonzero(mask[:x, :])
                second_half_count = np.count_nonzero(mask[x:, :])
                try:
                    ratio = first_half_count / second_half_count
                    if ratio > 0.9 * train_size and ratio < 1.1 * train_size:
                        break
                except ZeroDivisionError:
                    continue
            mask[:x, :] = 0
            train_gt[mask] = 0

        test_gt[train_gt > 0] = 0
    else:
        raise ValueError("{} sampling is not implemented yet.".format(mode))
    return train_gt, test_gt
Analysis:

indices = np.nonzero(gt) retrieves the index of non-zero elements in gt. The return value is a tuple of two array s, representing the index in the x and y directions, respectively.

In X = list (zip (* indices), the return value of np.nonzero(gt) is first broken down into two arrays by * and then the corresponding elements of * indices are broken down into tuples by zip() function, which are returned as objects, and then converted to list type by list(). The sentence y = gt[indices].ravel() first obtains the corresponding elements according to index indices by gt[indices], and then expands into a one-dimensional array by ravel().

The functions of this part of the code can be represented by this demo:

gt = np.array([[0,0,0,0],[0,1,2,0],[0,3,4,0],[0,0,0,0]])
print(gt)
# [[0 0 0 0]
#  [0 1 2 0]
#  [0 3 4 0]
#  [0 0 0 0]]
indices = np.nonzero(gt)
X = list(zip(*indices))  # Index in the form of x, y features (x, y)
y = gt[indices].ravel()  # classes
print(X)
# [(1, 1), (1, 2), (2, 1), (2, 2)]
print(type(X))
# <class 'list'>
print(y)
# [1 2 3 4]
print(type(y))
# <class 'numpy.ndarray'>

train_gt = np.zeros_like(gt) and test_gt = np.zeros_like(gt) initialize train_gt and test_gt to all 0 arrays with the same dimension as gt.

Since the default mode is random, only mode = random is parsed here.

Training_indices, test_indices = sklearn. model_selection. train_test_split (X, train_size = train_size, stratify = y) This code mainly calls sklearn.model_selection.train_test_split() function, which is used to divide data into training set and test set. train_indices and test_indices are the result of partition and are the index of elements in the form of [(4,2), (3,3), (2,2), (3,2), (1,1), (4,3)].

The sentence train_indices = list (t) for t in zip (* train_indices)] is a transformation of the representation of train_indices into this form: [[4, 3, 2, 3, 1, 4], [2, 3, 2, 1, 3].

The sentence train_gt[train_indices] = gt[train_indices] extracts samples of training sets from gt. Other non-training set samples remain initialized at 0 as ignored_labels.

Attached below is a small demo to help better understand:

import numpy as np
import sklearn.model_selection
gt = np.array([[0,0,0,0],[0,1,2,0],[0,3,4,0],[1,2,3,4],[4,3,3,2],[0,0,0,0]])
print(gt)
# [[0 0 0 0]
#  [0 1 2 0]
#  [0 3 4 0]
#  [1 2 3 4]
#  [4 3 3 2]
#  [0 0 0 0]]
indices = np.nonzero(gt)
X = list(zip(*indices))  # Index in the form of x, y features (x, y)
y = gt[indices].ravel()  # classes
print(X)
# [(1, 1), (1, 2), (2, 1), (2, 2), (3, 0), (3, 1), (3, 2), (3, 3), (4, 0), (4, 1), (4, 2), (4, 3)]
print(type(X))
# <class 'list'>
print(y)
# [1 2 3 4 1 2 3 4 4 3 3 2]
print(type(y))
# <class 'numpy.ndarray'>
train_gt = np.zeros_like(gt)
test_gt = np.zeros_like(gt)
train_indices, test_indices = sklearn.model_selection.train_test_split(X, train_size=0.5, stratify=y)
print(train_indices)
# [(4, 2), (3, 3), (2, 2), (3, 2), (1, 1), (4, 3)]
print(test_indices)
# [(1, 2), (3, 1), (2, 1), (4, 0), (3, 0), (4, 1)]
print('________________________')
y_train = []
for i, j in train_indices:
    y_train.append(gt[i][j])
y_test = []
for i, j in test_indices:
    y_test.append(gt[i][j])
print(y_train)
# [3, 4, 2, 1, 2, 3]
print(y_test)
# [2, 3, 3, 4, 4, 1]
print('________________________')
train_indices = [list(t) for t in zip(*train_indices)]
print(train_indices)
# [[4, 3, 2, 3, 1, 4], [2, 3, 2, 2, 1, 3]]
train_gt[train_indices] = gt[train_indices]
print(train_gt)
# [[0 0 0 0]
#  [0 1 0 0]
#  [0 0 4 0]
#  [0 0 3 4]
#  [0 0 3 2]
#  [0 0 0 0]]

compute_imf_weights()

(Provisional)

camel_to_snake()

(Provisional)

module.py

_addindent()

(Provisional)

class Module(object)

This part is the base class of all neural network modules.

File is read-only. That is, it's better not to change it.

So (briefly)

model.py

class Baseline(nn.Module)

Define a class, inherit nn.Module.

Properties:

nothing

Method:

weight_init()

def weight_init(m):
    if isinstance(m, nn.Linear):        # Determine whether the types are the same
        init.kaiming_normal_(m.weight)      # A Weight Initialization Method
        init.zeros_(m.bias)

Used to initialize weight s and bias.

Firstly, if isinstance(m, nn.Linear) is used to see if input m and nn.Linear are a kind of inheritance relationship (robustness test).

The weight is initialized by init.kaiming_normal_(m.weight), in which kaiming_normal_() is a method of initializing weight.

init.zeros_(m.bias) initializes biases to zero.

__ init__()

def __init__(self, input_channels, n_classes, dropout=False):   # Initialization of class attributes
    super(Baseline, self).__init__()
    self.use_dropout = dropout
    if dropout:
        self.dropout = nn.Dropout(p=0.5)

    self.fc1 = nn.Linear(input_channels, 2048)
    self.fc2 = nn.Linear(2048, 4096)
    self.fc3 = nn.Linear(4096, 2048)
    self.fc4 = nn.Linear(2048, n_classes)

    self.apply(self.weight_init)

This part is the initialization of the class, including whether to use dropout (True or Flase), the number of layers of the network and in_channel and out_channel. At the same time, the network parameters (weight and bias) are initialized by self.apply(self.weight_init).

forward(self, x)

def forward(self, x):
    x = F.relu(self.fc1(x))
    if self.use_dropout:
        x = self.dropout(x)
    x = F.relu(self.fc2(x))
    if self.use_dropout:
        x = self.dropout(x)
    x = F.relu(self.fc3(x))
    if self.use_dropout:
        x = self.dropout(x)
    x = self.fc4(x)
    return x 

This method defines the forward propagation process as follows:

input
nn.Linear(input_channels, 2048)
relu()
dropout()
self.fc2 = nn.Linear(2048, 4096)
relu()
dropout()
nn.Linear(4096, 2048)
relu()
dropout()
self.fc4 = nn.Linear(2048, n_classes)
output

class HuEtAl(nn.Module)

Properties:

nothing

Method:

weight_init()

def weight_init(m):
    # [All the trainable parameters in our CNN should be initialized to
    # be a random value between −0.05 and 0.05.]
    # All trainable parameters in our CNN should be initialized to random values between -0.05 and 0.05.
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d):
        init.uniform_(m.weight, -0.05, 0.05)
        init.zeros_(m.bias)

The weight of the model is initialized as a random value between -0.05 and 0.05, and the bias bias is initialized as 0.

_get_final_flattened_size()

def _get_final_flattened_size(self):    # Get the final flat size
    with torch.no_grad():
        x = torch.zeros(1, 1, self.input_channels)      # Generate a tensor of all 0 for 1 x 1 x input_channels
        x = self.pool(self.conv(x))         # First convolution, then pooling
    return x.numel()        # numel() returns the number of elements in a tensor

First, set with torch.no_grad(). Disabling gradient calculation is very useful for inference when you are sure that Tensor.backward () will not be called. It will reduce the memory consumption of the calculation, otherwise there will be require_grad = True.

x = torch.zeros(1, 1, self.input_channels) generates a tensor of all 0 of 1 * 1 * input_channels.

x = self.pool(self.conv(x)), first convolution, then pooling.

return x.numel() returns the number of elements in the tensor.

I don't see why I did it now.

__ init__()

    def __init__(self, input_channels, n_classes, kernel_size=None, pool_size=None):
        super(HuEtAl, self).__init__()
        if kernel_size is None:
           # [In our experiments, k1 is better to be [ceil](n1/9)]
           kernel_size = math.ceil(input_channels / 9)
        if pool_size is None:
           # The authors recommand that k2's value is chosen so that the pooled features have 30~40 values
           # ceil(kernel_size/5) gives the same values as in the paper so let's assume it's okay
           pool_size = math.ceil(kernel_size / 5)
        self.input_channels = input_channels

        # [The first hidden convolution layer C1 filters the n1 x 1 input data with 20 kernels of size k1 x 1]
        # The first hidden convolution layer C1 filters n1*1 input data with 20 kernels of size k1*1
        self.conv = nn.Conv1d(1, 20, kernel_size)
        self.pool = nn.MaxPool1d(pool_size)
        self.features_size = self._get_final_flattened_size()
        # [n4 is set to be 100]
        self.fc1 = nn.Linear(self.features_size, 100)
        self.fc2 = nn.Linear(100, n_classes)
        self.apply(self.weight_init)

This part initializes the kernel_size, pool_size, and reads the input_channels, as well as the structure of the network layer.

The choice of kernel_size and pool_size is based on the paper. Kernel_size is math.ceil (input_channels/9), that is, input_channels/9 is rounded up. Pool_size = math.ceil (kernel_size/5), i.e. kernel_size/5 is rounded up.

self.input_channels = input_channels.

For the structure of the network layer:

The first hidden convolution layer C1 filters n1*1 input data with 20 kernels of size k1*1, self.conv = nn.Conv1d(1, 20, kernel_size).

The pool layer unification is self.pool = nn.MaxPool1d(pool_size).

self.features_size = self._get_final_flattened_size() gets flattened size as the input dimension of the first linear layer (full connection layer).

Then two linear layers (full connection layer) are defined. The first one is self.fc1 = nn.Linear(self.features_size, 100), the input dimension is self.features_size, and the output dimension is 100. The second is self. FC2 = n n. Linear (100, n_classes), input dimension 100, output dimension n_classes.

forward()

def forward(self, x):
    # [In our design architecture, we choose the hyperbolic tangent function tanh(u)]
    # In our design framework, we choose hyperbolic tangent function.
    x = x.squeeze(dim=-1).squeeze(dim=-1)
    x = x.unsqueeze(1)
    x = self.conv(x)
    x = torch.tanh(self.pool(x))
    x = x.view(-1, self.features_size)
    x = torch.tanh(self.fc1(x))
    x = self.fc2(x)
    return x

x = x.squeeze(dim=-1).squeeze(dim=-1) is a data preprocessing that erases the penultimate first dimension and penultimate second dimension of x.

x = x.unsqueeze(1) adds another dimension at the second dimension (the dimension index starts at 0).

Demonstrate with a small demo (note: here x is tensor type):

import numpy as np

x = np.array([1,2,3])
print(x.shape)
# (3,)
x = x.reshape(3,1,1)
print(x.shape)
# (3, 1, 1)
x = x.squeeze(-1).squeeze(-1)
print(x.shape)
# (3,)

Then there is the forward propagation process of the network.

input
nn.Conv1d(1, 20, kernel_size)
nn.MaxPool1d(pool_size)
tanh()
view(-1, self.features_size)
nn.Linear(self.features_size, 100)
tanh()
nn.Linear(100, n_classes)
output

(Other network models, for the time being...)

get_model()

Functions:

Get the model name and corresponding hyperparameters (instantiate and obtain a model with sufficient hyperparameters, Instantiate and obtain a model with adequate hyperparameters)

Input and output:

Input:
  • Name: name of the model, string of the model name
  • Kwargs: Hyperparametric, dictionary type, ** kwargs denotes an indefinite number
Output:
  • model: PyTorch network
  • optimizer: PyTorch optimizer
  • criterion: PyTorch loss Function
  • kwargs: Hyperparameters with rational defaults

Code and parse:

def get_model(name, **kwargs):
    """
    Instantiate and obtain a model with adequate hyperparameters

    Args:
        name: string of the model name      Network name, string type
        kwargs: hyperparameters             Hyper parameter dictionary Type,**kwargs Indicates an indefinite number.
    Returns:
        model: PyTorch network
        optimizer: PyTorch optimizer
        criterion: PyTorch loss Function
        kwargs: hyperparameters with sane defaults  Superparameters with rational defaults
    """
    device = kwargs.setdefault('device', torch.device('cpu'))   # Get the value of the key device in the dictionary kwargs, otherwise return the default value of cpu.
    n_classes = kwargs['n_classes']                             # Gets the value of the key n_classes in the dictionary kwargs.
    n_bands = kwargs['n_bands']                                 # Gets the value of the key n_bands in the dictionary kwargs.
    weights = torch.ones(n_classes)
    weights[torch.LongTensor(kwargs['ignored_labels'])] = 0.
    weights = weights.to(device)            # Put it on cpu or gpu
    weights = kwargs.setdefault('weights', weights)

First of all, it should be emphasized that hyperparameters are stored in kwargs in the form of key-value pairs. By accessing the dictionary kwargs, the commonly used hyperparameter device s, n_classes and n_bands are obtained, and the default values are set through the setdefault() function.

if name == 'nn':
	......
elif name == 'hamida':
	......

In this part, according to the choice of the model, we use branching statements to set appropriate hyperparameters of our model, such as learning_rate, optimizer, criterion, epoch and batch_size.

model = model.to(device)
epoch = kwargs.setdefault('epoch', 100)
kwargs.setdefault('scheduler', optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=epoch//4, verbose=True))
#kwargs.setdefault('scheduler', None)
kwargs.setdefault('batch_size', 100)
kwargs.setdefault('supervision', 'full')
kwargs.setdefault('flip_augmentation', False)
kwargs.setdefault('radiation_augmentation', False)
kwargs.setdefault('mixture_augmentation', False)
kwargs['center_pixel'] = center_pixel

This part of the function is mainly setdefault(), when the model parameters are incomplete, the parameters are set to the default value.

But in fact, the parameters of the model are usually set in the previous if-esle, so these are actually the function of leak detection.

val()

Functions:

Calculate the accuracy of val set.

Input and output:

Input:
  • net
  • data_loader
  • device
  • supervision
Output:
  • accuracy / total: Actually, it's the accuracy. Accuracy counts the exact number.

Code and parse:

Function definition
def val(net, data_loader, device='cpu', supervision='full'):
# TODO : fix me using metrics()
Initialization and Pre-operation
accuracy, total = 0., 0.
ignored_labels = data_loader.dataset.ignored_labels
  • Initialize accuracy and total to float type 0
  • Get ignored_labels
Start testing
for batch_idx, (data, target) in enumerate(data_loader):

It should be noted that in the process of re-detection, val set is traversed only once, i.e. epoch 1.

Maintain gradient
with torch.no_grad():

Because this part is only used as detection and does not train the network, it is necessary to set up with torch.no_grad().

Select device
# Load the data into the GPU if required
data, target = data.to(device), target.to(device)

There's nothing to say about putting data on device.

Obtaining predictions in different ways
if supervision == 'full':
    output = net(data)
elif supervision == 'semi':
    outs = net(data)
    output, rec = outs
    
_, output = torch.max(output, dim=1)

Generally, it is full supervision, so we only look at the situation of full supervision. output = net(data), nothing to say.

_ output = torch.max(output, dim=1) gets the predicted value, and torch.max(output, dim=1) finds the largest element by row and returns the largest element and index (values, indices). _ Output represents the index indices in the retrieved value, and the number of indexes represents the category.

Statistical accuracy and total
for out, pred in zip(output.view(-1), target.view(-1)):
    if out.item() in ignored_labels:
        continue
    else:
        accuracy += out.item() == pred.item()
        total += 1

There's nothing to say. It's easy to understand.

Return accuracy / total
return accuracy / total

save_model()

def save_model(model, model_name, dataset_name, **kwargs):
     model_dir = './checkpoints/' + model_name + "/" + dataset_name + "/"
     if not os.path.isdir(model_dir):
         os.makedirs(model_dir, exist_ok=True)
     if isinstance(model, torch.nn.Module):
         filename = str('wk') + "_epoch{epoch}_{metric:.2f}".format(**kwargs)
         tqdm.write("Saving neural network weights in {}".format(filename))
         torch.save(model.state_dict(), model_dir + filename + '.pth')
     else:
         filename = str('wk')
         tqdm.write("Saving model params in {}".format(filename))
         joblib.dump(model, model_dir + filename + '.pkl')

It's easy to see, and it doesn't need to be carefully studied now.

train()

Functions:

Functions of encapsulated network training. (Training loop to optimize a network for several epochs and a specified loss)

Input and output:

Input:
  • net: a PyTorch model
  • optimizer: a PyTorch optimizer
  • data_loader: a PyTorch dataset loader
  • epoch: int specifying the number of training epochs
  • criterion: a PyTorch-compatible loss function, e.g. nn.CrossEntropyLoss
  • device (optional): torch device to use (defaults to CPU)
  • display_iter (optional): number of iterations before refreshing the display (False/None to switch off).
  • Schduler (optional): PyTorch scheduler, adjusting learning rate lr based on epoch
  • val_loader (optional): validation dataset
  • supervision (optional): 'full' or 'semi'
Output:
  • nothing

Code and parse:

Definition and Information of Functions
def train(net, optimizer, criterion, data_loader, epoch, scheduler=None,
          display_iter=100, device=torch.device('cpu'), display=None,
          val_loader=None, supervision='full'):
    """
    Training loop to optimize a network for several epochs and a specified loss

    Args:
        net: a PyTorch model
        optimizer: a PyTorch optimizer
        data_loader: a PyTorch dataset loader
        epoch: int specifying the number of training epochs
        criterion: a PyTorch-compatible loss function, e.g. nn.CrossEntropyLoss
        device (optional): torch device to use (defaults to CPU)
        display_iter (optional): number of iterations before refreshing the display (False/None to switch off).
        scheduler (optional): PyTorch scheduler
        val_loader (optional): validation dataset
        supervision (optional): 'full' or 'semi'
    """
Robustness Detection of Loss Function
if criterion is None:
    raise Exception("Missing criterion. You must specify a loss function.")

In this part, if the loss function criterion does not exist, the error is reported and the error information is printed.

This is to increase the robustness of the program without affecting the main functions of the program.

Initialize some variables
net.to(device)

save_epoch = epoch // 20 if epoch > 20 else 1

losses = np.zeros(1000000)
mean_losses = np.zeros(100000000)
iter_ = 1
loss_win, val_win = None, None
val_accuracies = []
Training network
Start the epoch cycle
for e in tqdm(range(1, epoch + 1), desc="Training the network"):

range(1, epoch + 1) indicates that the number of cycles is epoch.

tqdm() creates a progress bar with descriptive information desc="Training the network".

Setting model as training mode
# Set the network to training mode
net.train()
avg_loss = 0.

For the train() function, the PyTorch official documentation is as follows:

train(mode=True)[SOURCE]

Sets the module in training mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

  • Parameters

    mode (bool) – whether to set training mode (True) or evaluation mode (False). Default: True.

  • Returns

    self

  • Return type

    Module

Set up the module in the training mode.

This has only any effect on some modules. If they are affected (e.g. Dropout, BatchNorm, etc.), please refer to the documentation for specific modules for detailed information on their behavior in training/evaluation mode.

Here net.train() is to turn net into training mode.

avg_loss = 0. Initialize avg_loss to float type 0.

batch training in epoch
for batch_idx, (data, target) in tqdm(enumerate(data_loader), total=len(data_loader)):

enumerate(data_loader) forms an index sequence of data_loader. Enumerate (train_loader): <enumerate object at 0x000001 D86C258750>.

The return value of len(data_loader) is the batch value in an epoch (how many batches are divided into samples). Here the sample number of training set is np.count_nonzero(train_gt): 4063, corresponding len(train_loader): 41. Since the batch_size is set at 100, 4063 samples of the training set are divided into 41 batches (the last batch sample is not big enough for batch_size).

By traversing the index sequence composed of enumerate(data_loader), the index is assigned to batch_idx, indicating the number of batches; data and real values are assigned to (data, target), indicating input data and true values.

data target to device
# Load the data into the GPU if required
data, target = data.to(device), target.to(device)

Put data and target on the corresponding device, default to cpu, generally gpu.

Forward propagation
optimizer.zero_grad()
if supervision == 'full':
    output = net(data)
    loss = criterion(output, target)
elif supervision == 'semi':
    outs = net(data)
    output, rec = outs
    loss = criterion[0](output, target) + net.aux_loss_weight * criterion[1](rec, data)
else:
    raise ValueError("supervision mode \"{}\" is unknown.".format(supervision))

First, zero the gradient of optimizer: optimizer.zero_grad().

Then choose different training methods according to the different ways of supervision, because they are generally full supervision, so only analyze the situation of full supervision:

if supervision == 'full':
    output = net(data)
    loss = criterion(output, target)
  • Calculate the predicted value output: output = net(data)
  • Calculating loss function loss: criterion(output, target)
Back propagation
loss.backward()
optimizer.step()
  • Back propagation: loss.backward()
  • Optimization: optimizer.step()
Calculated loss
avg_loss += loss.item()
losses[iter_] = loss.item()
mean_losses[iter_] = np.mean(losses[max(0, iter_ - 100):iter_ + 1])
  • avg_loss: Initialized to 0, add the value of loss calculated by each batch.
  • Loss: ndarray type, where each location stores the value of the loss for each iteration (each batch). For example, the location of index 1 stores the loss value of the first iteration (the first batch).
  • mean_losses: ndarray type, which places the index iter_at the mean of [max(0, iter_ - 100):iter_ + 1] of losses. I don't know why.
Drawing Training loss and Validation Acracy Curves
if display_iter and iter_ % display_iter == 0: 

Here's why we need such an if judgment.

First, we need to address the meaning of the variable display_iter:

display_iter (optional): number of iterations before refreshing the display (False/None to switch off).

Simply put, display_iter means that once iterated iter_is an integer multiple of display_iter (for example, 100, 200,...). refreshing the display.

So when display_iter is not zero and iter_is an integer multiple of display_iter (iter_is the remainder of display_iter is 0), the Training loss and Validation accuracy curves are updated.

string = 'Train (epoch {}/{}) [{}/{} ({:.0f}%)]\tLoss: {:.6f}'
string = string.format(
	e, epoch, batch_idx *
	len(data), len(data) * len(data_loader),
	100. * batch_idx / len(data_loader), mean_losses[iter_])

The example printed in this paragraph (my own data) is that there are six values:

Train (epoch 55/100) [31100/32200 (97%)]        Loss: 0.024220
  • e: The current epoch number traverses the entire training set for several times.
  • epoch: the total number of epochs, a total of several training sets traversed.
  • batch_idx*len(data): The number of samples that this epoch has trained to end. batch_idx is the current batch number, and len(data) is the number of batch training samples.
  • len(data) * len(data_loader): The number of samples to be trained for this epoch. len(data_loader) is the total number of batches, len(data) is the number of training samples of a batch.
  • 100. * batch_idx/ len (data_loader): the number of samples that the epoch has trained / the number of samples that the epoch has to train, the value, a percentage.
  • mean_losses[iter_]: The value of the loss of iter_th iteration (iter_batch).
update = None if loss_win is None else 'append'
loss_win = display.line(
	X=np.arange(iter_ - display_iter, iter_),
	Y=mean_losses[iter_ - display_iter:iter_],
	win=loss_win,
	update=update,
	opts={'title': "Training loss",
		'xlabel': "Iterations",
		'ylabel': "Loss"
		}
)
tqdm.write(string)

The first sentence update = None if loss_win is None else'append'means that if loss_win is None, update is None, and if loss_win is not None, update is'append'.

Los_win is defined at the bottom of the first sentence, which results in that update is None at the first run and'append'at the next run.

update: None
iter_: 100
display_iter: 100
--------------------------------------------------------
Train (epoch 3/100) [1700/4100 (41%)]   Loss: 0.916481


update: append
iter_: 200
display_iter: 100
--------------------------------------------------------
Train (epoch 5/100) [3500/4100 (85%)]   Loss: 0.695566


update: append
iter_: 300
display_iter: 100
--------------------------------------------------------
Train (epoch 8/100) [1200/4100 (29%)]   Loss: 0.599200
loss_win = display.line(
	X=np.arange(iter_ - display_iter, iter_),
	Y=mean_losses[iter_ - display_iter:iter_],
	win=loss_win,
	update=update,
	opts={'title': "Training loss",
		'xlabel': "Iterations",
		'ylabel': "Loss"
		}
)
tqdm.write(string)

Here display is vis, using visdom visualization, the function used is vis.line(). About vis.line():

vis.line

This function draws a line graph. It needs to enter an N or NxM tensor Y to specify the value of the M line (connecting N points) to be drawn. It also uses an optional X tensor to specify the corresponding x-axis value; X can be an N tensor (in this case, all lines will share the same x-axis value) or have the same size Y.

The following opts are supported:

  • Optis. fillarea: Area below boolean
  • opts.colormap : colormap(string; default = 'Viridis')
  • opts.markers : show markers(boolean; default = false)
  • opts.markersymbol: a glyph (string; default='dot')
  • opts.markersize: tag size (number; default='10')
  • opts.legend: table contains the legend name

win=loss_win, as I guess, should be set to lose_win, otherwise so many window s do not know which one to operate.

The results of the first three times are shown as follows:

[External Link Picture Transfer Failure (img-6scjDfp7-15682040201) (file://E: 802; Image Group BC~H {8N6WZ9FQWXTFVFRJ] 6.png)]

Simply put, when iter_equals display_iter in the first run, it creates a new window and draws it once, and then draws it again when iter_equals an integer multiple of display_iter, and updates the later drawn part to the first drawn image (window).

Finally, the progress bar is printed by tqdm.write(string).

if len(val_accuracies) > 0:
val_win = display.line(Y=np.array(val_accuracies),
                           X=np.arange(len(val_accuracies)),
                           win=val_win,
                           opts={'title': "Validation accuracy",
                                 'xlabel': "Epochs",
                                 'ylabel': "Accuracy"
                                })

This part is to print Validation accuracy, as well as the designated window, which is all above, no more.

Iterative variable plus one
iter_ += 1
Recycling useless variables
del(data, target, loss, output)

For del method:

Corresponding to the init() method is the del() method, which is used to initialize the Python object, while the del() method is used to destroy the Python object, that is, when any Python object will be recycled by the system, the system will automatically call the del() method of the object. When a program no longer needs a Python object, the system must release the memory space occupied by the object. This process is called Garbage Collector. Python automatically reclaims the memory space occupied by all objects, so developers need not care about the process of object garbage collection.

Simply put, after running the batch, the data, target, loss and output of the batch are not needed and can be reclaimed to release memory.

Because in the next batch there will be new data, target s, new loss and output.

So far, an epoch is over

Calculate avg_loss, val_accuracies, metric
avg_loss /= len(data_loader)

if val_loader is not None:
    val_acc = val(net, val_loader, device=device, supervision=supervision)
    val_accuracies.append(val_acc)
    metric = -val_acc
else:
    metric = avg_loss

avg_loss is the mean of loss for all batch es in an epoch.

The following judgment statement is that if val_loader is None (the first execution here), metric = avg_loss; when it is executed here, the val() function is called to calculate the accuracy of val set and added to val_accuracies, then the metric is set to - val_acc (for some reason, temporarily).

Save the weights
# Save the weights
if e % save_epoch == 0:
    save_model(net, camel_to_snake(str(net.__class__.__name__)), data_loader.dataset.name, epoch=e, metric=abs(metric))

Examples of stored file names are:

wk_epoch60_0.99.pth

test()

Functions:

Test a model on a specific image

Input and output:

Input:
  • net
  • img: Images for test ing
  • hyperparams: Hyperparametric Dictionary
Output:
  • probs: W × H × n_classes.

Code and parse:

Function definition
def test(net, img, hyperparams):
    """
    Test a model on a specific image
    """
Set the model to test mode
net.eval()
Extraction of Superparameters
patch_size = hyperparams['patch_size']
center_pixel = hyperparams['center_pixel']
batch_size, device = hyperparams['batch_size'], hyperparams['device']
n_classes = hyperparams['n_classes']

kwargs = {'step': hyperparams['test_stride'], 'window_size': (patch_size, patch_size)}
  • patch_size: The size of the window. Windows can contain context information.
  • center_pixel: For True, just look at the middle sample, regardless of context information.
  • batch_size
  • device
  • kwargs: A dictionary with step size and window_size (tuple type)
Initialize return result probs
probs = np.zeros(img.shape[:2] + (n_classes,))

The dimension of img is W * H * channel, and img.shape[:2] + (n_classes,) obtains the first two dimensions of img, and takes n_classes as the third dimension, that is, W * H * n_classes.

Demonstrate with a small demo:

shape = (340,680,103)
print(shape[:2] + (10,))
# (340, 680, 10)

Initialization of probs results in an array of all 0 of W * H * n_classes.

Calculate iterations of the total number of iterations
iterations = count_sliding_window(img, **kwargs) // batch_size

count_sliding_window() calculates how many windows can be generated for the whole image, and the batch_size is the most batch, which is divided by the maximum iterations.

Start iteration
for batch in tqdm(grouper(batch_size, sliding_window(img, **kwargs)),
                  total=(iterations),
                  desc="Inference on the image"
                  ):

tqdm() is to generate progress bar, total represents the upper limit of progress bar, desc is a string describing information.

grouper() is a grouper that returns chunk of n elements from the iterable, which should be n elements obtained from the iterator sliding_window().

Extract data:
with torch.no_grad():
    if patch_size == 1:
        data = [b[0][0, 0] for b in batch]
        data = np.copy(data)
        data = torch.from_numpy(data)
    else:
        data = [b[0] for b in batch]
        data = np.copy(data)
        data = data.transpose(0, 3, 1, 2)
        data = torch.from_numpy(data)
        data = data.unsqueeze(1)

Through this self-added code, to see batch and b related information:

# --------------------------------------------------------------------------------------------------------------------------------------------------------------
print('batch',batch)
for b in batch:
    print('b:',b)
    print('b[0]:',b[0])
    print('b[0][0, 0]:',b[0][0, 0])
os.system('pause')
# --------------------------------------------------------------------------------------------------------------------------------------------------------------

Only the patch_size case is explained in detail here.

The first is the form of b. The type of b is tuple (because the element types are different):

b: (array([[[0.077625, 0.09325 , 0.0695  , 0.045   , 0.035625, 0.0375  ,
         0.03425 , 0.0345  , 0.0415  , 0.039875, 0.03475 , 0.031875,
         0.029   , 0.025875, 0.02625 , 0.026125, 0.021   , 0.017375,
         0.017125, 0.01925 , 0.021   , 0.02525 , 0.028125, 0.028875,
         0.0305  , 0.032125, 0.032875, 0.03275 , 0.03325 , 0.0345  ,
         0.035625, 0.036375, 0.035625, 0.034   , 0.033875, 0.030125,
         0.026   , 0.02425 , 0.022375, 0.019625, 0.02025 , 0.0215  ,
         0.02125 , 0.0195  , 0.018875, 0.017625, 0.016875, 0.015875,
         0.017125, 0.017875, 0.018   , 0.016125, 0.012875, 0.013   ,
         0.012375, 0.011625, 0.013375, 0.01575 , 0.014625, 0.01225 ,
         0.01175 , 0.01175 , 0.012375, 0.01275 , 0.0145  , 0.019125,
         0.0235  , 0.030375, 0.04025 , 0.051625, 0.0615  , 0.073875,
         0.092125, 0.116625, 0.140625, 0.165875, 0.189875, 0.20825 ,
         0.22375 , 0.24175 , 0.253625, 0.25425 , 0.25125 , 0.258625,
         0.273875, 0.279125, 0.280625, 0.281125, 0.281875, 0.28125 ,
         0.281125, 0.279875, 0.279875, 0.28525 , 0.286   , 0.28025 ,
         0.274125, 0.27525 , 0.278125, 0.28325 , 0.2885  , 0.293125,
         0.295125]]], dtype=float32), 0, 2, 1, 1)

So batch should be a list or tuple of many b's (I prefer lists).

So b[0] is the part of the data, that is:

array([[[0.077625, 0.09325 , 0.0695  , 0.045   , 0.035625, 0.0375  ,
         0.03425 , 0.0345  , 0.0415  , 0.039875, 0.03475 , 0.031875,
         0.029   , 0.025875, 0.02625 , 0.026125, 0.021   , 0.017375,
         0.017125, 0.01925 , 0.021   , 0.02525 , 0.028125, 0.028875,
         0.0305  , 0.032125, 0.032875, 0.03275 , 0.03325 , 0.0345  ,
         0.035625, 0.036375, 0.035625, 0.034   , 0.033875, 0.030125,
         0.026   , 0.02425 , 0.022375, 0.019625, 0.02025 , 0.0215  ,
         0.02125 , 0.0195  , 0.018875, 0.017625, 0.016875, 0.015875,
         0.017125, 0.017875, 0.018   , 0.016125, 0.012875, 0.013   ,
         0.012375, 0.011625, 0.013375, 0.01575 , 0.014625, 0.01225 ,
         0.01175 , 0.01175 , 0.012375, 0.01275 , 0.0145  , 0.019125,
         0.0235  , 0.030375, 0.04025 , 0.051625, 0.0615  , 0.073875,
         0.092125, 0.116625, 0.140625, 0.165875, 0.189875, 0.20825 ,
         0.22375 , 0.24175 , 0.253625, 0.25425 , 0.25125 , 0.258625,
         0.273875, 0.279125, 0.280625, 0.281125, 0.281875, 0.28125 ,
         0.281125, 0.279875, 0.279875, 0.28525 , 0.286   , 0.28025 ,
         0.274125, 0.27525 , 0.278125, 0.28325 , 0.2885  , 0.293125,
         0.295125]]], dtype=float32)

b[0] has a shape of:

(1, 1, 103)

Then b[0][0,0] and its shape are:

[0.077625, 0.09325 , 0.0695  , 0.045   , 0.035625, 0.0375  ,
 0.03425 , 0.0345  , 0.0415  , 0.039875, 0.03475 , 0.031875,
 0.029   , 0.025875, 0.02625 , 0.026125, 0.021   , 0.017375,
 0.017125, 0.01925 , 0.021   , 0.02525 , 0.028125, 0.028875,
 0.0305  , 0.032125, 0.032875, 0.03275 , 0.03325 , 0.0345  ,
 0.035625, 0.036375, 0.035625, 0.034   , 0.033875, 0.030125,
 0.026   , 0.02425 , 0.022375, 0.019625, 0.02025 , 0.0215  ,
 0.02125 , 0.0195  , 0.018875, 0.017625, 0.016875, 0.015875,
 0.017125, 0.017875, 0.018   , 0.016125, 0.012875, 0.013   ,
 0.012375, 0.011625, 0.013375, 0.01575 , 0.014625, 0.01225 ,
 0.01175 , 0.01175 , 0.012375, 0.01275 , 0.0145  , 0.019125,
 0.0235  , 0.030375, 0.04025 , 0.051625, 0.0615  , 0.073875,
 0.092125, 0.116625, 0.140625, 0.165875, 0.189875, 0.20825 ,
 0.22375 , 0.24175 , 0.253625, 0.25425 , 0.25125 , 0.258625,
 0.273875, 0.279125, 0.280625, 0.281125, 0.281875, 0.28125 ,
 0.281125, 0.279875, 0.279875, 0.28525 , 0.286   , 0.28025 ,
 0.274125, 0.27525 , 0.278125, 0.28325 , 0.2885  , 0.293125,
 0.295125]
 
(103,)

After understanding the composition and types of b, b[0] and b[0][0,1], come back to this sentence:

data = [b[0][0, 0] for b in batch]

To the right of the equal sign is a middle bracket [], which means a list. In the list, there is a for loop. The b obtained from each cycle is manipulated by b[0][0, 0] and put into the list as an element of the list.

To demonstrate the functions with a small demo:

a1 = np.array([[[0.079625, 0.074   , 0.06025 , 0.0695  , 0.0635  , 0.0355  ,
         0.02225 , 0.02475 , 0.024125, 0.028   , 0.027125, 0.026875,
         0.023375, 0.020125, 0.019   , 0.017   , 0.0155  , 0.01525 ,
         0.015875, 0.01575 , 0.015625, 0.015375, 0.018375, 0.0235  ,
         0.026   , 0.025375, 0.02525 , 0.02575 , 0.027375, 0.029375,
         0.02975 , 0.028375, 0.027125, 0.026875, 0.027   , 0.025125,
         0.02375 , 0.020875, 0.018625, 0.02025 , 0.02175 , 0.0225  ,
         0.022125, 0.02125 , 0.0205  , 0.021625, 0.0235  , 0.02275 ,
         0.020125, 0.018375, 0.017625, 0.01925 , 0.021875, 0.02075 ,
         0.0175  , 0.01725 , 0.01825 , 0.017375, 0.016125, 0.018125,
         0.01925 , 0.017125, 0.016625, 0.016   , 0.016125, 0.02175 ,
         0.030625, 0.04225 , 0.056875, 0.073125, 0.09    , 0.10625 ,
         0.126625, 0.153125, 0.1825  , 0.21275 , 0.24225 , 0.269625,
         0.289625, 0.304125, 0.315625, 0.319   , 0.311625, 0.31925 ,
         0.341625, 0.347625, 0.3435  , 0.3435  , 0.342125, 0.33875 ,
         0.335125, 0.33025 , 0.330625, 0.3355  , 0.334375, 0.326125,
         0.317625, 0.318875, 0.321375, 0.321125, 0.321625, 0.3275  ,
         0.3305  ]]])
a2 = np.array([[[0.115   , 0.0825  , 0.058125, 0.038875, 0.04375 , 0.046875,
         0.047   , 0.041375, 0.032625, 0.024875, 0.022375, 0.021125,
         0.02075 , 0.022375, 0.021625, 0.019125, 0.017375, 0.015   ,
         0.013375, 0.016625, 0.019875, 0.0235  , 0.028375, 0.0305  ,
         0.030875, 0.032375, 0.035   , 0.03725 , 0.04    , 0.042375,
         0.042625, 0.043375, 0.042375, 0.039875, 0.0385  , 0.036375,
         0.033875, 0.032375, 0.03075 , 0.029   , 0.0295  , 0.029625,
         0.030625, 0.027625, 0.024875, 0.02525 , 0.025375, 0.02625 ,
         0.026375, 0.026   , 0.0265  , 0.026875, 0.026875, 0.027125,
         0.02675 , 0.024125, 0.022875, 0.021625, 0.01975 , 0.019125,
         0.0195  , 0.019875, 0.019625, 0.02025 , 0.023625, 0.0295  ,
         0.03625 , 0.046625, 0.06175 , 0.076875, 0.09225 , 0.108625,
         0.130875, 0.157375, 0.182625, 0.209   , 0.2285  , 0.246   ,
         0.262   , 0.272625, 0.27975 , 0.276875, 0.269125, 0.279   ,
         0.299875, 0.3005  , 0.294   , 0.291625, 0.293   , 0.290375,
         0.289   , 0.292875, 0.2935  , 0.290875, 0.287375, 0.278625,
         0.27325 , 0.278375, 0.281   , 0.27725 , 0.282375, 0.294   ,
         0.298125]]])
batch = [a1, a2]
data = [b[0, 0] for b in batch]
data = np.copy(data)
print('data:',data)
# data: [[0.079625 0.074    0.06025  0.0695   0.0635   0.0355   0.02225  0.02475
#   0.024125 0.028    0.027125 0.026875 0.023375 0.020125 0.019    0.017
#   0.0155   0.01525  0.015875 0.01575  0.015625 0.015375 0.018375 0.0235
#   0.026    0.025375 0.02525  0.02575  0.027375 0.029375 0.02975  0.028375
#   0.027125 0.026875 0.027    0.025125 0.02375  0.020875 0.018625 0.02025
#   0.02175  0.0225   0.022125 0.02125  0.0205   0.021625 0.0235   0.02275
#   0.020125 0.018375 0.017625 0.01925  0.021875 0.02075  0.0175   0.01725
#   0.01825  0.017375 0.016125 0.018125 0.01925  0.017125 0.016625 0.016
#   0.016125 0.02175  0.030625 0.04225  0.056875 0.073125 0.09     0.10625
#   0.126625 0.153125 0.1825   0.21275  0.24225  0.269625 0.289625 0.304125
#   0.315625 0.319    0.311625 0.31925  0.341625 0.347625 0.3435   0.3435
#   0.342125 0.33875  0.335125 0.33025  0.330625 0.3355   0.334375 0.326125
#   0.317625 0.318875 0.321375 0.321125 0.321625 0.3275   0.3305  ]
#  [0.115    0.0825   0.058125 0.038875 0.04375  0.046875 0.047    0.041375
#   0.032625 0.024875 0.022375 0.021125 0.02075  0.022375 0.021625 0.019125
#   0.017375 0.015    0.013375 0.016625 0.019875 0.0235   0.028375 0.0305
#   0.030875 0.032375 0.035    0.03725  0.04     0.042375 0.042625 0.043375
#   0.042375 0.039875 0.0385   0.036375 0.033875 0.032375 0.03075  0.029
#   0.0295   0.029625 0.030625 0.027625 0.024875 0.02525  0.025375 0.02625
#   0.026375 0.026    0.0265   0.026875 0.026875 0.027125 0.02675  0.024125
#   0.022875 0.021625 0.01975  0.019125 0.0195   0.019875 0.019625 0.02025
#   0.023625 0.0295   0.03625  0.046625 0.06175  0.076875 0.09225  0.108625
#   0.130875 0.157375 0.182625 0.209    0.2285   0.246    0.262    0.272625
#   0.27975  0.276875 0.269125 0.279    0.299875 0.3005   0.294    0.291625
#   0.293    0.290375 0.289    0.292875 0.2935   0.290875 0.287375 0.278625
#   0.27325  0.278375 0.281    0.27725  0.282375 0.294    0.298125]]
print('data.shape:',data.shape)
# data.shape: (2, 103)

Here, batch only sets two elements, and you can see that the final return value of shape is (2,103), which should be each sample as a line.

If there are 100 elements in the tuple batch, the shape of the data is (100,103), which is (batch_size, channel).

So look at this part of the code if patch_size== 1:

with torch.no_grad():
    if patch_size == 1:
        data = [b[0][0, 0] for b in batch]
        data = np.copy(data)
        data = torch.from_numpy(data)

First, test, so set it to run with torch.no_grad(): below.

data = [b[0][0, 0] for b in batch] extracts the first element of each element in the tuple batch (data for each sample) and forms a list called data.

The data is then converted from list to array via np.copy().

data = torch.from_numpy(data) converts data from array to tensor.

Obtain the predicted value output
indices = [b[1:] for b in batch]
data = data.to(device)
output = net(data)
if isinstance(output, tuple):
    output = output[0]
output = output.to('cpu')

if patch_size == 1 or center_pixel:
    output = output.numpy()
else:
    output = np.transpose(output.numpy(), (0, 2, 3, 1))

indices = [b[1:] for b in batch] gets index information.

data = data.to(device) puts data on the corresponding device.

output = net(data) gets the predicted value of data.

if isinstance(output, tuple): output = output[0] this sentence does not know, for the time being.

output = output.to('cpu') transfers output to cpu.

Then, in the case of patch_size= 1 or center_pixel, output is converted to array type output = output.numpy().

Statistical results
for (x, y, w, h), out in zip(indices, output):
    if center_pixel:
        probs[x + w // 2, y + h // 2] += out
    else:
        probs[x:x + w, y:y + h] += out

First of all, it is emphasized that when returning the result probs to initialize, it is initialized to all 0 (generally only one ignored_label is set and its corresponding label is set as 0).

The general meaning of the whole test is to get a batch sample (batch_size) by grouper() each time, and then update their predicted output to probs. When such a batch is completed one by one, the predicted results of all training set samples (except for training samples, the corresponding position value is all zero) are obtained.

Other brief.

inference.py

This part of the code is duplicated with main.py.

(Provisional)

datasets.py

This file contains PyTorch data sets for hyperspectral images and related assistants.

DATASETS_CONFIG+Update

Data Set Configuration

DATASETS_CONFIG is a data set configuration dictionary and dictionary type. The key of the key-value pair is dataset_name, and the values are urls, img, and gt of the data set.

DATASETS_CONFIG = {
        'PaviaC': {
            'urls': ['http://www.ehu.eus/ccwintco/uploads/e/e3/Pavia.mat',      # urls are links
                     'http://www.ehu.eus/ccwintco/uploads/5/53/Pavia_gt.mat'],
            'img': 'Pavia.mat',
            'gt': 'Pavia_gt.mat'
            },
        'PaviaU': {
            'urls': ['http://www.ehu.eus/ccwintco/uploads/e/ee/PaviaU.mat',
                     'http://www.ehu.eus/ccwintco/uploads/5/50/PaviaU_gt.mat'],
            'img': 'PaviaU.mat',
            'gt': 'PaviaU_gt.mat'
            },
        'KSC': {
            'urls': ['http://www.ehu.es/ccwintco/uploads/2/26/KSC.mat',
                     'http://www.ehu.es/ccwintco/uploads/a/a6/KSC_gt.mat'],
            'img': 'KSC.mat',
            'gt': 'KSC_gt.mat'
            },
        'IndianPines': {
            'urls': ['http://www.ehu.eus/ccwintco/uploads/6/67/Indian_pines_corrected.mat',
                     'http://www.ehu.eus/ccwintco/uploads/c/c4/Indian_pines_gt.mat'],
            'img': 'Indian_pines_corrected.mat',
            'gt': 'Indian_pines_gt.mat'
            },
        'Botswana': {
            'urls': ['http://www.ehu.es/ccwintco/uploads/7/72/Botswana.mat',
                     'http://www.ehu.es/ccwintco/uploads/5/58/Botswana_gt.mat'],
            'img': 'Botswana.mat',
            'gt': 'Botswana_gt.mat',
            }
    }
Update data set configuration
try:
    from custom_datasets import CUSTOM_DATASETS_CONFIG
    DATASETS_CONFIG.update(CUSTOM_DATASETS_CONFIG)
except ImportError:
    pass

Essentially, the dictionary CUSTOM_DATASETS_CONFIG is updated to the dictionary DATASETS_CONFIG, using the update() function of the dictionary operation.

class TqdmUpTo(tqdm)

A class progress bar function.

(Provisional)

get_dataset()

#### Functions:

Download and read the data set.

#### Input and output:

Input:
  • dataset_name: string with the name of the dataset
  • target_folder (optional): folder to store the datasets, defaults to. /. Of course, I generally refer to location.
  • datasets (optional): dataset configuration dictionary, defaults to prebuilt one. Generally set to DATASETS_CONFIG.
Output:
  • img: 3D hyperspectral image (WxHxB), B band.
  • gt: 2D int array of labels
  • label_values: list of class names
  • ignored_labels: list of int classes to ignore
  • rgb_bands: int tuple, corresponding to red, green and blue bands

Code and parse:

Initialization parameters:
def get_dataset(dataset_name, target_folder="./", datasets=DATASETS_CONFIG):
# def get_dataset(dataset_name, target_folder="C:\\Users\\73416\\PycharmProjects\\HSIproject\\Datasets\\", datasets=DATASETS_CONFIG):
    """ Gets the dataset specified by name and return the related components.
    Args:
        dataset_name: string with the name of the dataset
        target_folder (optional): folder to store the datasets, defaults to ./  
        datasets (optional): dataset configuration dictionary, defaults to prebuilt one
    Returns:
        img: 3D hyperspectral image (WxHxB)
        gt: 2D int array of labels                      # Tag array
        label_values: list of class names               # Class list
        ignored_labels: list of int classes to ignore
        rgb_bands: int tuple that correspond to red, green and blue bands           # int tuple, corresponding to red, green and blue bands
    """
    target_folder = "C:\\Datasets\\"       # Self-adding, modifying the path of data sets
    # print(target_folder)  # Self adding

    palette = None

    # When the name of the input data set is not in the dataset dictionary datasets=DATASETS_CONFIG, an error dataset is unknown.
    if dataset_name not in datasets.keys():
        raise ValueError("{} dataset is unknown.".format(dataset_name))

    # Dictionary operation to obtain data set dictionary datasets, the key is the value of dataset_name (urls, img and gt)
    dataset = datasets[dataset_name]

    folder = target_folder + datasets[dataset_name].get('folder', dataset_name + '/')
    # folder: C: Datasets PaviaU/

This part is the initial parameters:

  • Target_folder: The storage path of Datasets folder. For example, target_folder= "C:\\ Datasets\"
  • palette: palette, initialized as None.
  • Dataset: Get the dataset dictionary datasets with the key as the value of dataset_name (urls, img, and gt)
  • Folder: The storage path of a particular dataset folder. For example, C: Datasets PaviaU/
Download data sets:
# Download the dataset if is not present
if dataset.get('download', True):
    # If there is no folder (C: Datasets PaviaU/) folder, create the folder
    if not os.path.isdir(folder):
        os.mkdir(folder)
    # Download data set (temporary pass)
    for url in datasets[dataset_name]['urls']:
        # download the files
        filename = url.split('/')[-1]
        if not os.path.exists(folder + filename):
            with TqdmUpTo(unit='B', unit_scale=True, miniters=1,
                      desc="Downloading {}".format(filename)) as t:
                urlretrieve(url, filename=folder + filename,
                                 reporthook=t.update_to)
elif not os.path.isdir(folder):
   print("WARNING: {} is not downloadable.".format(dataset_name))

if dataset.get('download', True):, the specified dataset is downloaded.

First, check whether the folder exists under the specified path folder, os.path.isdir(folder). If not, create folders under the specified path folder.

The code for downloading data sets (including acquiring urls, creating progress bars, etc.) is sketched.

Of course, there is also robustness checking for dataset_name, which is also temporary.

Read Data Set + Preprocess:
Data set reading:
# Read data sets
if dataset_name == 'PaviaC':
    # Load the image
    # Open the C: Datasets PaviaU / Pavia. mat file through the open_file() function written by ourselves. The return value is the dictionary type. The value in the key-value pair is extracted by ['pavia'].
    img = open_file(folder + 'Pavia.mat')['pavia']

    # Take RGB band, why don't you know?
    rgb_bands = (55, 41, 12)

    # Open the C: Datasets PaviaU / Pavia_gt. mat file through the open_file() function written by ourselves. The return value is the dictionary type. The value in the key-value pair is extracted by ['pavia_gt'].
    gt = open_file(folder + 'Pavia_gt.mat')['pavia_gt']

    # What's the use of label_values and how to link to gt?
    label_values = ["Undefined", "Water", "Trees", "Asphalt",
                    "Self-Blocking Bricks", "Bitumen", "Tiles", "Shadows",
                    "Meadows", "Bare Soil"]

    ignored_labels = [0]

elif dataset_name == 'PaviaU':
    # Load the image
    img = open_file(folder + 'PaviaU.mat')['paviaU']

    rgb_bands = (55, 41, 12)

    gt = open_file(folder + 'PaviaU_gt.mat')['paviaU_gt']

    label_values = ['Undefined', 'Asphalt', 'Meadows', 'Gravel', 'Trees',
                    'Painted metal sheets', 'Bare Soil', 'Bitumen',
                    'Self-Blocking Bricks', 'Shadows']

    ignored_labels = [0]

elif dataset_name == 'IndianPines':
    # Load the image
    img = open_file(folder + 'Indian_pines_corrected.mat')
    img = img['indian_pines_corrected']

    rgb_bands = (43, 21, 11)  # AVIRIS sensor

    gt = open_file(folder + 'Indian_pines_gt.mat')['indian_pines_gt']

    label_values = ["Undefined", "Alfalfa", "Corn-notill", "Corn-mintill",
                    "Corn", "Grass-pasture", "Grass-trees",
                    "Grass-pasture-mowed", "Hay-windrowed", "Oats",
                    "Soybean-notill", "Soybean-mintill", "Soybean-clean",
                    "Wheat", "Woods", "Buildings-Grass-Trees-Drives",
                    "Stone-Steel-Towers"]

    ignored_labels = [0]

elif dataset_name == 'Botswana':
    # Load the image
    img = open_file(folder + 'Botswana.mat')['Botswana']

    rgb_bands = (75, 33, 15)

    gt = open_file(folder + 'Botswana_gt.mat')['Botswana_gt']
    label_values = ["Undefined", "Water", "Hippo grass",
                    "Floodplain grasses 1", "Floodplain grasses 2",
                    "Reeds", "Riparian", "Firescar", "Island interior",
                    "Acacia woodlands", "Acacia shrublands",
                    "Acacia grasslands", "Short mopane", "Mixed mopane",
                    "Exposed soils"]

    ignored_labels = [0]

elif dataset_name == 'KSC':
    # Load the image
    img = open_file(folder + 'KSC.mat')['KSC']

    rgb_bands = (43, 21, 11)  # AVIRIS sensor

    gt = open_file(folder + 'KSC_gt.mat')['KSC_gt']

    label_values = ["Undefined", "Scrub", "Willow swamp",
                    "Cabbage palm hammock", "Cabbage palm/oak hammock",
                    "Slash pine", "Oak/broadleaf hammock",
                    "Hardwood swamp", "Graminoid marsh", "Spartina marsh",
                    "Cattail marsh", "Salt marsh", "Mud flats", "Wate"]

    ignored_labels = [0]
else:
    # See Custom Data Set Module for more details
    # Custom dataset
    img, gt, rgb_bands, ignored_labels, label_values, palette = CUSTOM_DATASETS_CONFIG[dataset_name]['loader'](folder)

This part is to read the downloaded data set files, including 3D image and 2D label. The operation of reading different data set files is also very repetitive.

Every time you read a data set file, you do these things (for example, dataset_name='PaviaC'):

  • Read the data. Open the C: Datasets PaviaU / Pavia. mat file through the open_file() function written by ourselves. The return value is the dictionary type, and the value in the key-value pair is extracted by ['pavia']. Code: img = open_file (folder + Pavia. mat') ['pavia']
  • Take RGB band, but I don't know how. Code: rgb_bands = 55, 41, 12
  • Read GT. Open the C: Datasets PaviaU / Pavia_gt. mat file through the open_file() function written by ourselves. The return value is the dictionary type. The value in the key-value pair is extracted by ['pavia_gt']. Code: GT = open_file (folder + Pavia_gt. mat') ['pavia_gt']
  • Determine label_values. Code: label_values = ["Undefined", "Water", "Trees", "Asphalt", "Self-Blocking Bricks", "Bitumen", "Tiles", "Shadows", "Meadows", "Bare Soil"]
  • Define ignore_labels, usually 0. Code: ignored_labels = [0]

** Note that ** When the data set to be processed is not the predefined data set of the project (that is, the user's own data set is processed), the value is returned in the last else. See CUSTOM_DATASETS_CONFIG for details.

else:
    # See Custom Data Set Module for more details
    # Custom dataset
    img, gt, rgb_bands, ignored_labels, label_values, palette = CUSTOM_DATASETS_CONFIG[dataset_name]['loader'](folder)
Processing NaN:
# Processing NaN
# Filter NaN out
nan_mask = np.isnan(img.sum(axis=-1))
if np.count_nonzero(nan_mask) > 0:
   print("Warning: NaN have been found in the data. It is preferable to remove them beforehand. Learning on NaN data is disabled.")
img[nan_mask] = 0
gt[nan_mask] = 0
ignored_labels.append(0)
ignored_labels = list(set(ignored_labels))

This is not a common case. It's temporary.

Normalization normalization:
# Normalization normalization
img = np.asarray(img, dtype='float32')
img = (img - np.min(img)) / (np.max(img) - np.min(img))

Firstly, the type of each element of img is changed to float32 (img = np.asarray(img, dtype='float32'), and then normalization operation (img = img - np. min (img)/ (np. Max (img) - np. min (img)).

Return value:
return img, gt, label_values, ignored_labels, rgb_bands, palette
  • img: 3D hyperspectral image (WxHxB), B band.
  • gt: 2D int array of labels
  • label_values: list of class names
  • ignored_labels: list of int classes to ignore
  • rgb_bands: int tuple, corresponding to red, green and blue bands
  • palette: Returns to None by default.

class HyperX(torch.utils.data.Dataset)

This is a general class for hyperspectral scenes.

class HyperX(torch.utils.data.Dataset):

The class is named HyperX, and the inherited parent class is torch.utils.data.Dataset.

__ init__(self, data, gt, **hyperparams):

Functions:

Class properties are initialized.

Input and output:
Input:
  • data: 3D hyperspectral image graphics
  • gt: 2D array of labels tag
  • ** hyperparams: hyperparams is a dictionary containing hyperparameters. ** indicates that this location receives any number of keyword parameters (e.g., a=1,b=2,c=3,d=4,e=5, etc.). ** Store multi-input variables as dictionary types.
Output:

No.

Code and parse:
Read img, gt and hyperparameters
class HyperX(torch.utils.data.Dataset):
    """ Generic class for a hyperspectral scene """

    def __init__(self, data, gt, **hyperparams):    #Does ** hyperparams mean accepting an indefinite number of hyperparameters?
        """
        Args:
            data: 3D hyperspectral image    Graphical
            gt: 2D array of labels          Label
            patch_size: int, size of the spatial neighbourhood  (int,Size of spatial neighborhood)
            center_pixel: bool, set to True to consider only the label of the
                          center pixel  (bool Type, set to True Consider only the label of the central pixel)
            data_augmentation: bool, set to True to perform random flips    (Data Enhancement)
            supervision: 'full' or 'semi' supervised algorithms             (Supervision: supervision or semi-supervision
        """
        super(HyperX, self).__init__()
        # Read img
        self.data = data
        # Read gt
        self.label = gt
        # Read hyperparameters
        self.name = hyperparams['dataset']
        self.patch_size = hyperparams['patch_size']
        self.ignored_labels = set(hyperparams['ignored_labels'])
        self.flip_augmentation = hyperparams['flip_augmentation']
        self.radiation_augmentation = hyperparams['radiation_augmentation'] 
        self.mixture_augmentation = hyperparams['mixture_augmentation'] 
        self.center_pixel = hyperparams['center_pixel']
        supervision = hyperparams['supervision']

This line of code reading ignored_labels (self.ignored_labels = set(hyperparams['ignored_labels']) has a set() function that converts elements of an iterative object into a set type.

A small demo of a set() function:

x = set('runoob')
print(x)
# {'r', 'b', 'o', 'u', 'n'}
print(type(x))
# <class 'set'>

There's nothing else to say.

Supervisory methods:
# Supervision mode
# Fully supervised: use all pixels with label not ignored. Fully supervised: use all pixels with label not ignored.
if supervision == 'full':
    mask = np.ones_like(gt)
    for l in self.ignored_labels:
        mask[gt == l] = 0
# Semi-supervised: use all pixels, except padding. Semi-supervised: use all pixels except filling
elif supervision == 'semi':
    mask = np.ones_like(gt)

Full supervision is to use all the pixels that the label has not been ignored. The mask is to set the corresponding position of ignored_labels in gt to zero and the other positions to one.

Semi-supervised uses all pixels except padding, and mask s are all in one.

Get the index:
x_pos, y_pos = np.nonzero(mask)
p = self.patch_size // 2
self.indices = np.array([(x,y) for x,y in zip(x_pos, y_pos) if x > p and x < data.shape[0] - p and y > p and y < data.shape[1] - p])
self.labels = [self.label[x,y] for x,y in self.indices]
np.random.shuffle(self.indices)

Because ignore_labels in mask is 0, the index of non-zero elements in mask is obtained by np.nonzero(), and the return value is a tuple composed of two array s, one is the index of x axis and the other is the index of y axis.

Take P = self. patch_size// 2 ("/" means floating-point division, returns floating-point results; "/" means integer division), but why this is not known. But my guess is to block the pictures into different blocks. And I think the whole picture is divided into four blocks, similar to the shape of "Tian".

The next step is to get an index of the elements in the specified range, which is x [(p, data. shape [0] - p) and Y [(p, data.shape[1] - p). The method is to change the index into the form of (x, y) by zip() function, and then traverse the index to select the index in X [(p, data.shape[0] - p), y [(p, data.shape[1] - p), and assign the value to indices.

Then the labels under the corresponding index indices are obtained by self.labels = [self.label[x,y] for x,y in self.indices].

np.random.shuffle(self.indices) disrupts self.indices. (But the label hasn't been disturbed with it?)

flip(*arrays)

Functions:

Horizontal or vertical flips of multiple array arrays are performed on the input. (I suspect it's a way of data enhancement)

Input and output:
Input:
  • Arays: Multiple arrays. The type is unknown.

* Represents that this location receives any number of non-keyword parameters (such as 1,2,3,4,5 equivalents)* and stores multiple input variables as tuple types.

Output:
  • Arays: Multiple arrays. The type is unknown.
Code and parse:
def flip(*arrays):
    horizontal = np.random.random() > 0.5
    vertical = np.random.random() > 0.5
    if horizontal:
        arrays = [np.fliplr(arr) for arr in arrays]
    if vertical:
        arrays = [np.flipud(arr) for arr in arrays]
    return arrays

Horizontal (left and right) and vertical (up and down) flips are random, which are realized by generating random numbers. The default probability in the code is 0.5. For p = 0.5, it is a random number, True if it is larger than 0.5, and False if it is less than 0.5.

When horizontal is True, each array is reversed horizontally (left and right) by traversing arrays.

When vertical is True, each array is reversed vertically (up and down) by traversing arrays.

radiation_noise()

Functions:

Add noise to data.

Input and output:
Input:
  • data: data with processing.
  • alpha_range: Reserved range of data, default to (0.9, 1.1)
  • beta: noise retention ratio, default 1/25
Output:
  • alpha * data + beta * noise: a combination of data and noise
Code and parse:
@staticmethod
def radiation_noise(data, alpha_range=(0.9, 1.1), beta=1/25):
    alpha = np.random.uniform(*alpha_range)
    noise = np.random.normal(loc=0., scale=1.0, size=data.shape)
    return alpha * data + beta * noise

Firstly, the value of alpha is determined by alpha_range, and the random number between (a, b) is generated by np.random.uniform(a, b) function. The * alpha_range represents the received variable of non-keyword type, and the variable is disassembled. alpha = np.random.uniform(*alpha_range).

noise is generated by a normal distribution. noise = np.random.normal(loc=0., scale=1.0, size=data.shape) denotes a random number with a mean of 0 and a standard deviation of 1.0, which obeys the normal distribution as data is size.

Finally, return the combination of data and noise: alpha * data + beta * noise.

mixture_noise()

I can't understand it, just a little.

__ len__()

Returns the length of the indices attribute of the object.

def __len__(self):
    return len(self.indices)

__ getitem__()

Functions:

An image block with the specified location i as the center and size as patch_size * patch_size is obtained.

At the same time, add data to enhance the effect.

Here item refers to "image block".

Input and output:
Input:
  • i: Index of the location taken
Output:
  • data: (Batch x) Planes x Channels x Width x Height
  • Label: label
Code and parse:
Get the image block:
def __getitem__(self, i):
    x, y = self.indices[i]
    x1, y1 = x - self.patch_size // 2, y - self.patch_size // 2
    x2, y2 = x1 + self.patch_size, y1 + self.patch_size
    
    data = self.data[x1:x2, y1:y2]
    label = self.label[x1:x2, y1:y2]

This part is to get the block of image with the specified position i as the center and size as patch_size * patch_size.

Probably the principle is as follows:

x2, y2
x, y
x1, y1

By subtracting self.patch_size// 2 from X and y, X1 and Y1 are obtained, and then by adding self.patch_size to X1 and y1, x2 and Y2 are obtained.

Then the corresponding data and label image blocks are obtained by [x1:x2, y1:y2].

Data Enhancement:
if self.flip_augmentation and self.patch_size > 1:
    # Perform data augmentation (only on 2D patches)
    data, label = self.flip(data, label)
if self.radiation_augmentation and np.random.random() < 0.1:
        data = self.radiation_noise(data)
if self.mixture_augmentation and np.random.random() < 0.2:
        data = self.mixture_noise(data, label)

Here data enhancement is not performed by default, but requires self.flip_augmentation == True or self.radiation_augmentation== True or self.mixture_augmentation== True.

In the case of self.flip_augmentation == True, self. patch_size > 1 is required, and only 2D data is executed.

In the case of self. radiation_augmentation=== True, np. random. random () < 0.1, i.e., only 10% of the probability of executing the operation is required.

In the case of self. mixture_augmentation=== True, np. random. random () < 0.2, i.e., only 20% of the probability of executing the operation is required.

data and label are converted to ndarray type:
# Copy the data into numpy arrays (PyTorch doesn't like numpy views)
data = np.asarray(np.copy(data).transpose((2, 0, 1)), dtype='float32')
label = np.asarray(np.copy(label), dtype='int64')

This section converts data and label to ndarray type. Source types are not yet clear.

PyTorch does't like numpy views. Numpy's view is channel * row * column, that is, C * W * H. So we need to refer the third dimension to the position of the first dimension by transposition.

numpy's view is C * W * H, while PyTorch's is W * H * C.

Data = np. asarray (np. copy (data). transpose ((2, 0, 1), dtype ='float32') copies data first as ndarray type, and then calls transpose() to change the order of dimensions from W * H * C to C * W * H. Finally, the type of each element is set to float32 by dtype='float32'.

For label, copy directly and change ndarray, then set the element type to int64, Ok.

ndarray to tensor:
# Load the data into PyTorch tensors
data = torch.from_numpy(data)
label = torch.from_numpy(label)

Simple application torch.from_numpy(), nothing to say.

Extract the center label if needed:
# Extract the center label if needed
if self.center_pixel and self.patch_size > 1:
    label = label[self.patch_size // 2, self.patch_size // 2]
# Remove unused dimensions when we work with invidual spectrums
elif self.patch_size == 1:
    data = data[:, 0, 0]
    label = label[0, 0]

For self. center_pixel== True and self.patch_size > 1, for label with size of self.patch_size * self.patch_size // 2, self.patch_size // / 2] only.

This part of the good trumpet is temporary and can't understand.

# Add a fourth dimension for 3D CNN
if self.patch_size > 1:
    # Make 4D data ((Batch x) Planes x Channels x Width x Height)
    data = data.unsqueeze(0)

This part adds a dimension to the data's position in the first dimension, as Batch, that is, the last dimension order is Batch x Channels x Width x Height.

The unsqueeze() function adds dimensions, and axis = 0 adds dimensions at the location of the first dimension.

Return value:
return data, label

The return value is data, label. The type is tensor.

  |      | x2, y2 |

| :----: | :–: | :----: |
| | x, y | |
| x1, y1 | | |

By subtracting self.patch_size// 2 from X and y, X1 and Y1 are obtained, and then by adding self.patch_size to X1 and y1, x2 and Y2 are obtained.

Then the corresponding data and label image blocks are obtained by [x1:x2, y1:y2].

Data Enhancement:
if self.flip_augmentation and self.patch_size > 1:
    # Perform data augmentation (only on 2D patches)
    data, label = self.flip(data, label)
if self.radiation_augmentation and np.random.random() < 0.1:
        data = self.radiation_noise(data)
if self.mixture_augmentation and np.random.random() < 0.2:
        data = self.mixture_noise(data, label)

Here data enhancement is not performed by default, but requires self.flip_augmentation == True or self.radiation_augmentation== True or self.mixture_augmentation== True.

In the case of self.flip_augmentation == True, self. patch_size > 1 is required, and only 2D data is executed.

In the case of self. radiation_augmentation=== True, np. random. random () < 0.1, i.e., only 10% of the probability of executing the operation is required.

In the case of self. mixture_augmentation=== True, np. random. random () < 0.2, i.e., only 20% of the probability of executing the operation is required.

data and label are converted to ndarray type:
# Copy the data into numpy arrays (PyTorch doesn't like numpy views)
data = np.asarray(np.copy(data).transpose((2, 0, 1)), dtype='float32')
label = np.asarray(np.copy(label), dtype='int64')

This section converts data and label to ndarray type. Source types are not yet clear.

PyTorch does't like numpy views. Numpy's view is channel * row * column, that is, C * W * H. So we need to refer the third dimension to the position of the first dimension by transposition.

numpy's view is C * W * H, while PyTorch's is W * H * C.

Data = np. asarray (np. copy (data). transpose ((2, 0, 1), dtype ='float32') copies data first as ndarray type, and then calls transpose() to change the order of dimensions from W * H * C to C * W * H. Finally, the type of each element is set to float32 by dtype='float32'.

For label, copy directly and change ndarray, then set the element type to int64, Ok.

ndarray to tensor:
# Load the data into PyTorch tensors
data = torch.from_numpy(data)
label = torch.from_numpy(label)

Simple application torch.from_numpy(), nothing to say.

Extract the center label if needed:
# Extract the center label if needed
if self.center_pixel and self.patch_size > 1:
    label = label[self.patch_size // 2, self.patch_size // 2]
# Remove unused dimensions when we work with invidual spectrums
elif self.patch_size == 1:
    data = data[:, 0, 0]
    label = label[0, 0]

For self. center_pixel== True and self.patch_size > 1, for label with size of self.patch_size * self.patch_size // 2, self.patch_size // / 2] only.

This part of the good trumpet is temporary and can't understand.

# Add a fourth dimension for 3D CNN
if self.patch_size > 1:
    # Make 4D data ((Batch x) Planes x Channels x Width x Height)
    data = data.unsqueeze(0)

This part adds a dimension to the data's position in the first dimension, as Batch, that is, the last dimension order is Batch x Channels x Width x Height.

The unsqueeze() function adds dimensions, and axis = 0 adds dimensions at the location of the first dimension.

Return value:
return data, label

The return value is data, label. The type is tensor.

Topics: network Python Windows encoding