CVPR LIIF super resolution code reading

Posted by shack on Tue, 01 Feb 2022 20:37:43 +0100

paper
code

abstract

The physical world presents visual images in a continuous manner, but computers store and display images in the form of discrete 2D pixel arrays. This paper studies the continuous representation of the image, uses the Local Implicit Image Function (LIIF) to take the image coordinates and the 2D depth features around the coordinates as the input, and predicts the RGB value under the given coordinates. Through the self supervised super-resolution task to train an encoder and LIIF representation to generate the continuous representation of pixel image, the resolution of any multiple can be achieved, and even the super score of more than 30 times that is not in the training task can be calculated. By modeling the image as a function in the continuous domain, the image with arbitrary resolution can be restored and generated. The idea of implicit function is to represent an object as a function and map the coordinates to the corresponding signals (such as the symbol distance on the surface of 3D object and the RGB value in the image). The neural implicit function is parameterized by deep neural network. In order to share knowledge across instances instead of fitting a separate implicit function for each object, an encoder based method is proposed to predict the potential coding of each object. The implicit function is then shared by all objects, while it takes the underlying code as additional input.

Local Implicit Image Function

In the LIIF representation, each continuous image I ( i ) I^{(i)} I(i) is mapped by two-dimensional features M ( i ) ∈ R H × W × D M^{(i)}∈\R^{H \times W \times D} M(i)∈RH × W × D means. A neural implicit function f θ f_θ f θ (with) θ θ θ Is shared by all images, and it is parameterized as M L P MLP MLP and take s = f ( z , x ) s = f(z,x) S = f (Z, x) (omitted simply) θ θ θ) Form, where z z z is a vector, x ∈ X x \in X X ∈ x is a two-dimensional coordinate in the continuous image domain, s ∈ S s \in S ∈ s is the predicted value (RGB).

For defined f f f. Each vector z z z can be regarded as a representation function f ( z , ⋅ ) : X → S f(z,·):X→S f(z,⋅):X→S. f ( z , ⋅ ) f(z,·) f(z, ⋅) can be regarded as a continuous image, that is, a function of mapping coordinates to RGB values. hypothesis M ( i ) M^{(i)} Of M(i) H × W H \times W H × W eigenvectors (called latent codes) are evenly distributed in I ( i ) I^{(i)} I(i) in the 2D space of the continuous image domain, and assign a 2D coordinate to each of them.

For images I ( i ) I^{(i)} I(i), coordinates x q x_q The RGB value at xq is defined as I ( i ) ( x q ) = f ( z ∗ , x q − v ∗ ) I^{(i)}(x_q) = f(z^*,x_q-v^*) I(i)(xq) = f(z *, xq − v *), where z ∗ z^∗ z * yes M ( i ) M^{(i)} In M(i) and x q x_q xq nearest (Euclidean distance) hidden code, v ∗ v^∗ v * is the latent code in the image domain z ∗ z^∗ Coordinates of z *. for example z 11 ∗ z^∗_{11} z11 * is in the current definition x q x_q xq + z ∗ z^∗ z *, and v ∗ v^∗ v * is defined as z 11 ∗ z^∗_{11} Coordinates of z11 *.

Implicit functions shared in all images f f Under f, continuous images are mapped by two-dimensional features M ( i ) ∈ R H × W × D M^{(i)} \in \R^{H \times W \times D} M(i)∈RH × W × D indicates that the feature map is considered to be evenly distributed in the 2D domain H × W H×W H × W hidden code. stay M ( i ) M^{(i)} Each potential code in M(i) z z z represents the local part of the continuous image, which is responsible for predicting the signal of its nearest coordinate set.

The normalized coordinate value and RGB value are obtained from the image

def make_coord(shape, ranges=None, flatten=True):
    """ Make coordinates at grid centers.
    """
    coord_seqs = []
    for i, n in enumerate(shape):
        if ranges is None:
            v0, v1 = -1, 1
        else:
            v0, v1 = ranges[i]
        r = (v1 - v0) / (2 * n)
        seq = v0 + r + (2 * r) * torch.arange(n).float()
        coord_seqs.append(seq)
    ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
    if flatten:
        ret = ret.view(-1, ret.shape[-1])
    return ret
    
coord = make_coord((h, w)) #h. W is the height and width of SR target

def to_pixel_samples(img):
    """ Convert the image to coord-RGB pairs.
        img: Tensor, (3, H, W)
    """
    coord = make_coord(img.shape[-2:])   #(h*w,2)--(h*w,[x,y])
    rgb = img.view(3, -1).permute(1, 0)  #(h*w,3)--(h*w,[R,G,B])
    return coord, rgb

Feature unfolding

In order to enrich the information contained in the hidden code, the feature M ( i ) M^{(i)} M(i) expansion M ^ ( i ) {\hat M^{(i)}} M^(i). M ^ ( i ) {\hat M^{(i)}} M^(i) is in M ( i ) M^{(i)} In M(i) 3 × 3 3 \times 3 three × 3. Merging of adjacent hidden codes.
M ^ j k ( i ) = C o n c a t ( { M j + l , k + m ( i ) } l , m ∈ { − 1 , 0 , 1 } ) {\hat M^{(i)}_{jk}} =Concat(\{M^{(i)}_{j+l,k+m}\}_{l,m\in\{-1,0,1\}}) M^jk(i)​=Concat({Mj+l,k+m(i)​}l,m∈{−1,0,1}​)
When C o n c a t Concat Concat refers to the connection of a set of vectors, M ( i ) M^{(i)} M(i) is filled with a zero vector outside its boundary.
[ N , C , L R H , L R W ] [N,C,LR_H,LR_W] [N,C,LRH, LRW] f e a t feat feat is u n f o l d unfold unfold becomes [ N , C ∗ 3 ∗ 3 , L R H , L R W ] [N,C*3*3,LR_H,LR_W] [N,C∗3∗3,LRH​,LRW​].

feat = F.unfold(feat, 3, padding=1).view(feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])

Local ensemble

s = f ( z , x ) s = f(z,x) s=f(z,x) is a discontinuous prediction because x q x_q xq signal prediction is through query M ( i ) M^{(i)} The nearest hidden code in M(i) z ∗ z^∗ z * completed, so when x q x_q xq# when moving in the image domain, z ∗ z^∗ z * will suddenly switch from one hidden code to another. stay z ∗ z^∗ z * around those coordinates selected for switching, two signals infinitely close to the coordinates will be predicted from different implicit codes, as long as the learned implicit function f f f is not perfect, in z ∗ z^∗ z * there is no discontinuous figure at the boundary of the selected switch. In order to solve this problem, local integration technology is used to expand the representation of each hidden code
I ( i ) ( x q ) = ∑ t ∈ { 00 , 01 , 10 , 11 } S t S f ( z t ∗ , x q − v t ∗ ) I^{(i)}(x_q) =\sum_{t\in\{00,01,10,11\}} \frac {S_t}{S} f(z^*_t,x_q-v^*_t) I(i)(xq​)=t∈{00,01,10,11}∑​SSt​​f(zt∗​,xq​−vt∗​)
z t ∗ ( t ∈ { 00 , 01 , 10 , 11 } ) z^*_t(t\in\{00,01,10,11\}) zt * (t ∈ {00,01,10,11}) refers to the nearest hidden code in the upper left, upper right, lower left and lower right subspaces, v t ∗ v^*_t vt * means z t ∗ z^*_t The coordinates of zt *, S t S_t St # yes x q x_q xq # and v t ′ ∗ v^*_{t'} vt′∗​( v t ′ ∗ v^*_{t'} vt '* yes v t ∗ v^*_{t} The diagonal of vt *, such as the rectangular area between 00 to 11, 10 to 01). Weight by S = ∑ t S t S=\sum_tS_t S = ∑ t ∑ St normalization. Characteristic diagram M ( i ) M^{(i)} M(i) is mirror filled outside the boundary, so this also applies to the coordinates near the boundary.

This is to make the local image block represented by the hidden code overlap with its adjacent blocks, so that there are four hidden codes at each coordinate for the independent prediction signal. Then, the four predictions are combined by voting with normalized confidence, which is proportional to the rectangular area between the query point and its nearest diagonal corresponding point of the hidden code. Therefore, when the query coordinate is closer, the confidence becomes higher. By this vote, it z ∗ z^* The continuous transition is realized at the z * transformation coordinates (i.e. the dotted line in the figure).

vx_lst = [-1, 1]
vy_lst = [-1, 1]
eps_shift = 1e-6

rx = 2 / feat.shape[-2] / 2  #2/H/2
ry = 2 / feat.shape[-1] / 2  #2/W/2

feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() #[LR_H,LR_W,2]
feat_coord = feat_coord.permute(2, 0, 1).unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])#[N,2,LR_H,LR_W]

preds = []
areas = []
for vx in vx_lst:
    for vy in vy_lst:
        coord_ = coord.clone()#[N,SR_H*SR_W,2]
        coord_[:, :, 0] += vx * rx + eps_shift
        coord_[:, :, 1] += vy * ry + eps_shift
        coord_.clamp_(-1 + 1e-6, 1 - 1e-6)

        q_feat = F.grid_sample(feat, coord_.flip(-1).unsqueeze(1),mode='nearest', align_corners=False)#[N,C*9,1,SR_H*SR_W]
        q_feat = q_feat[:, :, 0, :].permute(0, 2, 1)#[N,SR_H*SR_W,C*9]

        q_coord = F.grid_sample(feat_coord, coord_.flip(-1).unsqueeze(1),mode='nearest', align_corners=False)#[N,2,1,SR_H*SR_W]
        q_coord = q_coord[:, :, 0, :].permute(0, 2, 1)#[N,SR_H*SR_W,2]

        rel_coord = coord - q_coord #[N,SR_H*SR_W,2]
        rel_coord[:, :, 0] *= feat.shape[-2]
        rel_coord[:, :, 1] *= feat.shape[-1]
        inp = torch.cat([q_feat, rel_coord], dim=-1) #[N,SR_H*SR_W,C*9+2]

        if self.cell_decode:
            rel_cell = cell.clone()
            rel_cell[:, :, 0] *= feat.shape[-2]
            rel_cell[:, :, 1] *= feat.shape[-1]
            inp = torch.cat([inp, rel_cell], dim=-1) #[N,SR_H*SR_W,C*9+2+2]

        bs, q = coord.shape[:2] #bs=N q=SR_H*SR_W
        #[N*SR_H*SR_W,C*9+2+2] --> [N*SR_H*SR_W,3]
        pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1) #[N,SR_H*SR_W,3]
        preds.append(pred) #[[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W]]

        area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
        areas.append(area + 1e-9) #[[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W]]


tot_area = torch.stack(areas).sum(dim=0) #[N,SR_H*SR_W]
if self.local_ensemble:
    t = areas[0]; areas[0] = areas[3]; areas[3] = t #swap(areas[0],areas[3])
    t = areas[1]; areas[1] = areas[2]; areas[2] = t #swap(areas[1],areas[2])

Cell decoding

In order for LIIF to represent arbitrary resolution rendering based on pixel form, assuming the required resolution is given, a simple method is to query the continuous representation I ( ∗ ) I^{(*)} The RGB value at the center coordinate of the pixel in I(*), but because the predicted RGB value of the query pixel is independent of its size, the information in its pixel area except the center value is discarded, which may not be the best.
s = f c e l l ( z , [ x , c ] ) s=f_{cell}(z,[x,c]) s=fcell​(z,[x,c])
c = [ c h , c w ] c=[c_h,c_w] c=[ch, cw] contains the height and width of the specified query pixel, [ x , c ] [x,c] [x,c] is the value x x x and c c c connection, c c c is additional input.
f c e l l ( z , [ x , c ] ) f_{cell}(z,[x,c]) fcell (z,[x,c]) can be understood as using shape c c c render in coordinates x x The RGB value of the pixel centered on x. about 64 × 64 64\times64 sixty-four × 64 resolution, c c c is the width of the image 1 / 64 1/64 1/64. Logically, when c → 0 c→0 c → 0, f c e l l ( z , x ) = f c e l l ( z , [ x , c ] ) f_{cell}(z,x) =f_{cell}(z,[x,c]) Fcell (z,x)=fcell (z,[x,c]), that is, continuous images can be regarded as images with infinitely small pixels.

cell = torch.ones_like(coord) #[SR_H*SR_W,2] [1*2/SR_H,1*2/SR_W]
cell[:, 0] *= 2 / h
cell[:, 1] *= 2 / w 

if self.cell_decode:
     rel_cell = cell.clone()
     rel_cell[:, :, 0] *= feat.shape[-2]
     rel_cell[:, :, 1] *= feat.shape[-1]
     inp = torch.cat([inp, rel_cell], dim=-1) #[N,SR_H*SR_W,C*9+2+2]

LIIF class full code

class LIIF(nn.Module):
    def __init__(self, encoder_spec, imnet_spec=None,
                 local_ensemble=True, feat_unfold=True, cell_decode=True):
        super().__init__()
        self.local_ensemble = local_ensemble
        self.feat_unfold = feat_unfold
        self.cell_decode = cell_decode
        self.encoder = models.make(encoder_spec)

        #print("self.encoder.out_dim",self.encoder.out_dim)
        if imnet_spec is not None:
            imnet_in_dim = self.encoder.out_dim     #64
            if self.feat_unfold:
                imnet_in_dim *= 9
            imnet_in_dim += 2 # Attach coach specifies the coordinates of the query pixel [x,y]
            if self.cell_decode:
                imnet_in_dim += 2 #[Cell_h, Cell_w] specify two values for the height and width of the query pixel
            self.imnet = models.make(imnet_spec, args={'in_dim': imnet_in_dim})
        else:
            self.imnet = None

    def gen_feat(self, inp):
        self.feat = self.encoder(inp)
        return self.feat

    def query_rgb(self, coord, cell=None):
        #coord [N,SR_H*SR_*W,2]
        #cell [N,SR_H*SR_*W,2]
        feat = self.feat #[N,C,LR_H,LR_W]

        if self.imnet is None:
            ret = F.grid_sample(feat, coord.flip(-1).unsqueeze(1), mode='nearest', align_corners=False)
            ret = ret[:, :, 0, :].permute(0, 2, 1)
            return ret

        if self.feat_unfold:
            # [N,C*3*3,H,W]
            feat = F.unfold(feat, 3, padding=1).view(feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])

        if self.local_ensemble:
            vx_lst = [-1, 1]
            vy_lst = [-1, 1]
            eps_shift = 1e-6
        else:
            vx_lst, vy_lst, eps_shift = [0], [0], 0

        # field radius (global: [-1, 1])
        rx = 2 / feat.shape[-2] / 2  #2/H/2
        ry = 2 / feat.shape[-1] / 2  #2/W/2

        feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() #[LR_H,LR_W,2]
        feat_coord = feat_coord.permute(2, 0, 1).unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])#[N,2,LR_H,LR_W]

        preds = []
        areas = []
        for vx in vx_lst:
            for vy in vy_lst:
                coord_ = coord.clone()#[N,SR_H*SR_W,2]
                coord_[:, :, 0] += vx * rx + eps_shift
                coord_[:, :, 1] += vy * ry + eps_shift
                coord_.clamp_(-1 + 1e-6, 1 - 1e-6)

                q_feat = F.grid_sample(feat, coord_.flip(-1).unsqueeze(1),mode='nearest', align_corners=False)#[N,C*9,1,SR_H*SR_W]
                q_feat = q_feat[:, :, 0, :].permute(0, 2, 1)#[N,SR_H*SR_W,C*9]

                q_coord = F.grid_sample(feat_coord, coord_.flip(-1).unsqueeze(1),mode='nearest', align_corners=False)#[N,2,1,SR_H*SR_W]
                q_coord = q_coord[:, :, 0, :].permute(0, 2, 1)#[N,SR_H*SR_W,2]

                rel_coord = coord - q_coord #[N,SR_H*SR_W,2]
                rel_coord[:, :, 0] *= feat.shape[-2]
                rel_coord[:, :, 1] *= feat.shape[-1]
                inp = torch.cat([q_feat, rel_coord], dim=-1) #[N,SR_H*SR_W,C*9+2]

                if self.cell_decode:
                    rel_cell = cell.clone()
                    rel_cell[:, :, 0] *= feat.shape[-2]
                    rel_cell[:, :, 1] *= feat.shape[-1]
                    inp = torch.cat([inp, rel_cell], dim=-1) #[N,SR_H*SR_W,C*9+2+2]

                bs, q = coord.shape[:2] #bs=N q=SR_H*SR_W
                #[N*SR_H*SR_W,C*9+2+2] --> [N*SR_H*SR_W,3]
                pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1) #[N,SR_H*SR_W,3]
                preds.append(pred) #[[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W]]

                area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
                areas.append(area + 1e-9) #[[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W]]


        tot_area = torch.stack(areas).sum(dim=0) #[N,SR_H*SR_W]
        if self.local_ensemble:
            t = areas[0]; areas[0] = areas[3]; areas[3] = t #swap(areas[0],areas[3])
            t = areas[1]; areas[1] = areas[2]; areas[2] = t #swap(areas[1],areas[2])
        ret = 0
        for pred, area in zip(preds, areas):
            ret = ret + pred * (area / tot_area).unsqueeze(-1)
        return ret

    def forward(self, inp, coord, cell):
        self.gen_feat(inp)
        return self.query_rgb(coord, cell)

Topics: Computer Vision Deep Learning