- Catastrophic forgetting[1] is a fundamental flaw in deep learning — when a model learns a new task, it drastically forgets old task knowledge, preventing AI from learning continuously like humans
- Three major strategy categories for continual learning: Regularization methods (EWC[3], SI[4]) protect important parameters, Architecture methods (Progressive Networks[6]) dynamically expand networks, Replay methods (ER[9], GR[10]) revisit old task samples
- Large language models also face catastrophic forgetting[15] — continually fine-tuning BERT/GPT rapidly erases prior task capabilities, and experience replay is currently the most effective mitigation strategy
- This article includes two Google Colab hands-on labs: EWC preventing MNIST image classification forgetting, and BERT continual learning for multi-task text classification, both executable directly in the browser
1. Why AI Forgets: The Nature of Catastrophic Forgetting
Humans can continuously learn new skills without forgetting old knowledge — after learning to ride a bicycle, learning to swim does not cause you to forget how to ride. However, deep neural networks face a fundamental problem: Catastrophic Forgetting[1][2].
When a network already trained on Task A is subsequently trained on Task B, performance on Task A drops precipitously — not a gradual degradation, but a catastrophic collapse. This occurs because gradient descent indiscriminately updates all parameters to optimize the current task, overwriting weights that were critical for the old task.
The Nature of Catastrophic Forgetting:
Task A training complete: θ* = argmin_θ L_A(θ)
→ Parameters converge to A's optimal solution
Task B sequential training: θ** = argmin_θ L_B(θ), starting from θ*
→ Parameters move away from A's optimum, A's performance collapses
Root cause: Stability-Plasticity Dilemma
- Too stable → Cannot learn new tasks (underfitting)
- Too plastic → Forgets old tasks (catastrophic forgetting)
- Goal: Find a balance between the two
Concrete example (image classification):
Step 1: Train on digits 0-4 → Accuracy 98%
Step 2: Train on digits 5-9 → 5-9 accuracy 97%, but 0-4 drops to ~20%
Reason: Gradients for 5-9 destroy the key weights that distinguished 0-4
Catastrophic forgetting is not limited to image classification. It is equally severe during continual fine-tuning of language models[15]: a BERT fine-tuned for sentiment classification, when subsequently fine-tuned for named entity recognition, sees its sentiment classification capability severely degraded. This problem has become even more critical in the era of large language models — we want models to continuously learn from new data without retraining from scratch each time.
2. The Continual Learning Landscape: Three Major Strategy Categories
The research goal of Continual Learning (also known as Lifelong Learning) is to enable models to learn multiple tasks sequentially while both mastering new tasks and maintaining performance on old ones[12][13]. Based on different solution approaches, methods can be categorized into three main classes[17]:
| Strategy Category | Core Idea | Representative Methods | Advantages | Limitations |
|---|---|---|---|---|
| Regularization Methods | Add penalty terms to the loss function to constrain changes in important parameters | EWC[3], SI[4], LwF[5] | No need to store old data, fixed memory footprint | Protection capability degrades as task count increases |
| Architecture Methods | Allocate different network structures or sub-networks for different tasks | Progressive Nets[6], PackNet[7], HAT[8] | Zero forgetting (hard isolation) | Model size grows with task count |
| Replay Methods | Store or generate old task samples and co-train while learning new tasks | ER[9], GR[10], GEM[11], DER++[14] | Simple, effective, composable with other methods | Requires additional memory for storing old samples |
There are also three evaluation scenario levels for continual learning:
Three Continual Learning Scenarios:
1. Task-Incremental Learning (Task-IL):
Task identity is known at inference time → Easiest
Example: "This is Task B data, use B's classification head"
2. Domain-Incremental Learning (Domain-IL):
Same task structure, but data distribution changes → Medium difficulty
Example: Same 10-class classification, but image style changes from sketches to photos
3. Class-Incremental Learning (Class-IL):
Task identity is unknown at inference, must distinguish among all learned classes → Hardest
Example: Learn 0-4 first, then learn 5-9, at test time must distinguish all digits 0-9
Difficulty ranking: Task-IL < Domain-IL < Class-IL
In practical applications, Class-IL most closely matches real-world needs
3. Regularization Methods: EWC and Knowledge Distillation
Elastic Weight Consolidation (EWC)
EWC[3] is the most influential regularization method in continual learning, inspired by synaptic consolidation in neuroscience — important synaptic connections should be protected, while less important ones can be freely updated.
The core question is: how to measure each parameter's "importance" to old tasks? EWC's answer is the Fisher Information Matrix:
EWC Loss Function:
L_total(θ) = L_B(θ) + (λ/2) Σ_i F_i (θ_i - θ*_A,i)²
Where:
L_B(θ): Loss on new task B
θ*_A: Optimal parameters after training on old task A
F_i: Diagonal elements of Fisher Information Matrix (importance of parameter i to task A)
λ: Regularization strength (controls stability-plasticity balance)
Fisher Information Matrix (diagonal approximation):
F_i = E_{x~D_A} [(∂ log p(y|x,θ) / ∂θ_i)²]
Intuitive understanding:
F_i large → Parameter i is very important to task A → Strongly constrain its changes
F_i small → Parameter i is not important to task A → Free to update for learning task B
Geometric perspective:
There exists a "low-loss valley" around task A's optimum θ*_A
The Fisher matrix describes the shape of this valley
EWC guides task B's optimization to move along the valley's extension
→ Find parameters that work well for both A and B
Synaptic Intelligence (SI)
SI[4] is an online alternative to EWC. While EWC needs to compute the Fisher matrix after each task, SI accumulates each parameter's importance in real-time during training — tracking each parameter's contribution to loss reduction along the "path traveled" during training.
Learning without Forgetting (LwF)
LwF[5] takes a different approach — instead of protecting parameters, it protects outputs. Before learning a new task, it first passes new task data through the old model to obtain "soft labels," then during new task training, it simultaneously uses a knowledge distillation loss to keep the old task's output distribution unchanged. Its greatest advantage is that it requires no storage of old task data whatsoever.
4. Architecture Methods: Progressive Networks and Dynamic Expansion
The philosophy of architecture methods is: rather than struggling to balance new and old tasks within a limited parameter space, allocate dedicated network capacity for each task.
Progressive Neural Networks
The Progressive Networks proposed by Rusu et al.[6] is the most straightforward approach — for each new task, a new network column "grows" alongside the existing ones, with lateral connections allowing the new task to reuse features learned by old tasks:
Progressive Neural Networks:
Task 1: [Column 1] ← Normal training
Task 2: [Column 1] (frozen) ←─ lateral connections ──→ [Column 2] ← Only this is trained
Task 3: [Column 1] (frozen) ←─┐ [Column 2] (frozen) ←─┐
└─ lateral connections ──→ └─→ [Column 3]
Pros: Absolute zero forgetting (old columns are frozen)
Cons: Parameter count grows linearly (T tasks = T× parameters)
PackNet and HAT
PackNet[7] and HAT[8] attempt to achieve multi-task learning within a fixed-size network:
| Method | Strategy | Mechanism | Characteristics |
|---|---|---|---|
| PackNet[7] | Iterative pruning | Train → prune unimportant weights → free capacity for next task | Each task gets a dedicated sparse sub-network |
| HAT[8] | Hard attention masks | Learn binary masks for each task, protecting occupied neurons | Masks are gradient-optimizable, automatically allocating capacity |
5. Replay Methods: Memory Buffers and Generative Replay
Experience Replay methods draw inspiration from memory consolidation in cognitive science — humans "replay" daytime experiences during sleep to consolidate memories. In continual learning, replay methods mix in old task samples while learning new tasks[9].
Experience Replay (ER)
The most straightforward approach: maintain a fixed-size memory buffer storing a small number of representative samples from each old task. When learning a new task, each mini-batch mixes new task data with samples drawn from the buffer:
Experience Replay Flow:
Memory Buffer M (fixed size, e.g., 200 samples)
Learning task t:
for each mini-batch:
batch_new = sample(D_t) # New task data
batch_old = sample(M) # Sample old data from buffer
loss = L(batch_new) + L(batch_old) # Joint loss
Update θ
After task completion:
Add representative samples from D_t to M (using reservoir sampling or herding)
Reservoir Sampling:
With probability |M| / n, add the n-th sample to the buffer,
ensuring each previously seen sample has an equal chance of being selected
Key Finding (Rolnick et al., 2019):
Just 1-5 samples/class can significantly reduce forgetting
→ Minimal memory cost yields substantial anti-forgetting effects
Generative Replay (GR)
Shin et al.[10] proposed an elegant alternative: instead of storing real old data, train a generative model (such as a GAN or VAE) to generate virtual samples of old tasks. This is particularly valuable in privacy-sensitive scenarios — medical data cannot be stored, but its distribution can be reconstructed using a generative model.
GEM and DER++
GEM[11] (Gradient Episodic Memory) uses samples in memory to compute gradient constraints: new task gradient updates must not increase old task loss on memory samples. DER++[14] combines experience replay with knowledge distillation — replaying not only old data labels but also the old model's soft outputs (logits), preserving richer information in the form of "dark knowledge."
6. Continual Learning for Text AI: Continual Fine-Tuning of Language Models
Continual learning for large language models is at the forefront of current research[16]. When enterprises want BERT or GPT to continuously adapt to new tasks or domains, catastrophic forgetting severely impacts existing capabilities:
Continual Learning Scenarios for Language Models:
1. Continual Task Fine-tuning:
BERT → Sentiment Analysis → NER → QA → Text Summarization
Problem: Later fine-tuning destroys earlier task capabilities
2. Continual Domain Adaptation:
General LLM → Financial domain → Legal domain → Medical domain
Problem: New domain knowledge overwrites old domain expertise
3. Continual Pre-training:
Foundation model → Continuously absorbing new documents/knowledge
Problem: New knowledge may damage foundational language understanding
Unique Challenges of Language Model Forgetting:
- Extremely high parameter sharing (all tasks share the same Transformer)
- More severe representation space interference (high semantic overlap)
- Task heads can be separated, but underlying representations are hard to isolate
Scialom et al.'s research[15] demonstrates that experience replay is currently the most effective method for language model continual learning — mixing in small amounts of old task samples when learning new tasks significantly reduces forgetting. This is more effective than regularization methods like EWC in NLP scenarios because the parameter importance distribution for language tasks is more uniform, limiting the discriminative power of regularization constraints.
7. Hands-on Lab 1: EWC Preventing MNIST Image Classification Forgetting (Google Colab)
The following experiment compares three strategies on Split MNIST: (1) Naive fine-tuning, (2) EWC regularization, and (3) Experience Replay (ER), visually demonstrating the catastrophic forgetting phenomenon and its mitigation.
# ============================================================
# Lab 1: Continual Learning — EWC vs Experience Replay vs Naive Fine-tuning (Split MNIST)
# Environment: Google Colab (CPU is sufficient)
# ============================================================
# --- 0. Installation ---
!pip install -q torch torchvision matplotlib
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import copy
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# --- 1. Data Preparation: Split MNIST ---
# Task A: Digits 0-4, Task B: Digits 5-9
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_data = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = torchvision.datasets.MNIST('./data', train=False, transform=transform)
def filter_by_labels(dataset, labels):
"""Filter data by specific labels"""
mask = torch.zeros(len(dataset.targets), dtype=torch.bool)
for l in labels:
mask |= (dataset.targets == l)
indices = mask.nonzero(as_tuple=True)[0]
return torch.utils.data.Subset(dataset, indices)
task_a_labels = [0, 1, 2, 3, 4]
task_b_labels = [5, 6, 7, 8, 9]
train_a = filter_by_labels(train_data, task_a_labels)
train_b = filter_by_labels(train_data, task_b_labels)
test_a = filter_by_labels(test_data, task_a_labels)
test_b = filter_by_labels(test_data, task_b_labels)
loader_a = torch.utils.data.DataLoader(train_a, batch_size=128, shuffle=True)
loader_b = torch.utils.data.DataLoader(train_b, batch_size=128, shuffle=True)
test_loader_a = torch.utils.data.DataLoader(test_a, batch_size=256)
test_loader_b = torch.utils.data.DataLoader(test_b, batch_size=256)
print(f"Task A (digits 0-4): {len(train_a)} train, {len(test_a)} test")
print(f"Task B (digits 5-9): {len(train_b)} train, {len(test_b)} test")
# --- 2. Simple CNN Model ---
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2)
self.fc1 = nn.Linear(64 * 7 * 7, 256)
self.fc2 = nn.Linear(256, 10) # All 10 classes share the output
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
return self.fc2(x)
def evaluate(model, loader):
model.eval()
correct, total = 0, 0
with torch.no_grad():
for x, y in loader:
x, y = x.to(device), y.to(device)
pred = model(x).argmax(dim=1)
correct += (pred == y).sum().item()
total += y.size(0)
return correct / total
# --- 3. EWC Implementation ---
class EWC:
def __init__(self, model, dataloader, device, n_samples=200):
self.params = {n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad}
self.fisher = self._compute_fisher(model, dataloader, device, n_samples)
def _compute_fisher(self, model, dataloader, device, n_samples):
"""Compute Fisher Information Matrix (diagonal approximation)"""
fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
model.eval()
count = 0
for x, y in dataloader:
if count >= n_samples:
break
x, y = x.to(device), y.to(device)
model.zero_grad()
output = model(x)
loss = F.cross_entropy(output, y)
loss.backward()
for n, p in model.named_parameters():
if p.requires_grad and p.grad is not None:
fisher[n] += p.grad.data.pow(2) * x.size(0)
count += x.size(0)
fisher = {n: f / count for n, f in fisher.items()}
return fisher
def penalty(self, model):
"""EWC regularization term"""
loss = 0
for n, p in model.named_parameters():
if p.requires_grad and n in self.fisher:
loss += (self.fisher[n] * (p - self.params[n]).pow(2)).sum()
return loss
# --- 4. Experience Replay Memory Buffer ---
class ReplayBuffer:
def __init__(self, capacity=200):
self.capacity = capacity
self.buffer_x = []
self.buffer_y = []
def add_from_loader(self, loader, n_samples):
"""Sample randomly from loader and add to buffer"""
all_x, all_y = [], []
for x, y in loader:
all_x.append(x)
all_y.append(y)
all_x = torch.cat(all_x)
all_y = torch.cat(all_y)
indices = torch.randperm(len(all_x))[:n_samples]
self.buffer_x.append(all_x[indices])
self.buffer_y.append(all_y[indices])
def sample(self, batch_size):
all_x = torch.cat(self.buffer_x)
all_y = torch.cat(self.buffer_y)
indices = torch.randperm(len(all_x))[:batch_size]
return all_x[indices], all_y[indices]
# --- 5. Training Function ---
def train_task(model, loader, optimizer, epochs, ewc=None, ewc_lambda=0,
replay_buffer=None, replay_batch=32):
model.train()
for epoch in range(epochs):
total_loss = 0
for x, y in loader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
output = model(x)
loss = F.cross_entropy(output, y)
# EWC regularization
if ewc is not None:
loss += ewc_lambda * ewc.penalty(model)
# Experience replay
if replay_buffer is not None and len(replay_buffer.buffer_x) > 0:
rx, ry = replay_buffer.sample(replay_batch)
rx, ry = rx.to(device), ry.to(device)
r_output = model(rx)
loss += F.cross_entropy(r_output, ry)
loss.backward()
optimizer.step()
total_loss += loss.item()
# --- 6. Experiment: Comparing Three Strategies ---
n_epochs = 5
results = {}
# Strategy 1: Naive Fine-tuning
print("\n=== Strategy 1: Naive Fine-tuning ===")
model_naive = SimpleCNN().to(device)
opt = torch.optim.Adam(model_naive.parameters(), lr=1e-3)
train_task(model_naive, loader_a, opt, n_epochs)
acc_a_after_a = evaluate(model_naive, test_loader_a)
print(f"After Task A: Acc_A={acc_a_after_a:.4f}")
train_task(model_naive, loader_b, opt, n_epochs)
acc_a_after_b = evaluate(model_naive, test_loader_a)
acc_b_after_b = evaluate(model_naive, test_loader_b)
print(f"After Task B: Acc_A={acc_a_after_b:.4f}, Acc_B={acc_b_after_b:.4f}")
print(f"Forgetting: {acc_a_after_a - acc_a_after_b:.4f}")
results['Naive'] = (acc_a_after_a, acc_a_after_b, acc_b_after_b)
# Strategy 2: EWC
print("\n=== Strategy 2: EWC (λ=400) ===")
model_ewc = SimpleCNN().to(device)
opt = torch.optim.Adam(model_ewc.parameters(), lr=1e-3)
train_task(model_ewc, loader_a, opt, n_epochs)
acc_a_after_a = evaluate(model_ewc, test_loader_a)
print(f"After Task A: Acc_A={acc_a_after_a:.4f}")
# Compute Fisher matrix
ewc = EWC(model_ewc, loader_a, device)
train_task(model_ewc, loader_b, opt, n_epochs, ewc=ewc, ewc_lambda=400)
acc_a_after_b = evaluate(model_ewc, test_loader_a)
acc_b_after_b = evaluate(model_ewc, test_loader_b)
print(f"After Task B: Acc_A={acc_a_after_b:.4f}, Acc_B={acc_b_after_b:.4f}")
print(f"Forgetting: {acc_a_after_a - acc_a_after_b:.4f}")
results['EWC'] = (acc_a_after_a, acc_a_after_b, acc_b_after_b)
# Strategy 3: Experience Replay
print("\n=== Strategy 3: Experience Replay (buffer=200) ===")
model_er = SimpleCNN().to(device)
opt = torch.optim.Adam(model_er.parameters(), lr=1e-3)
buffer = ReplayBuffer(capacity=200)
train_task(model_er, loader_a, opt, n_epochs)
acc_a_after_a = evaluate(model_er, test_loader_a)
print(f"After Task A: Acc_A={acc_a_after_a:.4f}")
# Add Task A samples to buffer
buffer.add_from_loader(loader_a, n_samples=200)
train_task(model_er, loader_b, opt, n_epochs, replay_buffer=buffer)
acc_a_after_b = evaluate(model_er, test_loader_a)
acc_b_after_b = evaluate(model_er, test_loader_b)
print(f"After Task B: Acc_A={acc_a_after_b:.4f}, Acc_B={acc_b_after_b:.4f}")
print(f"Forgetting: {acc_a_after_a - acc_a_after_b:.4f}")
results['Replay'] = (acc_a_after_a, acc_a_after_b, acc_b_after_b)
# --- 7. Visualization ---
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Left plot: Task A accuracy (forgetting extent)
strategies = list(results.keys())
acc_before = [results[s][0] for s in strategies]
acc_after = [results[s][1] for s in strategies]
x_pos = np.arange(len(strategies))
width = 0.35
bars1 = axes[0].bar(x_pos - width/2, acc_before, width, label='After Task A', color='#0077b6')
bars2 = axes[0].bar(x_pos + width/2, acc_after, width, label='After Task B', color='#e63946')
axes[0].set_ylabel('Task A Accuracy')
axes[0].set_title('Catastrophic Forgetting: Task A Performance', fontsize=13)
axes[0].set_xticks(x_pos)
axes[0].set_xticklabels(strategies)
axes[0].legend()
axes[0].set_ylim(0, 1.05)
axes[0].grid(True, alpha=0.3, axis='y')
for bar in bars1:
axes[0].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=10)
for bar in bars2:
axes[0].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=10)
# Right plot: Forgetting amount
forgetting = [results[s][0] - results[s][1] for s in strategies]
colors = ['#e63946' if f > 0.3 else '#b8922e' if f > 0.1 else '#2a9d8f' for f in forgetting]
bars = axes[1].bar(strategies, forgetting, color=colors, edgecolor='white', linewidth=1.5)
axes[1].set_ylabel('Forgetting (↓ better)')
axes[1].set_title('Amount of Forgetting on Task A', fontsize=13)
axes[1].grid(True, alpha=0.3, axis='y')
for bar, f in zip(bars, forgetting):
axes[1].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
f'{f:.3f}', ha='center', va='bottom', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()
# --- 8. Per-Class Accuracy Analysis ---
print("\n=== Per-class Accuracy After Task B ===")
print(f"{'Class':<8} {'Naive':>8} {'EWC':>8} {'Replay':>8}")
print("-" * 36)
models = {'Naive': model_naive, 'EWC': model_ewc, 'Replay': model_er}
for digit in range(10):
test_digit = filter_by_labels(test_data, [digit])
loader_digit = torch.utils.data.DataLoader(test_digit, batch_size=256)
accs = []
for name in ['Naive', 'EWC', 'Replay']:
acc = evaluate(models[name], loader_digit)
accs.append(acc)
marker = " ← Task A" if digit < 5 else ""
print(f" {digit:<6} {accs[0]:>8.1%} {accs[1]:>8.1%} {accs[2]:>8.1%}{marker}")
print("\nLab 1 Complete!")
8. Hands-on Lab 2: BERT Continual Learning for Multi-Task Text Classification (Google Colab)
The following experiment demonstrates catastrophic forgetting in language models: sequentially fine-tuning BERT on two text classification tasks, observing the forgetting phenomenon, and mitigating it with experience replay.
# ============================================================
# Lab 2: BERT Continual Learning — Catastrophic Forgetting and Mitigation in Multi-Task Text Classification
# Environment: Google Colab (GPU recommended, CPU works but slower)
# ============================================================
# --- 0. Installation ---
!pip install -q transformers datasets torch
import torch
import torch.nn.functional as F
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
import random
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# --- 1. Load Two Text Classification Tasks ---
print("\n--- Loading Datasets ---")
# Task A: SST-2 Sentiment Classification (Positive/Negative)
sst2 = load_dataset("glue", "sst2")
print(f"Task A (SST-2): {len(sst2['train'])} train, {len(sst2['validation'])} val")
# Task B: MRPC Semantic Equivalence (Equivalent/Not Equivalent)
mrpc = load_dataset("glue", "mrpc")
print(f"Task B (MRPC): {len(mrpc['train'])} train, {len(mrpc['validation'])} val")
# --- 2. Tokenizer ---
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def tokenize_sst2(examples):
return tokenizer(examples['sentence'], truncation=True, padding='max_length', max_length=64)
def tokenize_mrpc(examples):
return tokenizer(examples['sentence1'], examples['sentence2'],
truncation=True, padding='max_length', max_length=128)
sst2_tok = sst2.map(tokenize_sst2, batched=True)
mrpc_tok = mrpc.map(tokenize_mrpc, batched=True)
for ds in [sst2_tok, mrpc_tok]:
ds.set_format("torch", columns=["input_ids", "attention_mask", "label"])
# Use subsets (Colab-friendly)
sst2_train = sst2_tok["train"].shuffle(seed=42).select(range(1000))
sst2_val = sst2_tok["validation"]
mrpc_train = mrpc_tok["train"].shuffle(seed=42).select(range(800))
mrpc_val = mrpc_tok["validation"]
# --- 3. Training and Evaluation Utilities ---
def make_loader(dataset, batch_size=16, shuffle=True):
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
def evaluate_task(model, loader):
model.eval()
correct, total = 0, 0
with torch.no_grad():
for batch in loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['label'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
preds = outputs.logits.argmax(dim=-1)
correct += (preds == labels).sum().item()
total += labels.size(0)
return correct / total
def train_epoch(model, loader, optimizer, replay_data=None, replay_ratio=0.3):
model.train()
total_loss = 0
for batch in loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['label'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
# Experience replay
if replay_data is not None and len(replay_data) > 0:
n_replay = max(1, int(input_ids.size(0) * replay_ratio))
indices = random.sample(range(len(replay_data)), min(n_replay, len(replay_data)))
r_ids = torch.stack([replay_data[i]['input_ids'] for i in indices]).to(device)
r_mask = torch.stack([replay_data[i]['attention_mask'] for i in indices]).to(device)
r_labels = torch.tensor([replay_data[i]['label'] for i in indices]).to(device)
r_outputs = model(input_ids=r_ids, attention_mask=r_mask, labels=r_labels)
loss = loss + r_outputs.loss
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
return total_loss / len(loader)
# --- 4. Create Replay Memory ---
def create_replay_buffer(dataset, n_samples=100):
"""Sample from dataset to create replay buffer"""
indices = random.sample(range(len(dataset)), min(n_samples, len(dataset)))
buffer = []
for i in indices:
item = dataset[i]
buffer.append({
'input_ids': item['input_ids'],
'attention_mask': item['attention_mask'],
'label': item['label'].item() if isinstance(item['label'], torch.Tensor) else item['label']
})
return buffer
# --- 5. Experiment 1: Naive Sequential Fine-tuning (Demonstrating Forgetting) ---
print("\n" + "="*60)
print("Experiment 1: Naive Sequential Fine-tuning")
print("="*60)
model_naive = BertForSequenceClassification.from_pretrained(
'bert-base-uncased', num_labels=2
).to(device)
opt = AdamW(model_naive.parameters(), lr=2e-5, weight_decay=0.01)
# Train Task A (SST-2)
print("\n--- Training on Task A (SST-2) ---")
loader_a = make_loader(sst2_train)
val_loader_a = make_loader(sst2_val, shuffle=False)
naive_history = {'task_a_on_a': [], 'task_a_on_b': [], 'task_b_on_b': []}
for epoch in range(3):
loss = train_epoch(model_naive, loader_a, opt)
acc = evaluate_task(model_naive, val_loader_a)
naive_history['task_a_on_a'].append(acc)
print(f" Epoch {epoch+1}: loss={loss:.4f}, SST-2 acc={acc:.4f}")
acc_a_before = naive_history['task_a_on_a'][-1]
# Train Task B (MRPC)
print("\n--- Training on Task B (MRPC) ---")
loader_b = make_loader(mrpc_train)
val_loader_b = make_loader(mrpc_val, shuffle=False)
for epoch in range(3):
loss = train_epoch(model_naive, loader_b, opt)
acc_a = evaluate_task(model_naive, val_loader_a)
acc_b = evaluate_task(model_naive, val_loader_b)
naive_history['task_a_on_b'].append(acc_a)
naive_history['task_b_on_b'].append(acc_b)
print(f" Epoch {epoch+1}: loss={loss:.4f}, SST-2 acc={acc_a:.4f}, MRPC acc={acc_b:.4f}")
acc_a_after_naive = naive_history['task_a_on_b'][-1]
acc_b_naive = naive_history['task_b_on_b'][-1]
# --- 6. Experiment 2: Experience Replay ---
print("\n" + "="*60)
print("Experiment 2: Experience Replay (buffer=100)")
print("="*60)
model_replay = BertForSequenceClassification.from_pretrained(
'bert-base-uncased', num_labels=2
).to(device)
opt = AdamW(model_replay.parameters(), lr=2e-5, weight_decay=0.01)
# Train Task A
print("\n--- Training on Task A (SST-2) ---")
replay_history = {'task_a_on_a': [], 'task_a_on_b': [], 'task_b_on_b': []}
for epoch in range(3):
loss = train_epoch(model_replay, loader_a, opt)
acc = evaluate_task(model_replay, val_loader_a)
replay_history['task_a_on_a'].append(acc)
print(f" Epoch {epoch+1}: loss={loss:.4f}, SST-2 acc={acc:.4f}")
# Create replay buffer
replay_buffer = create_replay_buffer(sst2_train, n_samples=100)
print(f"\nReplay buffer: {len(replay_buffer)} samples from Task A")
# Train Task B + Replay
print("\n--- Training on Task B (MRPC) with Replay ---")
for epoch in range(3):
loss = train_epoch(model_replay, loader_b, opt, replay_data=replay_buffer)
acc_a = evaluate_task(model_replay, val_loader_a)
acc_b = evaluate_task(model_replay, val_loader_b)
replay_history['task_a_on_b'].append(acc_a)
replay_history['task_b_on_b'].append(acc_b)
print(f" Epoch {epoch+1}: loss={loss:.4f}, SST-2 acc={acc_a:.4f}, MRPC acc={acc_b:.4f}")
acc_a_after_replay = replay_history['task_a_on_b'][-1]
acc_b_replay = replay_history['task_b_on_b'][-1]
# --- 7. Visualization Comparison ---
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
# Left plot: Task A accuracy over training
epochs_a = list(range(1, 4))
epochs_b = list(range(4, 7))
all_epochs = epochs_a + epochs_b
naive_a_curve = naive_history['task_a_on_a'] + naive_history['task_a_on_b']
replay_a_curve = replay_history['task_a_on_a'] + replay_history['task_a_on_b']
axes[0].plot(all_epochs, naive_a_curve, 'o-', color='#e63946', linewidth=2, label='Naive')
axes[0].plot(all_epochs, replay_a_curve, 's-', color='#0077b6', linewidth=2, label='Replay')
axes[0].axvline(x=3.5, color='gray', linestyle='--', alpha=0.5)
axes[0].text(2, 0.55, 'Task A\n(SST-2)', ha='center', fontsize=10, color='gray')
axes[0].text(5, 0.55, 'Task B\n(MRPC)', ha='center', fontsize=10, color='gray')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('SST-2 Accuracy')
axes[0].set_title('Task A (SST-2) Accuracy Over Time', fontsize=13)
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[0].set_ylim(0.5, 1.0)
# Middle plot: Final results comparison
categories = ['Task A\n(SST-2)', 'Task B\n(MRPC)']
naive_scores = [acc_a_after_naive, acc_b_naive]
replay_scores = [acc_a_after_replay, acc_b_replay]
x = np.arange(len(categories))
width = 0.35
axes[1].bar(x - width/2, naive_scores, width, label='Naive', color='#e63946', alpha=0.85)
axes[1].bar(x + width/2, replay_scores, width, label='Replay', color='#0077b6', alpha=0.85)
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Final Performance Comparison', fontsize=13)
axes[1].set_xticks(x)
axes[1].set_xticklabels(categories)
axes[1].legend()
axes[1].set_ylim(0.5, 1.0)
axes[1].grid(True, alpha=0.3, axis='y')
# Right plot: Forgetting amount
forgetting_naive = acc_a_before - acc_a_after_naive
forgetting_replay = replay_history['task_a_on_a'][-1] - acc_a_after_replay
bars = axes[2].bar(['Naive', 'Replay'], [forgetting_naive, forgetting_replay],
color=['#e63946', '#0077b6'], edgecolor='white', linewidth=1.5)
axes[2].set_ylabel('Forgetting (↓ better)')
axes[2].set_title('SST-2 Forgetting After MRPC Training', fontsize=13)
axes[2].grid(True, alpha=0.3, axis='y')
for bar, f in zip(bars, [forgetting_naive, forgetting_replay]):
axes[2].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.005,
f'{f:.3f}', ha='center', va='bottom', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()
# --- 8. Inference Demo ---
print("\n=== Inference Demo ===")
test_sentences = [
("This movie is absolutely wonderful!", "SST-2 (Positive)"),
("A terrible waste of time and money.", "SST-2 (Negative)"),
("The film was average, nothing special.", "SST-2 (Neutral-ish)"),
]
print("\n--- Naive Model ---")
model_naive.eval()
for text, label in test_sentences:
inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=64).to(device)
with torch.no_grad():
logits = model_naive(**inputs).logits
pred = "Positive" if logits.argmax().item() == 1 else "Negative"
conf = torch.softmax(logits, dim=-1).max().item()
print(f" [{pred} {conf:.1%}] {text} (expected: {label})")
print("\n--- Replay Model ---")
model_replay.eval()
for text, label in test_sentences:
inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=64).to(device)
with torch.no_grad():
logits = model_replay(**inputs).logits
pred = "Positive" if logits.argmax().item() == 1 else "Negative"
conf = torch.softmax(logits, dim=-1).max().item()
print(f" [{pred} {conf:.1%}] {text} (expected: {label})")
print("\nLab 2 Complete!")
9. Decision Framework: How Enterprises Should Choose Continual Learning Strategies
Based on data availability, privacy constraints, and computational budgets, enterprises can use the following framework to select the appropriate continual learning approach:
| Condition | Recommended Method | Rationale |
|---|---|---|
| Old data can be stored, sufficient memory | Experience Replay (ER / DER++)[9][14] | Simplest and most effective; 200 samples/task can significantly reduce forgetting |
| Privacy constraints, cannot store old data | EWC[3] + LwF[5] | Only need to store Fisher matrix or model snapshots, no raw data required |
| Many tasks and continuously growing | PackNet[7] or HAT[8] | Support multi-task within fixed model capacity, no extra storage needed |
| Few tasks but zero forgetting required | Progressive Networks[6] | Complete isolation, zero-forgetting guarantee, suitable for mission-critical scenarios |
| Language model continual fine-tuning | Experience Replay + learning rate scheduling[15] | Most effective for Transformer architectures; EWC has limited effect in NLP |
| Privacy constraints + sufficient compute | Generative Replay (GR)[10] | Generates virtual old data, balancing privacy and anti-forgetting |
Decision Tree:
1. Can real data from old tasks be stored?
├── Yes → Experience Replay (prefer DER++)
└── No → 2
2. Can model capacity grow?
├── Yes → Progressive Networks (zero forgetting)
└── No → 3
3. Is the compute budget sufficient?
├── Yes → Generative Replay (GAN/VAE generates virtual data)
└── No → EWC + LwF (only need Fisher matrix + old model snapshot)
10. Conclusion and Outlook
Catastrophic forgetting[1] is one of the core obstacles on deep learning's path toward true artificial intelligence. A system that cannot learn continuously — no matter how powerful — is merely a static tool, not an evolving intelligent entity.
Reviewing the key themes:
- Nature of the problem: The stability-plasticity dilemma[2] is a fundamental challenge for connectionist models, and indiscriminate gradient updates to shared parameters are the direct cause of forgetting
- Regularization approach: EWC[3] and SI[4] protect old knowledge guided by parameter importance, while LwF[5] protects output distributions through knowledge distillation
- Architecture approach: Progressive Networks[6] and PackNet[7] trade structural isolation for zero-forgetting guarantees
- Replay approach: Experience Replay[9] achieves the best anti-forgetting effect with minimal memory cost (a few hundred samples per task), and DER++[14] further integrates dark knowledge distillation
- Language model scenarios: Continual learning for large pre-trained models[16] is the most urgent current research direction, with experience replay being the most effective approach
Looking ahead, continual learning is transitioning from academic research to engineering practice. As Foundation Models become widespread[17], enterprises need models to continuously adapt to new data, tasks, and domains — rather than retraining from scratch each time. Sparse dynamic computation (MoE) models naturally allocate different expert sub-networks for different tasks, offering architectural potential for continual learning; parameter-efficient fine-tuning (LoRA fine-tuning, Adapters) freezes the backbone and trains only small modules, providing lightweight dedicated adaptation for each task — which is essentially a continual learning strategy in itself. When AI systems learn to continuously learn without forgetting, just like humans, we take another step closer to true general intelligence.



