Training at scale is a memory and communication problem before it's a math problem. You spend VRAM on weights + gradients + optimizer states + activations, and you spend wall-clock on moving gradients/params between GPUs. Every technique here buys back one of those two. This sheet covers the memory math, parallelism, ZeRO/FSDP, stability, fine-tuning, alignment, data, checkpointing, and the health metrics that tell you it's working.
1. Where the VRAM goes
| Consumer | Cost (Adam, mixed precision) | Notes |
|---|---|---|
| Weights | 2 B/param (BF16) | The model itself. |
| Gradients | 2 B/param | One per weight. |
| Optimizer states | ~12 B/param | Adam: FP32 momentum (4) + variance (4) + FP32 master weights (4). |
| Activations | ∝ batch × seq × layers | Often the biggest, and the most controllable (checkpointing). |
So ~16 B/param just for model states with Adam. A 7B model ≈ 112 GB before activations — which is why a single 80 GB GPU can't full-fine-tune 7B without sharding/offload, and why ZeRO/FSDP exist.
the "16 bytes/param" rule
Memorize it: BF16 weights (2) + BF16 grads (2) + Adam states (12) ≈ 16 B/param for model state.
Multiply by params for a first-order VRAM estimate, then add activations (cut with checkpointing) and
the framework's workspace. 8-bit Adam roughly halves the optimizer term.
2. Parallelism strategies
| Strategy | Splits | Comms | When |
|---|---|---|---|
| Data parallel (DDP) | The batch; full model replicated | All-reduce grads / step | Model fits on one GPU; scale throughput |
| ZeRO / FSDP | Optimizer → grads → params | All-gather params, reduce-scatter grads | Model too big to replicate |
| Tensor parallel | Each layer's matmuls | All-reduce activations / layer | Huge layers; NVLink intra-node |
| Pipeline parallel | Layers into stages | Activations stage→stage | Cross-node; micro-batch to fill bubble |
| Sequence/context parallel | The sequence dimension | Attention comms | Very long context |
| Expert parallel | MoE experts | All-to-all routing | Mixture-of-Experts |
3D parallelism = TP × PP × DP, mapped to topology: TP within a node (fast NVLink), PP across nodes, DP (often ZeRO) on top. FSDP alone covers most sub-100B fine-tuning.
3. ZeRO / FSDP stages
- Stage 1 — shard optimizer states across the DP group. Cheapest; biggest single win for Adam (the 12 B/param term).
- Stage 2 — also shard gradients (reduce-scatter instead of all-reduce).
- Stage 3 (FSDP full shard) — also shard parameters; each GPU holds a slice and all-gathers a layer's full params just-in-time for compute, frees after. Max memory savings, most comms.
- Offload — push states/params/grads to CPU (ZeRO-Offload) or NVMe (ZeRO-Infinity) when sharding still doesn't fit. Big slowdown; last resort.
# PyTorch FSDP sketch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = FSDP(model, sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16),
auto_wrap_policy=transformer_auto_wrap_policy)
sharding trades memory for comms
ZeRO-3 fits enormous models but all-gathers params every layer. On slow interconnect (PCIe, cross-node
without IB/RoCE) comms dominates and GPUs idle. Match the stage to your hardware; prefer the lowest
stage that still fits, and keep TP within a NVLink node.
4. Memory-saving techniques
- Mixed precision (BF16) — halves weight/grad/activation bytes; BF16 avoids loss scaling.
- Gradient (activation) checkpointing — don't store activations; recompute in backward. ~√N memory for ~33% more compute. The biggest lever for long sequences.
- Gradient accumulation — large effective batch from several micro-batches before
step(). - 8-bit optimizer (bitsandbytes) — quantize Adam states; ~halves the 12 B/param term.
- FlashAttention — removes O(N²) attention activation memory.
- CPU/NVMe offload — capacity at the cost of speed.
- Fused optimizer, fused layernorm — fewer kernels, slightly less memory traffic.
5. Training stability
- Learning-rate warmup — ramp LR from ~0 over the first steps; skipping it is a classic early-divergence cause. Follow with cosine (or linear) decay.
- Gradient clipping — clip global grad-norm (e.g. 1.0) so a spike doesn't blow up the run.
- BF16 over FP16 — wider exponent range → no loss scaling, far fewer NaNs. FP16 needs a dynamic loss scaler.
- Loss scaling (FP16) — multiply loss before backward to lift small grads out of underflow; unscale before clip/step; dynamic scaler backs off on overflow.
- Stabilizers — z-loss (penalize logit magnitude), QK-norm, careful init (scaled by depth), embedding/LM-head tying, RMSNorm placement.
- Watch grad-norm — it's the early-warning signal; a spike precedes a NaN.
loss → NaN: the usual suspects
LR too high, no warmup, no grad clipping, FP16 overflow (switch to BF16 / fix scaler), bad data
(corrupt/inf sample, empty or over-long sequence, label out of range), or division-by-zero in a custom
layer. Bisect: lower LR, clip, switch to BF16, inspect the exact batch at the failing step.
6. Optimizers & schedules
- AdamW — the default; decoupled weight decay. β₁≈0.9, β₂≈0.95 for LLMs (lower β₂ than vision), ε≈1e-8.
- Fused / 8-bit Adam — faster / lower memory variants.
- Schedules — warmup + cosine decay is standard; WSD (warmup-stable-decay) for continuable runs; linear for fine-tunes.
- Weight decay — typically ~0.1; usually excluded from norms/biases/embeddings.
- Batch size & LR — scale LR with batch (roughly linear/√), use the largest stable batch your throughput likes; gradient accumulation to reach it.
7. Fine-tuning (PEFT)
| Method | What | Trade |
|---|---|---|
| Full fine-tune | Update all weights | Best quality, most VRAM, forgetting risk |
| LoRA | Freeze base, train low-rank A·B adapters (ΔW=(α/r)BA) | ~0.1–1% params; mergeable; cheap |
| QLoRA | 4-bit (NF4) frozen base + LoRA, double-quant, paged optimizer | Fine-tune big models on one GPU |
| Adapters / prefix / IA³ / (IA)³ | Other PEFT inserts | Variants for different needs |
| DoRA | Weight-decomposed LoRA | Closer to full FT quality |
LoRA rank r (8–64) and alpha set capacity; target the attention (and often
MLP) projections. Merge adapters for zero-overhead inference, or keep them separate for multi-LoRA serving.
8. Post-training / alignment
- SFT — supervised fine-tune on instruction→response pairs (mask the prompt tokens in the loss). The foundation of a chat model.
- RLHF (PPO) — train a reward model from human preference pairs, then optimize the policy against it with PPO + a KL penalty to the SFT reference. Powerful but complex/unstable (reward hacking, 4 models in memory).
- DPO — Direct Preference Optimization: a classification-style loss directly on preference pairs that provably optimizes the RLHF objective without a reward model or RL loop. Simpler, stable, popular default.
- Others — ORPO (no reference model), KTO (pointwise feedback), GRPO/RLVR (verifiable rewards for reasoning), constitutional/RLAIF (AI feedback).
9. Data & tokenization
- Data is the product. Dedup (exact + near, MinHash), quality-filter (classifiers, heuristics), decontaminate against eval sets, and mix domains deliberately (curriculum / weighting).
- Tokenization — BPE/SentencePiece; vocab size affects embedding/LM-head memory and Tensor-Core alignment (pad it). Watch for whitespace/special-token handling bugs.
- Sequence packing — concatenate samples to fill the context and avoid padding waste; mask cross-document attention.
- SFT label masking — only compute loss on the response tokens, not the prompt.
10. Checkpointing & resume
- Long runs will hit node failures — checkpoint frequently and test the resume path before you depend on it.
- Save weights + optimizer state + LR-scheduler step + RNG seeds + data-iterator/sampler position + step/token count. Weights-only resume restarts the optimizer cold → loss bump.
- For sharded training use distributed/sharded checkpointing; use async checkpointing to avoid stalling the run.
- Keep rolling + best checkpoints; verify integrity (a half-written checkpoint after a crash is worse than none).
11. Health metrics to log
| Metric | Tells you |
|---|---|
| train/eval loss | Learning + overfitting (gap) |
| grad-norm | Stability; spike → incoming NaN |
| learning rate | Schedule is doing what you think |
| tokens/sec | Raw throughput |
| MFU | Hardware efficiency (achieved ÷ peak FLOPs); ~40–55% is good |
| GPU mem | Headroom before OOM |
12. Quick reference
VRAM ≈ params×16B (BF16+grad+Adam) + activations Fit bigger: BF16 + checkpointing + accumulation + FSDP/ZeRO(+offload) + 8-bit Adam Stability: warmup + cosine, clip grad-norm 1.0, BF16, z-loss; watch grad-norm Fine-tune cheap: QLoRA (4-bit NF4 base + LoRA, paged optimizer) Align: SFT → DPO (or PPO/GRPO) Data: dedup + quality filter + DECONTAMINATE + packing + SFT label masking Health: loss, grad-norm, lr, tokens/sec, MFU Resume: weights + optimizer + scheduler + RNG + data position ; checkpoint often
13. Interview Q&A
- Where does training memory go?Weights + grads + optimizer states (Adam ≈ 12 B/param) + activations. ~16 B/param model state in mixed precision before activations — why we shard and checkpoint.
- DDP vs FSDP/ZeRO?DDP replicates the full model per GPU and all-reduces grads (model must fit). FSDP/ZeRO shards optimizer/grads/params so each GPU holds a slice — fits bigger models at more comms.
- What do the ZeRO stages shard?1: optimizer states. 2: + gradients. 3: + parameters (all-gather per layer). Higher = less memory, more comms.
- How does gradient checkpointing help?Discards activations in forward, recomputes in backward — ~√N memory for ~33% more compute. Essential for long sequences.
- BF16 vs FP16?BF16 = FP32 exponent range, no loss scaling, fewer NaNs. FP16 finer but needs dynamic loss scaling. Default BF16 on Ampere+.
- Loss → NaN, how to debug?Check LR/warmup, add clip 1.0, switch BF16, inspect grad-norm before the spike + the failing batch. Bisect by lowering LR.
- LoRA vs full fine-tuning?LoRA/QLoRA: cheap, fast, mergeable, low forgetting, limited VRAM. Full FT: max quality, more compute.
- How does QLoRA fit big models on one GPU?4-bit NF4 frozen base + BF16 LoRA on top, double quantization, paged optimizer. Slashes weight memory; trains only small adapters.
- DPO vs RLHF/PPO?PPO needs a reward model + RL loop (powerful, unstable). DPO optimizes preferences directly via a simple loss with a reference model — simpler, stable default.
- What is MFU and why track it over GPU-util?Achieved FLOPs ÷ peak. GPU-util only means a kernel is resident; MFU exposes starvation/comms/no-Tensor-Core inefficiency. Target ~40–55%.
- Why checkpoint optimizer state?Adam momentum/variance define the trajectory; weights-only resume restarts the optimizer cold → loss bump. Save optimizer + scheduler + RNG + data position too.
- Why mask prompt tokens in SFT?You want the model to learn to produce the response, not to model the prompt. Compute loss only on response tokens.