Federated Learning
Foundations
Federated learning (FL) trains machine learning models across multiple decentralized data holders (clients) without sharing raw data. Each client keeps its data local; only model updates (gradients or parameters) are communicated to a central server (or exchanged peer-to-peer). This paradigm addresses privacy, regulatory, and practical constraints that prevent data centralization.
Motivation
- Privacy: Medical records, financial data, personal device data cannot be centralized due to legal constraints (GDPR, HIPAA) and user expectations.
- Communication: Edge devices generate massive data volumes; transmitting all data to a central server is impractical.
- Heterogeneity: Data distributions across clients differ (non-IID), reflecting real-world variation.
System Model
- Clients: K participants, each with local dataset D_k. Datasets may differ in size, distribution, and feature space.
- Server: Coordinates training, aggregates model updates. Does not access raw data.
- Communication rounds: In each round, the server sends the current global model to (a subset of) clients; clients perform local training; clients send updates back; the server aggregates.
FedAvg (Federated Averaging)
McMahan et al. (2017) introduced the foundational FL algorithm:
Per round:
- Server selects a subset C of K clients (fraction C/K, typically 0.1-0.3).
- Server broadcasts global model w_t to selected clients.
- Each client k runs E epochs of SGD on its local data: w_k^{t+1} = w_t - η ∇L_k(w_t) (multiple local steps).
- Server aggregates: w_{t+1} = Σ_k (n_k/n) w_k^{t+1} (weighted by local dataset size n_k).
Multiple local epochs: FedAvg performs E > 1 local epochs before communication. This dramatically reduces communication (fewer rounds needed) but introduces client drift — local models diverge from each other when training on heterogeneous data.
Convergence: For IID data and convex objectives, FedAvg converges. For non-IID data, convergence is slower and the solution may be suboptimal. The error bound has terms proportional to the data heterogeneity and the number of local steps.
Communication Efficiency
Communication is the bottleneck in FL: model updates can be large (millions of parameters), and client bandwidth is limited (especially for mobile devices).
Gradient Compression
- Quantization: Reduce the precision of communicated values. 1-bit SGD (sign of each gradient component) reduces communication by 32x. QSGD (Alistarh et al., 2017) provides tunable quantization with convergence guarantees.
- Sparsification: Transmit only the top-k% of gradient components by magnitude. Random-k: Random subset with appropriate scaling. Gradient dropping: Accumulate residual errors locally to avoid information loss (error feedback/memory).
- Compression: Apply lossless or lossy compression (entropy coding, sketching) to gradient vectors.
Reducing Communication Rounds
- More local computation: Increase E (local epochs). Trades communication for computation but increases client drift.
- FedProx (Li et al., 2020): Add a proximal term μ/2 ||w - w_t||² to the local objective, penalizing deviation from the global model. Controls client drift while allowing multiple local steps.
- SCAFFOLD (Karimireddy et al., 2020): Use control variates to correct for client drift. Each client maintains a correction term that estimates the difference between the local and global gradient directions. Provably reduces the effect of data heterogeneity.
Non-IID Challenges
Types of Heterogeneity
Label distribution skew (quantity skew): Different clients have different proportions of classes. Extreme case: each client has data from only 1-2 classes. This is the most studied and problematic form.
Feature distribution skew: Same labels but different feature distributions (e.g., photos taken in different lighting, demographics).
Concept shift: Same features map to different labels across clients (different labeling conventions or genuinely different relationships).
Quantity skew: Different clients have vastly different amounts of data.
Impact on FedAvg
Non-IID data causes client drift: local models optimize for their local distribution, diverging from each other. Averaging diverged models produces a global model that is suboptimal for all clients. Empirically, accuracy can drop 10-50% compared to centralized training on the same data.
Mitigation Strategies
- Data sharing: Share a small global dataset across clients to anchor the local distributions. Effective but partially defeats the purpose of FL.
- FedProx: Proximal regularization limits local divergence.
- FedNova (Wang et al., 2020): Normalize local updates by the number of local steps, addressing objective inconsistency.
- FedDyn (Acar et al., 2021): Dynamic regularization with per-client regularization terms that align local and global objectives.
Personalization
A single global model may perform poorly for clients with distinct data distributions. Personalized federated learning tailors models to individual clients while still benefiting from collaborative training.
Approaches
Fine-tuning: Train a global model via FedAvg, then fine-tune locally on each client's data. Simple and effective for moderate heterogeneity.
Per-FedAvg (Fallah et al., 2020): Apply MAML (meta-learning) to FL. The global model serves as an initialization optimized for fast local adaptation. Each client performs a few gradient steps; the meta-objective ensures the initialization adapts well.
Local adaptation layers: Share lower layers (feature extractor) globally; keep upper layers (classifier head) local. Captures shared low-level features while personalizing high-level decisions. FedPer (Arivazhagan et al., 2019), FedRep (Collins et al., 2021).
Clustered FL: Group clients with similar distributions and train separate models per cluster. IFCA (Ghosh et al., 2020): Iteratively cluster clients by model loss and train cluster-specific models.
Mixture of global and local models: pFL-HN (hypernetwork generates personalized models from client embeddings). Ditto (Li and Jiang, 2021): Each client maintains a personalized model regularized toward the global model.
Privacy in Federated Learning
FL does not inherently guarantee privacy — model updates can leak information about training data.
Attacks
Gradient inversion: Reconstruct training samples from gradients. Zhu et al. (2019) showed that individual images can be pixel-perfectly reconstructed from a single gradient update. R-GAP, GradInversion improve reconstruction quality.
Membership inference: Determine whether a specific data point was in a client's training set.
Model inversion: Reconstruct representative samples of a class from the trained model.
Differential Privacy (DP)
Add calibrated noise to model updates to provide formal privacy guarantees.
DP-SGD (Abadi et al., 2016): Clip per-sample gradients to bound sensitivity, then add Gaussian noise: g̃ = (1/B) Σ clip(g_i, C) + N(0, σ²C²I). The privacy budget (ε, δ) is tracked via the moments accountant (RDP composition).
Client-level DP: Clip and noise entire client model updates (not per-sample gradients). Protects participation of individual clients. Requires clipping the entire model update vector.
Tradeoffs: Stronger privacy (smaller ε) requires more noise, degrading model quality. Large cohorts (many clients per round) amortize the noise. Google's deployed FL systems use DP with ε ≈ 5-20.
Secure Aggregation
Cryptographic protocols ensure the server sees only the aggregated model update, not individual client updates.
Masking-based protocols (Bonawitz et al., 2017): Each pair of clients agrees on a random mask. Client k adds mask to its update; masks cancel in the sum. The server computes the aggregate but cannot see individual updates. Handles client dropout via secret sharing (Shamir's scheme).
Homomorphic encryption (HE): Clients encrypt updates; the server aggregates ciphertexts without decryption. The decrypted aggregate equals the sum of individual updates. Computationally expensive but provides strong guarantees.
Trusted execution environments (TEE): Perform aggregation inside a hardware enclave (Intel SGX, ARM TrustZone). The server's own code cannot access individual updates.
Vertical and Horizontal Federated Learning
Horizontal FL
Clients have the same features but different samples (e.g., hospitals with patient records using the same schema but different patients). This is the standard FL setting described above.
Vertical FL
Clients have different features for the same samples (e.g., a bank and an e-commerce company have different data about the same users). Features must be aligned by entity (using private set intersection for ID matching).
Split learning: The model is split across clients. Each client computes forward pass on its features up to a cut layer; intermediate representations (smashed data) are sent to the server or other party for the remainder of the forward pass. Gradients flow back through the cut.
Challenges: Requires entity alignment (which samples correspond across clients). The intermediate representations can leak information. Communication per sample (not per round) is needed.
Federated Transfer Learning
Clients have different features and different samples with partial overlap. Combines techniques from transfer learning and FL.
Incentive Mechanisms
Rational clients may not participate in FL without incentives, especially if participation incurs computation and communication costs.
Data Valuation
Shapley value: Assign each client a value proportional to its marginal contribution to the global model. Computing exact Shapley values requires training on all subsets of clients (2^K) — approximations (Monte Carlo sampling, truncated Shapley) are used in practice.
Contribution-based rewards: Clients contributing more valuable data (improving the global model more) receive greater rewards. FedCoin: Blockchain-based incentive mechanism with Shapley-based payments.
Free-Rider Attacks
Clients may submit random or stale updates to receive the global model without contributing genuine computation. Detection via consistency checks, reputation systems, or verifiable computation.
Systems and Deployment
Production Systems
Google's FL system: Deployed on Android for keyboard next-word prediction (Gboard), emoji prediction, and on-device ranking. Uses secure aggregation, DP, and on-device training during idle/charging periods.
Apple: On-device ML with FL for Siri, QuickType, and photo features. Strong privacy stance — heavy use of DP.
FATE (Federated AI Technology Enabler): WeBank's open-source FL framework for financial applications. Supports horizontal, vertical, and transfer FL.
Frameworks
- Flower: Framework-agnostic FL framework. Supports PyTorch, TensorFlow, JAX. Simulation and production deployment.
- PySyft (OpenMined): Privacy-focused FL framework with DP and secure computation support.
- TensorFlow Federated (TFF): Google's FL framework. Functional programming model for FL algorithms.
- NVIDIA FLARE: Enterprise FL for healthcare and financial applications.
Practical Challenges
- Stragglers: Slow clients delay synchronous rounds. Solutions: asynchronous FL (FedAsync), deadline-based aggregation, over-selection.
- Client availability: Mobile devices are available intermittently. Must handle partial participation and dropout gracefully.
- Model size: Large models (LLMs) are impractical for on-device training. Federated fine-tuning with LoRA reduces the parameter footprint.
- Data quality: No central authority can inspect or clean client data. Robust aggregation (Krum, trimmed mean, median) defends against noisy or adversarial updates.