Convolutional neural network UNET learning

Posted by blackmamba on Thu, 24 Feb 2022 15:52:48 +0100

Convolutional neural network UNET learning

Video portal (entrance of station B)

nn.Sequential

As an ordered container, the neural network module will be added to the calculation diagram in the order of the incoming constructor. At the same time, the ordered dictionary with the neural network module as the element can also be used as the incoming parameter.

Sequential portal

Batchnorm

First of all, this part is about why the deep network needs B a T C H n o r m batchnorm. We all know that the deep learning, especially the CV, needs to normalize the data, because the deep neural network is mainly to learn the distribution of training data and achieve good generalization effect on the test set, If the data we input in each batch has different distribution, it will obviously bring difficulties to the network training. On the other hand, the data distribution is also changing after one layer of network calculation. This phenomenon is called i n t e r n a l internal c o v a R I a t e covariate covariate s h i f t shift. It will be explained in detail later, which will bring difficulties to the next layer of network learning. B a T C H n o r m batchnorm is batch normalization, which is to solve the problem of distribution change.

Batchnorm portal

Conv2d

in_channels: the number of channels entered by the network.
out_channels: number of channels for network output.
kernel_size: the size of the convolution kernel. If the parameter is an integer q, the size of the convolution kernel is qXq.
Stripe: step size. Is the step size moved in the convolution process. The default is 1. Generally, the convolution kernel moves from left to right and from top to bottom on the input image. If the parameter is an integer, it defaults to this integer in both horizontal and vertical directions. If the parameter is stripe = (2, 1), 2 means that the height (h) progress step is 2, and 1 means that the width (w) progress step is 1.
Padding: padding. The default is 0 padding.
dilation: expansion. In general, the calculation between the convolution kernel and the corresponding position of the input image is of the same size, that is, the size of the convolution kernel is 3X3, so its area of each action on the input image is 3x3. In this case, division = 0.

Conv2d portal

ModuleList

nn.ModuleList this class, you can put any NN Subclasses of module (such as nn.Conv2d, nn.Linear, etc.) are added to this list. The method is the same as Python's own list, which is nothing more than extend, append and other operations. But different from the general list, add to NN The module in the modulelist will be registered in the whole network, and the parameters of the module will be automatically added to the whole network

ModuleList portal

MaxPool2d

kernel_ Size (int or tuple) - the window size of Max pooling
String (int or tuple, optional) - the step size of the window movement of max pool. The default value is kernel_size
padding(int or tuple, optional) - enter the number of layers of 0 for each edge
Translation (int or tuple, optional) – a parameter that controls the stride of elements in the window
return_indices - if equal to True, the sequence number of the maximum output value will be returned, which is helpful for the up sampling operation
ceil_mode - if equal to True, when calculating the output signal size, it will use up rounding instead of the default down rounding operation

MaxPool2d portal

tqdm

Tqdm is an extension package of python about progress bar. In the process of deep learning, the training process can be displayed in the form of progress bar, which will make the training interface more beautiful. The following describes the common functions and usage of tqdm.
Referring to the official documents, the common parameters of tqdm are:
desc ('str '): prefix of incoming progress bar
Min interval (float): minimum update time [default: 0.1] seconds
Maximal (float): maximum update time [default: 10] seconds Only in dynamic_miniters
Miniters (int or float): the minimum display update progress, if set to 0 or dynamic_ The miniters program will automatically adjust to adapt minitest Val to its items
ASCII (bool or str): if it is adjusted to True, ASCII (American Standard Code for information exchange) code will be used. If it is False by default, unicode will be used
ncols (int): the width of the entire output message
nrows (int): high speed of progress bar
dynamic_ncols(bool): it will continuously change ncols and nrows in the environment
smoothing (float): average moving factor and estimated time
bar_format(str): you can define one by yourself
Position (int): set the position of the print progress bar. You can set multiple bars
Colour (str): the colour of the progress bar
set_postfix: setting information

tqdm portal

model.py

import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

class DoubleConv(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(DoubleConv,self).__init__()
        self.conv=nn.Sequential(
            nn.Conv2d(in_channels,out_channels,3,1,1,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1  , 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self,x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(
            self,in_channels=3,out_channels=1,features=[64,128,256,512],
    ):
        super(UNET,self).__init__()
        self.ups=nn.ModuleList()
        self.downs=nn.ModuleList()
        self.pool=nn.MaxPool2d(kernel_size=2,stride=2)


        #Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels,feature))
            in_channels=feature

        # UP part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2,feature,kernel_size=2,stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2,feature))

        self.bottleneck=DoubleConv(features[-1],features[-1]*2)
        self.final_conv=nn.Conv2d(features[0],out_channels,kernel_size=1)

    def forward(self,x):
        skip_connections=[]
        for down in self.downs:
            x=down(x)
            skip_connections.append(x)
            x=self.pool(x)

        x=self.bottleneck(x)
        skip_connections=skip_connections[::-1]
        for idx in range(0,len(self.ups),2):
            x=self.ups[idx](x)
            skip_connection=skip_connections[idx//2]
            if x.shape != skip_connection.shape:
                x=TF.resize(x,size=skip_connection.shape[2:])

            concat_skip=torch.cat((skip_connection,x),dim=1)
            x=self.ups[idx+1](concat_skip)

        return self.final_conv(x)

def test():
    x=torch.randn((3,1,512,384))
    print(x.shape)
    model=UNET(in_channels=1,out_channels=1)
    preds=model(x)
    print(preds.shape)
    assert preds.shape==x.shape

if __name__ == '__main__':
    test()

train.py

import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from model import UNET
from utils import (
    load_checkpoint,
    save_checkpoint,
    get_loaders,
    check_accuracy,
    save_predictions_as_imgs,
)

# Hyperparameters etc.
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4
NUM_EPOCHS = 15
NUM_WORKERS = 2
IMAGE_WIDTH = 512  # 1918 originally
IMAGE_HEIGHT = 384  # 1280 originally
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = "mydata/data/train_images/"
TRAIN_MASK_DIR = "mydata/data/train_masks/"
VAL_IMG_DIR = "mydata/data/val_images/"
VAL_MASK_DIR = "mydata/data/val_masks/"

def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())

def main():
    train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    val_transforms = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
    )

    if LOAD_MODEL:
        load_checkpoint(torch.load("mydata/my_checkpoint.pth.tar"), model)


    check_accuracy(val_loader, model, device=DEVICE)
    scaler = torch.cuda.amp.GradScaler()
    max_score = 0
    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        # check accuracy
        Dice_score=check_accuracy(val_loader, model, device=DEVICE)
        if Dice_score>max_score:
            max_score=Dice_score
            # save model
            checkpoint = {
                "state_dict": model.state_dict(),
                # "optimizer":optimizer.state_dict(),
            }
            save_checkpoint(checkpoint)

            # print some examples to a folder
            save_predictions_as_imgs(
                val_loader, model, folder="mydata/saved_images/", device=DEVICE
            )


if __name__ == "__main__":
    man()

utils.py

import torch
import torchvision
from dataset import CarvanaDataset
from torch.utils.data import DataLoader

def save_checkpoint(state, filename="mydata/my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
    train_ds = CarvanaDataset(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform=train_transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_ds = CarvanaDataset(
        image_dir=val_dir,
        mask_dir=val_maskdir,
        transform=val_transform,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader

def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )
    Dice_score=dice_score/len(loader)
    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {Dice_score}")
    model.train()
    return Dice_score

def save_predictions_as_imgs(
    loader, model, folder="mydata/saved_images/", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()

dataset.py

import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

class CarvanaDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index])
        # mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", "_mask.gif"))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask != 0.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

Tip: if you want to train your own code, please download kaggle data set

Topics: Deep Learning CNN