May 11, 2026 · paperjuice · 14 min read · 3100 words

FlashAttention — What If Your GPU Has Been Reading Memory Wrong This Whole Time.

paperjuice ml attention gpu-optimization flashattention-series

You just trained a transformer. It took 3 days, $4,000 in GPU time, and you're pretty sure at least 40% of that compute was wasted shuffling numbers between two types of memory your GPU has — and you didn't even know it was happening.

A Stanford PhD student named Tri Dao knew. And in 2022, he wrote a paper that made every GPU on earth run attention faster — without changing a single thing about the math. Same inputs, same outputs, exact same attention scores. Just… faster. And using way less memory.

This is the paper that started a revolution. Four versions later, FlashAttention is inside every major ML framework on the planet. Let me squeeze the juice out of where it all began.

This is Part 1 of a 4-part series covering the entire FlashAttention evolution — from the original 2022 insight to the 2026 Blackwell-optimized FlashAttention-4.

The problem: attention treats your GPU like a single box

Standard attention does this: compute Q×KT to get scores (an N×N matrix), apply softmax, multiply by V. Simple math. The problem isn't the math — it's where the math happens.

Your GPU has two types of memory. HBM (High Bandwidth Memory) is the big one — 40–80 GB on an A100. It's where your tensors live. Then there's SRAM — tiny (about 20 MB total across all streaming multiprocessors), but 10x faster.

Standard attention computes the full N×N score matrix, writes it to HBM, reads it back for softmax, writes the softmax result to HBM, reads it back again for the V multiply. That's like writing a letter, mailing it to yourself, waiting for the mailman, opening it, then writing another letter with your reply.

You live in the same house. Just walk to the other room.

The bottleneck in attention isn't computation — it's the constant round-trips between fast memory and slow memory. FlashAttention eliminates the trips.

FlashAttention's big idea: tiling + IO-awareness

The core insight is almost embarrassingly simple: don't materialize the N×N attention matrix. Instead, compute attention in small tiles that fit entirely in SRAM, and never write the intermediate results to slow HBM.

Two ideas make this work:

1. Tiling — chunking the unchunkable

Matrix multiplication is easy to tile. You can multiply blocks of A with blocks of B and combine the results. But attention has softmax in the middle — and softmax needs to see all the scores in a row to compute the denominator. You can't just split it into blocks… right?

Wrong. This is where the "online softmax" trick comes in. Instead of computing softmax over the entire row at once, you can do it incrementally — processing one block at a time, keeping a running maximum and a running sum of exponentials. After processing all blocks, you have the exact same result as if you'd computed it all at once.

Think of it like calculating your class average. You don't need everyone's grade at once. You can keep a running total and a count, updating as each student turns in their exam. Same final answer, but you never need the full list in front of you.

2. IO-awareness — respecting the memory hierarchy

Most algorithms are designed to minimize FLOPs — the number of arithmetic operations. FlashAttention is designed to minimize IO — the number of bytes shuffled between HBM and SRAM.

This is the heretical part. FlashAttention actually does more total computation than standard attention (it recomputes some values in the backward pass instead of storing them). But it does far fewer memory accesses. And on modern GPUs, memory access is the bottleneck, not arithmetic.

It's like choosing to drive 5 extra miles on the highway instead of sitting in 2 miles of stop-and-go traffic. More total distance, less total time.

FlashAttention's key heresy: doing more math can be faster than doing less, if the extra math avoids slow memory access.

3. Kernel fusion — one trip, not five

In standard attention, every operation — matmul, masking, softmax, dropout — is a separate GPU "kernel" (a fancy word for GPU operation). Each kernel reads from HBM, does its thing, writes back to HBM. Five operations means five round-trips.

FlashAttention fuses everything into a single kernel. Load Q, K, V blocks into SRAM once. Compute scores, apply mask, do softmax, apply dropout, multiply by V — all in SRAM. Write the final output to HBM once. One trip instead of five.

Standard Attention Q×Kᵀ HBM softmax HBM write → read → write → read → write → read 5 HBM round-trips per attention layer FlashAttention HBM SRAM: fused score+mask+soft +drop+V mult HBM (output) 1 HBM round-trip. Same result. Fig 1 — Standard attention vs FlashAttention memory access pattern.

Fig 1 — Standard attention does 5 HBM round-trips. FlashAttention does 1.

Does it actually work?

The numbers don't whisper. They shout.

  • 15% faster training on BERT-large (seq length 512), beating the MLPerf 1.1 speed record
  • 3× speedup on GPT-2 (seq length 1K) vs HuggingFace and Megatron-LM baselines
  • 2.4× speedup on long-range arena (seq length 1K–4K)
  • Memory: O(N) instead of O(N²) — linear in sequence length, not quadratic
  • 5–20× fewer HBM accesses than standard attention

But the most stunning result? FlashAttention enabled sequence lengths that were previously impossible. The Path-X challenge (seq length 16K) had never been solved by any transformer. FlashAttention got 61.4% accuracy. Path-256 (seq length 64K): 63.1%. First transformers ever to crack better-than-chance on these benchmarks.

The surprise: faster AND better models

Here's what I didn't expect. FlashAttention doesn't just make the same models faster — it makes better models possible.

Because it reduces memory from O(N²) to O(N), you can train with much longer context windows. Longer context means the model sees more of the document, more of the conversation, more of the codebase. GPT-2 with FlashAttention got 0.7 better perplexity just from being able to use longer sequences. Long-document classification improved by 6.4 points.

It's like upgrading from reading one page at a time to reading entire chapters. Of course you understand the book better.

Why should you care?

  1. Every major framework uses this now. PyTorch, JAX, Hugging Face — they all ship FlashAttention. If you've trained a model recently, you've probably already benefited from this paper without knowing it.
  2. Sequence length is no longer a hard wall. Before FlashAttention, 2K tokens was painful and 16K was impossible. This paper is why you can now chat with models that remember your entire conversation.
  3. The insight generalizes beyond attention. "Respect the memory hierarchy" is a principle that applies to any GPU-bound computation. If your code is slow, the bottleneck might not be math — it might be memory traffic.

The one-paragraph version

Standard attention wastes most of its time shuffling data between fast and slow GPU memory. FlashAttention fixes this by computing attention in small tiles that fit entirely in fast SRAM, using an online softmax trick to get exact results without ever materializing the full N×N attention matrix. The result: 2–4× faster training, linear memory instead of quadratic, and the ability to handle sequence lengths (16K–64K) that were previously impossible for transformers.

The napkin takeaway

If attention is a library:

  • Standard attention = carrying one book at a time between the shelf and your desk, making a trip for every single page
  • Approximate attention methods = reading the CliffsNotes instead of the actual books (faster, but you miss things)
  • FlashAttention = bringing a small stack of books to your desk, reading them all there, and only making one trip back to the shelf when you're done — same books, same words, zero shortcuts

Same library. Same books. Wildly fewer trips.

Paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" — Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré. Stanford/Berkeley. NeurIPS 2022.

Next up: FlashAttention-2 — same guy, working alone, doubles the speed by fixing how GPUs split work between threads.

FlashAttention-2 Paper Juice →
© cvam — written in plaintext, served warm