SDMG-R model learning notes

Posted by bdata28 on Tue, 01 Mar 2022 23:44:58 +0100

The algorithm from Shang Tang is used for KIE and integrated in mmocr package. It needs to be used together with mmcv. Aside from the topic, mmcv uses hook programming, which is still very difficult to debug. I'll share the framework logic of mmcv when I'm free in the future.

model structure

The whole structure can be divided into three modules: dual-mode fusion module, graph reasoning module and classification module.

The input data of the model consists of the picture, the corresponding text detection coordinate area and the text content of the corresponding text area, such as:

{"file_name": "xxxx.jpg", "height": 1191, "width": 1685, "annotations": [{"box": [566, 113, 1095, 113, 1095, 145, 566, 145], "text": "yyyy", "label": 0}, {"box": [1119, 130, 1472, 130, 1472, 147, 1119, 147], "text": "aaaaa", "label": 1}, {"box": [299, 146, 392, 146, 392, 170, 299, 170], "text": "cccc", "label": 2}, {"box": [1447, 187, 1545, 187, 1545, 201, 1447, 201], "text": "dddd", "label": 0},]}

The first is the bimodal fusion module. The visual features are extracted by Unet and ROI pooling, and the semantic features are extracted by Bi LSTM. Then the multi-modal features are fused by Kronecker product, and then input into the spatial multi-modal reasoning model (graph reasoning module) to extract the final node features. Finally, the multi-classification task is carried out through the classification module;

Dual mode fusion module
Detailed steps of visual feature extraction:

  • Input the original picture and resize it to a fixed input size (512x512 in this paper);
  • Input to Unet and use Unet as the visual feature extractor to obtain the feature map of the last layer of CNN;
  • Mapping the text area coordinates () of the input size to the last layer of CNN feature map, and extracting the features through ROI pooling method to obtain the visual features of the corresponding text area image;

Corresponding code:
Location: mmocr \ models \ KIE \ extractors \ sdmgr py

    def extract_feat(self, img, gt_bboxes):
        if self.visual_modality:
		# Visual feature extraction
            x = super().extract_feat(img)[-1]
            feats = self.maxpool(self.extractor([x], bbox2roi(gt_bboxes)))
            return feats.view(feats.size(0), -1)
        return None
		
	

Unet network: an algorithm for image segmentation
Detailed explanation: Deep understanding of deep learning segmentation net work Unet
Corresponding code location: mmocr \ models \ common \ backbones \ UNET py

ROI Pooling: it is a kind of Pooling layer, and it is aimed at the Pooling of RoIs. Its feature is that the size of the input feature map is not fixed, but the size of the output feature map is fixed.
Detailed explanation: ROI Pooling layer analysis

Detailed steps of text semantic feature extraction:

  • Firstly, the character set table is collected. This paper collects 91 length character tables, covering numbers (0-9), letters (A-Z, A-Z), and special character sets of related tasks (such as "/", "n", "and". "), "$", "AC", "," "," ¥ ",": "," - "," * "," # ", etc.), characters not in the character table are uniformly marked as" unkown ";
  • Then, the text character content is mapped to the encoding form of one hot semantic input in 32 dimensions;
  • Then input it into Bi LSTM model to extract 256 dimensional semantic features;

Corresponding code:
Location: mmocr\models\kie\heads\sdmgr_head.py

     def forward(self, relations, texts, x=None):
        node_nums, char_nums = [], []
        for text in texts:
            node_nums.append(text.size(0))
            char_nums.append((text > 0).sum(-1))

        max_num = max([char_num.max() for char_num in char_nums])
        all_nodes = torch.cat([
            torch.cat(
                [text,
                 text.new_zeros(text.size(0), max_num - text.size(1))], -1)
            for text in texts
        ])
        embed_nodes = self.node_embed(all_nodes.clamp(min=0).long())
        rnn_nodes, _ = self.rnn(embed_nodes)

        nodes = rnn_nodes.new_zeros(*rnn_nodes.shape[::2])
        all_nums = torch.cat(char_nums)
        valid = all_nums > 0
        nodes[valid] = rnn_nodes[valid].gather(
            1, (all_nums[valid] - 1).unsqueeze(-1).unsqueeze(-1).expand(
                -1, -1, rnn_nodes.size(-1))).squeeze(1)

Visual + text semantic feature fusion steps:
Multimodal feature fusion: feature fusion is carried out through Kronecker product. The specific formula is as follows:

Corresponding code:

# Block is a custom class in the code. It is estimated that it is the Kronecker product written
self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim)

# Fusion of image features and text features
        if x is not None:
            nodes = self.fusion([x, nodes])

Graph reasoning module
In this paper, the document image is treated as a graph, and the final node features are completed through the multimodal graph reasoning model. The formula is as follows:


The calculation source code corresponding to the relationship code between nodes is as follows:

# The boxes here are all text boxes in a document. The dimension is [number of text boxes, 8]. 8 is the four coordinate values of the box, from left to right and from top to bottom

def compute_relation(boxes, norm: float = 10.):
    """Compute relation between every two boxes."""
    # Get minimal axis-aligned bounding boxes for each of the boxes
    # yapf: disable
    bboxes = np.concatenate(
        [boxes[:, 0::2].min(axis=1, keepdims=True),
         boxes[:, 1::2].min(axis=1, keepdims=True),
         boxes[:, 0::2].max(axis=1, keepdims=True),
         boxes[:, 1::2].max(axis=1, keepdims=True)],
        axis=1).astype(np.float32)
    # yapf: enable
    x1, y1 = boxes[:, 0:1], boxes[:, 1:2]
    x2, y2 = boxes[:, 4:5], boxes[:, 5:6]
    w, h = np.maximum(x2 - x1 + 1, 1), np.maximum(y2 - y1 + 1, 1)
    dx = (x1.T - x1) / norm
    dy = (y1.T - y1) / norm
    xhh, xwh = h.T / h, w.T / h
    whs = w / h + np.zeros_like(xhh)
    relation = np.stack([dx, dy, whs, xhh, xwh], -1).astype(np.float32)
    # bboxes = np.concatenate([x1, y1, x2, y2], -1).astype(np.float32)
    return relation, bboxes

Then, the information between text nodes is embedded into the weight of the edge. According to the following formula, the corresponding source code of this part is mainly located in GNNLayer class

This part mentioned in the paper is mainly to gradually optimize the node characteristics in an iterative way. See formula (13 ~ 14) in the paper for details:

	# Graph reasoning module
	
	# Equation 10
        all_edges = torch.cat(
            [rel.view(-1, rel.size(-1)) for rel in relations])
        embed_edges = self.edge_embed(all_edges.float())
        embed_edges = F.normalize(embed_edges)

        for gnn_layer in self.gnn_layers:
            nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums)
class GNNLayer(nn.Module):

    def __init__(self, node_dim=256, edge_dim=256):
        super().__init__()
        self.in_fc = nn.Linear(node_dim * 2 + edge_dim, node_dim)
        self.coef_fc = nn.Linear(node_dim, 1)
        self.out_fc = nn.Linear(node_dim, node_dim)
        self.relu = nn.ReLU()

    def forward(self, nodes, edges, nums):
        start, cat_nodes = 0, []
        for num in nums:
            sample_nodes = nodes[start:start + num]
            cat_nodes.append(
                torch.cat([
                    sample_nodes.unsqueeze(1).expand(-1, num, -1),
                    sample_nodes.unsqueeze(0).expand(num, -1, -1)
                ], -1).view(num**2, -1))
            start += num
		# Formula 11
        cat_nodes = torch.cat([torch.cat(cat_nodes), edges], -1)
		# Equation 12-13
        cat_nodes = self.relu(self.in_fc(cat_nodes))
        coefs = self.coef_fc(cat_nodes)
		
		# Equation 14
        start, residuals = 0, []
        for num in nums:
            residual = F.softmax(
                -torch.eye(num).to(coefs.device).unsqueeze(-1) * 1e9 +
                coefs[start:start + num**2].view(num, num, -1), 1)
            residuals.append(
                (residual *
                 cat_nodes[start:start + num**2].view(num, num, -1)).sum(1))
            start += num**2

        nodes += self.relu(self.out_fc(torch.cat(residuals)))
        return nodes, cat_nodes

Multi classification module
This part consists of two Linear layers, one corresponding node and one corresponding edge:

        self.node_cls = nn.Linear(node_embed, num_classes)
        self.edge_cls = nn.Linear(edge_embed, 2)
		# edge_cls shape is [node_num*2,2]
		 node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes)
		

Source code:

class SDMGRHead(BaseModule):

    def __init__(self,
                 num_chars=92,
                 visual_dim=64,
                 fusion_dim=1024,
                 node_input=32,
                 node_embed=256,
                 edge_input=5,
                 edge_embed=256,
                 num_gnn=2,
                 num_classes=26,
                 loss=dict(type='SDMGRLoss'),
                 bidirectional=False,
                 train_cfg=None,
                 test_cfg=None,
                 init_cfg=dict(
                     type='Normal',
                     override=dict(name='edge_embed'),
                     mean=0,
                     std=0.01)):
        super().__init__(init_cfg=init_cfg)
		# Text and visual information fusion module
        self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim)
        self.node_embed = nn.Embedding(num_chars, node_input, 0)
        hidden = node_embed // 2 if bidirectional else node_embed
		
		# Single layer lstm
        self.rnn = nn.LSTM(
            input_size=node_input,
            hidden_size=hidden,
            num_layers=1,
            batch_first=True,
            bidirectional=bidirectional)
        # Graph reasoning module
		self.edge_embed = nn.Linear(edge_input, edge_embed)
        self.gnn_layers = nn.ModuleList(
            [GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])
		# Classification module
        self.node_cls = nn.Linear(node_embed, num_classes)
        self.edge_cls = nn.Linear(edge_embed, 2)
        self.loss = build_loss(loss)

    def forward(self, relations, texts, x=None):
        # Relationship is the relation ship code between nodes. shape is [batch, number of text boxes, number of text boxes, 5], where 5 is fixed and represents the value corresponding to formula 7-9 above
		# texts is text information, and shape is [batch, number of text boxes, maximum value of characters in text box]
		# x is a graph feature
		node_nums, char_nums = [], []
        for text in texts:
            node_nums.append(text.size(0))
            char_nums.append((text > 0).sum(-1))
		
		# The length of the longest text in a batch of data
        max_num = max([char_num.max() for char_num in char_nums])
		
		# Perform padding operation
        all_nodes = torch.cat([
            torch.cat(
                [text,
                 text.new_zeros(text.size(0), max_num - text.size(1))], -1)
            for text in texts
        ])
		
		# Encoded text information
        embed_nodes = self.node_embed(all_nodes.clamp(min=0).long())
        rnn_nodes, _ = self.rnn(embed_nodes)

        nodes = rnn_nodes.new_zeros(*rnn_nodes.shape[::2])
        all_nums = torch.cat(char_nums)
        valid = all_nums > 0
        nodes[valid] = rnn_nodes[valid].gather(
            1, (all_nums[valid] - 1).unsqueeze(-1).unsqueeze(-1).expand(
                -1, -1, rnn_nodes.size(-1))).squeeze(1)
		
		# Fusion of visual features and text features
        if x is not None:
            nodes = self.fusion([x, nodes])
		
		# Graph reasoning module
		# Encode the edge relationship according to the spatial position relationship between the input two text boxes (important influence)
        all_edges = torch.cat(
            [rel.view(-1, rel.size(-1)) for rel in relations])
        embed_edges = self.edge_embed(all_edges.float())
        embed_edges = F.normalize(embed_edges)

        for gnn_layer in self.gnn_layers:
		# Although the input here is batch, the results of batch are spliced together when outputting
		# nodes.shape = [sum(batch_box_num),256]
		# cat_nodes.shape = [sum(batch_box_num^2),256]
            nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums)
		
		# Multi classification module
		# node_cls.shape = [sum(batch_box_num),label_num]
		# edge_cls .shape = [sum(batch_box_num^2),2]
        node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes)
        return node_cls, edge_cls
Model application
  • It is applicable to documents with relatively fixed layout
  • The visual module can be closed. Just declare it in the configuration file
model = dict(
    type='SDMGR',
    backbone=dict(type='UNet', base_channels=16),
    bbox_head=dict(
        # Num here_ Chars is a dictionary
        type='SDMGRHead', visual_dim=16, num_chars=123, num_classes=23, num_gnn=4),
    visual_modality=False, # This parameter controls whether the visual module is used
    train_cfg=None,
    test_cfg=None,
    class_list=f'{data_root}/../class_list.txt')
Ask and answer yourself
  • How is the relationship between edges in the dataset initialized?
    It is a matrix of [box_num,box_num], which is - 1 except for yourself and yourself, and all others are 1. Source code location: mmocr \ datasets \ KIE_ dataset. List in PY_ to_ Numpy function
        if labels is not None:
            labels = np.array(labels, np.int32)
            edges = ann_infos.get('edges', None)
            if edges is not None:
                labels = labels[:, None]
                edges = np.array(edges)
                edges = (edges[:, None] == edges[None, :]).astype(np.int32)
                if self.directed:
                    edges = (edges & labels == 1).astype(np.int32)
                np.fill_diagonal(edges, -1)
                labels = np.concatenate([labels, edges], -1)
  • edge_ The use of PRED?
    It is calculated in loss and compared with the edge after initialization_ Gold performs the operation of cross entropy and returns loss_edge and acc_edge, it should be noted that acc_edge does not include the relationship value between itself and itself in the calculation, acc_edge is always 100%, loss_edge takes into account the loss of its relationship with itself, so it is valuable, but relative to loss_ The value of node is very small, and the final loss is loss_node+loss_edge, relevant codes are as follows:

Location: mmocr\models\kie\losses\sdmgr_loss.py

	# Subdivision loss reasoning
    def forward(self, node_preds, edge_preds, gts):
        node_gts, edge_gts = [], []
        for gt in gts:
            node_gts.append(gt[:, 0])
            edge_gts.append(gt[:, 1:].contiguous().view(-1))
        node_gts = torch.cat(node_gts).long()
        edge_gts = torch.cat(edge_gts).long()

        node_valids = torch.nonzero(
            node_gts != self.ignore, as_tuple=False).view(-1)
        edge_valids = torch.nonzero(edge_gts != -1, as_tuple=False).view(-1)
        return dict(
            loss_node=self.node_weight * self.loss_node(node_preds, node_gts),
            loss_edge=self.edge_weight * self.loss_edge(edge_preds, edge_gts),
            acc_node=accuracy(node_preds[node_valids], node_gts[node_valids]),
            acc_edge=accuracy(edge_preds[edge_valids], edge_gts[edge_valids]))

Location: mmdet \ models \ detectors \ base PY_ parse_losses function

    def _parse_losses(self, losses):
        """Parse the raw outputs (losses) of the network.

        Args:
            losses (dict): Raw output of the network, which usually contain
                losses and other necessary infomation.

        Returns:
            tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor \
                which may be a weighted sum of all losses, log_vars contains \
                all the variables to be sent to the logger.
        """
        log_vars = OrderedDict()
        for loss_name, loss_value in losses.items():
            if isinstance(loss_value, torch.Tensor):
                log_vars[loss_name] = loss_value.mean()
            elif isinstance(loss_value, list):
                log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
            else:
                raise TypeError(
                    f'{loss_name} is not a tensor or list of tensors')
		# loss addition
        loss = sum(_value for _key, _value in log_vars.items()
                   if 'loss' in _key)

        log_vars['loss'] = loss
        for loss_name, loss_value in log_vars.items():
            # reduce loss when distributed training
            if dist.is_available() and dist.is_initialized():
                loss_value = loss_value.data.clone()
                dist.all_reduce(loss_value.div_(dist.get_world_size()))
            log_vars[loss_name] = loss_value.item()

        return loss, log_vars

reference material

Topics: AI Deep Learning NLP