Key Findings
  • Transformer[1] 的編碼器-解碼器架構由六個精密設計的組件構成:多頭注意力、前饋網路、殘差連接、層歸一化[2]、位置編碼、遮罩機制
  • 三大預訓練範式——遮蔽語言模型(BERT[5])、因果語言模型(GPT[4])、去噪重建(T5[7]/BART[8])——各有優勢,催生了當代所有大型語言模型
  • Transformer 已從 NLP 擴展至視覺(ViT[11])、多模態(CLIP[13])、蛋白質摺疊、機器人控制等領域,成為 AI 的通用計算基元
  • 本文附兩個 Google Colab 實作:從零搭建 Transformer 機器翻譯模型、ViT 影像分類微調與特徵視覺化

一、為什麼 Transformer 改變了一切

2017 年,Vaswani 等人在「Attention Is All You Need」[1]中提出 Transformer 時,它只是一個用於機器翻譯的模型。但短短數年內,這個架構成為了人工智慧的通用基礎設施——GPT-4 用它生成文字、DALL-E 用它生成影像、AlphaFold 用它預測蛋白質結構、RT-2 用它控制機器人。

Transformer 為何能統一如此多樣的領域?因為它提供了一個極度通用的計算原語:讓一組元素透過注意力機制動態地互相溝通。無論這些元素是文字 token、影像 patch、氨基酸殘基還是機器人動作,Transformer 都能學到它們之間的關係。

本文深入拆解 Transformer 作為完整系統的設計——不僅是注意力機制(已在自注意力機制完全指南中詳述),更包括所有使其可訓練、可擴展、可遷移的工程決策。

二、架構解剖:六大核心組件

原始 Transformer 是一個 Encoder-Decoder 架構,但其核心組件在所有變體中都反覆出現:

Transformer 完整架構:

Encoder (×N 層):                    Decoder (×N 層):
┌─────────────────────┐            ┌─────────────────────┐
│  Multi-Head          │            │  Masked Multi-Head   │
│  Self-Attention      │            │  Self-Attention       │
│  + Residual + LN     │            │  + Residual + LN     │
├─────────────────────┤            ├─────────────────────┤
│  Feed-Forward        │            │  Cross-Attention     │
│  Network (FFN)       │            │  (Q=Dec, K,V=Enc)   │
│  + Residual + LN     │            │  + Residual + LN     │
└─────────────────────┘            ├─────────────────────┤
                                    │  Feed-Forward        │
                                    │  Network (FFN)       │
                                    │  + Residual + LN     │
                                    └─────────────────────┘

組件 1:多頭自注意力

讓每個位置直接關注序列中的所有其他位置。Encoder 使用雙向注意力;Decoder 使用因果遮罩(只能看到過去的 token)。詳細數學推導請參考自注意力機制完全指南

組件 2:前饋網路(FFN)

每一層的注意力輸出都會經過一個兩層的全連接網路,這是 Transformer 的「思考」空間:

FFN(x) = max(0, x · W₁ + b₁) · W₂ + b₂

原始配置: d_model=512, d_ff=2048 (4 倍擴展)
現代配置: 使用 SwiGLU 激活(LLaMA)或 GELU(GPT)取代 ReLU

研究發現,FFN 層儲存了大量的「事實知識」。在 GPT 等模型中,FFN 的參數量佔總參數的約 2/3,是模型記憶能力的主要載體。

組件 3:殘差連接

每個子層的輸入會與其輸出相加:output = sublayer(x) + x。這讓梯度可以直接跨層流動,是訓練深層 Transformer(100+ 層)的關鍵。

組件 4:層歸一化

Ba 等人[2]提出的層歸一化將每個 token 的表示正規化為零均值、單位方差。原始 Transformer 採用 Post-LN(歸一化放在殘差連接之後),但 Xiong 等人[3]證明 Pre-LN(歸一化放在殘差連接之前)能消除學習率暖身的必要性,訓練更穩定。現代大型模型幾乎都使用 Pre-LN 或 RMSNorm。

配置順序特性代表模型
Post-LNAttn → Add → LN需要學習率暖身,但最終性能可能更好原始 Transformer, BERT
Pre-LNLN → Attn → Add訓練穩定,不需暖身GPT-2, GPT-3
RMSNorm類似 Pre-LN僅做縮放(無中心化),計算更快LLaMA, PaLM

組件 5:位置編碼

為序列注入順序資訊。原始 Transformer 使用正弦波位置編碼[1],現代模型大多採用旋轉位置編碼(RoPE)或 ALiBi。

組件 6:交叉注意力與遮罩

Decoder 中的交叉注意力讓生成過程能「參考」編碼器的輸出:Query 來自 Decoder,Key 和 Value 來自 Encoder。因果遮罩確保 Decoder 在生成第 t 個 token 時只能看到前 t-1 個 token。

三、訓練的藝術:讓 Transformer 收斂

Transformer 的訓練比傳統模型更具挑戰性。原始論文[1]中幾個看似不起眼的技巧實際上至關重要:

技巧細節為什麼重要
學習率暖身前 4000 步線性增長,之後平方根衰減避免訓練初期 Post-LN 的梯度爆炸[3]
Label Smoothing目標分布從 one-hot 變為 0.9/0.1 混合提升泛化能力,防止過度自信
Dropout注意力權重 + FFN + 殘差後各 0.1防止過擬合,提供正則化
Adam + β₂=0.98較高的動量衰減穩定注意力矩陣的梯度
梯度裁剪全域梯度範數裁剪至 1.0防止梯度爆炸

四、預訓練範式:三條路線的分岔

Transformer 的真正力量在「預訓練 → 微調」這個範式中爆發。三種主要的預訓練策略各有千秋:

路線一:遮蔽語言模型(MLM)— BERT

BERT[5] 使用 Encoder-only 架構,隨機遮蔽 15% 的 token 讓模型預測:

輸入: The [MASK] sat on the [MASK]
目標: 預測 [MASK] = "cat" 和 [MASK] = "mat"

優勢: 雙向上下文 — 每個 token 能看到左右所有鄰居
劣勢: 訓練時有 [MASK],推論時沒有 → 預訓練-微調不匹配

路線二:因果語言模型(CLM)— GPT

GPT[4] 使用 Decoder-only 架構,預測下一個 token:

輸入: The cat sat on the
目標: 預測下一個 token = "mat"

優勢: 天然適合生成任務、無預訓練-微調不匹配
劣勢: 單向上下文 — 每個 token 只能看到左邊

GPT-3[6] 的 1750 億參數版本展示了 in-context learning——不需要微調,僅靠 prompt 中的幾個範例就能執行任務。這徹底改變了 AI 的使用方式。

路線三:去噪重建 — T5 / BART

T5[7] 將所有 NLP 任務統一為 text-to-text 格式,使用完整的 Encoder-Decoder 架構:

翻譯: "translate English to German: That is good" → "Das ist gut"
摘要: "summarize: 長文..." → "短摘要"
分類: "sentiment: This movie is great" → "positive"

預訓練: span corruption — 隨機遮蔽連續的 span 讓模型重建

BART[8] 則使用更多樣的去噪策略——刪除、打亂、遮蔽、旋轉——訓練 Encoder-Decoder 從損壞的文本重建原始文本。

範式架構代表最適任務
MLMEncoder-onlyBERT[5]分類、NER、問答
CLMDecoder-onlyGPT[4], LLaMA[9]生成、對話、程式碼
去噪Encoder-DecoderT5[7], BART[8]翻譯、摘要、結構化生成

五、關鍵變體演化圖譜

從 2017 年的原始 Transformer 到 2024 年的 GPT-4、Claude、Gemini,架構經歷了戲劇性的演化:

模型年份參數量架構關鍵創新
Transformer[1]201765MEnc-DecSelf-Attention 取代 RNN
GPT-1[4]2018117MDec-only生成式預訓練 + 判別式微調
BERT[5]2019340MEnc-only遮蔽語言模型、雙向上下文
T5[7]202011BEnc-Dec統一 text-to-text 框架
GPT-3[6]2020175BDec-onlyFew-shot in-context learning
PaLM[10]2022540BDec-onlyPathways 系統、SwiGLU、RoPE
LLaMA[9]202365BDec-only公開資料、計算最優訓練

一個明確的趨勢是:Decoder-only 架構已成為大型語言模型的主流。LLaMA[9] 證明了一個重要結論:在相同的計算預算下,用更多資料訓練較小的模型,比用較少資料訓練巨大的模型更有效。LLaMA-13B 僅用公開資料訓練,卻在大多數基準上超越了 GPT-3(175B)。

六、Transformer 征服視覺與多模態

Vision Transformer (ViT)

ViT[11] 將影像分割為 16×16 patch,每個 patch 視為一個 token,直接送入 Transformer Encoder。DeiT[12] 引入知識蒸餾技術,解決了 ViT 需要海量資料(JFT-300M)才能訓練的問題——僅用 ImageNet 就能達到有競爭力的性能。

CLIP:統一視覺與語言

CLIP[13] 以對比學習同時訓練一個影像 Transformer 和一個文本 Transformer,學習對齊的視覺-語言表示。這讓模型能以自然語言描述「搜尋」影像——零樣本影像分類、影像檢索、甚至作為 Stable Diffusion 的條件編碼器。

Flamingo:多模態 Few-Shot

Flamingo[14] 以交叉注意力門控將凍結的視覺編碼器與凍結的語言模型橋接,實現了多模態的 few-shot 學習——僅靠幾張範例影像和問題,就能回答關於新影像的問題。

七、大規模訓練:工程的藝術

訓練一個千億參數的 Transformer 涉及的工程挑戰遠超模型設計本身。以下是關鍵技術:

技術核心思想效果
資料平行每個 GPU 一份模型副本,分割資料線性加速,但每個 GPU 需完整模型記憶體
張量平行[15]將單層的矩陣乘法分割到多個 GPU單層內平行,低延遲
管線平行不同層放在不同 GPU 上減少每個 GPU 的記憶體,但有氣泡
ZeRO[16]跨 GPU 分割優化器狀態/梯度/參數記憶體降至 1/N,兆級參數可行
混合精度計算用 FP16/BF16,累加用 FP32記憶體減半,速度加倍
梯度檢查點只保存部分層的激活值,反向時重算記憶體 O(√n),代價一次額外前向

Megatron-LM[15] 在 512 個 GPU 上達到 76% 的擴展效率。ZeRO[16] 則消除了資料平行中的記憶體冗餘——傳統資料平行中每個 GPU 都保存完整的優化器狀態(Adam 需要 16 bytes/參數),ZeRO 將其分割到所有 GPU,使兆級參數模型在合理的硬體上可訓練。

八、前沿進展:MoE 與後 Transformer 時代

混合專家模型(Mixture of Experts)

Switch Transformer[17] 提出了一個大膽的想法:不是每個 token 都需要經過所有參數。MoE 將每個 FFN 層替換為多個「專家」網路,每個 token 只路由到一個專家:

MoE Layer:
  輸入 token → Router (小型分類器) → 選擇 Top-1 專家
  專家 1: FFN₁(x)  ← 處理某些 token
  專家 2: FFN₂(x)  ← 處理其他 token
  ...
  專家 N: FFNₙ(x)

結果: 兆級參數,但每個 token 只啟動 ~1/N 的參數
      → 參數量巨大,但計算量與密集模型相當

Mixtral 8x7B 是 MoE 的成功案例——以 8 個 7B 專家組成的 46B 參數模型,實際每個 token 只使用約 12B 參數,性能卻與 70B 密集模型相當。

後 Transformer:狀態空間模型

Mamba[18] 以選擇性狀態空間模型(Selective SSM)挑戰 Transformer 的主導地位。它以線性時間複雜度處理序列,在同等規模下吞吐量是 Transformer 的 5 倍。Mamba-3B 的性能匹配了 Transformer-6B,暗示自注意力或許不是唯一的道路。

但 Transformer 也在反擊:Jamba 等混合架構將 Transformer 層與 Mamba 層交替使用,結合兩者的優勢。未來的主流可能不是純 Transformer 或純 SSM,而是混合架構

九、Hands-on Lab 1:從零搭建 Transformer 翻譯模型(Google Colab)

以下實驗從零實現完整的 Encoder-Decoder Transformer 架構,進行簡單的英德翻譯任務。

# ============================================================
# Lab 1: 從零搭建 Transformer — 英德翻譯模型
# 環境: Google Colab (GPU)
# ============================================================
# --- 0. 安裝 ---
!pip install -q datasets tokenizers

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# --- 1. 資料準備 ---
# 使用 Tatoeba 英德句對(簡單短句)
dataset = load_dataset("Helsinki-NLP/tatoeba_mt", "deu-eng", split="test")
pairs = [(ex["sourceString"], ex["targetString"]) for ex in dataset]
pairs = [p for p in pairs if len(p[0].split()) <= 15 and len(p[1].split()) <= 15]
pairs = pairs[:8000]
print(f"Sentence pairs: {len(pairs)}")
print(f"Example: '{pairs[0][0]}' → '{pairs[0][1]}'")

# 簡易字元級 tokenizer
class SimpleTokenizer:
    def __init__(self):
        self.char2idx = {"<pad>": 0, "<bos>": 1, "<eos>": 2, "<unk>": 3}
        self.idx2char = {0: "<pad>", 1: "<bos>", 2: "<eos>", 3: "<unk>"}

    def fit(self, texts):
        for text in texts:
            for ch in text:
                if ch not in self.char2idx:
                    idx = len(self.char2idx)
                    self.char2idx[ch] = idx
                    self.idx2char[idx] = ch

    def encode(self, text, max_len=80):
        ids = [1] + [self.char2idx.get(ch, 3) for ch in text[:max_len-2]] + [2]
        return ids

    def decode(self, ids):
        chars = []
        for idx in ids:
            if idx == 2: break
            if idx > 2: chars.append(self.idx2char.get(idx, "?"))
        return "".join(chars)

    @property
    def vocab_size(self):
        return len(self.char2idx)

src_tokenizer = SimpleTokenizer()
tgt_tokenizer = SimpleTokenizer()
src_tokenizer.fit([p[0] for p in pairs])
tgt_tokenizer.fit([p[1] for p in pairs])
print(f"Source vocab: {src_tokenizer.vocab_size}, Target vocab: {tgt_tokenizer.vocab_size}")

MAX_LEN = 80

class TranslationDataset(Dataset):
    def __init__(self, pairs, src_tok, tgt_tok):
        self.pairs = pairs
        self.src_tok = src_tok
        self.tgt_tok = tgt_tok

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        src, tgt = self.pairs[idx]
        return self.src_tok.encode(src, MAX_LEN), self.tgt_tok.encode(tgt, MAX_LEN)

def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src_max = max(len(s) for s in src_batch)
    tgt_max = max(len(t) for t in tgt_batch)
    src_padded = torch.zeros(len(batch), src_max, dtype=torch.long)
    tgt_padded = torch.zeros(len(batch), tgt_max, dtype=torch.long)
    for i, (s, t) in enumerate(zip(src_batch, tgt_batch)):
        src_padded[i, :len(s)] = torch.tensor(s)
        tgt_padded[i, :len(t)] = torch.tensor(t)
    return src_padded, tgt_padded

train_ds = TranslationDataset(pairs[:7000], src_tokenizer, tgt_tokenizer)
val_ds = TranslationDataset(pairs[7000:], src_tokenizer, tgt_tokenizer)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=64, collate_fn=collate_fn)

# --- 2. Transformer 模型(完整 Encoder-Decoder)---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=200):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_k = d_model // n_heads
        self.n_heads = n_heads
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, mask=None):
        B = Q.size(0)
        Q = self.W_Q(Q).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_K(K).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_V(V).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, V)
        out = out.transpose(1, 2).contiguous().view(B, -1, self.n_heads * self.d_k)
        return self.W_O(out)

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads)
        self.ffn = nn.Sequential(nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model))
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, src_mask):
        x = self.norm1(x + self.drop(self.self_attn(x, x, x, src_mask)))
        x = self.norm2(x + self.drop(self.ffn(x)))
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads)
        self.cross_attn = MultiHeadAttention(d_model, n_heads)
        self.ffn = nn.Sequential(nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model))
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, tgt_mask):
        x = self.norm1(x + self.drop(self.self_attn(x, x, x, tgt_mask)))
        x = self.norm2(x + self.drop(self.cross_attn(x, enc_out, enc_out, src_mask)))
        x = self.norm3(x + self.drop(self.ffn(x)))
        return x

class Transformer(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=128, n_heads=4,
                 n_layers=3, d_ff=256, dropout=0.1):
        super().__init__()
        self.src_emb = nn.Embedding(src_vocab, d_model, padding_idx=0)
        self.tgt_emb = nn.Embedding(tgt_vocab, d_model, padding_idx=0)
        self.pos_enc = PositionalEncoding(d_model)
        self.encoder = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
        self.decoder = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
        self.output_proj = nn.Linear(d_model, tgt_vocab)
        self.dropout = nn.Dropout(dropout)
        self.d_model = d_model

    def make_src_mask(self, src):
        return (src != 0).unsqueeze(1).unsqueeze(2)

    def make_tgt_mask(self, tgt):
        B, N = tgt.shape
        pad_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)
        causal_mask = torch.tril(torch.ones(N, N, device=tgt.device)).bool()
        return pad_mask & causal_mask.unsqueeze(0).unsqueeze(0)

    def encode(self, src):
        src_mask = self.make_src_mask(src)
        x = self.dropout(self.pos_enc(self.src_emb(src) * math.sqrt(self.d_model)))
        for layer in self.encoder:
            x = layer(x, src_mask)
        return x, src_mask

    def decode(self, tgt, enc_out, src_mask):
        tgt_mask = self.make_tgt_mask(tgt)
        x = self.dropout(self.pos_enc(self.tgt_emb(tgt) * math.sqrt(self.d_model)))
        for layer in self.decoder:
            x = layer(x, enc_out, src_mask, tgt_mask)
        return self.output_proj(x)

    def forward(self, src, tgt):
        enc_out, src_mask = self.encode(src)
        return self.decode(tgt, enc_out, src_mask)

# --- 3. 訓練 ---
model = Transformer(src_tokenizer.vocab_size, tgt_tokenizer.vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9)
criterion = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

for epoch in range(20):
    model.train()
    total_loss = 0
    for src, tgt in train_loader:
        src, tgt = src.to(device), tgt.to(device)
        tgt_input = tgt[:, :-1]
        tgt_target = tgt[:, 1:]
        logits = model(src, tgt_input)
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_target.reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1}: loss={total_loss/len(train_loader):.4f}")

# --- 4. 翻譯推論(Greedy Decoding)---
def translate(text, model, src_tok, tgt_tok, max_len=80):
    model.eval()
    src = torch.tensor([src_tok.encode(text, max_len)]).to(device)
    enc_out, src_mask = model.encode(src)

    tgt_ids = [1]  # <bos>
    for _ in range(max_len):
        tgt = torch.tensor([tgt_ids]).to(device)
        logits = model.decode(tgt, enc_out, src_mask)
        next_id = logits[0, -1].argmax().item()
        if next_id == 2: break  # <eos>
        tgt_ids.append(next_id)

    return tgt_tok.decode(tgt_ids)

# --- 5. 測試翻譯 ---
print("\n=== Translation Results ===")
test_sentences = [p[0] for p in pairs[7000:7010]]
for src_text in test_sentences:
    pred = translate(src_text, model, src_tokenizer, tgt_tokenizer)
    print(f"  EN: {src_text}")
    print(f"  DE: {pred}")
    print()

print("Lab 1 Complete!")

十、Hands-on Lab 2:ViT 影像分類微調 + 特徵視覺化(Google Colab)

以下實驗微調預訓練的 ViT 在 CIFAR-10 上進行影像分類,並視覺化 Transformer 學到的 patch 嵌入和 [CLS] token 特徵。

# ============================================================
# Lab 2: ViT 微調 CIFAR-10 + 特徵視覺化
# 環境: Google Colab (GPU)
# ============================================================
# --- 0. 安裝 ---
!pip install -q transformers datasets timm

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import ViTForImageClassification, ViTFeatureExtractor
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.decomposition import PCA
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# --- 1. 資料準備 ---
dataset = load_dataset("cifar10")
train_data = dataset["train"].shuffle(seed=42).select(range(5000))
test_data = dataset["test"].shuffle(seed=42).select(range(1000))

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def collate_fn(batch):
    images = torch.stack([transform(ex["img"]) for ex in batch])
    labels = torch.tensor([ex["label"] for ex in batch])
    return images, labels

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False, collate_fn=collate_fn)
print(f"Train: {len(train_data)}, Test: {len(test_data)}")

# --- 2. 載入預訓練 ViT 並替換分類頭 ---
model_name = "google/vit-base-patch16-224-in21k"
model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels=10,
    ignore_mismatched_sizes=True
).to(device)

# 凍結除最後兩層和分類頭外的所有參數
for name, param in model.named_parameters():
    if "classifier" not in name and "layernorm" not in name and "encoder.layer.11" not in name:
        param.requires_grad = False

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({trainable/total*100:.1f}%)")

# --- 3. 微調訓練 ---
optimizer = torch.optim.AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr=2e-4, weight_decay=0.01
)

for epoch in range(5):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(pixel_values=images, labels=labels)
        loss = outputs.loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * images.size(0)
        correct += (outputs.logits.argmax(1) == labels).sum().item()
        total += images.size(0)
    print(f"Epoch {epoch+1}: loss={total_loss/total:.4f}, acc={correct/total:.4f}")

# --- 4. 測試 ---
model.eval()
correct, total = 0, 0
all_features, all_labels = [], []
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(pixel_values=images, output_hidden_states=True)
        correct += (outputs.logits.argmax(1) == labels).sum().item()
        total += images.size(0)
        # 收集 [CLS] token 特徵
        cls_features = outputs.hidden_states[-1][:, 0]  # [B, 768]
        all_features.append(cls_features.cpu())
        all_labels.append(labels.cpu())

print(f"\nTest Accuracy: {correct/total:.4f}")

# --- 5. [CLS] Token 特徵 PCA 視覺化 ---
features = torch.cat(all_features).numpy()
labels = torch.cat(all_labels).numpy()

pca = PCA(n_components=2)
features_2d = pca.fit_transform(features)

plt.figure(figsize=(12, 8))
scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1],
                      c=labels, cmap='tab10', alpha=0.6, s=15)
plt.colorbar(scatter, ticks=range(10), label='Class')
plt.clim(-0.5, 9.5)

# 標註類別中心
for i in range(10):
    mask = labels == i
    cx, cy = features_2d[mask, 0].mean(), features_2d[mask, 1].mean()
    plt.annotate(class_names[i], (cx, cy), fontsize=10,
                 fontweight='bold', ha='center',
                 bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))

plt.title('ViT [CLS] Token Features — PCA Projection (CIFAR-10)', fontsize=14)
plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.1%} variance)')
plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.1%} variance)')
plt.tight_layout()
plt.show()

# --- 6. Position Embedding 視覺化 ---
pos_embed = model.vit.embeddings.position_embeddings[0].detach().cpu()  # [197, 768]
print(f"\nPosition embeddings shape: {pos_embed.shape}")

# 計算 patch position embeddings 之間的餘弦相似度
patch_pos = pos_embed[1:]  # 去掉 [CLS], [196, 768]
sim_matrix = F.cosine_similarity(
    patch_pos.unsqueeze(0), patch_pos.unsqueeze(1), dim=-1
)  # [196, 196]

# 選擇幾個 patch 位置,視覺化其與所有其他位置的相似度
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
positions = [0, 7, 48, 97, 112, 140, 175, 195]  # 不同位置
for idx, pos in enumerate(positions):
    ax = axes[idx // 4][idx % 4]
    sim = sim_matrix[pos].reshape(14, 14).numpy()
    im = ax.imshow(sim, cmap='RdBu_r', vmin=-1, vmax=1)
    row, col = pos // 14, pos % 14
    ax.plot(col, row, 'k*', markersize=10)
    ax.set_title(f'Patch ({row},{col})', fontsize=10)
    ax.axis('off')

plt.suptitle('Position Embedding Cosine Similarity (ViT)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# --- 7. 每個類別的預測信心分布 ---
model.eval()
all_probs, all_preds, all_true = [], [], []
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        logits = model(pixel_values=images).logits
        probs = F.softmax(logits, dim=-1)
        all_probs.append(probs.cpu())
        all_preds.append(logits.argmax(1).cpu())
        all_true.append(labels)

all_probs = torch.cat(all_probs)
all_preds = torch.cat(all_preds)
all_true = torch.cat(all_true)

fig, axes = plt.subplots(2, 5, figsize=(20, 8))
for i in range(10):
    ax = axes[i // 5][i % 5]
    mask = all_true == i
    correct_conf = all_probs[mask & (all_preds == all_true)].max(dim=1).values
    wrong_conf = all_probs[mask & (all_preds != all_true)].max(dim=1).values
    if len(correct_conf) > 0:
        ax.hist(correct_conf.numpy(), bins=20, alpha=0.7, color='#0077b6', label='Correct')
    if len(wrong_conf) > 0:
        ax.hist(wrong_conf.numpy(), bins=20, alpha=0.7, color='#e63946', label='Wrong')
    ax.set_title(class_names[i], fontsize=11)
    ax.set_xlim(0, 1)
    ax.legend(fontsize=8)

plt.suptitle('Prediction Confidence Distribution per Class', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nLab 2 Complete!")

十一、結語:Transformer 作為通用計算基元

Transformer 的影響力已遠超「一個好的模型架構」的範疇。它正在成為人工智慧的通用計算基元——就像 CPU 之於傳統計算、GPU 之於圖形渲染,Transformer 正在成為智慧計算的核心執行單元。

回顧本文的核心脈絡:

Transformer 或許終有一天會被超越,但它所確立的「預訓練 → 遷移」範式、「Scaling Laws」思想、以及「通用架構統一多模態」的願景,將長期塑造 AI 的發展方向。理解 Transformer,就是理解這個時代 AI 的底層邏輯。