100 cases of deep learning - generation confrontation network (GAN) handwritten numeral generation | day 18

Posted by stueee on Sat, 15 Jan 2022 03:11:36 +0100

šŸ”± Hello, I'm šŸ‘‰ Classmate Kļ¼Œ100 cases of deep learning The series will be updated continuously. Welcome to like šŸ‘, Collection ā­, follow šŸ‘€

This paper will use GAN model to realize the generation of handwritten digits, focusing on understanding the structure and construction method of GAN model.

1, Preliminary work

šŸš€ My environment:

  • Locale: Python 3 six point five
  • compiler: jupyter notebook
  • Deep learning environment: tensorflow2 four point one

šŸš€ In depth learning newcomers must see:

  1. Xiaobai introduction to in-depth learning Chapter 1: configuring in-depth learning environment
  2. Introduction to Xiaobai deep learning | Chapter 2: use of compiler - Jupiter notebook

šŸš€ Previous Highlights - convolutional neural network:

  1. 100 cases of deep learning convolutional neural network (CNN) to realize mnist handwritten numeral recognition | day 1
  2. 100 cases of deep learning - convolutional neural network (CNN) color picture classification | day 2
  3. 100 cases of deep learning - convolutional neural network (CNN) garment image classification | day 3
  4. 100 cases of deep learning - convolutional neural network (CNN) flower recognition | day 4
  5. 100 cases of deep learning - convolutional neural network (CNN) weather recognition | day 5
  6. 100 cases of deep learning - convolutional neural network (VGG-16) to identify the pirate king straw hat group | day 6
  7. 100 cases of deep learning - convolutional neural network (VGG-19) to identify the characters in the spirit cage | day 7
  8. 100 cases of deep learning - convolutional neural network (ResNet-50) bird recognition | day 8
  9. 100 cases of deep learning - convolutional neural network (AlexNet) hand-in-hand teaching | day 11
  10. 100 cases of deep learning - convolutional neural network (CNN) identification verification code | day 12
  11. 100 cases of deep learning - convolutional neural network (perception V3) recognition of sign language | day 13
  12. 100 cases of deep learning - convolution neural network (Inception-ResNet-v2) recognition of traffic signs | day 14
  13. 100 cases of deep learning - convolutional neural network (CNN) for license plate recognition | day 15
  14. 100 cases of in-depth learning - convolutional neural network (CNN) to identify the Magic Baby Xiaozhi group | day 16
  15. 100 cases of deep learning - convolutional neural network (CNN) attention detection | day 17

šŸš€ Highlights of previous issues - cyclic neural network:

  1. 100 cases of deep learning - circular neural network (RNN) to achieve stock prediction | day 9
  2. 100 cases of deep learning - circular neural network (LSTM) to realize stock prediction | day 10

šŸš€ This article is selected from the column: 100 cases of deep learning

1. Set GPU

If you are using a CPU, you can comment out this part of the code.

import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #Set the amount of GPU video memory and use it on demand
    tf.config.set_visible_devices([gpus[0]],"GPU")
    
# Print the graphics card information and confirm that the GPU is available
print(gpus)
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
from tensorflow.keras import layers, datasets, Sequential, Model, optimizers
from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2D

import matplotlib.pyplot as plt
import numpy             as np
import sys,os,pathlib

2. Define training parameters

img_shape  = (28, 28, 1)
latent_dim = 200

2, What is generative countermeasure network

1. Brief introduction

The generative countermeasure network (GAN) includes a generator and a discriminator. The two models continue to learn and evolve through countermeasure training.

  • Generator: generates data (mostly images) to "fool" the discriminator.

  • Discriminator: judge whether the image is real or machine generated, in order to find out the "false data" generated by the generator.

2. Application fields

GAN has a wide range of applications, including image synthesis, style migration, photo restoration, photo editing, data enhancement and so on.

1) Style transfer

Image style migration is to convert the style of image A into image B to get A new image.

2) Image generation

GAN can not only generate faces, but also generate other types of pictures, such as comic characters.

3, Network structure

In short, it is to use the generator to generate handwritten digital images, and use the discriminator to identify the authenticity of the images. They learn (volume) against each other, and constantly improve themselves in the process of learning (volume) until the generator can generate pictures that confuse the true with the false (the discriminator cannot judge whether they are true or false). The structure diagram is as follows:

GAN steps:

  • 1. The Generator receives the random number and returns the generated image.
  • 2. Send the generated digital image to the Discriminator together with the digital image in the actual data set.
  • 3. The Discriminator receives the real and false images and returns the probability. The number between 0 and 1, 1 indicates true and 0 indicates false.

4, Build generator

def build_generator():
    # ======================================= #
    #     Generator, input a string of random numbers to generate pictures
    # ======================================= #
    model = Sequential([
        layers.Dense(256, input_dim=latent_dim),
        layers.LeakyReLU(alpha=0.2),               # Higher level activation function
        layers.BatchNormalization(momentum=0.8),   # BN normalization
        
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        
        layers.Dense(1024),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        
        layers.Dense(np.prod(img_shape), activation='tanh'),
        layers.Reshape(img_shape)
    ])

    noise = layers.Input(shape=(latent_dim,))
    img = model(noise)

    return Model(noise, img)

5, Build discriminator

def build_discriminator():
    # ===================================== #
    #   The discriminator discriminates the authenticity of the input picture
    # ===================================== #
    model = Sequential([
        layers.Flatten(input_shape=img_shape),
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(256),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(1, activation='sigmoid')
    ])

    img = layers.Input(shape=img_shape)
    validity = model(img)

    return Model(img, validity)

  • Discriminator training principle: through the identification of the input picture, so as to achieve the effect of improvement
  • Generator training principle: the image generated by the discriminator is identified to achieve improvement
# Create discriminator
discriminator = build_discriminator()
# Define optimizer
optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator.compile(loss='binary_crossentropy',
                      optimizer=optimizer,
                      metrics=['accuracy'])

# Create generator 
generator = build_generator()
gan_input = layers.Input(shape=(latent_dim,))
img = generator(gan_input)

# Predict the generated false picture
validity = discriminator(img)
combined = Model(gan_input, validity)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

6, Training model

1. Save sample pictures

def sample_images(epoch):
    """
    Save sample picture
    """
    row, col = 4, 4
    noise = np.random.normal(0, 1, (row*col, latent_dim))
    gen_imgs = generator.predict(noise)

    fig, axs = plt.subplots(row, col)
    cnt = 0
    for i in range(row):
        for j in range(col):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/%05d.png" % epoch)
    plt.close()

2. Training model

train_on_batch: the function accepts a single batch of data, performs back propagation, and then updates the model parameters. The size of this batch of data can be arbitrary, that is, it does not need to provide a clear batch size. It belongs to the fine control training model.

def train(epochs, batch_size=128, sample_interval=50):
    # Load data
    (train_images,_), (_,_) = tf.keras.datasets.mnist.load_data()

    # Standardize the picture to the [- 1, 1] range   
    train_images = (train_images - 127.5) / 127.5
    # data
    train_images = np.expand_dims(train_images, axis=3)

    # create label
    true = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    
    # Cycle training
    for epoch in range(epochs): 

        # Randomly select batch_size picture
        idx = np.random.randint(0, train_images.shape[0], batch_size)
        imgs = train_images[idx]      
        
        # Generate noise
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        # The generator generates pictures through noise, Gen_ The shape of IMGs is: (128, 28, 28, 1)
        gen_imgs = generator.predict(noise)
        
        # Training discriminator 
        d_loss_true = discriminator.train_on_batch(imgs, true)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        # Return loss value
        d_loss = 0.5 * np.add(d_loss_true, d_loss_fake)

        # Training generator
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        g_loss = combined.train_on_batch(noise, true)
        
        print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

        # Save sample picture
        if epoch % sample_interval == 0:
            sample_images(epoch)
train(epochs=30000, batch_size=256, sample_interval=200)
0 [D loss: 0.587824, acc.: 67.77%] [G loss: 0.634870]
1 [D loss: 0.387015, acc.: 74.22%] [G loss: 0.541133]
2 [D loss: 0.380705, acc.: 63.67%] [G loss: 0.455188]
3 [D loss: 0.408720, acc.: 56.25%] [G loss: 0.405431]
4 [D loss: 0.445802, acc.: 52.34%] [G loss: 0.343866]
......
176 [D loss: 0.394246, acc.: 66.41%] [G loss: 0.648134]
177 [D loss: 0.393966, acc.: 66.21%] [G loss: 0.640118]
178 [D loss: 0.402815, acc.: 65.62%] [G loss: 0.641665]
179 [D loss: 0.404573, acc.: 65.82%] [G loss: 0.647686]
180 [D loss: 0.394707, acc.: 67.19%] [G loss: 0.631329]
......

7, Generate dynamic graph

If an error is reported: ModuleNotFoundError: No module named 'imageio', you can use: pip install imageio to install the imageio library.

import imageio

def compose_gif():
    # Picture address
    data_dir = "F:/jupyter notebook/DL-100-days/code/images"
    data_dir = pathlib.Path(data_dir)
    paths    = list(data_dir.glob('*'))
    
    gif_images = []
    for path in paths:
        print(path)
        gif_images.append(imageio.imread(path))
    imageio.mimsave("test.gif",gif_images,fps=2)
    
compose_gif()
F:\jupyter notebook\DL-100-days\code\images\00000.png
F:\jupyter notebook\DL-100-days\code\images\00200.png
F:\jupyter notebook\DL-100-days\code\images\00400.png
F:\jupyter notebook\DL-100-days\code\images\00600.png
F:\jupyter notebook\DL-100-days\code\images\00800.png
F:\jupyter notebook\DL-100-days\code\images\01000.png
.....

Display of image generation process (about 50s):

In depth learning newcomers must see:

  1. Xiaobai introduction to in-depth learning Chapter 1: configuring in-depth learning environment
  2. Introduction to Xiaobai deep learning | Chapter 2: use of compiler - Jupiter notebook

Previous highlights

  1. 100 cases of deep learning convolutional neural network (CNN) to realize mnist handwritten numeral recognition | day 1
  2. 100 cases of deep learning - convolutional neural network (CNN) color picture classification | day 2
  3. 100 cases of deep learning - convolutional neural network (CNN) garment image classification | day 3
  4. 100 cases of deep learning - convolutional neural network (CNN) flower recognition | day 4
  5. 100 cases of deep learning - convolutional neural network (CNN) weather recognition | day 5
  6. 100 cases of deep learning - convolutional neural network (VGG-16) to identify the pirate king straw hat group | day 6
  7. 100 cases of deep learning - convolutional neural network (VGG-19) to identify the characters in the spirit cage | day 7
  8. 100 cases of deep learning - convolutional neural network (ResNet-50) bird recognition | day 8
  9. 100 cases of deep learning - circular neural network (RNN) to achieve stock prediction | day 9
  10. 100 cases of deep learning - circular neural network (LSTM) to realize stock prediction | day 10
  11. 100 cases of deep learning - convolutional neural network (AlexNet) hand-in-hand teaching | day 11
  12. 100 cases of deep learning - convolutional neural network (CNN) identification verification code | day 12
  13. 100 cases of deep learning - convolutional neural network (perception V3) recognition of sign language | day 13
  14. 100 cases of deep learning - convolution neural network (Inception-ResNet-v2) recognition of traffic signs | day 14
  15. 100 cases of deep learning - convolutional neural network (CNN) for license plate recognition | day 15
  16. 100 cases of in-depth learning - convolutional neural network (CNN) to identify the Magic Baby Xiaozhi group | day 16
  17. 100 cases of deep learning - convolutional neural network (CNN) attention detection | day 17

šŸš€ From column: 100 cases of deep learning

Unfinished ļ½ž

Continuous updates welcome likes šŸ‘, Collection ā­, follow šŸ‘€

  • give the thumbs-up šŸ‘: Praise gives me the motivation to continuously update
  • Collection ā­ Note: you can find articles at any time after collection
  • follow šŸ‘€: Follow me to receive the latest articles for the first time

Topics: TensorFlow Deep Learning GAN