In 2022, Tri Dao and four co-authors published FlashAttention, and it changed how every transformer on earth runs attention. One year later, Tri Dao published the sequel.
Alone. Single author. No co-authors.
And he doubled the speed.
The paper is called FlashAttention-2, and it's a masterclass in a specific kind of engineering: the algorithm didn't change much. The math is the same. What changed is how work is split between GPU threads. It turns out FlashAttention-1 was leaving half the GPU's performance on the table — not because of the algorithm, but because of how it scheduled work on the hardware.
Part 2 of the FlashAttention evolution series. Read Part 1 first if you haven't — this one builds directly on it.
The problem: FlashAttention-1 only used 25–40% of the GPU
FlashAttention was a breakthrough in reducing memory access. But here's the thing about GPUs: they're not one processor. An A100 has 108 streaming multiprocessors (SMs), each running dozens of warps (groups of 32 threads). Getting all of them to do useful work at the same time is an art.
FlashAttention-1 reached only 25–40% of the A100's theoretical maximum FLOPs/s. For context, a well-optimized matrix multiplication (GEMM) hits 80–90%. That gap is enormous — it's like having a 10-lane highway where only 3 lanes are open.
The bottleneck wasn't the tiling algorithm or the online softmax. It was two mundane-sounding problems: work partitioning and non-matmul FLOPs.
The big idea: make attention as efficient as matrix multiply
FlashAttention-2 asks: why can't attention be as fast as GEMM? Both are fundamentally matrix operations. GEMM gets 80–90% utilization. Attention was stuck at 25–40%. The difference is scheduling, not math.
Three changes close the gap:
1. Reduce non-matmul FLOPs
GPU tensor cores are absurdly fast at matrix multiplication — but they can only do matrix multiplication. Everything else (rescaling, exponentiation, comparisons for the online softmax) runs on slower, general-purpose cores.
In FlashAttention-1, these non-matmul operations took a surprising amount of time. It's like having a Formula 1 car but stopping at every intersection to check a paper map. The engine is fast. The navigation is slow.
FlashAttention-2 tweaks the algorithm to minimize rescaling operations. In the online softmax, every time you process a new block, you might need to rescale all previous results because a new maximum was found. FlashAttention-2 restructures the computation so this rescaling happens less often and uses fewer instructions when it does happen.
On modern GPUs, the ratio of matmul to non-matmul speed is roughly 16:1. Every non-matmul FLOP costs 16× as much wall-clock time. FlashAttention-2 obsesses over eliminating them.
2. Parallelize across sequence length, not just batch × heads
FlashAttention-1 parallelized work across the batch size and number of attention heads. One thread block per (batch, head) pair. Simple, clean, and wasteful when batch sizes are small.
Imagine you have 8 batches × 12 heads = 96 thread blocks, but 108 SMs. Twelve SMs are sitting idle. And with long sequences (where FlashAttention matters most), batch sizes tend to be small because sequences eat so much memory.
FlashAttention-2 adds parallelism along the sequence length dimension. It splits the query sequence into blocks and runs them on different thread blocks. Now even a single batch with a single head can saturate all 108 SMs.
Think of it like a restaurant. FlashAttention-1 assigns one waiter per table. If you only have 3 tables, most of your waitstaff is standing around. FlashAttention-2 says: each table is big enough that multiple waiters can serve different sections simultaneously.
3. Fix the warp-level work partitioning
This is the most subtle change, and arguably the most impactful.
Inside each thread block, there are multiple warps (groups of 32 threads). FlashAttention-1 split Q across warps and shared K and V. This meant that every warp needed to read the same K and V from shared memory, and they all needed to communicate to combine their partial results for the online softmax.
FlashAttention-2 flips this: split K and V across warps, and let each warp have access to the full Q block. Now each warp independently computes its part of the attention, and there's no need for communication between warps until the very end. Less communication = less waiting = faster.
It's the difference between four chefs sharing one cutting board versus each chef having their own. Same kitchen, same recipe, dramatically less elbow-bumping.
Fig 1 — FA-1 splits Q (requires warp sync). FA-2 splits KV (no sync needed).
Does it actually work?
Yes. Emphatically.
- 2× faster than FlashAttention-1 — across all sequence lengths and head dimensions
- 50–73% of theoretical maximum FLOPs/s on A100 (up from 25–40%)
- 225 TFLOPs/s per A100 GPU for GPT-style model training (72% model FLOPs utilization)
- Approaching GEMM efficiency — attention is finally almost as fast as pure matrix multiply
Read that third number again. 225 TFLOPs/s. That's not a microbenchmark — that's end-to-end training of a GPT-style model. For reference, FlashAttention-1 on the same hardware hit about 120–150 TFLOPs/s.
The surprise: the algorithm barely changed
Here's the thing that blew my mind. If you read both papers back-to-back, the core algorithm is almost identical. Same tiling. Same online softmax. Same kernel fusion. The difference is entirely in how the work is mapped to hardware.
This is a profound lesson. In GPU computing, having the right algorithm is maybe half the battle. The other half is understanding thread blocks, warps, shared memory banks, register pressure, occupancy — all the gritty hardware details that turn a theoretically fast algorithm into an actually fast one.
It's like having the world's best recipe but cooking it in a disorganized kitchen. FlashAttention-2 didn't change the recipe. It reorganized the kitchen.
Why should you care?
-
This is the version you're actually using. When you call
F.scaled_dot_product_attention()in PyTorch, it's FlashAttention-2 under the hood. The original FlashAttention was the proof of concept. This is the production version. - Single-author papers can move the world. Tri Dao, alone, doubled the speed of the most critical computation in all of modern AI. Sometimes deep hardware understanding from one person beats a large team.
- Utilization is everything. Your A100 costs $2/hour whether you use 25% of it or 73% of it. FlashAttention-2 nearly tripled the bang-per-buck of attention computation.
The one-paragraph version
FlashAttention-2 doesn't change the core algorithm — it changes how work is distributed across GPU threads. By reducing non-matmul operations, parallelizing across the sequence length dimension (not just batch × heads), and flipping warp partitioning from splitting Q to splitting KV, it doubles the speed of FlashAttention-1, reaching 50–73% of the A100's theoretical maximum. A single-author paper that made attention nearly as efficient as raw matrix multiply.
The napkin takeaway
If FlashAttention-1 was about which road to take, FlashAttention-2 is about how to drive:
- FlashAttention-1 = found the fast route (tiling, online softmax, kernel fusion)
- FlashAttention-2 = learned to use all lanes, shift at the right RPM, and stop braking unnecessarily
- Result = same car, same road, double the average speed
The best optimization isn't always a new algorithm. Sometimes it's just learning to use the hardware you already have.
Paper: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" — Tri Dao. Princeton/Stanford. 2023.
Next up: FlashAttention-3 — NVIDIA's new H100 GPU arrives, and FlashAttention learns asynchrony and FP8.