Generative Adversarial Networks¶
Thus far, we have discussed several generative models. A generative model learns the structure of a set of input data. In doing so, the model learns to generate new data that it has never seen before in the training data.
A Generative Adversarial Network (GAN) is an example of a generative model.
%matplotlib inline
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torchvision import datasets, transforms
We designed this tutorial to run on Google Colab. Go to "Runtime" and select "Change Runtime Type" and select "GPU". Run the next two lines of code to download the MNIST dataset.
!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz
/usr/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. pid, fd = os.forkpty()
--2024-03-25 20:22:02-- http://www.di.ens.fr/~lelarge/MNIST.tar.gz Resolving www.di.ens.fr (www.di.ens.fr)... 129.199.99.14 Connecting to www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:80... connected. HTTP request sent, awaiting response... 302 Found Location: https://www.di.ens.fr/~lelarge/MNIST.tar.gz [following] --2024-03-25 20:22:03-- https://www.di.ens.fr/~lelarge/MNIST.tar.gz Connecting to www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:443... connected. HTTP request sent, awaiting response... 200 OK Length: unspecified [application/x-gzip] Saving to: ‘MNIST.tar.gz’ MNIST.tar.gz [ <=> ] 33.20M 3.78MB/s in 11s 2024-03-25 20:22:14 (3.10 MB/s) - ‘MNIST.tar.gz’ saved [34813078] MNIST/ MNIST/raw/ MNIST/raw/train-labels-idx1-ubyte MNIST/raw/t10k-labels-idx1-ubyte.gz MNIST/raw/t10k-labels-idx1-ubyte MNIST/raw/t10k-images-idx3-ubyte.gz MNIST/raw/train-images-idx3-ubyte MNIST/raw/train-labels-idx1-ubyte.gz MNIST/raw/t10k-images-idx3-ubyte MNIST/raw/train-images-idx3-ubyte.gz MNIST/processed/ MNIST/processed/training.pt MNIST/processed/test.pt
Then, let's begin by loading the MNIST training data.
mnist_data = datasets.MNIST('./', train=True, download=True, transform=transforms.ToTensor())
Q1. Model¶
A generative adversarial network (GAN) model consists of two models:
- A Generator network $G$ that takes in a latent embedding (usually random noise) and generates an image like those that exists in the training data
- A Discriminator network $D$ that tries to distinguish between real images from the training data, and fake images produced by the generator
In essense, we have two neural networks that are adversaries: the generator wants to fool the discriminator, and the discriminator wants to avoid being fooled.
Let's set up a simple generator and a discriminator to start:
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(28*28, 300),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(300, 100),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(100, 1)
)
def forward(self, x):
x = x.view(x.size(0), -1)
out = self.model(x)
return out.view(x.size(0))
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(100, 300),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(300, 28*28),
nn.Sigmoid()
)
def forward(self, x):
out = self.model(x).view(x.size(0), 1, 28, 28)
return out
To make the GAN faster to train, both the Discriminator and Generator are fully-connected networks. You can use convolutional layers and other techniques we discussed.
One difference between these models and the previous models we've built is
that we are using a nn.LeakyReLU
activation.
Leaky ReLU is a variation
of the ReLU activation that lets some information through, even when
its input is less than 0. The layer nn.LeakyReLU(0.2, inplace=True)
performs the computation: x if x > 0 else 0.2 * x
.
Part (a)¶
What tensor shape does the discriminator take as input? What does it produce as an output?
# Discuss the answers with your tutorial group
Part (b)¶
What tensor shape does the generator take as input? What does it produce as an output?
# Discuss the answers with your tutorial group
Part (c)¶
Explain what might be an advantage of using the leaky ReLU activation, compared to a ReLU activation.
# Discuss the answers with your tutorial group
Q1. Training a GAN¶
To train a GAN, we need to find a suitable loss function. Consider the following quantity:
P(D correctly identifies real image) + P(D correctly identifies image generated by G)
A good discriminator would want to maximize the above quanity by altering its parameters.
Likewise, a good generator would want to minimize the above quanity. Actually,
the only term that the generator controls is P(D correctly identifies image generated by G)$
So, the best thing for the generator to do is alter its parameters to generate images
that can fool D.
Since we are looking at class probabilities, we will use binary cross entropy loss.
Here is a rudimentary training loop to train a GAN. For every minimatch of data, we train the discriminator for one iteration, and then we train the generator for one iteration.
For the discriminator, we use the label 1
to represent a real image, and 0
to represent
a fake image.
Part (a)¶
Fill in the training code below. The lines that you need to change is marked
with a #TODO
comment.
def train(generator, discriminator, device, lr=0.001, num_epochs=5):
# We will use the binary cross-entropy loss
criterion = nn.BCEWithLogitsLoss()
# From: https://colab.research.google.com/github/smartgeometry-ucl/dl4g/blob/master/gan.ipynb#scrollTo=-o7DOryPkUOb
# GAN training can be unstable. In this case, the strong momentum
# for the gradient prevents convergence. One possible explanation is that the
# strong momentum does not allow the two players in the adversarial game to react
# to each other quickly enough. Decreasing beta1 (the exponential decay for the
# gradient moving average in [0,1], lower is faster decay) from the default 0.9
# to 0.5 allows for quicker reactions.
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
# Use the MNIST data as training data
train_loader = torch.utils.data.DataLoader(mnist_data, batch_size=100, shuffle=True)
# We will track 16 noise vectors, and what image the generator creates
# from these 16 noise vectors.
num_test_samples = 16
test_noise = torch.randn(num_test_samples, 100).to(device)
# Move generator and discriminator weights to the GPU
generator = generator.to(device)
discriminator = discriminator.to(device)
for epoch in range(num_epochs):
# put the generator/discriminator in training mode
generator.train()
discriminator.train()
for n, (images, _) in enumerate(train_loader):
# === Train the Discriminator ===
noise = torch.randn(images.size(0), 100).to(device)
fake_images = generator(noise)
inputs = torch.cat([images.to(device), fake_images])
labels = None # TODO: Create a vector denoting that the real images
# should have label 1, and fake images should have label 0
d_outputs = discriminator(inputs)
d_loss = criterion(d_outputs, labels.to(device))
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# === Train the Generator ===
noise = torch.randn(images.size(0), 100)
fake_images = generator(noise.to(device))
outputs = discriminator(fake_images.to(device))
g_loss = criterion(outputs,
None) # TODO: what should this be?
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
# Report the average scores
scores = torch.sigmoid(d_outputs)
real_score = scores[:images.size(0)].data.mean()
fake_score = scores[images.size(0):].data.mean()
print('Epoch [%d/%d], d_loss: %.4f, g_loss: %.4f, '
'D(x): %.2f, D(G(z)): %.2f'
% (epoch + 1, num_epochs, d_loss.item(), g_loss.item(), real_score, fake_score))
# Plot images generated from the 16 noise vectors
generator.eval()
discriminator.eval()
test_images = generator(test_noise).cpu()
plt.figure(figsize=(9, 3))
for k in range(16):
plt.subplot(2, 8, k+1)
plt.imshow(test_images[k,:].data.numpy().reshape(28, 28), cmap='Greys')
plt.show()
Part (b)¶
Train the network for at least 20 epochs.
discriminator = Discriminator()
generator = Generator()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#train(generator, discriminator, device, lr=0.0002, num_epochs=20)
Part (c)¶
GANs are notoriously difficult to train. One difficulty is that a training curve is no longer as helpful as it was for a supervised learning problem! The generator and discriminator losses tend to bounce up and down, since both the generator and discriminator are changing over time. Tuning hyperparameters is also much more difficult, because we don't have the training curve to guide us. Newer GAN models like Wasserstein GAN tries to alleviate some of these issues, but are beyond the scope of this course.
To compound the difficulty of hyperparameter tuning GANs also take a long time to train. It is tempting to stop training early, but the effects of hyperparameters may not be noticable until later on in training.
Discuss what trend you notice in the discriminator and generator loss.
# Discuss the answers with your tutorial group
Part (d)¶
You might have noticed in the images generated by our simple GAN that the model seems to only output a small number of digit types. This phenomenon is called mode collapse. Explain if mode collapse manifested in your GAN. If so, how? Why does mode collapse occur in a GAN?
# Discuss the answers with your tutorial group