Key Findings
  • The Transformer[1] encoder-decoder architecture consists of six precisely designed components: multi-head attention, feed-forward networks, residual connections, layer normalization[2], positional encoding, and masking mechanisms
  • Three major pre-training paradigms — Masked Language Modeling (BERT[5]), Causal Language Modeling (GPT[4]), and Denoising Reconstruction (T5[7]/BART[8]) — each with distinct advantages, gave rise to all modern large language models
  • Transformers have expanded from NLP to vision (ViT[11]), multimodal (CLIP[13]), protein folding, robotics, and more, becoming the universal computing primitive of AI
  • This article includes two Google Colab hands-on labs: building a Transformer machine translation model from scratch, and ViT image classification fine-tuning with feature visualization

1. Why Transformers Changed Everything

In 2017, when Vaswani et al. proposed the Transformer in "Attention Is All You Need"[1], it was merely a model for machine translation. But within just a few years, this architecture became the universal infrastructure of artificial intelligence — GPT-4 uses it to generate text, DALL-E uses it to generate images, AlphaFold uses it to predict protein structures, and RT-2 uses it to control robots.

Why can the Transformer unify such diverse domains? Because it provides an extremely general computational primitive: allowing a set of elements to dynamically communicate with each other through attention mechanisms. Whether these elements are text tokens, image patches, amino acid residues, or robot actions, the Transformer can learn the relationships between them.

This article deeply dissects the Transformer as a complete system — not just the attention mechanism (which is covered in detail in the Complete Guide to Self-Attention), but all the engineering decisions that make it trainable, scalable, and transferable.

2. Architecture Anatomy: Six Core Components

The original Transformer is an Encoder-Decoder architecture, but its core components appear repeatedly in all variants:

Complete Transformer Architecture:

Encoder (×N layers):                 Decoder (×N layers):
┌─────────────────────┐            ┌─────────────────────┐
│  Multi-Head          │            │  Masked Multi-Head   │
│  Self-Attention      │            │  Self-Attention       │
│  + Residual + LN     │            │  + Residual + LN     │
├─────────────────────┤            ├─────────────────────┤
│  Feed-Forward        │            │  Cross-Attention     │
│  Network (FFN)       │            │  (Q=Dec, K,V=Enc)   │
│  + Residual + LN     │            │  + Residual + LN     │
└─────────────────────┘            ├─────────────────────┤
                                    │  Feed-Forward        │
                                    │  Network (FFN)       │
                                    │  + Residual + LN     │
                                    └─────────────────────┘

Component 1: Multi-Head Self-Attention

Allows each position to directly attend to all other positions in the sequence. The Encoder uses bidirectional attention; the Decoder uses causal masking (can only see past tokens). For detailed mathematical derivations, refer to the Complete Guide to Self-Attention.

Component 2: Feed-Forward Network (FFN)

The attention output at each layer passes through a two-layer fully connected network — this is the Transformer's "thinking" space:

FFN(x) = max(0, x · W₁ + b₁) · W₂ + b₂

Original config: d_model=512, d_ff=2048 (4x expansion)
Modern config: SwiGLU activation (LLaMA) or GELU (GPT) replacing ReLU

Research has found that FFN layers store a large amount of "factual knowledge." In models like GPT, the FFN parameters account for roughly 2/3 of the total parameters, serving as the primary carrier of the model's memorization capacity.

Component 3: Residual Connections

The input to each sub-layer is added to its output: output = sublayer(x) + x. This allows gradients to flow directly across layers and is key to training deep Transformers (100+ layers).

Component 4: Layer Normalization

Layer normalization proposed by Ba et al.[2] normalizes each token's representation to zero mean and unit variance. The original Transformer used Post-LN (normalization after the residual connection), but Xiong et al.[3] demonstrated that Pre-LN (normalization before the residual connection) eliminates the need for learning rate warmup and yields more stable training. Modern large models almost exclusively use Pre-LN or RMSNorm.

ConfigurationOrderCharacteristicsRepresentative Models
Post-LNAttn → Add → LNRequires learning rate warmup, but may achieve better final performanceOriginal Transformer, BERT
Pre-LNLN → Attn → AddStable training, no warmup neededGPT-2, GPT-3
RMSNormSimilar to Pre-LNOnly scaling (no centering), computationally fasterLLaMA, PaLM

Component 5: Positional Encoding

Injects sequential order information into the sequence. The original Transformer used sinusoidal positional encoding[1], while modern models mostly adopt Rotary Position Embedding (RoPE) or ALiBi.

Component 6: Cross-Attention and Masking

Cross-attention in the Decoder allows the generation process to "reference" the Encoder's output: Queries come from the Decoder, while Keys and Values come from the Encoder. Causal masking ensures that the Decoder can only see the first t-1 tokens when generating the t-th token.

3. The Art of Training: Making Transformers Converge

Training Transformers is more challenging than traditional models. Several seemingly minor tricks from the original paper[1] are actually critical:

TechniqueDetailsWhy It Matters
Learning Rate WarmupLinear increase for the first 4000 steps, then square-root decayPrevents gradient explosion in early Post-LN training[3]
Label SmoothingTarget distribution changes from one-hot to 0.9/0.1 mixtureImproves generalization, prevents overconfidence
Dropout0.1 on attention weights + FFN + post-residualPrevents overfitting, provides regularization
Adam + β₂=0.98Higher momentum decayStabilizes gradients of attention matrices
Gradient ClippingGlobal gradient norm clipped to 1.0Prevents gradient explosion

4. Pre-Training Paradigms: Three Diverging Paths

The true power of Transformers explodes in the "pre-training → fine-tuning" paradigm. Three major pre-training strategies each have their own strengths:

Path One: Masked Language Model (MLM) — BERT

BERT[5] uses an Encoder-only architecture, randomly masking 15% of tokens for the model to predict:

Input: The [MASK] sat on the [MASK]
Target: Predict [MASK] = "cat" and [MASK] = "mat"

Advantage: Bidirectional context — each token can see all neighbors on both sides
Disadvantage: [MASK] present during training but absent during inference → pre-train/fine-tune mismatch

Path Two: Causal Language Model (CLM) — GPT

GPT[4] uses a Decoder-only architecture, predicting the next token:

Input: The cat sat on the
Target: Predict the next token = "mat"

Advantage: Naturally suited for generation tasks, no pre-train/fine-tune mismatch
Disadvantage: Unidirectional context — each token can only see tokens to the left

GPT-3[6] with its 175 billion parameters demonstrated in-context learning — performing tasks with just a few examples in the prompt without fine-tuning. This fundamentally changed how AI is used.

Path Three: Denoising Reconstruction — T5 / BART

T5[7] unified all NLP tasks into a text-to-text format, using a full Encoder-Decoder architecture:

Translation: "translate English to German: That is good" → "Das ist gut"
Summarization: "summarize: long text..." → "short summary"
Classification: "sentiment: This movie is great" → "positive"

Pre-training: span corruption — randomly masking contiguous spans for the model to reconstruct

BART[8] uses more diverse denoising strategies — deletion, shuffling, masking, rotation — training the Encoder-Decoder to reconstruct the original text from corrupted input.

ParadigmArchitectureRepresentativeBest Suited For
MLMEncoder-onlyBERT[5]Classification, NER, QA
CLMDecoder-onlyGPT[4], LLaMA[9]Generation, dialogue, code
DenoisingEncoder-DecoderT5[7], BART[8]Translation, summarization, structured generation

5. Key Variant Evolution Map

From the original Transformer in 2017 to GPT-4, Claude, and Gemini in 2024, the architecture has undergone dramatic evolution:

ModelYearParametersArchitectureKey Innovation
Transformer[1]201765MEnc-DecSelf-Attention replacing RNNs
GPT-1[4]2018117MDec-onlyGenerative pre-training + discriminative fine-tuning
BERT[5]2019340MEnc-onlyMasked Language Model, bidirectional context
T5[7]202011BEnc-DecUnified text-to-text framework
GPT-3[6]2020175BDec-onlyFew-shot in-context learning
PaLM[10]2022540BDec-onlyPathways system, SwiGLU, RoPE
LLaMA[9]202365BDec-onlyPublic data, compute-optimal training

A clear trend emerges: Decoder-only architecture has become the mainstream for large language models. LLaMA[9] demonstrated an important conclusion: under the same compute budget, training a smaller model with more data is more effective than training a larger model with less data. LLaMA-13B, trained only on public data, outperformed GPT-3 (175B) on most benchmarks.

6. Transformers Conquering Vision and Multimodality

Vision Transformer (ViT)

ViT[11] splits an image into 16×16 patches, treating each patch as a token and feeding it directly into a Transformer Encoder. DeiT[12] introduced knowledge distillation, solving the problem that ViT required massive data (JFT-300M) to train — achieving competitive performance with only ImageNet.

CLIP: Unifying Vision and Language

CLIP[13] simultaneously trains an image Transformer and a text Transformer through contrastive learning, learning aligned vision-language representations. This allows the model to "search" images using natural language descriptions — zero-shot image classification, image retrieval, and even serving as the conditioning encoder for Stable Diffusion.

Flamingo: Multimodal Few-Shot

Flamingo[14] bridges a frozen visual encoder with a frozen language model through gated cross-attention, achieving multimodal few-shot learning — answering questions about new images with just a few example images and questions.

7. Large-Scale Training: The Art of Engineering

Training a hundred-billion parameter Transformer involves engineering challenges far beyond model design itself. Here are the key techniques:

TechniqueCore IdeaEffect
Data ParallelismOne model replica per GPU, split dataLinear speedup, but each GPU needs full model memory
Tensor Parallelism[15]Split matrix multiplication of a single layer across multiple GPUsIntra-layer parallelism, low latency
Pipeline ParallelismPlace different layers on different GPUsReduces per-GPU memory, but introduces bubbles
ZeRO[16]Partition optimizer states/gradients/parameters across GPUsMemory reduced to 1/N, trillion-parameter models become feasible
Mixed PrecisionCompute in FP16/BF16, accumulate in FP32Halves memory, doubles speed
Gradient CheckpointingSave activations for only some layers, recompute during backward passMemory O(√n), at the cost of one extra forward pass

Megatron-LM[15] achieved 76% scaling efficiency on 512 GPUs. ZeRO[16] eliminated memory redundancy in data parallelism — in traditional data parallelism, each GPU stores a complete copy of optimizer states (Adam requires 16 bytes/parameter). ZeRO partitions this across all GPUs, making trillion-parameter models trainable on reasonable hardware.

8. Frontier Advances: MoE and the Post-Transformer Era

Mixture of Experts (MoE)

Switch Transformer[17] proposed a bold idea: not every token needs to pass through all parameters. MoE replaces each FFN layer with multiple "expert" networks, routing each token to only one expert:

MoE Layer:
  Input token → Router (small classifier) → Select Top-1 expert
  Expert 1: FFN₁(x)  ← processes some tokens
  Expert 2: FFN₂(x)  ← processes other tokens
  ...
  Expert N: FFNₙ(x)

Result: Trillion parameters, but each token activates only ~1/N of parameters
      → Massive parameter count, but compute comparable to dense models

Mixtral 8x7B is a successful MoE example — a 46B-parameter model composed of 8 7B experts, where each token actually uses only about 12B parameters, yet achieves performance comparable to 70B dense models.

Post-Transformer: State Space Models

Mamba[18] challenges the Transformer's dominance with Selective State Space Models (Selective SSM). It processes sequences in linear time complexity, with 5x the throughput of Transformers at the same scale. Mamba-3B matches the performance of Transformer-6B, suggesting that self-attention may not be the only path forward.

But Transformers are fighting back: hybrid architectures like Jamba alternate Transformer layers with Mamba layers, combining the advantages of both. The future mainstream may not be pure Transformers or pure SSMs, but rather hybrid architectures.

9. Hands-on Lab 1: Building a Transformer Translation Model from Scratch (Google Colab)

The following experiment implements a complete Encoder-Decoder Transformer architecture from scratch for a simple English-German translation task.

# ============================================================
# Lab 1: Building a Transformer from Scratch — English-German Translation Model
# Environment: Google Colab (GPU)
# ============================================================
# --- 0. Installation ---
!pip install -q datasets tokenizers

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# --- 1. Data Preparation ---
# Using Tatoeba English-German sentence pairs (simple short sentences)
dataset = load_dataset("Helsinki-NLP/tatoeba_mt", "deu-eng", split="test")
pairs = [(ex["sourceString"], ex["targetString"]) for ex in dataset]
pairs = [p for p in pairs if len(p[0].split()) <= 15 and len(p[1].split()) <= 15]
pairs = pairs[:8000]
print(f"Sentence pairs: {len(pairs)}")
print(f"Example: '{pairs[0][0]}' → '{pairs[0][1]}'")

# Simple character-level tokenizer
class SimpleTokenizer:
    def __init__(self):
        self.char2idx = {"<pad>": 0, "<bos>": 1, "<eos>": 2, "<unk>": 3}
        self.idx2char = {0: "<pad>", 1: "<bos>", 2: "<eos>", 3: "<unk>"}

    def fit(self, texts):
        for text in texts:
            for ch in text:
                if ch not in self.char2idx:
                    idx = len(self.char2idx)
                    self.char2idx[ch] = idx
                    self.idx2char[idx] = ch

    def encode(self, text, max_len=80):
        ids = [1] + [self.char2idx.get(ch, 3) for ch in text[:max_len-2]] + [2]
        return ids

    def decode(self, ids):
        chars = []
        for idx in ids:
            if idx == 2: break
            if idx > 2: chars.append(self.idx2char.get(idx, "?"))
        return "".join(chars)

    @property
    def vocab_size(self):
        return len(self.char2idx)

src_tokenizer = SimpleTokenizer()
tgt_tokenizer = SimpleTokenizer()
src_tokenizer.fit([p[0] for p in pairs])
tgt_tokenizer.fit([p[1] for p in pairs])
print(f"Source vocab: {src_tokenizer.vocab_size}, Target vocab: {tgt_tokenizer.vocab_size}")

MAX_LEN = 80

class TranslationDataset(Dataset):
    def __init__(self, pairs, src_tok, tgt_tok):
        self.pairs = pairs
        self.src_tok = src_tok
        self.tgt_tok = tgt_tok

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        src, tgt = self.pairs[idx]
        return self.src_tok.encode(src, MAX_LEN), self.tgt_tok.encode(tgt, MAX_LEN)

def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src_max = max(len(s) for s in src_batch)
    tgt_max = max(len(t) for t in tgt_batch)
    src_padded = torch.zeros(len(batch), src_max, dtype=torch.long)
    tgt_padded = torch.zeros(len(batch), tgt_max, dtype=torch.long)
    for i, (s, t) in enumerate(zip(src_batch, tgt_batch)):
        src_padded[i, :len(s)] = torch.tensor(s)
        tgt_padded[i, :len(t)] = torch.tensor(t)
    return src_padded, tgt_padded

train_ds = TranslationDataset(pairs[:7000], src_tokenizer, tgt_tokenizer)
val_ds = TranslationDataset(pairs[7000:], src_tokenizer, tgt_tokenizer)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=64, collate_fn=collate_fn)

# --- 2. Transformer Model (Complete Encoder-Decoder) ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=200):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_k = d_model // n_heads
        self.n_heads = n_heads
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, mask=None):
        B = Q.size(0)
        Q = self.W_Q(Q).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_K(K).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_V(V).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, V)
        out = out.transpose(1, 2).contiguous().view(B, -1, self.n_heads * self.d_k)
        return self.W_O(out)

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads)
        self.ffn = nn.Sequential(nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model))
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, src_mask):
        x = self.norm1(x + self.drop(self.self_attn(x, x, x, src_mask)))
        x = self.norm2(x + self.drop(self.ffn(x)))
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads)
        self.cross_attn = MultiHeadAttention(d_model, n_heads)
        self.ffn = nn.Sequential(nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model))
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, tgt_mask):
        x = self.norm1(x + self.drop(self.self_attn(x, x, x, tgt_mask)))
        x = self.norm2(x + self.drop(self.cross_attn(x, enc_out, enc_out, src_mask)))
        x = self.norm3(x + self.drop(self.ffn(x)))
        return x

class Transformer(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=128, n_heads=4,
                 n_layers=3, d_ff=256, dropout=0.1):
        super().__init__()
        self.src_emb = nn.Embedding(src_vocab, d_model, padding_idx=0)
        self.tgt_emb = nn.Embedding(tgt_vocab, d_model, padding_idx=0)
        self.pos_enc = PositionalEncoding(d_model)
        self.encoder = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
        self.decoder = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
        self.output_proj = nn.Linear(d_model, tgt_vocab)
        self.dropout = nn.Dropout(dropout)
        self.d_model = d_model

    def make_src_mask(self, src):
        return (src != 0).unsqueeze(1).unsqueeze(2)

    def make_tgt_mask(self, tgt):
        B, N = tgt.shape
        pad_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)
        causal_mask = torch.tril(torch.ones(N, N, device=tgt.device)).bool()
        return pad_mask & causal_mask.unsqueeze(0).unsqueeze(0)

    def encode(self, src):
        src_mask = self.make_src_mask(src)
        x = self.dropout(self.pos_enc(self.src_emb(src) * math.sqrt(self.d_model)))
        for layer in self.encoder:
            x = layer(x, src_mask)
        return x, src_mask

    def decode(self, tgt, enc_out, src_mask):
        tgt_mask = self.make_tgt_mask(tgt)
        x = self.dropout(self.pos_enc(self.tgt_emb(tgt) * math.sqrt(self.d_model)))
        for layer in self.decoder:
            x = layer(x, enc_out, src_mask, tgt_mask)
        return self.output_proj(x)

    def forward(self, src, tgt):
        enc_out, src_mask = self.encode(src)
        return self.decode(tgt, enc_out, src_mask)

# --- 3. Training ---
model = Transformer(src_tokenizer.vocab_size, tgt_tokenizer.vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9)
criterion = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

for epoch in range(20):
    model.train()
    total_loss = 0
    for src, tgt in train_loader:
        src, tgt = src.to(device), tgt.to(device)
        tgt_input = tgt[:, :-1]
        tgt_target = tgt[:, 1:]
        logits = model(src, tgt_input)
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_target.reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1}: loss={total_loss/len(train_loader):.4f}")

# --- 4. Translation Inference (Greedy Decoding) ---
def translate(text, model, src_tok, tgt_tok, max_len=80):
    model.eval()
    src = torch.tensor([src_tok.encode(text, max_len)]).to(device)
    enc_out, src_mask = model.encode(src)

    tgt_ids = [1]  # <bos>
    for _ in range(max_len):
        tgt = torch.tensor([tgt_ids]).to(device)
        logits = model.decode(tgt, enc_out, src_mask)
        next_id = logits[0, -1].argmax().item()
        if next_id == 2: break  # <eos>
        tgt_ids.append(next_id)

    return tgt_tok.decode(tgt_ids)

# --- 5. Test Translations ---
print("\n=== Translation Results ===")
test_sentences = [p[0] for p in pairs[7000:7010]]
for src_text in test_sentences:
    pred = translate(src_text, model, src_tokenizer, tgt_tokenizer)
    print(f"  EN: {src_text}")
    print(f"  DE: {pred}")
    print()

print("Lab 1 Complete!")

10. Hands-on Lab 2: ViT Image Classification Fine-Tuning + Feature Visualization (Google Colab)

The following experiment fine-tunes a pre-trained ViT on CIFAR-10 for image classification and visualizes the Transformer's learned patch embeddings and [CLS] token features.

# ============================================================
# Lab 2: ViT Fine-Tuning on CIFAR-10 + Feature Visualization
# Environment: Google Colab (GPU)
# ============================================================
# --- 0. Installation ---
!pip install -q transformers datasets timm

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import ViTForImageClassification, ViTFeatureExtractor
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.decomposition import PCA
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# --- 1. Data Preparation ---
dataset = load_dataset("cifar10")
train_data = dataset["train"].shuffle(seed=42).select(range(5000))
test_data = dataset["test"].shuffle(seed=42).select(range(1000))

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def collate_fn(batch):
    images = torch.stack([transform(ex["img"]) for ex in batch])
    labels = torch.tensor([ex["label"] for ex in batch])
    return images, labels

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False, collate_fn=collate_fn)
print(f"Train: {len(train_data)}, Test: {len(test_data)}")

# --- 2. Load Pre-Trained ViT and Replace Classification Head ---
model_name = "google/vit-base-patch16-224-in21k"
model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels=10,
    ignore_mismatched_sizes=True
).to(device)

# Freeze all parameters except the last two layers and classification head
for name, param in model.named_parameters():
    if "classifier" not in name and "layernorm" not in name and "encoder.layer.11" not in name:
        param.requires_grad = False

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({trainable/total*100:.1f}%)")

# --- 3. Fine-Tuning Training ---
optimizer = torch.optim.AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr=2e-4, weight_decay=0.01
)

for epoch in range(5):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(pixel_values=images, labels=labels)
        loss = outputs.loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * images.size(0)
        correct += (outputs.logits.argmax(1) == labels).sum().item()
        total += images.size(0)
    print(f"Epoch {epoch+1}: loss={total_loss/total:.4f}, acc={correct/total:.4f}")

# --- 4. Testing ---
model.eval()
correct, total = 0, 0
all_features, all_labels = [], []
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(pixel_values=images, output_hidden_states=True)
        correct += (outputs.logits.argmax(1) == labels).sum().item()
        total += images.size(0)
        # Collect [CLS] token features
        cls_features = outputs.hidden_states[-1][:, 0]  # [B, 768]
        all_features.append(cls_features.cpu())
        all_labels.append(labels.cpu())

print(f"\nTest Accuracy: {correct/total:.4f}")

# --- 5. [CLS] Token Feature PCA Visualization ---
features = torch.cat(all_features).numpy()
labels = torch.cat(all_labels).numpy()

pca = PCA(n_components=2)
features_2d = pca.fit_transform(features)

plt.figure(figsize=(12, 8))
scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1],
                      c=labels, cmap='tab10', alpha=0.6, s=15)
plt.colorbar(scatter, ticks=range(10), label='Class')
plt.clim(-0.5, 9.5)

# Label class centers
for i in range(10):
    mask = labels == i
    cx, cy = features_2d[mask, 0].mean(), features_2d[mask, 1].mean()
    plt.annotate(class_names[i], (cx, cy), fontsize=10,
                 fontweight='bold', ha='center',
                 bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))

plt.title('ViT [CLS] Token Features — PCA Projection (CIFAR-10)', fontsize=14)
plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.1%} variance)')
plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.1%} variance)')
plt.tight_layout()
plt.show()

# --- 6. Position Embedding Visualization ---
pos_embed = model.vit.embeddings.position_embeddings[0].detach().cpu()  # [197, 768]
print(f"\nPosition embeddings shape: {pos_embed.shape}")

# Compute cosine similarity between patch position embeddings
patch_pos = pos_embed[1:]  # Remove [CLS], [196, 768]
sim_matrix = F.cosine_similarity(
    patch_pos.unsqueeze(0), patch_pos.unsqueeze(1), dim=-1
)  # [196, 196]

# Select several patch positions and visualize their similarity with all other positions
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
positions = [0, 7, 48, 97, 112, 140, 175, 195]  # Different positions
for idx, pos in enumerate(positions):
    ax = axes[idx // 4][idx % 4]
    sim = sim_matrix[pos].reshape(14, 14).numpy()
    im = ax.imshow(sim, cmap='RdBu_r', vmin=-1, vmax=1)
    row, col = pos // 14, pos % 14
    ax.plot(col, row, 'k*', markersize=10)
    ax.set_title(f'Patch ({row},{col})', fontsize=10)
    ax.axis('off')

plt.suptitle('Position Embedding Cosine Similarity (ViT)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# --- 7. Per-Class Prediction Confidence Distribution ---
model.eval()
all_probs, all_preds, all_true = [], [], []
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        logits = model(pixel_values=images).logits
        probs = F.softmax(logits, dim=-1)
        all_probs.append(probs.cpu())
        all_preds.append(logits.argmax(1).cpu())
        all_true.append(labels)

all_probs = torch.cat(all_probs)
all_preds = torch.cat(all_preds)
all_true = torch.cat(all_true)

fig, axes = plt.subplots(2, 5, figsize=(20, 8))
for i in range(10):
    ax = axes[i // 5][i % 5]
    mask = all_true == i
    correct_conf = all_probs[mask & (all_preds == all_true)].max(dim=1).values
    wrong_conf = all_probs[mask & (all_preds != all_true)].max(dim=1).values
    if len(correct_conf) > 0:
        ax.hist(correct_conf.numpy(), bins=20, alpha=0.7, color='#0077b6', label='Correct')
    if len(wrong_conf) > 0:
        ax.hist(wrong_conf.numpy(), bins=20, alpha=0.7, color='#e63946', label='Wrong')
    ax.set_title(class_names[i], fontsize=11)
    ax.set_xlim(0, 1)
    ax.legend(fontsize=8)

plt.suptitle('Prediction Confidence Distribution per Class', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nLab 2 Complete!")

11. Conclusion: Transformer as a Universal Computing Primitive

The influence of the Transformer has far exceeded the scope of "a good model architecture." It is becoming the universal computing primitive of artificial intelligence — just as CPUs are to traditional computing and GPUs are to graphics rendering, the Transformer is becoming the core execution unit for intelligent computation.

Looking back at the core threads of this article:

The Transformer may one day be surpassed, but the paradigms it established — "pre-train then transfer," "Scaling Laws," and "a universal architecture unifying multimodality" — will shape the direction of AI development for years to come. Understanding the Transformer means understanding the foundational logic of AI in this era.