Key Findings
  • Catastrophic forgetting[1] is a fundamental flaw in deep learning — when a model learns a new task, it drastically forgets old task knowledge, preventing AI from learning continuously like humans
  • Three major strategy categories for continual learning: Regularization methods (EWC[3], SI[4]) protect important parameters, Architecture methods (Progressive Networks[6]) dynamically expand networks, Replay methods (ER[9], GR[10]) revisit old task samples
  • Large language models also face catastrophic forgetting[15] — continually fine-tuning BERT/GPT rapidly erases prior task capabilities, and experience replay is currently the most effective mitigation strategy
  • This article includes two Google Colab hands-on labs: EWC preventing MNIST image classification forgetting, and BERT continual learning for multi-task text classification, both executable directly in the browser

1. Why AI Forgets: The Nature of Catastrophic Forgetting

Humans can continuously learn new skills without forgetting old knowledge — after learning to ride a bicycle, learning to swim does not cause you to forget how to ride. However, deep neural networks face a fundamental problem: Catastrophic Forgetting[1][2].

When a network already trained on Task A is subsequently trained on Task B, performance on Task A drops precipitously — not a gradual degradation, but a catastrophic collapse. This occurs because gradient descent indiscriminately updates all parameters to optimize the current task, overwriting weights that were critical for the old task.

The Nature of Catastrophic Forgetting:

Task A training complete: θ* = argmin_θ L_A(θ)
                     → Parameters converge to A's optimal solution

Task B sequential training: θ** = argmin_θ L_B(θ), starting from θ*
                     → Parameters move away from A's optimum, A's performance collapses

Root cause: Stability-Plasticity Dilemma
  - Too stable → Cannot learn new tasks (underfitting)
  - Too plastic → Forgets old tasks (catastrophic forgetting)
  - Goal: Find a balance between the two

Concrete example (image classification):
  Step 1: Train on digits 0-4 → Accuracy 98%
  Step 2: Train on digits 5-9 → 5-9 accuracy 97%, but 0-4 drops to ~20%
  Reason: Gradients for 5-9 destroy the key weights that distinguished 0-4

Catastrophic forgetting is not limited to image classification. It is equally severe during continual fine-tuning of language models[15]: a BERT fine-tuned for sentiment classification, when subsequently fine-tuned for named entity recognition, sees its sentiment classification capability severely degraded. This problem has become even more critical in the era of large language models — we want models to continuously learn from new data without retraining from scratch each time.

2. The Continual Learning Landscape: Three Major Strategy Categories

The research goal of Continual Learning (also known as Lifelong Learning) is to enable models to learn multiple tasks sequentially while both mastering new tasks and maintaining performance on old ones[12][13]. Based on different solution approaches, methods can be categorized into three main classes[17]:

Strategy CategoryCore IdeaRepresentative MethodsAdvantagesLimitations
Regularization MethodsAdd penalty terms to the loss function to constrain changes in important parametersEWC[3], SI[4], LwF[5]No need to store old data, fixed memory footprintProtection capability degrades as task count increases
Architecture MethodsAllocate different network structures or sub-networks for different tasksProgressive Nets[6], PackNet[7], HAT[8]Zero forgetting (hard isolation)Model size grows with task count
Replay MethodsStore or generate old task samples and co-train while learning new tasksER[9], GR[10], GEM[11], DER++[14]Simple, effective, composable with other methodsRequires additional memory for storing old samples

There are also three evaluation scenario levels for continual learning:

Three Continual Learning Scenarios:

1. Task-Incremental Learning (Task-IL):
   Task identity is known at inference time → Easiest
   Example: "This is Task B data, use B's classification head"

2. Domain-Incremental Learning (Domain-IL):
   Same task structure, but data distribution changes → Medium difficulty
   Example: Same 10-class classification, but image style changes from sketches to photos

3. Class-Incremental Learning (Class-IL):
   Task identity is unknown at inference, must distinguish among all learned classes → Hardest
   Example: Learn 0-4 first, then learn 5-9, at test time must distinguish all digits 0-9

Difficulty ranking: Task-IL < Domain-IL < Class-IL
In practical applications, Class-IL most closely matches real-world needs

3. Regularization Methods: EWC and Knowledge Distillation

Elastic Weight Consolidation (EWC)

EWC[3] is the most influential regularization method in continual learning, inspired by synaptic consolidation in neuroscience — important synaptic connections should be protected, while less important ones can be freely updated.

The core question is: how to measure each parameter's "importance" to old tasks? EWC's answer is the Fisher Information Matrix:

EWC Loss Function:

L_total(θ) = L_B(θ) + (λ/2) Σ_i F_i (θ_i - θ*_A,i)²

Where:
  L_B(θ):     Loss on new task B
  θ*_A:       Optimal parameters after training on old task A
  F_i:        Diagonal elements of Fisher Information Matrix (importance of parameter i to task A)
  λ:          Regularization strength (controls stability-plasticity balance)

Fisher Information Matrix (diagonal approximation):
  F_i = E_{x~D_A} [(∂ log p(y|x,θ) / ∂θ_i)²]

Intuitive understanding:
  F_i large → Parameter i is very important to task A → Strongly constrain its changes
  F_i small → Parameter i is not important to task A → Free to update for learning task B

Geometric perspective:
  There exists a "low-loss valley" around task A's optimum θ*_A
  The Fisher matrix describes the shape of this valley
  EWC guides task B's optimization to move along the valley's extension
  → Find parameters that work well for both A and B

Synaptic Intelligence (SI)

SI[4] is an online alternative to EWC. While EWC needs to compute the Fisher matrix after each task, SI accumulates each parameter's importance in real-time during training — tracking each parameter's contribution to loss reduction along the "path traveled" during training.

Learning without Forgetting (LwF)

LwF[5] takes a different approach — instead of protecting parameters, it protects outputs. Before learning a new task, it first passes new task data through the old model to obtain "soft labels," then during new task training, it simultaneously uses a knowledge distillation loss to keep the old task's output distribution unchanged. Its greatest advantage is that it requires no storage of old task data whatsoever.

4. Architecture Methods: Progressive Networks and Dynamic Expansion

The philosophy of architecture methods is: rather than struggling to balance new and old tasks within a limited parameter space, allocate dedicated network capacity for each task.

Progressive Neural Networks

The Progressive Networks proposed by Rusu et al.[6] is the most straightforward approach — for each new task, a new network column "grows" alongside the existing ones, with lateral connections allowing the new task to reuse features learned by old tasks:

Progressive Neural Networks:

Task 1:  [Column 1] ← Normal training
Task 2:  [Column 1] (frozen) ←─ lateral connections ──→ [Column 2] ← Only this is trained
Task 3:  [Column 1] (frozen) ←─┐                [Column 2] (frozen) ←─┐
                              └─ lateral connections ──→                      └─→ [Column 3]

Pros: Absolute zero forgetting (old columns are frozen)
Cons: Parameter count grows linearly (T tasks = T× parameters)

PackNet and HAT

PackNet[7] and HAT[8] attempt to achieve multi-task learning within a fixed-size network:

MethodStrategyMechanismCharacteristics
PackNet[7]Iterative pruningTrain → prune unimportant weights → free capacity for next taskEach task gets a dedicated sparse sub-network
HAT[8]Hard attention masksLearn binary masks for each task, protecting occupied neuronsMasks are gradient-optimizable, automatically allocating capacity

5. Replay Methods: Memory Buffers and Generative Replay

Experience Replay methods draw inspiration from memory consolidation in cognitive science — humans "replay" daytime experiences during sleep to consolidate memories. In continual learning, replay methods mix in old task samples while learning new tasks[9].

Experience Replay (ER)

The most straightforward approach: maintain a fixed-size memory buffer storing a small number of representative samples from each old task. When learning a new task, each mini-batch mixes new task data with samples drawn from the buffer:

Experience Replay Flow:

Memory Buffer M (fixed size, e.g., 200 samples)

Learning task t:
  for each mini-batch:
    batch_new = sample(D_t)           # New task data
    batch_old = sample(M)             # Sample old data from buffer
    loss = L(batch_new) + L(batch_old)  # Joint loss
    Update θ

  After task completion:
    Add representative samples from D_t to M (using reservoir sampling or herding)

Reservoir Sampling:
  With probability |M| / n, add the n-th sample to the buffer,
  ensuring each previously seen sample has an equal chance of being selected

Key Finding (Rolnick et al., 2019):
  Just 1-5 samples/class can significantly reduce forgetting
  → Minimal memory cost yields substantial anti-forgetting effects

Generative Replay (GR)

Shin et al.[10] proposed an elegant alternative: instead of storing real old data, train a generative model (such as a GAN or VAE) to generate virtual samples of old tasks. This is particularly valuable in privacy-sensitive scenarios — medical data cannot be stored, but its distribution can be reconstructed using a generative model.

GEM and DER++

GEM[11] (Gradient Episodic Memory) uses samples in memory to compute gradient constraints: new task gradient updates must not increase old task loss on memory samples. DER++[14] combines experience replay with knowledge distillation — replaying not only old data labels but also the old model's soft outputs (logits), preserving richer information in the form of "dark knowledge."

6. Continual Learning for Text AI: Continual Fine-Tuning of Language Models

Continual learning for large language models is at the forefront of current research[16]. When enterprises want BERT or GPT to continuously adapt to new tasks or domains, catastrophic forgetting severely impacts existing capabilities:

Continual Learning Scenarios for Language Models:

1. Continual Task Fine-tuning:
   BERT → Sentiment Analysis → NER → QA → Text Summarization
   Problem: Later fine-tuning destroys earlier task capabilities

2. Continual Domain Adaptation:
   General LLM → Financial domain → Legal domain → Medical domain
   Problem: New domain knowledge overwrites old domain expertise

3. Continual Pre-training:
   Foundation model → Continuously absorbing new documents/knowledge
   Problem: New knowledge may damage foundational language understanding

Unique Challenges of Language Model Forgetting:
  - Extremely high parameter sharing (all tasks share the same Transformer)
  - More severe representation space interference (high semantic overlap)
  - Task heads can be separated, but underlying representations are hard to isolate

Scialom et al.'s research[15] demonstrates that experience replay is currently the most effective method for language model continual learning — mixing in small amounts of old task samples when learning new tasks significantly reduces forgetting. This is more effective than regularization methods like EWC in NLP scenarios because the parameter importance distribution for language tasks is more uniform, limiting the discriminative power of regularization constraints.

7. Hands-on Lab 1: EWC Preventing MNIST Image Classification Forgetting (Google Colab)

The following experiment compares three strategies on Split MNIST: (1) Naive fine-tuning, (2) EWC regularization, and (3) Experience Replay (ER), visually demonstrating the catastrophic forgetting phenomenon and its mitigation.

# ============================================================
# Lab 1: Continual Learning — EWC vs Experience Replay vs Naive Fine-tuning (Split MNIST)
# Environment: Google Colab (CPU is sufficient)
# ============================================================
# --- 0. Installation ---
!pip install -q torch torchvision matplotlib

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import copy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# --- 1. Data Preparation: Split MNIST ---
# Task A: Digits 0-4, Task B: Digits 5-9
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_data = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = torchvision.datasets.MNIST('./data', train=False, transform=transform)

def filter_by_labels(dataset, labels):
    """Filter data by specific labels"""
    mask = torch.zeros(len(dataset.targets), dtype=torch.bool)
    for l in labels:
        mask |= (dataset.targets == l)
    indices = mask.nonzero(as_tuple=True)[0]
    return torch.utils.data.Subset(dataset, indices)

task_a_labels = [0, 1, 2, 3, 4]
task_b_labels = [5, 6, 7, 8, 9]

train_a = filter_by_labels(train_data, task_a_labels)
train_b = filter_by_labels(train_data, task_b_labels)
test_a = filter_by_labels(test_data, task_a_labels)
test_b = filter_by_labels(test_data, task_b_labels)

loader_a = torch.utils.data.DataLoader(train_a, batch_size=128, shuffle=True)
loader_b = torch.utils.data.DataLoader(train_b, batch_size=128, shuffle=True)
test_loader_a = torch.utils.data.DataLoader(test_a, batch_size=256)
test_loader_b = torch.utils.data.DataLoader(test_b, batch_size=256)

print(f"Task A (digits 0-4): {len(train_a)} train, {len(test_a)} test")
print(f"Task B (digits 5-9): {len(train_b)} train, {len(test_b)} test")

# --- 2. Simple CNN Model ---
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(64 * 7 * 7, 256)
        self.fc2 = nn.Linear(256, 10)  # All 10 classes share the output

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            pred = model(x).argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return correct / total

# --- 3. EWC Implementation ---
class EWC:
    def __init__(self, model, dataloader, device, n_samples=200):
        self.params = {n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad}
        self.fisher = self._compute_fisher(model, dataloader, device, n_samples)

    def _compute_fisher(self, model, dataloader, device, n_samples):
        """Compute Fisher Information Matrix (diagonal approximation)"""
        fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
        model.eval()
        count = 0
        for x, y in dataloader:
            if count >= n_samples:
                break
            x, y = x.to(device), y.to(device)
            model.zero_grad()
            output = model(x)
            loss = F.cross_entropy(output, y)
            loss.backward()
            for n, p in model.named_parameters():
                if p.requires_grad and p.grad is not None:
                    fisher[n] += p.grad.data.pow(2) * x.size(0)
            count += x.size(0)
        fisher = {n: f / count for n, f in fisher.items()}
        return fisher

    def penalty(self, model):
        """EWC regularization term"""
        loss = 0
        for n, p in model.named_parameters():
            if p.requires_grad and n in self.fisher:
                loss += (self.fisher[n] * (p - self.params[n]).pow(2)).sum()
        return loss

# --- 4. Experience Replay Memory Buffer ---
class ReplayBuffer:
    def __init__(self, capacity=200):
        self.capacity = capacity
        self.buffer_x = []
        self.buffer_y = []

    def add_from_loader(self, loader, n_samples):
        """Sample randomly from loader and add to buffer"""
        all_x, all_y = [], []
        for x, y in loader:
            all_x.append(x)
            all_y.append(y)
        all_x = torch.cat(all_x)
        all_y = torch.cat(all_y)
        indices = torch.randperm(len(all_x))[:n_samples]
        self.buffer_x.append(all_x[indices])
        self.buffer_y.append(all_y[indices])

    def sample(self, batch_size):
        all_x = torch.cat(self.buffer_x)
        all_y = torch.cat(self.buffer_y)
        indices = torch.randperm(len(all_x))[:batch_size]
        return all_x[indices], all_y[indices]

# --- 5. Training Function ---
def train_task(model, loader, optimizer, epochs, ewc=None, ewc_lambda=0,
               replay_buffer=None, replay_batch=32):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()

            output = model(x)
            loss = F.cross_entropy(output, y)

            # EWC regularization
            if ewc is not None:
                loss += ewc_lambda * ewc.penalty(model)

            # Experience replay
            if replay_buffer is not None and len(replay_buffer.buffer_x) > 0:
                rx, ry = replay_buffer.sample(replay_batch)
                rx, ry = rx.to(device), ry.to(device)
                r_output = model(rx)
                loss += F.cross_entropy(r_output, ry)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

# --- 6. Experiment: Comparing Three Strategies ---
n_epochs = 5
results = {}

# Strategy 1: Naive Fine-tuning
print("\n=== Strategy 1: Naive Fine-tuning ===")
model_naive = SimpleCNN().to(device)
opt = torch.optim.Adam(model_naive.parameters(), lr=1e-3)

train_task(model_naive, loader_a, opt, n_epochs)
acc_a_after_a = evaluate(model_naive, test_loader_a)
print(f"After Task A: Acc_A={acc_a_after_a:.4f}")

train_task(model_naive, loader_b, opt, n_epochs)
acc_a_after_b = evaluate(model_naive, test_loader_a)
acc_b_after_b = evaluate(model_naive, test_loader_b)
print(f"After Task B: Acc_A={acc_a_after_b:.4f}, Acc_B={acc_b_after_b:.4f}")
print(f"Forgetting: {acc_a_after_a - acc_a_after_b:.4f}")
results['Naive'] = (acc_a_after_a, acc_a_after_b, acc_b_after_b)

# Strategy 2: EWC
print("\n=== Strategy 2: EWC (λ=400) ===")
model_ewc = SimpleCNN().to(device)
opt = torch.optim.Adam(model_ewc.parameters(), lr=1e-3)

train_task(model_ewc, loader_a, opt, n_epochs)
acc_a_after_a = evaluate(model_ewc, test_loader_a)
print(f"After Task A: Acc_A={acc_a_after_a:.4f}")

# Compute Fisher matrix
ewc = EWC(model_ewc, loader_a, device)

train_task(model_ewc, loader_b, opt, n_epochs, ewc=ewc, ewc_lambda=400)
acc_a_after_b = evaluate(model_ewc, test_loader_a)
acc_b_after_b = evaluate(model_ewc, test_loader_b)
print(f"After Task B: Acc_A={acc_a_after_b:.4f}, Acc_B={acc_b_after_b:.4f}")
print(f"Forgetting: {acc_a_after_a - acc_a_after_b:.4f}")
results['EWC'] = (acc_a_after_a, acc_a_after_b, acc_b_after_b)

# Strategy 3: Experience Replay
print("\n=== Strategy 3: Experience Replay (buffer=200) ===")
model_er = SimpleCNN().to(device)
opt = torch.optim.Adam(model_er.parameters(), lr=1e-3)
buffer = ReplayBuffer(capacity=200)

train_task(model_er, loader_a, opt, n_epochs)
acc_a_after_a = evaluate(model_er, test_loader_a)
print(f"After Task A: Acc_A={acc_a_after_a:.4f}")

# Add Task A samples to buffer
buffer.add_from_loader(loader_a, n_samples=200)

train_task(model_er, loader_b, opt, n_epochs, replay_buffer=buffer)
acc_a_after_b = evaluate(model_er, test_loader_a)
acc_b_after_b = evaluate(model_er, test_loader_b)
print(f"After Task B: Acc_A={acc_a_after_b:.4f}, Acc_B={acc_b_after_b:.4f}")
print(f"Forgetting: {acc_a_after_a - acc_a_after_b:.4f}")
results['Replay'] = (acc_a_after_a, acc_a_after_b, acc_b_after_b)

# --- 7. Visualization ---
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left plot: Task A accuracy (forgetting extent)
strategies = list(results.keys())
acc_before = [results[s][0] for s in strategies]
acc_after = [results[s][1] for s in strategies]

x_pos = np.arange(len(strategies))
width = 0.35
bars1 = axes[0].bar(x_pos - width/2, acc_before, width, label='After Task A', color='#0077b6')
bars2 = axes[0].bar(x_pos + width/2, acc_after, width, label='After Task B', color='#e63946')

axes[0].set_ylabel('Task A Accuracy')
axes[0].set_title('Catastrophic Forgetting: Task A Performance', fontsize=13)
axes[0].set_xticks(x_pos)
axes[0].set_xticklabels(strategies)
axes[0].legend()
axes[0].set_ylim(0, 1.05)
axes[0].grid(True, alpha=0.3, axis='y')

for bar in bars1:
    axes[0].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
                 f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=10)
for bar in bars2:
    axes[0].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
                 f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=10)

# Right plot: Forgetting amount
forgetting = [results[s][0] - results[s][1] for s in strategies]
colors = ['#e63946' if f > 0.3 else '#b8922e' if f > 0.1 else '#2a9d8f' for f in forgetting]
bars = axes[1].bar(strategies, forgetting, color=colors, edgecolor='white', linewidth=1.5)

axes[1].set_ylabel('Forgetting (↓ better)')
axes[1].set_title('Amount of Forgetting on Task A', fontsize=13)
axes[1].grid(True, alpha=0.3, axis='y')

for bar, f in zip(bars, forgetting):
    axes[1].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
                 f'{f:.3f}', ha='center', va='bottom', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()

# --- 8. Per-Class Accuracy Analysis ---
print("\n=== Per-class Accuracy After Task B ===")
print(f"{'Class':<8} {'Naive':>8} {'EWC':>8} {'Replay':>8}")
print("-" * 36)

models = {'Naive': model_naive, 'EWC': model_ewc, 'Replay': model_er}
for digit in range(10):
    test_digit = filter_by_labels(test_data, [digit])
    loader_digit = torch.utils.data.DataLoader(test_digit, batch_size=256)
    accs = []
    for name in ['Naive', 'EWC', 'Replay']:
        acc = evaluate(models[name], loader_digit)
        accs.append(acc)
    marker = " ← Task A" if digit < 5 else ""
    print(f"  {digit:<6} {accs[0]:>8.1%} {accs[1]:>8.1%} {accs[2]:>8.1%}{marker}")

print("\nLab 1 Complete!")

8. Hands-on Lab 2: BERT Continual Learning for Multi-Task Text Classification (Google Colab)

The following experiment demonstrates catastrophic forgetting in language models: sequentially fine-tuning BERT on two text classification tasks, observing the forgetting phenomenon, and mitigating it with experience replay.

# ============================================================
# Lab 2: BERT Continual Learning — Catastrophic Forgetting and Mitigation in Multi-Task Text Classification
# Environment: Google Colab (GPU recommended, CPU works but slower)
# ============================================================
# --- 0. Installation ---
!pip install -q transformers datasets torch

import torch
import torch.nn.functional as F
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
import random

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# --- 1. Load Two Text Classification Tasks ---
print("\n--- Loading Datasets ---")

# Task A: SST-2 Sentiment Classification (Positive/Negative)
sst2 = load_dataset("glue", "sst2")
print(f"Task A (SST-2): {len(sst2['train'])} train, {len(sst2['validation'])} val")

# Task B: MRPC Semantic Equivalence (Equivalent/Not Equivalent)
mrpc = load_dataset("glue", "mrpc")
print(f"Task B (MRPC):  {len(mrpc['train'])} train, {len(mrpc['validation'])} val")

# --- 2. Tokenizer ---
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize_sst2(examples):
    return tokenizer(examples['sentence'], truncation=True, padding='max_length', max_length=64)

def tokenize_mrpc(examples):
    return tokenizer(examples['sentence1'], examples['sentence2'],
                     truncation=True, padding='max_length', max_length=128)

sst2_tok = sst2.map(tokenize_sst2, batched=True)
mrpc_tok = mrpc.map(tokenize_mrpc, batched=True)

for ds in [sst2_tok, mrpc_tok]:
    ds.set_format("torch", columns=["input_ids", "attention_mask", "label"])

# Use subsets (Colab-friendly)
sst2_train = sst2_tok["train"].shuffle(seed=42).select(range(1000))
sst2_val = sst2_tok["validation"]
mrpc_train = mrpc_tok["train"].shuffle(seed=42).select(range(800))
mrpc_val = mrpc_tok["validation"]

# --- 3. Training and Evaluation Utilities ---
def make_loader(dataset, batch_size=16, shuffle=True):
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

def evaluate_task(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = outputs.logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total

def train_epoch(model, loader, optimizer, replay_data=None, replay_ratio=0.3):
    model.train()
    total_loss = 0
    for batch in loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        # Experience replay
        if replay_data is not None and len(replay_data) > 0:
            n_replay = max(1, int(input_ids.size(0) * replay_ratio))
            indices = random.sample(range(len(replay_data)), min(n_replay, len(replay_data)))

            r_ids = torch.stack([replay_data[i]['input_ids'] for i in indices]).to(device)
            r_mask = torch.stack([replay_data[i]['attention_mask'] for i in indices]).to(device)
            r_labels = torch.tensor([replay_data[i]['label'] for i in indices]).to(device)

            r_outputs = model(input_ids=r_ids, attention_mask=r_mask, labels=r_labels)
            loss = loss + r_outputs.loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(loader)

# --- 4. Create Replay Memory ---
def create_replay_buffer(dataset, n_samples=100):
    """Sample from dataset to create replay buffer"""
    indices = random.sample(range(len(dataset)), min(n_samples, len(dataset)))
    buffer = []
    for i in indices:
        item = dataset[i]
        buffer.append({
            'input_ids': item['input_ids'],
            'attention_mask': item['attention_mask'],
            'label': item['label'].item() if isinstance(item['label'], torch.Tensor) else item['label']
        })
    return buffer

# --- 5. Experiment 1: Naive Sequential Fine-tuning (Demonstrating Forgetting) ---
print("\n" + "="*60)
print("Experiment 1: Naive Sequential Fine-tuning")
print("="*60)

model_naive = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased', num_labels=2
).to(device)
opt = AdamW(model_naive.parameters(), lr=2e-5, weight_decay=0.01)

# Train Task A (SST-2)
print("\n--- Training on Task A (SST-2) ---")
loader_a = make_loader(sst2_train)
val_loader_a = make_loader(sst2_val, shuffle=False)

naive_history = {'task_a_on_a': [], 'task_a_on_b': [], 'task_b_on_b': []}

for epoch in range(3):
    loss = train_epoch(model_naive, loader_a, opt)
    acc = evaluate_task(model_naive, val_loader_a)
    naive_history['task_a_on_a'].append(acc)
    print(f"  Epoch {epoch+1}: loss={loss:.4f}, SST-2 acc={acc:.4f}")

acc_a_before = naive_history['task_a_on_a'][-1]

# Train Task B (MRPC)
print("\n--- Training on Task B (MRPC) ---")
loader_b = make_loader(mrpc_train)
val_loader_b = make_loader(mrpc_val, shuffle=False)

for epoch in range(3):
    loss = train_epoch(model_naive, loader_b, opt)
    acc_a = evaluate_task(model_naive, val_loader_a)
    acc_b = evaluate_task(model_naive, val_loader_b)
    naive_history['task_a_on_b'].append(acc_a)
    naive_history['task_b_on_b'].append(acc_b)
    print(f"  Epoch {epoch+1}: loss={loss:.4f}, SST-2 acc={acc_a:.4f}, MRPC acc={acc_b:.4f}")

acc_a_after_naive = naive_history['task_a_on_b'][-1]
acc_b_naive = naive_history['task_b_on_b'][-1]

# --- 6. Experiment 2: Experience Replay ---
print("\n" + "="*60)
print("Experiment 2: Experience Replay (buffer=100)")
print("="*60)

model_replay = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased', num_labels=2
).to(device)
opt = AdamW(model_replay.parameters(), lr=2e-5, weight_decay=0.01)

# Train Task A
print("\n--- Training on Task A (SST-2) ---")
replay_history = {'task_a_on_a': [], 'task_a_on_b': [], 'task_b_on_b': []}

for epoch in range(3):
    loss = train_epoch(model_replay, loader_a, opt)
    acc = evaluate_task(model_replay, val_loader_a)
    replay_history['task_a_on_a'].append(acc)
    print(f"  Epoch {epoch+1}: loss={loss:.4f}, SST-2 acc={acc:.4f}")

# Create replay buffer
replay_buffer = create_replay_buffer(sst2_train, n_samples=100)
print(f"\nReplay buffer: {len(replay_buffer)} samples from Task A")

# Train Task B + Replay
print("\n--- Training on Task B (MRPC) with Replay ---")
for epoch in range(3):
    loss = train_epoch(model_replay, loader_b, opt, replay_data=replay_buffer)
    acc_a = evaluate_task(model_replay, val_loader_a)
    acc_b = evaluate_task(model_replay, val_loader_b)
    replay_history['task_a_on_b'].append(acc_a)
    replay_history['task_b_on_b'].append(acc_b)
    print(f"  Epoch {epoch+1}: loss={loss:.4f}, SST-2 acc={acc_a:.4f}, MRPC acc={acc_b:.4f}")

acc_a_after_replay = replay_history['task_a_on_b'][-1]
acc_b_replay = replay_history['task_b_on_b'][-1]

# --- 7. Visualization Comparison ---
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Left plot: Task A accuracy over training
epochs_a = list(range(1, 4))
epochs_b = list(range(4, 7))
all_epochs = epochs_a + epochs_b

naive_a_curve = naive_history['task_a_on_a'] + naive_history['task_a_on_b']
replay_a_curve = replay_history['task_a_on_a'] + replay_history['task_a_on_b']

axes[0].plot(all_epochs, naive_a_curve, 'o-', color='#e63946', linewidth=2, label='Naive')
axes[0].plot(all_epochs, replay_a_curve, 's-', color='#0077b6', linewidth=2, label='Replay')
axes[0].axvline(x=3.5, color='gray', linestyle='--', alpha=0.5)
axes[0].text(2, 0.55, 'Task A\n(SST-2)', ha='center', fontsize=10, color='gray')
axes[0].text(5, 0.55, 'Task B\n(MRPC)', ha='center', fontsize=10, color='gray')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('SST-2 Accuracy')
axes[0].set_title('Task A (SST-2) Accuracy Over Time', fontsize=13)
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[0].set_ylim(0.5, 1.0)

# Middle plot: Final results comparison
categories = ['Task A\n(SST-2)', 'Task B\n(MRPC)']
naive_scores = [acc_a_after_naive, acc_b_naive]
replay_scores = [acc_a_after_replay, acc_b_replay]

x = np.arange(len(categories))
width = 0.35

axes[1].bar(x - width/2, naive_scores, width, label='Naive', color='#e63946', alpha=0.85)
axes[1].bar(x + width/2, replay_scores, width, label='Replay', color='#0077b6', alpha=0.85)
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Final Performance Comparison', fontsize=13)
axes[1].set_xticks(x)
axes[1].set_xticklabels(categories)
axes[1].legend()
axes[1].set_ylim(0.5, 1.0)
axes[1].grid(True, alpha=0.3, axis='y')

# Right plot: Forgetting amount
forgetting_naive = acc_a_before - acc_a_after_naive
forgetting_replay = replay_history['task_a_on_a'][-1] - acc_a_after_replay

bars = axes[2].bar(['Naive', 'Replay'], [forgetting_naive, forgetting_replay],
                    color=['#e63946', '#0077b6'], edgecolor='white', linewidth=1.5)
axes[2].set_ylabel('Forgetting (↓ better)')
axes[2].set_title('SST-2 Forgetting After MRPC Training', fontsize=13)
axes[2].grid(True, alpha=0.3, axis='y')

for bar, f in zip(bars, [forgetting_naive, forgetting_replay]):
    axes[2].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.005,
                 f'{f:.3f}', ha='center', va='bottom', fontsize=13, fontweight='bold')

plt.tight_layout()
plt.show()

# --- 8. Inference Demo ---
print("\n=== Inference Demo ===")
test_sentences = [
    ("This movie is absolutely wonderful!", "SST-2 (Positive)"),
    ("A terrible waste of time and money.", "SST-2 (Negative)"),
    ("The film was average, nothing special.", "SST-2 (Neutral-ish)"),
]

print("\n--- Naive Model ---")
model_naive.eval()
for text, label in test_sentences:
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=64).to(device)
    with torch.no_grad():
        logits = model_naive(**inputs).logits
    pred = "Positive" if logits.argmax().item() == 1 else "Negative"
    conf = torch.softmax(logits, dim=-1).max().item()
    print(f"  [{pred} {conf:.1%}] {text}  (expected: {label})")

print("\n--- Replay Model ---")
model_replay.eval()
for text, label in test_sentences:
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=64).to(device)
    with torch.no_grad():
        logits = model_replay(**inputs).logits
    pred = "Positive" if logits.argmax().item() == 1 else "Negative"
    conf = torch.softmax(logits, dim=-1).max().item()
    print(f"  [{pred} {conf:.1%}] {text}  (expected: {label})")

print("\nLab 2 Complete!")

9. Decision Framework: How Enterprises Should Choose Continual Learning Strategies

Based on data availability, privacy constraints, and computational budgets, enterprises can use the following framework to select the appropriate continual learning approach:

ConditionRecommended MethodRationale
Old data can be stored, sufficient memoryExperience Replay (ER / DER++)[9][14]Simplest and most effective; 200 samples/task can significantly reduce forgetting
Privacy constraints, cannot store old dataEWC[3] + LwF[5]Only need to store Fisher matrix or model snapshots, no raw data required
Many tasks and continuously growingPackNet[7] or HAT[8]Support multi-task within fixed model capacity, no extra storage needed
Few tasks but zero forgetting requiredProgressive Networks[6]Complete isolation, zero-forgetting guarantee, suitable for mission-critical scenarios
Language model continual fine-tuningExperience Replay + learning rate scheduling[15]Most effective for Transformer architectures; EWC has limited effect in NLP
Privacy constraints + sufficient computeGenerative Replay (GR)[10]Generates virtual old data, balancing privacy and anti-forgetting
Decision Tree:

1. Can real data from old tasks be stored?
   ├── Yes → Experience Replay (prefer DER++)
   └── No → 2

2. Can model capacity grow?
   ├── Yes → Progressive Networks (zero forgetting)
   └── No → 3

3. Is the compute budget sufficient?
   ├── Yes → Generative Replay (GAN/VAE generates virtual data)
   └── No → EWC + LwF (only need Fisher matrix + old model snapshot)

10. Conclusion and Outlook

Catastrophic forgetting[1] is one of the core obstacles on deep learning's path toward true artificial intelligence. A system that cannot learn continuously — no matter how powerful — is merely a static tool, not an evolving intelligent entity.

Reviewing the key themes:

Looking ahead, continual learning is transitioning from academic research to engineering practice. As Foundation Models become widespread[17], enterprises need models to continuously adapt to new data, tasks, and domains — rather than retraining from scratch each time. Sparse dynamic computation (MoE) models naturally allocate different expert sub-networks for different tasks, offering architectural potential for continual learning; parameter-efficient fine-tuning (LoRA fine-tuning, Adapters) freezes the backbone and trains only small modules, providing lightweight dedicated adaptation for each task — which is essentially a continual learning strategy in itself. When AI systems learn to continuously learn without forgetting, just like humans, we take another step closer to true general intelligence.