In deep learning projects_ 5 - Preparation of training module

Posted by mysterbx on Wed, 02 Mar 2022 12:03:17 +0100

1. Steps in the training module

Training modules are generally saved in train Py, which generally includes the following steps:

  1. Import various modules (standard library, third-party library, CV, torch, torch vision), if in model Py (self defined network model file), loss Py (custom loss function), utils Py (various customized methods), config Py (configuration file of the overall project), these modules need to be imported

  2. Command line parsing

  3. Dataset loading

  4. Check whether the model saving address exists. If it does not exist, create it;

  5. Instantiated network model;

  6. Instantiation loss function and optimizer

  7. Prepare event files to facilitate Tensorboard --logdir = "run" and visualize the training process;

  8. Check whether it is adopted, and then the last checkpoint training. If it is, load checkpoint Model;

  9. Start training, cycle epichs:
    – set the gradient to zero;
    – ask for loss;
    – back propagation;
    – update weight parameters;
    – update the learning rate in the optimizer (optional)

  10. Visual indicators;

  11. Validate the valid model, (adjust the model's super parameters according to the loss and measurement of the model in the validation set)

2. Code example of training module

2.1 training module demonstration I

import os
import torch
from torch.utils.data import DataLoader
from torch import nn
import argparse
from tensorboardX import SummaryWriter

from data_preparation.data_preparation import FileDateset
from model.Baseline import Base_model
from model.ops import pytorch_LSD


def parse_args():
    parser = argparse.ArgumentParser()
    # Start training again, defaule=None, continue training, and set defaule to '/ * * pth'
    parser.add_argument("--model_name", type=str, default=None, help="Load model to continue training '/50.pth' None")
    parser.add_argument("--batch-size", type=int, default=16, help="")
    parser.add_argument("--epochs", type=int, default=20)
    parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate (default: 0.01)')
    parser.add_argument('--train_data', default="./data_preparation/Synthetic/TRAIN", help='Data set path')
    parser.add_argument('--val_data', default="./data_preparation/Synthetic/VAL", help='Validation sample path')
    parser.add_argument('--checkpoints_dir', default="./checkpoints/AEC_baseline", help='Path to the model checkpoint file(To continue training)')
    parser.add_argument('--event_dir', default="./event_file/AEC_baseline", help='tensorboard Address of the event file')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    print("GPU Available:", torch.cuda.is_available())  # True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Instantiate Dataset
    train_set = FileDateset(dataset_path=args.train_data)  # Instantiated training dataset
    val_set = FileDateset(dataset_path=args.val_data)  # Instantiate validation dataset

    # Data Loader 
    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False, drop_last=True)
    val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, drop_last=True)

    # ###########    Address to save checkpoint(If the checkpoint does not exist, it is created)   ############
    if not os.path.exists(args.checkpoints_dir):
        os.makedirs(args.checkpoints_dir)

    ################################
    #          Instantiation model          #
    ################################
    model = Base_model().to(device)  # Instantiation model
    # summary(model, input_size=(322, 999))  # Model output torch Size([64, 322, 999])
    # ###########    loss function    ############
    criterion = nn.MSELoss(reduce=True, size_average=True, reduction='mean')

    ###############################
    # Create optimizer Create optimizers #
    ###############################
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, )
    # lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,20], gamma=0.1)

    # ###########    TensorBoard visualization summary  ############
    writer = SummaryWriter(args.event_dir)  # Create event file

    # ###########    Load model checkpoints   ############
    start_epoch = 0
    if args.model_name:
        print("Load model:", args.checkpoints_dir + args.model_name)
        checkpoint = torch.load(args.checkpoints_dir + args.model_name)
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        start_epoch = checkpoint['epoch']
        # lr_schedule.load_state_dict(checkpoint['lr_schedule'])  # Load lr_scheduler

    for epoch in range(start_epoch, args.epochs):
        model.train()  # Training model
        for batch_idx, (train_X, train_mask, train_nearend_mic_magnitude, train_nearend_magnitude) in enumerate(
                train_loader):
            train_X = train_X.to(device)  # Remote voice cat microphone voice [batch_size, 322, 999] (, F, T)
            train_mask = train_mask.to(device)  # IRM [batch_size 161, 999]
            train_nearend_mic_magnitude = train_nearend_mic_magnitude.to(device)
            train_nearend_magnitude = train_nearend_magnitude.to(device)

            # Forward propagation
            pred_mask = model(train_X)  # [batch_size, 322, 999]--> [batch_size, 161, 999]
            train_loss = criterion(pred_mask, train_mask)

            # Near end voice signal spectrum = mask * microphone signal spectrum [batch_size, 161, 999]
            pred_near_spectrum = pred_mask * train_nearend_mic_magnitude
            train_lsd = pytorch_LSD(train_nearend_magnitude, pred_near_spectrum)

            # Back propagation
            optimizer.zero_grad()  # Clear gradient
            train_loss.backward()  # Back propagation
            optimizer.step()  # Update parameters

            # ###########    Visual printing   ############
        print('Train Epoch: {} Loss: {:.6f} LSD: {:.6f}'.format(epoch + 1, train_loss.item(), train_lsd.item()))

        # ###########    TensorBoard visualization summary  ############
        # lr_schedule.step()  # Learning rate attenuation
        # writer.add_scalar(tag="lr", scalar_value=model.state_dict()['param_groups'][0]['lr'], global_step=epoch + 1)
        writer.add_scalar(tag="train_loss", scalar_value=train_loss.item(), global_step=epoch + 1)
        writer.add_scalar(tag="train_lsd", scalar_value=train_lsd.item(), global_step=epoch + 1)
        writer.flush()

        # Performance of neural network on validation data set
        model.eval()  # test model 
        # No gradient is required for testing
        with torch.no_grad():
            for val_batch_idx, (val_X, val_mask, val_nearend_mic_magnitude, val_nearend_magnitude) in enumerate(
                    val_loader):
                val_X = val_X.to(device)  # Remote voice cat microphone voice [batch_size, 322, 999] (, F, T)
                val_mask = val_mask.to(device)  # IRM [batch_size 161, 999]
                val_nearend_mic_magnitude = val_nearend_mic_magnitude.to(device)
                val_nearend_magnitude = val_nearend_magnitude.to(device)

                # Forward propagation
                val_pred_mask = model(val_X)
                val_loss = criterion(val_pred_mask, val_mask)

                # Near end voice signal spectrum = mask * microphone signal spectrum [batch_size, 161, 999]
                val_pred_near_spectrum = val_pred_mask * val_nearend_mic_magnitude
                val_lsd = pytorch_LSD(val_nearend_magnitude, val_pred_near_spectrum)

            # ###########    Visual printing   ############
            print('  val Epoch: {} \tLoss: {:.6f}\tlsd: {:.6f}'.format(epoch + 1, val_loss.item(), val_lsd.item()))
            ######################
            # to update tensorboard    #
            ######################
            writer.add_scalar(tag="val_loss", scalar_value=val_loss.item(), global_step=epoch + 1)
            writer.add_scalar(tag="val_lsd", scalar_value=val_lsd.item(), global_step=epoch + 1)
            writer.flush()

        # # ###########    Save model   ############
        if (epoch + 1) % 10 == 0:
            checkpoint = {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "epoch": epoch + 1,
                # 'lr_schedule': lr_schedule.state_dict()
            }
            torch.save(checkpoint, '%s/%d.pth' % (args.checkpoints_dir, epoch + 1))


if __name__ == "__main__":
    main()

2.2 training module demonstration II

Author: Shi Lang
 Link: https://www.zhihu.com/question/406133826/answer/1334319004
 Source: Zhihu
 The copyright belongs to the author. For commercial reprint, please contact the author for authorization. For non-commercial reprint, please indicate the source.

# Define network
net = Net()

# Define data
#Data preprocessing, 1 Turn to tensor, 2 normalization
transform = transforms.Compose(    
     [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# Training set
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)
# Validation set
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

# Define loss function and optimizer 
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# Start training
net.train()
for epoch in range(2):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        # Set gradient to 0
        # zero the parameter gradients
        optimizer.zero_grad()
        # Ask for loss
        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        # Gradient back propagation
        loss.backward()
        # Update parameters by gradient
        optimizer.step()

        # visualization
        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

# View the effect on the validation set
dataiter = iter(testloader)
images, labels = dataiter.next()

# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

net.eval()
outputs = net(images)
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
                              for j in range(4)))

3 super parameters input in training module

3.1 super parameter definition

In the process of machine learning,
Hyper parameter = a parameter set artificially before starting machine learning.

Model parameters = parameter data obtained through training.
Usually, we need to optimize the super parameters and select a set of optimal super parameters for the learning machine to improve the performance and effect of learning

3.2 common hyperparameters in deep learning

A deep learning network has many parameters that can be configured, which are generally divided into the following three categories:

  1. Dataset parameters (file path, batch_size, etc.)
  2. Training parameters (learning rate, training epoch, etc.)
  3. Model parameters (input size, output size)

These parameters can be saved in a class or a dictionary, and then saved in json. These parameters need to be implemented by yourself, but these are some minor things,
Write more times and find your favorite way. It's not a necessary part of the deep learning project.

Topics: Algorithm Machine Learning Deep Learning