May 14, 2026 · paperjuice · 15 min read · 3300 words

FlashAttention-4 — When Tensor Cores Got Too Fast for Everything Else.

paperjuice ml attention gpu-optimization flashattention-series

Imagine you're building a factory. Every year you double the speed of the main assembly line — the welding robots, the painting arms, the precision cutters. But the conveyor belts? They run at the same speed. The forklift drivers? Same pace. The QA inspectors? Still doing things by hand.

Eventually, your assembly line is so fast that it finishes a product and then just… waits. For the conveyor. For the forklift. For the inspector. Your $50 million robot is idle 60% of the time because the $500 conveyor belt can't keep up.

That's exactly what happened with NVIDIA's Blackwell GPUs. And FlashAttention-4 is the paper that fixes it.

Part 4 of the FlashAttention evolution series — the finale. We've gone from memory hierarchy awareness to work partitioning to asynchronous pipelining. Now we face the hardest problem yet: the hardware itself is lopsided.

The problem: asymmetric hardware scaling

NVIDIA's Blackwell B200 doubled the tensor core throughput compared to the H100. 2.25 petaFLOPs of matmul power. That's 2,250,000,000,000,000 floating-point operations per second for matrix multiplication.

But shared memory bandwidth? Unchanged. The exponential unit (used for softmax)? Unchanged. Integer ALUs? Unchanged.

The paper calls this asymmetric hardware scaling. Tensor cores scale aggressively because they're the most important for AI workloads. Everything else gets left behind because silicon is expensive and power budgets are fixed.

The consequence for FlashAttention-3: running it on Blackwell, the shared memory traffic and exponential operations now take 25–60% longer than the matmul computation. The tensor cores finish their matrix multiply and then sit there, tapping their feet, waiting for softmax to catch up.

On Hopper, tensor cores were the bottleneck. On Blackwell, tensor cores are the fastest thing in the room — and everything else is the bottleneck.

FlashAttention-4's big idea: co-design for imbalance

If the hardware is lopsided, the algorithm must be lopsided too. FlashAttention-4 doesn't treat the GPU as a uniform compute resource. It explicitly identifies the slow parts and engineers around them.

1. Software-emulated exponentials — faking math to go faster

This is the wildest trick in the paper. The exponential unit on B200 can do 16 operations per clock cycle per SM. Tensor cores can do 8,192. That's a 512:1 ratio. The exp unit is hopelessly outgunned.

FlashAttention-4's solution: don't use the exponential unit for all exponentials. Instead, fake some of the exponentials using polynomial approximation on the FMA (fused multiply-add) units. These FMA units are faster than the exp unit and can run in parallel with it.

The approximation uses a degree-3 polynomial. Is it perfectly precise? No — at FP32 it has 600× more error than the hardware exponential. But here's the key insight: the output gets rounded to BF16 anyway, and BF16's quantization error is so large that it completely drowns out the polynomial's error. After rounding, the approximate and exact results are indistinguishable on 99% of inputs.

They don't even fake all the exponentials — just 10–25% of them. Enough to eliminate the bottleneck without introducing meaningful error. It's like replacing one out of four QA inspectors with a slightly less precise but much faster automated scanner. The overall quality doesn't change, but the line stops backing up.

2. Conditional softmax rescaling — skipping work that doesn't matter

Remember the online softmax from FlashAttention-1? Every time you process a new block, you check if the new maximum score is bigger than the old one. If it is, you have to rescale all previous results. This rescaling involves multiplying an entire output tile by a scaling factor — expensive.

FlashAttention-4 notices that most of the time, rescaling barely matters. If the new maximum is only slightly larger than the old one, the rescaling factor is close to 1.0, and skipping it introduces negligible error.

So it sets a threshold: only rescale when the difference exceeds log₂(256) = 8.0. Below that threshold, skip the rescale and just keep accumulating. At the very end, apply one final correction to get the exact result. This eliminates most rescaling operations entirely.

It's like checking your bank balance after every purchase. Sure, you could — but if you bought a $3 coffee, it probably doesn't affect your budget. Check only when the purchase is big. Same final balance at the end of the month.

The best optimization isn't making something faster — it's realizing you don't need to do it at all.

3. Tensor memory (TMEM) + 2-CTA MMA — less shared memory traffic

Blackwell introduces a new type of on-chip memory called tensor memory (TMEM) — 256 KB per SM, specifically designed to hold intermediate tensor core results. Unlike shared memory, TMEM is directly wired to the tensor cores. Matmul outputs go straight to TMEM without passing through shared memory or registers.

FlashAttention-4 stores its attention score matrices (S, P) and accumulator (O) in TMEM instead of registers. This frees up register space (reducing the horrific "register pressure" that plagued Hopper kernels) and eliminates shared memory traffic for intermediate results.

For the backward pass, shared memory is the dominant bottleneck — 30% more cycles than matmul. FlashAttention-4 uses Blackwell's 2-CTA MMA mode: two CTAs (cooperative thread arrays) pair up to execute a single larger matrix multiply. Each CTA loads only half the operand, halving shared memory traffic. And because the dQ computation gets split across the pair, the number of global atomic reductions is also halved.

4. Written entirely in Python — 30× faster to compile

Here's a fact that would make any CUDA developer's jaw drop. FlashAttention 1, 2, and 3 were all written in C++ with CUDA template metaprogramming. Hundreds of template instantiations. Compile times measured in minutes.

FlashAttention-4 is written entirely in CuTe-DSL, embedded in Python. Full GPU kernel expressivity — direct PTX access, warp specialization, tensor memory management — but in Python syntax with JIT compilation.

The result: compile time dropped from 55 seconds to 2.5 seconds for a single kernel. That's a 22× speedup in compile time. For the backward pass kernel: 45 seconds → 1.4 seconds (32×).

This isn't just a convenience. FlashAttention kernels need to be precompiled for hundreds of attention variants (different head dims, causal vs non-causal, different precisions). Cutting compile time by 30× means researchers can iterate on kernel designs in seconds instead of waiting through coffee-length compile cycles.

The FlashAttention Evolution: 4 Years, 13× Faster FA-1 2022 ~120 TF/s A100 FA-2 2023 225 TF/s A100 FA-3 2024 740 TF/s H100 FA-4 2026 1613 TF/s B200 IO-awareness tiling work partition seq parallelism warp specialize FP8 + async fake exp + TMEM Python DSL Bottleneck that each version solved: memory bandwidth GPU utilization sync overhead non-matmul units Fig 1 — Four papers, four bottlenecks, four solutions. Each unlocked the next ceiling.

Fig 1 — The FlashAttention timeline: from 120 TFLOPs/s to 1,613 TFLOPs/s in four papers.

Does it actually work?

  • Up to 1.3× faster than cuDNN 9.13 — NVIDIA's own optimized library
  • Up to 2.7× faster than Triton on B200
  • 1,613 TFLOPs/s in BF16 — 71% of theoretical max on B200
  • 20–30× faster compile times vs FlashAttention-3 (CuTe-DSL vs C++)
  • Deterministic backward pass with only ~25% performance overhead — crucial for RL training reproducibility

1,613 TFLOPs/s. Let me put that in context. FlashAttention-1, four years ago, did about 120. That's a 13× improvement. Yes, some of that is faster hardware. But the B200 is "only" about 2× faster in raw tensor core throughput than the A100. The other 6.5× came from software — from four iterations of increasingly clever engineering that squeezes out every last drop of performance.

The surprise: NVIDIA adopted the techniques

Here's something remarkable buried in the appendix. After FlashAttention-4 was released, NVIDIA incorporated many of its techniques into cuDNN (their official library). Later cuDNN versions (9.14+) now perform similarly to FlashAttention-4 because they are FlashAttention-4, essentially.

An academic project didn't just beat the GPU vendor's optimized library — the vendor adopted the academic's techniques. That's the ultimate validation.

The evolution: looking back at all four papers

The FlashAttention series tells a beautiful story about the cat-and-mouse game between software and hardware:

  1. FA-1 (2022): The bottleneck was memory bandwidth. Solution: stop shuffling data between HBM and SRAM. Result: 2–4× speedup.
  2. FA-2 (2023): The bottleneck was GPU utilization. Solution: better work partitioning across threads. Result: 2× over FA-1.
  3. FA-3 (2024): The bottleneck was synchronous execution. Solution: overlap everything with warp specialization and pipelining. Result: 1.5–2× over FA-2.
  4. FA-4 (2026): The bottleneck was non-matmul units. Solution: fake exponentials, skip rescaling, use new memory types. Result: ~2× over FA-3.

Each paper solved one bottleneck and revealed the next. That's how performance engineering works: you're never done, you just find the next ceiling.

Why should you care?

  1. The age of "write once, run anywhere" GPU code is over. FlashAttention-4 is written specifically for Blackwell's architectural quirks — TMEM, 2-CTA MMA, asymmetric scaling. Code that ran great on H100 leaves performance on the table on B200.
  2. Python GPU kernels are real. CuTe-DSL proves you can write peak-performance GPU kernels in Python. The 30× compile time improvement isn't a toy — it fundamentally changes how fast researchers can iterate.
  3. The bottleneck will keep shifting. Today it's the exponential unit. Tomorrow it might be register bandwidth or L2 cache pressure. The lesson from four FlashAttention papers: the only constant is that something else will always be too slow.

The one-paragraph version

Blackwell GPUs doubled tensor core speed but left everything else the same — making shared memory and exponential operations the new bottleneck. FlashAttention-4 fixes this with software-emulated exponentials (polynomial approximation on FMA units), conditional softmax rescaling that skips unnecessary work, tensor memory to reduce shared memory traffic, and 2-CTA MMA to halve backward pass overhead. Written entirely in Python via CuTe-DSL (30× faster compile times), it reaches 1,613 TFLOPs/s on B200 — a 13× improvement over the original FlashAttention from 2022.

The napkin takeaway

The full FlashAttention saga in one metaphor — building a faster kitchen:

  • FA-1 = stopped walking to the pantry for every ingredient (tiling, kernel fusion)
  • FA-2 = gave every chef their own station instead of sharing one counter (work partitioning)
  • FA-3 = installed a conveyor belt so prep and cooking happen simultaneously (async pipelining)
  • FA-4 = the oven is now so fast that you need faster timers, bigger trays, and a shortcut for measuring spices (asymmetric scaling workarounds)

Four years. Four papers. One idea — make attention respect the hardware it runs on. From 120 TFLOPs/s to 1,613. The math never changed. The engineering changed everything.

Paper: "FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling" — Ted Zadouri, Markus Hoehnerbach, Jay Shah, Timmy Liu, Vijay Thakkar, Tri Dao. Princeton / Meta / Colfax / NVIDIA / Georgia Tech / Together AI. 2026.

← FlashAttention-3 Paper Juice X Algorithm →
© cvam — written in plaintext, served warm