Key Findings
  • 知識蒸餾讓小模型「學會」大模型的推理方式——DistilBERT 只用 BERT 60% 的參數,卻保留 97% 的語言理解能力;TinyBERT 更達到 7.5 倍縮小、9.4 倍加速
  • LLM 蒸餾已從「模仿輸出」進化到「傳承推理」——DeepSeek-R1 蒸餾的 14B 模型在數學推理上超越了 32B 的 QwQ-32B-Preview,證明小模型也能擁有深度思考能力
  • NVIDIA Minitron 結合剪枝與蒸餾,從 15B 模型衍生出 8B/4B 版本,訓練 token 用量僅為從頭訓練的 1/40,MMLU 提升最高 16%
  • 擴散模型蒸餾讓圖片生成從 50 步壓縮至 1-4 步——SDXL Turbo 實現即時單步生成,LCM 訓練僅需 32 A100 GPU hours,Flux.1-schnell 以 Apache 2.0 授權開放

一、AI 的「大模型依賴症」:當規模化撞上現實的牆

2025 年,AI 產業正面臨一個弔詭的局面:模型越來越大、越來越強,但能真正把它們部署到生產環境的企業卻沒有等比增長。GPT-4 級別的模型需要數百 GB 記憶體、數千美元的月度推論成本;即便是開源的 LLaMA-70B,也需要至少兩張 A100 才能流暢運行。Harvard Business Review 指出[1],AI 系統的碳排放正以驚人速度膨脹——若不加以優化,到 2030 年每年將額外排放數千萬噸 CO₂。

MIT Sloan Management Review 的研究[2]從另一個角度印證了這一點:小型化、精準化的 AI 部署往往比「越大越好」的策略帶來更高的商業回報。企業真正需要的不是最大的模型,而是一個剛好夠用、成本可控、能大規模部署的模型

這正是知識蒸餾(Knowledge Distillation)要解決的核心問題:如何讓一個小模型「學會」大模型的能力,而不只是簡單地把大模型縮小?

與剪枝(Pruning)直接「砍掉」冗餘參數不同,蒸餾是一種知識轉移技術——它訓練一個全新的小模型(學生),讓它從大模型(教師)的行為中學習。學生模型不需要和教師有相同的架構,甚至可以用完全不同的設計。這意味著蒸餾能產生真正小巧、針對特定任務高度優化的模型,而不只是一個「打了折」的大模型。

二、技術演進:從 Hinton 的直覺到 LLM 的推理蒸餾

2.1 經典知識蒸餾:溫度與暗知識(2015)

2015 年,Geoffrey Hinton、Oriol Vinyals 和 Jeff Dean 在一篇看似不起眼的 workshop 論文[3]中提出了知識蒸餾的完整框架。核心洞察極其優雅:

大模型的「知識」不僅藏在它的最終預測(hard label)裡,更藏在那些「錯誤」的預測機率中——這些軟目標(soft targets)包含了類別之間的相似性結構,即所謂的「暗知識」(dark knowledge)。

舉個直觀的例子:一個圖片分類器看到一張「2」的手寫數字,它可能以 90% 的機率預測「2」,但同時給「3」分配了 5% 的機率,給「7」分配了 3% 的機率。這些「錯誤」的機率其實蘊含著寶貴的資訊——「2」在視覺上確實和「3」、「7」有些相似。如果學生只學 hard label(「這是 2」),它學不到這些細微的類別關係;但如果它學習教師的完整機率分佈,它就能繼承教師對整個問題空間的理解。

Hinton 引入了溫度參數 T(Temperature)來控制軟目標的「軟度」。溫度越高,機率分佈越平滑,暗知識越明顯:

# 知識蒸餾的核心公式
# 標準 Softmax:q_i = exp(z_i) / Σ exp(z_j)
# 帶溫度的 Softmax:q_i = exp(z_i / T) / Σ exp(z_j / T)
#
# T = 1:標準 softmax(硬分佈)
# T > 1:更軟的分佈(暗知識更明顯)
# 典型值:T = 3~20

# 蒸餾損失 = α × KL(soft_teacher ∥ soft_student) + (1-α) × CE(hard_label, student)
# α:蒸餾損失與真實標籤損失的權重平衡

這篇論文的影響力遠超其發表的 workshop 形式——截至 2025 年,它已被引用超過 20,000 次,成為模型壓縮領域引用最多的論文之一。

2.2 BERT 時代:DistilBERT 與 TinyBERT

知識蒸餾的威力在 NLP 領域得到了最鮮明的驗證。2019 年,HuggingFace 團隊發表的 DistilBERT[4] 可能是蒸餾技術最成功的商業化案例:

DistilBERT 的訓練使用了三重損失函數:語言模型損失、蒸餾損失(教師與學生的軟目標 KL 散度)、以及隱藏層的餘弦距離損失。這種多層次的知識轉移讓學生不僅學到了教師的輸出,還學到了中間表徵的結構。

華為的 TinyBERT[5] 更進一步,將蒸餾分為兩個階段:通用預訓練蒸餾和任務特定蒸餾。除了輸出層的軟目標,TinyBERT 還對齊了注意力矩陣和隱藏狀態的中間表徵,實現了7.5 倍縮小和 9.4 倍加速,同時保留了 BERT-Base 96.8% 的表現。

2.3 LLM 時代:當蒸餾遇上千億參數

隨著 LLM 規模爆發,蒸餾面臨全新挑戰:教師模型太大、生成是自回歸的、訓練成本極高。2024 年,三篇重要論文分別從不同角度突破了這些瓶頸:

MiniLLM[6](ICLR 2024)發現了一個關鍵問題:傳統的前向 KL 散度會讓學生在教師給出低機率的區域「過度分散」注意力,導致生成品質下降。解決方案是改用反向 KL 散度——讓學生專注於教師最有信心的區域。實驗證明,從 GPT-2 到 LLaMA-13B,MiniLLM 在各個規模上都優於標準蒸餾。

GKD(Generalized Knowledge Distillation)[7](Google DeepMind,ICLR 2024)則解決了另一個根本問題:傳統蒸餾中學生是在教師的生成序列上訓練的(off-policy),但推論時學生必須基於自己的輸出繼續生成(on-policy)。這種訓練-推論分佈不匹配會隨著序列長度累積越來越大的誤差。GKD 的解法是讓學生用自己的生成結果搭配教師反饋來訓練——從自己犯的錯誤中學習,而不是死記教師的標準答案。

DistiLLM[8](ICML 2024)引入了 Skew KL 散度,在前向和反向 KL 之間取得平衡,並結合自適應離策略訓練策略,實現了比現有方法4.3 倍的訓練加速

NVIDIA 的 Minitron[9](NeurIPS 2024)則展示了剪枝與蒸餾的強大組合:先用結構化剪枝(移除深度、寬度、注意力頭、MLP 通道)從 15B 模型中削出 8B/4B 的骨架,再用知識蒸餾恢復品質。結果令人印象深刻——訓練 token 僅為從頭訓練的 1/40,MMLU 基準提升高達 16%。這證明蒸餾不僅是壓縮工具,更是一種高效的模型「培育」策略。

2.4 推理蒸餾:DeepSeek-R1 的範式突破

2025 年初,DeepSeek-R1[10] 將知識蒸餾推向了一個全新的維度:推理能力的蒸餾。傳統蒸餾傳遞的是「答案」,DeepSeek-R1 傳遞的是「思考過程」。

研究團隊從 DeepSeek-R1(透過強化學習獲得推理能力的大型模型)中生成了 800K 條包含完整推理鏈(chain-of-thought)的訓練樣本,用這些樣本微調了 6 個開源模型(基於 Qwen2.5 和 Llama3,規模從 1.5B 到 70B)。結果震驚了整個社群:

DeepSeek-R1 的啟示是深遠的:知識蒸餾不再只是「壓縮」——它是一種能力遺傳機制,可以將大模型透過昂貴訓練(如 RL)獲得的高階能力,以極低成本「遺傳」給小模型。

三、實證數據:蒸餾壓縮效果全覽

模型技術壓縮效果品質保留來源
BERT → DistilBERT三重損失蒸餾40% 縮小、60% 加速保留 97% GLUE 分數Sanh et al., 2019
BERT → TinyBERT兩階段蒸餾7.5x 縮小、9.4x 加速保留 96.8% 表現Jiao et al., 2020
Nemotron 15B → 8B/4B剪枝 + 蒸餾1/40 訓練 tokenMMLU +16%(vs. 從頭訓練)Muralidharan et al., 2024
GPT-2/LLaMA 系列MiniLLM(反向 KL)120M–13B 全規模蒸餾優於標準 KL 蒸餾Gu et al., 2024
DeepSeek-R1 → 14B推理鏈蒸餾14B 學生數學推理超越 QwQ-32BDeepSeek-AI, 2025
SD v1.4 → BK-SDMBlock 剪枝 + 蒸餾參數減 30-50%FID 持平甚至更優Kim et al., 2024
SDXL → SDXL Turbo對抗蒸餾(ADD)50 步 → 1 步單步超越 LCM/GANsSauer et al., 2023
SD → LCM潛空間一致性蒸餾50 步 → 2-4 步僅 32 A100 hours 訓練Luo et al., 2023
Flux.1-pro → schnellLADD(潛空間對抗蒸餾)20-50 步 → 1-4 步高品質、Apache 2.0 開源Sauer et al., 2024

四、決策框架:蒸餾的收益、代價與適用邊界

知識蒸餾與剪枝、量化是模型壓縮的三大支柱,但它們的適用場景不同。在決定是否採用蒸餾之前,需要理解它獨特的優勢和限制:

維度直接使用教師模型蒸餾後的學生模型
模型大小完整參數(如 BERT: 110M, LLaMA: 70B)可自由設計學生架構;典型壓縮 2-10x
推論速度基準速度結構性加速(非稀疏依賴):2-9x 加速
精度完整精度保留 95-99% 教師能力(任務依賴)
訓練成本教師需完整訓練學生訓練成本遠低於從頭訓練(Minitron: 1/40 token)
架構靈活性固定架構學生可用完全不同的架構設計
部署彈性僅限高階 GPU可部署至 CPU、手機、邊緣裝置

策略性優勢

必須管理的風險

蒸餾 vs. 剪枝:如何選擇?

場景推薦技術原因
快速壓縮現有 LLM剪枝(Wanda/SparseGPT)無需重新訓練,小時級完成
需要改變模型架構蒸餾學生可用完全不同的架構
追求最佳壓縮效果剪枝 + 蒸餾(Minitron)先剪枝削骨架,再蒸餾恢復品質
傳遞高階推理能力推理蒸餾(DeepSeek-R1 式)推理鏈蒸餾是唯一被驗證的方法
擴散模型步數壓縮蒸餾(LCM/ADD/LADD)步數蒸餾是壓縮生成步數的主要手段
預算有限、需要快速驗證剪枝蒸餾需要額外的訓練資源和時間

五、Hands-on Lab:Google Colab 線上實驗室(CV 模型蒸餾)

讓我們從最經典的場景開始:用一個大的 ResNet-34 作為教師,蒸餾出一個小的 ResNet-18 學生。你會看到蒸餾如何讓學生超越「直接訓練」的表現上限。所有程式碼可直接在 Google Colab 免費 GPU 上執行。

打開 Google Colab,新建 Notebook,依序貼入以下程式碼:

5.1 Step 1 — 訓練教師模型 ResNet-34(約 5 分鐘)

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import time, copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"裝置: {device}")

# ---- 資料集 ----
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=256,
                                         shuffle=False, num_workers=2)

# ---- 教師模型:ResNet-34(適配 CIFAR-10 的 32×32 輸入)----
teacher = models.resnet34(weights=None, num_classes=10)
teacher.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
teacher.maxpool = nn.Identity()
teacher = teacher.to(device)

# ---- 訓練教師 15 epochs ----
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15)

print("訓練教師模型 ResNet-34...")
for epoch in range(15):
    teacher.train()
    for inputs, targets in trainloader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = teacher(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
    scheduler.step()
    if (epoch + 1) % 5 == 0:
        print(f"  Epoch {epoch+1}/15 完成")

teacher.eval()
print("✓ 教師模型訓練完成")

5.2 Step 2 — 評估工具函數

def evaluate(model, dataloader, device):
    """計算測試集準確率"""
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    return 100. * correct / total

def count_params(model):
    """計算模型參數量"""
    return sum(p.numel() for p in model.parameters())

def measure_speed(model, device, input_size=(1, 3, 32, 32), n_runs=200):
    """測量推論延遲(ms)"""
    model.eval()
    dummy = torch.randn(*input_size).to(device)
    for _ in range(50):
        with torch.no_grad():
            model(dummy)
    if device.type == 'cuda':
        torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(n_runs):
        with torch.no_grad():
            model(dummy)
    if device.type == 'cuda':
        torch.cuda.synchronize()
    return (time.perf_counter() - start) / n_runs * 1000

# ---- 教師基準 ----
teacher_acc = evaluate(teacher, testloader, device)
teacher_params = count_params(teacher)
teacher_speed = measure_speed(teacher, device)

print(f"{'='*55}")
print(f"  教師模型 ResNet-34")
print(f"{'='*55}")
print(f"  準確率:   {teacher_acc:.2f}%")
print(f"  參數量:   {teacher_params:,}")
print(f"  延遲:     {teacher_speed:.2f} ms")
print(f"{'='*55}")

5.3 Step 3 — 知識蒸餾核心:訓練學生模型

def distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7):
    """
    ★ 知識蒸餾損失函數 ★
    - T: 溫度參數(越高,軟目標越平滑,暗知識越明顯)
    - alpha: 蒸餾損失的權重(1-alpha 為硬標籤損失的權重)
    """
    # 軟目標蒸餾損失(KL 散度)
    soft_loss = F.kl_div(
        F.log_softmax(student_logits / T, dim=1),
        F.softmax(teacher_logits / T, dim=1),
        reduction='batchmean'
    ) * (T * T)  # 乘以 T² 補償梯度縮放

    # 硬標籤交叉熵損失
    hard_loss = F.cross_entropy(student_logits, labels)

    return alpha * soft_loss + (1 - alpha) * hard_loss

# ---- 建立學生模型:ResNet-18(比教師小約 50%)----
student_distilled = models.resnet18(weights=None, num_classes=10)
student_distilled.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
student_distilled.maxpool = nn.Identity()
student_distilled = student_distilled.to(device)

# ---- 蒸餾訓練 15 epochs ----
optimizer_kd = optim.SGD(student_distilled.parameters(), lr=0.1,
                          momentum=0.9, weight_decay=5e-4)
scheduler_kd = optim.lr_scheduler.CosineAnnealingLR(optimizer_kd, T_max=15)

print("使用知識蒸餾訓練學生模型 ResNet-18...")
for epoch in range(15):
    student_distilled.train()
    teacher.eval()
    for inputs, targets in trainloader:
        inputs, targets = inputs.to(device), targets.to(device)

        # 教師推論(不計算梯度)
        with torch.no_grad():
            teacher_logits = teacher(inputs)

        # 學生推論
        student_logits = student_distilled(inputs)

        # ★ 蒸餾損失 ★
        loss = distillation_loss(
            student_logits, teacher_logits, targets,
            T=4.0,     # 溫度 4(讓暗知識更明顯)
            alpha=0.7   # 70% 蒸餾損失 + 30% 硬標籤損失
        )

        optimizer_kd.zero_grad()
        loss.backward()
        optimizer_kd.step()
    scheduler_kd.step()
    if (epoch + 1) % 5 == 0:
        print(f"  Epoch {epoch+1}/15 完成")

print("✓ 蒸餾訓練完成")

5.4 Step 4 — 對照組:直接訓練學生(無蒸餾)

# 同樣的 ResNet-18,但不用蒸餾,直接用硬標籤訓練
student_baseline = models.resnet18(weights=None, num_classes=10)
student_baseline.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
student_baseline.maxpool = nn.Identity()
student_baseline = student_baseline.to(device)

optimizer_bl = optim.SGD(student_baseline.parameters(), lr=0.1,
                          momentum=0.9, weight_decay=5e-4)
scheduler_bl = optim.lr_scheduler.CosineAnnealingLR(optimizer_bl, T_max=15)

print("直接訓練對照組 ResNet-18(無蒸餾)...")
for epoch in range(15):
    student_baseline.train()
    for inputs, targets in trainloader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer_bl.zero_grad()
        outputs = student_baseline(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer_bl.step()
    scheduler_bl.step()
    if (epoch + 1) % 5 == 0:
        print(f"  Epoch {epoch+1}/15 完成")

print("✓ 對照組訓練完成")

5.5 Step 5 — 完整比較:蒸餾 vs. 直接訓練 vs. 教師

# 評估所有模型
distilled_acc = evaluate(student_distilled, testloader, device)
baseline_acc = evaluate(student_baseline, testloader, device)
student_params = count_params(student_distilled)
distilled_speed = measure_speed(student_distilled, device)
baseline_speed = measure_speed(student_baseline, device)

print(f"\n{'='*70}")
print(f"  知識蒸餾效果比較(CIFAR-10)")
print(f"{'='*70}")
print(f"{'模型':<22} {'準確率':>8} {'參數量':>14} {'延遲(ms)':>9} {'備註':>12}")
print(f"{'-'*70}")
print(f"{'教師 ResNet-34':<22} {teacher_acc:>7.2f}% {teacher_params:>13,} "
      f"{teacher_speed:>8.2f}  {'上界'}")
print(f"{'學生 直接訓練':<22} {baseline_acc:>7.2f}% {student_params:>13,} "
      f"{baseline_speed:>8.2f}  {'無蒸餾'}")
print(f"{'學生 知識蒸餾':<22} {distilled_acc:>7.2f}% {student_params:>13,} "
      f"{distilled_speed:>8.2f}  {'T=4, α=0.7'}")
print(f"{'-'*70}")

improvement = distilled_acc - baseline_acc
gap_closed = (distilled_acc - baseline_acc) / (teacher_acc - baseline_acc) * 100 \
    if teacher_acc > baseline_acc else 0

print(f"\n★ 關鍵發現:")
print(f"  • 蒸餾學生 vs. 直接訓練: {improvement:+.2f}% 精度提升")
print(f"  • 蒸餾讓學生縮小了 {gap_closed:.0f}% 的教師-學生差距")
print(f"  • 學生參數量僅教師的 {student_params/teacher_params*100:.0f}%,"
      f"但透過蒸餾獲得了更接近教師的表現")
print(f"  • 推論速度幾乎相同(學生架構一樣),但蒸餾版精度更高")

5.6 Step 6 — 探索溫度參數的影響

# 測試不同溫度下的蒸餾效果
temperatures = [1, 2, 4, 8, 16]
temp_results = []

for T in temperatures:
    student_t = models.resnet18(weights=None, num_classes=10)
    student_t.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    student_t.maxpool = nn.Identity()
    student_t = student_t.to(device)

    opt = optim.SGD(student_t.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=15)

    for epoch in range(15):
        student_t.train()
        teacher.eval()
        for inputs, targets in trainloader:
            inputs, targets = inputs.to(device), targets.to(device)
            with torch.no_grad():
                teacher_logits = teacher(inputs)
            student_logits = student_t(inputs)
            loss = distillation_loss(student_logits, teacher_logits, targets, T=T, alpha=0.7)
            opt.zero_grad()
            loss.backward()
            opt.step()
        sch.step()

    acc = evaluate(student_t, testloader, device)
    temp_results.append({'T': T, 'acc': acc})
    print(f"  T={T:<3d} → 準確率: {acc:.2f}%")
    del student_t
    if device.type == 'cuda':
        torch.cuda.empty_cache()

print(f"\n{'='*50}")
print(f"  溫度參數 T 對蒸餾效果的影響")
print(f"{'='*50}")
print(f"{'溫度 T':>8} {'準確率':>10} {'vs. 直接訓練':>14}")
print(f"{'-'*50}")
for r in temp_results:
    delta = r['acc'] - baseline_acc
    print(f"{r['T']:>8d} {r['acc']:>9.2f}% {delta:>+13.2f}%")
print(f"{'-'*50}")
print(f"{'直接訓練':<8} {baseline_acc:>9.2f}%  {'(基線)':>13}")
print(f"\n→ 通常 T=3~8 效果最好。T 太低暗知識不足,T 太高信號被過度平滑。")

你會親眼看到的效果:蒸餾出的 ResNet-18 準確率會比直接訓練的 ResNet-18 高出約 0.5-2%。這看似不大,但要記住兩個模型架構完全相同——唯一的差異就是蒸餾帶來的「暗知識」。在學術論文中,這個量級的提升往往是數月研究的成果。

六、Hands-on Lab:LLM 知識蒸餾實驗室(語言模型)

CV 蒸餾展示了基礎原理。接下來我們在語言模型上動手——用 GPT-2 Medium(345M)作為教師,蒸餾到 GPT-2 Small(124M),在免費的 Google Colab 上即可完成。

打開 Google Colab,新建 Notebook,依序貼入以下程式碼:

6.1 安裝與載入模型

!pip install transformers datasets accelerate -q

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from datasets import load_dataset
import time, copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"裝置: {device}")

# 載入教師模型:GPT-2 Medium(345M 參數)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
tokenizer.pad_token = tokenizer.eos_token
teacher = GPT2LMHeadModel.from_pretrained("gpt2-medium").to(device)
teacher.eval()

# 載入學生模型:GPT-2 Small(124M 參數)
student = GPT2LMHeadModel.from_pretrained("gpt2").to(device)

teacher_params = sum(p.numel() for p in teacher.parameters())
student_params = sum(p.numel() for p in student.parameters())

print(f"教師 GPT-2 Medium: {teacher_params:,} 參數")
print(f"學生 GPT-2 Small:  {student_params:,} 參數")
print(f"壓縮比: {teacher_params/student_params:.1f}x")

6.2 準備訓練資料

# 使用 WikiText-2 作為蒸餾語料
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

# 預處理:將文字切分為固定長度的 token 序列
def tokenize_and_chunk(examples, max_length=128):
    tokens = tokenizer(
        examples["text"],
        truncation=True,
        max_length=max_length,
        padding="max_length",
        return_tensors="pt"
    )
    return tokens

# 過濾空行並取一個子集(免費 Colab 友好)
texts = [t for t in dataset["text"] if len(t.strip()) > 50][:2000]
print(f"使用 {len(texts)} 條文字進行蒸餾訓練")

# 建立 DataLoader
from torch.utils.data import DataLoader, TensorDataset

encodings = tokenizer(texts, truncation=True, max_length=128,
                      padding="max_length", return_tensors="pt")
train_dataset = TensorDataset(encodings["input_ids"], encodings["attention_mask"])
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

print(f"✓ 資料準備完成,共 {len(train_loader)} 個 batch")

6.3 定義評估函數

def measure_perplexity(model, texts, tokenizer, device, max_length=128):
    """在一組文字上計算困惑度"""
    model.eval()
    total_loss, total_tokens = 0, 0
    eval_texts = texts[:200]  # 取子集加速評估
    for text in eval_texts:
        inputs = tokenizer(text, return_tensors="pt", truncation=True,
                           max_length=max_length).to(device)
        if inputs["input_ids"].size(1) < 2:
            continue
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
        total_loss += outputs.loss.item() * inputs["input_ids"].size(1)
        total_tokens += inputs["input_ids"].size(1)
    return torch.exp(torch.tensor(total_loss / total_tokens)).item() if total_tokens > 0 else float('inf')

def generate_text(model, prompt, max_new_tokens=60):
    """生成文字,觀察品質"""
    model.eval()
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs, max_new_tokens=max_new_tokens,
            do_sample=True, temperature=0.7, top_p=0.9,
            pad_token_id=tokenizer.eos_token_id,
        )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# 評估文字
eval_texts = [t for t in dataset["text"] if len(t.strip()) > 100][:200]
test_prompts = [
    "The future of artificial intelligence",
    "Knowledge distillation is a technique that",
    "In machine learning, the relationship between",
]

# 記錄基準
teacher_ppl = measure_perplexity(teacher, eval_texts, tokenizer, device)
student_ppl_before = measure_perplexity(student, eval_texts, tokenizer, device)

print(f"教師 PPL: {teacher_ppl:.2f}")
print(f"學生 PPL(蒸餾前): {student_ppl_before:.2f}")

6.4 知識蒸餾訓練

def lm_distillation_loss(student_logits, teacher_logits, labels, T=3.0, alpha=0.5):
    """
    ★ 語言模型蒸餾損失 ★
    """
    # Shift: 預測下一個 token
    shift_student = student_logits[:, :-1, :].contiguous()
    shift_teacher = teacher_logits[:, :-1, :].contiguous()
    shift_labels = labels[:, 1:].contiguous()

    # 軟目標蒸餾損失
    soft_loss = F.kl_div(
        F.log_softmax(shift_student / T, dim=-1),
        F.softmax(shift_teacher / T, dim=-1),
        reduction='batchmean'
    ) * (T * T)

    # 硬標籤損失
    hard_loss = F.cross_entropy(
        shift_student.view(-1, shift_student.size(-1)),
        shift_labels.view(-1),
        ignore_index=tokenizer.pad_token_id
    )

    return alpha * soft_loss + (1 - alpha) * hard_loss

# ---- 蒸餾訓練 ----
optimizer = torch.optim.AdamW(student.parameters(), lr=5e-5, weight_decay=0.01)
T = 3.0   # 溫度
alpha = 0.5  # 蒸餾權重

print("開始 LLM 知識蒸餾...")
print(f"  溫度 T={T}, 蒸餾權重 α={alpha}")
print(f"  教師: GPT-2 Medium (345M), 學生: GPT-2 Small (124M)\n")

student.train()
for epoch in range(3):
    total_loss = 0
    for batch_idx, (input_ids, attention_mask) in enumerate(train_loader):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        # 教師推論
        with torch.no_grad():
            teacher_outputs = teacher(input_ids=input_ids,
                                       attention_mask=attention_mask)

        # 學生推論
        student_outputs = student(input_ids=input_ids,
                                   attention_mask=attention_mask)

        # 蒸餾損失
        loss = lm_distillation_loss(
            student_outputs.logits, teacher_outputs.logits,
            input_ids, T=T, alpha=alpha
        )

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()

        if (batch_idx + 1) % 50 == 0:
            print(f"  Epoch {epoch+1}, Batch {batch_idx+1}/{len(train_loader)}, "
                  f"Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(train_loader)
    print(f"  Epoch {epoch+1}/3 完成, Avg Loss: {avg_loss:.4f}\n")

print("✓ LLM 蒸餾訓練完成")

6.5 結果比較

student_ppl_after = measure_perplexity(student, eval_texts, tokenizer, device)

print(f"\n{'='*65}")
print(f"  GPT-2 知識蒸餾結果")
print(f"{'='*65}")
print(f"{'模型':<25} {'困惑度(PPL)':>12} {'參數量':>14}")
print(f"{'-'*65}")
print(f"{'教師 GPT-2 Medium':<25} {teacher_ppl:>11.2f} {teacher_params:>13,}")
print(f"{'學生(蒸餾前)':<25} {student_ppl_before:>11.2f} {student_params:>13,}")
print(f"{'學生(蒸餾後)':<25} {student_ppl_after:>11.2f} {student_params:>13,}")
print(f"{'='*65}")

ppl_improvement = student_ppl_before - student_ppl_after
gap_closed = (student_ppl_before - student_ppl_after) / \
    (student_ppl_before - teacher_ppl) * 100 if student_ppl_before > teacher_ppl else 0

print(f"\n★ 關鍵發現:")
print(f"  • 困惑度降低: {ppl_improvement:.2f}(越低越好)")
print(f"  • 縮小教師-學生差距: {gap_closed:.1f}%")
print(f"  • 壓縮比: {teacher_params/student_params:.1f}x(參數量)")

print(f"\n{'='*65}")
print(f"  生成品質比較")
print(f"{'='*65}")
for p in test_prompts:
    print(f"\n  Prompt: {p}")
    print(f"  教師: {generate_text(teacher, p, max_new_tokens=40)}")
    print(f"  學生: {generate_text(student, p, max_new_tokens=40)}")

6.6 進階:使用 HuggingFace TRL 的 GKD Trainer

上面的 demo 使用最基礎的 KL 散度蒸餾。對於更大規模的 LLM 蒸餾,HuggingFace 的 TRL 庫[11]提供了開箱即用的 GKDTrainer,實作了 Google DeepMind 的 GKD 論文[7]

# pip install trl

from trl import GKDConfig, GKDTrainer

# GKD 訓練配置
training_args = GKDConfig(
    output_dir="./gkd-output",
    per_device_train_batch_size=4,
    num_train_epochs=3,
    learning_rate=5e-5,
    lmbda=0.5,           # 教師混合比例(0 = 純 on-policy,1 = 純 off-policy)
    beta=0.5,             # Skew KL 散度的插值參數
    temperature=3.0,      # 蒸餾溫度
    max_new_tokens=128,   # 學生生成的最大 token 數
)

# 初始化 GKD Trainer
trainer = GKDTrainer(
    model=student_model,              # 學生模型
    teacher_model=teacher_model,      # 教師模型
    args=training_args,
    train_dataset=train_dataset,
    processing_class=tokenizer,
)

# 開始 on-policy 蒸餾
trainer.train()

# GKD 的核心優勢:
# 1. 學生在自己的生成序列上訓練(on-policy)
# 2. 消除了訓練-推論分佈不匹配問題
# 3. 支援多種散度度量(前向 KL、反向 KL、JSD)

GKD 在 Gemini 團隊的內部測試中已被驗證為 LLM 蒸餾的最佳實踐之一,特別適合需要長序列生成的場景(如摘要、翻譯、程式碼生成)。

七、擴散模型蒸餾:讓圖片生成從 50 步壓縮到 1 步

知識蒸餾在圖片生成領域的影響甚至比 NLP 更具戲劇性。擴散模型(Stable Diffusion、FLUX)的核心瓶頸是生成步數太多——每張圖片需要 20-50 步的迭代去噪。蒸餾技術正在從根本上解決這個問題。

7.1 漸進式蒸餾:步數減半的連鎖反應

Google 的 Salimans 與 Ho 在 ICLR 2022 發表的 Progressive Distillation[12] 是擴散模型蒸餾的先驅。核心思路極其直觀:訓練一個學生模型,讓它用 1 步完成教師 2 步的工作。重複這個過程 N 次,步數就從 2^N 降到 1。在 CIFAR-10 和 ImageNet 64×64 上,他們成功將 8192 步壓縮至 4 步。

7.2 LCM:2-4 步生成高解析度圖片

Latent Consistency Models(LCM)[13] 將蒸餾帶入了潛空間。LCM 不是直接蒸餾去噪步驟,而是訓練模型直接預測 ODE(常微分方程)的解——跳過中間步驟,一步到位。訓練僅需32 A100 GPU hours(約 $100 雲端成本),即可生成 768×768 的高品質圖片。

更重要的是,LCM 團隊同時發布了 LCM-LoRA——一種輕量級適配器,可以即插即用到任何基於 SD 的模型上。這意味著社群訓練的所有自定義模型(DreamBooth、LoRA 微調版等)都能直接獲得 2-4 步加速:

# LCM-LoRA:讓任何 Stable Diffusion 模型 2-4 步生成
!pip install diffusers transformers accelerate -q

from diffusers import DiffusionPipeline, LCMScheduler
import torch

# 載入任意 SD 模型
pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
).to("cuda")

# ★ 即插即用 LCM-LoRA ★
pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)

# 僅需 4 步即可生成高品質圖片!
image = pipe(
    prompt="A futuristic cityscape at sunset, photorealistic",
    num_inference_steps=4,       # 原本需要 25-50 步
    guidance_scale=1.5,          # LCM 用較低的 guidance
).images[0]

image.save("lcm_result.png")
print("✓ 4 步即完成 SDXL 圖片生成!")

7.3 SDXL Turbo:即時單步生成

Stability AI 的 Adversarial Diffusion Distillation(ADD)[14] 將步數壓縮推到了極限:單步生成。ADD 巧妙地結合了兩種損失:

結果是 SDXL Turbo 在單步就能生成 512×512 圖片,品質超越了需要多步的 LCM 和傳統 GANs。雖然 SDXL Turbo 不開源模型權重,但這一技術路線(ADD)已經被 Flux.1-schnell 的 LADD 方法成功延伸到更大規模。

7.4 Flux.1-schnell:LADD 潛空間對抗蒸餾

Black Forest Labs 的 Flux.1-schnell 是目前開源社群中品質最高的快速生成模型之一。它使用了 LADD(Latent Adversarial Diffusion Distillation)[15]——ADD 的進化版:

# 使用 Flux.1-schnell(蒸餾模型)
from diffusers import FluxPipeline
import torch

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    torch_dtype=torch.bfloat16
).to("cuda")

# 僅需 4 步!
image = pipe(
    prompt="A serene Japanese garden with cherry blossoms",
    num_inference_steps=4,
    guidance_scale=0.0,  # schnell 不需要 classifier-free guidance
).images[0]

image.save("flux_schnell_result.png")
print("✓ Flux.1-schnell: 4 步生成完成")

7.5 BK-SDM 與 SnapFusion:架構蒸餾

除了步數蒸餾,另一條路線是架構蒸餾——直接讓模型本身變小。

BK-SDM[16](ECCV 2024)從 Stable Diffusion v1.4 的 U-Net 中移除了多個 residual 和 attention blocks,再用特徵蒸餾恢復品質。結果是參數減少 30-50%,FID 分數反而更好,訓練僅需 13 A100-days——比原始 SD 的 6,000+ A100-days 便宜 460 倍。

SnapFusion[17](NeurIPS 2023)則同時做了架構蒸餾和步數蒸餾,最終在手機上 2 秒內生成 512×512 圖片,將 50 步壓縮至 8 步。

方法會議蒸餾類型步數壓縮架構壓縮訓練成本
Progressive DistillationICLR 2022步數蒸餾8192→4 步中等
LCM / LCM-LoRA2023一致性蒸餾50→2-4 步32 A100 hrs
ADD(SDXL Turbo)ECCV 2024對抗蒸餾50→1 步
LADD(Flux.1-schnell)SIGGRAPH Asia 2024潛空間對抗蒸餾20-50→1-4 步
BK-SDMECCV 2024特徵蒸餾參數減 30-50%13 A100-days
SnapFusionNeurIPS 2023架構 + 步數蒸餾50→8 步架構精簡中等

八、生態系工具全景

從學術實作到企業級平台,知識蒸餾的工具生態已覆蓋完整技術棧:

基礎框架

LLM 蒸餾

擴散模型蒸餾

通用平台

九、從技術指標到商業影響

知識蒸餾的商業價值不僅在於「讓模型變小」,更在於它根本性地改變了 AI 的部署經濟學:

十、導入路徑:三階段落地策略

  1. 盤點蒸餾機會:找出推論成本最高、呼叫頻率最高的模型。NLP 分類任務首選 DistilBERT / TinyBERT 等現成蒸餾模型;圖片生成首選 LCM-LoRA 即插即用加速
  2. 從現有蒸餾模型開始:HuggingFace 上已有大量預蒸餾模型(DistilBERT、DistilGPT-2、LCM-LoRA、Flux.1-schnell)。先用這些現成方案驗證蒸餾在你的場景中是否有效
  3. 進階自定義蒸餾:當預建模型無法滿足需求時,使用 GKDTrainer(LLM 場景)或 Diffusers + LCMScheduler(圖片場景)訓練自定義蒸餾模型。若預算充足,考慮 Minitron 式的剪枝 + 蒸餾組合管線

知識蒸餾不是一項新技術——Hinton 在 2015 年就奠定了基礎。但過去兩年,從 MiniLLM 到 DeepSeek-R1,從 LCM 到 Flux.1-schnell,蒸餾技術經歷了質的飛躍。它不再只是「壓縮」——它是AI 能力的高效傳承機制,讓頂級模型的智慧流入每一個邊緣裝置、每一個 API 呼叫、每一個即時生成的圖片中。

如果您的團隊正在評估模型壓縮策略,或需要在成本、延遲與能力之間找到最佳平衡點,歡迎與我們進行深度技術對話。超智諮詢的研究團隊能夠陪伴您走完從教師模型選擇到學生模型上線的完整旅程。