18hook function and CAM visualization

Posted by ThunderVike on Thu, 06 Feb 2020 07:01:40 +0100

1, The concept of Hook function

1.1 reasons for hook introduction

The operation mechanism of Python is dynamic calculation graph. After the operation of dynamic graph, some intermediate variables (such as the gradient of feature map and non leaf node) will be released, but sometimes we need to obtain these intermediate variables. At this time, we can obtain or change the intermediate variables by adding additional functions in the main body through Hook function according to Hook mechanism

1.2 Hook function mechanism

Hook function mechanism: does not change the main body (forward propagation and backward propagation), and implements additional functions, such as a pendant, hook, hook


The operation mechanism of call() function in nn.module is hook function mechanism. The whole call function is divided into four parts, namely:

  • forward_pre_hook
  • forward
  • forward_hook
  • backward_hook

As shown in the figure above, the call() function executes the forward ﹣ pre ﹣ hook function, then the forward propagation process, then the forward ﹣ hook function, and finally the back ﹣ forward function
Therefore, in the forward propagation process, not only the propagation of the preceding item is performed, but also the hook function interface is provided to realize additional operations and functions

1.3 four hook functions

It can be divided into three categories: for sensor, forward propagation and backward propagation

  1. torch.Tensor.register_hook(hook)
  2. torch.nn.Module.register_forward_hook
  3. torch.nn.Module.register_forward_pre_hook
  4. torch.nn.Module.register_backward hook

2, Hook function and feature extraction

2.1 Tensor.register_hook

hook(grad)

Function: register a back propagation hook function

The Hook function has only one input parameter, which is the gradient of the tensor and returns the tensor or no return

Example: get and change the gradient of non leaf node through hook function

# -*- coding:utf-8 -*-

import torch
import torch.nn as nn
from tools.common_tools import set_seed

set_seed(1)  # Set random seed


# ----------------------------------- 1 tensor hook 1 -----------------------------------
# flag = 0
flag = 1
if flag:

    w = torch.tensor([1.], requires_grad=True)
    x = torch.tensor([2.], requires_grad=True)
    a = torch.add(w, x)
    b = torch.add(w, 1)
    y = torch.mul(a, b)

    a_grad = list()        # Gradient of storage tensor

    def grad_hook(grad):
        a_grad.append(grad)

    handle = a.register_hook(grad_hook) # Register the defined function to the corresponding tensor

    y.backward()

    # View gradient
    print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
    print("a_grad[0]: ", a_grad[0])
    handle.remove()


# ----------------------------------- 2 tensor hook 2 -----------------------------------
# flag = 0
flag = 1
if flag:

    w = torch.tensor([1.], requires_grad=True)
    x = torch.tensor([2.], requires_grad=True)
    a = torch.add(w, x)
    b = torch.add(w, 1)
    y = torch.mul(a, b)

    a_grad = list()

    def grad_hook(grad):               # Define hook function to modify tensor gradient
        grad *= 2
        return grad*3                  # The gradient returned through return will overwrite the original gradient

    handle = w.register_hook(grad_hook)

    y.backward()

    # View gradient
    print("w.grad: ", w.grad)
    handle.remove()

Operation result:

Explain:

  • By defining the list, the gradient of the tensor is saved in the list through the hook function, so that the gradient can be obtained after the operation
  • If the custom hook function return returns a gradient, the returned gradient will cover the original gradient

2.2 Module.register_forward _hook

hook(module, input, output)

Function: register the forward propagation hook function of the module

Return value: None

Parameters:

  • module: current network layer
  • Input: current network layer input data
  • Output: output data of current network layer

Example: get the feature map in forward propagation

# -*- coding:utf-8 -*-

import torch
import torch.nn as nn
from tools.common_tools import set_seed

set_seed(1)  # Set random seed

# ----------------------------------- 3 Module.register_forward_hook and pre hook -----------------------------------
# flag = 0
flag = 1
if flag:

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 2, 3)
            self.pool1 = nn.MaxPool2d(2, 2)

        def forward(self, x):
            x = self.conv1(x)
            x = self.pool1(x)
            return x

    def forward_hook(module, data_input, data_output):
        fmap_block.append(data_output)
        input_block.append(data_input)

    def forward_pre_hook(module, data_input):
        print("forward_pre_hook input:{}".format(data_input))

    def backward_hook(module, grad_input, grad_output):
        print("backward hook input:{}".format(grad_input))
        print("backward hook output:{}".format(grad_output))

    # Initialize network
    net = Net()
    net.conv1.weight[0].detach().fill_(1)
    net.conv1.weight[1].detach().fill_(2)
    net.conv1.bias.data.detach().zero_()

    # Registered hook
    fmap_block = list()
    input_block = list()
    net.conv1.register_forward_hook(forward_hook)
    # net.conv1.register_forward_pre_hook(forward_pre_hook)
    # net.conv1.register_backward_hook(backward_hook)

    # inference
    fake_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * W
    output = net(fake_img)

    # loss_fnc = nn.L1Loss()
    # target = torch.randn_like(output)
    # loss = loss_fnc(target, output)
    # loss.backward()

    # observation
    print("output shape: {}\noutput value: {}\n".format(output.shape, output))
    print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
    print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))

2.3 Module.register_forward_pre_hook

hook(module, input)

Function: register the hook function before module forward propagation

Return value: None

Parameters:

  • module: current network layer
  • Input: current network layer input data

2.4 Module.register_backward_hook

hook(module, grad_input, grad_output)

Function: register the hook function of module back propagation

Return value: sensor or none

Parameters:

  • module: current network layer
  • Grad Ou input: current network layer input gradient data
  • Grad Ou output: output gradient data of current network layer

Example: use of forward ﹐ pre ﹐ hook and backward ﹐ hook

# -*- coding:utf-8 -*-

import torch
import torch.nn as nn
from tools.common_tools import set_seed

set_seed(1)  # Set random seed


# ----------------------------------- 3 Module.register_forward_hook and pre hook -----------------------------------
# flag = 0
flag = 1
if flag:

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 2, 3)
            self.pool1 = nn.MaxPool2d(2, 2)

        def forward(self, x):
            x = self.conv1(x)
            x = self.pool1(x)
            return x

    def forward_hook(module, data_input, data_output):
        fmap_block.append(data_output)
        input_block.append(data_input)

    def forward_pre_hook(module, data_input):
        print("forward_pre_hook input:{}".format(data_input))

    def backward_hook(module, grad_input, grad_output):
        print("backward hook input:{}".format(grad_input))
        print("backward hook output:{}".format(grad_output))

    # Initialize network
    net = Net()
    net.conv1.weight[0].detach().fill_(1)
    net.conv1.weight[1].detach().fill_(2)
    net.conv1.bias.data.detach().zero_()

    # Registered hook
    fmap_block = list()
    input_block = list()
    net.conv1.register_forward_hook(forward_hook)
    net.conv1.register_forward_pre_hook(forward_pre_hook)
    net.conv1.register_backward_hook(backward_hook)

    # inference
    fake_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * W
    output = net(fake_img)

    loss_fnc = nn.L1Loss()
    target = torch.randn_like(output)
    loss = loss_fnc(target, output)
    loss.backward()

2.5 use hook function to visualize feature map

# -*- coding:utf-8 -*-
import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from tools.common_tools import set_seed
import torchvision.models as models

set_seed(1)  # Set random seed

# ----------------------------------- feature map visualization -----------------------------------
# flag = 0
flag = 1
if flag:
    writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")

    # data
    path_img = "./lena.png"     # your path to image
    normMean = [0.49139968, 0.48215827, 0.44653124]
    normStd = [0.24703233, 0.24348505, 0.26158768]

    norm_transform = transforms.Normalize(normMean, normStd)
    img_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        norm_transform
    ])

    img_pil = Image.open(path_img).convert('RGB')
    if img_transforms is not None:
        img_tensor = img_transforms(img_pil)
    img_tensor.unsqueeze_(0)    # chw --> bchw

    # Model
    alexnet = models.alexnet(pretrained=True)

    # Registered hook
    fmap_dict = dict()
    for name, sub_module in alexnet.named_modules():  # named_modules() returns the subnet layer and its name of the network

        if isinstance(sub_module, nn.Conv2d):
            key_name = str(sub_module.weight.shape)
            fmap_dict.setdefault(key_name, list())

            n1, n2 = name.split(".")

            def hook_func(m, i, o):
                key_name = str(m.weight.shape)
                fmap_dict[key_name].append(o)

            alexnet._modules[n1]._modules[n2].register_forward_hook(hook_func)

    # forward
    output = alexnet(img_tensor)

    # add image
    for layer_name, fmap_list in fmap_dict.items():
        fmap = fmap_list[0]
        fmap.transpose_(0, 1)

        nrow = int(np.sqrt(fmap.shape[0]))
        fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow)
        writer.add_image('feature map in {}'.format(layer_name), fmap_grid, global_step=322)



3, Cam and grad cam

3.1 CAM

CAM: Class activation map

Function: according to the output of the network, analyze the output that the network focuses on the part of the image

Basic idea: weighted sum and average the last characteristic graph of network output to get an attention mechanism

Operation mechanism: pool the final feature map with global average, convert the feature map into vector form, each channel corresponds to a neuron, then another fc layer is output, and the weight of neuron corresponding to the class of image output is the weight of the feature map

Disadvantages: the network model needs to be changed, and the final output of the network must be globally averaged and pooled to obtain the weight, because it is often necessary to change the later network layer to retrain

CAM: <Learning Deep Features for Discriminative Localization>

3.2 Grad-CAM

Grad CAM: an improved version of CAM, which uses gradient as the weight of feature map

Operation mechanism: according to the output vector, carry out backward, calculate the gradient of the feature map, get the gradient corresponding to each pixel on each feature map, that is, the gradient map corresponding to the feature map, and then average each gradient map, the average value corresponds to the weight of each feature map, and then weighted sum the weight and the feature map, finally through the relu activation function You can get the final class activation graph

Grad-CAM: <Grad-CAM:Visual Explanations from Deep Networks via Gradient-based Localization>

Experimental example:

Analysis and code: https://xuanlan.zhihu.com/p/75894080

Published 88 original articles, won praise 9, visited 7056
Private letter follow

Topics: network Python