Advanced ⏱️ 5 min

πŸŽ“ What is Flash Attention?

Memory-efficient attention algorithm that enables longer context and faster inference

What is Flash Attention?

Flash Attention is an optimized attention algorithm that computes exact attention with O(N) memory instead of O(NΒ²), enabling longer context lengths and 2-4x faster training and inference on modern GPUs.

The Attention Bottleneck

Standard attention has two major issues:

  1. Memory: Stores full NΓ—N attention matrix (quadratic memory)
  2. Speed: Memory bandwidth becomes the bottleneck, not compute

For a 32K context model, the attention matrix alone would require 4GB of memory per layer!

How Flash Attention Works

Flash Attention uses tiling and recomputation:

  1. Tiling: Process attention in blocks that fit in GPU SRAM
  2. Kernel Fusion: Combine operations to minimize memory transfers
  3. Recomputation: Recompute values during backward pass instead of storing

This achieves the same mathematical result with dramatically less memory.

Performance Impact

MetricStandard AttentionFlash Attention
MemoryO(NΒ²)O(N)
Speed1x2-4x faster
Max Context~8K128K+
Training ThroughputBaseline2-3x higher

Flash Attention Versions

VersionKey Features
v1Tiling, online softmax
v2Better parallelism, 2x faster
v3Hopper GPU optimizations (H100)

Enabling Flash Attention

Most modern frameworks support Flash Attention:

# Hugging Face Transformers
model = AutoModelForCausalLM.from_pretrained(
    "model_name",
    attn_implementation="flash_attention_2"
)

Requirements

  • GPU: NVIDIA Ampere (A100) or newer recommended
  • CUDA: 11.6+
  • PyTorch: 2.0+
TechniqueDescription
PagedAttentionUsed in vLLM for serving
Ring AttentionDistributes attention across GPUs
Sliding WindowLimits attention to local context

πŸ•ΈοΈ Knowledge Mesh