一、為何 AI 會遺忘:災難性遺忘的本質
人類可以持續學習新技能而不遺忘舊知識——學會騎自行車後,再學游泳不會讓你忘記如何騎車。然而,深度神經網路卻面臨一個根本性問題:災難性遺忘(Catastrophic Forgetting)[1][2]。
當一個已在任務 A 上訓練好的網路,接著在任務 B 上訓練時,任務 A 的表現會急劇下降——不是逐漸退化,而是災難性地崩潰。這是因為梯度下降會無差別地更新所有參數以優化當前任務,覆蓋掉對舊任務至關重要的權重。
災難性遺忘的本質:
任務 A 訓練完成: θ* = argmin_θ L_A(θ)
→ 參數收斂到 A 的最優解
任務 B 接續訓練: θ** = argmin_θ L_B(θ), 從 θ* 出發
→ 參數遠離 A 的最優解,A 的表現崩潰
根本原因: 穩定性-可塑性困境 (Stability-Plasticity Dilemma)
- 太穩定 → 無法學習新任務(欠擬合)
- 太可塑 → 遺忘舊任務(災難性遺忘)
- 目標: 在兩者之間找到平衡
具體表現(以影像分類為例):
Step 1: 在數字 0-4 上訓練 → 準確率 98%
Step 2: 在數字 5-9 上訓練 → 5-9 準確率 97%, 但 0-4 降至 ~20%
原因: 5-9 的梯度更新破壞了區分 0-4 的關鍵權重
災難性遺忘不僅存在於影像分類中。語言模型持續微調時同樣嚴重[15]:在情感分類上微調的 BERT,接著在命名實體識別上微調後,情感分類能力會大幅退化。這個問題在大型語言模型時代變得更加關鍵——我們希望模型能持續從新資料中學習,而不是每次都從頭預訓練。
二、終身學習全景:三大策略類別
終身學習(Continual Learning / Lifelong Learning)的研究目標是讓模型在序列地學習多個任務時,既能學好新任務,又能保持舊任務的表現[12][13]。根據解決思路的不同,可分為三大類[17]:
| 策略類別 | 核心思想 | 代表方法 | 優勢 | 局限 |
|---|---|---|---|---|
| 正則化方法 | 在損失函數中加入懲罰項,限制重要參數的變動 | EWC[3]、SI[4]、LwF[5] | 不需存儲舊資料、記憶體固定 | 任務數量增加後保護能力下降 |
| 架構方法 | 為不同任務分配不同的網路結構或子網路 | Progressive Nets[6]、PackNet[7]、HAT[8] | 完全無遺忘(硬隔離) | 模型大小隨任務數增長 |
| 重播方法 | 保存或生成舊任務的樣本,在學新任務時一起訓練 | ER[9]、GR[10]、GEM[11]、DER++[14] | 簡單有效、與其他方法可組合 | 需要額外記憶體存儲舊樣本 |
終身學習的評估場景也有三種層次:
終身學習的三種場景:
1. Task-Incremental Learning (Task-IL):
推論時知道當前是哪個任務 → 最簡單
例: 「這是任務 B 的資料,用 B 的分類頭」
2. Domain-Incremental Learning (Domain-IL):
任務結構相同,但資料分布改變 → 中等難度
例: 同樣的 10 類分類,但影像風格從素描變為照片
3. Class-Incremental Learning (Class-IL):
推論時不知道任務身份,需從所有已學類別中區分 → 最困難
例: 先學 0-4,再學 5-9,測試時需區分 0-9 所有數字
難度排序: Task-IL < Domain-IL < Class-IL
實際應用中,Class-IL 最接近真實需求
三、正則化方法:EWC 與知識蒸餾
Elastic Weight Consolidation(EWC)
EWC[3] 是終身學習最具影響力的正則化方法,其靈感來自神經科學中的突觸鞏固——重要的突觸連接應該被保護,不那麼重要的可以自由更新。
核心問題是:如何衡量每個參數對舊任務的「重要性」?EWC 的答案是 Fisher 資訊矩陣:
EWC 損失函數:
L_total(θ) = L_B(θ) + (λ/2) Σ_i F_i (θ_i - θ*_A,i)²
其中:
L_B(θ): 新任務 B 的損失
θ*_A: 舊任務 A 訓練完成後的最優參數
F_i: Fisher 資訊矩陣的對角元素(參數 i 對任務 A 的重要性)
λ: 正則化強度(控制穩定性-可塑性平衡)
Fisher 資訊矩陣(對角近似):
F_i = E_{x~D_A} [(∂ log p(y|x,θ) / ∂θ_i)²]
直觀理解:
F_i 大 → 參數 i 對任務 A 很重要 → 強力限制其變動
F_i 小 → 參數 i 對任務 A 不重要 → 可以自由更新以學習任務 B
幾何視角:
任務 A 的最優解 θ*_A 周圍存在一個「低損失山谷」
Fisher 矩陣描述了這個山谷的形狀
EWC 引導任務 B 的優化沿著山谷延伸的方向移動
→ 找到同時適合 A 和 B 的參數
Synaptic Intelligence(SI)
SI[4] 是 EWC 的在線替代方案。EWC 需要在每個任務結束後計算 Fisher 矩陣,而 SI 在訓練過程中即時累積每個參數的重要性——追蹤每個參數在訓練中「走過的路徑」對損失下降的貢獻。
Learning without Forgetting(LwF)
LwF[5] 走了另一條路——不保護參數,而是保護輸出。在學習新任務前,先用新任務的資料通過舊模型取得「軟標籤」,然後在學習新任務時,同時用知識蒸餾損失保持舊任務的輸出分布不變。它的最大優勢是完全不需要存儲舊任務的資料。
四、架構方法:Progressive Networks 與動態擴展
架構方法的哲學是:與其在有限的參數空間中艱難地平衡新舊任務,不如為每個任務分配專屬的網路容量。
Progressive Neural Networks
Rusu 等人[6]提出的 Progressive Networks 是最直接的方案——每學一個新任務,就在旁邊「長出」一個新的網路欄(column),並透過側向連接(lateral connections)讓新任務復用舊任務學到的特徵:
Progressive Neural Networks:
任務 1: [Column 1] ← 正常訓練
任務 2: [Column 1](凍結)←─ 側向連接 ──→ [Column 2] ← 只訓練這個
任務 3: [Column 1](凍結)←─┐ [Column 2](凍結)←─┐
└─ 側向連接 ──→ └─→ [Column 3]
優點: 完全零遺忘(舊欄被凍結)
缺點: 參數量線性增長(T 個任務 = T 倍參數)
PackNet 與 HAT
PackNet[7] 和 HAT[8] 則試圖在固定大小的網路中實現多任務:
| 方法 | 策略 | 機制 | 特點 |
|---|---|---|---|
| PackNet[7] | 迭代剪枝 | 訓練 → 剪掉不重要的權重 → 釋放容量給下個任務 | 每個任務有專屬的稀疏子網路 |
| HAT[8] | 硬注意力遮罩 | 學習每個任務的二值遮罩,保護被佔用的神經元 | 遮罩可梯度優化,自動分配容量 |
五、經驗重播方法:記憶緩衝與生成重播
經驗重播(Experience Replay)方法的靈感來自認知科學中的記憶鞏固——人類在睡眠中會「重播」白天的經歷以鞏固記憶。在終身學習中,重播方法在學習新任務時混入舊任務的樣本[9]。
經驗重播(ER)
最直接的方法:維護一個固定大小的記憶緩衝(memory buffer),存儲每個舊任務的少量代表性樣本。學習新任務時,每個 mini-batch 混合新任務資料和從緩衝中取樣的舊資料:
經驗重播流程:
記憶緩衝 M(固定大小,如 200 個樣本)
學習任務 t:
for each mini-batch:
batch_new = sample(D_t) # 新任務資料
batch_old = sample(M) # 從緩衝取樣舊資料
loss = L(batch_new) + L(batch_old) # 聯合損失
更新 θ
任務完成後:
將 D_t 的代表性樣本加入 M(使用 reservoir sampling 或 herding)
Reservoir Sampling:
以機率 |M| / n 將第 n 個樣本加入緩衝,
確保每個已見過的樣本被選中的機率相等
關鍵發現(Rolnick et al., 2019):
僅 1-5 個樣本/類別 就能大幅減少遺忘
→ 極小的記憶代價即可獲得顯著的防遺忘效果
生成式重播(Generative Replay)
Shin 等人[10]提出了一個巧妙的替代方案:不存儲真實舊資料,而是訓練一個生成模型(如 GAN 或 VAE)來生成舊任務的虛擬樣本。這在隱私敏感的場景中特別有價值——醫療資料不能被存儲,但可以用生成模型重建其分布。
GEM 與 DER++
GEM[11](Gradient Episodic Memory)使用記憶中的樣本計算梯度約束:新任務的梯度更新不能增加舊任務在記憶樣本上的損失。DER++[14] 則結合經驗重播與知識蒸餾——不僅重播舊資料的標籤,還重播舊模型的軟輸出(logits),以「暗知識」的形式保留更豐富的資訊。
六、文字 AI 的終身學習:持續微調語言模型
大型語言模型的終身學習是當前研究的前沿[16]。當企業希望 BERT 或 GPT 持續適應新任務或新領域時,災難性遺忘會嚴重影響已有能力:
語言模型的終身學習場景:
1. 持續任務微調(Continual Task Fine-tuning):
BERT → 情感分析 → NER → QA → 文本摘要
問題: 後面的微調破壞前面任務的能力
2. 持續領域適應(Continual Domain Adaptation):
通用 LLM → 金融領域 → 法律領域 → 醫療領域
問題: 新領域的知識覆蓋舊領域的專業知識
3. 持續預訓練(Continual Pre-training):
基礎模型 → 持續吸收新文件/新知識
問題: 新知識可能破壞語言理解的基礎能力
語言模型遺忘的特殊挑戰:
- 參數共享度極高(所有任務共用同一 Transformer)
- 表徵空間的干擾更嚴重(語義重疊多)
- 任務頭可以分離,但底層表徵難以隔離
Scialom 等人[15]的研究表明,經驗重播是目前語言模型終身學習最有效的方法——在學習新任務時混入少量舊任務樣本,即可顯著減少遺忘。這比 EWC 等正則化方法在 NLP 場景中更為有效,因為語言任務的參數重要性分布更加均勻,正則化約束的區分力有限。
七、Hands-on Lab 1:EWC 防止 MNIST 影像分類遺忘(Google Colab)
以下實驗在 Split MNIST 上對比三種策略:(1) 樸素微調(Naive)、(2) EWC 正則化、(3) 經驗重播(ER),直觀展示災難性遺忘現象及其緩解。
# ============================================================
# Lab 1: 終身學習 — EWC vs 經驗重播 vs 樸素微調(Split MNIST)
# 環境: Google Colab (CPU 即可)
# ============================================================
# --- 0. 安裝 ---
!pip install -q torch torchvision matplotlib
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import copy
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# --- 1. 資料準備: Split MNIST ---
# 任務 A: 數字 0-4, 任務 B: 數字 5-9
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_data = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = torchvision.datasets.MNIST('./data', train=False, transform=transform)
def filter_by_labels(dataset, labels):
"""篩選特定標籤的資料"""
mask = torch.zeros(len(dataset.targets), dtype=torch.bool)
for l in labels:
mask |= (dataset.targets == l)
indices = mask.nonzero(as_tuple=True)[0]
return torch.utils.data.Subset(dataset, indices)
task_a_labels = [0, 1, 2, 3, 4]
task_b_labels = [5, 6, 7, 8, 9]
train_a = filter_by_labels(train_data, task_a_labels)
train_b = filter_by_labels(train_data, task_b_labels)
test_a = filter_by_labels(test_data, task_a_labels)
test_b = filter_by_labels(test_data, task_b_labels)
loader_a = torch.utils.data.DataLoader(train_a, batch_size=128, shuffle=True)
loader_b = torch.utils.data.DataLoader(train_b, batch_size=128, shuffle=True)
test_loader_a = torch.utils.data.DataLoader(test_a, batch_size=256)
test_loader_b = torch.utils.data.DataLoader(test_b, batch_size=256)
print(f"Task A (digits 0-4): {len(train_a)} train, {len(test_a)} test")
print(f"Task B (digits 5-9): {len(train_b)} train, {len(test_b)} test")
# --- 2. 簡單 CNN 模型 ---
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2)
self.fc1 = nn.Linear(64 * 7 * 7, 256)
self.fc2 = nn.Linear(256, 10) # 所有 10 類共享輸出
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
return self.fc2(x)
def evaluate(model, loader):
model.eval()
correct, total = 0, 0
with torch.no_grad():
for x, y in loader:
x, y = x.to(device), y.to(device)
pred = model(x).argmax(dim=1)
correct += (pred == y).sum().item()
total += y.size(0)
return correct / total
# --- 3. EWC 實現 ---
class EWC:
def __init__(self, model, dataloader, device, n_samples=200):
self.params = {n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad}
self.fisher = self._compute_fisher(model, dataloader, device, n_samples)
def _compute_fisher(self, model, dataloader, device, n_samples):
"""計算 Fisher 資訊矩陣(對角近似)"""
fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
model.eval()
count = 0
for x, y in dataloader:
if count >= n_samples:
break
x, y = x.to(device), y.to(device)
model.zero_grad()
output = model(x)
loss = F.cross_entropy(output, y)
loss.backward()
for n, p in model.named_parameters():
if p.requires_grad and p.grad is not None:
fisher[n] += p.grad.data.pow(2) * x.size(0)
count += x.size(0)
fisher = {n: f / count for n, f in fisher.items()}
return fisher
def penalty(self, model):
"""EWC 正則化項"""
loss = 0
for n, p in model.named_parameters():
if p.requires_grad and n in self.fisher:
loss += (self.fisher[n] * (p - self.params[n]).pow(2)).sum()
return loss
# --- 4. 經驗重播記憶緩衝 ---
class ReplayBuffer:
def __init__(self, capacity=200):
self.capacity = capacity
self.buffer_x = []
self.buffer_y = []
def add_from_loader(self, loader, n_samples):
"""從 loader 中隨機取樣加入緩衝"""
all_x, all_y = [], []
for x, y in loader:
all_x.append(x)
all_y.append(y)
all_x = torch.cat(all_x)
all_y = torch.cat(all_y)
indices = torch.randperm(len(all_x))[:n_samples]
self.buffer_x.append(all_x[indices])
self.buffer_y.append(all_y[indices])
def sample(self, batch_size):
all_x = torch.cat(self.buffer_x)
all_y = torch.cat(self.buffer_y)
indices = torch.randperm(len(all_x))[:batch_size]
return all_x[indices], all_y[indices]
# --- 5. 訓練函數 ---
def train_task(model, loader, optimizer, epochs, ewc=None, ewc_lambda=0,
replay_buffer=None, replay_batch=32):
model.train()
for epoch in range(epochs):
total_loss = 0
for x, y in loader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
output = model(x)
loss = F.cross_entropy(output, y)
# EWC 正則化
if ewc is not None:
loss += ewc_lambda * ewc.penalty(model)
# 經驗重播
if replay_buffer is not None and len(replay_buffer.buffer_x) > 0:
rx, ry = replay_buffer.sample(replay_batch)
rx, ry = rx.to(device), ry.to(device)
r_output = model(rx)
loss += F.cross_entropy(r_output, ry)
loss.backward()
optimizer.step()
total_loss += loss.item()
# --- 6. 實驗: 三種策略對比 ---
n_epochs = 5
results = {}
# 策略 1: 樸素微調(Naive)
print("\n=== Strategy 1: Naive Fine-tuning ===")
model_naive = SimpleCNN().to(device)
opt = torch.optim.Adam(model_naive.parameters(), lr=1e-3)
train_task(model_naive, loader_a, opt, n_epochs)
acc_a_after_a = evaluate(model_naive, test_loader_a)
print(f"After Task A: Acc_A={acc_a_after_a:.4f}")
train_task(model_naive, loader_b, opt, n_epochs)
acc_a_after_b = evaluate(model_naive, test_loader_a)
acc_b_after_b = evaluate(model_naive, test_loader_b)
print(f"After Task B: Acc_A={acc_a_after_b:.4f}, Acc_B={acc_b_after_b:.4f}")
print(f"Forgetting: {acc_a_after_a - acc_a_after_b:.4f}")
results['Naive'] = (acc_a_after_a, acc_a_after_b, acc_b_after_b)
# 策略 2: EWC
print("\n=== Strategy 2: EWC (λ=400) ===")
model_ewc = SimpleCNN().to(device)
opt = torch.optim.Adam(model_ewc.parameters(), lr=1e-3)
train_task(model_ewc, loader_a, opt, n_epochs)
acc_a_after_a = evaluate(model_ewc, test_loader_a)
print(f"After Task A: Acc_A={acc_a_after_a:.4f}")
# 計算 Fisher 矩陣
ewc = EWC(model_ewc, loader_a, device)
train_task(model_ewc, loader_b, opt, n_epochs, ewc=ewc, ewc_lambda=400)
acc_a_after_b = evaluate(model_ewc, test_loader_a)
acc_b_after_b = evaluate(model_ewc, test_loader_b)
print(f"After Task B: Acc_A={acc_a_after_b:.4f}, Acc_B={acc_b_after_b:.4f}")
print(f"Forgetting: {acc_a_after_a - acc_a_after_b:.4f}")
results['EWC'] = (acc_a_after_a, acc_a_after_b, acc_b_after_b)
# 策略 3: 經驗重播
print("\n=== Strategy 3: Experience Replay (buffer=200) ===")
model_er = SimpleCNN().to(device)
opt = torch.optim.Adam(model_er.parameters(), lr=1e-3)
buffer = ReplayBuffer(capacity=200)
train_task(model_er, loader_a, opt, n_epochs)
acc_a_after_a = evaluate(model_er, test_loader_a)
print(f"After Task A: Acc_A={acc_a_after_a:.4f}")
# 將任務 A 的樣本加入緩衝
buffer.add_from_loader(loader_a, n_samples=200)
train_task(model_er, loader_b, opt, n_epochs, replay_buffer=buffer)
acc_a_after_b = evaluate(model_er, test_loader_a)
acc_b_after_b = evaluate(model_er, test_loader_b)
print(f"After Task B: Acc_A={acc_a_after_b:.4f}, Acc_B={acc_b_after_b:.4f}")
print(f"Forgetting: {acc_a_after_a - acc_a_after_b:.4f}")
results['Replay'] = (acc_a_after_a, acc_a_after_b, acc_b_after_b)
# --- 7. 視覺化結果 ---
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# 左圖: 任務 A 準確率(遺忘程度)
strategies = list(results.keys())
acc_before = [results[s][0] for s in strategies]
acc_after = [results[s][1] for s in strategies]
x_pos = np.arange(len(strategies))
width = 0.35
bars1 = axes[0].bar(x_pos - width/2, acc_before, width, label='After Task A', color='#0077b6')
bars2 = axes[0].bar(x_pos + width/2, acc_after, width, label='After Task B', color='#e63946')
axes[0].set_ylabel('Task A Accuracy')
axes[0].set_title('Catastrophic Forgetting: Task A Performance', fontsize=13)
axes[0].set_xticks(x_pos)
axes[0].set_xticklabels(strategies)
axes[0].legend()
axes[0].set_ylim(0, 1.05)
axes[0].grid(True, alpha=0.3, axis='y')
for bar in bars1:
axes[0].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=10)
for bar in bars2:
axes[0].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=10)
# 右圖: 遺忘量
forgetting = [results[s][0] - results[s][1] for s in strategies]
colors = ['#e63946' if f > 0.3 else '#b8922e' if f > 0.1 else '#2a9d8f' for f in forgetting]
bars = axes[1].bar(strategies, forgetting, color=colors, edgecolor='white', linewidth=1.5)
axes[1].set_ylabel('Forgetting (↓ better)')
axes[1].set_title('Amount of Forgetting on Task A', fontsize=13)
axes[1].grid(True, alpha=0.3, axis='y')
for bar, f in zip(bars, forgetting):
axes[1].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
f'{f:.3f}', ha='center', va='bottom', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()
# --- 8. 逐類準確率分析 ---
print("\n=== Per-class Accuracy After Task B ===")
print(f"{'Class':<8} {'Naive':>8} {'EWC':>8} {'Replay':>8}")
print("-" * 36)
models = {'Naive': model_naive, 'EWC': model_ewc, 'Replay': model_er}
for digit in range(10):
test_digit = filter_by_labels(test_data, [digit])
loader_digit = torch.utils.data.DataLoader(test_digit, batch_size=256)
accs = []
for name in ['Naive', 'EWC', 'Replay']:
acc = evaluate(models[name], loader_digit)
accs.append(acc)
marker = " ← Task A" if digit < 5 else ""
print(f" {digit:<6} {accs[0]:>8.1%} {accs[1]:>8.1%} {accs[2]:>8.1%}{marker}")
print("\nLab 1 Complete!")
八、Hands-on Lab 2:BERT 持續學習多任務文本分類(Google Colab)
以下實驗展示語言模型的災難性遺忘:在兩個文本分類任務上依序微調 BERT,觀察遺忘現象,並以經驗重播進行緩解。
# ============================================================
# Lab 2: BERT 持續學習 — 多任務文本分類的災難性遺忘與緩解
# 環境: Google Colab (GPU 推薦, CPU 可用但較慢)
# ============================================================
# --- 0. 安裝 ---
!pip install -q transformers datasets torch
import torch
import torch.nn.functional as F
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
import random
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# --- 1. 載入兩個文本分類任務 ---
print("\n--- Loading Datasets ---")
# 任務 A: SST-2 情感分類(正/負)
sst2 = load_dataset("glue", "sst2")
print(f"Task A (SST-2): {len(sst2['train'])} train, {len(sst2['validation'])} val")
# 任務 B: MRPC 語義等價判斷(等價/不等價)
mrpc = load_dataset("glue", "mrpc")
print(f"Task B (MRPC): {len(mrpc['train'])} train, {len(mrpc['validation'])} val")
# --- 2. Tokenizer ---
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def tokenize_sst2(examples):
return tokenizer(examples['sentence'], truncation=True, padding='max_length', max_length=64)
def tokenize_mrpc(examples):
return tokenizer(examples['sentence1'], examples['sentence2'],
truncation=True, padding='max_length', max_length=128)
sst2_tok = sst2.map(tokenize_sst2, batched=True)
mrpc_tok = mrpc.map(tokenize_mrpc, batched=True)
for ds in [sst2_tok, mrpc_tok]:
ds.set_format("torch", columns=["input_ids", "attention_mask", "label"])
# 使用子集(Colab 友善)
sst2_train = sst2_tok["train"].shuffle(seed=42).select(range(1000))
sst2_val = sst2_tok["validation"]
mrpc_train = mrpc_tok["train"].shuffle(seed=42).select(range(800))
mrpc_val = mrpc_tok["validation"]
# --- 3. 訓練與評估工具 ---
def make_loader(dataset, batch_size=16, shuffle=True):
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
def evaluate_task(model, loader):
model.eval()
correct, total = 0, 0
with torch.no_grad():
for batch in loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['label'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
preds = outputs.logits.argmax(dim=-1)
correct += (preds == labels).sum().item()
total += labels.size(0)
return correct / total
def train_epoch(model, loader, optimizer, replay_data=None, replay_ratio=0.3):
model.train()
total_loss = 0
for batch in loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['label'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
# 經驗重播
if replay_data is not None and len(replay_data) > 0:
n_replay = max(1, int(input_ids.size(0) * replay_ratio))
indices = random.sample(range(len(replay_data)), min(n_replay, len(replay_data)))
r_ids = torch.stack([replay_data[i]['input_ids'] for i in indices]).to(device)
r_mask = torch.stack([replay_data[i]['attention_mask'] for i in indices]).to(device)
r_labels = torch.tensor([replay_data[i]['label'] for i in indices]).to(device)
r_outputs = model(input_ids=r_ids, attention_mask=r_mask, labels=r_labels)
loss = loss + r_outputs.loss
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
return total_loss / len(loader)
# --- 4. 建立重播記憶 ---
def create_replay_buffer(dataset, n_samples=100):
"""從資料集中取樣建立重播緩衝"""
indices = random.sample(range(len(dataset)), min(n_samples, len(dataset)))
buffer = []
for i in indices:
item = dataset[i]
buffer.append({
'input_ids': item['input_ids'],
'attention_mask': item['attention_mask'],
'label': item['label'].item() if isinstance(item['label'], torch.Tensor) else item['label']
})
return buffer
# --- 5. 實驗 1: 樸素持續微調(展示遺忘)---
print("\n" + "="*60)
print("Experiment 1: Naive Sequential Fine-tuning")
print("="*60)
model_naive = BertForSequenceClassification.from_pretrained(
'bert-base-uncased', num_labels=2
).to(device)
opt = AdamW(model_naive.parameters(), lr=2e-5, weight_decay=0.01)
# 訓練任務 A (SST-2)
print("\n--- Training on Task A (SST-2) ---")
loader_a = make_loader(sst2_train)
val_loader_a = make_loader(sst2_val, shuffle=False)
naive_history = {'task_a_on_a': [], 'task_a_on_b': [], 'task_b_on_b': []}
for epoch in range(3):
loss = train_epoch(model_naive, loader_a, opt)
acc = evaluate_task(model_naive, val_loader_a)
naive_history['task_a_on_a'].append(acc)
print(f" Epoch {epoch+1}: loss={loss:.4f}, SST-2 acc={acc:.4f}")
acc_a_before = naive_history['task_a_on_a'][-1]
# 訓練任務 B (MRPC)
print("\n--- Training on Task B (MRPC) ---")
loader_b = make_loader(mrpc_train)
val_loader_b = make_loader(mrpc_val, shuffle=False)
for epoch in range(3):
loss = train_epoch(model_naive, loader_b, opt)
acc_a = evaluate_task(model_naive, val_loader_a)
acc_b = evaluate_task(model_naive, val_loader_b)
naive_history['task_a_on_b'].append(acc_a)
naive_history['task_b_on_b'].append(acc_b)
print(f" Epoch {epoch+1}: loss={loss:.4f}, SST-2 acc={acc_a:.4f}, MRPC acc={acc_b:.4f}")
acc_a_after_naive = naive_history['task_a_on_b'][-1]
acc_b_naive = naive_history['task_b_on_b'][-1]
# --- 6. 實驗 2: 經驗重播 ---
print("\n" + "="*60)
print("Experiment 2: Experience Replay (buffer=100)")
print("="*60)
model_replay = BertForSequenceClassification.from_pretrained(
'bert-base-uncased', num_labels=2
).to(device)
opt = AdamW(model_replay.parameters(), lr=2e-5, weight_decay=0.01)
# 訓練任務 A
print("\n--- Training on Task A (SST-2) ---")
replay_history = {'task_a_on_a': [], 'task_a_on_b': [], 'task_b_on_b': []}
for epoch in range(3):
loss = train_epoch(model_replay, loader_a, opt)
acc = evaluate_task(model_replay, val_loader_a)
replay_history['task_a_on_a'].append(acc)
print(f" Epoch {epoch+1}: loss={loss:.4f}, SST-2 acc={acc:.4f}")
# 建立重播緩衝
replay_buffer = create_replay_buffer(sst2_train, n_samples=100)
print(f"\nReplay buffer: {len(replay_buffer)} samples from Task A")
# 訓練任務 B + 重播
print("\n--- Training on Task B (MRPC) with Replay ---")
for epoch in range(3):
loss = train_epoch(model_replay, loader_b, opt, replay_data=replay_buffer)
acc_a = evaluate_task(model_replay, val_loader_a)
acc_b = evaluate_task(model_replay, val_loader_b)
replay_history['task_a_on_b'].append(acc_a)
replay_history['task_b_on_b'].append(acc_b)
print(f" Epoch {epoch+1}: loss={loss:.4f}, SST-2 acc={acc_a:.4f}, MRPC acc={acc_b:.4f}")
acc_a_after_replay = replay_history['task_a_on_b'][-1]
acc_b_replay = replay_history['task_b_on_b'][-1]
# --- 7. 視覺化比較 ---
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
# 左圖: Task A 準確率隨訓練變化
epochs_a = list(range(1, 4))
epochs_b = list(range(4, 7))
all_epochs = epochs_a + epochs_b
naive_a_curve = naive_history['task_a_on_a'] + naive_history['task_a_on_b']
replay_a_curve = replay_history['task_a_on_a'] + replay_history['task_a_on_b']
axes[0].plot(all_epochs, naive_a_curve, 'o-', color='#e63946', linewidth=2, label='Naive')
axes[0].plot(all_epochs, replay_a_curve, 's-', color='#0077b6', linewidth=2, label='Replay')
axes[0].axvline(x=3.5, color='gray', linestyle='--', alpha=0.5)
axes[0].text(2, 0.55, 'Task A\n(SST-2)', ha='center', fontsize=10, color='gray')
axes[0].text(5, 0.55, 'Task B\n(MRPC)', ha='center', fontsize=10, color='gray')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('SST-2 Accuracy')
axes[0].set_title('Task A (SST-2) Accuracy Over Time', fontsize=13)
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[0].set_ylim(0.5, 1.0)
# 中圖: 最終結果對比
categories = ['Task A\n(SST-2)', 'Task B\n(MRPC)']
naive_scores = [acc_a_after_naive, acc_b_naive]
replay_scores = [acc_a_after_replay, acc_b_replay]
x = np.arange(len(categories))
width = 0.35
axes[1].bar(x - width/2, naive_scores, width, label='Naive', color='#e63946', alpha=0.85)
axes[1].bar(x + width/2, replay_scores, width, label='Replay', color='#0077b6', alpha=0.85)
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Final Performance Comparison', fontsize=13)
axes[1].set_xticks(x)
axes[1].set_xticklabels(categories)
axes[1].legend()
axes[1].set_ylim(0.5, 1.0)
axes[1].grid(True, alpha=0.3, axis='y')
# 右圖: 遺忘量
forgetting_naive = acc_a_before - acc_a_after_naive
forgetting_replay = replay_history['task_a_on_a'][-1] - acc_a_after_replay
bars = axes[2].bar(['Naive', 'Replay'], [forgetting_naive, forgetting_replay],
color=['#e63946', '#0077b6'], edgecolor='white', linewidth=1.5)
axes[2].set_ylabel('Forgetting (↓ better)')
axes[2].set_title('SST-2 Forgetting After MRPC Training', fontsize=13)
axes[2].grid(True, alpha=0.3, axis='y')
for bar, f in zip(bars, [forgetting_naive, forgetting_replay]):
axes[2].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.005,
f'{f:.3f}', ha='center', va='bottom', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()
# --- 8. 推論示範 ---
print("\n=== Inference Demo ===")
test_sentences = [
("This movie is absolutely wonderful!", "SST-2 (Positive)"),
("A terrible waste of time and money.", "SST-2 (Negative)"),
("The film was average, nothing special.", "SST-2 (Neutral-ish)"),
]
print("\n--- Naive Model ---")
model_naive.eval()
for text, label in test_sentences:
inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=64).to(device)
with torch.no_grad():
logits = model_naive(**inputs).logits
pred = "Positive" if logits.argmax().item() == 1 else "Negative"
conf = torch.softmax(logits, dim=-1).max().item()
print(f" [{pred} {conf:.1%}] {text} (expected: {label})")
print("\n--- Replay Model ---")
model_replay.eval()
for text, label in test_sentences:
inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=64).to(device)
with torch.no_grad():
logits = model_replay(**inputs).logits
pred = "Positive" if logits.argmax().item() == 1 else "Negative"
conf = torch.softmax(logits, dim=-1).max().item()
print(f" [{pred} {conf:.1%}] {text} (expected: {label})")
print("\nLab 2 Complete!")
九、決策框架:企業如何選擇終身學習策略
根據資料可用性、隱私約束與計算預算,企業可以用以下框架選擇適合的終身學習方案:
| 條件 | 推薦方法 | 理由 |
|---|---|---|
| 可存儲舊資料、記憶體充足 | 經驗重播(ER / DER++)[9][14] | 最簡單有效,200 個樣本/任務即可大幅降低遺忘 |
| 隱私約束,不能存舊資料 | EWC[3] + LwF[5] | 僅需存 Fisher 矩陣或模型快照,無需原始資料 |
| 任務數量多且持續增長 | PackNet[7] 或 HAT[8] | 在固定模型容量中支持多任務,不需額外存儲 |
| 任務數量少但要求零遺忘 | Progressive Networks[6] | 完全隔離,零遺忘保證,適合關鍵任務場景 |
| 語言模型持續微調 | 經驗重播 + 學習率調整[15] | 對 Transformer 架構最有效,EWC 在 NLP 中效果有限 |
| 隱私約束 + 充足計算 | 生成式重播(GR)[10] | 生成虛擬舊資料,兼顧隱私與防遺忘 |
決策樹:
1. 是否可以存儲舊任務的真實資料?
├── 是 → 經驗重播(首選 DER++)
└── 否 → 2
2. 模型容量是否可以增長?
├── 是 → Progressive Networks(零遺忘)
└── 否 → 3
3. 計算預算是否充足?
├── 是 → 生成式重播(GAN/VAE 生成虛擬資料)
└── 否 → EWC + LwF(僅需存 Fisher 矩陣 + 舊模型快照)
十、結語與展望
災難性遺忘[1]是深度學習通往真正人工智慧的核心障礙之一。一個無法持續學習的系統——無論多麼強大——都只是靜態的工具,而非進化的智慧體。
回顧核心脈絡:
- 問題本質:穩定性-可塑性困境[2]是連結主義模型的根本挑戰,梯度下降對共享參數的無差別更新是遺忘的直接原因
- 正則化路線:EWC[3] 和 SI[4] 以參數重要性為指導保護舊知識,LwF[5] 則以知識蒸餾保護輸出分布
- 架構路線:Progressive Networks[6] 和 PackNet[7] 以結構隔離換取零遺忘保證
- 重播路線:經驗重播[9]以極小的記憶代價(每個任務幾百個樣本)實現最佳的防遺忘效果,DER++[14] 進一步融合暗知識蒸餾
- 語言模型場景:大型預訓練模型的終身學習[16]是當前最迫切的研究方向,經驗重播是目前最有效的方案
展望未來,終身學習正在從學術研究走向工程實踐。隨著 Foundation Model 的普及[17],企業需要讓模型持續適應新資料、新任務和新領域——而不是每次都從頭預訓練。稀疏混合專家(MoE)模型天然地為不同任務分配不同的專家子網路,提供了架構層面的終身學習潛力;而參數高效微調(LoRA、Adapter)透過凍結主幹、只訓練小型模組,為每個任務提供了輕量級的專屬適配——這本質上就是一種終身學習策略。當 AI 系統學會像人類一樣持續學習而不遺忘,我們離真正的通用智慧又近了一步。