Training failures fall into five families: loss is wrong (NaN / diverging / not decreasing), OOM, too slow (low MFU), distributed hangs, and can't reproduce / resume. Watch loss, grad-norm, LR, and tokens/sec — they tell you which family before you dig.
Loss → NaN / Inf
Symptom. Loss becomes nan/inf, often after a grad-norm spike.
- LR too high / no warmup → add warmup, lower peak LR. Most common cause.
- No gradient clipping → clip global grad-norm (1.0). A spike precedes the NaN.
- FP16 overflow → switch to BF16 (wider range, no loss scaling), or fix the dynamic loss scaler.
- Bad data → a corrupt sample, inf/NaN in inputs, empty sequence, or label out of range. Inspect the exact batch at the failing step.
- Numerics in custom layers → log(0), divide-by-zero, sqrt of negative, unstable softmax; add epsilons.
torch.autograd.set_detect_anomaly(True) # find the op that makes NaN (slow) # log the early-warning signal print(step, loss.item(), torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).item())
grad-norm is your smoke alarm
Plot grad-norm. A sudden spike one step before the NaN tells you it's an optimization blow-up
(LR/clip/precision), not a data bug. A NaN with flat grad-norm points at a bad sample or a layer
numerics bug instead.
Loss not decreasing
- LR too low / too high → run an LR range test; too low crawls, too high bounces.
- Data/label bug → shuffled labels, wrong loss masking (e.g. not masking prompt tokens in SFT), targets shifted wrong. Overfit a single batch first — if it can't, it's a bug, not tuning.
- Frozen params → check
requires_grad; verify the optimizer actually has the params (common with LoRA/PEFT wiring). - Bad init / no normalization → check init scheme, layernorm placement.
overfit one batch — the fastest sanity check
Before scaling, train on a single batch repeatedly. Loss should drop to ~0. If it can't memorize one
batch, the bug is in the model/loss/data wiring — no hyperparameter will save it.
OOM during training
- Activations dominate → gradient checkpointing; reduce batch/seq; BF16.
- Optimizer states (Adam ≈ 12 B/param) → FSDP/ZeRO sharding; 8-bit Adam; CPU offload.
- OOM only in backward / later step → fragmentation or accumulation holding graphs;
expandable_segments:True, don't retain the graph. - OOM at validation → wrap eval in
torch.no_grad(). - Effective big batch without VRAM → gradient accumulation.
Training too slow (low MFU)
- Data starvation → more DataLoader workers, prefetch, pin_memory, sequence packing to kill padding waste.
- Host↔device syncs → remove
.item()/prints from the hot loop; log every N steps. - No fast path → BF16/TF32,
torch.compile, FlashAttention, fused optimizer; pad dims to multiples of 8/16. - Comms-bound (multi-GPU) → ZeRO stage too high for the interconnect, or no compute/comms overlap. Lower the stage if it still fits; check NVLink vs PCIe with
nvidia-smi topo -m. - Profile a few steps with
nsys— look for gaps (starvation), syncs, and NCCL not overlapping.
Distributed hang
Symptom. Run freezes (often at start or first all-reduce); no progress, no crash.
export NCCL_DEBUG=INFO export TORCH_DISTRIBUTED_DEBUG=DETAIL # flags collective mismatches
- Collective mismatch → ranks run different code paths (a conditional, uneven batches, logging only on rank 0 inside a collective). All ranks must call the same collectives in the same order.
- Uneven data → last batch differs per rank → one rank does an extra/short step. Use
DistributedSampler; drop_last or thejoin()context. - A rank died silently → an OOM/exception on one rank leaves the rest at the barrier until timeout. The real traceback is in that rank's log — search all of them.
- Init hang → wrong
MASTER_ADDR/PORT, world size, or a blocked rendezvous port.
Can't reproduce / resume
- Resume restarts the optimizer → you saved weights but not optimizer state (Adam momentum/variance) or the LR-scheduler/step. Checkpoint all of them.
- Data position lost → save the sampler/data iterator state too, or you re-see the same early data.
- Nondeterminism → set seeds,
torch.use_deterministic_algorithms(True); note some CUDA kernels remain nondeterministic and that's expected. - Checkpoint cadence → long runs hit node failures; checkpoint often and test the resume path before you need it.
Quick reference
# health signals to log every run loss, grad_norm, lr, tokens/sec, MFU, eval_loss # NaN → warmup + clip 1.0 + BF16 + inspect failing batch torch.autograd.set_detect_anomaly(True) # not learning → overfit one batch; check masking/requires_grad # OOM → checkpointing + accumulation + FSDP/ZeRO + no_grad eval # slow → workers/prefetch, kill .item() syncs, torch.compile, BF16 # hang → NCCL_DEBUG=INFO, TORCH_DISTRIBUTED_DEBUG=DETAIL, grep all ranks # resume → save weights+optimizer+scheduler+sampler state