- GAN[1]は生成器と識別器のゼロサムゲームにより敵対的生成モデリングの先駆けとなり、深層学習史上最も影響力のあるフレームワークの一つである
- DCGAN[2]からStyleGAN3[18]まで、GANアーキテクチャは五世代の進化を遂げ、画像生成品質はぼやけたノイズから写真のようにリアルな忠実度へと飛躍した
- WGAN[3]はWasserstein距離で訓練目標を再定義し、オリジナルGANの訓練不安定性とモード崩壊の問題を解決した。CycleGAN[10]はペアなし画像変換を実現した
- 本記事には二つのGoogle Colab実装を収録:DCGAN手書き数字生成(FID評価付き)とSeqGANテキストシーケンス生成(強化学習で離散GANを訓練)
I. 敵対的学習のアート:GANの中核哲学
2014年、Ian Goodfellowらは画期的な論文[1]で敵対的生成ネットワーク(GAN)を提案した。その中核的アイデアは優雅なアナロジーに由来する:偽造者と警察のゲーム。
偽造者(生成器G)は本物と見分けがつかない偽札を作ろうとし、警察(識別器D)は本物と偽物を見分けようとする。ゲームが進むにつれ、偽造者は偽造の腕をますます上げ、警察は検出能力をますます高める——偽造品が完璧になり、警察ですら区別できなくなるまでこのゲームは続く。
このゲームの数学的形式はミニマックスゼロサムゲームである:
min_G max_D V(D, G) = E_{x~p_data}[log D(x)] + E_{z~p_z}[log(1 - D(G(z)))]
ここで:
G(z):生成器。ランダムノイズzを偽サンプルにマッピング
D(x):識別器。xが実サンプルである確率を出力
p_data:実データ分布
p_z:ノイズ分布(通常はガウス分布または一様分布)
最適識別器:D*(x) = p_data(x) / (p_data(x) + p_g(x))
ナッシュ均衡:p_g = p_data, D*(x) = 0.5(完全に区別不可能)
GAN訓練は二段階の交互プロセスである:
- 識別器の訓練:Gを固定し、Dを訓練して実サンプルとGが生成した偽サンプルを区別する
- 生成器の訓練:Dを固定し、Gを訓練してDが本物と誤分類するサンプルを生成する
II. 訓練の悪夢:GANの三大課題
GANの理論は優雅だが、訓練は極めて困難である。三つの中核的な問題が研究コミュニティを何年も悩ませた:
モード崩壊
生成器が「近道」をとる——データ分布全体をカバーするのではなく、識別器を騙せる特定の少数サンプルのみを生成することを学習する。例えば、MNISTで訓練されたGANが数字の1と7のみを生成するようになることがある。
訓練の不安定性
オリジナルGANの損失関数はJensen-Shannonダイバージェンスに基づいている。実分布と生成分布に重なりがない場合(高次元空間ではほぼ確実にそうなる)、JSダイバージェンスはlog(2)で一定となり、勾配がゼロとなって生成器は学習できなくなる。
評価の困難さ
GANにはVAEのようなELBO損失がなく追跡が困難である。訓練損失は生成品質を反映しない——識別器の損失が0.5に近づくことは均衡を意味する可能性もあれば、崩壊を意味する可能性もある。
III. DCGAN:画像GANの設計ガイドラインの確立
Radfordら[2]はDCGAN(Deep Convolutional GAN)を提案し、画像GANのための古典的な設計ガイドラインを確立した:
| ガイドライン | 生成器 | 識別器 |
|---|---|---|
| アップ/ダウンサンプリング | 転置畳み込み(stride 2) | ストライド畳み込み(stride 2) |
| 正規化 | BatchNorm(出力層を除く) | BatchNorm(入力層を除く) |
| 活性化関数 | ReLU(出力はTanh) | LeakyReLU(0.2) |
| 全結合層 | 使用しない | 使用しない |
これら一見シンプルなガイドラインは深い意義を持っていた——GANの訓練を「偶然の成功」から「再現可能」なものへと変えたのである。DCGANはまた、生成器の潜在空間が意味のある構造を持つことを実証した:z(眼鏡をかけた男性) - z(男性) + z(女性) ≈ z(眼鏡をかけた女性)。
IV. WGAN:距離メトリクスの再定義
WGAN[3](Wasserstein GAN)はGAN訓練安定性の分水嶺であった。Arjovskyらは訓練目標をJSダイバージェンスからWasserstein-1距離(Earth Mover's Distance)に変更した:
W(p_data, p_g) = inf_{γ∈Π(p_data, p_g)} E_{(x,y)~γ}[||x - y||]
直感的理解:分布p_gをp_dataに「輸送」するために必要な最小の「仕事量」
WGAN損失(Kantorovich-Rubinsteinの双対性):
L_critic = E_{x~p_data}[f_w(x)] - E_{z~p_z}[f_w(G(z))]
L_G = -E_{z~p_z}[f_w(G(z))]
f_wは1-リプシッツ連続の「critic」(識別器を置き換える)
Wasserstein距離の主要な利点:二つの分布に重なりがなくても、意味のある勾配を提供する。WGANの損失値は生成品質と正の相関を持つ——損失が低いほど生成品質が高い。これはオリジナルGANでは実現できなかったことである。
WGAN-GP[4]はさらに、オリジナルWGANのウェイトクリッピングを勾配ペナルティに置き換えてリプシッツ制約を施行し、より安定した訓練を実現した。
V. StyleGANシリーズ:画像生成の頂点
NVIDIAのKarrasらによる一連の研究[5][6][7][18]は、GANの画像生成品質を驚異的な高みへと押し上げた:
| モデル | 年 | 主要イノベーション | インパクト |
|---|---|---|---|
| ProGAN[5] | 2018 | プログレッシブ訓練(4x4 → 1024x1024) | 初の1024x1024フォトリアリスティック顔生成 |
| StyleGAN[6] | 2019 | マッピングネットワーク + AdaINスタイル注入 | きめ細かな階層的スタイル制御 |
| StyleGAN2[7] | 2020 | ウェイトデモジュレーション + パス長正則化 | 水滴アーティファクトの除去、より滑らかな潜在空間 |
| StyleGAN3[18] | 2021 | エイリアシング修正、連続信号処理 | サブピクセル等変性、「テクスチャ固着」の解消 |
StyleGANの中核的なアーキテクチャイノベーションはコンテンツとスタイルの分離である:
StyleGAN生成器:
z ∈ Z (512次元) → マッピングネットワーク (8層MLP) → w ∈ W (512次元)
↓
定数4×4 → [AdaIN(w)] → Conv → [AdaIN(w)] → アップサンプル → ...
粗いスタイル 中間スタイル 細かいスタイル
(ポーズ/顔の形状) (特徴/髪) (肌/テクスチャ)
W空間はZ空間よりも「もつれ解消」されている → 線形補間で意味のある遷移が生成される
VI. 画像変換:Pix2PixとCycleGAN
Pix2Pix:ペアあり画像変換
Isolaら[9]はPix2Pixを提案し、条件付きGAN[8]を画像変換に適用した——エッジマップから写真へ、セマンティックラベルからストリートシーンへ、グレースケールからカラーへ。U-Net生成器とPatchGAN識別器(画像全体ではなく70x70パッチ単位で真偽を判定)を使用する。
CycleGAN:ペアなし画像変換
CycleGAN[10]はさらに困難な問題を解決した:ペアデータなしでの画像変換。その中核はサイクル一貫性損失である——馬をシマウマに変換し、そのシマウマを元に戻すと、元の馬が得られるはずである:
サイクル一貫性:
x → G(x) → F(G(x)) ≈ x (順方向サイクル)
y → F(y) → G(F(y)) ≈ y (逆方向サイクル)
L_cycle = ||F(G(x)) - x||₁ + ||G(F(y)) - y||₁
VII. テキスト向けGAN:離散シーケンスの課題
GANをテキスト生成に適用するには根本的な障壁がある:離散トークンのサンプリングは微分不可能である。argmaxやカテゴリカルサンプリングを通じて逆伝播を行うことはできない。
SeqGAN[11]は強化学習アプローチを用いてこの問題をエレガントに解決した:
SeqGANフレームワーク:
生成器 = 方策ネットワーク
識別器 = 報酬関数
行動 = 次のトークンの選択
状態 = 現在生成されたトークンシーケンス
訓練プロセス:
1. Gが自己回帰的に完全なシーケンスを生成
2. Dが完全なシーケンスをスコアリング(本物/偽物)
3. モンテカルロロールアウトで中間状態のQ値を推定
4. REINFORCE方策勾配でGを更新
Zhangら[17]は代替アプローチを提案した——離散トークン空間ではなく潜在特徴空間で敵対的訓練を行い、カーネル化された差異メトリクスを使って分布をマッチングすることで、離散サンプリング問題を回避する。
VIII. 訓練安定化テクニック
WGANシリーズ以外にも、いくつかの重要な安定化テクニックが開発されている:
| テクニック | 中核アイデア | 効果 |
|---|---|---|
| スペクトル正規化[12] | 識別器の各層の重み行列を最大特異値で除算 | 追加計算コストをほぼ要さない軽量なリプシッツ制約 |
| プログレッシブ訓練[5] | 低解像度から開始し、段階的に層を追加 | より安定した高解像度画像生成 |
| 截断トリック[13] | 推論時にzの範囲を制限(截断) | 多様性を犠牲にして品質を向上 |
| 特徴マッチング[14] | 識別器の中間層の統計量をマッチング | モード崩壊を軽減 |
| 二重タイムスケール[15] | DとGに異なる学習率を使用 | 理論的に局所ナッシュ均衡に収束 |
IX. 評価メトリクス:生成品質の測定方法
| メトリクス | 計算方法 | 測定対象 | 限界 |
|---|---|---|---|
| IS[14] | Inceptionモデル予測のKLダイバージェンス | 鮮明さ + 多様性 | 実分布との比較なし |
| FID[15] | 実画像と生成画像のInception特徴のフレシェ距離 | 品質 + 多様性 + 実分布への近さ | 多数のサンプルが必要、Inceptionに依存 |
| LPIPS | 知覚的特徴空間での距離 | 知覚的類似性 | ペア比較であり、分布レベルではない |
FID[15]は現在最も広く使用されているGAN評価メトリクスである。Inception-v3の特徴空間における実画像と生成画像の平均と共分散を比較し、値が低いほど生成品質が高いことを示す。典型的なFIDスコア:StyleGAN2のFFHQでの値は約2.8、BigGAN[13]のImageNet 128x128での値は約6.9。
X. GAN vs 拡散モデル:パラダイムシフト
2021年、DhariwalとNichol[16]は画期的な論文「拡散モデルがGANに勝利」を発表し、拡散モデルがFIDで最良のGANを包括的に上回ることを示した。しかし、これはGANが時代遅れになったことを意味するわけではない:
| 観点 | GAN | 拡散モデル |
|---|---|---|
| 生成速度 | 単一フォワードパス(約20ms) | 数十〜数百の反復ステップ(約2-10秒) |
| 訓練安定性 | 敵対的訓練は不安定 | シンプルなデノイジング目標、安定 |
| モードカバレッジ | モード崩壊を起こしやすい | 分布カバレッジが優れる |
| 画像品質 | 極めて高い(StyleGANシリーズ) | 極めて高い(Imagen、DALL-E 3) |
| 制御性 | StyleGANのW空間 | テキスト条件付きガイダンス |
| ユースケース | リアルタイムレンダリング、動画、超解像 | テキストから画像、編集 |
リアルタイム生成が求められるシナリオ(ゲーム、動画、インタラクティブアプリケーション)では、GANのシングルパス推論速度は依然として他の追随を許さない。
XI. ハンズオンラボ1:DCGAN手書き数字生成(Google Colab)
以下の実験では、DCGANをスクラッチから実装し、MNISTで手書き数字を生成する訓練を行い、FIDスコアを計算して訓練品質を追跡する。
# ============================================================
# Lab 1: DCGAN — MNIST手書き数字生成 + FID評価
# 環境: Google Colab (GPU)
# ============================================================
# --- 0. インストール ---
!pip install -q torchvision scipy
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from scipy import linalg
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# --- 1. データ読み込み ---
transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # [-1, 1]
])
dataset = torchvision.datasets.MNIST(root='./data', train=True,
transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128,
shuffle=True, num_workers=2)
print(f"Training images: {len(dataset)}")
# --- 2. 生成器(DCGANアーキテクチャ) ---
nz = 100 # 潜在ベクトル次元
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
# 入力: z [B, 100, 1, 1]
nn.ConvTranspose2d(nz, 256, 4, 1, 0, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# [B, 256, 4, 4]
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
# [B, 128, 8, 8]
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
# [B, 64, 16, 16]
nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
nn.Tanh()
# [B, 1, 32, 32]
)
def forward(self, z):
return self.main(z)
# --- 3. 識別器(DCGANアーキテクチャ) ---
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
# [B, 1, 32, 32]
nn.Conv2d(1, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# [B, 64, 16, 16]
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# [B, 128, 8, 8]
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# [B, 256, 4, 4]
nn.Conv2d(256, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x).view(-1)
# --- 4. 初期化 ---
def weights_init(m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.normal_(m.weight, 0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
G = Generator().to(device).apply(weights_init)
D = Discriminator().to(device).apply(weights_init)
print(f"G params: {sum(p.numel() for p in G.parameters()):,}")
print(f"D params: {sum(p.numel() for p in D.parameters()):,}")
criterion = nn.BCELoss()
optG = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
optD = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
# --- 5. 簡易FID計算 ---
def compute_fid(real_images, fake_images, n_features=64):
"""簡易FID:Inceptionの代わりに画像ピクセル特徴を使用"""
real_flat = real_images.view(real_images.size(0), -1).cpu().numpy()
fake_flat = fake_images.view(fake_images.size(0), -1).cpu().numpy()
mu_r, sigma_r = real_flat.mean(axis=0), np.cov(real_flat, rowvar=False)
mu_f, sigma_f = fake_flat.mean(axis=0), np.cov(fake_flat, rowvar=False)
diff = mu_r - mu_f
covmean = linalg.sqrtm(sigma_r @ sigma_f)
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = diff @ diff + np.trace(sigma_r + sigma_f - 2 * covmean)
return fid
# --- 6. 訓練ループ ---
n_epochs = 25
G_losses, D_losses, fid_scores = [], [], []
print("\nTraining DCGAN...")
for epoch in range(n_epochs):
for i, (real, _) in enumerate(dataloader):
B = real.size(0)
real = real.to(device)
real_label = torch.ones(B, device=device)
fake_label = torch.zeros(B, device=device)
# --- 識別器の訓練 ---
D.zero_grad()
out_real = D(real)
loss_real = criterion(out_real, real_label)
noise = torch.randn(B, nz, 1, 1, device=device)
fake = G(noise)
out_fake = D(fake.detach())
loss_fake = criterion(out_fake, fake_label)
loss_D = loss_real + loss_fake
loss_D.backward()
optD.step()
# --- 生成器の訓練 ---
G.zero_grad()
out_fake2 = D(fake)
loss_G = criterion(out_fake2, real_label)
loss_G.backward()
optG.step()
G_losses.append(loss_G.item())
D_losses.append(loss_D.item())
# FID計算(5エポックごと)
if (epoch + 1) % 5 == 0:
G.eval()
with torch.no_grad():
fake_sample = G(torch.randn(500, nz, 1, 1, device=device))
real_sample = torch.stack([dataset[i][0] for i in range(500)]).to(device)
fid = compute_fid(real_sample, fake_sample)
fid_scores.append((epoch + 1, fid))
print(f"Epoch {epoch+1}/{n_epochs}: G_loss={loss_G.item():.4f}, "
f"D_loss={loss_D.item():.4f}, FID={fid:.1f}")
G.train()
# --- 7. 可視化 ---
G.eval()
with torch.no_grad():
generated = G(fixed_noise).cpu()
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
# 生成画像
axes[0].imshow(make_grid(generated, nrow=8, normalize=True, padding=2).permute(1, 2, 0),
cmap='gray')
axes[0].set_title('Generated Digits (Epoch 25)', fontsize=13)
axes[0].axis('off')
# 損失曲線
axes[1].plot(G_losses, label='Generator', color='#0077b6')
axes[1].plot(D_losses, label='Discriminator', color='#b8922e')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].set_title('Training Loss', fontsize=13)
axes[1].legend()
axes[1].grid(True, alpha=0.3)
# FID曲線
if fid_scores:
epochs_f, fids = zip(*fid_scores)
axes[2].plot(epochs_f, fids, 'o-', color='#e63946', linewidth=2)
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('FID (lower = better)')
axes[2].set_title('FID Score', fontsize=13)
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# --- 8. 潜在空間補間 ---
z1 = torch.randn(1, nz, 1, 1, device=device)
z2 = torch.randn(1, nz, 1, 1, device=device)
alphas = torch.linspace(0, 1, 10)
interpolated = []
with torch.no_grad():
for a in alphas:
z = z1 * (1 - a) + z2 * a
interpolated.append(G(z).cpu())
fig, axes = plt.subplots(1, 10, figsize=(20, 2))
for i, img in enumerate(interpolated):
axes[i].imshow(img[0, 0], cmap='gray')
axes[i].set_title(f'α={alphas[i]:.1f}', fontsize=9)
axes[i].axis('off')
plt.suptitle('Latent Space Interpolation', fontsize=14)
plt.tight_layout()
plt.show()
print("Lab 1 Complete!")
XII. ハンズオンラボ2:SeqGANテキストシーケンス生成(Google Colab)
以下の実験では、SeqGANの中核概念を実装する——方策勾配を使って生成器を訓練し離散トークンシーケンスを生成し、識別器がシーケンスの品質を評価する。
# ============================================================
# Lab 2: SeqGAN概念実装 — テキストシーケンス生成
# 環境: Google Colab (GPU or CPU)
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# --- 1. データ準備:シンプルな英語文パターン ---
# ルールベースで生成した「実」文で識別器を訓練
templates = [
"the cat sat on the mat",
"the dog ran in the park",
"a bird flew over the tree",
"the fish swam in the lake",
"a boy read in the room",
"the girl sang on the stage",
"a man walked to the store",
"the sun set in the west",
"a car drove on the road",
"the wind blew through the door",
]
# データセット拡張
import random
random.seed(42)
subjects = ["the cat", "the dog", "a bird", "the fish", "a boy",
"the girl", "a man", "the sun", "a car", "the wind"]
verbs = ["sat", "ran", "flew", "swam", "read", "sang", "walked", "set", "drove", "blew"]
preps = ["on", "in", "over", "through", "to", "under", "near", "past", "by", "from"]
objects = ["the mat", "the park", "the tree", "the lake", "the room",
"the stage", "the store", "the west", "the road", "the door"]
real_sentences = []
for _ in range(2000):
s = f"{random.choice(subjects)} {random.choice(verbs)} {random.choice(preps)} {random.choice(objects)}"
real_sentences.append(s)
# 語彙構築
all_words = set()
for s in real_sentences:
all_words.update(s.split())
vocab = {"<pad>": 0, "<bos>": 1, "<eos>": 2}
for w in sorted(all_words):
vocab[w] = len(vocab)
idx2word = {v: k for k, v in vocab.items()}
vocab_size = len(vocab)
SEQ_LEN = 8 # <bos> + 6 words + <eos>
print(f"Vocab size: {vocab_size}, Seq len: {SEQ_LEN}")
def encode_sentence(s):
tokens = [1] + [vocab.get(w, 0) for w in s.split()][:SEQ_LEN-2] + [2]
tokens += [0] * (SEQ_LEN - len(tokens))
return tokens
real_data = torch.tensor([encode_sentence(s) for s in real_sentences]).to(device)
print(f"Real data shape: {real_data.shape}")
# --- 2. 生成器(リカレントニューラルネットワーク方策ネットワーク) ---
class SeqGenerator(nn.Module):
def __init__(self, vocab_size, emb_dim=64, hidden_dim=128):
super().__init__()
self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
self.lstm = nn.LSTM(emb_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, vocab_size)
self.hidden_dim = hidden_dim
def forward(self, x, hidden=None):
emb = self.emb(x)
out, hidden = self.lstm(emb, hidden)
logits = self.fc(out)
return logits, hidden
def generate(self, batch_size, max_len=SEQ_LEN):
"""自己回帰的シーケンス生成"""
x = torch.full((batch_size, 1), 1, dtype=torch.long, device=device) # <bos>
hidden = None
sequences = [x]
log_probs_list = []
for _ in range(max_len - 1):
logits, hidden = self.forward(x, hidden)
probs = F.softmax(logits[:, -1], dim=-1)
dist = torch.distributions.Categorical(probs)
x = dist.sample().unsqueeze(1)
log_probs_list.append(dist.log_prob(x.squeeze(1)))
sequences.append(x)
sequences = torch.cat(sequences, dim=1)
log_probs = torch.stack(log_probs_list, dim=1)
return sequences, log_probs
# --- 3. 識別器(畳み込みニューラルネットワークシーケンス分類器) ---
class SeqDiscriminator(nn.Module):
def __init__(self, vocab_size, emb_dim=64):
super().__init__()
self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
self.convs = nn.ModuleList([
nn.Conv1d(emb_dim, 32, k, padding=k//2) for k in [2, 3, 4]
])
self.fc = nn.Linear(32 * 3, 1)
self.drop = nn.Dropout(0.2)
def forward(self, x):
emb = self.emb(x).transpose(1, 2) # [B, emb, L]
conv_outs = [F.relu(conv(emb)).max(dim=2).values for conv in self.convs]
out = torch.cat(conv_outs, dim=1)
return torch.sigmoid(self.fc(self.drop(out))).squeeze(1)
# --- 4. 初期化 ---
G = SeqGenerator(vocab_size).to(device)
D = SeqDiscriminator(vocab_size).to(device)
optG = torch.optim.Adam(G.parameters(), lr=1e-3)
optD = torch.optim.Adam(D.parameters(), lr=1e-3)
print(f"G params: {sum(p.numel() for p in G.parameters()):,}")
print(f"D params: {sum(p.numel() for p in D.parameters()):,}")
# --- 5. 生成器の事前訓練(MLE) ---
print("\n--- 生成器の事前訓練 (MLE) ---")
for epoch in range(30):
idx = torch.randperm(len(real_data))[:256]
batch = real_data[idx]
logits, _ = G(batch[:, :-1])
loss = F.cross_entropy(logits.reshape(-1, vocab_size), batch[:, 1:].reshape(-1),
ignore_index=0)
optG.zero_grad()
loss.backward()
optG.step()
if (epoch + 1) % 10 == 0:
print(f" Epoch {epoch+1}: MLE loss = {loss.item():.4f}")
# --- 6. 識別器の事前訓練 ---
print("\n--- 識別器の事前訓練 ---")
for epoch in range(20):
idx = torch.randperm(len(real_data))[:128]
real_batch = real_data[idx]
with torch.no_grad():
fake_batch, _ = G.generate(128)
real_score = D(real_batch)
fake_score = D(fake_batch)
loss_D = -torch.log(real_score + 1e-8).mean() - torch.log(1 - fake_score + 1e-8).mean()
optD.zero_grad()
loss_D.backward()
optD.step()
if (epoch + 1) % 10 == 0:
print(f" Epoch {epoch+1}: D loss = {loss_D.item():.4f}, "
f"D(real)={real_score.mean():.3f}, D(fake)={fake_score.mean():.3f}")
# --- 7. 敵対的訓練 (SeqGAN方策勾配) ---
print("\n--- 敵対的訓練 (SeqGAN) ---")
g_rewards_history, d_scores_history = [], []
for epoch in range(50):
# --- Gの訓練(方策勾配) ---
fake_seqs, log_probs = G.generate(64)
with torch.no_grad():
rewards = D(fake_seqs) # [B]
# REINFORCE:報酬ベースライン
baseline = rewards.mean()
pg_loss = -((rewards - baseline).unsqueeze(1) * log_probs).mean()
optG.zero_grad()
pg_loss.backward()
nn.utils.clip_grad_norm_(G.parameters(), 5.0)
optG.step()
# --- Dの訓練 ---
idx = torch.randperm(len(real_data))[:64]
real_batch = real_data[idx]
with torch.no_grad():
fake_batch, _ = G.generate(64)
real_score = D(real_batch)
fake_score = D(fake_batch)
loss_D = -torch.log(real_score + 1e-8).mean() - torch.log(1 - fake_score + 1e-8).mean()
optD.zero_grad()
loss_D.backward()
optD.step()
g_rewards_history.append(rewards.mean().item())
d_scores_history.append(fake_score.mean().item())
if (epoch + 1) % 10 == 0:
print(f" Epoch {epoch+1}: G_reward={rewards.mean():.3f}, "
f"D(fake)={fake_score.mean():.3f}, D(real)={real_score.mean():.3f}")
# --- 8. 生成と評価 ---
G.eval()
print("\n=== 生成文 ===")
with torch.no_grad():
seqs, _ = G.generate(20)
for seq in seqs:
words = [idx2word.get(t.item(), "?") for t in seq if t.item() > 2]
sentence = " ".join(words)
score = D(seq.unsqueeze(0)).item()
print(f" [{score:.3f}] {sentence}")
# --- 9. 訓練プロセスの可視化 ---
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
axes[0].plot(g_rewards_history, color='#0077b6', linewidth=1.5)
axes[0].set_xlabel('Adversarial Epoch')
axes[0].set_ylabel('Average Reward')
axes[0].set_title('Generator Reward (D score on fake)', fontsize=13)
axes[0].grid(True, alpha=0.3)
axes[1].plot(d_scores_history, color='#b8922e', linewidth=1.5, label='D(fake)')
axes[1].axhline(y=0.5, color='#e63946', linestyle='--', alpha=0.7, label='Equilibrium')
axes[1].set_xlabel('Adversarial Epoch')
axes[1].set_ylabel('D(fake)')
axes[1].set_title('Discriminator Score on Fake Sequences', fontsize=13)
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# --- 10. 実文と生成文の比較 ---
print("\n=== 実文 vs 生成文 比較 ===")
print("実文:")
for s in real_sentences[:5]:
print(f" {s}")
print("\n生成文:")
with torch.no_grad():
seqs, _ = G.generate(5)
for seq in seqs:
words = [idx2word.get(t.item(), "?") for t in seq if t.item() > 2]
print(f" {' '.join(words)}")
print("\nLab 2 Complete!")
XIII. 結論と展望
GANの物語は、深層学習における最も魅力的な章の一つである。Goodfellowの2014年の原論文[1]からStyleGAN3[18]のフォトリアリスティックな忠実度まで、GANは10年足らずで「機械は創造できるのか」という我々の理解を根本的に変えた。
中核的な軌跡を振り返る:
- 理論的基盤:ミニマックスゲームフレームワーク[1]がエレガントな生成モデリングパラダイムを提供した
- 安定化のブレークスルー:WGAN[3]とスペクトル正規化[12]が訓練を「アート」から「エンジニアリング」へ変えた
- 品質の頂点:StyleGANシリーズ[6][7]が人間の目では区別できない生成品質を達成した
- 応用の爆発:Pix2Pix[9]とCycleGAN[10]が画像変換の無限の可能性を切り開いた
- パラダイム競争:拡散モデル[16]がGANを品質で上回ったが、GANは速度面での優位性を保持している
今後を見据えると、GANは舞台を去っていない。GigaGANのような研究はGANを数十億パラメータ規模に拡張し、蒸留技術により拡散モデルがシングルステップ生成を学習できるようになり(本質的にGANになることを学ぶ)、そしてGANの識別器の概念はRLHFなどのアラインメント技術に広く取り入れられている。二つのシステムが競争を通じて共に向上する——この敵対的なアイデアは、AI開発の方向性に影響を与え続けるだろう。



