Hengyuan cloud (GpuShare)_ Medical image segmentation: MT Unet

Posted by Pyro4816 on Wed, 09 Mar 2022 12:04:00 +0100

Our community has new technology sharing partners 🎉🎉🎉
a warm welcome 👏
As a qualified Porter, I must do something to express my joy: Handling ~ handling ~ immediate handling ~

Article source| Hengyuan cloud community

Original address| New hybrid Transformer module (MTM)

Original author Dong Dong


Existing problemsAlthough U-Net has achieved great success in medical image segmentation, it lacks the ability to explicitly model long-term dependencies. Visual Transformer has become an alternative segmentation structure in recent years because of its inherent ability to capture long-range correlation through self attention (SA).
Existing problemsHowever, Transformer usually relies on large-scale pre training and has high computational complexity. In addition, SA can only model self affinities in a single sample, ignoring the potential correlation of the whole data set
Thesis methodA new hybrid Transformer module (MTM) is proposed for simultaneous inter affinities learning and intra affinities learning. MTM first effectively calculates the internal affinities of the window through local global Gaussian weighted self attention (LGG-SA). Then, pay attention to mining the relationship between data samples through external. Using MTM algorithm, a MT UNET model for medical image segmentation is constructed


As shown in Figure 1. The network is based on encoder decoder structure

  1. In order to reduce the computational cost, MTMs is only used for deep layers with small space,
  2. The classical convolution operation is still used in shallow layer. This is because the shallow layer mainly focuses on local information and contains more high-resolution details.


As shown in Figure 2. MTM is mainly composed of LGG-SA and EA.

LGG-SA is used to model short-term and long-term dependencies with different granularity, while EA is used to mine the correlation between samples.

This module is to replace the original Transformer encoder to improve its performance in visual tasks and reduce the time complexity

LGG-SA(Local-Global Gaussian-Weighted Self-Attention)

The traditional SA module gives the same attention to all tokens, while lgg-sa is different. It can focus more on the adjacent area by using local global self attention and Gaussian mask. Experiments show that this method can improve the performance of the model and save computing resources. The detailed design of the module is shown in Figure 3

Local global self attention

In computer vision, the correlation between adjacent regions is often more important than that between distant regions. When calculating the attention map, it does not need to pay the same price for the farther regions.

Therefore, local global self attention is proposed.

  1. Each local window in stage1 above contains four token s. local SA calculates the internal affinities in each window.
  2. The tokens in each window are aggregated into a global token by aggregate, which represents the main information of the window. For aggregate functions, lightweight dynamic convolution (ldconv) has the best performance.
  3. After obtaining the entire feature map of down sampling, global SA can be executed with less overhead (stage2 above).

Where \ (X \in R^{H \times W \times C} \)

The self attention codes of local windows in stage1 are as follows:

class WinAttention(nn.Module):
    def __init__(self, configs, dim):
        super(WinAttention, self).__init__()
        self.window_size = configs["win_size"]
        self.attention = Attention(dim, configs)

    def forward(self, x):
        b, n, c = x.shape
        h, w = int(np.sqrt(n)), int(np.sqrt(n))
        x = x.permute(0, 2, 1).contiguous().view(b, c, h, w)
        if h % self.window_size != 0:
            right_size = h + self.window_size - h % self.window_size
            new_x = torch.zeros((b, c, right_size, right_size))
            new_x[:, :, 0:x.shape[2], 0:x.shape[3]] = x[:]
            new_x[:, :, x.shape[2]:,
                  x.shape[3]:] = x[:, :, (x.shape[2] - right_size):,
                                   (x.shape[3] - right_size):]
            x = new_x
            b, c, h, w = x.shape
        x = x.view(b, c, h // self.window_size, self.window_size,
                   w // self.window_size, self.window_size)  
        x = x.permute(0, 2, 4, 3, 5,
                      1).contiguous().view(b, h // self.window_size,
                                           w // self.window_size,
                                           self.window_size * self.window_size,
        x = self.attention(x)  #  (b, p, p, win, c) calculate the self attention of tokens in the local window
        return x

The aggregate function code is as follows

class DlightConv(nn.Module):
    def __init__(self, dim, configs):
        super(DlightConv, self).__init__()
        self.linear = nn.Linear(dim, configs["win_size"] * configs["win_size"])
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):  # (b, p, p, win, c)
        h = x
        avg_x = torch.mean(x, dim=-2)  # (b, p, p, c)
        x_prob = self.softmax(self.linear(avg_x))  # (b, p, p, win)

        x = torch.mul(h,
                      x_prob.unsqueeze(-1))  # (b, p, p, win, c) 
        x = torch.sum(x, dim=-2)  # (b, p, p, c)
        return x

Gaussian-Weighted Axial Attention

Different from the LSA using the original SA, a Gaussian weighted axial attention (GWAA) method is proposed. GWAA enhances the perceptual full weight of adjacent regions through a learnable Gaussian matrix, and reduces the time complexity due to axial attention.

  1. In the figure above, the features in the third row and the third column of the feature map in stage2 are obtained by linear projection \ (q_{i, j} \)
  2. Perform linear projection on all the features of the row and column of the feature point to obtain \ (K_{i, j} \)
    And \ (V {I, J} \)
  3. The Euclidean distance between the feature point and all K and V is defined as \ (D {I, J} \)

The final Gaussian weighted axial attention output is

And simplified to

The axial attention code is as follows:

class Attention(nn.Module):
    def __init__(self, dim, configs, axial=False):
        super(Attention, self).__init__()
        self.axial = axial
        self.dim = dim
        self.num_head = configs["head"]
        self.attention_head_size = int(self.dim / configs["head"])
        self.all_head_size = self.num_head * self.attention_head_size

        self.query_layer = nn.Linear(self.dim, self.all_head_size)
        self.key_layer = nn.Linear(self.dim, self.all_head_size)
        self.value_layer = nn.Linear(self.dim, self.all_head_size)

        self.out = nn.Linear(self.dim, self.dim)
        self.softmax = nn.Softmax(dim=-1)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_head, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x

    def forward(self, x):
        # first row and col attention
        if self.axial:
             # x: (b, p, p, c)
            # row attention (single head attention)
            b, h, w, c = x.shape
            mixed_query_layer = self.query_layer(x)
            mixed_key_layer = self.key_layer(x)
            mixed_value_layer = self.value_layer(x)

            query_layer_x = mixed_query_layer.view(b * h, w, -1)
            key_layer_x = mixed_key_layer.view(b * h, w, -1).transpose(-1, -2)  # (b*h, -1, w)
            attention_scores_x = torch.matmul(query_layer_x,
                                              key_layer_x)  # (b*h, w, w)
            attention_scores_x = attention_scores_x.view(b, -1, w,
                                                         w)  # (b, h, w, w)

            # col attention  (single head attention)
            query_layer_y = mixed_query_layer.permute(0, 2, 1,
                                                          b * w, h, -1)
            key_layer_y = mixed_key_layer.permute(
                0, 2, 1, 3).contiguous().view(b * w, h, -1).transpose(-1, -2)  # (b*w, -1, h)
            attention_scores_y = torch.matmul(query_layer_y,
                                              key_layer_y)  # (b*w, h, h)
            attention_scores_y = attention_scores_y.view(b, -1, h,
                                                         h)  # (b, w, h, h)

            return attention_scores_x, attention_scores_y, mixed_value_layer

            mixed_query_layer = self.query_layer(x)
            mixed_key_layer = self.key_layer(x)
            mixed_value_layer = self.value_layer(x)

            query_layer = self.transpose_for_scores(mixed_query_layer).permute(
                0, 1, 2, 4, 3, 5).contiguous()  # (b, p, p, head, n, c)
            key_layer = self.transpose_for_scores(mixed_key_layer).permute(
                0, 1, 2, 4, 3, 5).contiguous()
            value_layer = self.transpose_for_scores(mixed_value_layer).permute(
                0, 1, 2, 4, 3, 5).contiguous()

            attention_scores = torch.matmul(query_layer,
                                            key_layer.transpose(-1, -2))
            attention_scores = attention_scores / math.sqrt(
            atten_probs = self.softmax(attention_scores)

            context_layer = torch.matmul(
                atten_probs, value_layer)  # (b, p, p, head, win, h)
            context_layer = context_layer.permute(0, 1, 2, 4, 3,
            new_context_layer_shape = context_layer.size()[:-2] + (
                self.all_head_size, )
            context_layer = context_layer.view(*new_context_layer_shape)
            attention_output = self.out(context_layer)

        return attention_output

The Gaussian weighting code is as follows:

class GaussianTrans(nn.Module):
    def __init__(self):
        super(GaussianTrans, self).__init__()
        self.bias = nn.Parameter(-torch.abs(torch.randn(1)))
        self.shift = nn.Parameter(torch.abs(torch.randn(1)))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x): 
        x, atten_x_full, atten_y_full, value_full = x  #x(b, h, w, c) atten_x_full(b, h, w, w)   atten_y_full(b, w, h, h) value_full(b, h, w, c)
        new_value_full = torch.zeros_like(value_full)

        for r in range(x.shape[1]):  # row
            for c in range(x.shape[2]):  # col
                atten_x = atten_x_full[:, r, c, :]  # (b, w)
                atten_y = atten_y_full[:, c, r, :]  # (b, h)

                dis_x = torch.tensor([(h - c)**2 for h in range(x.shape[2])
                                      ]).cuda()  # (b, w)
                dis_y = torch.tensor([(w - r)**2 for w in range(x.shape[1])
                                      ]).cuda()  # (b, h)

                dis_x = -(self.shift * dis_x + self.bias).cuda()
                dis_y = -(self.shift * dis_y + self.bias).cuda()

                atten_x = self.softmax(dis_x + atten_x)
                atten_y = self.softmax(dis_y + atten_y)

                new_value_full[:, r, c, :] = torch.sum(
                    atten_x.unsqueeze(dim=-1) * value_full[:, r, :, :] +
                    atten_y.unsqueeze(dim=-1) * value_full[:, :, c, :],
        return new_value_full

The complete code of local global self attention is as follows:

class CSAttention(nn.Module):
    def __init__(self, dim, configs):
        super(CSAttention, self).__init__()
        self.win_atten = WinAttention(configs, dim)
        self.dlightconv = DlightConv(dim, configs)
        self.global_atten = Attention(dim, configs, axial=True)
        self.gaussiantrans = GaussianTrans()
        #self.conv = nn.Conv2d(dim, dim, 3, padding=1)
        #self.maxpool = nn.MaxPool2d(2)
        self.up = nn.UpsamplingBilinear2d(scale_factor=4)
        self.queeze = nn.Conv2d(2 * dim, dim, 1)

    def forward(self, x):
        :param x: size(b, n, c)
        origin_size = x.shape
        _, origin_h, origin_w, _ = origin_size[0], int(np.sqrt(
            origin_size[1])), int(np.sqrt(origin_size[1])), origin_size[2]
        x = self.win_atten(x)  # (b, p, p, win, c)
        b, p, p, win, c = x.shape
        h = x.view(b, p, p, int(np.sqrt(win)), int(np.sqrt(win)),
                   c).permute(0, 1, 3, 2, 4, 5).contiguous()
        h = h.view(b, p * int(np.sqrt(win)), p * int(np.sqrt(win)),
                   c).permute(0, 3, 1, 2).contiguous()  # (b, c, h, w)

        x = self.dlightconv(x)  # (b, p, p, c)
        atten_x, atten_y, mixed_value = self.global_atten(
            x)  # (b, h, w, w) (b, w, h, h) (b, h, w, c) here h, W is p
        gaussian_input = (x, atten_x, atten_y, mixed_value)
        x = self.gaussiantrans(gaussian_input)  # (b, h, w, c)
        x = x.permute(0, 3, 1, 2).contiguous()  # (b, c, h, w)

        x = self.up(x)
        x = self.queeze(torch.cat((x, h), dim=1)).permute(0, 2, 3,
        x = x[:, :origin_h, :origin_w, :].contiguous()
        x = x.view(b, -1, c)

        return x

External attention (EA) is used to solve the problem that SA cannot use the relationship between different input data samples.

Different from using each sample's own linear transformation to calculate the self attention of attention score, in EA, all data samples share two memory units MK and MV (as shown in Figure 2), which describe the most important information of the whole data set.

EA code is as follows:

class MEAttention(nn.Module):
    def __init__(self, dim, configs):
        super(MEAttention, self).__init__()
        self.num_heads = configs["head"]
        self.coef = 4
        self.query_liner = nn.Linear(dim, dim * self.coef)
        self.num_heads = self.coef * self.num_heads
        self.k = 256 // self.coef
        self.linear_0 = nn.Linear(dim * self.coef // self.num_heads, self.k)
        self.linear_1 = nn.Linear(self.k, dim * self.coef // self.num_heads)

        self.proj = nn.Linear(dim * self.coef, dim)

    def forward(self, x):
        B, N, C = x.shape
        x = self.query_liner(x)  # (b, n, 4c)
        x = x.view(B, N, self.num_heads, -1).permute(0, 2, 1,
                                                     3)  #  (b, h, n, 4c/h)

        attn = self.linear_0(x)  # (b, h, n, 256/4)

        attn = attn.softmax(dim=-2)  # (b, h, 256/4)
        attn = attn / (1e-9 + attn.sum(dim=-1, keepdim=True))  # (b, h, 256/4)

        x = self.linear_1(attn).permute(0, 2, 1, 3).reshape(B, N, -1)

        x = self.proj(x)

        return x


Topics: Deep Learning transform