- GANs[1] pioneered adversarial generative modeling with a zero-sum game between generator and discriminator, representing one of the most influential frameworks in deep learning history
- From DCGAN[2] to StyleGAN3[18], GAN architectures have undergone five generations of evolution, with image generation quality leaping from blurry noise to photo-realistic fidelity
- WGAN[3] redefined the training objective with Wasserstein distance, resolving the original GAN's training instability and mode collapse issues; CycleGAN[10] achieved unpaired image translation
- This article includes two Google Colab implementations: DCGAN handwritten digit generation (with FID evaluation) and SeqGAN text sequence generation (training discrete GANs with reinforcement learning)
I. The Art of Adversarial Learning: GAN's Core Philosophy
In 2014, Ian Goodfellow and colleagues proposed Generative Adversarial Networks (GANs) in a landmark paper[1]. Its core idea stems from an elegant analogy: the game between a counterfeiter and the police.
The counterfeiter (generator G) tries to produce counterfeit bills that are indistinguishable from real ones; the police (discriminator D) try to tell real from fake. As the game progresses, the counterfeiter becomes increasingly skilled at forging, and the police become increasingly skilled at detection — until the counterfeits are so perfect that even the police cannot distinguish them.
The mathematical form of this game is a minimax zero-sum game:
min_G max_D V(D, G) = E_{x~p_data}[log D(x)] + E_{z~p_z}[log(1 - D(G(z)))]
Where:
G(z): Generator, maps random noise z to fake samples
D(x): Discriminator, outputs probability that x is a real sample
p_data: Real data distribution
p_z: Noise distribution (typically Gaussian or uniform)
Optimal discriminator: D*(x) = p_data(x) / (p_data(x) + p_g(x))
Nash equilibrium: p_g = p_data, D*(x) = 0.5 (completely indistinguishable)
GAN training is a two-step alternating process:
- Train discriminator: Fix G, train D to distinguish between real samples and fake samples generated by G
- Train generator: Fix D, train G to generate samples that D misclassifies as real
II. Training Nightmares: GAN's Three Major Challenges
GAN's theory is elegant, but training is extremely difficult. Three core problems plagued the research community for years:
Mode Collapse
The generator "takes shortcuts" — learning to generate only a few specific samples that fool the discriminator, rather than covering the entire data distribution. For example, a GAN trained on MNIST might only learn to generate the digits 1 and 7.
Training Instability
The original GAN's loss function is based on Jensen-Shannon divergence. When the real distribution and generated distribution have no overlap (almost certainly the case in high-dimensional spaces), JS divergence is constant at log(2) — the gradient is zero, and the generator cannot learn.
Evaluation Difficulty
GANs don't have an ELBO loss like VAEs to track. Training loss doesn't reflect generation quality — discriminator loss approaching 0.5 could indicate equilibrium or collapse.
III. DCGAN: Establishing Design Guidelines for Image GANs
Radford et al.[2] proposed DCGAN (Deep Convolutional GAN), establishing a set of classic design guidelines for image GANs:
| Guideline | Generator | Discriminator |
|---|---|---|
| Up/Downsampling | Transposed convolution (stride 2) | Strided convolution (stride 2) |
| Normalization | BatchNorm (except output layer) | BatchNorm (except input layer) |
| Activation | ReLU (Tanh for output) | LeakyReLU(0.2) |
| Fully connected layers | Not used | Not used |
These seemingly simple guidelines had profound significance — they transformed GAN training from "random success" to "reproducible." DCGAN also demonstrated that the generator's latent space has meaningful structure: z(man with glasses) - z(man) + z(woman) ≈ z(woman with glasses).
IV. WGAN: Redefining the Distance Metric
WGAN[3] (Wasserstein GAN) was a watershed for GAN training stability. Arjovsky et al. changed the training objective from JS divergence to the Wasserstein-1 distance (Earth Mover's Distance):
W(p_data, p_g) = inf_{γ∈Π(p_data, p_g)} E_{(x,y)~γ}[||x - y||]
Intuition: The minimum "work" needed to "transport" distribution p_g into p_data
WGAN loss (Kantorovich-Rubinstein duality):
L_critic = E_{x~p_data}[f_w(x)] - E_{z~p_z}[f_w(G(z))]
L_G = -E_{z~p_z}[f_w(G(z))]
f_w is a 1-Lipschitz continuous "critic" (replacing the discriminator)
The key advantage of Wasserstein distance: even when two distributions have no overlap, it still provides meaningful gradients. WGAN's loss value positively correlates with generation quality — lower loss means better generation. This is something the original GAN cannot achieve.
WGAN-GP[4] further replaced the original WGAN's weight clipping with gradient penalty to enforce the Lipschitz constraint, achieving even more stable training.
V. The StyleGAN Series: The Pinnacle of Image Generation
Karras et al. at NVIDIA produced a series of studies[5][6][7][18] that pushed GAN image generation quality to incredible heights:
| Model | Year | Key Innovation | Impact |
|---|---|---|---|
| ProGAN[5] | 2018 | Progressive training (4x4 → 1024x1024) | First 1024x1024 photo-realistic faces |
| StyleGAN[6] | 2019 | Mapping network + AdaIN style injection | Fine-grained hierarchical style control |
| StyleGAN2[7] | 2020 | Weight demodulation + path length regularization | Eliminated droplet artifacts, smoother latent space |
| StyleGAN3[18] | 2021 | Fixed aliasing, continuous signal processing | Sub-pixel equivariance, eliminated "texture sticking" |
StyleGAN's core architectural innovation is separating content from style:
StyleGAN Generator:
z ∈ Z (512-dim) → Mapping Network (8-layer MLP) → w ∈ W (512-dim)
↓
Constant 4×4 → [AdaIN(w)] → Conv → [AdaIN(w)] → Upsample → ...
Coarse style Mid-level style Fine style
(pose/face shape) (features/hair) (skin/texture)
W space is more "disentangled" than Z space → linear interpolation produces meaningful transitions
VI. Image Translation: Pix2Pix and CycleGAN
Pix2Pix: Paired Image Translation
Isola et al.[9] proposed Pix2Pix, applying conditional GANs[8] to image translation — edge maps to photos, semantic labels to street scenes, grayscale to color. It uses a U-Net generator and a PatchGAN discriminator (judging real/fake on 70x70 patches rather than the entire image).
CycleGAN: Unpaired Image Translation
CycleGAN[10] solved an even more challenging problem: image translation without paired data. Its core is cycle consistency loss — if you translate a horse into a zebra, then translate the zebra back, you should get the original horse:
Cycle Consistency:
x → G(x) → F(G(x)) ≈ x (forward cycle)
y → F(y) → G(F(y)) ≈ y (backward cycle)
L_cycle = ||F(G(x)) - x||₁ + ||G(F(y)) - y||₁
VII. GANs for Text: The Challenge of Discrete Sequences
Applying GANs to text generation faces a fundamental barrier: sampling discrete tokens is non-differentiable. You cannot backpropagate through argmax or categorical sampling.
SeqGAN[11] elegantly solved this problem using a reinforcement learning approach:
SeqGAN Framework:
Generator = Policy Network
Discriminator = Reward Function
Action = Select next token
State = Currently generated token sequence
Training Process:
1. G autoregressively generates complete sequences
2. D scores complete sequences (real/fake)
3. Use Monte Carlo rollout to estimate Q-values of intermediate states
4. Update G with REINFORCE policy gradient
Zhang et al.[17] proposed an alternative approach — performing adversarial training in the latent feature space rather than discrete token space, using kernelized discrepancy metrics to match distributions, circumventing the discrete sampling problem.
VIII. Training Stabilization Techniques
Beyond the WGAN series, several key stabilization techniques have been developed:
| Technique | Core Idea | Effect |
|---|---|---|
| Spectral Normalization[12] | Divide each discriminator layer's weight matrix by its largest singular value | Lightweight Lipschitz constraint with almost no additional computational cost |
| Progressive Training[5] | Start from low resolution, progressively add layers | More stable high-resolution image generation |
| Truncation Trick[13] | Limit z range during inference (truncation) | Sacrifice diversity for quality |
| Feature Matching[14] | Match statistics of discriminator intermediate layers | Reduce mode collapse |
| Two Time-Scale[15] | Use different learning rates for D and G | Theoretically converges to local Nash equilibrium |
IX. Evaluation Metrics: How to Measure Generation Quality
| Metric | Computation | What It Measures | Limitations |
|---|---|---|---|
| IS[14] | KL divergence of Inception model predictions | Sharpness + Diversity | Doesn't compare to real distribution |
| FID[15] | Frechet distance of Inception features between real and fake images | Quality + Diversity + Proximity to real distribution | Requires many samples, depends on Inception |
| LPIPS | Distance in perceptual feature space | Perceptual similarity | Pairwise comparison, not distribution-level |
FID[15] is currently the most widely used GAN evaluation metric. It compares the mean and covariance of real and generated images in the Inception-v3 feature space, with lower values indicating better generation quality. Typical FID scores: StyleGAN2 on FFHQ achieves approximately 2.8, BigGAN[13] on ImageNet 128x128 achieves approximately 6.9.
X. GANs vs. Diffusion Models: A Paradigm Shift
In 2021, Dhariwal and Nichol[16] published the landmark paper "Diffusion Models Beat GANs on Image Synthesis," showing that diffusion models comprehensively surpass the best GANs on FID. But this doesn't mean GANs are obsolete:
| Aspect | GAN | Diffusion Model |
|---|---|---|
| Generation Speed | Single forward pass (~20ms) | Tens to hundreds of iterative steps (~2-10s) |
| Training Stability | Adversarial training is unstable | Simple denoising objective, stable |
| Mode Coverage | Prone to mode collapse | Better distribution coverage |
| Image Quality | Extremely high (StyleGAN series) | Extremely high (Imagen, DALL-E 3) |
| Controllability | StyleGAN's W space | Text-conditioned guidance |
| Use Cases | Real-time rendering, video, super-resolution | Text-to-image, editing |
In scenarios requiring real-time generation (games, video, interactive applications), GAN's single-pass inference speed remains unmatched.
XI. Hands-on Lab 1: DCGAN Handwritten Digit Generation (Google Colab)
The following experiment implements DCGAN from scratch, training on MNIST to generate handwritten digits, and computing FID scores to track training quality.
# ============================================================
# Lab 1: DCGAN — MNIST Handwritten Digit Generation + FID Evaluation
# Environment: Google Colab (GPU)
# ============================================================
# --- 0. Installation ---
!pip install -q torchvision scipy
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from scipy import linalg
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# --- 1. Data Loading ---
transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # [-1, 1]
])
dataset = torchvision.datasets.MNIST(root='./data', train=True,
transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128,
shuffle=True, num_workers=2)
print(f"Training images: {len(dataset)}")
# --- 2. Generator (DCGAN Architecture) ---
nz = 100 # Latent vector dimension
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
# Input: z [B, 100, 1, 1]
nn.ConvTranspose2d(nz, 256, 4, 1, 0, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# [B, 256, 4, 4]
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
# [B, 128, 8, 8]
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
# [B, 64, 16, 16]
nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
nn.Tanh()
# [B, 1, 32, 32]
)
def forward(self, z):
return self.main(z)
# --- 3. Discriminator (DCGAN Architecture) ---
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
# [B, 1, 32, 32]
nn.Conv2d(1, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# [B, 64, 16, 16]
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# [B, 128, 8, 8]
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# [B, 256, 4, 4]
nn.Conv2d(256, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x).view(-1)
# --- 4. Initialization ---
def weights_init(m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.normal_(m.weight, 0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
G = Generator().to(device).apply(weights_init)
D = Discriminator().to(device).apply(weights_init)
print(f"G params: {sum(p.numel() for p in G.parameters()):,}")
print(f"D params: {sum(p.numel() for p in D.parameters()):,}")
criterion = nn.BCELoss()
optG = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
optD = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
# --- 5. Simplified FID Computation ---
def compute_fid(real_images, fake_images, n_features=64):
"""Simplified FID: uses image pixel features instead of Inception"""
real_flat = real_images.view(real_images.size(0), -1).cpu().numpy()
fake_flat = fake_images.view(fake_images.size(0), -1).cpu().numpy()
mu_r, sigma_r = real_flat.mean(axis=0), np.cov(real_flat, rowvar=False)
mu_f, sigma_f = fake_flat.mean(axis=0), np.cov(fake_flat, rowvar=False)
diff = mu_r - mu_f
covmean = linalg.sqrtm(sigma_r @ sigma_f)
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = diff @ diff + np.trace(sigma_r + sigma_f - 2 * covmean)
return fid
# --- 6. Training Loop ---
n_epochs = 25
G_losses, D_losses, fid_scores = [], [], []
print("\nTraining DCGAN...")
for epoch in range(n_epochs):
for i, (real, _) in enumerate(dataloader):
B = real.size(0)
real = real.to(device)
real_label = torch.ones(B, device=device)
fake_label = torch.zeros(B, device=device)
# --- Train Discriminator ---
D.zero_grad()
out_real = D(real)
loss_real = criterion(out_real, real_label)
noise = torch.randn(B, nz, 1, 1, device=device)
fake = G(noise)
out_fake = D(fake.detach())
loss_fake = criterion(out_fake, fake_label)
loss_D = loss_real + loss_fake
loss_D.backward()
optD.step()
# --- Train Generator ---
G.zero_grad()
out_fake2 = D(fake)
loss_G = criterion(out_fake2, real_label)
loss_G.backward()
optG.step()
G_losses.append(loss_G.item())
D_losses.append(loss_D.item())
# Compute FID (every 5 epochs)
if (epoch + 1) % 5 == 0:
G.eval()
with torch.no_grad():
fake_sample = G(torch.randn(500, nz, 1, 1, device=device))
real_sample = torch.stack([dataset[i][0] for i in range(500)]).to(device)
fid = compute_fid(real_sample, fake_sample)
fid_scores.append((epoch + 1, fid))
print(f"Epoch {epoch+1}/{n_epochs}: G_loss={loss_G.item():.4f}, "
f"D_loss={loss_D.item():.4f}, FID={fid:.1f}")
G.train()
# --- 7. Visualization ---
G.eval()
with torch.no_grad():
generated = G(fixed_noise).cpu()
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
# Generated images
axes[0].imshow(make_grid(generated, nrow=8, normalize=True, padding=2).permute(1, 2, 0),
cmap='gray')
axes[0].set_title('Generated Digits (Epoch 25)', fontsize=13)
axes[0].axis('off')
# Loss curves
axes[1].plot(G_losses, label='Generator', color='#0077b6')
axes[1].plot(D_losses, label='Discriminator', color='#b8922e')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].set_title('Training Loss', fontsize=13)
axes[1].legend()
axes[1].grid(True, alpha=0.3)
# FID curve
if fid_scores:
epochs_f, fids = zip(*fid_scores)
axes[2].plot(epochs_f, fids, 'o-', color='#e63946', linewidth=2)
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('FID (lower = better)')
axes[2].set_title('FID Score', fontsize=13)
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# --- 8. Latent Space Interpolation ---
z1 = torch.randn(1, nz, 1, 1, device=device)
z2 = torch.randn(1, nz, 1, 1, device=device)
alphas = torch.linspace(0, 1, 10)
interpolated = []
with torch.no_grad():
for a in alphas:
z = z1 * (1 - a) + z2 * a
interpolated.append(G(z).cpu())
fig, axes = plt.subplots(1, 10, figsize=(20, 2))
for i, img in enumerate(interpolated):
axes[i].imshow(img[0, 0], cmap='gray')
axes[i].set_title(f'α={alphas[i]:.1f}', fontsize=9)
axes[i].axis('off')
plt.suptitle('Latent Space Interpolation', fontsize=14)
plt.tight_layout()
plt.show()
print("Lab 1 Complete!")
XII. Hands-on Lab 2: SeqGAN Text Sequence Generation (Google Colab)
The following experiment implements the core concept of SeqGAN — training a generator with policy gradient to produce discrete token sequences, with a discriminator evaluating sequence quality.
# ============================================================
# Lab 2: SeqGAN Concept Implementation — Text Sequence Generation
# Environment: Google Colab (GPU or CPU)
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# --- 1. Data Preparation: Simple English Sentence Patterns ---
# Rule-generated "real" sentences for training the discriminator
templates = [
"the cat sat on the mat",
"the dog ran in the park",
"a bird flew over the tree",
"the fish swam in the lake",
"a boy read in the room",
"the girl sang on the stage",
"a man walked to the store",
"the sun set in the west",
"a car drove on the road",
"the wind blew through the door",
]
# Expand dataset
import random
random.seed(42)
subjects = ["the cat", "the dog", "a bird", "the fish", "a boy",
"the girl", "a man", "the sun", "a car", "the wind"]
verbs = ["sat", "ran", "flew", "swam", "read", "sang", "walked", "set", "drove", "blew"]
preps = ["on", "in", "over", "through", "to", "under", "near", "past", "by", "from"]
objects = ["the mat", "the park", "the tree", "the lake", "the room",
"the stage", "the store", "the west", "the road", "the door"]
real_sentences = []
for _ in range(2000):
s = f"{random.choice(subjects)} {random.choice(verbs)} {random.choice(preps)} {random.choice(objects)}"
real_sentences.append(s)
# Build vocabulary
all_words = set()
for s in real_sentences:
all_words.update(s.split())
vocab = {"<pad>": 0, "<bos>": 1, "<eos>": 2}
for w in sorted(all_words):
vocab[w] = len(vocab)
idx2word = {v: k for k, v in vocab.items()}
vocab_size = len(vocab)
SEQ_LEN = 8 # <bos> + 6 words + <eos>
print(f"Vocab size: {vocab_size}, Seq len: {SEQ_LEN}")
def encode_sentence(s):
tokens = [1] + [vocab.get(w, 0) for w in s.split()][:SEQ_LEN-2] + [2]
tokens += [0] * (SEQ_LEN - len(tokens))
return tokens
real_data = torch.tensor([encode_sentence(s) for s in real_sentences]).to(device)
print(f"Real data shape: {real_data.shape}")
# --- 2. Generator (Recurrent Neural Network Policy Network) ---
class SeqGenerator(nn.Module):
def __init__(self, vocab_size, emb_dim=64, hidden_dim=128):
super().__init__()
self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
self.lstm = nn.LSTM(emb_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, vocab_size)
self.hidden_dim = hidden_dim
def forward(self, x, hidden=None):
emb = self.emb(x)
out, hidden = self.lstm(emb, hidden)
logits = self.fc(out)
return logits, hidden
def generate(self, batch_size, max_len=SEQ_LEN):
"""Autoregressive sequence generation"""
x = torch.full((batch_size, 1), 1, dtype=torch.long, device=device) # <bos>
hidden = None
sequences = [x]
log_probs_list = []
for _ in range(max_len - 1):
logits, hidden = self.forward(x, hidden)
probs = F.softmax(logits[:, -1], dim=-1)
dist = torch.distributions.Categorical(probs)
x = dist.sample().unsqueeze(1)
log_probs_list.append(dist.log_prob(x.squeeze(1)))
sequences.append(x)
sequences = torch.cat(sequences, dim=1)
log_probs = torch.stack(log_probs_list, dim=1)
return sequences, log_probs
# --- 3. Discriminator (Convolutional Neural Network Sequence Classifier) ---
class SeqDiscriminator(nn.Module):
def __init__(self, vocab_size, emb_dim=64):
super().__init__()
self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
self.convs = nn.ModuleList([
nn.Conv1d(emb_dim, 32, k, padding=k//2) for k in [2, 3, 4]
])
self.fc = nn.Linear(32 * 3, 1)
self.drop = nn.Dropout(0.2)
def forward(self, x):
emb = self.emb(x).transpose(1, 2) # [B, emb, L]
conv_outs = [F.relu(conv(emb)).max(dim=2).values for conv in self.convs]
out = torch.cat(conv_outs, dim=1)
return torch.sigmoid(self.fc(self.drop(out))).squeeze(1)
# --- 4. Initialization ---
G = SeqGenerator(vocab_size).to(device)
D = SeqDiscriminator(vocab_size).to(device)
optG = torch.optim.Adam(G.parameters(), lr=1e-3)
optD = torch.optim.Adam(D.parameters(), lr=1e-3)
print(f"G params: {sum(p.numel() for p in G.parameters()):,}")
print(f"D params: {sum(p.numel() for p in D.parameters()):,}")
# --- 5. Pre-train Generator (MLE) ---
print("\n--- Pre-training Generator (MLE) ---")
for epoch in range(30):
idx = torch.randperm(len(real_data))[:256]
batch = real_data[idx]
logits, _ = G(batch[:, :-1])
loss = F.cross_entropy(logits.reshape(-1, vocab_size), batch[:, 1:].reshape(-1),
ignore_index=0)
optG.zero_grad()
loss.backward()
optG.step()
if (epoch + 1) % 10 == 0:
print(f" Epoch {epoch+1}: MLE loss = {loss.item():.4f}")
# --- 6. Pre-train Discriminator ---
print("\n--- Pre-training Discriminator ---")
for epoch in range(20):
idx = torch.randperm(len(real_data))[:128]
real_batch = real_data[idx]
with torch.no_grad():
fake_batch, _ = G.generate(128)
real_score = D(real_batch)
fake_score = D(fake_batch)
loss_D = -torch.log(real_score + 1e-8).mean() - torch.log(1 - fake_score + 1e-8).mean()
optD.zero_grad()
loss_D.backward()
optD.step()
if (epoch + 1) % 10 == 0:
print(f" Epoch {epoch+1}: D loss = {loss_D.item():.4f}, "
f"D(real)={real_score.mean():.3f}, D(fake)={fake_score.mean():.3f}")
# --- 7. Adversarial Training (SeqGAN Policy Gradient) ---
print("\n--- Adversarial Training (SeqGAN) ---")
g_rewards_history, d_scores_history = [], []
for epoch in range(50):
# --- Train G (Policy Gradient) ---
fake_seqs, log_probs = G.generate(64)
with torch.no_grad():
rewards = D(fake_seqs) # [B]
# REINFORCE: reward baseline
baseline = rewards.mean()
pg_loss = -((rewards - baseline).unsqueeze(1) * log_probs).mean()
optG.zero_grad()
pg_loss.backward()
nn.utils.clip_grad_norm_(G.parameters(), 5.0)
optG.step()
# --- Train D ---
idx = torch.randperm(len(real_data))[:64]
real_batch = real_data[idx]
with torch.no_grad():
fake_batch, _ = G.generate(64)
real_score = D(real_batch)
fake_score = D(fake_batch)
loss_D = -torch.log(real_score + 1e-8).mean() - torch.log(1 - fake_score + 1e-8).mean()
optD.zero_grad()
loss_D.backward()
optD.step()
g_rewards_history.append(rewards.mean().item())
d_scores_history.append(fake_score.mean().item())
if (epoch + 1) % 10 == 0:
print(f" Epoch {epoch+1}: G_reward={rewards.mean():.3f}, "
f"D(fake)={fake_score.mean():.3f}, D(real)={real_score.mean():.3f}")
# --- 8. Generation and Evaluation ---
G.eval()
print("\n=== Generated Sentences ===")
with torch.no_grad():
seqs, _ = G.generate(20)
for seq in seqs:
words = [idx2word.get(t.item(), "?") for t in seq if t.item() > 2]
sentence = " ".join(words)
score = D(seq.unsqueeze(0)).item()
print(f" [{score:.3f}] {sentence}")
# --- 9. Training Process Visualization ---
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
axes[0].plot(g_rewards_history, color='#0077b6', linewidth=1.5)
axes[0].set_xlabel('Adversarial Epoch')
axes[0].set_ylabel('Average Reward')
axes[0].set_title('Generator Reward (D score on fake)', fontsize=13)
axes[0].grid(True, alpha=0.3)
axes[1].plot(d_scores_history, color='#b8922e', linewidth=1.5, label='D(fake)')
axes[1].axhline(y=0.5, color='#e63946', linestyle='--', alpha=0.7, label='Equilibrium')
axes[1].set_xlabel('Adversarial Epoch')
axes[1].set_ylabel('D(fake)')
axes[1].set_title('Discriminator Score on Fake Sequences', fontsize=13)
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# --- 10. Real vs Generated Sentence Comparison ---
print("\n=== Real vs Generated Comparison ===")
print("Real sentences:")
for s in real_sentences[:5]:
print(f" {s}")
print("\nGenerated sentences:")
with torch.no_grad():
seqs, _ = G.generate(5)
for seq in seqs:
words = [idx2word.get(t.item(), "?") for t in seq if t.item() > 2]
print(f" {' '.join(words)}")
print("\nLab 2 Complete!")
XIII. Conclusion and Outlook
The story of GANs is one of the most fascinating chapters in deep learning. From Goodfellow's original 2014 paper[1] to StyleGAN3's[18] photo-realistic fidelity, GANs fundamentally changed our understanding of "whether machines can create" in less than a decade.
Reviewing the core trajectory:
- Theoretical Foundation: The minimax game framework[1] provided an elegant generative modeling paradigm
- Stabilization Breakthrough: WGAN[3] and spectral normalization[12] transformed training from "art" to "engineering"
- Quality Pinnacle: The StyleGAN series[6][7] achieved generation quality indistinguishable by the human eye
- Application Explosion: Pix2Pix[9] and CycleGAN[10] opened up infinite possibilities for image translation
- Paradigm Competition: Diffusion models[16] surpassed GANs in quality, but GANs maintain their speed advantage
Looking ahead, GANs have not left the stage. Works like GigaGAN extend GANs to billion-parameter scales; distillation techniques enable diffusion models to learn single-step generation (essentially learning to become GANs); and the discriminator concept from GANs has been widely incorporated into alignment techniques like RLHF. The adversarial idea — letting two systems improve together through competition — will continue to influence the direction of AI development.



