Key Findings
  • Der Self-Attention-Mechanismus ermöglicht es jedem Element einer Sequenz, direkt mit allen anderen Elementen zu interagieren[1] — er beseitigt damit das Engpassproblem der Langstreckenabhängigkeiten von rekurrenten neuronalen Netzen und ermöglicht vollständige Parallelisierung der Berechnung
  • Die Transformer-Architektur[1] ist zur einheitlichen Grundarchitektur für NLP (BERT[3], GPT[4]) und Computer Vision (ViT[6], Swin[7]) geworden
  • Technologien wie FlashAttention[11] und Linformer[9] durchbrechen den O(n²)-Engpass der Self-Attention in Richtung nahezu linearer Komplexität und ermöglichen Kontextfenster mit Millionen von Token
  • Dieser Artikel enthält zwei Google Colab-Praxisübungen: Transformer-Textsentimentklassifikation (eigene Implementierung der Multi-Head Attention) und ViT-Bildklassifikation (Visualisierung von Attention-Heatmaps)

I. Von Attention zu Self-Attention: Ein Paradigmenwechsel

Im Jahr 2017 stellte ein Google-Team in einer Veröffentlichung mit dem Titel „Attention Is All You Need"[1] eine scheinbar kühne These auf: Man kann das leistungsfähigste Sequenzmodell allein mit dem Attention-Mechanismus konstruieren — ganz ohne Rekurrenz und Faltung. Die daraus hervorgegangene Transformer-Architektur hat in nur wenigen Jahren das gesamte Feld der künstlichen Intelligenz grundlegend umgestaltet.

Vor dem Transformer hatte der Attention-Mechanismus bereits in der Forschung von Bahdanau et al.[2] vielversprechende Ergebnisse gezeigt — als Hilfsmodul für RNNs, das dem Decoder half, die relevantesten Teile der Encoder-Ausgabe „zu beachten". Vaswani et al. gingen jedoch weiter: Sie ermöglichten es jedem Element der Sequenz, direkt mit allen anderen Elementen zu interagieren, ohne die schrittweise Weitergabe eines RNN zu benötigen. Dies ist der Kern der Self-Attention.

Die Bedeutung dieser Transformation im Überblick:

EigenschaftRNNSelf-Attention
LangstreckenabhängigkeitO(n) Schritte, um Anfang und Ende zu verbindenO(1) direkte Verbindung beliebiger Positionen
ParallelisierungMuss schrittweise berechnet werden, keine ParallelisierungAlle Positionen werden gleichzeitig berechnet
RechenkomplexitätO(n · d²) schrittweiseO(n² · d) global
SpeicherengpassFeste Größe des Hidden StateDynamische Attention-Gewichtsmatrix

II. Scaled Dot-Product Attention: Der mathematische Kern

Die Berechnung der Self-Attention lässt sich auf drei Matrixoperationen verdichten: Query (Abfrage), Key (Schlüssel), Value (Wert). Jedes Eingabe-Token wird in drei Vektoren projiziert, und dann wird über das Skalarprodukt der „Grad der Relevanz" zueinander berechnet.

Attention(Q, K, V) = softmax(Q · K^T / √d_k) · V

Dabei gilt:
  Q = X · W_Q    (Query-Matrix, Form [n, d_k])
  K = X · W_K    (Key-Matrix, Form [n, d_k])
  V = X · W_V    (Value-Matrix, Form [n, d_v])
  d_k = Dimension des Key-Vektors
  √d_k = Skalierungsfaktor, verhindert, dass zu große Skalarprodukte die Softmax-Sättigung verursachen

Intuitive Erklärung:

Beispiel: Im Satz „The cat sat on the mat because it was tired" bildet der Query-Vektor von „it" das Skalarprodukt mit den Key-Vektoren aller anderen Token. Im Idealfall erzeugt der Key von „cat" den höchsten Wert mit dem Query von „it", sodass das Modell korrekt erkennt, dass „it" sich auf „cat" bezieht.

III. Multi-Head Attention: Parallele Multiperspektiven-Beobachtung

Ein einzelner Satz von Q, K, V kann nur eine Art von Beziehung erfassen. Sprachliche Beziehungen sind jedoch vielschichtig — syntaktische Abhängigkeiten, semantische Ähnlichkeiten, Referenzbeziehungen, Tempuskonsistenz… Die Lösung des Transformers ist Multi-Head Attention:

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_O

Dabei gilt: head_i = Attention(Q · W_Q^i, K · W_K^i, V · W_V^i)

Am Beispiel d_model=512, h=8:
  Jeder Kopf hat d_k = d_v = 512 / 8 = 64
  8 Köpfe führen jeweils unabhängig 64-dimensionale Attention durch
  Die Ergebnisse werden konkateniert und zurück auf 512 Dimensionen projiziert

Die Forschung von Voita et al.[12] hat gezeigt, dass verschiedene Attention-Köpfe tatsächlich unterschiedliche „Rollen" erlernen: Einige konzentrieren sich auf Positionsbeziehungen, andere verfolgen syntaktische Strukturen, wieder andere verarbeiten seltene Wörter. Interessanterweise stellten Michel et al.[17] fest, dass eine große Anzahl von Köpfen entfernt werden kann, ohne die Leistung wesentlich zu beeinträchtigen — was darauf hindeutet, dass der Multi-Head-Mechanismus eine vorteilhafte Redundanz und Regularisierung bietet.

IV. Positionskodierung: Attention das Verständnis von „Reihenfolge" beibringen

Self-Attention ist von Natur aus permutationsinvariant — wenn die Eingabereihenfolge vertauscht wird, folgt die Ausgabe, aber die Repräsentation jedes Tokens ändert sich nicht. Das bedeutet, dass „Der Hund beißt den Mann" und „Der Mann beißt den Hund" aus Sicht der reinen Self-Attention identisch aussehen. Deshalb benötigt der Transformer eine zusätzliche Positionskodierung (Positional Encoding), um Reihenfolge-Informationen einzuspeisen.

Der ursprüngliche Transformer verwendet sinusförmige Positionskodierung[1]:

PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

Dabei ist pos die Token-Position und i der Dimensionsindex

In den letzten Jahren hat die Positionskodierung eine bedeutende Weiterentwicklung erfahren:

MethodeKernideeVorteilRepräsentatives Modell
Sinusförmig[1]Feste trigonometrische FunktionenKein Training erforderlichUrsprünglicher Transformer
LernbarEin trainierbarer Vektor pro PositionAdaptivBERT, GPT
RoPE[13]Rotationsmatrix kodiert relative PositionLängenextrapolation, relative WahrnehmungLLaMA, PaLM
ALiBi[14]Linearer Distanz-Bias zu Attention-WertenNull Parameter, starke ExtrapolationBLOOM, MPT

RoPE[13] ist zum Standard aktueller großer Sprachmodelle geworden. Es kodiert Positionsinformationen als Rotation im Vektorraum — der Attention-Wert zwischen zwei Token hängt nur von ihrer relativen Distanz ab und bietet gute Fähigkeiten zur Längenextrapolation.

V. Transformer-Architektur im Überblick

Der vollständige Transformer besteht aus Encoder und Decoder, wobei verschiedene Anwendungen unterschiedliche Kombinationen wählen:

ArchitekturtypStrukturRepräsentatives ModellTypische Aufgabe
Encoder-onlyBidirektionale Self-AttentionBERT[3]Klassifikation, NER, Frage-Antwort
Decoder-onlyKausale maskierte Self-AttentionGPT[4][5]Textgenerierung, Dialog
Encoder-DecoderEncoder + Cross-Attention + DecoderT5, BARTÜbersetzung, Zusammenfassung

Jede Transformer-Schicht (Block) enthält zwei Teilschichten, jeweils mit Residualverbindung und Schichtnormalisierung:

Transformer Block:
  1. Multi-Head Self-Attention
     → LayerNorm(x + MultiHead(x, x, x))
  2. Feed-Forward Network (FFN)
     → LayerNorm(x + FFN(x))
     → FFN(x) = max(0, x·W_1 + b_1) · W_2 + b_2

Der Decoder enthält zusätzlich:
  - Causal Mask: Maskierung zukünftiger Positionen für autoregressive Generierung
  - Cross-Attention: Query kommt vom Decoder, Key/Value vom Encoder

VI. BERT und GPT: Zwei Wege, eine Weggabelung

Die Transformer-Architektur brachte zwei große Pre-Training-Paradigmen hervor, die sich im „Sichtbarkeitsbereich" der Self-Attention unterscheiden:

BERT (Bidirectional)[3]: Verwendet die Encoder-Architektur, wobei jedes Token alle anderen Token in der Sequenz sehen kann (einschließlich des vorherigen und nachfolgenden Kontexts). Das Pre-Training erfolgt über ein Masked Language Model (MLM) — 15 % der Token werden zufällig maskiert, und das Modell soll die maskierten Wörter vorhersagen. Dieser bidirektionale Kontext macht BERT besonders stark bei Verständnisaufgaben.

GPT (Autoregressive)[4]: Verwendet die Decoder-Architektur, wobei jedes Token nur die vorangehenden Token sehen kann (kausale Maskierung). Das Pre-Training erfolgt über Next Token Prediction. GPT-3[5] skalierte die Parameter auf 175 Milliarden und demonstrierte beeindruckende Few-Shot-Lernfähigkeiten — ohne Fine-Tuning konnte es allein durch Prompts Aufgaben wie Übersetzung, Frage-Antwort und Codegenerierung ausführen.

Mit zunehmender Skalierung entdeckten Kaplan et al.[15] die berühmten Scaling Laws — es existieren stabile Potenzgesetz-Beziehungen zwischen Modellleistung (Loss) und Parameterzahl, Datenmenge sowie Rechenleistung. Noch überraschender dokumentierten Wei et al.[16] das Phänomen der emergenten Fähigkeiten (Emergent Abilities): Bestimmte Fähigkeiten (wie Chain-of-Thought-Reasoning) existieren in kleinen Modellen überhaupt nicht, treten aber plötzlich auf, wenn das Modell eine bestimmte Skalierungsschwelle erreicht.

VII. Vision Transformer: Self-Attention erobert die visuelle Domäne

Lange Zeit wurde Computer Vision von Convolutional Neural Networks dominiert. Im Jahr 2021 stellten Dosovitskiy et al.[6] den Vision Transformer (ViT) vor und bewiesen, dass eine reine Self-Attention-Architektur bei großskaligen Daten mit den stärksten CNNs mithalten oder diese sogar übertreffen kann.

Die Kernidee von ViT ist elegant in ihrer Einfachheit:

ViT-Prozess:
1. Ein 224×224-Bild wird in 16×16-Patches aufgeteilt → 196 Patches
2. Jeder Patch wird zu einem 768-dimensionalen Vektor abgeflacht (16×16×3 = 768)
3. Hinzufügen lernbarer Positionskodierung + [CLS]-Token
4. Eingabe in einen Standard-Transformer-Encoder
5. Verwendung der [CLS]-Token-Ausgabe für die Klassifikation

„Ein Bild ist 16×16 Wörter wert" — Bildpatches werden als Token in einer Sequenz behandelt

Der Erfolg von ViT löste eine Explosion von Vision Transformern aus. Swin Transformer[7] reduzierte mit einer Sliding-Window-Strategie die Attention-Komplexität von O(n²) auf O(n) und wurde zu einem universellen visuellen Backbone. DETR[8] realisierte End-to-End-Objekterkennung mit Transformern und eliminierte die traditionellen Anchor Boxes und NMS-Nachverarbeitung.

VIII. Effizienzdurchbrüche: Self-Attention jenseits von O(n²)

Die O(n²)-Zeit- und Speicherkomplexität der Self-Attention ist ihr größter Engpass. Für eine Sequenz der Länge n erfordert die Attention-Matrix n² Rechenoperationen und Speicherplatz. Bei n = 100.000 bedeutet das 10^10 Operationen — ingenieurstechnisch inakzeptabel.

Wissenschaft und Industrie haben verschiedene Durchbruchstrategien vorgeschlagen:

MethodeKernstrategieKomplexitätGenauigkeit
Linformer[9]Key/Value Low-Rank-ProjektionO(n)Approximiert
Performer[10]Zufällige Feature-Maps (FAVOR+)O(n)Approximiert
FlashAttention[11]IO-bewusstes Blocking + Kernel-FusionO(n²) aber 2–4× schnellerExakt
Sliding WindowNur lokales Fenster beachtenO(n · w)Lokal exakt
Sparse AttentionSparse Attention-MusterO(n√n)Approximiert

Der Beitrag von FlashAttention[11] ist besonders bedeutsam: Es verändert nicht die Mathematik — das Ergebnis ist identisch mit Standard-Attention — sondern erreicht eine 2–4-fache Beschleunigung durch geschicktes GPU-Speichermanagement (blockweise Berechnung, Vermeidung von HBM-SRAM-Roundtrips). Es ist zu einer Standardkomponente des modernen LLM-Trainings geworden.

IX. Hands-on Lab 1: Transformer-Textklassifikator von Grund auf implementieren (Google Colab)

Das folgende Experiment implementiert eine Multi-Head Self-Attention-Schicht von Grund auf, baut einen Mini-Transformer für die IMDb-Filmrezensionen-Sentimentklassifikation und visualisiert die Attention-Gewichte.

# ============================================================
# Lab 1: Transformer von Grund auf — IMDb-Sentimentklassifikation + Attention-Visualisierung
# Umgebung: Google Colab (GPU)
# ============================================================
# --- 0. Installation ---
!pip install -q torchtext datasets

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from torch.utils.data import DataLoader
from collections import Counter

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# --- 1. Datenvorbereitung ---
dataset = load_dataset("imdb")
train_data = dataset["train"].shuffle(seed=42).select(range(10000))
test_data = dataset["test"].shuffle(seed=42).select(range(2000))

# Einfacher Tokenizer
def simple_tokenizer(text):
    return text.lower().split()

# Vokabular erstellen
counter = Counter()
for example in train_data:
    counter.update(simple_tokenizer(example["text"]))
vocab = {"<pad>": 0, "<unk>": 1}
for word, count in counter.most_common(20000):
    if count >= 3:
        vocab[word] = len(vocab)
vocab_size = len(vocab)
print(f"Vocab size: {vocab_size}")

MAX_LEN = 256

def encode(text):
    tokens = simple_tokenizer(text)[:MAX_LEN]
    ids = [vocab.get(t, 1) for t in tokens]
    return ids

def collate_fn(batch):
    texts = [encode(ex["text"]) for ex in batch]
    labels = torch.tensor([ex["label"] for ex in batch])
    max_len = min(max(len(t) for t in texts), MAX_LEN)
    padded = torch.zeros(len(texts), max_len, dtype=torch.long)
    for i, t in enumerate(texts):
        padded[i, :len(t)] = torch.tensor(t)
    return padded, labels

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False, collate_fn=collate_fn)

# --- 2. Eigene Implementierung der Multi-Head Self-Attention ---
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)
        self.attn_weights = None  # Attention-Gewichte zur Visualisierung speichern

    def forward(self, x, mask=None):
        B, N, _ = x.shape

        # Projektion + Aufteilung in Köpfe: [B, N, d_model] → [B, h, N, d_k]
        Q = self.W_Q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_K(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_V(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)

        # Scaled Dot-Product Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        self.attn_weights = attn.detach()

        out = torch.matmul(attn, V)  # [B, h, N, d_k]
        out = out.transpose(1, 2).contiguous().view(B, N, self.d_model)
        return self.W_O(out)

# --- 3. Transformer Encoder Block ---
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadSelfAttention(d_model, n_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        x = self.norm1(x + self.dropout(self.attn(x, mask)))
        x = self.norm2(x + self.ffn(x))
        return x

# --- 4. Vollständiges Klassifikationsmodell ---
class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, d_model=128, n_heads=4,
                 n_layers=3, d_ff=256, max_len=256, n_classes=2, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_encoding = nn.Embedding(max_len, d_model)
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])
        self.classifier = nn.Linear(d_model, n_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N = x.shape
        positions = torch.arange(N, device=x.device).unsqueeze(0).expand(B, N)
        mask = (x != 0).unsqueeze(1).unsqueeze(2)  # [B, 1, 1, N]

        x = self.dropout(self.embedding(x) + self.pos_encoding(positions))
        for block in self.blocks:
            x = block(x, mask)

        # Globales Average Pooling (Padding ignorieren)
        mask_float = (x != 0).any(dim=-1, keepdim=True).float()
        x = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1).clamp(min=1)
        return self.classifier(x)

# --- 5. Training ---
model = TransformerClassifier(vocab_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

for epoch in range(6):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        loss = criterion(logits, yb)
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item() * xb.size(0)
        correct += (logits.argmax(1) == yb).sum().item()
        total += xb.size(0)
    print(f"Epoch {epoch+1}: loss={total_loss/total:.4f}, acc={correct/total:.4f}")

# --- 6. Test ---
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        correct += (logits.argmax(1) == yb).sum().item()
        total += xb.size(0)
print(f"\nTest Accuracy: {correct/total:.4f}")

# --- 7. Attention-Visualisierung ---
def visualize_attention(text, model):
    model.eval()
    tokens = simple_tokenizer(text)[:50]
    ids = torch.tensor([[vocab.get(t, 1) for t in tokens]]).to(device)

    with torch.no_grad():
        _ = model(ids)

    # Attention-Gewichte der letzten Schicht
    attn = model.blocks[-1].attn.attn_weights[0]  # [h, N, N]

    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    for i in range(4):
        ax = axes[i]
        im = ax.imshow(attn[i, :len(tokens), :len(tokens)].cpu(),
                       cmap='Blues', aspect='auto')
        ax.set_xticks(range(len(tokens)))
        ax.set_yticks(range(len(tokens)))
        ax.set_xticklabels(tokens, rotation=90, fontsize=7)
        ax.set_yticklabels(tokens, fontsize=7)
        ax.set_title(f'Head {i+1}')
        plt.colorbar(im, ax=ax, fraction=0.046)
    plt.suptitle('Multi-Head Self-Attention Weights (Last Layer)', fontsize=14)
    plt.tight_layout()
    plt.show()

# Visualisierungsbeispiele
visualize_attention("this movie was absolutely wonderful and the acting was superb", model)
visualize_attention("terrible film with awful dialogue and boring plot", model)
print("Lab 1 Complete!")

X. Hands-on Lab 2: Vision Transformer Bildklassifikation + Attention-Heatmaps (Google Colab)

Das folgende Experiment verwendet ein vortrainiertes ViT-Modell zur Bildklassifikation und extrahiert Self-Attention-Gewichte zur Erzeugung von Attention-Heatmaps, die intuitiv zeigen, „was das Modell sieht".

# ============================================================
# Lab 2: Vision Transformer — Bildklassifikation + Attention-Heatmap-Visualisierung
# Umgebung: Google Colab (GPU oder CPU)
# ============================================================
# --- 0. Installation ---
!pip install -q transformers timm pillow matplotlib

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor
import requests
from io import BytesIO
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# --- 1. Vortrainiertes ViT laden ---
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(
    model_name, output_attentions=True
).to(device).eval()
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
print(f"Model: {model_name}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Patch size: 16x16, Image size: 224x224 → 196 patches + 1 [CLS]")

# --- 2. Bildladen und Vorverarbeitung ---
def load_image(url):
    response = requests.get(url)
    img = Image.open(BytesIO(response.content)).convert("RGB")
    return img

# Klassische Testbilder verwenden
urls = {
    "Golden Retriever": "https://upload.wikimedia.org/wikipedia/commons/thumb/b/bd/Golden_Retriever_Dukedestination.jpg/800px-Golden_Retriever_Dukedestination.jpg",
    "Tabby Cat": "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4d/Cat_November_2010-1a.jpg/800px-Cat_November_2010-1a.jpg",
    "Bald Eagle": "https://upload.wikimedia.org/wikipedia/commons/thumb/1/1a/About_to_Launch_%2826075320352%29.jpg/800px-About_to_Launch_%2826075320352%29.jpg",
}

images = {}
for name, url in urls.items():
    try:
        images[name] = load_image(url)
        print(f"Loaded: {name} ({images[name].size})")
    except Exception as e:
        print(f"Failed to load {name}: {e}")

# --- 3. Inferenz + Attention extrahieren ---
def predict_with_attention(img, model, feature_extractor):
    inputs = feature_extractor(images=img, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    probs = F.softmax(logits, dim=-1)
    top5 = torch.topk(probs, 5)

    # Attention-Gewichte: Liste von [B, heads, N, N] für jede Schicht
    attentions = outputs.attentions  # 12 Schichten
    return top5, attentions

# --- 4. Attention-Heatmap-Erzeugung ---
def get_attention_map(attentions, layer=-1):
    """[CLS]-Token-Attention der angegebenen Schicht extrahieren → 14x14 Heatmap"""
    # Angegebene Schicht, Durchschnitt aller Köpfe
    attn = attentions[layer][0]  # [heads, N, N]
    attn_avg = attn.mean(dim=0)  # [N, N]

    # Attention des [CLS]-Tokens (Position 0) auf alle Patches
    cls_attn = attn_avg[0, 1:]  # [CLS] selbst entfernen, [196]
    cls_attn = cls_attn.reshape(14, 14).cpu().numpy()

    # Normalisierung
    cls_attn = (cls_attn - cls_attn.min()) / (cls_attn.max() - cls_attn.min() + 1e-8)
    return cls_attn

def get_rollout_attention(attentions):
    """Attention Rollout: Kumulation der Attention über alle Schichten"""
    result = torch.eye(attentions[0].size(-1)).to(device)
    for attn_layer in attentions:
        attn = attn_layer[0].mean(dim=0)  # [N, N] Durchschnitt aller Köpfe
        attn = attn + torch.eye(attn.size(0)).to(device)  # Residualverbindung
        attn = attn / attn.sum(dim=-1, keepdim=True)  # Normalisierung
        result = torch.matmul(attn, result)

    cls_attn = result[0, 1:].reshape(14, 14).cpu().numpy()
    cls_attn = (cls_attn - cls_attn.min()) / (cls_attn.max() - cls_attn.min() + 1e-8)
    return cls_attn

# --- 5. Visualisierung ---
def visualize_vit(name, img, model, feature_extractor):
    top5, attentions = predict_with_attention(img, model, feature_extractor)

    print(f"\n{'='*50}")
    print(f"Image: {name}")
    print(f"{'='*50}")
    for i in range(5):
        idx = top5.indices[0][i].item()
        prob = top5.values[0][i].item()
        label = model.config.id2label[idx]
        print(f"  {i+1}. {label}: {prob:.4f}")

    # Attention verschiedener Schichten extrahieren
    attn_first = get_attention_map(attentions, layer=0)
    attn_mid = get_attention_map(attentions, layer=5)
    attn_last = get_attention_map(attentions, layer=-1)
    attn_rollout = get_rollout_attention(attentions)

    # Bild zu numpy
    img_np = np.array(img.resize((224, 224)))

    fig, axes = plt.subplots(1, 5, figsize=(25, 5))

    axes[0].imshow(img_np)
    axes[0].set_title(f"Original\n{model.config.id2label[top5.indices[0][0].item()]}", fontsize=12)
    axes[0].axis('off')

    titles = ['Layer 1 Attention', 'Layer 6 Attention',
              'Layer 12 Attention', 'Attention Rollout']
    maps = [attn_first, attn_mid, attn_last, attn_rollout]

    for i, (title, attn_map) in enumerate(zip(titles, maps)):
        ax = axes[i + 1]
        ax.imshow(img_np)
        attn_resized = np.array(Image.fromarray(
            (attn_map * 255).astype(np.uint8)
        ).resize((224, 224), Image.BICUBIC)) / 255.0
        ax.imshow(attn_resized, alpha=0.6, cmap='jet')
        ax.set_title(title, fontsize=12)
        ax.axis('off')

    plt.suptitle(f'ViT Attention Maps — {name}', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

# --- 6. Visualisierung ausführen ---
for name, img in images.items():
    visualize_vit(name, img, model, feature_extractor)

# --- 7. Multi-Head Attention-Vergleich ---
def visualize_heads(name, img, model, feature_extractor, layer=-1):
    """Attention-Muster verschiedener Köpfe einer Schicht visualisieren"""
    inputs = feature_extractor(images=img, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)

    attn = outputs.attentions[layer][0]  # [12, 197, 197]
    img_np = np.array(img.resize((224, 224)))

    fig, axes = plt.subplots(2, 6, figsize=(24, 8))
    for head_idx in range(12):
        ax = axes[head_idx // 6][head_idx % 6]
        head_attn = attn[head_idx, 0, 1:].reshape(14, 14).cpu().numpy()
        head_attn = (head_attn - head_attn.min()) / (head_attn.max() - head_attn.min() + 1e-8)
        attn_resized = np.array(Image.fromarray(
            (head_attn * 255).astype(np.uint8)
        ).resize((224, 224), Image.BICUBIC)) / 255.0

        ax.imshow(img_np)
        ax.imshow(attn_resized, alpha=0.6, cmap='inferno')
        ax.set_title(f'Head {head_idx+1}', fontsize=10)
        ax.axis('off')

    plt.suptitle(f'12 Attention Heads (Last Layer) — {name}', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

# 12 Köpfe des ersten Bildes visualisieren
first_name = list(images.keys())[0]
visualize_heads(first_name, images[first_name], model, feature_extractor)

print("\nLab 2 Complete!")

XI. Entscheidungsrahmen: Wie Unternehmen die richtige Attention-Architektur wählen

SzenarioEmpfohlene ArchitekturBegründung
Textklassifikation / NERBERT-base Fine-TuningBidirektionaler Kontext, optimal für Klassifikation
Textgenerierung / DialogGPT-Serie / LLaMAAutoregressive Generierung, kausale Attention
BildklassifikationViT / Swin TransformerÜbertrifft CNN bei großskaligen Daten
ObjekterkennungDETR / Swin + FPNEnd-to-End, keine Anchor Boxes erforderlich
Extrem lange SequenzenFlashAttention + RoPE128K+ Token-Kontextfenster
Edge-GeräteDestilliertes ViT / MobileViTLightweight Self-Attention
MultimodalCross-Attention TransformerEinheitliche Bild-Text-Repräsentation

XII. Fazit und Ausblick

Der Self-Attention-Mechanismus ist einer der bedeutendsten Durchbrüche der letzten zehn Jahre im Bereich der künstlichen Intelligenz. Von „Attention Is All You Need"[1] im Jahr 2017 bis hin zu GPT-4, Claude und Gemini heute — Self-Attention ist zum Rechenkern nahezu aller modernen KI-Systeme geworden.

Ein Rückblick auf diese Revolution:

Mit Blick auf die Zukunft umfassen die Entwicklungsrichtungen der Self-Attention: State Space Models (SSM / Mamba) als sub-lineare Alternativen, Mixture of Experts (MoE) für sparsam aktivierte Riesenmodelle sowie einheitliche multimodale Architekturen, die Text, Bild, Audio und Video im selben Attention-Raum zusammenführen. Unabhängig davon, wie sich die konkrete Form weiterentwickelt, wird die Kernidee — „jedem Element zu ermöglichen, dynamisch auf alle relevanten Informationen zu achten" — die KI auch im nächsten Jahrzehnt anführen.