2 times of super-resolution reconstruction of flowers data set with ESPCN

Posted by phpcat on Wed, 01 Jan 2020 00:32:53 +0100

Uniform sample size, large subtraction, insufficient addition of 0

def resize_image_with_crop_or_pad(image, target_height, target_width):

 

network structure

At the time of input, the normalization is made to change it into a number between 0-1. The output of the last convolution layer is 12 channels, which represents the scale of 2X2. There are 3 channels in total. Therefore, multiply by 3. Each layer uses tanh function, and the last layer does not use activation function

dbatch is a combination of the generated y ﹣ PRED and images according to the dimensions of the batch. This tensor is used for later image quality evaluation

The maximum value and the minimum value are regulated in Y ﹤ PRED to avoid the bright spot in the final generated image, which is the image is not clear

y_pred = y_predt * 255.0
y_pred = tf.maximum(y_pred, 0)
y_pred = tf.minimum(y_pred, 255)

 

Since the input samples are normalized, the images also need to be normalized when calculating loss

cost = tf.reduce_mean(tf.pow(tf.cast(images, tf.float32) / 255.0 - y_predt, 2))

The higher the PSNR(peak signal to Noise Ratio) and SSIM(structure similarity index) are, the closer the reconstructed pixel value is to the standard.

PSRN computing

The mse values in the three channels are calculated based on rgb, and then the average values are calculated, and the results are taken into PSNR

Based on YUV, the Y component of image YUV space is calculated, and the PSNR value of Y component is calculated

 

import tensorflow as tf
from datasets import flowers
import numpy as np
import matplotlib.pyplot as plt
import tensorflow.contrib.slim as slim


def batch_mse_psnr(dbatch):
    im1, im2 = np.split(dbatch, 2)
    mse = ((im1 - im2) ** 2).mean(axis=(1, 2))
    psnr = np.mean(20 * np.log10(255.0 / np.sqrt(mse)))
    return np.mean(mse), psnr


def batch_y_psnr(dbatch):
    r, g, b = np.split(dbatch, 3, axis=3)
    y = np.squeeze(0.3 * r + 0.59 * g + 0.11 * b)
    im1, im2 = np.split(y, 2)
    mse = ((im1 - im2) ** 2).mean(axis=(1, 2))
    psnr = np.mean(20 * np.log10(255.0 / np.sqrt(mse)))
    return psnr


def batch_ssim(dbatch):
    im1, im2 = np.split(dbatch, 2)
    imgsize = im1.shape[1] * im1.shape[2]
    avg1 = im1.mean((1, 2), keepdims=1)
    avg2 = im2.mean((1, 2), keepdims=1)
    std1 = im1.std((1, 2), ddof=1)
    std2 = im2.std((1, 2), ddof=1)
    cov = ((im1 - avg1) * (im2 - avg2)).mean((1, 2)) * imgsize / (imgsize - 1)
    avg1 = np.squeeze(avg1)
    avg2 = np.squeeze(avg2)
    k1 = 0.01
    k2 = 0.03
    c1 = (k1 * 255) ** 2
    c2 = (k2 * 255) ** 2
    c3 = c2 / 2
    return np.mean(
        (2 * avg1 * avg2 + c1) * 2 * (cov + c3) / (avg1 ** 2 + avg2 ** 2 + c1) / (std1 ** 2 + std2 ** 2 + c2))


def showresult(subplot, title, orgimg, thisimg, dopsnr=True):
    p = plt.subplot(subplot)
    p.axis('off')
    p.imshow(np.asarray(thisimg[0], dtype='uint8'))
    if dopsnr:
        conimg = np.concatenate((orgimg, thisimg))
        mse, psnr = batch_mse_psnr(conimg)
        ypsnr = batch_y_psnr(conimg)
        ssim = batch_ssim(conimg)
        p.set_title(title + str(int(psnr)) + " y:" + str(int(ypsnr)) + " s:" + str(ssim))
    else:
        p.set_title(title)


height = width = 200
batch_size = 4
DATA_DIR = "D:/tmp/data/flowers"

# Select dataset validation
dataset = flowers.get_split('validation', DATA_DIR)
# Create a provider
provider = slim.dataset_data_provider.DatasetDataProvider(dataset, num_readers=2)
# Get content through get of provider
[image, label] = provider.get(['image', 'label'])
print(image.shape) # (?, ?, 3)

# Clip picture to uniform size 
distorted_image = tf.image.resize_image_with_crop_or_pad(image, height, width)  # Clip size, not enough fill
print(distorted_image.shape) # (200, 200, 3)
################################################
images, labels = tf.train.batch([distorted_image, label], batch_size=batch_size)
print(images.shape) # (4, 200, 200, 3)

x_smalls = tf.image.resize_images(images, (np.int32(height / 2), np.int32(width / 2)))  # Reduce 2*2 times
x_smalls2 = x_smalls / 255.0
print(x_smalls2.shape) # (4, 100, 100, 3)
# reduction
x_nearests = tf.image.resize_images(x_smalls, (height, width), tf.image.ResizeMethod.NEAREST_NEIGHBOR)
x_bilins = tf.image.resize_images(x_smalls, (height, width), tf.image.ResizeMethod.BILINEAR)
x_bicubics = tf.image.resize_images(x_smalls, (height, width), tf.image.ResizeMethod.BICUBIC)

net = slim.conv2d(x_smalls2, 64, 5, activation_fn=tf.nn.tanh)
print(net.shape) # (4, 100, 100, 64)
net = slim.conv2d(net, 32, 3, activation_fn=tf.nn.tanh)
print(net.shape) # (4, 100, 100, 32)
net = slim.conv2d(net, 12, 3, activation_fn=None)  # 2*2*3
print(net.shape) # (4, 100, 100, 12)
y_predt = tf.depth_to_space(net, 2)
print(y_predt.shape) # (4, 200, 200, 3)

y_pred = y_predt * 255.0
y_pred = tf.maximum(y_pred, 0)
y_pred = tf.minimum(y_pred, 255)

dbatch = tf.concat([tf.cast(images, tf.float32), y_pred], 0)
print(dbatch.shape) # (8, 200, 200, 3)
cost = tf.reduce_mean(tf.pow(tf.cast(images, tf.float32) / 255.0 - y_predt, 2))
optimizer = tf.train.AdamOptimizer(0.000001).minimize(cost)
training_epochs = 20000
display_step = 200

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

# Startup queue
tf.train.start_queue_runners(sess=sess)

# Start cycle start training
for epoch in range(training_epochs):

    _, c = sess.run([optimizer, cost])

    # Show details in training
    if epoch % display_step == 0:
        d_batch = dbatch.eval()
        mse, psnr = batch_mse_psnr(d_batch)
        ypsnr = batch_y_psnr(d_batch)
        ssim = batch_ssim(d_batch)
        print("Epoch:", '%04d' % (epoch + 1),
              "cost=", "{:.9f}".format(c), "psnr", psnr, "ypsnr", ypsnr, "ssim", ssim)

print("complete!")

imagesv, label_batch, x_smallv, x_nearestv, x_bilinv, x_bicubicv, y_predv = sess.run(
    [images, labels, x_smalls, x_nearests, x_bilins, x_bicubics, y_pred])
print("primary", np.shape(imagesv), "Zoomed", np.shape(x_smallv), label_batch)

###display
plt.figure(figsize=(20, 10))

showresult(161, "org", imagesv, imagesv, False)
showresult(162, "small/4", imagesv, x_smallv, False)
showresult(163, "near", imagesv, x_nearestv)
showresult(164, "biline", imagesv, x_bilinv)
showresult(165, "bicubicv", imagesv, x_bicubicv)
showresult(166, "pred", imagesv, y_predv)

plt.show()

 

 

The effect is not so good.....

Topics: network