100 cases of deep learning - generation confrontation network (GAN) handwritten numeral generation | day 18

Posted by stueee on Sat, 15 Jan 2022 03:11:36 +0100

This paper will use GAN model to realize the generation of handwritten digits, focusing on understanding the structure and construction method of GAN model.

This paper will use GAN model to realize the generation of handwritten digits, focusing on understanding the structure and construction method of GAN model.

1, Preliminary work

šŸš€ My environment:

  • Locale: Python 3 six point five
  • compiler: jupyter notebook
  • Deep learning environment: tensorflow2 four point one

1. Set GPU

If you are using a CPU, you can comment out this part of the code.

import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #Set the amount of GPU video memory and use it on demand
# Print the graphics card information and confirm that the GPU is available
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
from tensorflow.keras import layers, datasets, Sequential, Model, optimizers
from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2D

import matplotlib.pyplot as plt
import numpy             as np
import sys,os,pathlib

2. Define training parameters

img_shape  = (28, 28, 1)
latent_dim = 200

2, What is generative countermeasure network

1. Brief introduction

The generative countermeasure network (GAN) includes a generator and a discriminator. The two models continue to learn and evolve through countermeasure training.

  • Generator: generates data (mostly images) to "fool" the discriminator.

  • Discriminator: judge whether the image is real or machine generated, in order to find out the "false data" generated by the generator.

2. Application fields

GAN has a wide range of applications, including image synthesis, style migration, photo restoration, photo editing, data enhancement and so on.

1) Style transfer

Image style migration is to convert the style of image A into image B to get A new image.

2) Image generation

GAN can not only generate faces, but also generate other types of pictures, such as comic characters.

3, Network structure

In short, it is to use the generator to generate handwritten digital images, and use the discriminator to identify the authenticity of the images. They learn (volume) against each other, and constantly improve themselves in the process of learning (volume) until the generator can generate pictures that confuse the true with the false (the discriminator cannot judge whether they are true or false). The structure diagram is as follows:

GAN steps:

  • 1. The Generator receives the random number and returns the generated image.
  • 2. Send the generated digital image to the Discriminator together with the digital image in the actual data set.
  • 3. The Discriminator receives the real and false images and returns the probability. The number between 0 and 1, 1 indicates true and 0 indicates false.

4, Build generator

def build_generator():
    # ======================================= #
    #     Generator, input a string of random numbers to generate pictures
    # ======================================= #
    model = Sequential([
        layers.Dense(256, input_dim=latent_dim),
        layers.LeakyReLU(alpha=0.2),               # Higher level activation function
        layers.BatchNormalization(momentum=0.8),   # BN normalization
        layers.Dense(np.prod(img_shape), activation='tanh'),

    noise = layers.Input(shape=(latent_dim,))
    img = model(noise)

    return Model(noise, img)

5, Build discriminator

def build_discriminator():
    # ===================================== #
    #   The discriminator discriminates the authenticity of the input picture
    # ===================================== #
    model = Sequential([
        layers.Dense(1, activation='sigmoid')

    img = layers.Input(shape=img_shape)
    validity = model(img)

    return Model(img, validity)

  • Discriminator training principle: through the identification of the input picture, so as to achieve the effect of improvement
  • Generator training principle: the image generated by the discriminator is identified to achieve improvement
# Create discriminator
discriminator = build_discriminator()
# Define optimizer
optimizer = tf.keras.optimizers.Adam(1e-4)

# Create generator 
generator = build_generator()
gan_input = layers.Input(shape=(latent_dim,))
img = generator(gan_input)

# Predict the generated false picture
validity = discriminator(img)
combined = Model(gan_input, validity)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

6, Training model

1. Save sample pictures

def sample_images(epoch):
    Save sample picture
    row, col = 4, 4
    noise = np.random.normal(0, 1, (row*col, latent_dim))
    gen_imgs = generator.predict(noise)

    fig, axs = plt.subplots(row, col)
    cnt = 0
    for i in range(row):
        for j in range(col):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            cnt += 1
    fig.savefig("images/%05d.png" % epoch)

2. Training model

train_on_batch: the function accepts a single batch of data, performs back propagation, and then updates the model parameters. The size of this batch of data can be arbitrary, that is, it does not need to provide a clear batch size. It belongs to the fine control training model.

def train(epochs, batch_size=128, sample_interval=50):
    # Load data
    (train_images,_), (_,_) = tf.keras.datasets.mnist.load_data()

    # Standardize the picture to the [- 1, 1] range   
    train_images = (train_images - 127.5) / 127.5
    # data
    train_images = np.expand_dims(train_images, axis=3)

    # create label
    true = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    # Cycle training
    for epoch in range(epochs): 

        # Randomly select batch_size picture
        idx = np.random.randint(0, train_images.shape[0], batch_size)
        imgs = train_images[idx]      
        # Generate noise
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        # The generator generates pictures through noise, Gen_ The shape of IMGs is: (128, 28, 28, 1)
        gen_imgs = generator.predict(noise)
        # Training discriminator 
        d_loss_true = discriminator.train_on_batch(imgs, true)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        # Return loss value
        d_loss = 0.5 * np.add(d_loss_true, d_loss_fake)

        # Training generator
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        g_loss = combined.train_on_batch(noise, true)
        print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

        # Save sample picture
        if epoch % sample_interval == 0:
train(epochs=30000, batch_size=256, sample_interval=200)
0 [D loss: 0.587824, acc.: 67.77%] [G loss: 0.634870]
1 [D loss: 0.387015, acc.: 74.22%] [G loss: 0.541133]
2 [D loss: 0.380705, acc.: 63.67%] [G loss: 0.455188]
3 [D loss: 0.408720, acc.: 56.25%] [G loss: 0.405431]
4 [D loss: 0.445802, acc.: 52.34%] [G loss: 0.343866]
176 [D loss: 0.394246, acc.: 66.41%] [G loss: 0.648134]
177 [D loss: 0.393966, acc.: 66.21%] [G loss: 0.640118]
178 [D loss: 0.402815, acc.: 65.62%] [G loss: 0.641665]
179 [D loss: 0.404573, acc.: 65.82%] [G loss: 0.647686]
180 [D loss: 0.394707, acc.: 67.19%] [G loss: 0.631329]

7, Generate dynamic graph

If an error is reported: ModuleNotFoundError: No module named 'imageio', you can use: pip install imageio to install the imageio library.

import imageio

def compose_gif():
    # Picture address
    data_dir = "F:/jupyter notebook/DL-100-days/code/images"
    data_dir = pathlib.Path(data_dir)
    paths    = list(data_dir.glob('*'))
    gif_images = []
    for path in paths:
F:\jupyter notebook\DL-100-days\code\images\00000.png
F:\jupyter notebook\DL-100-days\code\images\00200.png
F:\jupyter notebook\DL-100-days\code\images\00400.png
F:\jupyter notebook\DL-100-days\code\images\00600.png
F:\jupyter notebook\DL-100-days\code\images\00800.png
F:\jupyter notebook\DL-100-days\code\images\01000.png

Display of image generation process (about 50s):

