← Cheatsheets

CHEATSHEET · GENAI · MODEL TRAINING

LLM Training — The Deep Scaling Cheatsheet.

training distributed fsdp fine-tuning
Training at scale is a memory and communication problem before it's a math problem. You spend VRAM on weights + gradients + optimizer states + activations, and you spend wall-clock on moving gradients/params between GPUs. Every technique here buys back one of those two. This sheet covers the memory math, parallelism, ZeRO/FSDP, stability, fine-tuning, alignment, data, checkpointing, and the health metrics that tell you it's working.

1. Where the VRAM goes

ConsumerCost (Adam, mixed precision)Notes
Weights2 B/param (BF16)The model itself.
Gradients2 B/paramOne per weight.
Optimizer states~12 B/paramAdam: FP32 momentum (4) + variance (4) + FP32 master weights (4).
Activations∝ batch × seq × layersOften the biggest, and the most controllable (checkpointing).

So ~16 B/param just for model states with Adam. A 7B model ≈ 112 GB before activations — which is why a single 80 GB GPU can't full-fine-tune 7B without sharding/offload, and why ZeRO/FSDP exist.

the "16 bytes/param" rule Memorize it: BF16 weights (2) + BF16 grads (2) + Adam states (12) ≈ 16 B/param for model state. Multiply by params for a first-order VRAM estimate, then add activations (cut with checkpointing) and the framework's workspace. 8-bit Adam roughly halves the optimizer term.

2. Parallelism strategies

StrategySplitsCommsWhen
Data parallel (DDP)The batch; full model replicatedAll-reduce grads / stepModel fits on one GPU; scale throughput
ZeRO / FSDPOptimizer → grads → paramsAll-gather params, reduce-scatter gradsModel too big to replicate
Tensor parallelEach layer's matmulsAll-reduce activations / layerHuge layers; NVLink intra-node
Pipeline parallelLayers into stagesActivations stage→stageCross-node; micro-batch to fill bubble
Sequence/context parallelThe sequence dimensionAttention commsVery long context
Expert parallelMoE expertsAll-to-all routingMixture-of-Experts

3D parallelism = TP × PP × DP, mapped to topology: TP within a node (fast NVLink), PP across nodes, DP (often ZeRO) on top. FSDP alone covers most sub-100B fine-tuning.

3. ZeRO / FSDP stages

  • Stage 1 — shard optimizer states across the DP group. Cheapest; biggest single win for Adam (the 12 B/param term).
  • Stage 2 — also shard gradients (reduce-scatter instead of all-reduce).
  • Stage 3 (FSDP full shard) — also shard parameters; each GPU holds a slice and all-gathers a layer's full params just-in-time for compute, frees after. Max memory savings, most comms.
  • Offload — push states/params/grads to CPU (ZeRO-Offload) or NVMe (ZeRO-Infinity) when sharding still doesn't fit. Big slowdown; last resort.
# PyTorch FSDP sketch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = FSDP(model, sharding_strategy=ShardingStrategy.FULL_SHARD,  # ZeRO-3
             mixed_precision=MixedPrecision(param_dtype=torch.bfloat16),
             auto_wrap_policy=transformer_auto_wrap_policy)
sharding trades memory for comms ZeRO-3 fits enormous models but all-gathers params every layer. On slow interconnect (PCIe, cross-node without IB/RoCE) comms dominates and GPUs idle. Match the stage to your hardware; prefer the lowest stage that still fits, and keep TP within a NVLink node.

4. Memory-saving techniques

  • Mixed precision (BF16) — halves weight/grad/activation bytes; BF16 avoids loss scaling.
  • Gradient (activation) checkpointing — don't store activations; recompute in backward. ~√N memory for ~33% more compute. The biggest lever for long sequences.
  • Gradient accumulation — large effective batch from several micro-batches before step().
  • 8-bit optimizer (bitsandbytes) — quantize Adam states; ~halves the 12 B/param term.
  • FlashAttention — removes O(N²) attention activation memory.
  • CPU/NVMe offload — capacity at the cost of speed.
  • Fused optimizer, fused layernorm — fewer kernels, slightly less memory traffic.

5. Training stability

  • Learning-rate warmup — ramp LR from ~0 over the first steps; skipping it is a classic early-divergence cause. Follow with cosine (or linear) decay.
  • Gradient clipping — clip global grad-norm (e.g. 1.0) so a spike doesn't blow up the run.
  • BF16 over FP16 — wider exponent range → no loss scaling, far fewer NaNs. FP16 needs a dynamic loss scaler.
  • Loss scaling (FP16) — multiply loss before backward to lift small grads out of underflow; unscale before clip/step; dynamic scaler backs off on overflow.
  • Stabilizers — z-loss (penalize logit magnitude), QK-norm, careful init (scaled by depth), embedding/LM-head tying, RMSNorm placement.
  • Watch grad-norm — it's the early-warning signal; a spike precedes a NaN.
loss → NaN: the usual suspects LR too high, no warmup, no grad clipping, FP16 overflow (switch to BF16 / fix scaler), bad data (corrupt/inf sample, empty or over-long sequence, label out of range), or division-by-zero in a custom layer. Bisect: lower LR, clip, switch to BF16, inspect the exact batch at the failing step.

6. Optimizers & schedules

  • AdamW — the default; decoupled weight decay. β₁≈0.9, β₂≈0.95 for LLMs (lower β₂ than vision), ε≈1e-8.
  • Fused / 8-bit Adam — faster / lower memory variants.
  • Schedules — warmup + cosine decay is standard; WSD (warmup-stable-decay) for continuable runs; linear for fine-tunes.
  • Weight decay — typically ~0.1; usually excluded from norms/biases/embeddings.
  • Batch size & LR — scale LR with batch (roughly linear/√), use the largest stable batch your throughput likes; gradient accumulation to reach it.

7. Fine-tuning (PEFT)

MethodWhatTrade
Full fine-tuneUpdate all weightsBest quality, most VRAM, forgetting risk
LoRAFreeze base, train low-rank A·B adapters (ΔW=(α/r)BA)~0.1–1% params; mergeable; cheap
QLoRA4-bit (NF4) frozen base + LoRA, double-quant, paged optimizerFine-tune big models on one GPU
Adapters / prefix / IA³ / (IA)³Other PEFT insertsVariants for different needs
DoRAWeight-decomposed LoRACloser to full FT quality

LoRA rank r (8–64) and alpha set capacity; target the attention (and often MLP) projections. Merge adapters for zero-overhead inference, or keep them separate for multi-LoRA serving.

8. Post-training / alignment

  • SFT — supervised fine-tune on instruction→response pairs (mask the prompt tokens in the loss). The foundation of a chat model.
  • RLHF (PPO) — train a reward model from human preference pairs, then optimize the policy against it with PPO + a KL penalty to the SFT reference. Powerful but complex/unstable (reward hacking, 4 models in memory).
  • DPO — Direct Preference Optimization: a classification-style loss directly on preference pairs that provably optimizes the RLHF objective without a reward model or RL loop. Simpler, stable, popular default.
  • Others — ORPO (no reference model), KTO (pointwise feedback), GRPO/RLVR (verifiable rewards for reasoning), constitutional/RLAIF (AI feedback).

9. Data & tokenization

  • Data is the product. Dedup (exact + near, MinHash), quality-filter (classifiers, heuristics), decontaminate against eval sets, and mix domains deliberately (curriculum / weighting).
  • Tokenization — BPE/SentencePiece; vocab size affects embedding/LM-head memory and Tensor-Core alignment (pad it). Watch for whitespace/special-token handling bugs.
  • Sequence packing — concatenate samples to fill the context and avoid padding waste; mask cross-document attention.
  • SFT label masking — only compute loss on the response tokens, not the prompt.

10. Checkpointing & resume

  • Long runs will hit node failures — checkpoint frequently and test the resume path before you depend on it.
  • Save weights + optimizer state + LR-scheduler step + RNG seeds + data-iterator/sampler position + step/token count. Weights-only resume restarts the optimizer cold → loss bump.
  • For sharded training use distributed/sharded checkpointing; use async checkpointing to avoid stalling the run.
  • Keep rolling + best checkpoints; verify integrity (a half-written checkpoint after a crash is worse than none).

11. Health metrics to log

MetricTells you
train/eval lossLearning + overfitting (gap)
grad-normStability; spike → incoming NaN
learning rateSchedule is doing what you think
tokens/secRaw throughput
MFUHardware efficiency (achieved ÷ peak FLOPs); ~40–55% is good
GPU memHeadroom before OOM

12. Quick reference

VRAM ≈ params×16B (BF16+grad+Adam) + activations
Fit bigger: BF16 + checkpointing + accumulation + FSDP/ZeRO(+offload) + 8-bit Adam
Stability: warmup + cosine, clip grad-norm 1.0, BF16, z-loss; watch grad-norm
Fine-tune cheap: QLoRA (4-bit NF4 base + LoRA, paged optimizer)
Align: SFT → DPO (or PPO/GRPO)
Data: dedup + quality filter + DECONTAMINATE + packing + SFT label masking
Health: loss, grad-norm, lr, tokens/sec, MFU
Resume: weights + optimizer + scheduler + RNG + data position ; checkpoint often

13. Interview Q&A

  • Where does training memory go?Weights + grads + optimizer states (Adam ≈ 12 B/param) + activations. ~16 B/param model state in mixed precision before activations — why we shard and checkpoint.
  • DDP vs FSDP/ZeRO?DDP replicates the full model per GPU and all-reduces grads (model must fit). FSDP/ZeRO shards optimizer/grads/params so each GPU holds a slice — fits bigger models at more comms.
  • What do the ZeRO stages shard?1: optimizer states. 2: + gradients. 3: + parameters (all-gather per layer). Higher = less memory, more comms.
  • How does gradient checkpointing help?Discards activations in forward, recomputes in backward — ~√N memory for ~33% more compute. Essential for long sequences.
  • BF16 vs FP16?BF16 = FP32 exponent range, no loss scaling, fewer NaNs. FP16 finer but needs dynamic loss scaling. Default BF16 on Ampere+.
  • Loss → NaN, how to debug?Check LR/warmup, add clip 1.0, switch BF16, inspect grad-norm before the spike + the failing batch. Bisect by lowering LR.
  • LoRA vs full fine-tuning?LoRA/QLoRA: cheap, fast, mergeable, low forgetting, limited VRAM. Full FT: max quality, more compute.
  • How does QLoRA fit big models on one GPU?4-bit NF4 frozen base + BF16 LoRA on top, double quantization, paged optimizer. Slashes weight memory; trains only small adapters.
  • DPO vs RLHF/PPO?PPO needs a reward model + RL loop (powerful, unstable). DPO optimizes preferences directly via a simple loss with a reference model — simpler, stable default.
  • What is MFU and why track it over GPU-util?Achieved FLOPs ÷ peak. GPU-util only means a kernel is resident; MFU exposes starvation/comms/no-Tensor-Core inefficiency. Target ~40–55%.
  • Why checkpoint optimizer state?Adam momentum/variance define the trajectory; weights-only resume restarts the optimizer cold → loss bump. Save optimizer + scheduler + RNG + data position too.
  • Why mask prompt tokens in SFT?You want the model to learn to produce the response, not to model the prompt. Compute loss only on response tokens.
← prev: LLM Inference next: GenAI Engineering →
© cvam — written in plaintext, served warm