Key Findings
  • 災難性遺忘[1]是深度學習的根本缺陷——當模型學習新任務時,會急劇遺忘舊任務的知識,這使得 AI 無法像人類一樣持續學習
  • 終身學習的三大策略類別:正則化方法(EWC[3]、SI[4])保護重要參數、架構方法(Progressive Networks[6])動態擴展網路、重播方法(ER[9]、GR[10])回顧舊任務樣本
  • 大型語言模型同樣面臨災難性遺忘[15]——持續微調 BERT/GPT 會迅速遺忘先前任務的能力,經驗重播是目前最有效的緩解方案
  • 本文附兩個 Google Colab 實作:EWC 防止 MNIST 影像分類遺忘、BERT 持續學習多任務文本分類,可直接在瀏覽器中執行

一、為何 AI 會遺忘:災難性遺忘的本質

人類可以持續學習新技能而不遺忘舊知識——學會騎自行車後,再學游泳不會讓你忘記如何騎車。然而,深度神經網路卻面臨一個根本性問題:災難性遺忘(Catastrophic Forgetting)[1][2]

當一個已在任務 A 上訓練好的網路,接著在任務 B 上訓練時,任務 A 的表現會急劇下降——不是逐漸退化,而是災難性地崩潰。這是因為梯度下降會無差別地更新所有參數以優化當前任務,覆蓋掉對舊任務至關重要的權重。

災難性遺忘的本質:

任務 A 訓練完成: θ* = argmin_θ L_A(θ)
                     → 參數收斂到 A 的最優解

任務 B 接續訓練: θ** = argmin_θ L_B(θ), 從 θ* 出發
                     → 參數遠離 A 的最優解,A 的表現崩潰

根本原因: 穩定性-可塑性困境 (Stability-Plasticity Dilemma)
  - 太穩定 → 無法學習新任務(欠擬合)
  - 太可塑 → 遺忘舊任務(災難性遺忘)
  - 目標: 在兩者之間找到平衡

具體表現(以影像分類為例):
  Step 1: 在數字 0-4 上訓練 → 準確率 98%
  Step 2: 在數字 5-9 上訓練 → 5-9 準確率 97%, 但 0-4 降至 ~20%
  原因: 5-9 的梯度更新破壞了區分 0-4 的關鍵權重

災難性遺忘不僅存在於影像分類中。語言模型持續微調時同樣嚴重[15]:在情感分類上微調的 BERT,接著在命名實體識別上微調後,情感分類能力會大幅退化。這個問題在大型語言模型時代變得更加關鍵——我們希望模型能持續從新資料中學習,而不是每次都從頭預訓練。

二、終身學習全景:三大策略類別

終身學習(Continual Learning / Lifelong Learning)的研究目標是讓模型在序列地學習多個任務時,既能學好新任務,又能保持舊任務的表現[12][13]。根據解決思路的不同,可分為三大類[17]

策略類別核心思想代表方法優勢局限
正則化方法在損失函數中加入懲罰項,限制重要參數的變動EWC[3]、SI[4]、LwF[5]不需存儲舊資料、記憶體固定任務數量增加後保護能力下降
架構方法為不同任務分配不同的網路結構或子網路Progressive Nets[6]、PackNet[7]、HAT[8]完全無遺忘(硬隔離)模型大小隨任務數增長
重播方法保存或生成舊任務的樣本,在學新任務時一起訓練ER[9]、GR[10]、GEM[11]、DER++[14]簡單有效、與其他方法可組合需要額外記憶體存儲舊樣本

終身學習的評估場景也有三種層次:

終身學習的三種場景:

1. Task-Incremental Learning (Task-IL):
   推論時知道當前是哪個任務 → 最簡單
   例: 「這是任務 B 的資料,用 B 的分類頭」

2. Domain-Incremental Learning (Domain-IL):
   任務結構相同,但資料分布改變 → 中等難度
   例: 同樣的 10 類分類,但影像風格從素描變為照片

3. Class-Incremental Learning (Class-IL):
   推論時不知道任務身份,需從所有已學類別中區分 → 最困難
   例: 先學 0-4,再學 5-9,測試時需區分 0-9 所有數字

難度排序: Task-IL < Domain-IL < Class-IL
實際應用中,Class-IL 最接近真實需求

三、正則化方法:EWC 與知識蒸餾

Elastic Weight Consolidation(EWC)

EWC[3] 是終身學習最具影響力的正則化方法,其靈感來自神經科學中的突觸鞏固——重要的突觸連接應該被保護,不那麼重要的可以自由更新。

核心問題是:如何衡量每個參數對舊任務的「重要性」?EWC 的答案是 Fisher 資訊矩陣

EWC 損失函數:

L_total(θ) = L_B(θ) + (λ/2) Σ_i F_i (θ_i - θ*_A,i)²

其中:
  L_B(θ):     新任務 B 的損失
  θ*_A:       舊任務 A 訓練完成後的最優參數
  F_i:        Fisher 資訊矩陣的對角元素(參數 i 對任務 A 的重要性)
  λ:          正則化強度(控制穩定性-可塑性平衡)

Fisher 資訊矩陣(對角近似):
  F_i = E_{x~D_A} [(∂ log p(y|x,θ) / ∂θ_i)²]

直觀理解:
  F_i 大 → 參數 i 對任務 A 很重要 → 強力限制其變動
  F_i 小 → 參數 i 對任務 A 不重要 → 可以自由更新以學習任務 B

幾何視角:
  任務 A 的最優解 θ*_A 周圍存在一個「低損失山谷」
  Fisher 矩陣描述了這個山谷的形狀
  EWC 引導任務 B 的優化沿著山谷延伸的方向移動
  → 找到同時適合 A 和 B 的參數

Synaptic Intelligence(SI)

SI[4] 是 EWC 的在線替代方案。EWC 需要在每個任務結束後計算 Fisher 矩陣,而 SI 在訓練過程中即時累積每個參數的重要性——追蹤每個參數在訓練中「走過的路徑」對損失下降的貢獻。

Learning without Forgetting(LwF)

LwF[5] 走了另一條路——不保護參數,而是保護輸出。在學習新任務前,先用新任務的資料通過舊模型取得「軟標籤」,然後在學習新任務時,同時用知識蒸餾損失保持舊任務的輸出分布不變。它的最大優勢是完全不需要存儲舊任務的資料

四、架構方法:Progressive Networks 與動態擴展

架構方法的哲學是:與其在有限的參數空間中艱難地平衡新舊任務,不如為每個任務分配專屬的網路容量。

Progressive Neural Networks

Rusu 等人[6]提出的 Progressive Networks 是最直接的方案——每學一個新任務,就在旁邊「長出」一個新的網路欄(column),並透過側向連接(lateral connections)讓新任務復用舊任務學到的特徵:

Progressive Neural Networks:

任務 1:  [Column 1] ← 正常訓練
任務 2:  [Column 1](凍結)←─ 側向連接 ──→ [Column 2] ← 只訓練這個
任務 3:  [Column 1](凍結)←─┐                [Column 2](凍結)←─┐
                              └─ 側向連接 ──→                      └─→ [Column 3]

優點: 完全零遺忘(舊欄被凍結)
缺點: 參數量線性增長(T 個任務 = T 倍參數)

PackNet 與 HAT

PackNet[7] 和 HAT[8] 則試圖在固定大小的網路中實現多任務:

方法策略機制特點
PackNet[7]迭代剪枝訓練 → 剪掉不重要的權重 → 釋放容量給下個任務每個任務有專屬的稀疏子網路
HAT[8]硬注意力遮罩學習每個任務的二值遮罩,保護被佔用的神經元遮罩可梯度優化,自動分配容量

五、經驗重播方法:記憶緩衝與生成重播

經驗重播(Experience Replay)方法的靈感來自認知科學中的記憶鞏固——人類在睡眠中會「重播」白天的經歷以鞏固記憶。在終身學習中,重播方法在學習新任務時混入舊任務的樣本[9]

經驗重播(ER)

最直接的方法:維護一個固定大小的記憶緩衝(memory buffer),存儲每個舊任務的少量代表性樣本。學習新任務時,每個 mini-batch 混合新任務資料和從緩衝中取樣的舊資料:

經驗重播流程:

記憶緩衝 M(固定大小,如 200 個樣本)

學習任務 t:
  for each mini-batch:
    batch_new = sample(D_t)           # 新任務資料
    batch_old = sample(M)             # 從緩衝取樣舊資料
    loss = L(batch_new) + L(batch_old)  # 聯合損失
    更新 θ

  任務完成後:
    將 D_t 的代表性樣本加入 M(使用 reservoir sampling 或 herding)

Reservoir Sampling:
  以機率 |M| / n 將第 n 個樣本加入緩衝,
  確保每個已見過的樣本被選中的機率相等

關鍵發現(Rolnick et al., 2019):
  僅 1-5 個樣本/類別 就能大幅減少遺忘
  → 極小的記憶代價即可獲得顯著的防遺忘效果

生成式重播(Generative Replay)

Shin 等人[10]提出了一個巧妙的替代方案:不存儲真實舊資料,而是訓練一個生成模型(如 GAN 或 VAE)來生成舊任務的虛擬樣本。這在隱私敏感的場景中特別有價值——醫療資料不能被存儲,但可以用生成模型重建其分布。

GEM 與 DER++

GEM[11](Gradient Episodic Memory)使用記憶中的樣本計算梯度約束:新任務的梯度更新不能增加舊任務在記憶樣本上的損失。DER++[14] 則結合經驗重播與知識蒸餾——不僅重播舊資料的標籤,還重播舊模型的軟輸出(logits),以「暗知識」的形式保留更豐富的資訊。

六、文字 AI 的終身學習:持續微調語言模型

大型語言模型的終身學習是當前研究的前沿[16]。當企業希望 BERT 或 GPT 持續適應新任務或新領域時,災難性遺忘會嚴重影響已有能力:

語言模型的終身學習場景:

1. 持續任務微調(Continual Task Fine-tuning):
   BERT → 情感分析 → NER → QA → 文本摘要
   問題: 後面的微調破壞前面任務的能力

2. 持續領域適應(Continual Domain Adaptation):
   通用 LLM → 金融領域 → 法律領域 → 醫療領域
   問題: 新領域的知識覆蓋舊領域的專業知識

3. 持續預訓練(Continual Pre-training):
   基礎模型 → 持續吸收新文件/新知識
   問題: 新知識可能破壞語言理解的基礎能力

語言模型遺忘的特殊挑戰:
  - 參數共享度極高(所有任務共用同一 Transformer)
  - 表徵空間的干擾更嚴重(語義重疊多)
  - 任務頭可以分離,但底層表徵難以隔離

Scialom 等人[15]的研究表明,經驗重播是目前語言模型終身學習最有效的方法——在學習新任務時混入少量舊任務樣本,即可顯著減少遺忘。這比 EWC 等正則化方法在 NLP 場景中更為有效,因為語言任務的參數重要性分布更加均勻,正則化約束的區分力有限。

七、Hands-on Lab 1:EWC 防止 MNIST 影像分類遺忘(Google Colab)

以下實驗在 Split MNIST 上對比三種策略:(1) 樸素微調(Naive)、(2) EWC 正則化、(3) 經驗重播(ER),直觀展示災難性遺忘現象及其緩解。

# ============================================================
# Lab 1: 終身學習 — EWC vs 經驗重播 vs 樸素微調(Split MNIST)
# 環境: Google Colab (CPU 即可)
# ============================================================
# --- 0. 安裝 ---
!pip install -q torch torchvision matplotlib

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import copy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# --- 1. 資料準備: Split MNIST ---
# 任務 A: 數字 0-4, 任務 B: 數字 5-9
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_data = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = torchvision.datasets.MNIST('./data', train=False, transform=transform)

def filter_by_labels(dataset, labels):
    """篩選特定標籤的資料"""
    mask = torch.zeros(len(dataset.targets), dtype=torch.bool)
    for l in labels:
        mask |= (dataset.targets == l)
    indices = mask.nonzero(as_tuple=True)[0]
    return torch.utils.data.Subset(dataset, indices)

task_a_labels = [0, 1, 2, 3, 4]
task_b_labels = [5, 6, 7, 8, 9]

train_a = filter_by_labels(train_data, task_a_labels)
train_b = filter_by_labels(train_data, task_b_labels)
test_a = filter_by_labels(test_data, task_a_labels)
test_b = filter_by_labels(test_data, task_b_labels)

loader_a = torch.utils.data.DataLoader(train_a, batch_size=128, shuffle=True)
loader_b = torch.utils.data.DataLoader(train_b, batch_size=128, shuffle=True)
test_loader_a = torch.utils.data.DataLoader(test_a, batch_size=256)
test_loader_b = torch.utils.data.DataLoader(test_b, batch_size=256)

print(f"Task A (digits 0-4): {len(train_a)} train, {len(test_a)} test")
print(f"Task B (digits 5-9): {len(train_b)} train, {len(test_b)} test")

# --- 2. 簡單 CNN 模型 ---
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(64 * 7 * 7, 256)
        self.fc2 = nn.Linear(256, 10)  # 所有 10 類共享輸出

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            pred = model(x).argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return correct / total

# --- 3. EWC 實現 ---
class EWC:
    def __init__(self, model, dataloader, device, n_samples=200):
        self.params = {n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad}
        self.fisher = self._compute_fisher(model, dataloader, device, n_samples)

    def _compute_fisher(self, model, dataloader, device, n_samples):
        """計算 Fisher 資訊矩陣(對角近似)"""
        fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
        model.eval()
        count = 0
        for x, y in dataloader:
            if count >= n_samples:
                break
            x, y = x.to(device), y.to(device)
            model.zero_grad()
            output = model(x)
            loss = F.cross_entropy(output, y)
            loss.backward()
            for n, p in model.named_parameters():
                if p.requires_grad and p.grad is not None:
                    fisher[n] += p.grad.data.pow(2) * x.size(0)
            count += x.size(0)
        fisher = {n: f / count for n, f in fisher.items()}
        return fisher

    def penalty(self, model):
        """EWC 正則化項"""
        loss = 0
        for n, p in model.named_parameters():
            if p.requires_grad and n in self.fisher:
                loss += (self.fisher[n] * (p - self.params[n]).pow(2)).sum()
        return loss

# --- 4. 經驗重播記憶緩衝 ---
class ReplayBuffer:
    def __init__(self, capacity=200):
        self.capacity = capacity
        self.buffer_x = []
        self.buffer_y = []

    def add_from_loader(self, loader, n_samples):
        """從 loader 中隨機取樣加入緩衝"""
        all_x, all_y = [], []
        for x, y in loader:
            all_x.append(x)
            all_y.append(y)
        all_x = torch.cat(all_x)
        all_y = torch.cat(all_y)
        indices = torch.randperm(len(all_x))[:n_samples]
        self.buffer_x.append(all_x[indices])
        self.buffer_y.append(all_y[indices])

    def sample(self, batch_size):
        all_x = torch.cat(self.buffer_x)
        all_y = torch.cat(self.buffer_y)
        indices = torch.randperm(len(all_x))[:batch_size]
        return all_x[indices], all_y[indices]

# --- 5. 訓練函數 ---
def train_task(model, loader, optimizer, epochs, ewc=None, ewc_lambda=0,
               replay_buffer=None, replay_batch=32):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()

            output = model(x)
            loss = F.cross_entropy(output, y)

            # EWC 正則化
            if ewc is not None:
                loss += ewc_lambda * ewc.penalty(model)

            # 經驗重播
            if replay_buffer is not None and len(replay_buffer.buffer_x) > 0:
                rx, ry = replay_buffer.sample(replay_batch)
                rx, ry = rx.to(device), ry.to(device)
                r_output = model(rx)
                loss += F.cross_entropy(r_output, ry)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

# --- 6. 實驗: 三種策略對比 ---
n_epochs = 5
results = {}

# 策略 1: 樸素微調(Naive)
print("\n=== Strategy 1: Naive Fine-tuning ===")
model_naive = SimpleCNN().to(device)
opt = torch.optim.Adam(model_naive.parameters(), lr=1e-3)

train_task(model_naive, loader_a, opt, n_epochs)
acc_a_after_a = evaluate(model_naive, test_loader_a)
print(f"After Task A: Acc_A={acc_a_after_a:.4f}")

train_task(model_naive, loader_b, opt, n_epochs)
acc_a_after_b = evaluate(model_naive, test_loader_a)
acc_b_after_b = evaluate(model_naive, test_loader_b)
print(f"After Task B: Acc_A={acc_a_after_b:.4f}, Acc_B={acc_b_after_b:.4f}")
print(f"Forgetting: {acc_a_after_a - acc_a_after_b:.4f}")
results['Naive'] = (acc_a_after_a, acc_a_after_b, acc_b_after_b)

# 策略 2: EWC
print("\n=== Strategy 2: EWC (λ=400) ===")
model_ewc = SimpleCNN().to(device)
opt = torch.optim.Adam(model_ewc.parameters(), lr=1e-3)

train_task(model_ewc, loader_a, opt, n_epochs)
acc_a_after_a = evaluate(model_ewc, test_loader_a)
print(f"After Task A: Acc_A={acc_a_after_a:.4f}")

# 計算 Fisher 矩陣
ewc = EWC(model_ewc, loader_a, device)

train_task(model_ewc, loader_b, opt, n_epochs, ewc=ewc, ewc_lambda=400)
acc_a_after_b = evaluate(model_ewc, test_loader_a)
acc_b_after_b = evaluate(model_ewc, test_loader_b)
print(f"After Task B: Acc_A={acc_a_after_b:.4f}, Acc_B={acc_b_after_b:.4f}")
print(f"Forgetting: {acc_a_after_a - acc_a_after_b:.4f}")
results['EWC'] = (acc_a_after_a, acc_a_after_b, acc_b_after_b)

# 策略 3: 經驗重播
print("\n=== Strategy 3: Experience Replay (buffer=200) ===")
model_er = SimpleCNN().to(device)
opt = torch.optim.Adam(model_er.parameters(), lr=1e-3)
buffer = ReplayBuffer(capacity=200)

train_task(model_er, loader_a, opt, n_epochs)
acc_a_after_a = evaluate(model_er, test_loader_a)
print(f"After Task A: Acc_A={acc_a_after_a:.4f}")

# 將任務 A 的樣本加入緩衝
buffer.add_from_loader(loader_a, n_samples=200)

train_task(model_er, loader_b, opt, n_epochs, replay_buffer=buffer)
acc_a_after_b = evaluate(model_er, test_loader_a)
acc_b_after_b = evaluate(model_er, test_loader_b)
print(f"After Task B: Acc_A={acc_a_after_b:.4f}, Acc_B={acc_b_after_b:.4f}")
print(f"Forgetting: {acc_a_after_a - acc_a_after_b:.4f}")
results['Replay'] = (acc_a_after_a, acc_a_after_b, acc_b_after_b)

# --- 7. 視覺化結果 ---
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# 左圖: 任務 A 準確率(遺忘程度)
strategies = list(results.keys())
acc_before = [results[s][0] for s in strategies]
acc_after = [results[s][1] for s in strategies]

x_pos = np.arange(len(strategies))
width = 0.35
bars1 = axes[0].bar(x_pos - width/2, acc_before, width, label='After Task A', color='#0077b6')
bars2 = axes[0].bar(x_pos + width/2, acc_after, width, label='After Task B', color='#e63946')

axes[0].set_ylabel('Task A Accuracy')
axes[0].set_title('Catastrophic Forgetting: Task A Performance', fontsize=13)
axes[0].set_xticks(x_pos)
axes[0].set_xticklabels(strategies)
axes[0].legend()
axes[0].set_ylim(0, 1.05)
axes[0].grid(True, alpha=0.3, axis='y')

for bar in bars1:
    axes[0].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
                 f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=10)
for bar in bars2:
    axes[0].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
                 f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=10)

# 右圖: 遺忘量
forgetting = [results[s][0] - results[s][1] for s in strategies]
colors = ['#e63946' if f > 0.3 else '#b8922e' if f > 0.1 else '#2a9d8f' for f in forgetting]
bars = axes[1].bar(strategies, forgetting, color=colors, edgecolor='white', linewidth=1.5)

axes[1].set_ylabel('Forgetting (↓ better)')
axes[1].set_title('Amount of Forgetting on Task A', fontsize=13)
axes[1].grid(True, alpha=0.3, axis='y')

for bar, f in zip(bars, forgetting):
    axes[1].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
                 f'{f:.3f}', ha='center', va='bottom', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()

# --- 8. 逐類準確率分析 ---
print("\n=== Per-class Accuracy After Task B ===")
print(f"{'Class':<8} {'Naive':>8} {'EWC':>8} {'Replay':>8}")
print("-" * 36)

models = {'Naive': model_naive, 'EWC': model_ewc, 'Replay': model_er}
for digit in range(10):
    test_digit = filter_by_labels(test_data, [digit])
    loader_digit = torch.utils.data.DataLoader(test_digit, batch_size=256)
    accs = []
    for name in ['Naive', 'EWC', 'Replay']:
        acc = evaluate(models[name], loader_digit)
        accs.append(acc)
    marker = " ← Task A" if digit < 5 else ""
    print(f"  {digit:<6} {accs[0]:>8.1%} {accs[1]:>8.1%} {accs[2]:>8.1%}{marker}")

print("\nLab 1 Complete!")

八、Hands-on Lab 2:BERT 持續學習多任務文本分類(Google Colab)

以下實驗展示語言模型的災難性遺忘:在兩個文本分類任務上依序微調 BERT,觀察遺忘現象,並以經驗重播進行緩解。

# ============================================================
# Lab 2: BERT 持續學習 — 多任務文本分類的災難性遺忘與緩解
# 環境: Google Colab (GPU 推薦, CPU 可用但較慢)
# ============================================================
# --- 0. 安裝 ---
!pip install -q transformers datasets torch

import torch
import torch.nn.functional as F
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
import random

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# --- 1. 載入兩個文本分類任務 ---
print("\n--- Loading Datasets ---")

# 任務 A: SST-2 情感分類(正/負)
sst2 = load_dataset("glue", "sst2")
print(f"Task A (SST-2): {len(sst2['train'])} train, {len(sst2['validation'])} val")

# 任務 B: MRPC 語義等價判斷(等價/不等價)
mrpc = load_dataset("glue", "mrpc")
print(f"Task B (MRPC):  {len(mrpc['train'])} train, {len(mrpc['validation'])} val")

# --- 2. Tokenizer ---
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize_sst2(examples):
    return tokenizer(examples['sentence'], truncation=True, padding='max_length', max_length=64)

def tokenize_mrpc(examples):
    return tokenizer(examples['sentence1'], examples['sentence2'],
                     truncation=True, padding='max_length', max_length=128)

sst2_tok = sst2.map(tokenize_sst2, batched=True)
mrpc_tok = mrpc.map(tokenize_mrpc, batched=True)

for ds in [sst2_tok, mrpc_tok]:
    ds.set_format("torch", columns=["input_ids", "attention_mask", "label"])

# 使用子集(Colab 友善)
sst2_train = sst2_tok["train"].shuffle(seed=42).select(range(1000))
sst2_val = sst2_tok["validation"]
mrpc_train = mrpc_tok["train"].shuffle(seed=42).select(range(800))
mrpc_val = mrpc_tok["validation"]

# --- 3. 訓練與評估工具 ---
def make_loader(dataset, batch_size=16, shuffle=True):
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

def evaluate_task(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = outputs.logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total

def train_epoch(model, loader, optimizer, replay_data=None, replay_ratio=0.3):
    model.train()
    total_loss = 0
    for batch in loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        # 經驗重播
        if replay_data is not None and len(replay_data) > 0:
            n_replay = max(1, int(input_ids.size(0) * replay_ratio))
            indices = random.sample(range(len(replay_data)), min(n_replay, len(replay_data)))

            r_ids = torch.stack([replay_data[i]['input_ids'] for i in indices]).to(device)
            r_mask = torch.stack([replay_data[i]['attention_mask'] for i in indices]).to(device)
            r_labels = torch.tensor([replay_data[i]['label'] for i in indices]).to(device)

            r_outputs = model(input_ids=r_ids, attention_mask=r_mask, labels=r_labels)
            loss = loss + r_outputs.loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(loader)

# --- 4. 建立重播記憶 ---
def create_replay_buffer(dataset, n_samples=100):
    """從資料集中取樣建立重播緩衝"""
    indices = random.sample(range(len(dataset)), min(n_samples, len(dataset)))
    buffer = []
    for i in indices:
        item = dataset[i]
        buffer.append({
            'input_ids': item['input_ids'],
            'attention_mask': item['attention_mask'],
            'label': item['label'].item() if isinstance(item['label'], torch.Tensor) else item['label']
        })
    return buffer

# --- 5. 實驗 1: 樸素持續微調(展示遺忘)---
print("\n" + "="*60)
print("Experiment 1: Naive Sequential Fine-tuning")
print("="*60)

model_naive = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased', num_labels=2
).to(device)
opt = AdamW(model_naive.parameters(), lr=2e-5, weight_decay=0.01)

# 訓練任務 A (SST-2)
print("\n--- Training on Task A (SST-2) ---")
loader_a = make_loader(sst2_train)
val_loader_a = make_loader(sst2_val, shuffle=False)

naive_history = {'task_a_on_a': [], 'task_a_on_b': [], 'task_b_on_b': []}

for epoch in range(3):
    loss = train_epoch(model_naive, loader_a, opt)
    acc = evaluate_task(model_naive, val_loader_a)
    naive_history['task_a_on_a'].append(acc)
    print(f"  Epoch {epoch+1}: loss={loss:.4f}, SST-2 acc={acc:.4f}")

acc_a_before = naive_history['task_a_on_a'][-1]

# 訓練任務 B (MRPC)
print("\n--- Training on Task B (MRPC) ---")
loader_b = make_loader(mrpc_train)
val_loader_b = make_loader(mrpc_val, shuffle=False)

for epoch in range(3):
    loss = train_epoch(model_naive, loader_b, opt)
    acc_a = evaluate_task(model_naive, val_loader_a)
    acc_b = evaluate_task(model_naive, val_loader_b)
    naive_history['task_a_on_b'].append(acc_a)
    naive_history['task_b_on_b'].append(acc_b)
    print(f"  Epoch {epoch+1}: loss={loss:.4f}, SST-2 acc={acc_a:.4f}, MRPC acc={acc_b:.4f}")

acc_a_after_naive = naive_history['task_a_on_b'][-1]
acc_b_naive = naive_history['task_b_on_b'][-1]

# --- 6. 實驗 2: 經驗重播 ---
print("\n" + "="*60)
print("Experiment 2: Experience Replay (buffer=100)")
print("="*60)

model_replay = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased', num_labels=2
).to(device)
opt = AdamW(model_replay.parameters(), lr=2e-5, weight_decay=0.01)

# 訓練任務 A
print("\n--- Training on Task A (SST-2) ---")
replay_history = {'task_a_on_a': [], 'task_a_on_b': [], 'task_b_on_b': []}

for epoch in range(3):
    loss = train_epoch(model_replay, loader_a, opt)
    acc = evaluate_task(model_replay, val_loader_a)
    replay_history['task_a_on_a'].append(acc)
    print(f"  Epoch {epoch+1}: loss={loss:.4f}, SST-2 acc={acc:.4f}")

# 建立重播緩衝
replay_buffer = create_replay_buffer(sst2_train, n_samples=100)
print(f"\nReplay buffer: {len(replay_buffer)} samples from Task A")

# 訓練任務 B + 重播
print("\n--- Training on Task B (MRPC) with Replay ---")
for epoch in range(3):
    loss = train_epoch(model_replay, loader_b, opt, replay_data=replay_buffer)
    acc_a = evaluate_task(model_replay, val_loader_a)
    acc_b = evaluate_task(model_replay, val_loader_b)
    replay_history['task_a_on_b'].append(acc_a)
    replay_history['task_b_on_b'].append(acc_b)
    print(f"  Epoch {epoch+1}: loss={loss:.4f}, SST-2 acc={acc_a:.4f}, MRPC acc={acc_b:.4f}")

acc_a_after_replay = replay_history['task_a_on_b'][-1]
acc_b_replay = replay_history['task_b_on_b'][-1]

# --- 7. 視覺化比較 ---
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# 左圖: Task A 準確率隨訓練變化
epochs_a = list(range(1, 4))
epochs_b = list(range(4, 7))
all_epochs = epochs_a + epochs_b

naive_a_curve = naive_history['task_a_on_a'] + naive_history['task_a_on_b']
replay_a_curve = replay_history['task_a_on_a'] + replay_history['task_a_on_b']

axes[0].plot(all_epochs, naive_a_curve, 'o-', color='#e63946', linewidth=2, label='Naive')
axes[0].plot(all_epochs, replay_a_curve, 's-', color='#0077b6', linewidth=2, label='Replay')
axes[0].axvline(x=3.5, color='gray', linestyle='--', alpha=0.5)
axes[0].text(2, 0.55, 'Task A\n(SST-2)', ha='center', fontsize=10, color='gray')
axes[0].text(5, 0.55, 'Task B\n(MRPC)', ha='center', fontsize=10, color='gray')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('SST-2 Accuracy')
axes[0].set_title('Task A (SST-2) Accuracy Over Time', fontsize=13)
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[0].set_ylim(0.5, 1.0)

# 中圖: 最終結果對比
categories = ['Task A\n(SST-2)', 'Task B\n(MRPC)']
naive_scores = [acc_a_after_naive, acc_b_naive]
replay_scores = [acc_a_after_replay, acc_b_replay]

x = np.arange(len(categories))
width = 0.35

axes[1].bar(x - width/2, naive_scores, width, label='Naive', color='#e63946', alpha=0.85)
axes[1].bar(x + width/2, replay_scores, width, label='Replay', color='#0077b6', alpha=0.85)
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Final Performance Comparison', fontsize=13)
axes[1].set_xticks(x)
axes[1].set_xticklabels(categories)
axes[1].legend()
axes[1].set_ylim(0.5, 1.0)
axes[1].grid(True, alpha=0.3, axis='y')

# 右圖: 遺忘量
forgetting_naive = acc_a_before - acc_a_after_naive
forgetting_replay = replay_history['task_a_on_a'][-1] - acc_a_after_replay

bars = axes[2].bar(['Naive', 'Replay'], [forgetting_naive, forgetting_replay],
                    color=['#e63946', '#0077b6'], edgecolor='white', linewidth=1.5)
axes[2].set_ylabel('Forgetting (↓ better)')
axes[2].set_title('SST-2 Forgetting After MRPC Training', fontsize=13)
axes[2].grid(True, alpha=0.3, axis='y')

for bar, f in zip(bars, [forgetting_naive, forgetting_replay]):
    axes[2].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.005,
                 f'{f:.3f}', ha='center', va='bottom', fontsize=13, fontweight='bold')

plt.tight_layout()
plt.show()

# --- 8. 推論示範 ---
print("\n=== Inference Demo ===")
test_sentences = [
    ("This movie is absolutely wonderful!", "SST-2 (Positive)"),
    ("A terrible waste of time and money.", "SST-2 (Negative)"),
    ("The film was average, nothing special.", "SST-2 (Neutral-ish)"),
]

print("\n--- Naive Model ---")
model_naive.eval()
for text, label in test_sentences:
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=64).to(device)
    with torch.no_grad():
        logits = model_naive(**inputs).logits
    pred = "Positive" if logits.argmax().item() == 1 else "Negative"
    conf = torch.softmax(logits, dim=-1).max().item()
    print(f"  [{pred} {conf:.1%}] {text}  (expected: {label})")

print("\n--- Replay Model ---")
model_replay.eval()
for text, label in test_sentences:
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=64).to(device)
    with torch.no_grad():
        logits = model_replay(**inputs).logits
    pred = "Positive" if logits.argmax().item() == 1 else "Negative"
    conf = torch.softmax(logits, dim=-1).max().item()
    print(f"  [{pred} {conf:.1%}] {text}  (expected: {label})")

print("\nLab 2 Complete!")

九、決策框架:企業如何選擇終身學習策略

根據資料可用性、隱私約束與計算預算,企業可以用以下框架選擇適合的終身學習方案:

條件推薦方法理由
可存儲舊資料、記憶體充足經驗重播(ER / DER++)[9][14]最簡單有效,200 個樣本/任務即可大幅降低遺忘
隱私約束,不能存舊資料EWC[3] + LwF[5]僅需存 Fisher 矩陣或模型快照,無需原始資料
任務數量多且持續增長PackNet[7] 或 HAT[8]在固定模型容量中支持多任務,不需額外存儲
任務數量少但要求零遺忘Progressive Networks[6]完全隔離,零遺忘保證,適合關鍵任務場景
語言模型持續微調經驗重播 + 學習率調整[15]對 Transformer 架構最有效,EWC 在 NLP 中效果有限
隱私約束 + 充足計算生成式重播(GR)[10]生成虛擬舊資料,兼顧隱私與防遺忘
決策樹:

1. 是否可以存儲舊任務的真實資料?
   ├── 是 → 經驗重播(首選 DER++)
   └── 否 → 2

2. 模型容量是否可以增長?
   ├── 是 → Progressive Networks(零遺忘)
   └── 否 → 3

3. 計算預算是否充足?
   ├── 是 → 生成式重播(GAN/VAE 生成虛擬資料)
   └── 否 → EWC + LwF(僅需存 Fisher 矩陣 + 舊模型快照)

十、結語與展望

災難性遺忘[1]是深度學習通往真正人工智慧的核心障礙之一。一個無法持續學習的系統——無論多麼強大——都只是靜態的工具,而非進化的智慧體。

回顧核心脈絡:

展望未來,終身學習正在從學術研究走向工程實踐。隨著 Foundation Model 的普及[17],企業需要讓模型持續適應新資料、新任務和新領域——而不是每次都從頭預訓練。稀疏混合專家(MoE)模型天然地為不同任務分配不同的專家子網路,提供了架構層面的終身學習潛力;而參數高效微調(LoRA、Adapter)透過凍結主幹、只訓練小型模組,為每個任務提供了輕量級的專屬適配——這本質上就是一種終身學習策略。當 AI 系統學會像人類一樣持續學習而不遺忘,我們離真正的通用智慧又近了一步。