- Knowledge distillation enables small models to "learn" how large models reason — DistilBERT uses only 60% of BERT's parameters yet retains 97% of its language understanding capability; TinyBERT achieves 7.5x compression and 9.4x speedup
- LLM distillation has evolved from "imitating outputs" to "inheriting reasoning" — DeepSeek-R1's distilled 14B model surpasses the 32B QwQ-32B-Preview in mathematical reasoning, proving that small models can possess deep thinking capabilities
- NVIDIA's Minitron combines pruning with distillation, deriving 8B/4B versions from a 15B model using only 1/40 of the training tokens required for training from scratch, with MMLU improvements of up to 16%
- Diffusion model distillation compresses image generation from 50 steps to 1-4 steps — SDXL Turbo achieves real-time single-step generation, LCM requires only 32 A100 GPU hours for training, and Flux.1-schnell is released under the Apache 2.0 license
1. AI's "Large Model Dependency": When Scaling Hits a Wall
In 2025, the AI industry faces a paradox: models keep getting larger and more capable, but the number of enterprises that can actually deploy them into production has not grown proportionally. GPT-4-class models require hundreds of gigabytes of memory and thousands of dollars in monthly inference costs; even the open-source LLaMA-70B needs at least two A100 GPUs to run smoothly. Harvard Business Review notes[1] that AI system carbon emissions are growing at an alarming rate — without optimization, they could add tens of millions of tons of CO2 annually by 2030.
Research from MIT Sloan Management Review[2] reinforces this from another angle: smaller, more targeted AI deployments often yield higher business returns than "bigger is better" strategies. What enterprises truly need is not the largest model, but one that is just capable enough, cost-effective, and deployable at scale.
This is precisely the core problem that Knowledge Distillation addresses: How can a small model "learn" the capabilities of a large model, rather than simply being a shrunken version of it?
Unlike pruning (neural network pruning), which directly "cuts away" redundant parameters, distillation is a knowledge transfer technique — it trains an entirely new small model (the student) to learn from the behavior of a large model (the teacher). The student model does not need to share the same architecture as the teacher and can even use a completely different design. This means distillation can produce truly compact models that are highly optimized for specific tasks, rather than just a "discounted" version of the large model.
2. Technical Evolution: From Hinton's Intuition to LLM Reasoning Distillation
2.1 Classic Knowledge Distillation: Temperature and Dark Knowledge (2015)
In 2015, Geoffrey Hinton, Oriol Vinyals, and Jeff Dean proposed the complete framework for knowledge distillation in a seemingly modest workshop paper[3]. The core insight was remarkably elegant:
A large model's "knowledge" is not only encoded in its final predictions (hard labels), but also hidden within those "incorrect" prediction probabilities — these soft targets contain the similarity structure between classes, known as "dark knowledge."
Here is an intuitive example: an image classifier looking at a handwritten digit "2" might predict "2" with 90% probability, while assigning 5% to "3" and 3% to "7." These "incorrect" probabilities actually contain valuable information — "2" does share visual similarities with "3" and "7." If the student only learns from hard labels ("this is a 2"), it cannot capture these subtle inter-class relationships; but if it learns from the teacher's full probability distribution, it can inherit the teacher's understanding of the entire problem space.
Hinton introduced the temperature parameter T to control the "softness" of soft targets. Higher temperatures produce smoother probability distributions, making dark knowledge more prominent:
# Core formula for knowledge distillation
# Standard Softmax: q_i = exp(z_i) / Σ exp(z_j)
# Temperature-scaled Softmax: q_i = exp(z_i / T) / Σ exp(z_j / T)
#
# T = 1: Standard softmax (hard distribution)
# T > 1: Softer distribution (dark knowledge becomes more apparent)
# Typical values: T = 3~20
# Distillation loss = α × KL(soft_teacher ∥ soft_student) + (1-α) × CE(hard_label, student)
# α: weight balance between distillation loss and ground-truth label loss
The impact of this paper far exceeded its workshop format — as of 2025, it has been cited over 20,000 times, making it one of the most cited papers in the model compression field.
2.2 The BERT Era: DistilBERT and TinyBERT
The power of knowledge distillation was most vividly demonstrated in the NLP domain. DistilBERT[4], published by the HuggingFace team in 2019, is perhaps the most successful commercial application of distillation technology:
- 40% smaller model: Reduced from BERT's 110M parameters to 66M
- 60% faster inference: Nearly halved latency on CPU
- 97% capability retained: Only a 3% performance drop on the GLUE benchmark
DistilBERT's training employed a triple loss function: language model loss, distillation loss (KL divergence between teacher and student soft targets), and cosine distance loss on hidden layers. This multi-level knowledge transfer enabled the student to learn not only the teacher's outputs but also the structure of its intermediate representations.
Huawei's TinyBERT[5] went further by splitting distillation into two stages: general pre-training distillation and task-specific distillation. In addition to output-layer soft targets, TinyBERT also aligned attention matrices and hidden-state intermediate representations, achieving 7.5x compression and 9.4x speedup while retaining 96.8% of BERT-Base's performance.
2.3 The LLM Era: When Distillation Meets Hundred-Billion-Parameter Models
As LLM scales exploded, distillation faced entirely new challenges: teacher models were too large, generation was autoregressive, and training costs were extremely high. In 2024, three important papers broke through these bottlenecks from different angles:
MiniLLM[6] (ICLR 2024) identified a key problem: traditional forward KL divergence causes the student to "over-distribute" attention across regions where the teacher assigns low probability, degrading generation quality. The solution was to use reverse KL divergence — allowing the student to focus on regions where the teacher is most confident. Experiments showed that from GPT-2 to LLaMA-13B, MiniLLM outperformed standard distillation across all scales.
GKD (Generalized Knowledge Distillation)[7] (Google DeepMind, ICLR 2024) addressed another fundamental issue: in traditional distillation, the student trains on the teacher's generated sequences (off-policy), but during inference, the student must generate based on its own outputs (on-policy). This train-inference distribution mismatch accumulates increasingly larger errors as sequence length grows. GKD's solution is to let the student train on its own generated outputs with teacher feedback — learning from its own mistakes rather than memorizing the teacher's standard answers.
DistiLLM[8] (ICML 2024) introduced Skew KL divergence, striking a balance between forward and reverse KL, and combined it with an adaptive off-policy training strategy to achieve 4.3x training speedup over existing methods.
NVIDIA's Minitron[9] (NeurIPS 2024) demonstrated the powerful combination of pruning and distillation: first, structured pruning (removing depth, width, attention heads, and MLP channels) carves out 8B/4B skeletons from a 15B model, then knowledge distillation restores quality. The results were impressive — training tokens were only 1/40 of training from scratch, with MMLU benchmark improvements of up to 16%. This proved that distillation is not just a compression tool, but an efficient model "cultivation" strategy.
2.4 Reasoning Distillation: DeepSeek-R1's Paradigm Breakthrough
In early 2025, DeepSeek-R1[10] pushed knowledge distillation into an entirely new dimension: distillation of reasoning capabilities. Traditional distillation transfers "answers"; DeepSeek-R1 transfers "thinking processes."
The research team generated 800K training samples containing complete chain-of-thought reasoning from DeepSeek-R1 (a large model that acquired reasoning capabilities through reinforcement learning), and used these samples to fine-tune 6 open-source models (based on Qwen2.5 and Llama3, ranging from 1.5B to 70B). The results stunned the entire community:
- The distilled 14B model surpassed QwQ-32B-Preview in mathematical reasoning — a model twice its size
- For small models (14B and below), distillation significantly outperformed directly applying reinforcement learning on the small model
- Even the 1.5B distilled version demonstrated rudimentary reasoning and reflection capabilities
The implications of DeepSeek-R1 are profound: knowledge distillation is no longer just "compression" — it is a capability inheritance mechanism that can "pass down" advanced capabilities acquired through expensive training (such as RL) from large models to small models at minimal cost.
3. Empirical Data: A Comprehensive Overview of Distillation Compression Results
| Model | Technique | Compression | Quality Retention | Source |
|---|---|---|---|---|
| BERT → DistilBERT | Triple loss distillation | 40% smaller, 60% faster | Retains 97% GLUE score | Sanh et al., 2019 |
| BERT → TinyBERT | Two-stage distillation | 7.5x smaller, 9.4x faster | Retains 96.8% performance | Jiao et al., 2020 |
| Nemotron 15B → 8B/4B | Pruning + distillation | 1/40 training tokens | MMLU +16% (vs. training from scratch) | Muralidharan et al., 2024 |
| GPT-2/LLaMA series | MiniLLM (reverse KL) | Full-scale distillation 120M–13B | Outperforms standard KL distillation | Gu et al., 2024 |
| DeepSeek-R1 → 14B | Chain-of-thought distillation | 14B student | Math reasoning surpasses QwQ-32B | DeepSeek-AI, 2025 |
| SD v1.4 → BK-SDM | Block pruning + distillation | 30-50% parameter reduction | FID on par or better | Kim et al., 2024 |
| SDXL → SDXL Turbo | Adversarial distillation (ADD) | 50 steps → 1 step | Single-step surpasses LCM/GANs | Sauer et al., 2023 |
| SD → LCM | Latent consistency distillation | 50 steps → 2-4 steps | Only 32 A100 hours training | Luo et al., 2023 |
| Flux.1-pro → schnell | LADD (latent adversarial distillation) | 20-50 steps → 1-4 steps | High quality, Apache 2.0 open source | Sauer et al., 2024 |
4. Decision Framework: Benefits, Costs, and Boundaries of Distillation
Knowledge distillation, pruning, and quantization are the three pillars of model compression, but they each apply to different scenarios. Before deciding whether to adopt distillation, it is important to understand its unique advantages and limitations:
| Dimension | Using the Teacher Model Directly | Distilled Student Model |
|---|---|---|
| Model Size | Full parameters (e.g., BERT: 110M, LLaMA: 70B) | Freely designable student architecture; typical 2-10x compression |
| Inference Speed | Baseline speed | Structural speedup (not sparsity-dependent): 2-9x faster |
| Accuracy | Full accuracy | Retains 95-99% of teacher capability (task-dependent) |
| Training Cost | Teacher requires full training | Student training cost far lower than training from scratch (Minitron: 1/40 tokens) |
| Architecture Flexibility | Fixed architecture | Student can use a completely different architecture design |
| Deployment Flexibility | High-end GPUs only | Deployable to CPUs, mobile devices, edge devices |
Strategic Advantages
- Architecture freedom: Unlike pruning, the distilled student model can use a completely different architecture. You can distill a Transformer teacher into a convolutional neural network student, or vice versa. This means deployment-side architecture choices are not constrained by the teacher model
- Depth of capability transfer: Distillation does not just compress models — it transfers "knowledge." DeepSeek-R1's reasoning distillation proves that even advanced reasoning capabilities acquired through RL can be distilled into small models
- Composable with other techniques: Minitron demonstrated the combined power of pruning + distillation, and quantization can further compress models after distillation. The combined effect of all three far exceeds any single technique
- One teacher, multiple students: A single powerful teacher can simultaneously distill multiple students at different scales, tailored for different deployment scenarios (cloud, edge, mobile)
Risks to Manage
- Higher training cost than pruning: Distillation requires retraining the student model, whereas pruning (especially SparseGPT/Wanda) can be completed in a single pass. For LLM scenarios, distillation training costs can be significant
- Requires teacher inference resources: During training, the teacher model must be continuously run to generate soft targets, increasing GPU memory and compute requirements
- Capacity gap problem: If the student model is too small, it may not fully absorb the teacher's knowledge. The student-to-teacher size ratio must be carefully tuned
- Task specificity: Distilled student models typically excel on specific tasks, but their generalization capability may not match the teacher's. If downstream tasks change frequently, this should be considered
- Data dependency: Distillation quality is highly dependent on the quality and coverage of training data. If the distillation data distribution does not match the actual deployment scenario, student model performance will suffer
Distillation vs. Pruning: How to Choose?
| Scenario | Recommended Technique | Rationale |
|---|---|---|
| Quickly compress an existing LLM | Pruning (Wanda/SparseGPT) | No retraining needed, completed in hours |
| Need to change model architecture | Distillation | Student can use a completely different architecture |
| Pursue optimal compression | Pruning + distillation (Minitron) | Prune first to create the skeleton, then distill to restore quality |
| Transfer advanced reasoning capabilities | Reasoning distillation (DeepSeek-R1 style) | Chain-of-thought distillation is the only validated approach |
| Diffusion model step compression | Distillation (LCM/ADD/LADD) | Step distillation is the primary method for reducing generation steps |
| Limited budget, need quick validation | Pruning | Distillation requires additional training resources and time |
5. Hands-on Lab: Google Colab Workshop (CV Model Distillation)
Let us start with the most classic scenario: using a large ResNet-34 as the teacher and distilling it into a smaller ResNet-18 student. You will see how distillation enables the student to surpass the performance ceiling of "direct training." All code can be run directly on Google Colab's free GPU.
Open Google Colab, create a new Notebook, and paste the following code blocks in sequence:
5.1 Step 1 — Train the Teacher Model ResNet-34 (~5 minutes)
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import time, copy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# ---- Dataset ----
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=256,
shuffle=False, num_workers=2)
# ---- Teacher model: ResNet-34 (adapted for CIFAR-10's 32x32 input) ----
teacher = models.resnet34(weights=None, num_classes=10)
teacher.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
teacher.maxpool = nn.Identity()
teacher = teacher.to(device)
# ---- Train teacher for 15 epochs ----
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15)
print("Training teacher model ResNet-34...")
for epoch in range(15):
teacher.train()
for inputs, targets in trainloader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = teacher(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
scheduler.step()
if (epoch + 1) % 5 == 0:
print(f" Epoch {epoch+1}/15 complete")
teacher.eval()
print("Teacher model training complete")
5.2 Step 2 — Evaluation Utility Functions
def evaluate(model, dataloader, device):
"""Calculate test set accuracy"""
model.eval()
correct, total = 0, 0
with torch.no_grad():
for inputs, targets in dataloader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
return 100. * correct / total
def count_params(model):
"""Count model parameters"""
return sum(p.numel() for p in model.parameters())
def measure_speed(model, device, input_size=(1, 3, 32, 32), n_runs=200):
"""Measure inference latency (ms)"""
model.eval()
dummy = torch.randn(*input_size).to(device)
for _ in range(50):
with torch.no_grad():
model(dummy)
if device.type == 'cuda':
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(n_runs):
with torch.no_grad():
model(dummy)
if device.type == 'cuda':
torch.cuda.synchronize()
return (time.perf_counter() - start) / n_runs * 1000
# ---- Teacher baseline ----
teacher_acc = evaluate(teacher, testloader, device)
teacher_params = count_params(teacher)
teacher_speed = measure_speed(teacher, device)
print(f"{'='*55}")
print(f" Teacher Model ResNet-34")
print(f"{'='*55}")
print(f" Accuracy: {teacher_acc:.2f}%")
print(f" Parameters: {teacher_params:,}")
print(f" Latency: {teacher_speed:.2f} ms")
print(f"{'='*55}")
5.3 Step 3 — Knowledge Distillation Core: Training the Student Model
def distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7):
"""
Knowledge Distillation Loss Function
- T: Temperature parameter (higher = smoother soft targets, more prominent dark knowledge)
- alpha: Weight of distillation loss (1-alpha is the weight of hard label loss)
"""
# Soft target distillation loss (KL divergence)
soft_loss = F.kl_div(
F.log_softmax(student_logits / T, dim=1),
F.softmax(teacher_logits / T, dim=1),
reduction='batchmean'
) * (T * T) # Multiply by T^2 to compensate for gradient scaling
# Hard label cross-entropy loss
hard_loss = F.cross_entropy(student_logits, labels)
return alpha * soft_loss + (1 - alpha) * hard_loss
# ---- Create student model: ResNet-18 (~50% smaller than teacher) ----
student_distilled = models.resnet18(weights=None, num_classes=10)
student_distilled.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
student_distilled.maxpool = nn.Identity()
student_distilled = student_distilled.to(device)
# ---- Distillation training for 15 epochs ----
optimizer_kd = optim.SGD(student_distilled.parameters(), lr=0.1,
momentum=0.9, weight_decay=5e-4)
scheduler_kd = optim.lr_scheduler.CosineAnnealingLR(optimizer_kd, T_max=15)
print("Training student model ResNet-18 with knowledge distillation...")
for epoch in range(15):
student_distilled.train()
teacher.eval()
for inputs, targets in trainloader:
inputs, targets = inputs.to(device), targets.to(device)
# Teacher inference (no gradient computation)
with torch.no_grad():
teacher_logits = teacher(inputs)
# Student inference
student_logits = student_distilled(inputs)
# Distillation loss
loss = distillation_loss(
student_logits, teacher_logits, targets,
T=4.0, # Temperature 4 (makes dark knowledge more prominent)
alpha=0.7 # 70% distillation loss + 30% hard label loss
)
optimizer_kd.zero_grad()
loss.backward()
optimizer_kd.step()
scheduler_kd.step()
if (epoch + 1) % 5 == 0:
print(f" Epoch {epoch+1}/15 complete")
print("Distillation training complete")
5.4 Step 4 — Control Group: Training the Student Directly (No Distillation)
# Same ResNet-18, but trained directly with hard labels — no distillation
student_baseline = models.resnet18(weights=None, num_classes=10)
student_baseline.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
student_baseline.maxpool = nn.Identity()
student_baseline = student_baseline.to(device)
optimizer_bl = optim.SGD(student_baseline.parameters(), lr=0.1,
momentum=0.9, weight_decay=5e-4)
scheduler_bl = optim.lr_scheduler.CosineAnnealingLR(optimizer_bl, T_max=15)
print("Training control group ResNet-18 directly (no distillation)...")
for epoch in range(15):
student_baseline.train()
for inputs, targets in trainloader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer_bl.zero_grad()
outputs = student_baseline(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer_bl.step()
scheduler_bl.step()
if (epoch + 1) % 5 == 0:
print(f" Epoch {epoch+1}/15 complete")
print("Control group training complete")
5.5 Step 5 — Full Comparison: Distillation vs. Direct Training vs. Teacher
# Evaluate all models
distilled_acc = evaluate(student_distilled, testloader, device)
baseline_acc = evaluate(student_baseline, testloader, device)
student_params = count_params(student_distilled)
distilled_speed = measure_speed(student_distilled, device)
baseline_speed = measure_speed(student_baseline, device)
print(f"\n{'='*70}")
print(f" Knowledge Distillation Results Comparison (CIFAR-10)")
print(f"{'='*70}")
print(f"{'Model':<26} {'Accuracy':>8} {'Parameters':>14} {'Latency(ms)':>11} {'Note':>10}")
print(f"{'-'*70}")
print(f"{'Teacher ResNet-34':<26} {teacher_acc:>7.2f}% {teacher_params:>13,} "
f"{teacher_speed:>10.2f} {'Upper bound'}")
print(f"{'Student Direct Train':<26} {baseline_acc:>7.2f}% {student_params:>13,} "
f"{baseline_speed:>10.2f} {'No distill'}")
print(f"{'Student Distilled':<26} {distilled_acc:>7.2f}% {student_params:>13,} "
f"{distilled_speed:>10.2f} {'T=4, a=0.7'}")
print(f"{'-'*70}")
improvement = distilled_acc - baseline_acc
gap_closed = (distilled_acc - baseline_acc) / (teacher_acc - baseline_acc) * 100 \
if teacher_acc > baseline_acc else 0
print(f"\nKey Findings:")
print(f" - Distilled student vs. direct training: {improvement:+.2f}% accuracy gain")
print(f" - Distillation closed {gap_closed:.0f}% of the teacher-student gap")
print(f" - Student has only {student_params/teacher_params*100:.0f}% of teacher's parameters, "
f"yet achieves performance much closer to the teacher through distillation")
print(f" - Inference speed is nearly identical (same student architecture), "
f"but the distilled version is more accurate")
5.6 Step 6 — Exploring the Effect of the Temperature Parameter
# Test distillation results at different temperatures
temperatures = [1, 2, 4, 8, 16]
temp_results = []
for T in temperatures:
student_t = models.resnet18(weights=None, num_classes=10)
student_t.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
student_t.maxpool = nn.Identity()
student_t = student_t.to(device)
opt = optim.SGD(student_t.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=15)
for epoch in range(15):
student_t.train()
teacher.eval()
for inputs, targets in trainloader:
inputs, targets = inputs.to(device), targets.to(device)
with torch.no_grad():
teacher_logits = teacher(inputs)
student_logits = student_t(inputs)
loss = distillation_loss(student_logits, teacher_logits, targets, T=T, alpha=0.7)
opt.zero_grad()
loss.backward()
opt.step()
sch.step()
acc = evaluate(student_t, testloader, device)
temp_results.append({'T': T, 'acc': acc})
print(f" T={T:<3d} -> Accuracy: {acc:.2f}%")
del student_t
if device.type == 'cuda':
torch.cuda.empty_cache()
print(f"\n{'='*50}")
print(f" Effect of Temperature T on Distillation")
print(f"{'='*50}")
print(f"{'Temp T':>8} {'Accuracy':>10} {'vs. Direct':>14}")
print(f"{'-'*50}")
for r in temp_results:
delta = r['acc'] - baseline_acc
print(f"{r['T']:>8d} {r['acc']:>9.2f}% {delta:>+13.2f}%")
print(f"{'-'*50}")
print(f"{'Direct':<8} {baseline_acc:>9.2f}% {'(baseline)':>13}")
print(f"\n-> T=3~8 typically works best. Too low: insufficient dark knowledge. Too high: signal over-smoothed.")
What you will observe firsthand: The distilled ResNet-18 will be approximately 0.5-2% more accurate than the directly trained ResNet-18. This may seem small, but remember that both models have identical architectures — the only difference is the "dark knowledge" imparted through distillation. In academic papers, improvements of this magnitude often represent months of research effort.
6. Hands-on Lab: LLM Knowledge Distillation Workshop (Language Models)
The CV distillation lab demonstrated the fundamental principles. Now let us work with language models — using GPT-2 Medium (345M) as the teacher, distilling it into GPT-2 Small (124M), all achievable on Google Colab's free GPU.
Open Google Colab, create a new Notebook, and paste the following code blocks in sequence:
6.1 Installation and Model Loading
!pip install transformers datasets accelerate -q
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from datasets import load_dataset
import time, copy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Load teacher model: GPT-2 Medium (345M parameters)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
tokenizer.pad_token = tokenizer.eos_token
teacher = GPT2LMHeadModel.from_pretrained("gpt2-medium").to(device)
teacher.eval()
# Load student model: GPT-2 Small (124M parameters)
student = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
teacher_params = sum(p.numel() for p in teacher.parameters())
student_params = sum(p.numel() for p in student.parameters())
print(f"Teacher GPT-2 Medium: {teacher_params:,} parameters")
print(f"Student GPT-2 Small: {student_params:,} parameters")
print(f"Compression ratio: {teacher_params/student_params:.1f}x")
6.2 Preparing Training Data
# Use WikiText-2 as the distillation corpus
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
# Preprocessing: split text into fixed-length token sequences
def tokenize_and_chunk(examples, max_length=128):
tokens = tokenizer(
examples["text"],
truncation=True,
max_length=max_length,
padding="max_length",
return_tensors="pt"
)
return tokens
# Filter empty lines and take a subset (free Colab friendly)
texts = [t for t in dataset["text"] if len(t.strip()) > 50][:2000]
print(f"Using {len(texts)} texts for distillation training")
# Create DataLoader
from torch.utils.data import DataLoader, TensorDataset
encodings = tokenizer(texts, truncation=True, max_length=128,
padding="max_length", return_tensors="pt")
train_dataset = TensorDataset(encodings["input_ids"], encodings["attention_mask"])
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
print(f"Data preparation complete, {len(train_loader)} batches total")
6.3 Defining the Evaluation Function
def measure_perplexity(model, texts, tokenizer, device, max_length=128):
"""Compute perplexity on a set of texts"""
model.eval()
total_loss, total_tokens = 0, 0
eval_texts = texts[:200] # Use subset for faster evaluation
for text in eval_texts:
inputs = tokenizer(text, return_tensors="pt", truncation=True,
max_length=max_length).to(device)
if inputs["input_ids"].size(1) < 2:
continue
with torch.no_grad():
outputs = model(**inputs, labels=inputs["input_ids"])
total_loss += outputs.loss.item() * inputs["input_ids"].size(1)
total_tokens += inputs["input_ids"].size(1)
return torch.exp(torch.tensor(total_loss / total_tokens)).item() if total_tokens > 0 else float('inf')
def generate_text(model, prompt, max_new_tokens=60):
"""Generate text to observe quality"""
model.eval()
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs, max_new_tokens=max_new_tokens,
do_sample=True, temperature=0.7, top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Evaluation texts
eval_texts = [t for t in dataset["text"] if len(t.strip()) > 100][:200]
test_prompts = [
"The future of artificial intelligence",
"Knowledge distillation is a technique that",
"In machine learning, the relationship between",
]
# Record baselines
teacher_ppl = measure_perplexity(teacher, eval_texts, tokenizer, device)
student_ppl_before = measure_perplexity(student, eval_texts, tokenizer, device)
print(f"Teacher PPL: {teacher_ppl:.2f}")
print(f"Student PPL (before distillation): {student_ppl_before:.2f}")
6.4 Knowledge Distillation Training
def lm_distillation_loss(student_logits, teacher_logits, labels, T=3.0, alpha=0.5):
"""
Language Model Distillation Loss
"""
# Shift: predict next token
shift_student = student_logits[:, :-1, :].contiguous()
shift_teacher = teacher_logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
# Soft target distillation loss
soft_loss = F.kl_div(
F.log_softmax(shift_student / T, dim=-1),
F.softmax(shift_teacher / T, dim=-1),
reduction='batchmean'
) * (T * T)
# Hard label loss
hard_loss = F.cross_entropy(
shift_student.view(-1, shift_student.size(-1)),
shift_labels.view(-1),
ignore_index=tokenizer.pad_token_id
)
return alpha * soft_loss + (1 - alpha) * hard_loss
# ---- Distillation training ----
optimizer = torch.optim.AdamW(student.parameters(), lr=5e-5, weight_decay=0.01)
T = 3.0 # Temperature
alpha = 0.5 # Distillation weight
print("Starting LLM knowledge distillation...")
print(f" Temperature T={T}, Distillation weight alpha={alpha}")
print(f" Teacher: GPT-2 Medium (345M), Student: GPT-2 Small (124M)\n")
student.train()
for epoch in range(3):
total_loss = 0
for batch_idx, (input_ids, attention_mask) in enumerate(train_loader):
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
# Teacher inference
with torch.no_grad():
teacher_outputs = teacher(input_ids=input_ids,
attention_mask=attention_mask)
# Student inference
student_outputs = student(input_ids=input_ids,
attention_mask=attention_mask)
# Distillation loss
loss = lm_distillation_loss(
student_outputs.logits, teacher_outputs.logits,
input_ids, T=T, alpha=alpha
)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
if (batch_idx + 1) % 50 == 0:
print(f" Epoch {epoch+1}, Batch {batch_idx+1}/{len(train_loader)}, "
f"Loss: {loss.item():.4f}")
avg_loss = total_loss / len(train_loader)
print(f" Epoch {epoch+1}/3 complete, Avg Loss: {avg_loss:.4f}\n")
print("LLM distillation training complete")
6.5 Results Comparison
student_ppl_after = measure_perplexity(student, eval_texts, tokenizer, device)
print(f"\n{'='*65}")
print(f" GPT-2 Knowledge Distillation Results")
print(f"{'='*65}")
print(f"{'Model':<30} {'Perplexity':>12} {'Parameters':>14}")
print(f"{'-'*65}")
print(f"{'Teacher GPT-2 Medium':<30} {teacher_ppl:>11.2f} {teacher_params:>13,}")
print(f"{'Student (before distill)':<30} {student_ppl_before:>11.2f} {student_params:>13,}")
print(f"{'Student (after distill)':<30} {student_ppl_after:>11.2f} {student_params:>13,}")
print(f"{'='*65}")
ppl_improvement = student_ppl_before - student_ppl_after
gap_closed = (student_ppl_before - student_ppl_after) / \
(student_ppl_before - teacher_ppl) * 100 if student_ppl_before > teacher_ppl else 0
print(f"\nKey Findings:")
print(f" - Perplexity reduced by: {ppl_improvement:.2f} (lower is better)")
print(f" - Teacher-student gap closed: {gap_closed:.1f}%")
print(f" - Compression ratio: {teacher_params/student_params:.1f}x (parameters)")
print(f"\n{'='*65}")
print(f" Generation Quality Comparison")
print(f"{'='*65}")
for p in test_prompts:
print(f"\n Prompt: {p}")
print(f" Teacher: {generate_text(teacher, p, max_new_tokens=40)}")
print(f" Student: {generate_text(student, p, max_new_tokens=40)}")
6.6 Advanced: Using HuggingFace TRL's GKD Trainer
The demo above uses basic KL divergence distillation. For larger-scale LLM distillation, HuggingFace's TRL library[11] provides an out-of-the-box GKDTrainer that implements Google DeepMind's GKD paper[7]:
# pip install trl
from trl import GKDConfig, GKDTrainer
# GKD training configuration
training_args = GKDConfig(
output_dir="./gkd-output",
per_device_train_batch_size=4,
num_train_epochs=3,
learning_rate=5e-5,
lmbda=0.5, # Teacher mixing ratio (0 = pure on-policy, 1 = pure off-policy)
beta=0.5, # Interpolation parameter for Skew KL divergence
temperature=3.0, # Distillation temperature
max_new_tokens=128, # Maximum tokens for student generation
)
# Initialize GKD Trainer
trainer = GKDTrainer(
model=student_model, # Student model
teacher_model=teacher_model, # Teacher model
args=training_args,
train_dataset=train_dataset,
processing_class=tokenizer,
)
# Start on-policy distillation
trainer.train()
# Core advantages of GKD:
# 1. Student trains on its own generated sequences (on-policy)
# 2. Eliminates train-inference distribution mismatch
# 3. Supports multiple divergence measures (forward KL, reverse KL, JSD)
GKD has been validated as one of the best practices for LLM distillation in the Gemini team's internal testing, and is particularly well-suited for scenarios requiring long sequence generation (such as summarization, translation, and code generation).
7. Diffusion Model Distillation: Compressing Image Generation from 50 Steps to 1
The impact of knowledge distillation in the image generation domain is even more dramatic than in NLP. The core bottleneck of diffusion models (Stable Diffusion, FLUX) is too many generation steps — each image requires 20-50 iterative denoising steps. Distillation technology is fundamentally solving this problem.
7.1 Progressive Distillation: The Chain Reaction of Step Halving
Google's Salimans and Ho published Progressive Distillation[12] at ICLR 2022, pioneering diffusion model distillation. The core idea is remarkably intuitive: train a student model to accomplish in 1 step what the teacher does in 2 steps. Repeat this process N times, and the step count drops from 2^N to 1. On CIFAR-10 and ImageNet 64x64, they successfully compressed 8,192 steps down to 4.
7.2 LCM: High-Resolution Image Generation in 2-4 Steps
Latent Consistency Models (LCM)[13] brought distillation into latent space. Rather than directly distilling denoising steps, LCM trains the model to directly predict ODE (ordinary differential equation) solutions — skipping intermediate steps entirely. Training requires only 32 A100 GPU hours (approximately $100 in cloud costs) to generate high-quality 768x768 images.
More importantly, the LCM team simultaneously released LCM-LoRA — a lightweight adapter that can be plugged into any SD-based model. This means all community-trained custom models (DreamBooth, LoRA fine-tuned versions, etc.) can immediately gain 2-4 step acceleration:
# LCM-LoRA: Enable 2-4 step generation for any Stable Diffusion model
!pip install diffusers transformers accelerate -q
from diffusers import DiffusionPipeline, LCMScheduler
import torch
# Load any SD model
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
# Plug-and-play LCM-LoRA
pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
# Only 4 steps needed for high-quality image generation!
image = pipe(
prompt="A futuristic cityscape at sunset, photorealistic",
num_inference_steps=4, # Originally requires 25-50 steps
guidance_scale=1.5, # LCM uses lower guidance
).images[0]
image.save("lcm_result.png")
print("4-step SDXL image generation complete!")
7.3 SDXL Turbo: Real-Time Single-Step Generation
Stability AI's Adversarial Diffusion Distillation (ADD)[14] pushed step compression to its limit: single-step generation. ADD cleverly combines two losses:
- Score distillation loss: Makes the student's output "inversely" match the teacher's denoising score function
- Adversarial loss: Adds a discriminator to ensure generated images are visually realistic
The result is that SDXL Turbo can generate 512x512 images in a single step, surpassing multi-step LCM and traditional GANs in quality. While SDXL Turbo does not open-source its model weights, this technical approach (ADD) has been successfully extended to larger scales by Flux.1-schnell's LADD method.
7.4 Flux.1-schnell: LADD Latent Adversarial Distillation
Black Forest Labs' Flux.1-schnell is currently one of the highest-quality fast generation models in the open-source community. It uses LADD (Latent Adversarial Diffusion Distillation)[15] — an evolution of ADD:
- Operates in latent space: No need to decode to pixel space for adversarial loss computation, making training more stable and efficient
- Supports high resolution: Can directly generate 1024x1024 and higher resolution images at multiple aspect ratios
- 1-4 step generation: Distilled from Flux.1-pro with virtually no quality loss
- Apache 2.0 open source: Commercially usable, and has become a foundational model for the community
# Using Flux.1-schnell (distilled model)
from diffusers import FluxPipeline
import torch
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16
).to("cuda")
# Only 4 steps!
image = pipe(
prompt="A serene Japanese garden with cherry blossoms",
num_inference_steps=4,
guidance_scale=0.0, # schnell does not need classifier-free guidance
).images[0]
image.save("flux_schnell_result.png")
print("Flux.1-schnell: 4-step generation complete")
7.5 BK-SDM and SnapFusion: Architecture Distillation
Beyond step distillation, another approach is architecture distillation — making the model itself smaller.
BK-SDM[16] (ECCV 2024) removed multiple residual and attention blocks from Stable Diffusion v1.4's U-Net, then used feature distillation to restore quality. The result was a 30-50% parameter reduction with FID scores that were actually better, requiring only 13 A100-days of training — 460x cheaper than the original SD's 6,000+ A100-days.
SnapFusion[17] (NeurIPS 2023) performed both architecture distillation and step distillation simultaneously, ultimately generating 512x512 images on a mobile phone in under 2 seconds, compressing 50 steps down to 8.
| Method | Venue | Distillation Type | Step Compression | Architecture Compression | Training Cost |
|---|---|---|---|---|---|
| Progressive Distillation | ICLR 2022 | Step distillation | 8192 → 4 steps | -- | Medium |
| LCM / LCM-LoRA | 2023 | Consistency distillation | 50 → 2-4 steps | -- | 32 A100 hrs |
| ADD (SDXL Turbo) | ECCV 2024 | Adversarial distillation | 50 → 1 step | -- | High |
| LADD (Flux.1-schnell) | SIGGRAPH Asia 2024 | Latent adversarial distillation | 20-50 → 1-4 steps | -- | High |
| BK-SDM | ECCV 2024 | Feature distillation | -- | 30-50% parameter reduction | 13 A100-days |
| SnapFusion | NeurIPS 2023 | Architecture + step distillation | 50 → 8 steps | Architecture streamlining | Medium |
8. Ecosystem Tools: The Full Landscape
From academic implementations to enterprise-grade platforms, the knowledge distillation tooling ecosystem now covers the complete technology stack:
Foundational Frameworks
- PyTorch Knowledge Distillation Tutorial[18]: Official tutorial with complete Colab notebooks for three distillation strategies. Ideal for getting started and proof-of-concept
- HuggingFace TRL — GKDTrainer[11] (Documentation): Out-of-the-box GKD on-policy distillation supporting forward KL, reverse KL, JSD, and other divergence measures
- HuggingFace Transformers — distilbert-base-uncased (Model Page): Pre-trained DistilBERT, ready for direct fine-tuning on downstream tasks
LLM Distillation
- MiniLLM (GitHub): ICLR 2024, reverse KL divergence LLM distillation
- DistiLLM (GitHub): ICML 2024, Skew KL + adaptive off-policy method
- NVIDIA Minitron / NeMo (GitHub): NeurIPS 2024, complete pruning + distillation pipeline supporting LLaMA / Mistral
- OpenAI Model Distillation API (Documentation): Cloud workflow for distilling from o1/GPT-4o to GPT-4o-mini
Diffusion Model Distillation
- LCM-LoRA (HuggingFace): Plug-and-play step distillation LoRA, compatible with all SD/SDXL models
- Flux.1-schnell (HuggingFace): Apache 2.0 open source, LADD-distilled 1-4 step Flux model
- BK-SDM (GitHub): ECCV 2024, lightweight Stable Diffusion
- Diffusers Library (Documentation): HuggingFace's diffusion model framework with built-in distillation components like LCMScheduler
General-Purpose Platforms
- Intel Neural Compressor (GitHub): Unified pipeline supporting distillation + pruning + quantization
- NVIDIA ModelOpt (GitHub): Integrates quantization, pruning, distillation, and speculative decoding
9. From Technical Metrics to Business Impact
The business value of knowledge distillation extends beyond "making models smaller" — it fundamentally transforms the deployment economics of AI:
- API cost reduction: OpenAI's Model Distillation API directly distills GPT-4o capabilities into GPT-4o-mini, with API unit price differences of 10x or more. Enterprises can develop with flagship models and deploy with distilled ones
- Inference cost and latency: DistilBERT's latency on CPU is 60% lower than BERT's, enabling a large number of NLP tasks that previously required GPUs to run on CPUs — potentially reducing cloud costs by an order of magnitude
- Democratization of image generation: LCM-LoRA enables any community-fine-tuned SD model to generate in 2-4 steps, transforming the user experience from "waiting" to "instant." Flux.1-schnell's Apache 2.0 open-source license gives small studios access to top-tier generation models
- Edge AI and offline inference: BK-SDM generates images on edge devices in 4 seconds, while SnapFusion produces images on mobile phones in 2 seconds. Distillation frees generative AI from cloud dependency
- Democratization of reasoning capabilities: DeepSeek-R1's distillation experiments prove that small models can also possess deep reasoning capabilities. This is highly significant for education, healthcare, and other domains that require local deployment yet demand advanced reasoning
- Sustainable AI: Smaller models mean lower energy consumption. Harvard Business Review[1] notes that model optimization is one of the most direct levers enterprises have for controlling AI's carbon footprint
10. Adoption Roadmap: A Three-Phase Implementation Strategy
- Identify distillation opportunities: Find the models with the highest inference costs and call frequencies. For NLP classification tasks, prioritize ready-made distilled models like DistilBERT / TinyBERT; for image generation, start with plug-and-play LCM-LoRA acceleration
- Start with existing distilled models: HuggingFace already hosts a large number of pre-distilled models (DistilBERT, DistilGPT-2, LCM-LoRA, Flux.1-schnell). Use these off-the-shelf solutions first to validate whether distillation works for your use case
- Advanced custom distillation: When pre-built models cannot meet your requirements, use GKDTrainer (for LLM scenarios) or Diffusers + LCMScheduler (for image scenarios) to train custom distilled models. If budget allows, consider a Minitron-style pruning + distillation combination pipeline
Knowledge distillation is not a new technique — Hinton laid the groundwork in 2015. But over the past two years, from MiniLLM to DeepSeek-R1, from LCM to Flux.1-schnell, distillation technology has undergone a qualitative leap. It is no longer just "compression" — it is an efficient mechanism for inheriting AI capabilities, channeling the intelligence of top-tier models into every edge device, every API call, and every instantly generated image.
If your team is evaluating model compression strategies or needs to find the optimal balance between cost, latency, and capability, we welcome a deep technical conversation. Meta Intelligence's research team can guide you through the entire journey — from teacher model selection to student model deployment.



