Home
π€FFPA: Yet another Faster Flash Prefill Attention
with O(1)β‘οΈGPU SRAM complexity for large headdimπ
FFPA(Split-D): Yet another Faster Flash Prefill Attention with Split-D strategy, achieve O(1) SRAM complexity and O(d/4) register complexity for large headdim (> 256), 1.5~3x π faster than SDPA. ππThe Core features:
| Self Attn | GQA/MQA | Cross Attn | Causal/Mask | Dropout | Headdim | Fwd/Bwd |
|---|---|---|---|---|---|---|
βοΈ(Nq=Nkv) |
βοΈ(Hq!=Hkv) |
βοΈ(Nq!=Nkv) |
βοΈ(attn_mask) |
βοΈ(p>0) |
320~1024 | 1.5~3xβ |
π Quick Start¶
First, install the prebuilt package from PyPI or build ffpa-attn from source:
# Fisrt, install the prebuilt package from PyPI
pip3 install -U ffpa-attn # CUDA 13.0+, PyTorch 2.11+
# Or, build ffpa-attn from source, just follow the cmds
git clone https://github.com/xlite-dev/ffpa-attn.git
# Then, build the wheel package (Triton + CuTeDSL backends)
cd ffpa-attn && pip3 install -e . --no-build-isolation
# Optional: install ffpa-attn w/ CUDA backend (forward only)
ENABLE_FFPA_CUDA_IMPL=1 MAX_JOBS=32 pip3 install -e .
Then, try to accelerate the attention for large headdim with just one-line of code:
>>> import torch.nn.functional as F
>>> from ffpa_attn import ffpa_attn_func
>>> # Monkey-patch SDPA to point to FFPA. Every thing that FFPA
>>> # does not support will auto fallback to SDPA: D <= 256, etc.
>>> F.scaled_dot_product_attention = ffpa_attn_func # one-line code
For more advanced features, please refer to our online docs at πffpa-attn.io.
Note
FFPA supports cross-attention where the query seqlen Nq may differ from the key/value seqlen Nkv, GQA / MQA attention where Q has Nh_q heads and K/V have Nh_kv heads (requires Nh_q % Nh_kv == 0; group size = Nh_q / Nh_kv), and causal attention (pass is_causal=True; queries are aligned to the KV tail, i.e. Q row r attends to k <= r + (Nkv - Nq), which requires Nkv >= Nq). K/V must share the same Nh_kv and Nkv. enable_gqa now defaults to False to match SDPA exactly, so GQA/MQA usage must pass enable_gqa=True explicitly.
Minimal usage example β Self-Attention (B=1, H=32, N=8192, D=512):
import torch
import torch.nn.functional as F
from ffpa_attn import ffpa_attn_func
# D: 32, 64, ..., 320, ..., 1024 (FA-2 <= 256, FFPA supports up to 1024).
B, H, N, D = 1, 32, 8192, 512 # batch_size, num_heads, seq_len, head_dim
q = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda")
k = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda")
v = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda")
# FFPA self attention; layout follows SDPA: (B, H, N, D).
out = ffpa_attn_func(q, k, v) # -> torch.Tensor of shape (B, H, N, D)
print(out.shape, out.dtype)
ref = F.scaled_dot_product_attention(q, k, v)
print(f"vs SDPA max_abs_err={(out - ref).abs().max().item():.4e}")
Cross-Attention or Decoding-Attention example (short query, long KV cache; Nq != Nkv):
import torch
import torch.nn.functional as F
from ffpa_attn import ffpa_attn_func
# Short-query / long-KV, e.g. incremental decoding or cross-attention:
# Q: [B, H, Nq, D], K/V: [B, H, Nkv, D]; Nq can differ from Nkv but Nk==Nv required.
B, H, D = 1, 8, 512
Nq, Nkv = 128, 8192
q = torch.randn(B, H, Nq, D, dtype=torch.bfloat16, device="cuda")
k = torch.randn(B, H, Nkv, D, dtype=torch.bfloat16, device="cuda")
v = torch.randn(B, H, Nkv, D, dtype=torch.bfloat16, device="cuda")
out = ffpa_attn_func(q, k, v) # -> (B, H, Nq, D) = (1, 8, 128, 512)
print(out.shape, out.dtype)
ref = F.scaled_dot_product_attention(q, k, v)
print(f"vs SDPA max_abs_err={(out - ref).abs().max().item():.4e}")
Grouped-Query / Multi-Query Attention example (Q has more heads than K/V):
import torch
import torch.nn.functional as F
from ffpa_attn import ffpa_attn_func
# GQA: Q has Nh_q heads, K/V share Nh_kv heads; group_size = Nh_q / Nh_kv.
# Typical Llama-3-style 32/8 ratio; MQA is the Nh_kv==1 special case.
# FFPA targets large headdim so we use D=512 here (FA-2 tops out at D=256).
# enable_gqa defaults to False, so opt into GQA semantics explicitly.
B, D, Nq, Nkv = 1, 512, 1024, 4096
Nh_q, Nh_kv = 32, 8 # group_size = 4
q = torch.randn(B, Nh_q, Nq, D, dtype=torch.bfloat16, device="cuda")
k = torch.randn(B, Nh_kv, Nkv, D, dtype=torch.bfloat16, device="cuda")
v = torch.randn(B, Nh_kv, Nkv, D, dtype=torch.bfloat16, device="cuda")
out = ffpa_attn_func(q, k, v, enable_gqa=True) # -> (B, Nh_q, Nq, D) = (1, 32, 1024, 512)
print(out.shape, out.dtype)
# Reference: replicate K/V along head dim to match Q's head count.
group_size = Nh_q // Nh_kv
k_ref = k.repeat_interleave(group_size, dim=1)
v_ref = v.repeat_interleave(group_size, dim=1)
ref = F.scaled_dot_product_attention(q, k_ref, v_ref)
print(f"vs SDPA max_abs_err={(out - ref).abs().max().item():.4e}")
Causal Attention example (self-attention causal; also supports chunked / decoding prefill with Nkv > Nq):
import torch
import torch.nn.functional as F
from ffpa_attn import ffpa_attn_func
# Causal self-attention: Q row r attends to k <= r (standard triangular mask).
# FFPA is tuned for large headdim, so we keep D=512 as in the self-attn example.
B, H, N, D = 1, 8, 4096, 512
q = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda")
k = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda")
v = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda")
out = ffpa_attn_func(q, k, v, is_causal=True)
print(out.shape, out.dtype)
ref = F.scaled_dot_product_attention(q, k, v, is_causal=True)
print(f"vs SDPA max_abs_err={(out - ref).abs().max().item():.4e}")
# Chunked / decoding prefill: Nq < Nkv, queries aligned to the KV tail
# so Q row r attends to k <= r + (Nkv - Nq). Requires Nkv >= Nq.
# This example keeps D=512 so it stays on the FFPA large-D path. For D <= 256,
# ffpa_attn_func forwards the inputs directly to SDPA without synthesizing a
# causal cross-attention mask for you.
Nq, Nkv = 128, 8192
q = torch.randn(B, H, Nq, D, dtype=torch.bfloat16, device="cuda")
k = torch.randn(B, H, Nkv, D, dtype=torch.bfloat16, device="cuda")
v = torch.randn(B, H, Nkv, D, dtype=torch.bfloat16, device="cuda")
out = ffpa_attn_func(q, k, v, is_causal=True)
print(out.shape, out.dtype) # (1, 8, 128, 512)
Backward Pass example (compare dQ / dK / dV against SDPA):
import math
import torch
import torch.nn.functional as F
from ffpa_attn import ffpa_attn_func
# Focus on a large-headdim case where FFPA is typically used.
B, H, N, D = 1, 32, 8192, 512
scale = 1.0 / math.sqrt(D)
q = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda", requires_grad=True)
k = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda", requires_grad=True)
v = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda", requires_grad=True)
out = ffpa_attn_func(q, k, v, scale=scale)
out.sum().backward()
dq = q.grad.detach().clone()
dk = k.grad.detach().clone()
dv = v.grad.detach().clone()
q_ref = q.detach().clone().requires_grad_(True)
k_ref = k.detach().clone().requires_grad_(True)
v_ref = v.detach().clone().requires_grad_(True)
out_ref = F.scaled_dot_product_attention(q_ref, k_ref, v_ref, scale=scale)
out_ref.sum().backward()
print(f"dQ vs SDPA dQ max_abs_err={(dq - q_ref.grad).abs().max().item():.4e}")
print(f"dK vs SDPA dK max_abs_err={(dk - k_ref.grad).abs().max().item():.4e}")
print(f"dV vs SDPA dV max_abs_err={(dv - v_ref.grad).abs().max().item():.4e}")
π Split-D¶
We extend FlashAttention to support large headdim ($D>256$) via fine-grained tiling at the MMA level for $QK^\top$ and $PV$ matrix multiplication, referred to as Split-D. This design keeps SRAM usage fixed at $B_r \times 16$ (with $B_r=B_c$) for Q, K and V, yielding constant SRAM complexity $O(B_r \times 16) \approx O(1)$ and register complexity $O(d/4)$.
FFPA enables headdim > 256, and outperforms standard SDPA by 1.5~3xπ.
Note
FFPA has been tested on Ampere, Ada, Hopper, and Blackwell architectures (e.g., A30, L20, 4090, H200, 5090), achieves 1.5~3Γβπ speedup over SDPA. FFPA is mainly design for prefill and large headdim, and may not be faster than SDPA for π small sequence length (N<512) or small headdim (D<=256).
π Benchmark¶
Runnable benchmark are provided under bench. The performance benchmarks for the NVIDIA L20 (Ada), NVIDIA Geforce RTX 5090 (Blackwell), NVIDIA H800 PCIE (Hopper), NVIDIA H200 SXM (Hopper, CuTeDSL backend, up to 427 TFLOPS!π) with large headdims can be found at bench.



π€ Backends¶
FFPA supports multiple backends for the forward and backward pass, including: SDPA (baseline), CUDA (forward only), Triton, and CuTeDSL. The CuTeDSL backend is currently in early stage and has some constraints, but it can achieve up to 427π TFLOPS on H200! Stay tuned for future updates.
| Backend | Arch | Fwd | Bwd | Headdim | Autotune | Speedup | Recommend |
|---|---|---|---|---|---|---|---|
| SDPA | sm>=75 | β | β | All | β | 1.0xπ€ | sm>=75 |
| CUDA | sm>=80 | β | β | 320~1024 | β | 1.5x~3xπ | sm80~89,120 |
| Triton | sm>=80 | β | β | 320~1024 | β | 1.5x~5xπ | sm>=80 |
| CuTeDSL | sm>=80 | β | β | 320~1024 | β | 1.5x~2xπ | sm80~89,120 |
| CuTeDSL | sm90 | β | β | 320~512 | β | 3x~6xπ | sm90 |
Special thanks to Butterfingrz for contributing to the CuTeDSL backend! Awesome work!π
How to use different backends for your own scenario? Users can simply pass the Backend configs (SDPABackend, CUDABackend, TritonBackend or CuTeDSLBackend) to ffpa_attn_func, for example:
>>> from ffpa_attn import ffpa_attn_func, CuTeDSLBackend
>>> # CuTeDSL backend, D=512 scenario, fastest on H200!π
>>> o = ffpa_attn_func(q, k, v, backend=CuTeDSLBackend())
Β©οΈLicense¶
Apache License 2.0
Β©οΈCitations¶
@misc{ffpa-attn@2025,
title={FFPA: Yet another Faster Flash Prefill Attention for large headdim.},
url={https://github.com/xlite-dev/ffpa-attn.git},
note={Open-source software available at https://github.com/xlite-dev/ffpa-attn.git},
author={DefTruth, Butterfingrz},
year={2025}
}