SRGAN image super resolution reconstruction (Keras)

Posted by plouka on Mon, 06 Dec 2021 03:52:03 +0100

preface

SRGAN network is a network that uses GAN network to realize image super-resolution reconstruction. After training the network. Only the Generator is used to reconstruct low resolution images. The network structure mainly uses generators and discriminators. The training process is not very stable. It is generally used for image reconstruction of satellite images and remote sensing images.
Here we use the high resolution data set (DIV2K)
Dataset download link: https://pan.baidu.com/s/1UBle5Cu74TRifcAVz14cDg Extraction code: luly
github code address: uploading

1, SRGAN

1. Training steps

The training idea of SRGAN network is shown in the figure below:

The training steps are as follows:
(1) The low resolution is input to the generation network to generate a high-resolution image.
(2) The discrimination network of high-resolution image input is used to distinguish true and false, and compared with 0 and 1
(3) The original high-resolution image and the generated high-resolution image are extracted with the first 9 layers of VGG19 respectively, and the extracted features are calculated as loss.
(4) . return loss to the generator to continue training.
This is the training process of SRGAN. Next, we will implement the above steps one by one.

2. Generator

The network structure of the generator is shown in the following figure:

The generator is mainly composed of two parts. The first part is the residual block (Red Square in the figure), and the second part is the up sampling part (blue square in the figure) to enlarge the image.
Residual block: contains a convolution of two 3x3
Upsampling: implemented using UpSampling2D
The generator code is as follows:
Since the whole SRGAN is defined as a class, I will unload the parameters that do not appear in the annotation

def build_generator(self):
	
        def residual_block(layer_input, filters):
            """Residual block described in paper #Residual block ''“
            d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
            d = Activation('relu')(d)
            d = BatchNormalization(momentum=0.8)(d)
            d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)
            d = BatchNormalization(momentum=0.8)(d)
            d = Add()([d, layer_input])
            return d

        def deconv2d(layer_input):
            """Layers used during upsampling #Upsampling block ''“
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
            u = Activation('relu')(u)
            return u

        # Low resolution image input
        # self.lr_shape the size of the low resolution image
        img_lr = Input(shape=self.lr_shape)

        # Pre-residual block
        c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
        c1 = Activation('relu')(c1)

        # Propogate through residual blocks
        # The self.gf generator uses the number of fast residuals
        r = residual_block(c1, self.gf)
        for _ in range(self.n_residual_blocks - 1):
            r = residual_block(r, self.gf)

        # Post-residual block
        c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
        c2 = BatchNormalization(momentum=0.8)(c2)
        c2 = Add()([c2, c1])

        # Upsampling
        #Up sampling
        u1 = deconv2d(c2)
        u2 = deconv2d(u1)

        # Generate high resolution output
        gen_hr = Conv2D(self.channels, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)

        return Model(img_lr, gen_hr)

3. Discriminator

The discriminator is mainly used to judge whether the generated picture is true or false. Compared with 0 and 1, 1 represents true picture and 0 represents false picture. Here, 0 and 1 are the desired vectors related to the output size of the discriminator, rather than a simple 0,1 discriminator network. The results are as follows:

The discrimination network is composed of blocks containing convolution, BN and LeakyRelu activation functions, and finally outputs 1 or 0, which is actually equivalent to a binary classification network. The code is as follows:

 def build_discriminator(self):
        #Here self.df =64
        def d_block(layer_input, filters, strides=1, bn=True):
            """Discriminator layer #The discriminator mainly contains the convolution block ""“
            d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        # Input img
        #self.hr_shape size of high resolution picture
        d0 = Input(shape=self.hr_shape)

        d1 = d_block(d0, self.df, bn=False)
        d2 = d_block(d1, self.df, strides=2)
        d3 = d_block(d2, self.df*2)
        d4 = d_block(d3, self.df*2, strides=2)
        d5 = d_block(d4, self.df*4)
        d6 = d_block(d5, self.df*4, strides=2)
        d7 = d_block(d6, self.df*8)
        d8 = d_block(d7, self.df*8, strides=2)

        d9 = Dense(self.df*16)(d8)
        d10 = LeakyReLU(alpha=0.2)(d9)
        validity = Dense(1, activation='sigmoid')(d10)

        return Model(d0, validity)```

The network is mainly divided into generator and discriminator, which confront each other during training in order to achieve a good balance.

2, Other preparations

1. Data reading

During training, we will enlarge 128x128 images into 512x512 images. The purpose of generating network is to ensure that the enlarged images are still clear.
In the data reading process, the 512x512 picture of the picture reshape is used as the supervision data, and the 128x128 picture is used as the training data. The reading code is as follows:

import scipy
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import scipy.misc
import cv2
#======================
# Used to read high-resolution data sets
#======================
#Data preprocessing, the original image is processed into small and large images
class DataLoader():
    #Initialization, the size of the reconstructed clear image
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res
    #Read data from folder
    def load_data(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "test"
        
        path = glob('./datasets/%s/train/*' % (self.dataset_name))

        #Randomly select pictures for training. There may be many pictures
        batch_images = np.random.choice(path, size=batch_size)

        imgs_hr = []
        imgs_lr = []
        for img_path in batch_images:
            img = self.imread(img_path)

            #Calculate reduced data
            h, w = self.img_res
            low_h, low_w = int(h / 4), int(w / 4)

            #Zoom out the picture
            img_hr = cv2.resize(img, self.img_res)
            img_lr = cv2.resize(img, (low_h, low_w))

            # If training = > do random flip, if it is a training mode, flip and enhance the data
            if not is_testing and np.random.random() < 0.5:
                img_hr = np.fliplr(img_hr)
                img_lr = np.fliplr(img_lr)

            imgs_hr.append(img_hr)
            imgs_lr.append(img_lr)

        #Normalized 0-255255 / 127.5 = 2, between 0-2, - 1 is normalized to - 1 to 1
        imgs_hr = np.array(imgs_hr) / 127.5 - 1.
        imgs_lr = np.array(imgs_lr) / 127.5 - 1.

        return imgs_hr, imgs_lr #Matrix, matrix in the list

    #Read pictures and convert them to RGB
    def imread(self, path):
         img =cv2.imread(path)
         return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

2.VGG19 feature extraction

Using the first 9 layers of VGG19 pre trained to extract features, there will be errors in the original code, and the error prompt is shown in this blog [error]
Change the code as follows:

 def build_vgg(self):
        # VGG model is established, only the features of layer 9 are used
        vgg = VGG19(weights="imagenet",input_shape=self.hr_shape,include_top=False)
        return Model(vgg.input, outputs=vgg.layers[9].output)

4. Complete code of training process

from re import S
import cv2
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add
from keras.layers.advanced_activations import PReLU, LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.applications import VGG19
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
from data_loader import DataLoader
import numpy as np
import os
class SRGAN():
    def __init__(self):
        # Input shape
        self.channels = 3
        self.lr_height = 128                 # Low resolution height
        self.lr_width = 128                  # Low resolution width
        self.lr_shape = (self.lr_height, self.lr_width, self.channels)
        self.hr_height = self.lr_height*4   # High resolution height
        self.hr_width = self.lr_width*4     # High resolution width
        self.hr_shape = (self.hr_height, self.hr_width, self.channels)

        # Number of residual blocks in the generator
        self.n_residual_blocks = 16

        optimizer = Adam(0.0002, 0.5)

        # VGG19 extract features
        self.vgg = self.build_vgg()
        self.vgg.trainable = False
        self.vgg.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

        #Data path
        self.dataset_name = 'DIV'
        self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                      img_res=(self.hr_height, self.hr_width))

        # The output dimension facilitates the construction of 0,1 matrix and the discriminator calculates the loss
        patch = int(self.hr_height / 2**4)
        self.disc_patch = (patch, patch, 1)

        # Number of convolution kernels of generator and discriminator
        self.gf = 64
        self.df = 64

        # Configuration build discriminator mse loss
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

        # generator 
        self.generator = self.build_generator()

        # High and low resolution shape
        img_hr = Input(shape=self.hr_shape)
        img_lr = Input(shape=self.lr_shape)

        # The generator generates fake pictures
        fake_hr = self.generator(img_lr)

        # feature extraction 
        fake_features = self.vgg(fake_hr)

       #The discriminator is not trained at first
        self.discriminator.trainable = False

      
        validity = self.discriminator(fake_hr)

        self.combined = Model([img_lr, img_hr], [validity, fake_features])
        self.combined.compile(loss=['binary_crossentropy', 'mse'],
                              loss_weights=[1e-3, 1],
                              optimizer=optimizer)


    def build_vgg(self):
        # VGG model is established, only the features of layer 9 are used
        vgg = VGG19(weights="imagenet",input_shape=self.hr_shape,include_top=False)
        return Model(vgg.input, outputs=vgg.layers[9].output)

    def build_generator(self):

        def residual_block(layer_input, filters):
            """Residual block described in paper"""
            d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
            d = Activation('relu')(d)
            d = BatchNormalization(momentum=0.8)(d)
            d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)
            d = BatchNormalization(momentum=0.8)(d)
            d = Add()([d, layer_input])
            return d

        def deconv2d(layer_input):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
            u = Activation('relu')(u)
            return u

        # Low resolution image input
        img_lr = Input(shape=self.lr_shape)

        # Pre-residual block
        c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
        c1 = Activation('relu')(c1)

        # Propogate through residual blocks
        r = residual_block(c1, self.gf)
        for _ in range(self.n_residual_blocks - 1):
            r = residual_block(r, self.gf)

        # Post-residual block
        c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
        c2 = BatchNormalization(momentum=0.8)(c2)
        c2 = Add()([c2, c1])

        # Upsampling
        u1 = deconv2d(c2)
        u2 = deconv2d(u1)

        # Generate high resolution output
        gen_hr = Conv2D(self.channels, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)

        return Model(img_lr, gen_hr)

    def build_discriminator(self):
        #Here self.df =64
        def d_block(layer_input, filters, strides=1, bn=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        # Input img
        d0 = Input(shape=self.hr_shape)

        d1 = d_block(d0, self.df, bn=False)
        d2 = d_block(d1, self.df, strides=2)
        d3 = d_block(d2, self.df*2)
        d4 = d_block(d3, self.df*2, strides=2)
        d5 = d_block(d4, self.df*4)
        d6 = d_block(d5, self.df*4, strides=2)
        d7 = d_block(d6, self.df*8)
        d8 = d_block(d7, self.df*8, strides=2)

        d9 = Dense(self.df*16)(d8)
        d10 = LeakyReLU(alpha=0.2)(d9)
        validity = Dense(1, activation='sigmoid')(d10)

        return Model(d0, validity)

    def train(self, epochs, batch_size=1, sample_interval=50):

        start_time = datetime.datetime.now()

        for epoch in range(epochs):

            # ----------------------
            #  Training generator
            # ----------------------

            # Load data
            imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)

            # Generating high resolution data from low resolution data
            fake_hr = self.generator.predict(imgs_lr)
            # 0,1
            valid = np.ones((batch_size,) + self.disc_patch)
            fake = np.zeros((batch_size,) + self.disc_patch)

            # Calculate loss
            d_loss_real = self.discriminator.train_on_batch(imgs_hr, valid)
            d_loss_fake = self.discriminator.train_on_batch(fake_hr, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ------------------
            #  Training generator
            # ------------------

            # Sample images and their conditioning counterparts
            imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)

            # The generators want the discriminators to label the generated images as real
            valid = np.ones((batch_size,) + self.disc_patch)

            # vgg19 feature extraction
            image_features = self.vgg.predict(imgs_hr)

            # generator 
            g_loss = self.combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])

            elapsed_time = datetime.datetime.now() - start_time
            # Plot the progress
            print ("%d time: %s" % (epoch, elapsed_time))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)
                self.generator.save('weights/epoch%s'%str(epoch)+'.h5')

    def sample_images(self, epoch):
        os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
        r, c = 2, 2

        imgs_hr, imgs_lr = self.data_loader.load_data(batch_size=2, is_testing=True)
        fake_hr = self.generator.predict(imgs_lr)
        #-----------------------------------------------
        # This is mainly because when the pictures read by opencv are displayed in plt,
        # The color will be wrong. This is mainly to solve this problem
        #--------------------------------------------------
        imgs_lr = 0.5 * imgs_lr + 0.5
        fake_hr = 0.5 * fake_hr + 0.5
        imgs_hr = 0.5 * imgs_hr + 0.5

        # Save generated images and the high resolution originals
        titles = ['Generated', 'Original']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for row in range(r):
            for col, image in enumerate([fake_hr, imgs_hr]):
                axs[row, col].imshow(image[row])
                axs[row, col].set_title(titles[col])
                axs[row, col].axis('off')
            cnt += 1
        fig.savefig("images/%s/%d.png" % (self.dataset_name, epoch))
        plt.close()

        # Save low resolution images for comparison
        for i in range(r):
            fig = plt.figure()
            plt.imshow(imgs_lr[i])
            fig.savefig('images/%s/%d_lowres%d.png' % (self.dataset_name, epoch, i))
            plt.close()

if __name__ == '__main__':
    gan = SRGAN()
    gan.train(epochs=30000, batch_size=2, sample_interval=200)

During training, such pictures will appear in the image directory
Low resolution map:

(original image on the right, generated image on the left):

5. Prediction process

The prediction process only needs a generator, and there is no need to limit the picture size. Take out the generator and use it alone:

#Generator code
from keras.layers import Input
from keras.layers import BatchNormalization, Activation, Add
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import  Model
def build_generator():
    def residual_block(layer_input, filters):
        d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
        d = Activation('relu')(d)
        d = BatchNormalization(momentum=0.8)(d)
        d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)
        d = BatchNormalization(momentum=0.8)(d)
        d = Add()([d, layer_input])
        return d

    def deconv2d(layer_input):
        u = UpSampling2D(size=2)(layer_input)
        u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
        u = Activation('relu')(u)
        return u

    img_lr = Input(shape=[None,None,3])
    # In the first part, the low resolution image will go through a convolution + RELU function after entering
    c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
    c1 = Activation('relu')(c1)

    # In the second part, after 16 residual network structures, each residual network contains two convolution + standardization + RELU and one residual edge.
    r = residual_block(c1, 64)
    for _ in range(15):
        r = residual_block(r, 64)

    # The third part, the upper sampling part, enlarges the length and width. After two upper sampling, it becomes 4 times of the original to improve the resolution.
    c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
    c2 = BatchNormalization(momentum=0.8)(c2)
    c2 = Add()([c2, c1])
    u1 = deconv2d(c2)
    u2 = deconv2d(u1)
    gen_hr = Conv2D(3, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)

    return Model(img_lr, gen_hr)

Forecast part code:

from srgan import SRGAN
from PIL import Image
import cv2
import numpy as np
import matplotlib.pyplot as plt
from generator_model import build_generator
before_image = Image.open(r"Female_person.jpg")
gen_model = build_generator()
gen_model.load_weights('weights\epoch14800.h5')
# gen_model.summary()
new_img = Image.new('RGB', before_image.size, (128, 128, 128))
new_img.paste(before_image)
# plt.imshow(new_img)
# plt.show()

new_image = np.array(new_img)/127.5 - 1
# Three dimensional becomes four-dimensional because the input of neural network is four-dimensional
new_image = np.expand_dims(new_image, axis=0)  # [batch_size,w,h,c]
fake = (gen_model.predict(new_image)*0.5 + 0.5)*255
#Convert pictures in the form of np array into unit8, and convert data into graphs
fake = Image.fromarray(np.uint8(fake[0]))

fake.save("out.png")
titles = ['Generated', 'Original']
plt.subplot(1, 2, 1)
plt.imshow(before_image)
plt.subplot(1, 2, 2)
plt.imshow(fake)
plt.show()

Reconstruction effect:

summary

No!

Topics: Deep Learning keras