- The Transformer[1] encoder-decoder architecture consists of six precisely designed components: multi-head attention, feed-forward networks, residual connections, layer normalization[2], positional encoding, and masking mechanisms
- Three major pre-training paradigms — Masked Language Modeling (BERT[5]), Causal Language Modeling (GPT[4]), and Denoising Reconstruction (T5[7]/BART[8]) — each with distinct advantages, gave rise to all modern large language models
- Transformers have expanded from NLP to vision (ViT[11]), multimodal (CLIP[13]), protein folding, robotics, and more, becoming the universal computing primitive of AI
- This article includes two Google Colab hands-on labs: building a Transformer machine translation model from scratch, and ViT image classification fine-tuning with feature visualization
1. Why Transformers Changed Everything
In 2017, when Vaswani et al. proposed the Transformer in "Attention Is All You Need"[1], it was merely a model for machine translation. But within just a few years, this architecture became the universal infrastructure of artificial intelligence — GPT-4 uses it to generate text, DALL-E uses it to generate images, AlphaFold uses it to predict protein structures, and RT-2 uses it to control robots.
Why can the Transformer unify such diverse domains? Because it provides an extremely general computational primitive: allowing a set of elements to dynamically communicate with each other through attention mechanisms. Whether these elements are text tokens, image patches, amino acid residues, or robot actions, the Transformer can learn the relationships between them.
This article deeply dissects the Transformer as a complete system — not just the attention mechanism (which is covered in detail in the Complete Guide to Self-Attention), but all the engineering decisions that make it trainable, scalable, and transferable.
2. Architecture Anatomy: Six Core Components
The original Transformer is an Encoder-Decoder architecture, but its core components appear repeatedly in all variants:
Complete Transformer Architecture:
Encoder (×N layers): Decoder (×N layers):
┌─────────────────────┐ ┌─────────────────────┐
│ Multi-Head │ │ Masked Multi-Head │
│ Self-Attention │ │ Self-Attention │
│ + Residual + LN │ │ + Residual + LN │
├─────────────────────┤ ├─────────────────────┤
│ Feed-Forward │ │ Cross-Attention │
│ Network (FFN) │ │ (Q=Dec, K,V=Enc) │
│ + Residual + LN │ │ + Residual + LN │
└─────────────────────┘ ├─────────────────────┤
│ Feed-Forward │
│ Network (FFN) │
│ + Residual + LN │
└─────────────────────┘
Component 1: Multi-Head Self-Attention
Allows each position to directly attend to all other positions in the sequence. The Encoder uses bidirectional attention; the Decoder uses causal masking (can only see past tokens). For detailed mathematical derivations, refer to the Complete Guide to Self-Attention.
Component 2: Feed-Forward Network (FFN)
The attention output at each layer passes through a two-layer fully connected network — this is the Transformer's "thinking" space:
FFN(x) = max(0, x · W₁ + b₁) · W₂ + b₂
Original config: d_model=512, d_ff=2048 (4x expansion)
Modern config: SwiGLU activation (LLaMA) or GELU (GPT) replacing ReLU
Research has found that FFN layers store a large amount of "factual knowledge." In models like GPT, the FFN parameters account for roughly 2/3 of the total parameters, serving as the primary carrier of the model's memorization capacity.
Component 3: Residual Connections
The input to each sub-layer is added to its output: output = sublayer(x) + x. This allows gradients to flow directly across layers and is key to training deep Transformers (100+ layers).
Component 4: Layer Normalization
Layer normalization proposed by Ba et al.[2] normalizes each token's representation to zero mean and unit variance. The original Transformer used Post-LN (normalization after the residual connection), but Xiong et al.[3] demonstrated that Pre-LN (normalization before the residual connection) eliminates the need for learning rate warmup and yields more stable training. Modern large models almost exclusively use Pre-LN or RMSNorm.
| Configuration | Order | Characteristics | Representative Models |
|---|---|---|---|
| Post-LN | Attn → Add → LN | Requires learning rate warmup, but may achieve better final performance | Original Transformer, BERT |
| Pre-LN | LN → Attn → Add | Stable training, no warmup needed | GPT-2, GPT-3 |
| RMSNorm | Similar to Pre-LN | Only scaling (no centering), computationally faster | LLaMA, PaLM |
Component 5: Positional Encoding
Injects sequential order information into the sequence. The original Transformer used sinusoidal positional encoding[1], while modern models mostly adopt Rotary Position Embedding (RoPE) or ALiBi.
Component 6: Cross-Attention and Masking
Cross-attention in the Decoder allows the generation process to "reference" the Encoder's output: Queries come from the Decoder, while Keys and Values come from the Encoder. Causal masking ensures that the Decoder can only see the first t-1 tokens when generating the t-th token.
3. The Art of Training: Making Transformers Converge
Training Transformers is more challenging than traditional models. Several seemingly minor tricks from the original paper[1] are actually critical:
| Technique | Details | Why It Matters |
|---|---|---|
| Learning Rate Warmup | Linear increase for the first 4000 steps, then square-root decay | Prevents gradient explosion in early Post-LN training[3] |
| Label Smoothing | Target distribution changes from one-hot to 0.9/0.1 mixture | Improves generalization, prevents overconfidence |
| Dropout | 0.1 on attention weights + FFN + post-residual | Prevents overfitting, provides regularization |
| Adam + β₂=0.98 | Higher momentum decay | Stabilizes gradients of attention matrices |
| Gradient Clipping | Global gradient norm clipped to 1.0 | Prevents gradient explosion |
4. Pre-Training Paradigms: Three Diverging Paths
The true power of Transformers explodes in the "pre-training → fine-tuning" paradigm. Three major pre-training strategies each have their own strengths:
Path One: Masked Language Model (MLM) — BERT
BERT[5] uses an Encoder-only architecture, randomly masking 15% of tokens for the model to predict:
Input: The [MASK] sat on the [MASK]
Target: Predict [MASK] = "cat" and [MASK] = "mat"
Advantage: Bidirectional context — each token can see all neighbors on both sides
Disadvantage: [MASK] present during training but absent during inference → pre-train/fine-tune mismatch
Path Two: Causal Language Model (CLM) — GPT
GPT[4] uses a Decoder-only architecture, predicting the next token:
Input: The cat sat on the
Target: Predict the next token = "mat"
Advantage: Naturally suited for generation tasks, no pre-train/fine-tune mismatch
Disadvantage: Unidirectional context — each token can only see tokens to the left
GPT-3[6] with its 175 billion parameters demonstrated in-context learning — performing tasks with just a few examples in the prompt without fine-tuning. This fundamentally changed how AI is used.
Path Three: Denoising Reconstruction — T5 / BART
T5[7] unified all NLP tasks into a text-to-text format, using a full Encoder-Decoder architecture:
Translation: "translate English to German: That is good" → "Das ist gut"
Summarization: "summarize: long text..." → "short summary"
Classification: "sentiment: This movie is great" → "positive"
Pre-training: span corruption — randomly masking contiguous spans for the model to reconstruct
BART[8] uses more diverse denoising strategies — deletion, shuffling, masking, rotation — training the Encoder-Decoder to reconstruct the original text from corrupted input.
| Paradigm | Architecture | Representative | Best Suited For |
|---|---|---|---|
| MLM | Encoder-only | BERT[5] | Classification, NER, QA |
| CLM | Decoder-only | GPT[4], LLaMA[9] | Generation, dialogue, code |
| Denoising | Encoder-Decoder | T5[7], BART[8] | Translation, summarization, structured generation |
5. Key Variant Evolution Map
From the original Transformer in 2017 to GPT-4, Claude, and Gemini in 2024, the architecture has undergone dramatic evolution:
| Model | Year | Parameters | Architecture | Key Innovation |
|---|---|---|---|---|
| Transformer[1] | 2017 | 65M | Enc-Dec | Self-Attention replacing RNNs |
| GPT-1[4] | 2018 | 117M | Dec-only | Generative pre-training + discriminative fine-tuning |
| BERT[5] | 2019 | 340M | Enc-only | Masked Language Model, bidirectional context |
| T5[7] | 2020 | 11B | Enc-Dec | Unified text-to-text framework |
| GPT-3[6] | 2020 | 175B | Dec-only | Few-shot in-context learning |
| PaLM[10] | 2022 | 540B | Dec-only | Pathways system, SwiGLU, RoPE |
| LLaMA[9] | 2023 | 65B | Dec-only | Public data, compute-optimal training |
A clear trend emerges: Decoder-only architecture has become the mainstream for large language models. LLaMA[9] demonstrated an important conclusion: under the same compute budget, training a smaller model with more data is more effective than training a larger model with less data. LLaMA-13B, trained only on public data, outperformed GPT-3 (175B) on most benchmarks.
6. Transformers Conquering Vision and Multimodality
Vision Transformer (ViT)
ViT[11] splits an image into 16×16 patches, treating each patch as a token and feeding it directly into a Transformer Encoder. DeiT[12] introduced knowledge distillation, solving the problem that ViT required massive data (JFT-300M) to train — achieving competitive performance with only ImageNet.
CLIP: Unifying Vision and Language
CLIP[13] simultaneously trains an image Transformer and a text Transformer through contrastive learning, learning aligned vision-language representations. This allows the model to "search" images using natural language descriptions — zero-shot image classification, image retrieval, and even serving as the conditioning encoder for Stable Diffusion.
Flamingo: Multimodal Few-Shot
Flamingo[14] bridges a frozen visual encoder with a frozen language model through gated cross-attention, achieving multimodal few-shot learning — answering questions about new images with just a few example images and questions.
7. Large-Scale Training: The Art of Engineering
Training a hundred-billion parameter Transformer involves engineering challenges far beyond model design itself. Here are the key techniques:
| Technique | Core Idea | Effect |
|---|---|---|
| Data Parallelism | One model replica per GPU, split data | Linear speedup, but each GPU needs full model memory |
| Tensor Parallelism[15] | Split matrix multiplication of a single layer across multiple GPUs | Intra-layer parallelism, low latency |
| Pipeline Parallelism | Place different layers on different GPUs | Reduces per-GPU memory, but introduces bubbles |
| ZeRO[16] | Partition optimizer states/gradients/parameters across GPUs | Memory reduced to 1/N, trillion-parameter models become feasible |
| Mixed Precision | Compute in FP16/BF16, accumulate in FP32 | Halves memory, doubles speed |
| Gradient Checkpointing | Save activations for only some layers, recompute during backward pass | Memory O(√n), at the cost of one extra forward pass |
Megatron-LM[15] achieved 76% scaling efficiency on 512 GPUs. ZeRO[16] eliminated memory redundancy in data parallelism — in traditional data parallelism, each GPU stores a complete copy of optimizer states (Adam requires 16 bytes/parameter). ZeRO partitions this across all GPUs, making trillion-parameter models trainable on reasonable hardware.
8. Frontier Advances: MoE and the Post-Transformer Era
Mixture of Experts (MoE)
Switch Transformer[17] proposed a bold idea: not every token needs to pass through all parameters. MoE replaces each FFN layer with multiple "expert" networks, routing each token to only one expert:
MoE Layer:
Input token → Router (small classifier) → Select Top-1 expert
Expert 1: FFN₁(x) ← processes some tokens
Expert 2: FFN₂(x) ← processes other tokens
...
Expert N: FFNₙ(x)
Result: Trillion parameters, but each token activates only ~1/N of parameters
→ Massive parameter count, but compute comparable to dense models
Mixtral 8x7B is a successful MoE example — a 46B-parameter model composed of 8 7B experts, where each token actually uses only about 12B parameters, yet achieves performance comparable to 70B dense models.
Post-Transformer: State Space Models
Mamba[18] challenges the Transformer's dominance with Selective State Space Models (Selective SSM). It processes sequences in linear time complexity, with 5x the throughput of Transformers at the same scale. Mamba-3B matches the performance of Transformer-6B, suggesting that self-attention may not be the only path forward.
But Transformers are fighting back: hybrid architectures like Jamba alternate Transformer layers with Mamba layers, combining the advantages of both. The future mainstream may not be pure Transformers or pure SSMs, but rather hybrid architectures.
9. Hands-on Lab 1: Building a Transformer Translation Model from Scratch (Google Colab)
The following experiment implements a complete Encoder-Decoder Transformer architecture from scratch for a simple English-German translation task.
# ============================================================
# Lab 1: Building a Transformer from Scratch — English-German Translation Model
# Environment: Google Colab (GPU)
# ============================================================
# --- 0. Installation ---
!pip install -q datasets tokenizers
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# --- 1. Data Preparation ---
# Using Tatoeba English-German sentence pairs (simple short sentences)
dataset = load_dataset("Helsinki-NLP/tatoeba_mt", "deu-eng", split="test")
pairs = [(ex["sourceString"], ex["targetString"]) for ex in dataset]
pairs = [p for p in pairs if len(p[0].split()) <= 15 and len(p[1].split()) <= 15]
pairs = pairs[:8000]
print(f"Sentence pairs: {len(pairs)}")
print(f"Example: '{pairs[0][0]}' → '{pairs[0][1]}'")
# Simple character-level tokenizer
class SimpleTokenizer:
def __init__(self):
self.char2idx = {"<pad>": 0, "<bos>": 1, "<eos>": 2, "<unk>": 3}
self.idx2char = {0: "<pad>", 1: "<bos>", 2: "<eos>", 3: "<unk>"}
def fit(self, texts):
for text in texts:
for ch in text:
if ch not in self.char2idx:
idx = len(self.char2idx)
self.char2idx[ch] = idx
self.idx2char[idx] = ch
def encode(self, text, max_len=80):
ids = [1] + [self.char2idx.get(ch, 3) for ch in text[:max_len-2]] + [2]
return ids
def decode(self, ids):
chars = []
for idx in ids:
if idx == 2: break
if idx > 2: chars.append(self.idx2char.get(idx, "?"))
return "".join(chars)
@property
def vocab_size(self):
return len(self.char2idx)
src_tokenizer = SimpleTokenizer()
tgt_tokenizer = SimpleTokenizer()
src_tokenizer.fit([p[0] for p in pairs])
tgt_tokenizer.fit([p[1] for p in pairs])
print(f"Source vocab: {src_tokenizer.vocab_size}, Target vocab: {tgt_tokenizer.vocab_size}")
MAX_LEN = 80
class TranslationDataset(Dataset):
def __init__(self, pairs, src_tok, tgt_tok):
self.pairs = pairs
self.src_tok = src_tok
self.tgt_tok = tgt_tok
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
src, tgt = self.pairs[idx]
return self.src_tok.encode(src, MAX_LEN), self.tgt_tok.encode(tgt, MAX_LEN)
def collate_fn(batch):
src_batch, tgt_batch = zip(*batch)
src_max = max(len(s) for s in src_batch)
tgt_max = max(len(t) for t in tgt_batch)
src_padded = torch.zeros(len(batch), src_max, dtype=torch.long)
tgt_padded = torch.zeros(len(batch), tgt_max, dtype=torch.long)
for i, (s, t) in enumerate(zip(src_batch, tgt_batch)):
src_padded[i, :len(s)] = torch.tensor(s)
tgt_padded[i, :len(t)] = torch.tensor(t)
return src_padded, tgt_padded
train_ds = TranslationDataset(pairs[:7000], src_tokenizer, tgt_tokenizer)
val_ds = TranslationDataset(pairs[7000:], src_tokenizer, tgt_tokenizer)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=64, collate_fn=collate_fn)
# --- 2. Transformer Model (Complete Encoder-Decoder) ---
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=200):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len).unsqueeze(1).float()
div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.d_k = d_model // n_heads
self.n_heads = n_heads
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
def forward(self, Q, K, V, mask=None):
B = Q.size(0)
Q = self.W_Q(Q).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_K(K).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_V(V).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V)
out = out.transpose(1, 2).contiguous().view(B, -1, self.n_heads * self.d_k)
return self.W_O(out)
class EncoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads)
self.ffn = nn.Sequential(nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model))
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x, src_mask):
x = self.norm1(x + self.drop(self.self_attn(x, x, x, src_mask)))
x = self.norm2(x + self.drop(self.ffn(x)))
return x
class DecoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads)
self.cross_attn = MultiHeadAttention(d_model, n_heads)
self.ffn = nn.Sequential(nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model))
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x, enc_out, src_mask, tgt_mask):
x = self.norm1(x + self.drop(self.self_attn(x, x, x, tgt_mask)))
x = self.norm2(x + self.drop(self.cross_attn(x, enc_out, enc_out, src_mask)))
x = self.norm3(x + self.drop(self.ffn(x)))
return x
class Transformer(nn.Module):
def __init__(self, src_vocab, tgt_vocab, d_model=128, n_heads=4,
n_layers=3, d_ff=256, dropout=0.1):
super().__init__()
self.src_emb = nn.Embedding(src_vocab, d_model, padding_idx=0)
self.tgt_emb = nn.Embedding(tgt_vocab, d_model, padding_idx=0)
self.pos_enc = PositionalEncoding(d_model)
self.encoder = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
self.decoder = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
self.output_proj = nn.Linear(d_model, tgt_vocab)
self.dropout = nn.Dropout(dropout)
self.d_model = d_model
def make_src_mask(self, src):
return (src != 0).unsqueeze(1).unsqueeze(2)
def make_tgt_mask(self, tgt):
B, N = tgt.shape
pad_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)
causal_mask = torch.tril(torch.ones(N, N, device=tgt.device)).bool()
return pad_mask & causal_mask.unsqueeze(0).unsqueeze(0)
def encode(self, src):
src_mask = self.make_src_mask(src)
x = self.dropout(self.pos_enc(self.src_emb(src) * math.sqrt(self.d_model)))
for layer in self.encoder:
x = layer(x, src_mask)
return x, src_mask
def decode(self, tgt, enc_out, src_mask):
tgt_mask = self.make_tgt_mask(tgt)
x = self.dropout(self.pos_enc(self.tgt_emb(tgt) * math.sqrt(self.d_model)))
for layer in self.decoder:
x = layer(x, enc_out, src_mask, tgt_mask)
return self.output_proj(x)
def forward(self, src, tgt):
enc_out, src_mask = self.encode(src)
return self.decode(tgt, enc_out, src_mask)
# --- 3. Training ---
model = Transformer(src_tokenizer.vocab_size, tgt_tokenizer.vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9)
criterion = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
for epoch in range(20):
model.train()
total_loss = 0
for src, tgt in train_loader:
src, tgt = src.to(device), tgt.to(device)
tgt_input = tgt[:, :-1]
tgt_target = tgt[:, 1:]
logits = model(src, tgt_input)
loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_target.reshape(-1))
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
if (epoch + 1) % 5 == 0:
print(f"Epoch {epoch+1}: loss={total_loss/len(train_loader):.4f}")
# --- 4. Translation Inference (Greedy Decoding) ---
def translate(text, model, src_tok, tgt_tok, max_len=80):
model.eval()
src = torch.tensor([src_tok.encode(text, max_len)]).to(device)
enc_out, src_mask = model.encode(src)
tgt_ids = [1] # <bos>
for _ in range(max_len):
tgt = torch.tensor([tgt_ids]).to(device)
logits = model.decode(tgt, enc_out, src_mask)
next_id = logits[0, -1].argmax().item()
if next_id == 2: break # <eos>
tgt_ids.append(next_id)
return tgt_tok.decode(tgt_ids)
# --- 5. Test Translations ---
print("\n=== Translation Results ===")
test_sentences = [p[0] for p in pairs[7000:7010]]
for src_text in test_sentences:
pred = translate(src_text, model, src_tokenizer, tgt_tokenizer)
print(f" EN: {src_text}")
print(f" DE: {pred}")
print()
print("Lab 1 Complete!")
10. Hands-on Lab 2: ViT Image Classification Fine-Tuning + Feature Visualization (Google Colab)
The following experiment fine-tunes a pre-trained ViT on CIFAR-10 for image classification and visualizes the Transformer's learned patch embeddings and [CLS] token features.
# ============================================================
# Lab 2: ViT Fine-Tuning on CIFAR-10 + Feature Visualization
# Environment: Google Colab (GPU)
# ============================================================
# --- 0. Installation ---
!pip install -q transformers datasets timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import ViTForImageClassification, ViTFeatureExtractor
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.decomposition import PCA
import warnings
warnings.filterwarnings('ignore')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# --- 1. Data Preparation ---
dataset = load_dataset("cifar10")
train_data = dataset["train"].shuffle(seed=42).select(range(5000))
test_data = dataset["test"].shuffle(seed=42).select(range(1000))
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def collate_fn(batch):
images = torch.stack([transform(ex["img"]) for ex in batch])
labels = torch.tensor([ex["label"] for ex in batch])
return images, labels
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False, collate_fn=collate_fn)
print(f"Train: {len(train_data)}, Test: {len(test_data)}")
# --- 2. Load Pre-Trained ViT and Replace Classification Head ---
model_name = "google/vit-base-patch16-224-in21k"
model = ViTForImageClassification.from_pretrained(
model_name,
num_labels=10,
ignore_mismatched_sizes=True
).to(device)
# Freeze all parameters except the last two layers and classification head
for name, param in model.named_parameters():
if "classifier" not in name and "layernorm" not in name and "encoder.layer.11" not in name:
param.requires_grad = False
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({trainable/total*100:.1f}%)")
# --- 3. Fine-Tuning Training ---
optimizer = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=2e-4, weight_decay=0.01
)
for epoch in range(5):
model.train()
total_loss, correct, total = 0, 0, 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(pixel_values=images, labels=labels)
loss = outputs.loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * images.size(0)
correct += (outputs.logits.argmax(1) == labels).sum().item()
total += images.size(0)
print(f"Epoch {epoch+1}: loss={total_loss/total:.4f}, acc={correct/total:.4f}")
# --- 4. Testing ---
model.eval()
correct, total = 0, 0
all_features, all_labels = [], []
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(pixel_values=images, output_hidden_states=True)
correct += (outputs.logits.argmax(1) == labels).sum().item()
total += images.size(0)
# Collect [CLS] token features
cls_features = outputs.hidden_states[-1][:, 0] # [B, 768]
all_features.append(cls_features.cpu())
all_labels.append(labels.cpu())
print(f"\nTest Accuracy: {correct/total:.4f}")
# --- 5. [CLS] Token Feature PCA Visualization ---
features = torch.cat(all_features).numpy()
labels = torch.cat(all_labels).numpy()
pca = PCA(n_components=2)
features_2d = pca.fit_transform(features)
plt.figure(figsize=(12, 8))
scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1],
c=labels, cmap='tab10', alpha=0.6, s=15)
plt.colorbar(scatter, ticks=range(10), label='Class')
plt.clim(-0.5, 9.5)
# Label class centers
for i in range(10):
mask = labels == i
cx, cy = features_2d[mask, 0].mean(), features_2d[mask, 1].mean()
plt.annotate(class_names[i], (cx, cy), fontsize=10,
fontweight='bold', ha='center',
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))
plt.title('ViT [CLS] Token Features — PCA Projection (CIFAR-10)', fontsize=14)
plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.1%} variance)')
plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.1%} variance)')
plt.tight_layout()
plt.show()
# --- 6. Position Embedding Visualization ---
pos_embed = model.vit.embeddings.position_embeddings[0].detach().cpu() # [197, 768]
print(f"\nPosition embeddings shape: {pos_embed.shape}")
# Compute cosine similarity between patch position embeddings
patch_pos = pos_embed[1:] # Remove [CLS], [196, 768]
sim_matrix = F.cosine_similarity(
patch_pos.unsqueeze(0), patch_pos.unsqueeze(1), dim=-1
) # [196, 196]
# Select several patch positions and visualize their similarity with all other positions
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
positions = [0, 7, 48, 97, 112, 140, 175, 195] # Different positions
for idx, pos in enumerate(positions):
ax = axes[idx // 4][idx % 4]
sim = sim_matrix[pos].reshape(14, 14).numpy()
im = ax.imshow(sim, cmap='RdBu_r', vmin=-1, vmax=1)
row, col = pos // 14, pos % 14
ax.plot(col, row, 'k*', markersize=10)
ax.set_title(f'Patch ({row},{col})', fontsize=10)
ax.axis('off')
plt.suptitle('Position Embedding Cosine Similarity (ViT)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
# --- 7. Per-Class Prediction Confidence Distribution ---
model.eval()
all_probs, all_preds, all_true = [], [], []
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
logits = model(pixel_values=images).logits
probs = F.softmax(logits, dim=-1)
all_probs.append(probs.cpu())
all_preds.append(logits.argmax(1).cpu())
all_true.append(labels)
all_probs = torch.cat(all_probs)
all_preds = torch.cat(all_preds)
all_true = torch.cat(all_true)
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
for i in range(10):
ax = axes[i // 5][i % 5]
mask = all_true == i
correct_conf = all_probs[mask & (all_preds == all_true)].max(dim=1).values
wrong_conf = all_probs[mask & (all_preds != all_true)].max(dim=1).values
if len(correct_conf) > 0:
ax.hist(correct_conf.numpy(), bins=20, alpha=0.7, color='#0077b6', label='Correct')
if len(wrong_conf) > 0:
ax.hist(wrong_conf.numpy(), bins=20, alpha=0.7, color='#e63946', label='Wrong')
ax.set_title(class_names[i], fontsize=11)
ax.set_xlim(0, 1)
ax.legend(fontsize=8)
plt.suptitle('Prediction Confidence Distribution per Class', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
print("\nLab 2 Complete!")
11. Conclusion: Transformer as a Universal Computing Primitive
The influence of the Transformer has far exceeded the scope of "a good model architecture." It is becoming the universal computing primitive of artificial intelligence — just as CPUs are to traditional computing and GPUs are to graphics rendering, the Transformer is becoming the core execution unit for intelligent computation.
Looking back at the core threads of this article:
- Architectural Foundations: The precise collaboration of six components (attention, FFN, residual connections, LN, positional encoding, masking)[1]
- Pre-Training Paradigms: MLM[5], CLM[4], and denoising reconstruction[7] — three paths each with distinct strengths
- Modality Expansion: From NLP to vision[11] to multimodal[13], the Transformer proved its cross-domain universality
- Scale Engineering: Tensor parallelism[15] and ZeRO[16] made trillion-parameter models a reality
- Continued Evolution: MoE[17] enables sparse scaling, while SSMs[18] challenge the necessity of attention
The Transformer may one day be surpassed, but the paradigms it established — "pre-train then transfer," "Scaling Laws," and "a universal architecture unifying multimodality" — will shape the direction of AI development for years to come. Understanding the Transformer means understanding the foundational logic of AI in this era.



