Key Findings
  • GANs[1] pioneered adversarial generative modeling with a zero-sum game between generator and discriminator, representing one of the most influential frameworks in deep learning history
  • From DCGAN[2] to StyleGAN3[18], GAN architectures have undergone five generations of evolution, with image generation quality leaping from blurry noise to photo-realistic fidelity
  • WGAN[3] redefined the training objective with Wasserstein distance, resolving the original GAN's training instability and mode collapse issues; CycleGAN[10] achieved unpaired image translation
  • This article includes two Google Colab implementations: DCGAN handwritten digit generation (with FID evaluation) and SeqGAN text sequence generation (training discrete GANs with reinforcement learning)

I. The Art of Adversarial Learning: GAN's Core Philosophy

In 2014, Ian Goodfellow and colleagues proposed Generative Adversarial Networks (GANs) in a landmark paper[1]. Its core idea stems from an elegant analogy: the game between a counterfeiter and the police.

The counterfeiter (generator G) tries to produce counterfeit bills that are indistinguishable from real ones; the police (discriminator D) try to tell real from fake. As the game progresses, the counterfeiter becomes increasingly skilled at forging, and the police become increasingly skilled at detection — until the counterfeits are so perfect that even the police cannot distinguish them.

The mathematical form of this game is a minimax zero-sum game:

min_G max_D  V(D, G) = E_{x~p_data}[log D(x)] + E_{z~p_z}[log(1 - D(G(z)))]

Where:
  G(z): Generator, maps random noise z to fake samples
  D(x): Discriminator, outputs probability that x is a real sample
  p_data: Real data distribution
  p_z: Noise distribution (typically Gaussian or uniform)

Optimal discriminator: D*(x) = p_data(x) / (p_data(x) + p_g(x))
Nash equilibrium: p_g = p_data, D*(x) = 0.5 (completely indistinguishable)

GAN training is a two-step alternating process:

  1. Train discriminator: Fix G, train D to distinguish between real samples and fake samples generated by G
  2. Train generator: Fix D, train G to generate samples that D misclassifies as real

II. Training Nightmares: GAN's Three Major Challenges

GAN's theory is elegant, but training is extremely difficult. Three core problems plagued the research community for years:

Mode Collapse

The generator "takes shortcuts" — learning to generate only a few specific samples that fool the discriminator, rather than covering the entire data distribution. For example, a GAN trained on MNIST might only learn to generate the digits 1 and 7.

Training Instability

The original GAN's loss function is based on Jensen-Shannon divergence. When the real distribution and generated distribution have no overlap (almost certainly the case in high-dimensional spaces), JS divergence is constant at log(2) — the gradient is zero, and the generator cannot learn.

Evaluation Difficulty

GANs don't have an ELBO loss like VAEs to track. Training loss doesn't reflect generation quality — discriminator loss approaching 0.5 could indicate equilibrium or collapse.

III. DCGAN: Establishing Design Guidelines for Image GANs

Radford et al.[2] proposed DCGAN (Deep Convolutional GAN), establishing a set of classic design guidelines for image GANs:

GuidelineGeneratorDiscriminator
Up/DownsamplingTransposed convolution (stride 2)Strided convolution (stride 2)
NormalizationBatchNorm (except output layer)BatchNorm (except input layer)
ActivationReLU (Tanh for output)LeakyReLU(0.2)
Fully connected layersNot usedNot used

These seemingly simple guidelines had profound significance — they transformed GAN training from "random success" to "reproducible." DCGAN also demonstrated that the generator's latent space has meaningful structure: z(man with glasses) - z(man) + z(woman) ≈ z(woman with glasses).

IV. WGAN: Redefining the Distance Metric

WGAN[3] (Wasserstein GAN) was a watershed for GAN training stability. Arjovsky et al. changed the training objective from JS divergence to the Wasserstein-1 distance (Earth Mover's Distance):

W(p_data, p_g) = inf_{γ∈Π(p_data, p_g)} E_{(x,y)~γ}[||x - y||]

Intuition: The minimum "work" needed to "transport" distribution p_g into p_data

WGAN loss (Kantorovich-Rubinstein duality):
  L_critic = E_{x~p_data}[f_w(x)] - E_{z~p_z}[f_w(G(z))]
  L_G = -E_{z~p_z}[f_w(G(z))]

  f_w is a 1-Lipschitz continuous "critic" (replacing the discriminator)

The key advantage of Wasserstein distance: even when two distributions have no overlap, it still provides meaningful gradients. WGAN's loss value positively correlates with generation quality — lower loss means better generation. This is something the original GAN cannot achieve.

WGAN-GP[4] further replaced the original WGAN's weight clipping with gradient penalty to enforce the Lipschitz constraint, achieving even more stable training.

V. The StyleGAN Series: The Pinnacle of Image Generation

Karras et al. at NVIDIA produced a series of studies[5][6][7][18] that pushed GAN image generation quality to incredible heights:

ModelYearKey InnovationImpact
ProGAN[5]2018Progressive training (4x4 → 1024x1024)First 1024x1024 photo-realistic faces
StyleGAN[6]2019Mapping network + AdaIN style injectionFine-grained hierarchical style control
StyleGAN2[7]2020Weight demodulation + path length regularizationEliminated droplet artifacts, smoother latent space
StyleGAN3[18]2021Fixed aliasing, continuous signal processingSub-pixel equivariance, eliminated "texture sticking"

StyleGAN's core architectural innovation is separating content from style:

StyleGAN Generator:
  z ∈ Z (512-dim)  →  Mapping Network (8-layer MLP)  →  w ∈ W (512-dim)
                                                     ↓
  Constant 4×4     →  [AdaIN(w)] → Conv → [AdaIN(w)] → Upsample → ...
                       Coarse style      Mid-level style    Fine style
                       (pose/face shape) (features/hair)    (skin/texture)

  W space is more "disentangled" than Z space → linear interpolation produces meaningful transitions

VI. Image Translation: Pix2Pix and CycleGAN

Pix2Pix: Paired Image Translation

Isola et al.[9] proposed Pix2Pix, applying conditional GANs[8] to image translation — edge maps to photos, semantic labels to street scenes, grayscale to color. It uses a U-Net generator and a PatchGAN discriminator (judging real/fake on 70x70 patches rather than the entire image).

CycleGAN: Unpaired Image Translation

CycleGAN[10] solved an even more challenging problem: image translation without paired data. Its core is cycle consistency loss — if you translate a horse into a zebra, then translate the zebra back, you should get the original horse:

Cycle Consistency:
  x → G(x) → F(G(x)) ≈ x   (forward cycle)
  y → F(y) → G(F(y)) ≈ y   (backward cycle)

  L_cycle = ||F(G(x)) - x||₁ + ||G(F(y)) - y||₁

VII. GANs for Text: The Challenge of Discrete Sequences

Applying GANs to text generation faces a fundamental barrier: sampling discrete tokens is non-differentiable. You cannot backpropagate through argmax or categorical sampling.

SeqGAN[11] elegantly solved this problem using a reinforcement learning approach:

SeqGAN Framework:
  Generator = Policy Network
  Discriminator = Reward Function
  Action = Select next token
  State = Currently generated token sequence

  Training Process:
  1. G autoregressively generates complete sequences
  2. D scores complete sequences (real/fake)
  3. Use Monte Carlo rollout to estimate Q-values of intermediate states
  4. Update G with REINFORCE policy gradient

Zhang et al.[17] proposed an alternative approach — performing adversarial training in the latent feature space rather than discrete token space, using kernelized discrepancy metrics to match distributions, circumventing the discrete sampling problem.

VIII. Training Stabilization Techniques

Beyond the WGAN series, several key stabilization techniques have been developed:

TechniqueCore IdeaEffect
Spectral Normalization[12]Divide each discriminator layer's weight matrix by its largest singular valueLightweight Lipschitz constraint with almost no additional computational cost
Progressive Training[5]Start from low resolution, progressively add layersMore stable high-resolution image generation
Truncation Trick[13]Limit z range during inference (truncation)Sacrifice diversity for quality
Feature Matching[14]Match statistics of discriminator intermediate layersReduce mode collapse
Two Time-Scale[15]Use different learning rates for D and GTheoretically converges to local Nash equilibrium

IX. Evaluation Metrics: How to Measure Generation Quality

MetricComputationWhat It MeasuresLimitations
IS[14]KL divergence of Inception model predictionsSharpness + DiversityDoesn't compare to real distribution
FID[15]Frechet distance of Inception features between real and fake imagesQuality + Diversity + Proximity to real distributionRequires many samples, depends on Inception
LPIPSDistance in perceptual feature spacePerceptual similarityPairwise comparison, not distribution-level

FID[15] is currently the most widely used GAN evaluation metric. It compares the mean and covariance of real and generated images in the Inception-v3 feature space, with lower values indicating better generation quality. Typical FID scores: StyleGAN2 on FFHQ achieves approximately 2.8, BigGAN[13] on ImageNet 128x128 achieves approximately 6.9.

X. GANs vs. Diffusion Models: A Paradigm Shift

In 2021, Dhariwal and Nichol[16] published the landmark paper "Diffusion Models Beat GANs on Image Synthesis," showing that diffusion models comprehensively surpass the best GANs on FID. But this doesn't mean GANs are obsolete:

AspectGANDiffusion Model
Generation SpeedSingle forward pass (~20ms)Tens to hundreds of iterative steps (~2-10s)
Training StabilityAdversarial training is unstableSimple denoising objective, stable
Mode CoverageProne to mode collapseBetter distribution coverage
Image QualityExtremely high (StyleGAN series)Extremely high (Imagen, DALL-E 3)
ControllabilityStyleGAN's W spaceText-conditioned guidance
Use CasesReal-time rendering, video, super-resolutionText-to-image, editing

In scenarios requiring real-time generation (games, video, interactive applications), GAN's single-pass inference speed remains unmatched.

XI. Hands-on Lab 1: DCGAN Handwritten Digit Generation (Google Colab)

The following experiment implements DCGAN from scratch, training on MNIST to generate handwritten digits, and computing FID scores to track training quality.

# ============================================================
# Lab 1: DCGAN — MNIST Handwritten Digit Generation + FID Evaluation
# Environment: Google Colab (GPU)
# ============================================================
# --- 0. Installation ---
!pip install -q torchvision scipy

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from scipy import linalg

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# --- 1. Data Loading ---
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # [-1, 1]
])
dataset = torchvision.datasets.MNIST(root='./data', train=True,
                                      transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128,
                                          shuffle=True, num_workers=2)
print(f"Training images: {len(dataset)}")

# --- 2. Generator (DCGAN Architecture) ---
nz = 100  # Latent vector dimension

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            # Input: z [B, 100, 1, 1]
            nn.ConvTranspose2d(nz, 256, 4, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # [B, 256, 4, 4]
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # [B, 128, 8, 8]
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # [B, 64, 16, 16]
            nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
            nn.Tanh()
            # [B, 1, 32, 32]
        )

    def forward(self, z):
        return self.main(z)

# --- 3. Discriminator (DCGAN Architecture) ---
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            # [B, 1, 32, 32]
            nn.Conv2d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # [B, 64, 16, 16]
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # [B, 128, 8, 8]
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # [B, 256, 4, 4]
            nn.Conv2d(256, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x).view(-1)

# --- 4. Initialization ---
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight, 1.0, 0.02)
        nn.init.zeros_(m.bias)

G = Generator().to(device).apply(weights_init)
D = Discriminator().to(device).apply(weights_init)
print(f"G params: {sum(p.numel() for p in G.parameters()):,}")
print(f"D params: {sum(p.numel() for p in D.parameters()):,}")

criterion = nn.BCELoss()
optG = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
optD = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))

fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# --- 5. Simplified FID Computation ---
def compute_fid(real_images, fake_images, n_features=64):
    """Simplified FID: uses image pixel features instead of Inception"""
    real_flat = real_images.view(real_images.size(0), -1).cpu().numpy()
    fake_flat = fake_images.view(fake_images.size(0), -1).cpu().numpy()

    mu_r, sigma_r = real_flat.mean(axis=0), np.cov(real_flat, rowvar=False)
    mu_f, sigma_f = fake_flat.mean(axis=0), np.cov(fake_flat, rowvar=False)

    diff = mu_r - mu_f
    covmean = linalg.sqrtm(sigma_r @ sigma_f)
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff @ diff + np.trace(sigma_r + sigma_f - 2 * covmean)
    return fid

# --- 6. Training Loop ---
n_epochs = 25
G_losses, D_losses, fid_scores = [], [], []

print("\nTraining DCGAN...")
for epoch in range(n_epochs):
    for i, (real, _) in enumerate(dataloader):
        B = real.size(0)
        real = real.to(device)
        real_label = torch.ones(B, device=device)
        fake_label = torch.zeros(B, device=device)

        # --- Train Discriminator ---
        D.zero_grad()
        out_real = D(real)
        loss_real = criterion(out_real, real_label)

        noise = torch.randn(B, nz, 1, 1, device=device)
        fake = G(noise)
        out_fake = D(fake.detach())
        loss_fake = criterion(out_fake, fake_label)

        loss_D = loss_real + loss_fake
        loss_D.backward()
        optD.step()

        # --- Train Generator ---
        G.zero_grad()
        out_fake2 = D(fake)
        loss_G = criterion(out_fake2, real_label)
        loss_G.backward()
        optG.step()

    G_losses.append(loss_G.item())
    D_losses.append(loss_D.item())

    # Compute FID (every 5 epochs)
    if (epoch + 1) % 5 == 0:
        G.eval()
        with torch.no_grad():
            fake_sample = G(torch.randn(500, nz, 1, 1, device=device))
        real_sample = torch.stack([dataset[i][0] for i in range(500)]).to(device)
        fid = compute_fid(real_sample, fake_sample)
        fid_scores.append((epoch + 1, fid))
        print(f"Epoch {epoch+1}/{n_epochs}: G_loss={loss_G.item():.4f}, "
              f"D_loss={loss_D.item():.4f}, FID={fid:.1f}")
        G.train()

# --- 7. Visualization ---
G.eval()
with torch.no_grad():
    generated = G(fixed_noise).cpu()

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Generated images
axes[0].imshow(make_grid(generated, nrow=8, normalize=True, padding=2).permute(1, 2, 0),
               cmap='gray')
axes[0].set_title('Generated Digits (Epoch 25)', fontsize=13)
axes[0].axis('off')

# Loss curves
axes[1].plot(G_losses, label='Generator', color='#0077b6')
axes[1].plot(D_losses, label='Discriminator', color='#b8922e')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].set_title('Training Loss', fontsize=13)
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# FID curve
if fid_scores:
    epochs_f, fids = zip(*fid_scores)
    axes[2].plot(epochs_f, fids, 'o-', color='#e63946', linewidth=2)
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('FID (lower = better)')
    axes[2].set_title('FID Score', fontsize=13)
    axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# --- 8. Latent Space Interpolation ---
z1 = torch.randn(1, nz, 1, 1, device=device)
z2 = torch.randn(1, nz, 1, 1, device=device)
alphas = torch.linspace(0, 1, 10)

interpolated = []
with torch.no_grad():
    for a in alphas:
        z = z1 * (1 - a) + z2 * a
        interpolated.append(G(z).cpu())

fig, axes = plt.subplots(1, 10, figsize=(20, 2))
for i, img in enumerate(interpolated):
    axes[i].imshow(img[0, 0], cmap='gray')
    axes[i].set_title(f'α={alphas[i]:.1f}', fontsize=9)
    axes[i].axis('off')
plt.suptitle('Latent Space Interpolation', fontsize=14)
plt.tight_layout()
plt.show()

print("Lab 1 Complete!")

XII. Hands-on Lab 2: SeqGAN Text Sequence Generation (Google Colab)

The following experiment implements the core concept of SeqGAN — training a generator with policy gradient to produce discrete token sequences, with a discriminator evaluating sequence quality.

# ============================================================
# Lab 2: SeqGAN Concept Implementation — Text Sequence Generation
# Environment: Google Colab (GPU or CPU)
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# --- 1. Data Preparation: Simple English Sentence Patterns ---
# Rule-generated "real" sentences for training the discriminator
templates = [
    "the cat sat on the mat",
    "the dog ran in the park",
    "a bird flew over the tree",
    "the fish swam in the lake",
    "a boy read in the room",
    "the girl sang on the stage",
    "a man walked to the store",
    "the sun set in the west",
    "a car drove on the road",
    "the wind blew through the door",
]

# Expand dataset
import random
random.seed(42)
subjects = ["the cat", "the dog", "a bird", "the fish", "a boy",
            "the girl", "a man", "the sun", "a car", "the wind"]
verbs = ["sat", "ran", "flew", "swam", "read", "sang", "walked", "set", "drove", "blew"]
preps = ["on", "in", "over", "through", "to", "under", "near", "past", "by", "from"]
objects = ["the mat", "the park", "the tree", "the lake", "the room",
           "the stage", "the store", "the west", "the road", "the door"]

real_sentences = []
for _ in range(2000):
    s = f"{random.choice(subjects)} {random.choice(verbs)} {random.choice(preps)} {random.choice(objects)}"
    real_sentences.append(s)

# Build vocabulary
all_words = set()
for s in real_sentences:
    all_words.update(s.split())
vocab = {"<pad>": 0, "<bos>": 1, "<eos>": 2}
for w in sorted(all_words):
    vocab[w] = len(vocab)
idx2word = {v: k for k, v in vocab.items()}
vocab_size = len(vocab)
SEQ_LEN = 8  # <bos> + 6 words + <eos>

print(f"Vocab size: {vocab_size}, Seq len: {SEQ_LEN}")

def encode_sentence(s):
    tokens = [1] + [vocab.get(w, 0) for w in s.split()][:SEQ_LEN-2] + [2]
    tokens += [0] * (SEQ_LEN - len(tokens))
    return tokens

real_data = torch.tensor([encode_sentence(s) for s in real_sentences]).to(device)
print(f"Real data shape: {real_data.shape}")

# --- 2. Generator (Recurrent Neural Network Policy Network) ---
class SeqGenerator(nn.Module):
    def __init__(self, vocab_size, emb_dim=64, hidden_dim=128):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.lstm = nn.LSTM(emb_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.hidden_dim = hidden_dim

    def forward(self, x, hidden=None):
        emb = self.emb(x)
        out, hidden = self.lstm(emb, hidden)
        logits = self.fc(out)
        return logits, hidden

    def generate(self, batch_size, max_len=SEQ_LEN):
        """Autoregressive sequence generation"""
        x = torch.full((batch_size, 1), 1, dtype=torch.long, device=device)  # <bos>
        hidden = None
        sequences = [x]
        log_probs_list = []

        for _ in range(max_len - 1):
            logits, hidden = self.forward(x, hidden)
            probs = F.softmax(logits[:, -1], dim=-1)
            dist = torch.distributions.Categorical(probs)
            x = dist.sample().unsqueeze(1)
            log_probs_list.append(dist.log_prob(x.squeeze(1)))
            sequences.append(x)

        sequences = torch.cat(sequences, dim=1)
        log_probs = torch.stack(log_probs_list, dim=1)
        return sequences, log_probs

# --- 3. Discriminator (Convolutional Neural Network Sequence Classifier) ---
class SeqDiscriminator(nn.Module):
    def __init__(self, vocab_size, emb_dim=64):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.convs = nn.ModuleList([
            nn.Conv1d(emb_dim, 32, k, padding=k//2) for k in [2, 3, 4]
        ])
        self.fc = nn.Linear(32 * 3, 1)
        self.drop = nn.Dropout(0.2)

    def forward(self, x):
        emb = self.emb(x).transpose(1, 2)  # [B, emb, L]
        conv_outs = [F.relu(conv(emb)).max(dim=2).values for conv in self.convs]
        out = torch.cat(conv_outs, dim=1)
        return torch.sigmoid(self.fc(self.drop(out))).squeeze(1)

# --- 4. Initialization ---
G = SeqGenerator(vocab_size).to(device)
D = SeqDiscriminator(vocab_size).to(device)
optG = torch.optim.Adam(G.parameters(), lr=1e-3)
optD = torch.optim.Adam(D.parameters(), lr=1e-3)
print(f"G params: {sum(p.numel() for p in G.parameters()):,}")
print(f"D params: {sum(p.numel() for p in D.parameters()):,}")

# --- 5. Pre-train Generator (MLE) ---
print("\n--- Pre-training Generator (MLE) ---")
for epoch in range(30):
    idx = torch.randperm(len(real_data))[:256]
    batch = real_data[idx]
    logits, _ = G(batch[:, :-1])
    loss = F.cross_entropy(logits.reshape(-1, vocab_size), batch[:, 1:].reshape(-1),
                           ignore_index=0)
    optG.zero_grad()
    loss.backward()
    optG.step()
    if (epoch + 1) % 10 == 0:
        print(f"  Epoch {epoch+1}: MLE loss = {loss.item():.4f}")

# --- 6. Pre-train Discriminator ---
print("\n--- Pre-training Discriminator ---")
for epoch in range(20):
    idx = torch.randperm(len(real_data))[:128]
    real_batch = real_data[idx]
    with torch.no_grad():
        fake_batch, _ = G.generate(128)

    real_score = D(real_batch)
    fake_score = D(fake_batch)
    loss_D = -torch.log(real_score + 1e-8).mean() - torch.log(1 - fake_score + 1e-8).mean()
    optD.zero_grad()
    loss_D.backward()
    optD.step()
    if (epoch + 1) % 10 == 0:
        print(f"  Epoch {epoch+1}: D loss = {loss_D.item():.4f}, "
              f"D(real)={real_score.mean():.3f}, D(fake)={fake_score.mean():.3f}")

# --- 7. Adversarial Training (SeqGAN Policy Gradient) ---
print("\n--- Adversarial Training (SeqGAN) ---")
g_rewards_history, d_scores_history = [], []

for epoch in range(50):
    # --- Train G (Policy Gradient) ---
    fake_seqs, log_probs = G.generate(64)
    with torch.no_grad():
        rewards = D(fake_seqs)  # [B]

    # REINFORCE: reward baseline
    baseline = rewards.mean()
    pg_loss = -((rewards - baseline).unsqueeze(1) * log_probs).mean()

    optG.zero_grad()
    pg_loss.backward()
    nn.utils.clip_grad_norm_(G.parameters(), 5.0)
    optG.step()

    # --- Train D ---
    idx = torch.randperm(len(real_data))[:64]
    real_batch = real_data[idx]
    with torch.no_grad():
        fake_batch, _ = G.generate(64)

    real_score = D(real_batch)
    fake_score = D(fake_batch)
    loss_D = -torch.log(real_score + 1e-8).mean() - torch.log(1 - fake_score + 1e-8).mean()
    optD.zero_grad()
    loss_D.backward()
    optD.step()

    g_rewards_history.append(rewards.mean().item())
    d_scores_history.append(fake_score.mean().item())

    if (epoch + 1) % 10 == 0:
        print(f"  Epoch {epoch+1}: G_reward={rewards.mean():.3f}, "
              f"D(fake)={fake_score.mean():.3f}, D(real)={real_score.mean():.3f}")

# --- 8. Generation and Evaluation ---
G.eval()
print("\n=== Generated Sentences ===")
with torch.no_grad():
    seqs, _ = G.generate(20)
    for seq in seqs:
        words = [idx2word.get(t.item(), "?") for t in seq if t.item() > 2]
        sentence = " ".join(words)
        score = D(seq.unsqueeze(0)).item()
        print(f"  [{score:.3f}] {sentence}")

# --- 9. Training Process Visualization ---
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(g_rewards_history, color='#0077b6', linewidth=1.5)
axes[0].set_xlabel('Adversarial Epoch')
axes[0].set_ylabel('Average Reward')
axes[0].set_title('Generator Reward (D score on fake)', fontsize=13)
axes[0].grid(True, alpha=0.3)

axes[1].plot(d_scores_history, color='#b8922e', linewidth=1.5, label='D(fake)')
axes[1].axhline(y=0.5, color='#e63946', linestyle='--', alpha=0.7, label='Equilibrium')
axes[1].set_xlabel('Adversarial Epoch')
axes[1].set_ylabel('D(fake)')
axes[1].set_title('Discriminator Score on Fake Sequences', fontsize=13)
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# --- 10. Real vs Generated Sentence Comparison ---
print("\n=== Real vs Generated Comparison ===")
print("Real sentences:")
for s in real_sentences[:5]:
    print(f"  {s}")
print("\nGenerated sentences:")
with torch.no_grad():
    seqs, _ = G.generate(5)
    for seq in seqs:
        words = [idx2word.get(t.item(), "?") for t in seq if t.item() > 2]
        print(f"  {' '.join(words)}")

print("\nLab 2 Complete!")

XIII. Conclusion and Outlook

The story of GANs is one of the most fascinating chapters in deep learning. From Goodfellow's original 2014 paper[1] to StyleGAN3's[18] photo-realistic fidelity, GANs fundamentally changed our understanding of "whether machines can create" in less than a decade.

Reviewing the core trajectory:

Looking ahead, GANs have not left the stage. Works like GigaGAN extend GANs to billion-parameter scales; distillation techniques enable diffusion models to learn single-step generation (essentially learning to become GANs); and the discriminator concept from GANs has been widely incorporated into alignment techniques like RLHF. The adversarial idea — letting two systems improve together through competition — will continue to influence the direction of AI development.