Part 9
Completed
FlashAttention Comparison
Implemented FlashAttention and FlashAttention-2 from scratch, understanding the tiling, softmax rescaling, and memory access patterns that make it IO-aware. Benchmarked against standard attention across sequence lengths, batch sizes, and hardware configurations.
What I Built
Implemented FlashAttention and FlashAttention-2 from scratch, understanding the tiling, softmax rescaling, and memory access patterns that make it IO-aware. Benchmarked against standard attention across sequence lengths, batch sizes, and hardware configurations.
Key Concepts
FlashAttentionTilingSoftmax RescalingIO-AwarenessMemory Access PatternsKernel Fusion
Architecture
1
Tiled Attention Kernel2
Softmax Online Normalizer3
Memory Scheduler4
Benchmark SuiteResults
2.4x speedup and 10x memory reduction at 8k sequence length. Speedup increases with sequence length, reaching 7.8x at 32k.
Key Learnings
- Memory access patterns matter more than FLOPs
- Online softmax normalization is the key algorithmic insight
- Hardware-aware design beats naive algorithmic improvements
Challenges
- Implementing efficient tiling without materializing full attention matrix
- Handling edge cases in block sizes
- Profiling memory bandwidth vs. compute utilization