
PyTorch Project to Build a GAN Model on MNIST Dataset
This project compares Vanilla GAN and WGAN regarding generating real images of MNIST. It evaluates these methods using qualitative metrics like FID and Inception Score. The study also attempts to understand the ability of each of these GAN architectures to generate high-quality images.
Project Overview
With this project, we will enter the world of GANs through comparative experimentation between two of its models - Vanilla GAN and WGAN - to produce real-life images using the MNIST dataset and assess which model can achieve this better!
We begin by feeding all real images into both GANs, through which some fake images will be generated. Then, we shall do some fun visualization by combining the real and fake images to assess their closeness.
After that, it's on to computing the FID, or Fréchet Inception Distance, to compare how similar the generated images are to the real deal. It calculates an Inception Score to measure image diversity and quality.
Hopefully, by this, one will have understood comparison-wise how the Vanilla GAN and WGAN perform to generate realistic and diverse images. Quite exciting, right?
Prerequisite
- Good knowledge about GANs.
- Python experience and expertise in other libraries, especially PyTorch.
- Understanding neural networks, specifically the generators and discriminators utilized in GAN.
- Some level of prior understanding of various transforms, which is simply resizing an image and normalization.
- Familiarity with performance metrics like FID (Fréchet Inception Distance) and Inception Score.
- Working with data visualization libraries, for example, matplotlib and torchvision.
Approach
In this project, the approach begins with loading the MNIST dataset and preprocessing the images for use in GANs. We will keep initializing and training two models: Vanilla GAN and WGAN, each with a generator and discriminator. Both generators learn to create fake images from random noise vectors, while the discriminator learns to recognize real vs. fake images. After training, we get fake images from both and compare them visually. Then, we measure the quality of generated images using the FID score, determining how similar the generated images are to the real ones, and the Inception score, which indicates how diverse and qualified the produced images are. This approach allows us to compare performance in terms of the generation of realistic images after comparison of Vanilla GAN with WGAN and draw better results, which model could give
Workflow
The workflow comprises the following steps:
- Load and Preprocess Images for Training of the MNIST Dataset.
- Initialize and Train Vanilla GAN and WGAN Models.
- Generate Fake Images with Vanilla GAN and WGAN Generators.
- Display Real vs. Fake Images Side by Side for Comparison.
- Compute the FID Score between Real and Fake Images.
- Calculate the Inception Score for Quality and Diversity Measurement of the Generated Images.
- Compare the Results of Both Models to Determine which One would be Best.
Methodology
- Use random noise vectors as input into the generators to create fake images.
- Train the discriminators to tell the difference between genuine and impostor images.
- Input preprocessing for images will involve resizing, cropping, and normalization.
- Model performance assessment using FID scores and Inception Scores to obtain a quantified measure of image quality.
- Evaluate the performance of both GANs visually by generating images in grids to observe and compare their performance.
Data Collection and Preparation
Data Collection
- This project utilizes the MNIST dataset, which contains handwritten digits ranging from 0 to 9.
- The dataset is available for download via the torchvision.datasets library.
- The dataset comprises 60,000 training and 10,000 testing images in 28x28 pixel sizes.
- All images are grayscale, and it is a dataset with good digit images for training image generation models.
- The dataset is preprocessed by resizing, normalizing, and converting into tensors compatible with GAN models.
Data Preparation Workflow
- Load the MNIST dataset with torchvision.datasets.MNIST under image generation tasks.
- Perform preprocessing transformations such as resizing, normalization, and grayscale-to-RGB conversion.
- Convert images to pytorch tensors to use with neural networks.
- Use a DataLoader to load images in batch for efficient training of models.
- Put images to the appropriate device (CPU or GPU) for processing by model.
Code Explanation
STEP 1:
This code installs Python libraries: torch for deep learning, torchvision for computer vision tasks, and matplotlib for plotting and graphing. Thus, you are preparing your environment for machine learning and data visualization.
!pip install torch torchvision matplotlib
This code imports all the needed libraries for deep learning and computer vision tasks. Libraries include PyTorch (torch), torchvision to build datasets and transform images, optimization and neural network libraries, and libraries for plotting. This code also imports all the tools for working with images, like PIL and matplotlib, to visualize results.
# Import necessary libraries
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
import torchvision.utils as vutils
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from scipy.linalg import sqrtm
from torch.autograd import Variable
from torchvision import models, transforms
import torch
from PIL import Image
This code sets up the MNIST dataset for training. It converts input images to tensors and normalizes them. The dataset is loaded in batches of size 64 and shuffled for training using a DataLoader.
# MNIST Dataset Setup
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
This function visualizes a set of training MNIST images. It selects a batch of training images and arranges them in a grid for display using matplotlib. The number of images to show is controlled by the num_images argument, whose default value is 10.
def visualize_mnist_data(train_loader, num_images=10):
"""Visualize real MNIST images"""
data_iter = iter(train_loader)
images, labels = next(data_iter)
grid = vutils.make_grid(images[:num_images], nrow=5, normalize=True)
plt.figure(figsize=(8, 8))
plt.imshow(grid.permute(1, 2, 0))
plt.title('Real MNIST Images')
plt.axis('off')
plt.show()
# Visualize MNIST data
visualize_mnist_data(train_loader)
STEP 2:
This code describes general structures for the generator and discriminator for a Vanilla GAN. The Generator receives a random noise vector (z) and generates a 28x28 image using fully connected layers, batch normalization, and LeakyReLU activations.
The Discriminator class tries to determine whether an input image is legitimate or counterfeit by flattening the image and conducting it through several fully connected layers with dropout and LeakyReLU for regularization. It finally outputs a probability signifying whether the image is real or fake.