Key Findings
  • GAN[1] 以生成器與判別器的零和博弈開創了對抗式生成建模,是深度學習史上最具影響力的框架之一
  • 從 DCGAN[2] 到 StyleGAN3[18],GAN 架構經歷五代演進,影像生成品質從模糊噪點躍升至照片級真實感
  • WGAN[3] 以 Wasserstein 距離重新定義訓練目標,解決了原始 GAN 的訓練不穩定與模式崩潰問題;CycleGAN[10] 實現了無配對影像翻譯
  • 本文附兩個 Google Colab 實作:DCGAN 手寫數字生成(含 FID 評估)、SeqGAN 文本序列生成(以強化學習訓練離散 GAN)

一、對抗的藝術:GAN 的核心哲學

2014 年,Ian Goodfellow 和同事們在一篇里程碑式的論文[1]中提出了生成對抗網路(Generative Adversarial Network, GAN)。它的核心思想源自一個巧妙的類比:偽鈔製造者與警察的博弈

製造者(生成器 G)試圖製造以假亂真的鈔票;警察(判別器 D)試圖分辨真假。隨著博弈的進行,製造者越來越擅長偽造,警察也越來越擅長鑑別——直到偽鈔完美到連警察都無法區分。

這個博弈的數學形式是一個 minimax 零和遊戲

min_G max_D  V(D, G) = E_{x~p_data}[log D(x)] + E_{z~p_z}[log(1 - D(G(z)))]

其中:
  G(z): 生成器,將隨機噪聲 z 映射為假樣本
  D(x): 判別器,輸出 x 為真實樣本的機率
  p_data: 真實資料分布
  p_z: 噪聲分布(通常為高斯或均勻分布)

最優判別器: D*(x) = p_data(x) / (p_data(x) + p_g(x))
Nash 均衡: p_g = p_data, D*(x) = 0.5(完全無法區分)

GAN 的訓練是一個兩步交替的過程:

  1. 訓練判別器:固定 G,用真實樣本和 G 生成的假樣本訓練 D 區分真假
  2. 訓練生成器:固定 D,訓練 G 生成讓 D 誤判為真的樣本

二、訓練的噩夢:GAN 的三大挑戰

GAN 的理論優美,但訓練極其困難。三個核心問題困擾了研究界多年:

模式崩潰(Mode Collapse)

生成器「偷懶」——只學會生成幾種特定樣本就能騙過判別器,而非覆蓋整個資料分布。例如,訓練在 MNIST 上的 GAN 可能只會生成數字 1 和 7。

訓練不穩定

原始 GAN 的損失函數基於 Jensen-Shannon 散度。當真實分布和生成分布沒有重疊時(在高維空間中幾乎必然如此),JS 散度恆為 log(2)——梯度為零,生成器無法學習。

評估困難

GAN 沒有像 VAE 那樣的 ELBO 損失可以追蹤。訓練損失不反映生成品質——判別器損失趨近 0.5 可能是均衡,也可能是崩潰。

三、DCGAN:建立影像 GAN 的設計準則

Radford 等人[2]提出的 DCGAN(Deep Convolutional GAN)為影像 GAN 建立了一組經典設計準則:

準則生成器判別器
上/下取樣轉置卷積(stride 2)步幅卷積(stride 2)
歸一化BatchNorm(除輸出層)BatchNorm(除輸入層)
激活函數ReLU(輸出用 Tanh)LeakyReLU(0.2)
全連接層不使用不使用

這些看似簡單的準則意義深遠——它們讓 GAN 訓練從「隨機成功」變為「可重現」。DCGAN 也展示了生成器的潛在空間具有有意義的結構:z(戴眼鏡男人) - z(男人) + z(女人) ≈ z(戴眼鏡女人)

四、WGAN:重新定義距離度量

WGAN[3](Wasserstein GAN)是 GAN 訓練穩定性的分水嶺。Arjovsky 等人將訓練目標從 JS 散度改為 Wasserstein-1 距離(Earth Mover's Distance):

W(p_data, p_g) = inf_{γ∈Π(p_data, p_g)} E_{(x,y)~γ}[||x - y||]

直觀理解: 把分布 p_g 「搬運」成 p_data 所需的最小「工作量」

WGAN 損失(Kantorovich-Rubinstein 對偶):
  L_critic = E_{x~p_data}[f_w(x)] - E_{z~p_z}[f_w(G(z))]
  L_G = -E_{z~p_z}[f_w(G(z))]

  f_w 是 1-Lipschitz 連續的「評論家」(取代判別器)

Wasserstein 距離的關鍵優勢:即使兩個分布沒有重疊,它仍然提供有意義的梯度。WGAN 的損失值與生成品質正相關——損失越低,生成越好。這是原始 GAN 做不到的。

WGAN-GP[4] 進一步以梯度懲罰取代原始 WGAN 的權重裁剪來強制 Lipschitz 約束,實現了更穩定的訓練。

五、StyleGAN 系列:影像生成的巔峰

Karras 等人在 NVIDIA 的一系列研究[5][6][7][18]將 GAN 的影像生成品質推向了令人難以置信的高度:

模型年份關鍵創新影響
ProGAN[5]2018漸進式訓練(4×4 → 1024×1024)首次 1024² 照片級人臉
StyleGAN[6]2019映射網路 + AdaIN 風格注入精細的風格分層控制
StyleGAN2[7]2020權重解調 + 路徑長度正則化消除液滴偽影、更平滑的潛在空間
StyleGAN3[18]2021修復混疊、連續信號處理次像素等變性、消除「紋理黏著」

StyleGAN 的核心架構創新是分離內容與風格

StyleGAN 生成器:
  z ∈ Z (512-dim)  →  Mapping Network (8層 MLP)  →  w ∈ W (512-dim)
                                                     ↓
  Constant 4×4     →  [AdaIN(w)] → Conv → [AdaIN(w)] → Upsample → ...
                       粗糙風格         中間風格         精細風格
                       (姿態/臉型)     (五官/髮型)       (膚色/紋理)

  w 空間比 z 空間更「解糾纏」→ 線性插值產生有意義的漸變

六、影像翻譯:Pix2Pix 與 CycleGAN

Pix2Pix:配對影像翻譯

Isola 等人[9]提出的 Pix2Pix 將條件 GAN[8]應用於影像翻譯——邊緣圖→照片、語義標籤→街景、黑白→彩色。它使用 U-Net 生成器和 PatchGAN 判別器(只判斷 70×70 區塊的真假,而非整張影像)。

CycleGAN:無配對影像翻譯

CycleGAN[10] 解決了一個更有挑戰性的問題:沒有配對資料的影像翻譯。它的核心是循環一致性損失——如果把馬翻譯成斑馬,再把斑馬翻譯回去,應該得到原始的馬:

循環一致性:
  x → G(x) → F(G(x)) ≈ x   (前向循環)
  y → F(y) → G(F(y)) ≈ y   (反向循環)

  L_cycle = ||F(G(x)) - x||₁ + ||G(F(y)) - y||₁

七、GAN 應用於文本:離散序列的挑戰

將 GAN 應用於文本生成面臨一個根本障礙:離散 token 的取樣不可微分。你無法對 argmax 或 categorical sampling 做反向傳播。

SeqGAN[11] 以強化學習的思路巧妙地解決了這個問題:

SeqGAN 框架:
  生成器 = 策略網路 (Policy Network)
  判別器 = 獎勵函數 (Reward Function)
  動作 = 選擇下一個 token
  狀態 = 已生成的 token 序列

  訓練流程:
  1. G 自迴歸生成完整序列
  2. D 對完整序列評分(真/假)
  3. 用 Monte Carlo rollout 估計中間狀態的 Q 值
  4. 以 REINFORCE 策略梯度更新 G

Zhang 等人[17]則提出了另一條路線——在潛在特徵空間而非離散 token 空間進行對抗訓練,以核化差異度量匹配分布,避開了離散取樣的問題。

八、訓練穩定化技術

除了 WGAN 系列,還有幾項關鍵的穩定化技術:

技術核心思想效果
頻譜歸一化[12]對判別器每層權重矩陣除以其最大奇異值輕量級 Lipschitz 約束,幾乎不增加計算成本
漸進式訓練[5]從低解析度開始,逐步增加層高解析度影像生成更穩定
截斷技巧[13]推論時限制 z 的範圍(truncation)犧牲多樣性換取品質
特徵匹配[14]匹配判別器中間層的統計量減少模式崩潰
兩時間尺度[15]D 和 G 使用不同學習率理論上收斂至局部 Nash 均衡

九、評估指標:如何衡量生成品質

指標計算方式衡量什麼局限
IS[14]Inception 模型預測的 KL 散度清晰度 + 多樣性不比較真實分布
FID[15]真假影像 Inception 特徵的 Fréchet 距離品質 + 多樣性 + 與真實分布的接近度需要大量樣本、對 Inception 依賴
LPIPS感知特徵空間的距離感知相似度配對比較,非分布層級

FID[15] 是當前最廣泛使用的 GAN 評估指標。它比較真實影像和生成影像在 Inception-v3 特徵空間中的均值和協方差,數值越低代表生成品質越好。典型的 FID 分數:StyleGAN2 在 FFHQ 上約 2.8,BigGAN[13] 在 ImageNet 128×128 上約 6.9。

十、GAN vs. 擴散模型:典範的交替

2021 年,Dhariwal 與 Nichol[16]發表了標誌性論文「Diffusion Models Beat GANs on Image Synthesis」,展示擴散模型在 FID 上全面超越最佳 GAN。但這並不意味著 GAN 過時了:

面向GAN擴散模型
生成速度一次前向傳播(~20ms)數十到數百步迭代(~2-10s)
訓練穩定性對抗訓練不穩定簡單的去噪目標,穩定
模式覆蓋容易模式崩潰更好的分布覆蓋
影像品質極高(StyleGAN 系列)極高(Imagen, DALL-E 3)
可控性StyleGAN 的 W 空間文字條件引導
應用場景即時渲染、影片、超解析度文字生成影像、編輯

在需要即時生成的場景(遊戲、影片、互動應用),GAN 的單次推論速度仍無可替代。

十一、Hands-on Lab 1:DCGAN 手寫數字生成(Google Colab)

以下實驗從零實現 DCGAN,在 MNIST 上訓練生成手寫數字,並計算 FID 分數追蹤訓練品質。

# ============================================================
# Lab 1: DCGAN — MNIST 手寫數字生成 + FID 評估
# 環境: Google Colab (GPU)
# ============================================================
# --- 0. 安裝 ---
!pip install -q torchvision scipy

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from scipy import linalg

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# --- 1. 資料載入 ---
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # [-1, 1]
])
dataset = torchvision.datasets.MNIST(root='./data', train=True,
                                      transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128,
                                          shuffle=True, num_workers=2)
print(f"Training images: {len(dataset)}")

# --- 2. 生成器(DCGAN 架構)---
nz = 100  # 潛在向量維度

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            # 輸入: z [B, 100, 1, 1]
            nn.ConvTranspose2d(nz, 256, 4, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # [B, 256, 4, 4]
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # [B, 128, 8, 8]
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # [B, 64, 16, 16]
            nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
            nn.Tanh()
            # [B, 1, 32, 32]
        )

    def forward(self, z):
        return self.main(z)

# --- 3. 判別器(DCGAN 架構)---
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            # [B, 1, 32, 32]
            nn.Conv2d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # [B, 64, 16, 16]
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # [B, 128, 8, 8]
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # [B, 256, 4, 4]
            nn.Conv2d(256, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x).view(-1)

# --- 4. 初始化 ---
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight, 1.0, 0.02)
        nn.init.zeros_(m.bias)

G = Generator().to(device).apply(weights_init)
D = Discriminator().to(device).apply(weights_init)
print(f"G params: {sum(p.numel() for p in G.parameters()):,}")
print(f"D params: {sum(p.numel() for p in D.parameters()):,}")

criterion = nn.BCELoss()
optG = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
optD = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))

fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# --- 5. 簡易 FID 計算 ---
def compute_fid(real_images, fake_images, n_features=64):
    """簡化版 FID: 使用影像像素特徵而非 Inception"""
    real_flat = real_images.view(real_images.size(0), -1).cpu().numpy()
    fake_flat = fake_images.view(fake_images.size(0), -1).cpu().numpy()

    mu_r, sigma_r = real_flat.mean(axis=0), np.cov(real_flat, rowvar=False)
    mu_f, sigma_f = fake_flat.mean(axis=0), np.cov(fake_flat, rowvar=False)

    diff = mu_r - mu_f
    covmean = linalg.sqrtm(sigma_r @ sigma_f)
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff @ diff + np.trace(sigma_r + sigma_f - 2 * covmean)
    return fid

# --- 6. 訓練迴圈 ---
n_epochs = 25
G_losses, D_losses, fid_scores = [], [], []

print("\nTraining DCGAN...")
for epoch in range(n_epochs):
    for i, (real, _) in enumerate(dataloader):
        B = real.size(0)
        real = real.to(device)
        real_label = torch.ones(B, device=device)
        fake_label = torch.zeros(B, device=device)

        # --- 訓練判別器 ---
        D.zero_grad()
        out_real = D(real)
        loss_real = criterion(out_real, real_label)

        noise = torch.randn(B, nz, 1, 1, device=device)
        fake = G(noise)
        out_fake = D(fake.detach())
        loss_fake = criterion(out_fake, fake_label)

        loss_D = loss_real + loss_fake
        loss_D.backward()
        optD.step()

        # --- 訓練生成器 ---
        G.zero_grad()
        out_fake2 = D(fake)
        loss_G = criterion(out_fake2, real_label)
        loss_G.backward()
        optG.step()

    G_losses.append(loss_G.item())
    D_losses.append(loss_D.item())

    # 計算 FID(每 5 個 epoch)
    if (epoch + 1) % 5 == 0:
        G.eval()
        with torch.no_grad():
            fake_sample = G(torch.randn(500, nz, 1, 1, device=device))
        real_sample = torch.stack([dataset[i][0] for i in range(500)]).to(device)
        fid = compute_fid(real_sample, fake_sample)
        fid_scores.append((epoch + 1, fid))
        print(f"Epoch {epoch+1}/{n_epochs}: G_loss={loss_G.item():.4f}, "
              f"D_loss={loss_D.item():.4f}, FID={fid:.1f}")
        G.train()

# --- 7. 視覺化結果 ---
G.eval()
with torch.no_grad():
    generated = G(fixed_noise).cpu()

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# 生成影像
axes[0].imshow(make_grid(generated, nrow=8, normalize=True, padding=2).permute(1, 2, 0),
               cmap='gray')
axes[0].set_title('Generated Digits (Epoch 25)', fontsize=13)
axes[0].axis('off')

# 損失曲線
axes[1].plot(G_losses, label='Generator', color='#0077b6')
axes[1].plot(D_losses, label='Discriminator', color='#b8922e')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].set_title('Training Loss', fontsize=13)
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# FID 曲線
if fid_scores:
    epochs_f, fids = zip(*fid_scores)
    axes[2].plot(epochs_f, fids, 'o-', color='#e63946', linewidth=2)
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('FID (lower = better)')
    axes[2].set_title('FID Score', fontsize=13)
    axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# --- 8. 潛在空間插值 ---
z1 = torch.randn(1, nz, 1, 1, device=device)
z2 = torch.randn(1, nz, 1, 1, device=device)
alphas = torch.linspace(0, 1, 10)

interpolated = []
with torch.no_grad():
    for a in alphas:
        z = z1 * (1 - a) + z2 * a
        interpolated.append(G(z).cpu())

fig, axes = plt.subplots(1, 10, figsize=(20, 2))
for i, img in enumerate(interpolated):
    axes[i].imshow(img[0, 0], cmap='gray')
    axes[i].set_title(f'α={alphas[i]:.1f}', fontsize=9)
    axes[i].axis('off')
plt.suptitle('Latent Space Interpolation', fontsize=14)
plt.tight_layout()
plt.show()

print("Lab 1 Complete!")

十二、Hands-on Lab 2:SeqGAN 文本序列生成(Google Colab)

以下實驗實現 SeqGAN 的核心概念——以策略梯度訓練生成器產生離散 token 序列,判別器評估序列品質。

# ============================================================
# Lab 2: SeqGAN 概念實現 — 文本序列生成
# 環境: Google Colab (GPU 或 CPU)
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# --- 1. 資料準備:簡單英文句型 ---
# 用規則生成的「真實」句子訓練判別器
templates = [
    "the cat sat on the mat",
    "the dog ran in the park",
    "a bird flew over the tree",
    "the fish swam in the lake",
    "a boy read in the room",
    "the girl sang on the stage",
    "a man walked to the store",
    "the sun set in the west",
    "a car drove on the road",
    "the wind blew through the door",
]

# 擴展資料集
import random
random.seed(42)
subjects = ["the cat", "the dog", "a bird", "the fish", "a boy",
            "the girl", "a man", "the sun", "a car", "the wind"]
verbs = ["sat", "ran", "flew", "swam", "read", "sang", "walked", "set", "drove", "blew"]
preps = ["on", "in", "over", "through", "to", "under", "near", "past", "by", "from"]
objects = ["the mat", "the park", "the tree", "the lake", "the room",
           "the stage", "the store", "the west", "the road", "the door"]

real_sentences = []
for _ in range(2000):
    s = f"{random.choice(subjects)} {random.choice(verbs)} {random.choice(preps)} {random.choice(objects)}"
    real_sentences.append(s)

# 建立詞彙表
all_words = set()
for s in real_sentences:
    all_words.update(s.split())
vocab = {"<pad>": 0, "<bos>": 1, "<eos>": 2}
for w in sorted(all_words):
    vocab[w] = len(vocab)
idx2word = {v: k for k, v in vocab.items()}
vocab_size = len(vocab)
SEQ_LEN = 8  # <bos> + 6 words + <eos>

print(f"Vocab size: {vocab_size}, Seq len: {SEQ_LEN}")

def encode_sentence(s):
    tokens = [1] + [vocab.get(w, 0) for w in s.split()][:SEQ_LEN-2] + [2]
    tokens += [0] * (SEQ_LEN - len(tokens))
    return tokens

real_data = torch.tensor([encode_sentence(s) for s in real_sentences]).to(device)
print(f"Real data shape: {real_data.shape}")

# --- 2. 生成器(LSTM 策略網路)---
class SeqGenerator(nn.Module):
    def __init__(self, vocab_size, emb_dim=64, hidden_dim=128):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.lstm = nn.LSTM(emb_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.hidden_dim = hidden_dim

    def forward(self, x, hidden=None):
        emb = self.emb(x)
        out, hidden = self.lstm(emb, hidden)
        logits = self.fc(out)
        return logits, hidden

    def generate(self, batch_size, max_len=SEQ_LEN):
        """自迴歸生成序列"""
        x = torch.full((batch_size, 1), 1, dtype=torch.long, device=device)  # <bos>
        hidden = None
        sequences = [x]
        log_probs_list = []

        for _ in range(max_len - 1):
            logits, hidden = self.forward(x, hidden)
            probs = F.softmax(logits[:, -1], dim=-1)
            dist = torch.distributions.Categorical(probs)
            x = dist.sample().unsqueeze(1)
            log_probs_list.append(dist.log_prob(x.squeeze(1)))
            sequences.append(x)

        sequences = torch.cat(sequences, dim=1)
        log_probs = torch.stack(log_probs_list, dim=1)
        return sequences, log_probs

# --- 3. 判別器(CNN 序列分類器)---
class SeqDiscriminator(nn.Module):
    def __init__(self, vocab_size, emb_dim=64):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.convs = nn.ModuleList([
            nn.Conv1d(emb_dim, 32, k, padding=k//2) for k in [2, 3, 4]
        ])
        self.fc = nn.Linear(32 * 3, 1)
        self.drop = nn.Dropout(0.2)

    def forward(self, x):
        emb = self.emb(x).transpose(1, 2)  # [B, emb, L]
        conv_outs = [F.relu(conv(emb)).max(dim=2).values for conv in self.convs]
        out = torch.cat(conv_outs, dim=1)
        return torch.sigmoid(self.fc(self.drop(out))).squeeze(1)

# --- 4. 初始化 ---
G = SeqGenerator(vocab_size).to(device)
D = SeqDiscriminator(vocab_size).to(device)
optG = torch.optim.Adam(G.parameters(), lr=1e-3)
optD = torch.optim.Adam(D.parameters(), lr=1e-3)
print(f"G params: {sum(p.numel() for p in G.parameters()):,}")
print(f"D params: {sum(p.numel() for p in D.parameters()):,}")

# --- 5. 預訓練生成器(MLE)---
print("\n--- Pre-training Generator (MLE) ---")
for epoch in range(30):
    idx = torch.randperm(len(real_data))[:256]
    batch = real_data[idx]
    logits, _ = G(batch[:, :-1])
    loss = F.cross_entropy(logits.reshape(-1, vocab_size), batch[:, 1:].reshape(-1),
                           ignore_index=0)
    optG.zero_grad()
    loss.backward()
    optG.step()
    if (epoch + 1) % 10 == 0:
        print(f"  Epoch {epoch+1}: MLE loss = {loss.item():.4f}")

# --- 6. 預訓練判別器 ---
print("\n--- Pre-training Discriminator ---")
for epoch in range(20):
    idx = torch.randperm(len(real_data))[:128]
    real_batch = real_data[idx]
    with torch.no_grad():
        fake_batch, _ = G.generate(128)

    real_score = D(real_batch)
    fake_score = D(fake_batch)
    loss_D = -torch.log(real_score + 1e-8).mean() - torch.log(1 - fake_score + 1e-8).mean()
    optD.zero_grad()
    loss_D.backward()
    optD.step()
    if (epoch + 1) % 10 == 0:
        print(f"  Epoch {epoch+1}: D loss = {loss_D.item():.4f}, "
              f"D(real)={real_score.mean():.3f}, D(fake)={fake_score.mean():.3f}")

# --- 7. 對抗訓練(SeqGAN 策略梯度)---
print("\n--- Adversarial Training (SeqGAN) ---")
g_rewards_history, d_scores_history = [], []

for epoch in range(50):
    # --- 訓練 G(策略梯度)---
    fake_seqs, log_probs = G.generate(64)
    with torch.no_grad():
        rewards = D(fake_seqs)  # [B]

    # REINFORCE: reward baseline
    baseline = rewards.mean()
    pg_loss = -((rewards - baseline).unsqueeze(1) * log_probs).mean()

    optG.zero_grad()
    pg_loss.backward()
    nn.utils.clip_grad_norm_(G.parameters(), 5.0)
    optG.step()

    # --- 訓練 D ---
    idx = torch.randperm(len(real_data))[:64]
    real_batch = real_data[idx]
    with torch.no_grad():
        fake_batch, _ = G.generate(64)

    real_score = D(real_batch)
    fake_score = D(fake_batch)
    loss_D = -torch.log(real_score + 1e-8).mean() - torch.log(1 - fake_score + 1e-8).mean()
    optD.zero_grad()
    loss_D.backward()
    optD.step()

    g_rewards_history.append(rewards.mean().item())
    d_scores_history.append(fake_score.mean().item())

    if (epoch + 1) % 10 == 0:
        print(f"  Epoch {epoch+1}: G_reward={rewards.mean():.3f}, "
              f"D(fake)={fake_score.mean():.3f}, D(real)={real_score.mean():.3f}")

# --- 8. 生成與評估 ---
G.eval()
print("\n=== Generated Sentences ===")
with torch.no_grad():
    seqs, _ = G.generate(20)
    for seq in seqs:
        words = [idx2word.get(t.item(), "?") for t in seq if t.item() > 2]
        sentence = " ".join(words)
        score = D(seq.unsqueeze(0)).item()
        print(f"  [{score:.3f}] {sentence}")

# --- 9. 訓練過程視覺化 ---
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(g_rewards_history, color='#0077b6', linewidth=1.5)
axes[0].set_xlabel('Adversarial Epoch')
axes[0].set_ylabel('Average Reward')
axes[0].set_title('Generator Reward (D score on fake)', fontsize=13)
axes[0].grid(True, alpha=0.3)

axes[1].plot(d_scores_history, color='#b8922e', linewidth=1.5, label='D(fake)')
axes[1].axhline(y=0.5, color='#e63946', linestyle='--', alpha=0.7, label='Equilibrium')
axes[1].set_xlabel('Adversarial Epoch')
axes[1].set_ylabel('D(fake)')
axes[1].set_title('Discriminator Score on Fake Sequences', fontsize=13)
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# --- 10. 真實 vs 生成句子比較 ---
print("\n=== Real vs Generated Comparison ===")
print("Real sentences:")
for s in real_sentences[:5]:
    print(f"  {s}")
print("\nGenerated sentences:")
with torch.no_grad():
    seqs, _ = G.generate(5)
    for seq in seqs:
        words = [idx2word.get(t.item(), "?") for t in seq if t.item() > 2]
        print(f"  {' '.join(words)}")

print("\nLab 2 Complete!")

十三、結語與展望

GAN 的故事是深度學習中最精彩的篇章之一。從 2014 年 Goodfellow 的原始論文[1]到 StyleGAN3[18]的照片級真實感,GAN 在不到十年間徹底改變了我們對「機器能否創造」的認知。

回顧核心脈絡:

展望未來,GAN 並未退場。GigaGAN 等工作將 GAN 擴展到十億參數規模;蒸餾技術讓擴散模型學習一步生成(本質上是學習成為 GAN);而 GAN 的判別器思想也被廣泛融入 RLHF 等對齊技術中。對抗的思想——讓兩個系統在競爭中共同進步——將持續影響 AI 的發展方向。