Generative Adversarial Network (GAN) is a powerful tool that creates meaningful and realistic images from noise. However, these networks are known for their extremely slow training rates. Recently, several methods appeared suggesting a speed-up of GAN's training. This software represents a GAN, equipped with such methods. Consider realistically looking MNIST images generated by ShortGAN after only 100 training epochs:
ShortGAN perfectly fits educational purposes. Concise and intuitively clear code reveals complex Generative Adversarial Networks machinery in an understandable for unexperienced user way. It does not require a GPU (but it would be nice if you have one), and you can run it on your desktop and get the first result in 10-40 minutes!
- Python 3
- CPU or GPU
The neural network consists of:
- A Generator network that generates synthetic images of digits.
- A Discriminator network that classifies images as real or fake.
- Pretraining the generator with real MNIST samples to ensure Generator training.
- Training the GAN where the generator and discriminator compete against each other.
The Generator creates synthetic images from random noise vectors. Here's how its layers function:
- Input Layer: Accepts a noise vector of size
noise_dim
(e.g., 128). - Fully Connected Layer 1: A linear transformation followed by a ReLU activation. This projects the noise into a higher-dimensional space (256 units).
- Convolutional Block:
Unflatten
: Reshapes the data into a feature map of size(1, 16, 16)
to work with convolutional layers.Conv2d
Layers: Adds spatial context to the feature maps, progressively learning localized patterns.BatchNorm2d
: Normalizes activations, speeding up convergence.ConvTranspose2d
: Upsamples the feature maps, reconstructing a higher-resolution image.
- Fully Connected Layer 2:
- Maps the convolutional outputs to a 28x28 grayscale image.
- Includes a Tanh activation to normalize pixel values to ([-1, 1]), matching the MNIST normalization.
- Spectral Normalization: Ensures smoother gradients and prevents the generator from overproducing high-frequency details.
- Convolution and Transpose Convolution: Learn to downsample and upsample, preserving realistic image structure.
Pretraining is mandatory for the FastGAN convergence. In a nutshell, in a conventional GAN neural network (NN) without pretraining, the generator learns simultaneously with the discriminator. The discriminator usually trains much faster than the generator and becomes overconfident: whatever untrained generator creates, the discriminator classifies it as close to 100% 'fake'. From the training perspective it looks as: in whatever direction the generator starts its training, the discriminator indicates the wrong direction. Thus, the generator makes quite a little progress in training over several epochs. Pretraining helps the generator to break this vicious circle by producing images vaugeely resembling MNIST digits before full GAN training (no discriminator on this stage!). Here's how it works:
- Objective: Minimize the difference between generated images and real MNIST samples using the Smooth L1 Loss (a regression loss function).
- Procedure:
- Feed random noise to the generator to produce fake images.
- Compare generated images to real images (from the MNIST dataset).
- Backpropagate the loss to update the generator’s parameters.
- Purpose:
- Initializes the generator close to the image manifold, avoiding mode collapse during GAN training.
- Speeds up convergence during adversarial training.
Once generator is pretrained, GAN training starts, alternatively improving Discriminator and the Generator estimates:
-
Discriminator Training:
- Input: A mix of real images (labeled as 1) and fake images from the generator (labeled as 0).
- Loss Function: Binary Cross-Entropy Loss (BCELoss).
- Goal: Maximize the discriminator's ability to distinguish real from fake samples.
-
Generator Training:
- Input: Random noise vectors.
- Loss Function: BCELoss, where the generator aims to fool the discriminator into classifying fake images as real.
- Goal: Minimize the discriminator’s success rate, forcing the generator to produce more realistic images.
-
Epoch Flow:
- Update the discriminator.
- Train the generator twice as often as the discriminator to counteract discriminator overfitting.
- Add Gaussian noise to real images to make the discriminator robust to noisy inputs.
- Smooth the labels for real samples (e.g., use values like 0.9 instead of 1.0) to prevent the discriminator from becoming overly confident.
The GAN uses Binary Cross-Entropy Loss:
- Discriminator Loss:
-
$$( y_i )$$ : Real labels for real images$$((\approx 1))$$ or fake images$$((\approx 0))$$ . -
$$( D(x_i) )$$ : Discriminator’s confidence score for an input. -
Generator Loss:
-
$$( G(z_i) )$$ : Generated images.
-
Discriminator:
$$( \mathcal{L}_D \approx 0.5 - 1.5 )$$ indicates balanced performance. -
Generator:
$$( \mathcal{L}_G \approx 0.7 - 1.2 )$$ suggests the generator is producing plausible samples.
- Goodfellow, I., et al. (2014). Generative Adversarial Networks. arXiv:1406.2661
- Miyato, T., et al. (2018). Spectral Normalization for Generative Adversarial Networks. arXiv:1802.05957
Here is a gan_env.yaml
file to set up the required Python environment using Anaconda. Run conda env create -f gan_env.yaml
in your terminal to get the packages.
name: gan_env
channels:
- defaults
- conda-forge
dependencies:
- python=3.9
- pytorch=1.13.0
- torchvision=0.14.0
- numpy
- matplotlib
- tqdm
- pip
- pip:
- torch-summary
- Run the code to pretrain the generator.
- Train the GAN.
- Visualize the generated images at different stages of training.
Feel free to experiment with the hyperparameters or architecture to further improve the results!