from pathlib import Path
from typing import Callable, Sequence
import numpy as np
from matplotlib import pyplot as plt
%config InlineBackend.figure_formats = {'retina', 'png'}
import torch
from torch import Tensor, nn, optim
from torch.utils.data import ConcatDataset, DataLoader, Dataset
from torchinfo import summary
from torchvision import transforms as T
from torchvision.utils import make_grid
from tqdm import tqdm
= 42
SEED
= Path(".").resolve()
PROJECT_PATH = PROJECT_PATH / "figures"
FIGURE_PATH = Path.home() / "datasets" DATASET_PATH
Generative Adversarial Networks
Generative Adversarial Networks (GANs) represent an innovative class of unsupervised neural networks that have revolutionized the field of artificial intelligence. Eager to learn how they work, I’ve implemented foundational “vanilla” GAN and its more complex counterpart, the Deep Convolutional GAN (DCGAN), from scratch. I’ve put them on a test run on MNIST Digits and Fashion toy datasets.
Introduction
Generative Adversarial Networks (GANs) are an innovative class of unsupervised neural networks that have revolutionized the field of artificial intelligence. They were first introduced in Generative Adversarial Networks (Goodfellow et al. 2014) and consist of two separate neural networks: the generator (creates data) and the discriminator (evaluates data authenticity). The generator aims to fool the discriminator by producing realistic data, while the discriminator tries to differentiate real from fake. Over iterations, the generator’s data becomes more convincing.
As an analogy, consider two kids, one drawing counterfeit money (“Generator”) and another assessing its realism (“Discriminator”). Over time, the counterfeit drawings become increasingly convincing.
Vanilla GAN
The most fundamental variant of GAN is the “vanilla” GAN, where “vanilla” signifies the model in its original and most straightforward form rather than a flavor. To better understand its mechanism, I’ve illustrated its structure on Figure 1.
- Generator \(G(z; w_g)\) takes random noise \(z\) as input and produces fabricated data \(x_f\).
- \(z\) represents the input vector, a noise vector from the Gaussian distribution.
- \(w_g\) denotes generator neural network weights.
- \(x_f\) is a fabricated data sample meant for the discriminator.
- Discriminator \(D(x; w_d)\) differentiates between real and generated data.
- \(x\) represents input vectors, which come from either a real dataset (\(x_r \sim p_\textrm{data}(x)\)) or from the set of fabricated samples (\(x_f = G(z \sim p_z(z); w_g)\)).
- \(w_d\) denodes discriminator neural network weights.
Objective Function
The interaction between the Generator and the Discriminator can be quantified by their objective or loss functions:
- Discriminator’s Objective: For real data \(x\), \(D\) wants \(D(x)\) near \(1\). For generated data \(G(z)\), it targets \(D(G(z))\) close to \(0\). Its objective is:
\[ \mathcal{L}(D) = \log(D(x)) + \log(1 - D(G(z))). \]
- Generator’s Objective: \(G\) aims for \(D(G(z))\) to approach \(1\), given by:
\[ \mathcal{L}(G) = \log(1 − D(G(z))) \]
Both \(G\) and \(D\) continuously improve to outperform each other in this game.
Minimax Game in GANs
Vanilla GANs are structured around the minimax game from game theory:
\[ \min_{G}\max_{D} \mathcal{L}(D, G) = \log(D(x)) + \log(1 - D(G(z))) \]
In essence:
- Discriminator: Maximizes its capacity to differentiate real data from generated.
- Generator: Minimizes the discriminator’s success rate by producing superior forgeries.
The iterative competition refines both, targeting a proficient Generator and a perceptive Discriminator.
Prepare Environment
In the upcoming sections, we’ll do the following steps to prepare the development environment:
- Import necessary libraries, primarily PyTorch and Matplotlib.
- Define constants, including project path and seed, for consistency.
- Determine the computational device (e.g., GPU).
- Provide a weight initialization helper function.
# Disable functionalities for speed-up
False)
torch.autograd.set_detect_anomaly(False)
torch.autograd.profiler.profile(False)
torch.autograd.profiler.emit_nvtx(
= torch.device("cuda" if torch.cuda.is_available() else "cpu")
device if torch.cuda.is_available():
# Allow CuDNN internal benchmarking for architecture-specific optimizations
= True torch.backends.cudnn.benchmark
def weights_init(net: nn.Module) -> None:
for m in net.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
0.0, 0.02)
nn.init.normal_(m.weight, if m.bias is not None:
0.0)
nn.init.constant_(m.bias,
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
1.0, 0.02)
nn.init.normal_(m.weight, if m.bias is not None:
0.0)
nn.init.constant_(m.bias,
elif isinstance(m, nn.Linear):
0, 0.02)
nn.init.normal_(m.weight, if m.bias is not None:
0.0) nn.init.constant_(m.bias,
Generator
The Generator in GANs acts as an artist, crafting data.
- Input: Takes random noise, typically from a standard normal distribution.
- Architecture: Uses dense layers, progressively increasing data dimensions.
- Output: Reshapes data to desired format (e.g., image). Often uses ‘tanh’ for activation.
- Objective: Generate data indistinguishable from real by the Discriminator.
class Generator(nn.Module):
def __init__(self, out_dim: Sequence[int], nz: int = 100, ngf: int = 256, alpha: float = 0.2):
"""
:param out_dim: output image dimension / shape
:param nz: size of the latent z vector $z$
:param ngf: size of feature maps (units in the hidden layers) in the generator
:param alpha: negative slope of leaky ReLU activation
"""
super().__init__()
self.out_dim = out_dim
self.model = nn.Sequential(
nn.Linear(nz, ngf),=True),
nn.LeakyReLU(alpha, inplace2 * ngf),
nn.Linear(ngf, =True),
nn.LeakyReLU(alpha, inplace2 * ngf, 2 * ngf),
nn.Linear(=True),
nn.LeakyReLU(alpha, inplace4 * ngf, int(np.prod(self.out_dim))),
nn.Linear(
nn.Tanh(),
)
def forward(self, x: Tensor) -> Tensor:
= self.model(x)
x = torch.reshape(x, (x.size(0), *self.out_dim))
x return x
=(1, 28, 28)), input_size=[128, 100]) summary(Generator(out_dim
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Generator [128, 1, 28, 28] --
├─Sequential: 1-1 [128, 784] --
│ └─Linear: 2-1 [128, 256] 25,856
│ └─LeakyReLU: 2-2 [128, 256] --
│ └─Linear: 2-3 [128, 512] 131,584
│ └─LeakyReLU: 2-4 [128, 512] --
│ └─Linear: 2-5 [128, 1024] 525,312
│ └─LeakyReLU: 2-6 [128, 1024] --
│ └─Linear: 2-7 [128, 784] 803,600
│ └─Tanh: 2-8 [128, 784] --
==========================================================================================
Total params: 1,486,352
Trainable params: 1,486,352
Non-trainable params: 0
Total mult-adds (M): 190.25
==========================================================================================
Input size (MB): 0.05
Forward/backward pass size (MB): 2.64
Params size (MB): 5.95
Estimated Total Size (MB): 8.63
==========================================================================================
Discriminator
The Discriminator is GAN’s evaluator, distinguishing real from fake data.
- Input: Takes either real data samples or those from the Generator.
- Architecture: Employs dense layers for binary classification of the input.
- Output: Uses a sigmoid activation, yielding a score between 0-1, reflecting the data’s authenticity.
- Objective: Recognize real data and identify fake data from the Generator.
class Discriminator(nn.Module):
def __init__(self, input_dim: Sequence[int], ndf: int = 128, alpha: float = 0.2):
super().__init__()
self.model = nn.Sequential(
int(np.prod(input_dim)), 4 * ndf),
nn.Linear(=True),
nn.LeakyReLU(alpha, inplace0.3),
nn.Dropout(4 * ndf, 2 * ndf),
nn.Linear(=True),
nn.LeakyReLU(alpha, inplace0.3),
nn.Dropout(2 * ndf, ndf),
nn.Linear(=True),
nn.LeakyReLU(alpha, inplace0.3),
nn.Dropout(1),
nn.Linear(ndf,
nn.Sigmoid(),
)
def forward(self, x: Tensor) -> Tensor:
= torch.reshape(x, (x.size(0), -1))
x return self.model(x)
=(1, 28, 28)), input_size=[128, 1, 28, 28]) summary(Discriminator(input_dim
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Discriminator [128, 1] --
├─Sequential: 1-1 [128, 1] --
│ └─Linear: 2-1 [128, 512] 401,920
│ └─LeakyReLU: 2-2 [128, 512] --
│ └─Dropout: 2-3 [128, 512] --
│ └─Linear: 2-4 [128, 256] 131,328
│ └─LeakyReLU: 2-5 [128, 256] --
│ └─Dropout: 2-6 [128, 256] --
│ └─Linear: 2-7 [128, 128] 32,896
│ └─LeakyReLU: 2-8 [128, 128] --
│ └─Dropout: 2-9 [128, 128] --
│ └─Linear: 2-10 [128, 1] 129
│ └─Sigmoid: 2-11 [128, 1] --
==========================================================================================
Total params: 566,273
Trainable params: 566,273
Non-trainable params: 0
Total mult-adds (M): 72.48
==========================================================================================
Input size (MB): 0.40
Forward/backward pass size (MB): 0.92
Params size (MB): 2.27
Estimated Total Size (MB): 3.59
==========================================================================================
Traning Loop
The training process is iterative:
- Update Discriminator: With the Generator static, improve the Discriminator’s detection of real vs. fake.
- Update Generator: With a static Discriminator, enhance the Generator’s ability to deceive.
Training continues until the Generator produces almost authentic data. Equilibrium is reached when the Discriminator sees every input as equally likely real or fake, assigning a probability of \(\frac{1}{2}\).
Using .eval()
and .train()
modes initially seemed promising for faster training. However, they affected layers like BatchNorm2d
and Dropout
, making the GAN diverge. Also, switching between eval and train modes is not free of charge.
def train_step(
generator: nn.Module,
discriminator: nn.Module,
optim_G: optim.Optimizer,
optim_D: optim.Optimizer,
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
real_data: torch.Tensor,int,
noise_dim:
device: torch.device,-> tuple[float, float]:
) = real_data.size(0)
batch_size
= real_data.to(device)
real_data
# Labels for real and fake data
= torch.ones(batch_size, 1, device=device)
real_labels = torch.zeros(batch_size, 1, device=device)
fake_labels
### Train Discriminator
# Real data
= discriminator(real_data)
output_real = criterion(output_real, real_labels)
loss_D_real
# Fake data
= torch.randn(batch_size, noise_dim, dtype=torch.float, device=device)
noise = generator(noise)
fake_data = discriminator(fake_data.detach())
output_fake = criterion(output_fake, fake_labels)
loss_D_fake
# Backprop and optimize for discriminator
= (loss_D_real + loss_D_fake) / 2.0
loss_D
discriminator.zero_grad()
loss_D.backward()
optim_D.step()
### Train Generator
# Recompute fake data’s discriminator scores
= discriminator(fake_data)
output_fake = criterion(output_fake, real_labels)
loss_G
# Backprop and optimize for generator
generator.zero_grad()
loss_G.backward()
optim_G.step()
return loss_G.item(), loss_D.item()
Evaluation
Before evaluation, we configured the learning rate (LR), optimizer’s \(\beta\) parameters, batch size, and data loader settings for all experiments. We used the MNIST digits and MNIST fashion datasets for assessment.
= 0.0002
OPTIMIZER_LR = 1e-5
L2_NORM = (0.5, 0.999)
OPTIMIZER_BETAS = 100
N_EPOCHS = 128 BATCH_SIZE
= {
loader_kwargs "num_workers": 8,
"pin_memory": True,
"shuffle": True,
"batch_size": BATCH_SIZE,
"prefetch_factor": 16,
"pin_memory_device": device.type,
"persistent_workers": False,
}
MNIST Digits Dataset
The MNIST (Modified National Institute of Standards and Technology) dataset is a well-known collection of handwritten digits, extensively used in the fields of machine learning and computer vision for training and testing purposes. Its simplicity and size make it a popular choice for introductory courses and experiments in image recognition.
In total, the dataset contains 70,000 grayscale images of handwritten digits (from 0 to 9). Each image is 28x28 pixels.
def get_minst_dataset(transform=None) -> Dataset:
from torchvision.datasets import MNIST
= str(DATASET_PATH)
root = MNIST(root=root, train=True, download=True, transform=transform)
trainset = MNIST(root=root, train=False, download=True, transform=transform)
testset # Combine train and test dataset for more samples.
= ConcatDataset([trainset, testset])
dataset return dataset
= (1, 28, 28)
IMG_DIM = 100 NOISE_DIM
= T.Compose([T.ToTensor(), T.Normalize(0.5, 0.5)])
transform
= get_minst_dataset(transform=transform)
data = DataLoader(data, **loader_kwargs)
dataloader
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
# benchmark_noise is used for the animation to show how output evolve on the same vector
= torch.randn(16 * 16, NOISE_DIM, device=device)
benchmark_noise
= Generator(out_dim=IMG_DIM, nz=NOISE_DIM).to(device)
generator apply(weights_init)
generator.
= Discriminator(input_dim=IMG_DIM).to(device)
discriminator apply(weights_init)
discriminator.
= optim.AdamW(
optimizer_G
generator.parameters(),=OPTIMIZER_LR,
lr=OPTIMIZER_BETAS,
betas=L2_NORM,
weight_decay
)
= optim.AdamW(
optimizer_D
discriminator.parameters(),=OPTIMIZER_LR,
lr=OPTIMIZER_BETAS,
betas=L2_NORM,
weight_decay
)
= nn.BCELoss().to(device) criterion
= []
animation
= [], []
g_losses, d_losses for _ in tqdm(range(N_EPOCHS), unit="epochs"):
generator.train()
discriminator.train()
for samples_real, _ in dataloader:
= train_step(generator, discriminator, optimizer_G, optimizer_D, criterion, samples_real, NOISE_DIM, device)
g_loss, d_loss
g_losses.append(g_loss)
d_losses.append(d_loss)
eval()
generator.with torch.inference_mode():
= generator(benchmark_noise)
images = images.detach().cpu()
images
= make_grid(images, nrow=16, normalize=True)
images animation.append(images)
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [03:43<00:00, 2.23s/epochs]
MNIST Fashion Dataset
The Fashion MNIST dataset is a collection of grayscale images of 10 different categories of clothing items, designed as a more challenging alternative to the classic MNIST dataset of handwritten digits. Each image in the dataset is 28x28 pixels. The 10 categories include items like t-shirts/tops, trousers, pullovers, dresses, coats, sandals, and more. With 70,000 images, Fashion MNIST is commonly used for benchmarking machine learning algorithms, especially in image classification tasks.
= (1, 28, 28)
IMG_DIM = 100 NOISE_DIM
def get_mnist_fashion_dataset(transform=None):
from torchvision.datasets import FashionMNIST
= str(DATASET_PATH)
root = FashionMNIST(root=root, train=True, download=True, transform=transform)
trainset = FashionMNIST(root=root, train=False, download=True, transform=transform)
testset # Combine train and test dataset for more samples.
= ConcatDataset([trainset, testset])
dataset return dataset
= T.Compose([T.ToTensor(), T.Normalize(0.5, 0.5)])
transform
= get_mnist_fashion_dataset(transform=transform)
data = DataLoader(data, **loader_kwargs)
dataloader
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
# benchmark_noise is used for the animation to show how output evolve on same vector
= torch.randn(16 * 16, NOISE_DIM, device=device)
benchmark_noise
= Generator(out_dim=IMG_DIM, nz=NOISE_DIM).to(device)
generator apply(weights_init)
generator.
= Discriminator(input_dim=IMG_DIM).to(device)
discriminator apply(weights_init)
discriminator.
= optim.AdamW(
optimizer_G
generator.parameters(),=OPTIMIZER_LR,
lr=OPTIMIZER_BETAS,
betas=L2_NORM,
weight_decay
)
= optim.AdamW(
optimizer_D
discriminator.parameters(),=OPTIMIZER_LR,
lr=OPTIMIZER_BETAS,
betas=L2_NORM,
weight_decay
)
= nn.BCELoss().to(device) criterion
= []
animation
= [], []
g_losses, d_losses for _ in tqdm(range(N_EPOCHS), unit="epochs"):
generator.train()
discriminator.train()
for samples_real, _ in dataloader:
= train_step(generator, discriminator, optimizer_G, optimizer_D, criterion, samples_real, NOISE_DIM, device)
g_loss, d_loss
g_losses.append(g_loss)
d_losses.append(d_loss)
eval()
generator.with torch.inference_mode():
= generator(benchmark_noise)
images = images.detach().cpu()
images
= make_grid(images, nrow=16, normalize=True)
images
animation.append(images)
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [03:49<00:00, 2.30s/epochs]
DCGAN
DCGAN, short for Deep Convolutional Generative Adversarial Network, differs from vanilla GAN by using convolutional layers. This design makes DCGAN better for image data. With specific architectural guidelines, DCGAN trains more consistently and generates clearer images than vanilla GANs across various hyperparameters.
Setting Up DCGANs
Generator
class Generator(nn.Module):
def __init__(self, nz: int = 100, ngf: int = 32, nc: int = 1):
"""
:param nz: size of the latent z vector
:param ngf: size of feature maps in generator
:param nc: number of channels in the training images.
"""
super().__init__()
self.layers = nn.Sequential(
4 * ngf, 4, 1, 0, bias=False),
nn.ConvTranspose2d(nz, 4 * ngf),
nn.BatchNorm2d(=True),
nn.ReLU(inplace4 * ngf, 2 * ngf, 3, 2, 1, bias=False),
nn.ConvTranspose2d(2 * ngf),
nn.BatchNorm2d(=True),
nn.ReLU(inplace2 * ngf, ngf, 4, 2, 1, bias=False),
nn.ConvTranspose2d(
nn.BatchNorm2d(ngf),=True),
nn.ReLU(inplace4, 2, 1, bias=False),
nn.ConvTranspose2d(ngf, nc,
nn.Tanh(),
)
def forward(self, x: Tensor) -> Tensor:
= torch.reshape(x, (x.size(0), -1, 1, 1))
x return self.layers(x)
=(128, 100)) summary(Generator(), input_size
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Generator [128, 1, 28, 28] --
├─Sequential: 1-1 [128, 1, 28, 28] --
│ └─ConvTranspose2d: 2-1 [128, 128, 4, 4] 204,800
│ └─BatchNorm2d: 2-2 [128, 128, 4, 4] 256
│ └─ReLU: 2-3 [128, 128, 4, 4] --
│ └─ConvTranspose2d: 2-4 [128, 64, 7, 7] 73,728
│ └─BatchNorm2d: 2-5 [128, 64, 7, 7] 128
│ └─ReLU: 2-6 [128, 64, 7, 7] --
│ └─ConvTranspose2d: 2-7 [128, 32, 14, 14] 32,768
│ └─BatchNorm2d: 2-8 [128, 32, 14, 14] 64
│ └─ReLU: 2-9 [128, 32, 14, 14] --
│ └─ConvTranspose2d: 2-10 [128, 1, 28, 28] 512
│ └─Tanh: 2-11 [128, 1, 28, 28] --
==========================================================================================
Total params: 312,256
Trainable params: 312,256
Non-trainable params: 0
Total mult-adds (G): 1.76
==========================================================================================
Input size (MB): 0.05
Forward/backward pass size (MB): 24.26
Params size (MB): 1.25
Estimated Total Size (MB): 25.56
==========================================================================================
Discriminator
class Discriminator(nn.Module):
def __init__(self, ndf: int = 32, nc: int = 1, alpha: float = 0.2):
super().__init__()
self.layers = nn.Sequential(
4, 2, 1, bias=False),
nn.Conv2d(nc, ndf,
nn.BatchNorm2d(ndf),=True),
nn.LeakyReLU(alpha, inplace2 * ndf, 4, 2, 1, bias=False),
nn.Conv2d(ndf, 2 * ndf),
nn.BatchNorm2d(=True),
nn.LeakyReLU(alpha, inplace2 * ndf, 4 * ndf, 3, 2, 1, bias=False),
nn.Conv2d(* 4),
nn.BatchNorm2d(ndf =True),
nn.LeakyReLU(alpha, inplace4 * ndf, 1, 4, 1, 0, bias=False),
nn.Conv2d(
nn.Sigmoid(),
)
def forward(self, x: Tensor) -> Tensor:
= self.layers(x)
x = torch.reshape(x, (x.size(0), -1))
x return x
=(BATCH_SIZE, 1, 28, 28)) summary(Discriminator(), input_size
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Discriminator [128, 1] --
├─Sequential: 1-1 [128, 1, 1, 1] --
│ └─Conv2d: 2-1 [128, 32, 14, 14] 512
│ └─BatchNorm2d: 2-2 [128, 32, 14, 14] 64
│ └─LeakyReLU: 2-3 [128, 32, 14, 14] --
│ └─Conv2d: 2-4 [128, 64, 7, 7] 32,768
│ └─BatchNorm2d: 2-5 [128, 64, 7, 7] 128
│ └─LeakyReLU: 2-6 [128, 64, 7, 7] --
│ └─Conv2d: 2-7 [128, 128, 4, 4] 73,728
│ └─BatchNorm2d: 2-8 [128, 128, 4, 4] 256
│ └─LeakyReLU: 2-9 [128, 128, 4, 4] --
│ └─Conv2d: 2-10 [128, 1, 1, 1] 2,048
│ └─Sigmoid: 2-11 [128, 1, 1, 1] --
==========================================================================================
Total params: 109,504
Trainable params: 109,504
Non-trainable params: 0
Total mult-adds (M): 369.68
==========================================================================================
Input size (MB): 0.40
Forward/backward pass size (MB): 23.46
Params size (MB): 0.44
Estimated Total Size (MB): 24.30
==========================================================================================
Evaluation
MNIST Digits Dataset
= (1, 28, 28)
IMG_DIM = 128
NOISE_DIM
= T.Compose(
transform
[
T.ToTensor(),0.5, 0.5),
T.Normalize(
]
)
= get_minst_dataset(transform)
data = DataLoader(data, **loader_kwargs)
dataloader
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
# benchmark_noise is used for the animation to show how output evolve on same vector
= torch.randn(16 * 16, NOISE_DIM, device=device)
benchmark_noise
= Generator(nz=NOISE_DIM, ngf=32, nc=IMG_DIM[0]).to(device)
generator apply(weights_init)
generator.
= Discriminator(ndf=32, nc=IMG_DIM[0]).to(device)
discriminator apply(weights_init)
discriminator.
= optim.AdamW(
optimizer_G
generator.parameters(),=OPTIMIZER_LR,
lr=OPTIMIZER_BETAS,
betas=L2_NORM,
weight_decay
)
= optim.AdamW(
optimizer_D
discriminator.parameters(),=OPTIMIZER_LR,
lr=OPTIMIZER_BETAS,
betas=L2_NORM,
weight_decay
)
= nn.BCELoss().to(device) criterion
= []
animation
= [], []
g_losses, d_losses for _ in tqdm(range(N_EPOCHS), unit="epochs"):
generator.train()
discriminator.train()
for samples_real, _ in dataloader:
= train_step(generator, discriminator, optimizer_G, optimizer_D, criterion, samples_real, NOISE_DIM, device)
g_loss, d_loss
g_losses.append(g_loss)
d_losses.append(d_loss)
eval()
generator.with torch.inference_mode():
= generator(benchmark_noise)
images = images.detach().cpu()
images
= make_grid(images, nrow=16, normalize=True)
images
animation.append(images)
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [04:54<00:00, 2.94s/epochs]
MNIST Fashion Dataset
= (1, 28, 28)
IMG_DIM = 128
NOISE_DIM
= T.Compose(
transform
[
T.ToTensor(),0.5, 0.5),
T.Normalize(
]
)
= get_mnist_fashion_dataset(transform)
data = DataLoader(data, **loader_kwargs)
dataloader
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
# benchmark_noise is used for the animation to show how output evolve on same vector
= torch.randn(16 * 16, NOISE_DIM, device=device)
benchmark_noise
= Generator(nz=NOISE_DIM, ngf=32, nc=IMG_DIM[0]).to(device)
generator apply(weights_init)
generator.
= Discriminator(ndf=32, nc=IMG_DIM[0]).to(device)
discriminator apply(weights_init)
discriminator.
= optim.AdamW(
optimizer_G
generator.parameters(),=OPTIMIZER_LR,
lr=OPTIMIZER_BETAS,
betas=L2_NORM,
weight_decay
)
= optim.AdamW(discriminator.parameters(), lr=OPTIMIZER_LR, betas=OPTIMIZER_BETAS, weight_decay=L2_NORM)
optimizer_D
= nn.BCELoss().to(device) # F.binary_cross_entropy_with_logits #nn.BCELoss().to(device) criterion
= []
animation
= [], []
g_losses, d_losses for _ in tqdm(range(N_EPOCHS), unit="epochs"):
generator.train()
discriminator.train()
for samples_real, _ in dataloader:
= train_step(generator, discriminator, optimizer_G, optimizer_D, criterion, samples_real, NOISE_DIM, device)
g_loss, d_loss
g_losses.append(g_loss)
d_losses.append(d_loss)
eval()
generator.with torch.inference_mode():
= generator(benchmark_noise)
images = images.detach().cpu()
images
= make_grid(images, nrow=16, normalize=True)
images
animation.append(images)
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [04:55<00:00, 2.95s/epochs]
Conclusion
Generative Adversarial Networks (GANs) represent an innovative class of unsupervised neural networks that have significantly impacted the field of artificial intelligence (AI). They consist of two components: a Generator that improves its output and a Discriminator that enhances its evaluative skills. In a competitive yet symbiotic relationship, these two networks converge towards a dynamic equilibrium. This interaction exemplifies the strength of GANs and the adaptability of adversarial learning in AI, blending creative generation with critical assessment.
In this post, I explore the original GAN, often referred to as the “vanilla” GAN. My aim was to understand the basic mechanics of how GANs operate. Meanwhile, others have advanced this technology, applying it to a range of innovative and fascinating new areas.