主要な発見
  • 壊滅的忘却[1]は深層学習の根本的な欠陥です——モデルが新しいタスクを学習すると、旧タスクの知識を劇的に忘却し、人間のように継続的に学習することができません
  • 継続学習の3大戦略カテゴリ:正則化手法(EWC[3]、SI[4])は重要なパラメータを保護、アーキテクチャ手法(Progressive Networks[6])はネットワークを動的に拡張、リプレイ手法(ER[9]、GR[10])は旧タスクのサンプルを再訪
  • 大規模言語モデルも壊滅的忘却に直面します[15]——BERT/GPTの継続ファインチューニングは以前のタスク能力を急速に消失させ、経験リプレイが現在最も効果的な緩和戦略です
  • 本記事にはGoogle Colabハンズオンラボ2つが含まれます:EWCによるMNIST画像分類の忘却防止と、BERTの継続学習によるマルチタスクテキスト分類。いずれもブラウザで直接実行可能です

1. なぜAIは忘れるのか:壊滅的忘却の本質

人間は古い知識を忘れることなく新しいスキルを継続的に学習できます——自転車に乗れるようになった後、水泳を学んでも自転車の乗り方を忘れることはありません。しかし、深層ニューラルネットワークは根本的な問題に直面します:壊滅的忘却(Catastrophic Forgetting)[1][2]

タスクAで訓練済みのネットワークをその後タスクBで訓練すると、タスクAの性能は急激に低下します——緩やかな劣化ではなく、壊滅的な崩壊です。これは、勾配降下法が現在のタスクを最適化するためにすべてのパラメータを無差別に更新し、旧タスクにとって重要だった重みを上書きしてしまうために起こります。

壊滅的忘却の本質:

タスクAの訓練完了:θ* = argmin_θ L_A(θ)
                     → パラメータがAの最適解に収束

タスクBの逐次訓練:θ** = argmin_θ L_B(θ)、θ*から開始
                     → パラメータがAの最適点から離れ、Aの性能が崩壊

根本原因:安定性-可塑性ジレンマ
  - 安定しすぎ → 新タスクを学習できない(未学習)
  - 可塑的すぎ → 旧タスクを忘却する(壊滅的忘却)
  - 目標:両者のバランスを見つけること

具体例(画像分類):
  ステップ1:数字0-4で訓練 → 精度98%
  ステップ2:数字5-9で訓練 → 5-9の精度97%、しかし0-4は約20%に低下
  原因:5-9の勾配が0-4を区別するための重要な重みを破壊

壊滅的忘却は画像分類に限定されません。言語モデルの継続ファインチューニングでも同様に深刻です[15]:感情分類でファインチューニングしたBERTを、その後固有表現認識でファインチューニングすると、感情分類能力が著しく劣化します。大規模言語モデルの時代においてこの問題はさらに重要になっています——毎回ゼロから再訓練することなく、モデルが新しいデータから継続的に学習できることが求められています。

2. 継続学習の全体像:3大戦略カテゴリ

継続学習(ライフロング学習とも呼ばれる)の研究目標は、モデルが複数のタスクを順次学習しながら、新タスクの習得と旧タスクの性能維持の両方を実現することです[12][13]。異なる解決アプローチに基づき、手法は3つの主要カテゴリに分類できます[17]

戦略カテゴリ核心的アイデア代表的手法利点制限事項
正則化手法損失関数にペナルティ項を追加して重要なパラメータの変更を制約EWC[3]、SI[4]、LwF[5]旧データの保存不要、固定メモリフットプリントタスク数が増えると保護能力が劣化
アーキテクチャ手法異なるタスクに対して異なるネットワーク構造やサブネットワークを割り当てProgressive Nets[6]、PackNet[7]、HAT[8]ゼロ忘却(ハード分離)モデルサイズがタスク数とともに増加
リプレイ手法旧タスクのサンプルを保存または生成し、新タスク学習時に共同訓練ER[9]、GR[10]、GEM[11]、DER++[14]シンプル、効果的、他手法と組み合わせ可能旧サンプル保存用の追加メモリが必要

継続学習には3つの評価シナリオレベルもあります:

継続学習の3つのシナリオ:

1. タスク漸進学習(Task-IL):
   推論時にタスクIDが既知 → 最も簡単
   例:「これはタスクBのデータです、Bの分類ヘッドを使ってください」

2. ドメイン漸進学習(Domain-IL):
   同じタスク構造だがデータ分布が変化 → 中程度の難易度
   例:同じ10クラス分類だが画像スタイルがスケッチから写真に変化

3. クラス漸進学習(Class-IL):
   推論時にタスクIDが不明、学習済み全クラスを区別する必要あり → 最も難しい
   例:最初に0-4を学習、次に5-9を学習、テスト時に全数字0-9を区別

難易度順:Task-IL < Domain-IL < Class-IL
実用アプリケーションではClass-ILが現実のニーズに最も近い

3. 正則化手法:EWCと知識蒸留

Elastic Weight Consolidation(EWC)

EWC[3]は継続学習で最も影響力のある正則化手法で、神経科学のシナプス固化にインスパイアされています——重要なシナプス結合は保護されるべきで、重要でないものは自由に更新できるべきです。

核心的な問題は:各パラメータの旧タスクに対する「重要度」をどう測定するか?EWCの答えはフィッシャー情報行列です:

EWC損失関数:

L_total(θ) = L_B(θ) + (λ/2) Σ_i F_i (θ_i - θ*_A,i)²

ここで:
  L_B(θ):     新タスクBの損失
  θ*_A:       旧タスクA訓練後の最適パラメータ
  F_i:        フィッシャー情報行列の対角要素(パラメータiのタスクAに対する重要度)
  λ:          正則化強度(安定性-可塑性バランスを制御)

フィッシャー情報行列(対角近似):
  F_i = E_{x~D_A} [(∂ log p(y|x,θ) / ∂θ_i)²]

直感的理解:
  F_iが大 → パラメータiはタスクAにとって非常に重要 → 変更を強く制約
  F_iが小 → パラメータiはタスクAにとって重要でない → タスクB学習のため自由に更新

幾何学的視点:
  タスクAの最適点θ*_A周辺に「低損失の谷」が存在
  フィッシャー行列はこの谷の形状を記述
  EWCはタスクBの最適化を谷の延長方向に沿って導く
  → AとBの両方で良く機能するパラメータを見つける

Synaptic Intelligence(SI)

SI[4]はEWCのオンライン代替手法です。EWCが各タスク後にフィッシャー行列を計算する必要があるのに対し、SIは訓練中にリアルタイムで各パラメータの重要度を蓄積します——訓練中に「辿った経路」に沿った各パラメータの損失削減への貢献を追跡します。

Learning without Forgetting(LwF)

LwF[5]は異なるアプローチを取ります——パラメータではなく出力を保護します。新タスクを学習する前に、まず新タスクのデータを旧モデルに通して「ソフトラベル」を取得し、新タスクの訓練中に知識蒸留損失を同時に使用して旧タスクの出力分布を変化させないようにします。最大の利点は旧タスクのデータを一切保存する必要がないことです。

4. アーキテクチャ手法:Progressive Networksと動的拡張

アーキテクチャ手法の哲学は:限られたパラメータ空間内で新旧タスクのバランスに苦労するよりも、各タスクに専用のネットワーク容量を割り当てるというものです。

Progressive Neural Networks

Rusuらが提案したProgressive Networks[6]は最も直接的なアプローチです——新タスクごとに、既存のものと並んで新しいネットワーク列が「成長」し、ラテラル接続により新タスクが旧タスクで学習した特徴を再利用できます:

Progressive Neural Networks:

タスク1:[列1] ← 通常の訓練
タスク2:[列1](凍結)←─ ラテラル接続 ──→ [列2] ← これだけを訓練
タスク3:[列1](凍結)←─┐                [列2](凍結)←─┐
                              └─ ラテラル接続 ──→                      └─→ [列3]

利点:絶対的ゼロ忘却(旧列は凍結)
欠点:パラメータ数が線形に増加(Tタスク = T×パラメータ)

PackNetとHAT

PackNet[7]とHAT[8]は固定サイズのネットワーク内でマルチタスク学習を達成しようとします:

手法戦略メカニズム特徴
PackNet[7]反復的プルーニング訓練 → 重要でない重みをプルーニング → 次タスク用に容量を解放各タスクが専用のスパースサブネットワークを取得
HAT[8]ハードアテンションマスク各タスク用にバイナリマスクを学習し、使用済みニューロンを保護マスクは勾配最適化可能で、容量を自動割り当て

5. リプレイ手法:メモリバッファと生成リプレイ

経験リプレイ手法は認知科学の記憶固化からインスピレーションを得ています——人間は睡眠中に日中の経験を「リプレイ」して記憶を固化させます。継続学習では、リプレイ手法は新タスクの学習中に旧タスクのサンプルを混ぜ込みます[9]

Experience Replay(ER)

最も直接的なアプローチ:固定サイズのメモリバッファを維持し、各旧タスクから少数の代表的サンプルを格納します。新タスクの学習時に、各ミニバッチは新タスクデータとバッファからサンプリングしたデータを混合します:

Experience Replayのフロー:

メモリバッファ M(固定サイズ、例:200サンプル)

タスクtの学習:
  各ミニバッチに対して:
    batch_new = sample(D_t)           # 新タスクデータ
    batch_old = sample(M)             # バッファから旧データをサンプル
    loss = L(batch_new) + L(batch_old)  # 結合損失
    θを更新

  タスク完了後:
    D_tから代表的サンプルをMに追加(リザーバサンプリングまたはハーディングを使用)

リザーバサンプリング:
  確率 |M| / n でn番目のサンプルをバッファに追加、
  以前に見た各サンプルが等しい確率で選ばれることを保証

重要な発見(Rolnick et al., 2019):
  クラスあたりわずか1〜5サンプルで忘却を大幅に削減
  → 最小限のメモリコストで大きな忘却防止効果

Generative Replay(GR)

Shinら[10]はエレガントな代替案を提案しました:実際の旧データを保存する代わりに、生成モデル(GANやVAEなど)を訓練して旧タスクの仮想サンプルを生成します。これはプライバシーに敏感なシナリオで特に価値があります——医療データは保存できませんが、その分布は生成モデルで再構成できます。

GEMとDER++

GEM[11](Gradient Episodic Memory)はメモリ内のサンプルを使用して勾配制約を計算します:新タスクの勾配更新はメモリサンプル上の旧タスク損失を増加させてはなりません。DER++[14]は経験リプレイと知識蒸留を組み合わせます——旧データのラベルだけでなく、旧モデルのソフト出力(ロジット)もリプレイし、「ダークナレッジ」の形でより豊かな情報を保存します。

6. テキストAIの継続学習:言語モデルの継続ファインチューニング

大規模言語モデルの継続学習は現在の研究のフロンティアにあります[16]。企業がBERTやGPTを新しいタスクやドメインに継続的に適応させたい場合、壊滅的忘却が既存の能力に深刻な影響を与えます:

言語モデルの継続学習シナリオ:

1. 継続タスクファインチューニング:
   BERT → 感情分析 → 固有表現認識 → QA → テキスト要約
   問題:後のファインチューニングが以前のタスク能力を破壊

2. 継続ドメイン適応:
   汎用LLM → 金融ドメイン → 法律ドメイン → 医療ドメイン
   問題:新ドメインの知識が旧ドメインの専門性を上書き

3. 継続事前学習:
   基盤モデル → 新しいドキュメント/知識を継続的に吸収
   問題:新しい知識が基礎的な言語理解を損なう可能性

言語モデル忘却の固有の課題:
  - 極めて高いパラメータ共有(全タスクが同じTransformerを共有)
  - より深刻な表現空間の干渉(セマンティックオーバーラップが大きい)
  - タスクヘッドは分離できるが、基盤表現の分離は困難

Scialomらの研究[15]は、経験リプレイが現在、言語モデルの継続学習で最も効果的な手法であることを実証しています——新タスク学習時に少量の旧タスクサンプルを混入させることで、忘却が大幅に削減されます。これはNLPシナリオではEWCのような正則化手法よりも効果的です。言語タスクのパラメータ重要度分布がより均一であり、正則化制約の識別力が制限されるためです。

7. ハンズオンラボ1:EWCによるMNIST画像分類の忘却防止(Google Colab)

以下の実験では、Split MNISTで3つの戦略を比較します:(1) ナイーブファインチューニング、(2) EWC正則化、(3) 経験リプレイ(ER)。壊滅的忘却現象とその緩和を視覚的に実証します。

# ============================================================
# Lab 1: Continual Learning — EWC vs Experience Replay vs Naive Fine-tuning (Split MNIST)
# Environment: Google Colab (CPU is sufficient)
# ============================================================
# --- 0. Installation ---
!pip install -q torch torchvision matplotlib

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import copy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# --- 1. Data Preparation: Split MNIST ---
# Task A: Digits 0-4, Task B: Digits 5-9
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_data = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = torchvision.datasets.MNIST('./data', train=False, transform=transform)

def filter_by_labels(dataset, labels):
    """Filter data by specific labels"""
    mask = torch.zeros(len(dataset.targets), dtype=torch.bool)
    for l in labels:
        mask |= (dataset.targets == l)
    indices = mask.nonzero(as_tuple=True)[0]
    return torch.utils.data.Subset(dataset, indices)

task_a_labels = [0, 1, 2, 3, 4]
task_b_labels = [5, 6, 7, 8, 9]

train_a = filter_by_labels(train_data, task_a_labels)
train_b = filter_by_labels(train_data, task_b_labels)
test_a = filter_by_labels(test_data, task_a_labels)
test_b = filter_by_labels(test_data, task_b_labels)

loader_a = torch.utils.data.DataLoader(train_a, batch_size=128, shuffle=True)
loader_b = torch.utils.data.DataLoader(train_b, batch_size=128, shuffle=True)
test_loader_a = torch.utils.data.DataLoader(test_a, batch_size=256)
test_loader_b = torch.utils.data.DataLoader(test_b, batch_size=256)

print(f"Task A (digits 0-4): {len(train_a)} train, {len(test_a)} test")
print(f"Task B (digits 5-9): {len(train_b)} train, {len(test_b)} test")

# --- 2. Simple CNN Model ---
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(64 * 7 * 7, 256)
        self.fc2 = nn.Linear(256, 10)  # All 10 classes share the output

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            pred = model(x).argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return correct / total

# --- 3. EWC Implementation ---
class EWC:
    def __init__(self, model, dataloader, device, n_samples=200):
        self.params = {n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad}
        self.fisher = self._compute_fisher(model, dataloader, device, n_samples)

    def _compute_fisher(self, model, dataloader, device, n_samples):
        """Compute Fisher Information Matrix (diagonal approximation)"""
        fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
        model.eval()
        count = 0
        for x, y in dataloader:
            if count >= n_samples:
                break
            x, y = x.to(device), y.to(device)
            model.zero_grad()
            output = model(x)
            loss = F.cross_entropy(output, y)
            loss.backward()
            for n, p in model.named_parameters():
                if p.requires_grad and p.grad is not None:
                    fisher[n] += p.grad.data.pow(2) * x.size(0)
            count += x.size(0)
        fisher = {n: f / count for n, f in fisher.items()}
        return fisher

    def penalty(self, model):
        """EWC regularization term"""
        loss = 0
        for n, p in model.named_parameters():
            if p.requires_grad and n in self.fisher:
                loss += (self.fisher[n] * (p - self.params[n]).pow(2)).sum()
        return loss

# --- 4. Experience Replay Memory Buffer ---
class ReplayBuffer:
    def __init__(self, capacity=200):
        self.capacity = capacity
        self.buffer_x = []
        self.buffer_y = []

    def add_from_loader(self, loader, n_samples):
        """Sample randomly from loader and add to buffer"""
        all_x, all_y = [], []
        for x, y in loader:
            all_x.append(x)
            all_y.append(y)
        all_x = torch.cat(all_x)
        all_y = torch.cat(all_y)
        indices = torch.randperm(len(all_x))[:n_samples]
        self.buffer_x.append(all_x[indices])
        self.buffer_y.append(all_y[indices])

    def sample(self, batch_size):
        all_x = torch.cat(self.buffer_x)
        all_y = torch.cat(self.buffer_y)
        indices = torch.randperm(len(all_x))[:batch_size]
        return all_x[indices], all_y[indices]

# --- 5. Training Function ---
def train_task(model, loader, optimizer, epochs, ewc=None, ewc_lambda=0,
               replay_buffer=None, replay_batch=32):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()

            output = model(x)
            loss = F.cross_entropy(output, y)

            # EWC regularization
            if ewc is not None:
                loss += ewc_lambda * ewc.penalty(model)

            # Experience replay
            if replay_buffer is not None and len(replay_buffer.buffer_x) > 0:
                rx, ry = replay_buffer.sample(replay_batch)
                rx, ry = rx.to(device), ry.to(device)
                r_output = model(rx)
                loss += F.cross_entropy(r_output, ry)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

# --- 6. Experiment: Comparing Three Strategies ---
n_epochs = 5
results = {}

# Strategy 1: Naive Fine-tuning
print("\n=== Strategy 1: Naive Fine-tuning ===")
model_naive = SimpleCNN().to(device)
opt = torch.optim.Adam(model_naive.parameters(), lr=1e-3)

train_task(model_naive, loader_a, opt, n_epochs)
acc_a_after_a = evaluate(model_naive, test_loader_a)
print(f"After Task A: Acc_A={acc_a_after_a:.4f}")

train_task(model_naive, loader_b, opt, n_epochs)
acc_a_after_b = evaluate(model_naive, test_loader_a)
acc_b_after_b = evaluate(model_naive, test_loader_b)
print(f"After Task B: Acc_A={acc_a_after_b:.4f}, Acc_B={acc_b_after_b:.4f}")
print(f"Forgetting: {acc_a_after_a - acc_a_after_b:.4f}")
results['Naive'] = (acc_a_after_a, acc_a_after_b, acc_b_after_b)

# Strategy 2: EWC
print("\n=== Strategy 2: EWC (λ=400) ===")
model_ewc = SimpleCNN().to(device)
opt = torch.optim.Adam(model_ewc.parameters(), lr=1e-3)

train_task(model_ewc, loader_a, opt, n_epochs)
acc_a_after_a = evaluate(model_ewc, test_loader_a)
print(f"After Task A: Acc_A={acc_a_after_a:.4f}")

# Compute Fisher matrix
ewc = EWC(model_ewc, loader_a, device)

train_task(model_ewc, loader_b, opt, n_epochs, ewc=ewc, ewc_lambda=400)
acc_a_after_b = evaluate(model_ewc, test_loader_a)
acc_b_after_b = evaluate(model_ewc, test_loader_b)
print(f"After Task B: Acc_A={acc_a_after_b:.4f}, Acc_B={acc_b_after_b:.4f}")
print(f"Forgetting: {acc_a_after_a - acc_a_after_b:.4f}")
results['EWC'] = (acc_a_after_a, acc_a_after_b, acc_b_after_b)

# Strategy 3: Experience Replay
print("\n=== Strategy 3: Experience Replay (buffer=200) ===")
model_er = SimpleCNN().to(device)
opt = torch.optim.Adam(model_er.parameters(), lr=1e-3)
buffer = ReplayBuffer(capacity=200)

train_task(model_er, loader_a, opt, n_epochs)
acc_a_after_a = evaluate(model_er, test_loader_a)
print(f"After Task A: Acc_A={acc_a_after_a:.4f}")

# Add Task A samples to buffer
buffer.add_from_loader(loader_a, n_samples=200)

train_task(model_er, loader_b, opt, n_epochs, replay_buffer=buffer)
acc_a_after_b = evaluate(model_er, test_loader_a)
acc_b_after_b = evaluate(model_er, test_loader_b)
print(f"After Task B: Acc_A={acc_a_after_b:.4f}, Acc_B={acc_b_after_b:.4f}")
print(f"Forgetting: {acc_a_after_a - acc_a_after_b:.4f}")
results['Replay'] = (acc_a_after_a, acc_a_after_b, acc_b_after_b)

# --- 7. Visualization ---
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left plot: Task A accuracy (forgetting extent)
strategies = list(results.keys())
acc_before = [results[s][0] for s in strategies]
acc_after = [results[s][1] for s in strategies]

x_pos = np.arange(len(strategies))
width = 0.35
bars1 = axes[0].bar(x_pos - width/2, acc_before, width, label='After Task A', color='#0077b6')
bars2 = axes[0].bar(x_pos + width/2, acc_after, width, label='After Task B', color='#e63946')

axes[0].set_ylabel('Task A Accuracy')
axes[0].set_title('Catastrophic Forgetting: Task A Performance', fontsize=13)
axes[0].set_xticks(x_pos)
axes[0].set_xticklabels(strategies)
axes[0].legend()
axes[0].set_ylim(0, 1.05)
axes[0].grid(True, alpha=0.3, axis='y')

for bar in bars1:
    axes[0].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
                 f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=10)
for bar in bars2:
    axes[0].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
                 f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=10)

# Right plot: Forgetting amount
forgetting = [results[s][0] - results[s][1] for s in strategies]
colors = ['#e63946' if f > 0.3 else '#b8922e' if f > 0.1 else '#2a9d8f' for f in forgetting]
bars = axes[1].bar(strategies, forgetting, color=colors, edgecolor='white', linewidth=1.5)

axes[1].set_ylabel('Forgetting (↓ better)')
axes[1].set_title('Amount of Forgetting on Task A', fontsize=13)
axes[1].grid(True, alpha=0.3, axis='y')

for bar, f in zip(bars, forgetting):
    axes[1].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
                 f'{f:.3f}', ha='center', va='bottom', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()

# --- 8. Per-Class Accuracy Analysis ---
print("\n=== Per-class Accuracy After Task B ===")
print(f"{'Class':<8} {'Naive':>8} {'EWC':>8} {'Replay':>8}")
print("-" * 36)

models = {'Naive': model_naive, 'EWC': model_ewc, 'Replay': model_er}
for digit in range(10):
    test_digit = filter_by_labels(test_data, [digit])
    loader_digit = torch.utils.data.DataLoader(test_digit, batch_size=256)
    accs = []
    for name in ['Naive', 'EWC', 'Replay']:
        acc = evaluate(models[name], loader_digit)
        accs.append(acc)
    marker = " ← Task A" if digit < 5 else ""
    print(f"  {digit:<6} {accs[0]:>8.1%} {accs[1]:>8.1%} {accs[2]:>8.1%}{marker}")

print("\nLab 1 Complete!")

8. ハンズオンラボ2:BERTの継続学習によるマルチタスクテキスト分類(Google Colab)

以下の実験では、言語モデルにおける壊滅的忘却を実証します:BERTを2つのテキスト分類タスクで順次ファインチューニングし、忘却現象を観察し、経験リプレイで緩和します。

# ============================================================
# Lab 2: BERT Continual Learning — Catastrophic Forgetting and Mitigation in Multi-Task Text Classification
# Environment: Google Colab (GPU recommended, CPU works but slower)
# ============================================================
# --- 0. Installation ---
!pip install -q transformers datasets torch

import torch
import torch.nn.functional as F
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
import random

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# --- 1. Load Two Text Classification Tasks ---
print("\n--- Loading Datasets ---")

# Task A: SST-2 Sentiment Classification (Positive/Negative)
sst2 = load_dataset("glue", "sst2")
print(f"Task A (SST-2): {len(sst2['train'])} train, {len(sst2['validation'])} val")

# Task B: MRPC Semantic Equivalence (Equivalent/Not Equivalent)
mrpc = load_dataset("glue", "mrpc")
print(f"Task B (MRPC):  {len(mrpc['train'])} train, {len(mrpc['validation'])} val")

# --- 2. Tokenizer ---
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize_sst2(examples):
    return tokenizer(examples['sentence'], truncation=True, padding='max_length', max_length=64)

def tokenize_mrpc(examples):
    return tokenizer(examples['sentence1'], examples['sentence2'],
                     truncation=True, padding='max_length', max_length=128)

sst2_tok = sst2.map(tokenize_sst2, batched=True)
mrpc_tok = mrpc.map(tokenize_mrpc, batched=True)

for ds in [sst2_tok, mrpc_tok]:
    ds.set_format("torch", columns=["input_ids", "attention_mask", "label"])

# Use subsets (Colab-friendly)
sst2_train = sst2_tok["train"].shuffle(seed=42).select(range(1000))
sst2_val = sst2_tok["validation"]
mrpc_train = mrpc_tok["train"].shuffle(seed=42).select(range(800))
mrpc_val = mrpc_tok["validation"]

# --- 3. Training and Evaluation Utilities ---
def make_loader(dataset, batch_size=16, shuffle=True):
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

def evaluate_task(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = outputs.logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total

def train_epoch(model, loader, optimizer, replay_data=None, replay_ratio=0.3):
    model.train()
    total_loss = 0
    for batch in loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        # Experience replay
        if replay_data is not None and len(replay_data) > 0:
            n_replay = max(1, int(input_ids.size(0) * replay_ratio))
            indices = random.sample(range(len(replay_data)), min(n_replay, len(replay_data)))

            r_ids = torch.stack([replay_data[i]['input_ids'] for i in indices]).to(device)
            r_mask = torch.stack([replay_data[i]['attention_mask'] for i in indices]).to(device)
            r_labels = torch.tensor([replay_data[i]['label'] for i in indices]).to(device)

            r_outputs = model(input_ids=r_ids, attention_mask=r_mask, labels=r_labels)
            loss = loss + r_outputs.loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(loader)

# --- 4. Create Replay Memory ---
def create_replay_buffer(dataset, n_samples=100):
    """Sample from dataset to create replay buffer"""
    indices = random.sample(range(len(dataset)), min(n_samples, len(dataset)))
    buffer = []
    for i in indices:
        item = dataset[i]
        buffer.append({
            'input_ids': item['input_ids'],
            'attention_mask': item['attention_mask'],
            'label': item['label'].item() if isinstance(item['label'], torch.Tensor) else item['label']
        })
    return buffer

# --- 5. Experiment 1: Naive Sequential Fine-tuning (Demonstrating Forgetting) ---
print("\n" + "="*60)
print("Experiment 1: Naive Sequential Fine-tuning")
print("="*60)

model_naive = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased', num_labels=2
).to(device)
opt = AdamW(model_naive.parameters(), lr=2e-5, weight_decay=0.01)

# Train Task A (SST-2)
print("\n--- Training on Task A (SST-2) ---")
loader_a = make_loader(sst2_train)
val_loader_a = make_loader(sst2_val, shuffle=False)

naive_history = {'task_a_on_a': [], 'task_a_on_b': [], 'task_b_on_b': []}

for epoch in range(3):
    loss = train_epoch(model_naive, loader_a, opt)
    acc = evaluate_task(model_naive, val_loader_a)
    naive_history['task_a_on_a'].append(acc)
    print(f"  Epoch {epoch+1}: loss={loss:.4f}, SST-2 acc={acc:.4f}")

acc_a_before = naive_history['task_a_on_a'][-1]

# Train Task B (MRPC)
print("\n--- Training on Task B (MRPC) ---")
loader_b = make_loader(mrpc_train)
val_loader_b = make_loader(mrpc_val, shuffle=False)

for epoch in range(3):
    loss = train_epoch(model_naive, loader_b, opt)
    acc_a = evaluate_task(model_naive, val_loader_a)
    acc_b = evaluate_task(model_naive, val_loader_b)
    naive_history['task_a_on_b'].append(acc_a)
    naive_history['task_b_on_b'].append(acc_b)
    print(f"  Epoch {epoch+1}: loss={loss:.4f}, SST-2 acc={acc_a:.4f}, MRPC acc={acc_b:.4f}")

acc_a_after_naive = naive_history['task_a_on_b'][-1]
acc_b_naive = naive_history['task_b_on_b'][-1]

# --- 6. Experiment 2: Experience Replay ---
print("\n" + "="*60)
print("Experiment 2: Experience Replay (buffer=100)")
print("="*60)

model_replay = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased', num_labels=2
).to(device)
opt = AdamW(model_replay.parameters(), lr=2e-5, weight_decay=0.01)

# Train Task A
print("\n--- Training on Task A (SST-2) ---")
replay_history = {'task_a_on_a': [], 'task_a_on_b': [], 'task_b_on_b': []}

for epoch in range(3):
    loss = train_epoch(model_replay, loader_a, opt)
    acc = evaluate_task(model_replay, val_loader_a)
    replay_history['task_a_on_a'].append(acc)
    print(f"  Epoch {epoch+1}: loss={loss:.4f}, SST-2 acc={acc:.4f}")

# Create replay buffer
replay_buffer = create_replay_buffer(sst2_train, n_samples=100)
print(f"\nReplay buffer: {len(replay_buffer)} samples from Task A")

# Train Task B + Replay
print("\n--- Training on Task B (MRPC) with Replay ---")
for epoch in range(3):
    loss = train_epoch(model_replay, loader_b, opt, replay_data=replay_buffer)
    acc_a = evaluate_task(model_replay, val_loader_a)
    acc_b = evaluate_task(model_replay, val_loader_b)
    replay_history['task_a_on_b'].append(acc_a)
    replay_history['task_b_on_b'].append(acc_b)
    print(f"  Epoch {epoch+1}: loss={loss:.4f}, SST-2 acc={acc_a:.4f}, MRPC acc={acc_b:.4f}")

acc_a_after_replay = replay_history['task_a_on_b'][-1]
acc_b_replay = replay_history['task_b_on_b'][-1]

# --- 7. Visualization Comparison ---
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Left plot: Task A accuracy over training
epochs_a = list(range(1, 4))
epochs_b = list(range(4, 7))
all_epochs = epochs_a + epochs_b

naive_a_curve = naive_history['task_a_on_a'] + naive_history['task_a_on_b']
replay_a_curve = replay_history['task_a_on_a'] + replay_history['task_a_on_b']

axes[0].plot(all_epochs, naive_a_curve, 'o-', color='#e63946', linewidth=2, label='Naive')
axes[0].plot(all_epochs, replay_a_curve, 's-', color='#0077b6', linewidth=2, label='Replay')
axes[0].axvline(x=3.5, color='gray', linestyle='--', alpha=0.5)
axes[0].text(2, 0.55, 'Task A\n(SST-2)', ha='center', fontsize=10, color='gray')
axes[0].text(5, 0.55, 'Task B\n(MRPC)', ha='center', fontsize=10, color='gray')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('SST-2 Accuracy')
axes[0].set_title('Task A (SST-2) Accuracy Over Time', fontsize=13)
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[0].set_ylim(0.5, 1.0)

# Middle plot: Final results comparison
categories = ['Task A\n(SST-2)', 'Task B\n(MRPC)']
naive_scores = [acc_a_after_naive, acc_b_naive]
replay_scores = [acc_a_after_replay, acc_b_replay]

x = np.arange(len(categories))
width = 0.35

axes[1].bar(x - width/2, naive_scores, width, label='Naive', color='#e63946', alpha=0.85)
axes[1].bar(x + width/2, replay_scores, width, label='Replay', color='#0077b6', alpha=0.85)
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Final Performance Comparison', fontsize=13)
axes[1].set_xticks(x)
axes[1].set_xticklabels(categories)
axes[1].legend()
axes[1].set_ylim(0.5, 1.0)
axes[1].grid(True, alpha=0.3, axis='y')

# Right plot: Forgetting amount
forgetting_naive = acc_a_before - acc_a_after_naive
forgetting_replay = replay_history['task_a_on_a'][-1] - acc_a_after_replay

bars = axes[2].bar(['Naive', 'Replay'], [forgetting_naive, forgetting_replay],
                    color=['#e63946', '#0077b6'], edgecolor='white', linewidth=1.5)
axes[2].set_ylabel('Forgetting (↓ better)')
axes[2].set_title('SST-2 Forgetting After MRPC Training', fontsize=13)
axes[2].grid(True, alpha=0.3, axis='y')

for bar, f in zip(bars, [forgetting_naive, forgetting_replay]):
    axes[2].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.005,
                 f'{f:.3f}', ha='center', va='bottom', fontsize=13, fontweight='bold')

plt.tight_layout()
plt.show()

# --- 8. Inference Demo ---
print("\n=== Inference Demo ===")
test_sentences = [
    ("This movie is absolutely wonderful!", "SST-2 (Positive)"),
    ("A terrible waste of time and money.", "SST-2 (Negative)"),
    ("The film was average, nothing special.", "SST-2 (Neutral-ish)"),
]

print("\n--- Naive Model ---")
model_naive.eval()
for text, label in test_sentences:
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=64).to(device)
    with torch.no_grad():
        logits = model_naive(**inputs).logits
    pred = "Positive" if logits.argmax().item() == 1 else "Negative"
    conf = torch.softmax(logits, dim=-1).max().item()
    print(f"  [{pred} {conf:.1%}] {text}  (expected: {label})")

print("\n--- Replay Model ---")
model_replay.eval()
for text, label in test_sentences:
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=64).to(device)
    with torch.no_grad():
        logits = model_replay(**inputs).logits
    pred = "Positive" if logits.argmax().item() == 1 else "Negative"
    conf = torch.softmax(logits, dim=-1).max().item()
    print(f"  [{pred} {conf:.1%}] {text}  (expected: {label})")

print("\nLab 2 Complete!")

9. 判断フレームワーク:企業はどの継続学習戦略を選ぶべきか

データの利用可能性、プライバシー制約、計算予算に基づき、企業は以下のフレームワークで適切な継続学習アプローチを選択できます:

条件推奨手法根拠
旧データの保存が可能、十分なメモリExperience Replay(ER / DER++)[9][14]最もシンプルで効果的;タスクあたり200サンプルで忘却を大幅に削減
プライバシー制約、旧データの保存不可EWC[3] + LwF[5]フィッシャー行列またはモデルスナップショットの保存のみ、生データ不要
多くのタスクが継続的に増加PackNet[7]またはHAT[8]固定モデル容量内でマルチタスクをサポート、追加ストレージ不要
少数のタスクだがゼロ忘却が必要Progressive Networks[6]完全な分離、ゼロ忘却を保証、ミッションクリティカルなシナリオに適する
言語モデルの継続ファインチューニングExperience Replay + 学習率スケジューリング[15]Transformerアーキテクチャに最も効果的;EWCはNLPでの効果が限定的
プライバシー制約 + 十分な計算資源Generative Replay(GR)[10]仮想の旧データを生成、プライバシーと忘却防止のバランスを実現
判断ツリー:

1. 旧タスクの実データを保存できるか?
   ├── はい → Experience Replay(DER++推奨)
   └── いいえ → 2へ

2. モデル容量を拡張できるか?
   ├── はい → Progressive Networks(ゼロ忘却)
   └── いいえ → 3へ

3. 計算予算は十分か?
   ├── はい → Generative Replay(GAN/VAEで仮想データを生成)
   └── いいえ → EWC + LwF(フィッシャー行列 + 旧モデルスナップショットのみ必要)

10. 結論と展望

壊滅的忘却[1]は、深層学習が真の人工知能に向かう道で最もコアとなる障壁の一つです。継続的に学習できないシステムは——どんなに強力であっても——単なる静的なツールであり、進化するインテリジェントな存在ではありません。

主要テーマの振り返り:

今後を見据えると、継続学習は学術研究からエンジニアリング実践へと移行しています。基盤モデルが普及するにつれ[17]、企業は毎回ゼロから再訓練するのではなく、モデルが新しいデータ、タスク、ドメインに継続的に適応することを必要としています。スパースな動的計算(MoE)モデルは異なるタスクに異なるエキスパートサブネットワークを自然に割り当て、継続学習のアーキテクチャ的可能性を提供します。パラメータ効率の良いファインチューニング(LoRAファインチューニング、Adapters)はバックボーンを凍結し小さなモジュールのみを訓練するもので、各タスクに軽量な専用適応を提供します——これは本質的に継続学習戦略そのものです。AIシステムが人間のように忘れることなく継続的に学習することを学んだとき、私たちは真の汎用知能にまた一歩近づくことになります。