← Debug Guides

DEBUG GUIDE · GENAI · TRAINING PLAYBOOK

Debugging LLM & Model Training.

training distributed nan ml-engineering
Training failures fall into five families: loss is wrong (NaN / diverging / not decreasing), OOM, too slow (low MFU), distributed hangs, and can't reproduce / resume. Watch loss, grad-norm, LR, and tokens/sec — they tell you which family before you dig.

Loss → NaN / Inf

Symptom. Loss becomes nan/inf, often after a grad-norm spike.

  • LR too high / no warmup → add warmup, lower peak LR. Most common cause.
  • No gradient clipping → clip global grad-norm (1.0). A spike precedes the NaN.
  • FP16 overflow → switch to BF16 (wider range, no loss scaling), or fix the dynamic loss scaler.
  • Bad data → a corrupt sample, inf/NaN in inputs, empty sequence, or label out of range. Inspect the exact batch at the failing step.
  • Numerics in custom layers → log(0), divide-by-zero, sqrt of negative, unstable softmax; add epsilons.
torch.autograd.set_detect_anomaly(True)   # find the op that makes NaN (slow)
# log the early-warning signal
print(step, loss.item(), torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).item())
grad-norm is your smoke alarm Plot grad-norm. A sudden spike one step before the NaN tells you it's an optimization blow-up (LR/clip/precision), not a data bug. A NaN with flat grad-norm points at a bad sample or a layer numerics bug instead.

Loss not decreasing

  • LR too low / too high → run an LR range test; too low crawls, too high bounces.
  • Data/label bug → shuffled labels, wrong loss masking (e.g. not masking prompt tokens in SFT), targets shifted wrong. Overfit a single batch first — if it can't, it's a bug, not tuning.
  • Frozen params → check requires_grad; verify the optimizer actually has the params (common with LoRA/PEFT wiring).
  • Bad init / no normalization → check init scheme, layernorm placement.
overfit one batch — the fastest sanity check Before scaling, train on a single batch repeatedly. Loss should drop to ~0. If it can't memorize one batch, the bug is in the model/loss/data wiring — no hyperparameter will save it.

OOM during training

  • Activations dominate → gradient checkpointing; reduce batch/seq; BF16.
  • Optimizer states (Adam ≈ 12 B/param) → FSDP/ZeRO sharding; 8-bit Adam; CPU offload.
  • OOM only in backward / later step → fragmentation or accumulation holding graphs; expandable_segments:True, don't retain the graph.
  • OOM at validation → wrap eval in torch.no_grad().
  • Effective big batch without VRAM → gradient accumulation.

Training too slow (low MFU)

  • Data starvation → more DataLoader workers, prefetch, pin_memory, sequence packing to kill padding waste.
  • Host↔device syncs → remove .item()/prints from the hot loop; log every N steps.
  • No fast path → BF16/TF32, torch.compile, FlashAttention, fused optimizer; pad dims to multiples of 8/16.
  • Comms-bound (multi-GPU) → ZeRO stage too high for the interconnect, or no compute/comms overlap. Lower the stage if it still fits; check NVLink vs PCIe with nvidia-smi topo -m.
  • Profile a few steps with nsys — look for gaps (starvation), syncs, and NCCL not overlapping.

Distributed hang

Symptom. Run freezes (often at start or first all-reduce); no progress, no crash.

export NCCL_DEBUG=INFO
export TORCH_DISTRIBUTED_DEBUG=DETAIL    # flags collective mismatches
  • Collective mismatch → ranks run different code paths (a conditional, uneven batches, logging only on rank 0 inside a collective). All ranks must call the same collectives in the same order.
  • Uneven data → last batch differs per rank → one rank does an extra/short step. Use DistributedSampler; drop_last or the join() context.
  • A rank died silently → an OOM/exception on one rank leaves the rest at the barrier until timeout. The real traceback is in that rank's log — search all of them.
  • Init hang → wrong MASTER_ADDR/PORT, world size, or a blocked rendezvous port.

Can't reproduce / resume

  • Resume restarts the optimizer → you saved weights but not optimizer state (Adam momentum/variance) or the LR-scheduler/step. Checkpoint all of them.
  • Data position lost → save the sampler/data iterator state too, or you re-see the same early data.
  • Nondeterminism → set seeds, torch.use_deterministic_algorithms(True); note some CUDA kernels remain nondeterministic and that's expected.
  • Checkpoint cadence → long runs hit node failures; checkpoint often and test the resume path before you need it.

Quick reference

# health signals to log every run
loss, grad_norm, lr, tokens/sec, MFU, eval_loss
# NaN → warmup + clip 1.0 + BF16 + inspect failing batch
torch.autograd.set_detect_anomaly(True)
# not learning → overfit one batch; check masking/requires_grad
# OOM → checkpointing + accumulation + FSDP/ZeRO + no_grad eval
# slow → workers/prefetch, kill .item() syncs, torch.compile, BF16
# hang → NCCL_DEBUG=INFO, TORCH_DISTRIBUTED_DEBUG=DETAIL, grep all ranks
# resume → save weights+optimizer+scheduler+sampler state
← prev: LLM Inference all debug guides →
© cvam — written in plaintext, served warm