- 聯邦學習(Federated Learning)讓多個機構在不共享原始資料的前提下協作訓練 AI 模型[1],是 GDPR[13] 等隱私法規時代最重要的分散式學習範式
- FedAvg[1] 是聯邦學習的基石演算法,但面對非 IID(Non-IID)資料分佈時會顯著退化——FedProx[3] 與 FedBN[15] 分別從優化約束與特徵歸一化角度解決此問題
- 差分隱私(Differential Privacy)[12]與安全聚合(Secure Aggregation)[5]構成聯邦學習的雙重隱私防線——前者在數學上保證個體資料不可被推斷,後者確保伺服器無法看到任何單一客戶端的模型更新
- 本文附兩個 Google Colab 實作:使用 Flower 框架[6]模擬多客戶端聯邦訓練(CIFAR-10 影像分類),以及差分隱私聯邦學習實驗(Opacus 整合),可直接在瀏覽器中執行
一、資料不能出門:為何聯邦學習是隱私法規時代的必然選擇
傳統的機器學習假設很簡單:把所有資料集中到一台伺服器上,然後訓練模型。這個假設在 2010 年代或許可行,但在今天正被三股力量同時瓦解:
法規壓力。歐盟 GDPR[13] 明確要求資料最小化(Data Minimization)和目的限制(Purpose Limitation),任何跨機構的資料傳輸都需要合法基礎。美國各州的隱私法(CCPA/CPRA)、中國的《個人資訊保護法》、台灣的《個人資料保護法》也在收緊資料流通的閘門。對醫療機構而言,HIPAA 更直接禁止將病患資料傳輸到外部伺服器。
商業競爭。即使法規允許,企業也不願意將核心資料交給第三方。一家銀行不會把客戶交易紀錄交給另一家銀行來聯合訓練反詐騙模型——即使雙方都知道聯合訓練的效果更好。資料是競爭壁壘,沒有企業願意拆掉這道牆。
物理限制。數十億台行動裝置每天產生海量資料,但將所有手機上的打字紀錄上傳到雲端既不實際(頻寬成本)也不安全(隱私風險)。Google 在 2017 年面對的正是這個問題:如何利用數億台 Android 手機的打字數據改善 Gboard 的下一詞預測,卻不收集任何一個用戶的打字內容[10]?
聯邦學習(Federated Learning)正是在這個背景下誕生的。Google 的 McMahan 等人在 2017 年提出 FedAvg 演算法[1],其核心思想革命性地簡單:不移動資料,只移動模型。
聯邦學習的基本流程:
Round t:
1. 伺服器 → 客戶端: 廣播全域模型 w_t
2. 每個客戶端 k:
- 在本地資料上訓練 E 個 epoch
- 得到本地模型 w_t^k
3. 客戶端 → 伺服器: 上傳模型更新 Δw_t^k = w_t^k - w_t
4. 伺服器: 聚合所有更新 → w_{t+1}
5. 重複直到收斂
關鍵特性:
- 原始資料永遠不離開客戶端
- 伺服器只看到模型參數的更新(梯度)
- 通訊量遠小於傳輸原始資料
根據 Yang 等人的分類[4],聯邦學習可依據資料分區方式分為三種類型:
| 類型 | 資料分佈特徵 | 典型場景 | 範例 |
|---|---|---|---|
| 橫向聯邦學習(Horizontal FL) | 各方擁有相同特徵、不同樣本 | 同業協作 | 多家醫院各自擁有不同病患的相同類型醫療影像 |
| 縱向聯邦學習(Vertical FL) | 各方擁有相同樣本、不同特徵 | 跨業協作 | 銀行有客戶金融資料,電商有同一批客戶的消費紀錄 |
| 聯邦遷移學習(Federated Transfer Learning) | 各方的樣本與特徵都不同 | 跨域協作 | 不同國家的醫院,病患群體和檢測設備都不同 |
本文聚焦於最常見的橫向聯邦學習,即多個客戶端擁有相同的特徵空間但各自擁有不同的樣本。這也是 FedAvg、FedProx 等經典演算法的主要適用場景。
二、FedAvg:聯邦學習的基石演算法
Federated Averaging(FedAvg)[1]是 McMahan 等人在 2017 年提出的第一個實用聯邦學習演算法,至今仍是大多數聯邦學習系統的基礎。它的設計目標是在通訊效率和模型品質之間取得平衡。
FedAvg 的核心概念是:與其每個 mini-batch 都與伺服器同步梯度(像傳統的分散式 SGD),不如讓每個客戶端在本地多訓練幾個 epoch,然後再將完整的模型參數上傳。這大幅減少了通訊次數。
FedAvg 演算法(虛擬碼):
ServerUpdate:
初始化全域模型 w_0
for each round t = 1, 2, ..., T:
S_t ← 從 K 個客戶端中隨機選取 m = max(C·K, 1) 個
for each 客戶端 k ∈ S_t (可平行):
w_{t+1}^k ← ClientUpdate(k, w_t)
w_{t+1} ← Σ_k (n_k / n) · w_{t+1}^k # 加權平均
ClientUpdate(k, w):
B ← 將本地資料分成 batch size B 的 mini-batches
for each local epoch e = 1, ..., E:
for each batch b ∈ B:
w ← w - η · ∇L(w; b)
return w
超參數:
C = 客戶端選取比例(例如 0.1 表示每輪選 10%)
E = 本地訓練 epoch 數
B = 本地 batch size
η = 本地學習率
FedAvg 的聚合公式是按樣本數加權平均:如果客戶端 k 擁有 n_k 筆資料,佔總資料量 n 的比例為 n_k/n,則其模型更新的權重就是 n_k/n。這確保了擁有更多資料的客戶端對全域模型有更大的影響力。
FedAvg 的通訊效率來自兩個關鍵設計:(1)客戶端選取——每輪只需要一部分客戶端參與,而非全部;(2)多步本地更新——增加本地 epoch 數 E 可以減少通訊輪數,但可能犧牲收斂速度。McMahan 等人的實驗顯示,E=5、C=0.1 是一個良好的起點[1]。
然而,FedAvg 有一個關鍵假設:各客戶端的資料分佈是獨立同分布的(IID)。在現實世界中,這個假設幾乎總是被違反。不同醫院的病患群體不同、不同地區的用戶行為不同、不同銀行的客戶結構不同——這就是所謂的「非 IID」(Non-IID)問題,也是聯邦學習的最大挑戰。
三、非 IID 資料挑戰:FedProx、FedBN 與 SCAFFOLD
當各客戶端的資料分佈存在顯著差異時(即 Non-IID 情境),FedAvg 的表現會急劇惡化。Kairouz 等人的綜述[2]將非 IID 分為五種類型:
- 標籤分佈偏斜(Label Distribution Skew):某些客戶端只有特定類別的資料。例如一家皮膚科診所幾乎沒有骨科影像。
- 特徵分佈偏斜(Feature Distribution Skew):相同標籤的資料在不同客戶端有不同的特徵分佈。例如不同手機型號拍攝的影像色溫不同。
- 樣本數量不均(Quantity Skew):不同客戶端的資料量差異巨大。大醫院可能有百萬筆紀錄,小診所只有幾千筆。
- 概念偏移(Concept Shift):相同特徵在不同客戶端對應不同標籤。例如同一張胸部 X 光,不同醫師可能給出不同的診斷標註。
- 時間偏移(Temporal Shift):資料分佈隨時間變化。用戶的興趣、市場的交易模式都會演變。
為了解決非 IID 問題,研究社群提出了多種改進方案:
3.1 FedProx:加入近端約束
Li 等人提出的 FedProx[3] 是對 FedAvg 最直接的改進。它在每個客戶端的本地損失函數中加入一個近端項(proximal term),限制本地模型不要偏離全域模型太遠:
FedProx 的本地目標函數:
min_w L_k(w) + (μ/2) · ‖w - w_t‖²
其中:
L_k(w) = 客戶端 k 的原始損失函數
w_t = 當前輪的全域模型
μ = 近端係數(超參數,控制約束強度)
直覺:
- μ = 0 時退化為 FedAvg
- μ 越大,本地模型越接近全域模型(更穩定但學得更慢)
- μ 越小,本地模型有更大的自由度(學得更快但可能偏離)
典型設定: μ ∈ {0.001, 0.01, 0.1, 1.0}
FedProx 的另一個優勢是它能容忍系統異質性(Systems Heterogeneity):不同客戶端可以完成不同數量的本地更新步。慢速設備可以只做部分更新然後提前上傳,而不是被完全排除。
3.2 FedBN:特徵歸一化層的局部化
Li 等人提出的 FedBN[15] 採用了一個優雅的策略:既然不同客戶端的特徵分佈不同,那就讓每個客戶端保留自己的 Batch Normalization 層,只聚合其他層的參數。
FedBN 策略:
標準 FedAvg:
所有參數(包括 BN 層)都參與聚合
FedBN:
卷積層、全連接層 → 正常聚合(全域共享)
BatchNorm 的 γ, β, running_mean, running_var → 保留在本地(不聚合)
效果:
- BN 層自動適應每個客戶端的本地特徵分佈
- 其他層學習通用的特徵提取能力
- 在特徵分佈偏斜的場景下效果顯著
FedBN 的工程實現極為簡單——只需要在聚合步驟中排除 BN 相關的參數。這使得它成為處理特徵偏斜問題最容易採用的方案之一。
3.3 非 IID 方法對比
| 方法 | 核心策略 | 主要解決的 Non-IID 類型 | 額外通訊成本 | 實作難度 |
|---|---|---|---|---|
| FedAvg[1] | 加權平均 | (基線,IID 最佳) | 無 | 低 |
| FedProx[3] | 近端約束 | 標籤偏斜、系統異質性 | 無 | 低 |
| FedBN[15] | BN 層局部化 | 特徵分佈偏斜 | 略減(BN 不傳) | 極低 |
| SCAFFOLD | 控制變量修正梯度漂移 | 標籤偏斜(收斂速度) | 2x(需傳控制變量) | 中 |
| FedMA | 層匹配後聚合 | 模型異質性 | 增加 | 高 |
實務建議:在大多數場景中,先嘗試 FedAvg 作為基線。如果效果不佳,FedProx(加一行正則項即可)和 FedBN(排除 BN 參數即可)是成本最低的改進。只有在嚴重的非 IID 場景且通訊頻寬充裕時,才考慮 SCAFFOLD 等更複雜的方法。
四、隱私保護機制:差分隱私與安全聚合
聯邦學習「不共享原始資料」的設計提供了基本的隱私保護,但這並不足夠。研究顯示[14],即使只觀察模型更新(梯度),攻擊者仍然可以推斷出訓練資料的敏感資訊。主要的攻擊方式包括:
- 梯度反轉攻擊(Gradient Inversion Attack):從上傳的梯度反推出原始訓練樣本的近似重建
- 成員推斷攻擊(Membership Inference Attack):判斷某筆特定資料是否被用於訓練
- 模型中毒攻擊(Model Poisoning Attack):惡意客戶端上傳被操控的模型更新,影響全域模型
因此,實用的聯邦學習系統需要額外的隱私保護層。兩個最重要的機制是差分隱私和安全聚合。
4.1 差分隱私(Differential Privacy)
差分隱私[12]提供了一個數學上可證明的隱私保證:無論攻擊者的計算能力多強、擁有多少背景知識,都無法從模型輸出中高度確信地推斷出任何單一個體的資料是否被使用。
差分隱私的定義:
一個隨機機制 M 滿足 (ε, δ)-差分隱私,若對任意兩個
只差一筆記錄的相鄰資料集 D 和 D',以及任意輸出集合 S:
P[M(D) ∈ S] ≤ e^ε · P[M(D') ∈ S] + δ
直覺:
- ε(隱私預算)越小,隱私保護越強
- ε = 0: 完美隱私(但模型無法學到任何東西)
- ε = ∞: 無隱私保護
- 實務上 ε ∈ [1, 10] 被認為是合理範圍
在聯邦學習中的實現(DP-FedAvg):
1. 每個客戶端計算本地模型更新 Δw
2. 梯度裁剪: Δw ← Δw · min(1, C/‖Δw‖) # 限制敏感度
3. 加入雜訊: Δw ← Δw + N(0, σ²C²I) # 高斯雜訊
4. 上傳加雜訊後的更新
σ 的選擇由 (ε, δ) 和訓練輪數 T 共同決定
差分隱私在聯邦學習中的應用可分為兩個層次[11]:
客戶端層級差分隱私(Client-Level DP):保證任何單一客戶端的整批資料不會被洩露。這是聯邦學習中最常用的形式,因為在跨機構場景中,我們關心的是保護每個機構的全部資料。
記錄層級差分隱私(Record-Level DP):保證任何單筆記錄不會被洩露。保護更細緻,但通常需要更大的雜訊,導致模型品質下降更多。
4.2 安全聚合(Secure Aggregation)
安全聚合[5]是另一條防線。它透過密碼學協議確保伺服器只能看到所有客戶端更新的聚合結果,而無法看到任何單一客戶端的模型更新。
安全聚合的基本原理(基於秘密共享):
準備階段:
每對客戶端 (i, j) 協商一個隨機遮罩 r_{i,j}
且 r_{i,j} = -r_{j,i}(遮罩互為相反數)
上傳階段:
客戶端 i 上傳: w_i + Σ_{j≠i} r_{i,j} (加遮罩的模型更新)
聚合階段:
伺服器計算: Σ_i (w_i + Σ_{j≠i} r_{i,j})
= Σ_i w_i + Σ_i Σ_{j≠i} r_{i,j}
= Σ_i w_i + 0 (遮罩互消)
= 正確的聚合結果
結果: 伺服器得到正確的聚合模型,但無法得知任何單一客戶端的更新
Google 在其生產級聯邦學習系統中,同時使用了差分隱私和安全聚合[5]。兩者的功能互補:安全聚合防止伺服器窺探個別更新,差分隱私防止從聚合結果推斷個體資訊。
| 保護機制 | 保護對象 | 防禦的攻擊 | 代價 |
|---|---|---|---|
| 差分隱私 | 個體資料不可推斷 | 成員推斷、梯度反轉 | 模型精度下降(雜訊) |
| 安全聚合 | 個別模型更新不可見 | 好奇伺服器(Honest-but-Curious) | 通訊量增加、有掉線處理複雜度 |
| 同態加密 | 密文上直接計算 | 伺服器端全部攻擊 | 計算量極大(10-100x 慢) |
五、產業應用:醫療、金融與行動裝置
5.1 醫療:跨院協作不共享病歷
醫療是聯邦學習最具影響力的應用領域[8]。單一醫院的病例數量往往不足以訓練高品質的 AI 模型——尤其是罕見疾病。但醫療資料受到最嚴格的隱私法規保護(HIPAA、GDPR),跨院共享病歷幾乎不可能。
Sheller 等人[9]在腦腫瘤分割任務上的研究顯示,10 家醫院使用聯邦學習協作訓練的模型,效果接近將所有資料集中訓練的模型,且顯著優於任何單一醫院獨立訓練的模型。這項研究證明了聯邦學習在醫療場景中的實用性。
目前已有多個醫療聯邦學習平台進入臨床應用:NVIDIA Clara FL 用於多院協作的醫學影像分析,Intel OpenFL 支援跨國的藥物發現協作,HealthChain 項目在歐洲多國實現了乳腺癌 AI 的聯邦訓練。
5.2 金融:反洗錢與信用評分
金融機構擁有豐富的交易資料,但受到嚴格的法規限制,無法直接共享客戶資訊。聯邦學習讓多家銀行可以聯合訓練反洗錢(AML)模型,每家銀行只能看到自己客戶的交易,但模型能學到跨行的洗錢模式[4]。
在信用評分場景中,縱向聯邦學習特別有價值:銀行擁有客戶的金融紀錄,電商平台擁有消費行為資料,電信公司擁有通訊模式。這三類互補的特徵可以透過縱向聯邦學習融合,建構更精準的信用模型,而不需要任何一方向其他方揭露原始資料。
5.3 行動裝置:Google Gboard 的實戰案例
Google 在 Gboard(行動鍵盤)中部署的聯邦學習[10]是最大規模的生產級應用之一。數億台 Android 手機各自在本地利用使用者的打字資料訓練下一詞預測模型,只將加密的模型更新上傳到 Google 伺服器進行聚合。
這個系統面臨的挑戰是典型的 Cross-Device 聯邦學習問題[5]:客戶端數量極大(數億台)、每台設備的資料量極小(個人打字紀錄)、設備隨時可能離線、計算和通訊資源有限。Google 的解決方案包括:只在設備充電且連接 Wi-Fi 時訓練、使用差分隱私和安全聚合保護隱私、每輪只選取數千台設備參與。
5.4 Cross-Silo vs. Cross-Device
| 特性 | Cross-Silo(跨機構) | Cross-Device(跨設備) |
|---|---|---|
| 客戶端數量 | 2–100(醫院、銀行) | 10^6–10^10(手機、IoT) |
| 每客戶端資料量 | 大(百萬筆以上) | 極小(數百筆) |
| 客戶端穩定性 | 穩定在線 | 隨時離線 |
| 資料異質性 | 中度 | 極端 |
| 主要挑戰 | 法規合規、機構互信 | 通訊效率、設備異質 |
| 代表應用 | 醫療、金融 | 鍵盤預測、推薦系統 |
六、Hands-on Lab 1:使用 Flower 框架模擬聯邦學習(影像分類)
在本實作中,我們使用 Flower[6]——目前最友善的聯邦學習框架——在單機上模擬 3 個客戶端進行 CIFAR-10 影像分類的聯邦訓練。Flower 的設計哲學是「框架無關」(framework-agnostic),支援 PyTorch、TensorFlow、JAX 等任何機器學習框架。
目標:(1)理解 Flower 的客戶端-伺服器架構;(2)模擬非 IID 資料分佈下的聯邦訓練;(3)比較聯邦訓練與集中式訓練的準確率。
環境需求:Google Colab(免費版即可,CPU 或 T4 GPU)。
# ============================================================
# Hands-on Lab 1: Flower 框架聯邦學習 — CIFAR-10 影像分類
# 環境: Google Colab (CPU or GPU)
# 目標: 模擬 3 個客戶端的 FedAvg 聯邦訓練
# ============================================================
# --- 0. 安裝依賴 ---
# !pip install flwr[simulation] torch torchvision matplotlib -q
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
import warnings
warnings.filterwarnings("ignore")
# Flower imports
import flwr as fl
from flwr.client import NumPyClient, ClientApp
from flwr.server import ServerApp, ServerConfig
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
print(f"Flower version: {fl.__version__}")
print(f"PyTorch version: {torch.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# --- 1. 資料準備:CIFAR-10 分割為 3 個非 IID 客戶端 ---
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2470, 0.2435, 0.2616)),
])
trainset = torchvision.datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform
)
testset = torchvision.datasets.CIFAR10(
root="./data", train=False, download=True, transform=transform
)
NUM_CLIENTS = 3
def partition_non_iid(dataset, num_clients, alpha=0.5):
"""
使用 Dirichlet 分佈模擬非 IID 資料分割。
alpha 越小,非 IID 程度越高。
alpha=0.5: 中度非 IID
alpha=0.1: 嚴重非 IID
alpha=100: 近似 IID
"""
labels = np.array([dataset[i][1] for i in range(len(dataset))])
num_classes = len(np.unique(labels))
client_indices = [[] for _ in range(num_clients)]
for c in range(num_classes):
class_indices = np.where(labels == c)[0]
np.random.shuffle(class_indices)
# Dirichlet 分佈決定每個客戶端分到多少該類別的樣本
proportions = np.random.dirichlet(np.repeat(alpha, num_clients))
# 按比例切割
splits = (proportions * len(class_indices)).astype(int)
# 確保總數正確
splits[-1] = len(class_indices) - splits[:-1].sum()
start = 0
for k in range(num_clients):
client_indices[k].extend(
class_indices[start:start + splits[k]].tolist()
)
start += splits[k]
return client_indices
np.random.seed(42)
client_indices = partition_non_iid(trainset, NUM_CLIENTS, alpha=0.5)
# 視覺化各客戶端的標籤分佈
fig, axes = plt.subplots(1, NUM_CLIENTS, figsize=(15, 4))
class_names = trainset.classes
for k in range(NUM_CLIENTS):
labels_k = [trainset[i][1] for i in client_indices[k]]
counts = np.bincount(labels_k, minlength=10)
axes[k].bar(range(10), counts, color='steelblue')
axes[k].set_title(f"Client {k+1} ({len(labels_k)} samples)")
axes[k].set_xticks(range(10))
axes[k].set_xticklabels(class_names, rotation=45, fontsize=7)
axes[k].set_ylabel("Count")
fig.suptitle("Non-IID Data Distribution (Dirichlet α=0.5)", fontsize=14)
plt.tight_layout()
plt.show()
# --- 2. 定義 CNN 模型 ---
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(128 * 4 * 4, 256)
self.fc2 = nn.Linear(256, 10)
self.dropout = nn.Dropout(0.3)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(-1, 128 * 4 * 4)
x = self.dropout(F.relu(self.fc1(x)))
x = self.fc2(x)
return x
# --- 3. 定義 Flower 客戶端 ---
def get_params(model):
"""取得模型參數為 NumPy 陣列列表"""
return [val.cpu().numpy() for _, val in model.state_dict().items()]
def set_params(model, params):
"""將 NumPy 陣列列表設定為模型參數"""
params_dict = zip(model.state_dict().keys(), params)
state_dict = OrderedDict(
{k: torch.tensor(v) for k, v in params_dict}
)
model.load_state_dict(state_dict, strict=True)
def train_local(model, trainloader, epochs, lr=0.001):
"""本地訓練"""
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for _ in range(epochs):
for images, labels in trainloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
loss = F.cross_entropy(model(images), labels)
loss.backward()
optimizer.step()
def evaluate_model(model, testloader):
"""評估模型"""
model.to(device)
model.eval()
correct, total, total_loss = 0, 0, 0.0
with torch.no_grad():
for images, labels in testloader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
total_loss += F.cross_entropy(outputs, labels).item()
correct += (outputs.argmax(1) == labels).sum().item()
total += labels.size(0)
return total_loss / len(testloader), correct / total
# 建立客戶端的 DataLoader
client_loaders = []
for k in range(NUM_CLIENTS):
subset = Subset(trainset, client_indices[k])
loader = DataLoader(subset, batch_size=32, shuffle=True)
client_loaders.append(loader)
testloader = DataLoader(testset, batch_size=64, shuffle=False)
# --- 4. 手動模擬 FedAvg(教學用,清晰展示每一步) ---
def fedavg_manual(num_rounds=10, local_epochs=2, lr=0.001):
"""
手動實現 FedAvg,便於理解每個步驟。
"""
# 初始化全域模型
global_model = SimpleCNN()
history = {"round": [], "loss": [], "accuracy": []}
print("=" * 60)
print("FedAvg Federated Training (Manual Implementation)")
print(f"Clients: {NUM_CLIENTS}, Rounds: {num_rounds}, "
f"Local Epochs: {local_epochs}")
print("=" * 60)
for rnd in range(1, num_rounds + 1):
# Step 1: 廣播全域模型參數給所有客戶端
global_params = get_params(global_model)
client_params_list = []
client_sizes = []
for k in range(NUM_CLIENTS):
# Step 2: 每個客戶端從全域模型開始本地訓練
local_model = SimpleCNN()
set_params(local_model, global_params)
train_local(local_model, client_loaders[k],
epochs=local_epochs, lr=lr)
# Step 3: 收集本地模型參數
client_params_list.append(get_params(local_model))
client_sizes.append(len(client_indices[k]))
# Step 4: FedAvg 加權聚合
total_size = sum(client_sizes)
new_params = []
for param_idx in range(len(global_params)):
weighted_sum = sum(
client_params_list[k][param_idx] *
(client_sizes[k] / total_size)
for k in range(NUM_CLIENTS)
)
new_params.append(weighted_sum)
# Step 5: 更新全域模型
set_params(global_model, new_params)
# 評估全域模型
loss, accuracy = evaluate_model(global_model, testloader)
history["round"].append(rnd)
history["loss"].append(loss)
history["accuracy"].append(accuracy)
print(f"Round {rnd:2d} | Loss: {loss:.4f} | "
f"Accuracy: {accuracy:.4f}")
return global_model, history
# 執行聯邦訓練
fed_model, fed_history = fedavg_manual(
num_rounds=10, local_epochs=2, lr=0.001
)
# --- 5. 集中式訓練對照組 ---
def centralized_training(epochs=20, lr=0.001):
"""集中式訓練(所有資料在一起)作為上界"""
model = SimpleCNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
full_loader = DataLoader(trainset, batch_size=64, shuffle=True)
history = {"epoch": [], "loss": [], "accuracy": []}
print("\n" + "=" * 60)
print("Centralized Training (Upper Bound)")
print("=" * 60)
for epoch in range(1, epochs + 1):
model.train()
for images, labels in full_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
loss = F.cross_entropy(model(images), labels)
loss.backward()
optimizer.step()
loss, accuracy = evaluate_model(model, testloader)
history["epoch"].append(epoch)
history["loss"].append(loss)
history["accuracy"].append(accuracy)
if epoch % 5 == 0:
print(f"Epoch {epoch:2d} | Loss: {loss:.4f} | "
f"Accuracy: {accuracy:.4f}")
return model, history
central_model, central_history = centralized_training(
epochs=20, lr=0.001
)
# --- 6. 結果視覺化 ---
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Accuracy 比較
axes[0].plot(fed_history["round"], fed_history["accuracy"],
'o-', label="Federated (FedAvg)", color='#0077b6',
linewidth=2, markersize=5)
axes[0].plot(central_history["epoch"], central_history["accuracy"],
's-', label="Centralized", color='#b8922e',
linewidth=2, markersize=4)
axes[0].set_xlabel("Round / Epoch")
axes[0].set_ylabel("Test Accuracy")
axes[0].set_title("Federated vs. Centralized: Accuracy")
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Loss 比較
axes[1].plot(fed_history["round"], fed_history["loss"],
'o-', label="Federated (FedAvg)", color='#0077b6',
linewidth=2, markersize=5)
axes[1].plot(central_history["epoch"], central_history["loss"],
's-', label="Centralized", color='#b8922e',
linewidth=2, markersize=4)
axes[1].set_xlabel("Round / Epoch")
axes[1].set_ylabel("Test Loss")
axes[1].set_title("Federated vs. Centralized: Loss")
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# 最終比較
print("\n" + "=" * 60)
print("Final Results Summary")
print("=" * 60)
print(f"Federated (FedAvg, 10 rounds): "
f"Accuracy = {fed_history['accuracy'][-1]:.4f}")
print(f"Centralized (20 epochs): "
f"Accuracy = {central_history['accuracy'][-1]:.4f}")
gap = central_history['accuracy'][-1] - fed_history['accuracy'][-1]
print(f"Accuracy Gap: {gap:.4f}")
print(f"\nNote: Federated training preserves data privacy while")
print(f"achieving competitive accuracy with centralized training.")
預期結果:在 CIFAR-10 上,FedAvg 經過 10 輪聯邦訓練(每輪 2 個本地 epoch)通常可以達到約 65-72% 的準確率,而集中式訓練 20 個 epoch 可以達到約 73-78%。聯邦訓練在保護隱私的前提下,只損失了約 3-8% 的準確率。如果將 Dirichlet 的 alpha 降低(例如 0.1),非 IID 程度加劇,聯邦訓練的準確率會進一步下降,這正是 FedProx 等改進方法的必要性所在。
七、Hands-on Lab 2:差分隱私聯邦學習實驗
在本實作中,我們將差分隱私(Differential Privacy)整合到聯邦學習流程中。使用 Opacus——Meta 開發的 PyTorch 差分隱私函式庫——為每個客戶端的本地訓練加入 DP 保護,並觀察隱私預算 epsilon 對模型準確率的影響。
目標:(1)理解 DP-SGD 的梯度裁剪與雜訊機制;(2)實驗不同隱私預算 epsilon 的精度-隱私權衡;(3)視覺化隱私預算消耗曲線。
環境需求:Google Colab(免費版即可)。
# ============================================================
# Hands-on Lab 2: 差分隱私聯邦學習實驗
# 環境: Google Colab (CPU or GPU)
# 目標: 以 Opacus 實現 DP-FedAvg,觀察 ε 對精度的影響
# ============================================================
# --- 0. 安裝依賴 ---
# !pip install opacus torch torchvision matplotlib -q
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
import copy
import warnings
warnings.filterwarnings("ignore")
# Opacus imports
from opacus import PrivacyEngine
from opacus.validators import ModuleValidator
print(f"PyTorch version: {torch.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# --- 1. 資料準備(MNIST,DP 實驗用較小資料集加速) ---
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
trainset = torchvision.datasets.MNIST(
root="./data", train=True, download=True, transform=transform
)
testset = torchvision.datasets.MNIST(
root="./data", train=False, download=True, transform=transform
)
NUM_CLIENTS = 3
def partition_iid(dataset, num_clients):
"""IID 分割(DP 實驗專注於隱私,使用 IID 排除 non-IID 干擾)"""
indices = np.random.permutation(len(dataset))
splits = np.array_split(indices, num_clients)
return [s.tolist() for s in splits]
np.random.seed(42)
client_indices = partition_iid(trainset, NUM_CLIENTS)
print(f"Client data sizes: "
f"{[len(idx) for idx in client_indices]}")
# --- 2. 定義符合 Opacus 的 CNN 模型 ---
# Opacus 要求模型不能使用某些不相容的層
# 例如 nn.BatchNorm 需要替換為 nn.GroupNorm
class DPCNN(nn.Module):
"""Opacus-compatible CNN (uses GroupNorm instead of BatchNorm)"""
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
self.gn1 = nn.GroupNorm(4, 16)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.gn2 = nn.GroupNorm(4, 32)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool(F.relu(self.gn1(self.conv1(x))))
x = self.pool(F.relu(self.gn2(self.conv2(x))))
x = x.view(-1, 32 * 7 * 7)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# 驗證模型與 Opacus 相容
sample_model = DPCNN()
errors = ModuleValidator.validate(sample_model, strict=False)
if errors:
print(f"Model validation errors: {errors}")
sample_model = ModuleValidator.fix(sample_model)
print("Model fixed for Opacus compatibility.")
else:
print("Model is Opacus-compatible.")
# --- 3. 工具函數 ---
def get_params(model):
return [val.cpu().detach().numpy()
for _, val in model.state_dict().items()]
def set_params(model, params):
params_dict = zip(model.state_dict().keys(), params)
state_dict = OrderedDict(
{k: torch.tensor(v) for k, v in params_dict}
)
model.load_state_dict(state_dict, strict=True)
def evaluate_model(model, testloader):
model.to(device)
model.eval()
correct, total = 0, 0
with torch.no_grad():
for images, labels in testloader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
correct += (outputs.argmax(1) == labels).sum().item()
total += labels.size(0)
return correct / total
testloader = DataLoader(testset, batch_size=256, shuffle=False)
# --- 4. DP-FedAvg 實現 ---
def train_local_with_dp(model, trainloader, epochs, lr,
target_epsilon, target_delta, max_grad_norm):
"""
使用 Opacus 進行差分隱私本地訓練。
Args:
target_epsilon: 目標隱私預算
target_delta: delta 參數(通常 1/n)
max_grad_norm: 每筆樣本的梯度範數裁剪上界
"""
model = copy.deepcopy(model)
model.to(device)
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# 建立 PrivacyEngine
privacy_engine = PrivacyEngine()
model, optimizer, trainloader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=trainloader,
epochs=epochs,
target_epsilon=target_epsilon,
target_delta=target_delta,
max_grad_norm=max_grad_norm,
)
for epoch in range(epochs):
for images, labels in trainloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
output = model(images)
loss = F.cross_entropy(output, labels)
loss.backward()
optimizer.step()
# 取得實際消耗的隱私預算
actual_epsilon = privacy_engine.get_epsilon(delta=target_delta)
# 回傳未包裝的模型參數
raw_model = model._module if hasattr(model, '_module') else model
return get_params(raw_model), actual_epsilon
def train_local_no_dp(model, trainloader, epochs, lr):
"""無差分隱私的本地訓練(對照組)"""
model = copy.deepcopy(model)
model.to(device)
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
for epoch in range(epochs):
for images, labels in trainloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
output = model(images)
loss = F.cross_entropy(output, labels)
loss.backward()
optimizer.step()
return get_params(model)
def fedavg_aggregate(global_params, client_params_list, client_sizes):
"""FedAvg 加權聚合"""
total = sum(client_sizes)
new_params = []
for i in range(len(global_params)):
weighted = sum(
client_params_list[k][i] * (client_sizes[k] / total)
for k in range(len(client_params_list))
)
new_params.append(weighted)
return new_params
# --- 5. 實驗:不同隱私預算的比較 ---
EPSILON_VALUES = [1.0, 3.0, 8.0] # 不同隱私預算
NUM_ROUNDS = 8
LOCAL_EPOCHS = 1
LR = 0.05
DELTA = 1e-5
MAX_GRAD_NORM = 1.0
BATCH_SIZE = 64
results = {}
# 實驗 1: 無差分隱私的 FedAvg(對照組)
print("=" * 60)
print("Experiment: FedAvg WITHOUT Differential Privacy")
print("=" * 60)
global_model = DPCNN()
history_no_dp = []
for rnd in range(1, NUM_ROUNDS + 1):
global_params = get_params(global_model)
client_params_list = []
client_sizes = []
for k in range(NUM_CLIENTS):
loader = DataLoader(
Subset(trainset, client_indices[k]),
batch_size=BATCH_SIZE, shuffle=True
)
local_model = DPCNN()
set_params(local_model, global_params)
local_params = train_local_no_dp(
local_model, loader, LOCAL_EPOCHS, LR
)
client_params_list.append(local_params)
client_sizes.append(len(client_indices[k]))
new_params = fedavg_aggregate(
global_params, client_params_list, client_sizes
)
set_params(global_model, new_params)
acc = evaluate_model(global_model, testloader)
history_no_dp.append(acc)
print(f"Round {rnd} | Accuracy: {acc:.4f}")
results["No DP"] = history_no_dp
# 實驗 2-4: 不同 ε 值的 DP-FedAvg
for target_eps in EPSILON_VALUES:
print(f"\n{'=' * 60}")
print(f"Experiment: DP-FedAvg with ε = {target_eps}")
print("=" * 60)
global_model = DPCNN()
history = []
epsilons_consumed = []
for rnd in range(1, NUM_ROUNDS + 1):
global_params = get_params(global_model)
client_params_list = []
client_sizes = []
round_epsilons = []
for k in range(NUM_CLIENTS):
loader = DataLoader(
Subset(trainset, client_indices[k]),
batch_size=BATCH_SIZE, shuffle=True
)
local_model = DPCNN()
set_params(local_model, global_params)
local_params, actual_eps = train_local_with_dp(
local_model, loader, LOCAL_EPOCHS, LR,
target_epsilon=target_eps,
target_delta=DELTA,
max_grad_norm=MAX_GRAD_NORM,
)
client_params_list.append(local_params)
client_sizes.append(len(client_indices[k]))
round_epsilons.append(actual_eps)
new_params = fedavg_aggregate(
global_params, client_params_list, client_sizes
)
set_params(global_model, new_params)
acc = evaluate_model(global_model, testloader)
avg_eps = np.mean(round_epsilons)
history.append(acc)
epsilons_consumed.append(avg_eps)
print(f"Round {rnd} | Accuracy: {acc:.4f} | "
f"ε consumed: {avg_eps:.2f}")
results[f"ε={target_eps}"] = history
# --- 6. 結果視覺化 ---
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
# 圖 1: 準確率比較
colors = {'No DP': '#2d3436', 'ε=1.0': '#d63031',
'ε=3.0': '#e17055', 'ε=8.0': '#0984e3'}
for label, hist in results.items():
axes[0].plot(range(1, NUM_ROUNDS + 1), hist,
'o-', label=label, color=colors[label],
linewidth=2, markersize=5)
axes[0].set_xlabel("Communication Round", fontsize=12)
axes[0].set_ylabel("Test Accuracy", fontsize=12)
axes[0].set_title("Privacy-Accuracy Tradeoff in DP-FedAvg",
fontsize=13)
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)
# 圖 2: 最終準確率 vs 隱私預算
final_accs = [results[k][-1] for k in results]
labels = list(results.keys())
bar_colors = [colors[k] for k in labels]
bars = axes[1].bar(labels, final_accs, color=bar_colors, edgecolor='white')
axes[1].set_ylabel("Final Test Accuracy", fontsize=12)
axes[1].set_title("Final Accuracy at Different Privacy Levels",
fontsize=13)
axes[1].set_ylim(0, 1.0)
for bar, acc in zip(bars, final_accs):
axes[1].text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
f"{acc:.3f}", ha='center', fontsize=11, fontweight='bold')
plt.tight_layout()
plt.show()
# --- 7. 隱私-精度權衡分析 ---
print("\n" + "=" * 60)
print("Privacy-Accuracy Tradeoff Summary")
print("=" * 60)
print(f"{'Setting':<15} {'Final Accuracy':<18} {'Privacy Level'}")
print("-" * 55)
for label, hist in results.items():
if label == "No DP":
privacy = "None (baseline)"
elif "1.0" in label:
privacy = "Strong (ε=1)"
elif "3.0" in label:
privacy = "Moderate (ε=3)"
else:
privacy = "Relaxed (ε=8)"
print(f"{label:<15} {hist[-1]:<18.4f} {privacy}")
print("\nKey Takeaways:")
print("1. ε=8 (relaxed DP) retains most accuracy, suitable for "
"low-sensitivity data")
print("2. ε=1 (strong DP) causes noticeable accuracy drop, "
"but provides mathematical privacy guarantee")
print("3. The accuracy gap can be reduced by: more clients, "
"more communication rounds, or larger local datasets")
print("4. In practice, ε ∈ [3, 10] is the sweet spot for most "
"enterprise applications")
預期結果:在 MNIST 上,無 DP 的 FedAvg 可以達到約 97-98% 的準確率。加入差分隱私後:epsilon=8(寬鬆隱私)約 95-97%,epsilon=3(中度隱私)約 92-95%,epsilon=1(嚴格隱私)約 85-92%。這清楚展示了隱私-精度的權衡關係——更強的隱私保護需要更多的雜訊,而更多的雜訊必然導致精度下降。在實務中,可以透過增加客戶端數量、增加通訊輪數或增大本地資料集來緩解這個權衡。
八、聯邦學習框架選型:Flower vs PySyft vs FATE
選擇正確的聯邦學習框架是工程落地的第一步。目前主流的開源框架各有定位[6][7]:
| 框架 | 主導方 | 核心定位 | 支援 ML 框架 | 隱私機制 | 適用場景 |
|---|---|---|---|---|---|
| Flower | Flower Labs | 框架無關、研究友善 | PyTorch、TF、JAX、任意 | DP(透過 Opacus/TF-Privacy)、SecAgg | 研究原型、Cross-Silo 與 Cross-Device |
| PySyft | OpenMined | 隱私優先、可驗證計算 | PyTorch 為主 | DP、SMPC、同態加密 | 高隱私需求場景(醫療、金融) |
| FATE | 微眾銀行 | 企業級生產系統 | 自有框架 + PyTorch | 同態加密、SecAgg | 金融業跨機構協作 |
| NVIDIA FLARE | NVIDIA | 企業級、醫療為主 | PyTorch、TF | DP、同態加密 | 醫療影像、大規模 GPU 叢集 |
| TFF | 研究用模擬框架 | TensorFlow | DP(TF-Privacy) | 聯邦學習演算法研究 | |
| OpenFL | Intel | 跨組織協作 | PyTorch、TF | DP | 醫療、製藥協作 |
選型建議:
- 學術研究或概念驗證:Flower 是最佳起點。API 極為簡潔(定義一個客戶端類別即可),支援所有主流 ML 框架,模擬模式讓你在筆電上就能實驗上百個客戶端。
- 醫療或高隱私場景:PySyft 提供最完整的隱私計算堆疊(DP + SMPC + 同態加密);NVIDIA FLARE 和 OpenFL 有現成的醫療影像聯邦訓練流水線。
- 中國金融業:FATE 是事實上的標準,有完整的中文文件、管理介面和合規支援。
- 從原型到生產:Flower 的 Simulation 模式可以快速驗證,確認可行後切換到分散式部署模式(gRPC),無需重寫客戶端程式碼。
九、決策框架與企業導入建議
導入聯邦學習不僅是技術決策,更涉及組織、法務與商業層面。以下是一個結構化的決策框架:
9.1 是否需要聯邦學習?
聯邦學習需求評估決策樹:
Q1: 資料能否集中?
├── 可以 → 使用傳統集中式訓練(更簡單、效果更好)
└── 不可以 → 繼續
Q2: 為何不能集中?
├── 法規限制(GDPR/HIPAA)→ 強需求,聯邦學習 + DP
├── 商業競爭(不願共享資料)→ 中需求,聯邦學習 + SecAgg
└── 物理限制(資料太大/設備太多)→ 中需求,Cross-Device FL
Q3: 資料分佈類型?
├── 各方有相同特徵、不同樣本 → 橫向聯邦學習
├── 各方有相同樣本、不同特徵 → 縱向聯邦學習
└── 兩者皆不同 → 聯邦遷移學習
Q4: 客戶端數量?
├── 2-100 家機構 → Cross-Silo(穩定、可靠)
└── 數千至數億設備 → Cross-Device(需特殊系統設計)
9.2 導入路線圖
| 階段 | 活動 | 交付物 | 時程 |
|---|---|---|---|
| 1. 可行性評估 | 資料審計、法規分析、利害關係人訪談 | 可行性報告、ROI 估算 | 2-4 週 |
| 2. 概念驗證 | 單機模擬(Flower Simulation)、基線比較 | 技術可行性報告、精度對比 | 4-6 週 |
| 3. 試點部署 | 2-3 個真實節點、真實資料、端到端測試 | 系統架構、隱私影響評估 | 2-3 個月 |
| 4. 生產上線 | 完整部署、監控、自動化管線 | SLA、營運手冊、合規文件 | 3-6 個月 |
| 5. 持續優化 | 新客戶端接入、模型版本管理、效能調優 | 定期效能報告、模型更新策略 | 持續 |
9.3 常見陷阱與應對
- 陷阱 1:低估通訊成本。模型參數的上下傳是聯邦學習的瓶頸,尤其對大型模型而言。應對:使用梯度壓縮(Top-K Sparsification)、模型量化、知識蒸餾減小傳輸量。
- 陷阱 2:忽略系統異質性。不同客戶端的硬體能力差異巨大。慢速節點(stragglers)會拖慢整個系統。應對:使用非同步聯邦學習或 FedProx 的彈性本地更新。
- 陷阱 3:將聯邦學習等同於隱私保護。原始的 FedAvg 沒有任何正式的隱私保證。必須加入差分隱私或安全聚合才能聲稱「隱私保護」。應對:從設計之初就整合隱私機制,並量化隱私預算。
- 陷阱 4:忽略資料品質差異。某些客戶端的資料品質可能很差(標註錯誤、噪音大),但 FedAvg 會給予它們與高品質客戶端相同的影響力。應對:使用基於信譽的聚合權重或異常偵測排除低品質更新。
- 陷阱 5:缺乏模型驗證機制。在隱私約束下,伺服器無法直接在客戶端資料上驗證模型。應對:保留一個共享的驗證集(脫敏後),或使用聯邦式的模型評估協議。
十、結語
聯邦學習不是一項單純的技術創新——它代表了 AI 產業從「資料集中化」到「計算分散化」的典範轉移[2]。在 GDPR、HIPAA 等法規日益嚴格的今天,能夠在不集中資料的前提下訓練高品質模型,已經從「錦上添花」變成了「剛性需求」。
技術層面,FedAvg[1] 作為基石演算法已經被廣泛驗證,而 FedProx[3]、FedBN[15] 等後續改進有效解決了非 IID 資料的挑戰。差分隱私[12]和安全聚合[5]提供了數學上可證明的隱私保證,將「不共享資料」從口號變成了可驗證的承諾。
產業層面,聯邦學習已經在醫療[8][9]、金融[4]、行動裝置[10]等領域證明了其實用價值。隨著 Flower[6] 等框架的成熟,導入門檻正在快速降低。
對企業決策者而言,現在是評估聯邦學習的最佳時機。不必等到法規迫使你行動——主動擁抱隱私保護 AI,不僅能降低合規風險,更能開啟原本因為資料隱私而無法實現的跨機構協作機會。從 Flower 模擬開始,用本文的兩個 Colab 實作驗證你的場景,然後以結構化的方式推進到試點和生產部署。
資料不能出門的時代已經來臨。聯邦學習讓 AI 模型走出去,替代資料的遷移——這不僅是技術上的突破,更是對「資料所有權」的一次根本性重新定義。