Try Live
Add Docs
Rankings
Pricing
Enterprise
Docs
Install
Install
Docs
Pricing
Enterprise
More...
More...
Try Live
Rankings
Add Docs
FlashAttention
https://github.com/dao-ailab/flash-attention
Admin
FlashAttention is a fast and memory-efficient exact attention implementation with IO-awareness for
...
Tokens:
34,548
Snippets:
294
Trust Score:
7.8
Update:
1 week ago
Context
Skills
Chat
Benchmark
65.9
Suggestions
Latest
Show doc for...
Code
Info
Show Results
Context Summary (auto-generated)
Raw
Copy
Link
# FlashAttention FlashAttention is a fast, memory-efficient, IO-aware implementation of exact scaled dot-product attention for deep learning. Unlike standard attention that materializes the full N×N attention matrix, FlashAttention tiles the computation across SRAM blocks, achieving linear memory complexity in sequence length instead of quadratic. Version 2.8.4 supports NVIDIA CUDA (Ampere, Ada, Hopper, Blackwell) and AMD ROCm GPUs, fp16/bf16 data types, head dimensions up to 256, and integrates with `torch.compile`. The library exposes a family of attention functions covering the most common training and inference patterns: packed QKV (faster when Q/K/V are already stacked), separate Q/K/V (supports MQA/GQA), variable-length batches (sequences of different lengths packed without padding), and a KV-cache variant for autoregressive decoding with optional rotary embedding application inside the kernel. It also ships higher-level `nn.Module` wrappers (`FlashSelfAttention`, `FlashCrossAttention`, `MHA`), fused layer norm, fused dense layers, rotary positional embeddings, and an optimized cross-entropy loss, forming a complete toolkit for efficient transformer training. --- ## Installation ```bash # Install from PyPI (pre-built wheels, recommended) pip install flash-attn --no-build-isolation # Limit parallel compile jobs on memory-constrained machines MAX_JOBS=4 pip install flash-attn --no-build-isolation # AMD ROCm (Triton backend) FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pip install --no-build-isolation . # From source python setup.py install ``` --- ## `flash_attn_func` — Standard separate Q/K/V attention The primary API for training. Accepts separate Q, K, V tensors and supports MQA/GQA when K/V have fewer heads than Q. The causal mask is aligned to the bottom-right corner when `seqlen_q != seqlen_k`. ```python import torch from flash_attn import flash_attn_func batch, seqlen, nheads, nheads_kv, headdim = 2, 1024, 16, 4, 64 device = "cuda" q = torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=torch.float16, requires_grad=True) k = torch.randn(batch, seqlen, nheads_kv, headdim, device=device, dtype=torch.float16, requires_grad=True) v = torch.randn(batch, seqlen, nheads_kv, headdim, device=device, dtype=torch.float16, requires_grad=True) # Forward — causal GQA (4 KV heads, 16 Q heads) out = flash_attn_func( q, k, v, dropout_p=0.0, # set > 0 only during training softmax_scale=None, # defaults to 1/sqrt(headdim) causal=True, # autoregressive mask window_size=(-1, -1), # (-1,-1) = full context; (128, 0) = sliding window softcap=0.0, # > 0 activates Gemma-2-style softcapping alibi_slopes=None, # (nheads,) fp32 tensor for ALiBi deterministic=False, # True = deterministic bwd, slightly slower ) # out: (batch, seqlen, nheads, headdim) assert out.shape == (batch, seqlen, nheads, headdim) # Backward works automatically loss = out.sum() loss.backward() # Sliding window (local) attention — Mistral-style out_local = flash_attn_func(q, k, v, causal=True, window_size=(512, 0)) # ALiBi slopes slopes = torch.tensor([2**(-i/nheads) for i in range(nheads)], device=device, dtype=torch.float32) out_alibi = flash_attn_func(q, k, v, alibi_slopes=slopes) ``` --- ## `flash_attn_qkvpacked_func` — Packed QKV attention Faster than `flash_attn_func` when Q, K, V share the same sequence length, because the backward pass avoids explicit gradient concatenation. ```python import torch from flash_attn import flash_attn_qkvpacked_func batch, seqlen, nheads, headdim = 4, 512, 8, 128 qkv = torch.randn(batch, seqlen, 3, nheads, headdim, device="cuda", dtype=torch.bfloat16, requires_grad=True) out = flash_attn_qkvpacked_func( qkv, dropout_p=0.1, # applied during training softmax_scale=None, causal=True, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, deterministic=False, return_attn_probs=False, # if True, also returns softmax_lse and S_dmask (for testing) ) # out: (batch, seqlen, nheads, headdim) # Retrieve attention probabilities for debugging (not guaranteed accurate scaling) out, softmax_lse, S_dmask = flash_attn_qkvpacked_func( qkv, dropout_p=0.1, causal=True, return_attn_probs=True ) # softmax_lse: (batch, nheads, seqlen) # S_dmask: (batch, nheads, seqlen, seqlen) ``` --- ## `flash_attn_kvpacked_func` — Packed KV attention (cross-attention / GQA) Use when K and V are already stacked into one tensor; supports GQA/MQA (fewer KV heads than Q heads). ```python import torch from flash_attn import flash_attn_kvpacked_func batch, seqlen_q, seqlen_k = 2, 128, 1024 nheads_q, nheads_kv, headdim = 16, 2, 64 q = torch.randn(batch, seqlen_q, nheads_q, headdim, device="cuda", dtype=torch.float16) kv = torch.randn(batch, seqlen_k, 2, nheads_kv, headdim, device="cuda", dtype=torch.float16) out = flash_attn_kvpacked_func( q, kv, dropout_p=0.0, softmax_scale=None, causal=False, # cross-attention is typically non-causal window_size=(-1, -1), ) # out: (batch, seqlen_q, nheads_q, headdim) ``` --- ## `flash_attn_varlen_func` — Variable-length (unpadded) attention For batches with sequences of different lengths packed into a single tensor (no padding). Requires cumulative sequence length arrays. ```python import torch from flash_attn import flash_attn_varlen_func from flash_attn.bert_padding import unpad_input, pad_input batch, nheads, headdim = 3, 8, 64 seqlens = [128, 256, 192] # different lengths per sample max_seqlen = max(seqlens) device = "cuda" # Simulate padded input and attention mask hidden = torch.randn(batch, max_seqlen, nheads * headdim, device=device, dtype=torch.float16) attention_mask = torch.zeros(batch, max_seqlen, dtype=torch.bool, device=device) for i, s in enumerate(seqlens): attention_mask[i, :s] = True # Unpad to remove padding tokens hidden_unpadded, indices, cu_seqlens, max_s, _ = unpad_input(hidden, attention_mask) # hidden_unpadded: (total_tokens, nheads * headdim) total_tokens = hidden_unpadded.shape[0] q = torch.randn(total_tokens, nheads, headdim, device=device, dtype=torch.float16) k = torch.randn(total_tokens, nheads, headdim, device=device, dtype=torch.float16) v = torch.randn(total_tokens, nheads, headdim, device=device, dtype=torch.float16) # cu_seqlens: (batch + 1,) int32, e.g. [0, 128, 384, 576] out = flash_attn_varlen_func( q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_s, max_seqlen_k=max_s, dropout_p=0.0, causal=True, ) # out: (total_tokens, nheads, headdim) # Re-pad output back to (batch, max_seqlen, nheads, headdim) out_padded = pad_input(out, indices, batch, max_seqlen) ``` --- ## `flash_attn_varlen_qkvpacked_func` — Variable-length packed QKV Like `flash_attn_varlen_func` but accepts QKV pre-stacked in one tensor. More efficient backward pass. ```python import torch from flash_attn import flash_attn_varlen_qkvpacked_func total_tokens, nheads, headdim = 576, 8, 64 seqlens = [128, 256, 192] cu_seqlens = torch.tensor([0, 128, 384, 576], dtype=torch.int32, device="cuda") max_seqlen = 256 qkv = torch.randn(total_tokens, 3, nheads, headdim, device="cuda", dtype=torch.float16, requires_grad=True) out = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=0.0, causal=True, ) # out: (total_tokens, nheads, headdim) out.sum().backward() # gradients flow back to qkv ``` --- ## `flash_attn_with_kvcache` — Inference with KV cache Optimized for autoregressive decoding. Updates the KV cache in-place, optionally applies rotary embeddings inside the kernel, and supports paged (PagedAttention) KV caches via `block_table`. ```python import torch from flash_attn import flash_attn_with_kvcache batch, nheads, nheads_kv, headdim = 4, 16, 2, 64 seqlen_cache = 2048 # pre-allocated cache length seqlen_new = 1 # single new token per step (typical decoding) device = "cuda" # Pre-allocate KV cache (contiguous last dim required) k_cache = torch.zeros(batch, seqlen_cache, nheads_kv, headdim, device=device, dtype=torch.float16) v_cache = torch.zeros(batch, seqlen_cache, nheads_kv, headdim, device=device, dtype=torch.float16) # Track how many tokens are already in each cache entry cache_seqlens = torch.tensor([100, 200, 150, 50], dtype=torch.int32, device=device) # New query/key/value for the current decoding step q = torch.randn(batch, seqlen_new, nheads, headdim, device=device, dtype=torch.float16) k = torch.randn(batch, seqlen_new, nheads_kv, headdim, device=device, dtype=torch.float16) v = torch.randn(batch, seqlen_new, nheads_kv, headdim, device=device, dtype=torch.float16) # Atomically: append k/v to cache, then attend over full cached sequence out = flash_attn_with_kvcache( q, k_cache, v_cache, k=k, v=v, # new tokens appended at cache_seqlens positions cache_seqlens=cache_seqlens, causal=True, softmax_scale=None, ) # out: (batch, seqlen_new, nheads, headdim); k_cache/v_cache updated in-place # --- Paged KV cache (PagedAttention) --- page_size = 256 # must be multiple of 256 num_blocks = 64 k_cache_paged = torch.zeros(num_blocks, page_size, nheads_kv, headdim, device=device, dtype=torch.float16) v_cache_paged = torch.zeros(num_blocks, page_size, nheads_kv, headdim, device=device, dtype=torch.float16) block_table = torch.randint(0, num_blocks, (batch, 8), dtype=torch.int32, device=device) out_paged = flash_attn_with_kvcache( q, k_cache_paged, v_cache_paged, k=k, v=v, cache_seqlens=cache_seqlens, block_table=block_table, causal=True, ) # --- With rotary embedding applied inside the kernel --- rotary_dim = headdim cos = torch.randn(seqlen_cache, rotary_dim // 2, device=device, dtype=torch.float16) sin = torch.randn(seqlen_cache, rotary_dim // 2, device=device, dtype=torch.float16) out_rotary = flash_attn_with_kvcache( q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, rotary_cos=cos, rotary_sin=sin, rotary_interleaved=False, # True = GPT-J style, False = GPT-NeoX style causal=True, ) ``` --- ## `unpad_input` / `pad_input` — Variable-length packing utilities Convert between padded `(batch, seqlen, ...)` tensors and unpadded `(total_tokens, ...)` representations required by the varlen attention functions. ```python import torch from flash_attn.bert_padding import unpad_input, pad_input batch, max_seqlen, hidden_dim = 4, 256, 512 device = "cuda" hidden = torch.randn(batch, max_seqlen, hidden_dim, device=device, dtype=torch.float16) # attention_mask: 1 = real token, 0 = padding attention_mask = torch.ones(batch, max_seqlen, dtype=torch.bool, device=device) attention_mask[0, 200:] = False # seq 0 has length 200 attention_mask[2, 100:] = False # seq 2 has length 100 hidden_unpadded, indices, cu_seqlens, max_s, seqused = unpad_input(hidden, attention_mask) # hidden_unpadded: (total_real_tokens, hidden_dim) — no padding rows # indices: (total_real_tokens,) — original flat positions # cu_seqlens: (batch + 1,) int32 # max_s: int (== 256 here since seqs 1 and 3 are full) # seqused: (batch,) int32 actual per-seq lengths print(hidden_unpadded.shape) # (200 + 256 + 100 + 256, hidden_dim) # After attention, re-insert results into padded layout out_unpadded = torch.randn_like(hidden_unpadded) out_padded = pad_input(out_unpadded, indices, batch, max_seqlen) # out_padded: (batch, max_seqlen, hidden_dim) padding positions are zero ``` --- ## `FlashSelfAttention` — nn.Module wrapper for self-attention A drop-in `nn.Module` that dispatches to `flash_attn_qkvpacked_func` (padded input) or `flash_attn_varlen_qkvpacked_func` (unpadded input) based on arguments. ```python import torch from flash_attn.modules.mha import FlashSelfAttention batch, seqlen, nheads, headdim = 2, 512, 8, 64 device = "cuda" self_attn = FlashSelfAttention( causal=True, softmax_scale=None, attention_dropout=0.1, window_size=(-1, -1), alibi_slopes=None, deterministic=False, ).to(device) self_attn.eval() # disables dropout qkv = torch.randn(batch, seqlen, 3, nheads, headdim, device=device, dtype=torch.float16) # Padded path out = self_attn(qkv) # out: (batch, seqlen, nheads, headdim) # Unpadded path — pass cu_seqlens and max_seqlen total_tokens = batch * seqlen qkv_unpadded = torch.randn(total_tokens, 3, nheads, headdim, device=device, dtype=torch.float16) cu_seqlens = torch.arange(0, (batch + 1) * seqlen, seqlen, dtype=torch.int32, device=device) out_unpadded = self_attn(qkv_unpadded, cu_seqlens=cu_seqlens, max_seqlen=seqlen) # out_unpadded: (total_tokens, nheads, headdim) ``` --- ## `FlashCrossAttention` — nn.Module wrapper for cross-attention Wraps `flash_attn_kvpacked_func` / `flash_attn_varlen_kvpacked_func` for encoder-decoder cross-attention. ```python import torch from flash_attn.modules.mha import FlashCrossAttention batch, seqlen_q, seqlen_k = 2, 64, 512 nheads_q, nheads_kv, headdim = 8, 2, 64 device = "cuda" cross_attn = FlashCrossAttention( causal=False, softmax_scale=None, attention_dropout=0.0, ).to(device) cross_attn.eval() q = torch.randn(batch, seqlen_q, nheads_q, headdim, device=device, dtype=torch.float16) kv = torch.randn(batch, seqlen_k, 2, nheads_kv, headdim, device=device, dtype=torch.float16) out = cross_attn(q, kv) # out: (batch, seqlen_q, nheads_q, headdim) ``` --- ## `MHA` — Full multi-head attention module with projections Complete MHA layer including QKV projection, optional rotary embeddings, optional ALiBi, and output projection. Supports tensor parallelism and inference KV-cache. ```python import torch from flash_attn.modules.mha import MHA batch, seqlen, embed_dim = 2, 256, 512 nheads = 8 device = "cuda" mha = MHA( embed_dim=embed_dim, num_heads=nheads, num_heads_kv=None, # None = standard MHA; set to < nheads for GQA causal=True, dropout=0.0, rotary_emb_dim=64, # apply rotary embedding to first 64 dims rotary_emb_base=10000.0, alibi_slopes=None, window_size=(-1, -1), use_flash_attn=True, bias=True, layer_idx=None, ).to(device).half() x = torch.randn(batch, seqlen, embed_dim, device=device, dtype=torch.float16) # Training forward out, _ = mha(x) # out: (batch, seqlen, embed_dim) # Inference with KV cache past_kv = None for step in range(4): tok = torch.randn(batch, 1, embed_dim, device=device, dtype=torch.float16) out_step, past_kv = mha(tok, inference_params=None) # or pass an InferenceParams object ``` --- ## `RotaryEmbedding` — Rotary positional embeddings Efficient Triton-based rotary embedding module supporting GPT-NeoX and GPT-J rotation styles, XPos scaling, and inference offset. ```python import torch from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb device = "cuda" headdim, seqlen, batch, nheads = 128, 512, 2, 8 # Module-based usage rotary_emb = RotaryEmbedding( dim=headdim, base=10000.0, interleaved=False, # False = GPT-NeoX (split halves); True = GPT-J (interleaved pairs) scale_base=None, # set to 512 to enable XPos ).to(device) q = torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=torch.float16) k = torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=torch.float16) # Apply to separate Q and K q_rot = rotary_emb(q) # (cos/sin buffers are auto-updated when seqlen grows) # Functional API — apply precomputed cos/sin cos = torch.randn(seqlen, headdim // 2, device=device) sin = torch.randn(seqlen, headdim // 2, device=device) q_rot = apply_rotary_emb( q, cos, sin, interleaved=False, inplace=False, seqlen_offsets=0, # shift for KV-cache inference ) # q_rot: (batch, seqlen, nheads, headdim) ``` --- ## `CrossEntropyLoss` — Fused Triton cross-entropy loss Memory-efficient cross-entropy that avoids materializing the full probability distribution. Supports label smoothing, z-loss regularization, and tensor-parallel vocab shards. ```python import torch from flash_attn.losses.cross_entropy import CrossEntropyLoss batch, vocab_size = 8 * 1024, 50257 # typical LM scenario device = "cuda" loss_fn = CrossEntropyLoss( ignore_index=-100, reduction="mean", label_smoothing=0.0, lse_square_scale=1e-4, # z-loss coefficient (z-loss = lse^2) inplace_backward=True, # saves memory; modifies logits in-place during bwd process_group=None, # set for tensor-parallel vocab shards return_z_loss=True, ) logits = torch.randn(batch, vocab_size, device=device, dtype=torch.float16, requires_grad=True) targets = torch.randint(0, vocab_size, (batch,), device=device) targets[100:110] = -100 # ignored positions loss, z_loss = loss_fn(logits, targets) # loss: scalar (mean over non-ignored tokens) # z_loss: scalar (logged separately; no gradient) loss.backward() ``` --- ## `dropout_add_layer_norm` — Fused dropout + residual + layer norm Single-kernel operation combining dropout, residual addition, and layer normalization. Substantially reduces memory bandwidth compared to three sequential ops. ```python import torch from flash_attn.ops.layer_norm import dropout_add_layer_norm batch_tokens, hidden_dim = 4096, 2048 device = "cuda" dtype = torch.float16 x0 = torch.randn(batch_tokens, hidden_dim, device=device, dtype=dtype, requires_grad=True) residual = torch.randn(batch_tokens, hidden_dim, device=device, dtype=dtype) weight = torch.ones(hidden_dim, device=device, dtype=dtype) bias = torch.zeros(hidden_dim, device=device, dtype=dtype) out, residual_out = dropout_add_layer_norm( x0, residual, weight, bias, dropout_p=0.1, epsilon=1e-5, rowscale=None, # optional per-token scale (batch_tokens,) colscale=None, # optional per-feature scale (hidden_dim,) residual_in_fp32=False, is_rms_norm=False, # True → RMSNorm instead of LayerNorm return_dropout_mask=False, ) # out: (batch_tokens, hidden_dim) — normalized output # residual_out: (batch_tokens, hidden_dim) — x0 + residual (pre-norm, for next layer) ``` --- ## `FusedDense` — Fused linear layer Optimized dense (linear) layer using a custom CUDA kernel with optional bias. Intended as a drop-in replacement for `nn.Linear` in performance-critical paths. ```python import torch from flash_attn.ops.fused_dense import FusedDense batch, in_features, out_features = 1024, 2048, 4096 device = "cuda" linear = FusedDense(in_features, out_features, bias=True).to(device).half() x = torch.randn(batch, in_features, device=device, dtype=torch.float16) y = linear(x) # y: (batch, out_features) ``` --- ## Using with 🤗 `kernels` Library FlashAttention 2 and 3 can be loaded without manual installation via Hugging Face `kernels`. ```python # pip install kernels from kernels import get_kernel # FlashAttention-2 fa2 = get_kernel("kernels-community/flash-attn2", version=1) out = fa2.flash_attn_func(q, k, v, causal=True) # FlashAttention-3 (Hopper) fa3 = get_kernel("kernels-community/flash-attn3", version=1) out = fa3.flash_attn_func(q, k, v, causal=True) ``` --- ## Summary FlashAttention's primary use case is accelerating transformer training and inference: it directly replaces `torch.nn.functional.scaled_dot_product_attention` with up to 3–8× speedup and 10–20× memory reduction at long sequence lengths (4K–16K+), enabling larger batch sizes, longer contexts, and models that would otherwise OOM. The varlen functions are especially impactful when training on datasets with highly variable sequence lengths, eliminating all wasted computation on padding tokens. The KV-cache function is the recommended kernel for single-step autoregressive decoding, handling cache updates, rotary embeddings, and paged memory management in one fused call. For integration, the typical pattern is to swap the attention operation in an existing transformer: replace the `q @ k.T ... softmax ... @ v` block with `flash_attn_func`, optionally replacing layer norms with `dropout_add_layer_norm` and linear layers with `FusedDense` for end-to-end kernel fusion. The `MHA` module provides a self-contained drop-in for the entire multi-head attention block including projections. The library is fully compatible with PyTorch autograd, `torch.compile` (PyTorch ≥ 2.4), mixed-precision training via `torch.cuda.amp`, and distributed training patterns including tensor parallelism through the `ColumnParallelLinear` / `RowParallelLinear` variants of `FusedDense`.