Flash Attention - Fast Memory-Efficient Attention

## Quick start

25 stars

Best use case

Flash Attention - Fast Memory-Efficient Attention is best used when you need a repeatable AI agent workflow instead of a one-off prompt.

## Quick start

Teams using Flash Attention - Fast Memory-Efficient Attention should expect a more consistent output, faster repeated execution, less prompt rewriting.

When to use this skill

  • You want a reusable workflow that can be run more than once with consistent structure.

When not to use this skill

  • You only need a quick one-off answer and do not need a reusable workflow.
  • You cannot install or maintain the underlying files, dependencies, or repository context.

Installation

Claude Code / Cursor / Codex

$curl -o ~/.claude/skills/flash-attention/SKILL.md --create-dirs "https://raw.githubusercontent.com/ComeOnOliver/skillshub/main/skills/Orchestra-Research/AI-Research-SKILLs/flash-attention/SKILL.md"

Manual Installation

  1. Download SKILL.md from GitHub
  2. Place it in .claude/skills/flash-attention/SKILL.md inside your project
  3. Restart your AI agent — it will auto-discover the skill

How Flash Attention - Fast Memory-Efficient Attention Compares

Feature / AgentFlash Attention - Fast Memory-Efficient AttentionStandard Approach
Platform SupportNot specifiedLimited / Varies
Context Awareness High Baseline
Installation ComplexityUnknownN/A

Frequently Asked Questions

What does this skill do?

## Quick start

Where can I find the source code?

You can find the source code on GitHub using the link provided at the top of the page.

SKILL.md Source

# Flash Attention - Fast Memory-Efficient Attention

## Quick start

Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.

**PyTorch native (easiest, PyTorch 2.2+)**:
```python
import torch
import torch.nn.functional as F

q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)  # [batch, heads, seq, dim]
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)

# Automatically uses Flash Attention if available
out = F.scaled_dot_product_attention(q, k, v)
```

**flash-attn library (more features)**:
```bash
pip install flash-attn --no-build-isolation
```

```python
from flash_attn import flash_attn_func

# q, k, v: [batch, seqlen, nheads, headdim]
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
```

## Common workflows

### Workflow 1: Enable in existing PyTorch model

Copy this checklist:

```
Flash Attention Integration:
- [ ] Step 1: Check PyTorch version (≥2.2)
- [ ] Step 2: Enable Flash Attention backend
- [ ] Step 3: Verify speedup with profiling
- [ ] Step 4: Test accuracy matches baseline
```

**Step 1: Check PyTorch version**

```bash
python -c "import torch; print(torch.__version__)"
# Should be ≥2.2.0
```

If <2.2, upgrade:
```bash
pip install --upgrade torch
```

**Step 2: Enable Flash Attention backend**

Replace standard attention:
```python
# Before (standard attention)
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
out = attn_weights @ v

# After (Flash Attention)
import torch.nn.functional as F
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
```

Force Flash Attention backend:
```python
with torch.backends.cuda.sdp_kernel(
    enable_flash=True,
    enable_math=False,
    enable_mem_efficient=False
):
    out = F.scaled_dot_product_attention(q, k, v)
```

**Step 3: Verify speedup with profiling**

```python
import torch.utils.benchmark as benchmark

def test_attention(use_flash):
    q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]

    if use_flash:
        with torch.backends.cuda.sdp_kernel(enable_flash=True):
            return F.scaled_dot_product_attention(q, k, v)
    else:
        attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
        return attn @ v

# Benchmark
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())

print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
```

Expected: 2-4x speedup for sequences >512 tokens.

**Step 4: Test accuracy matches baseline**

```python
# Compare outputs
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]

# Flash Attention
out_flash = F.scaled_dot_product_attention(q, k, v)

# Standard attention
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
out_standard = attn_weights @ v

# Check difference
diff = (out_flash - out_standard).abs().max()
print(f"Max difference: {diff:.6f}")
# Should be <1e-3 for float16
```

### Workflow 2: Use flash-attn library for advanced features

For multi-query attention, sliding window, or H100 FP8.

Copy this checklist:

```
flash-attn Library Setup:
- [ ] Step 1: Install flash-attn library
- [ ] Step 2: Modify attention code
- [ ] Step 3: Enable advanced features
- [ ] Step 4: Benchmark performance
```

**Step 1: Install flash-attn library**

```bash
# NVIDIA GPUs (CUDA 12.0+)
pip install flash-attn --no-build-isolation

# Verify installation
python -c "from flash_attn import flash_attn_func; print('Success')"
```

**Step 2: Modify attention code**

```python
from flash_attn import flash_attn_func

# Input: [batch_size, seq_len, num_heads, head_dim]
# Transpose from [batch, heads, seq, dim] if needed
q = q.transpose(1, 2)  # [batch, seq, heads, dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)

out = flash_attn_func(
    q, k, v,
    dropout_p=0.1,
    causal=True,  # For autoregressive models
    window_size=(-1, -1),  # No sliding window
    softmax_scale=None  # Auto-scale
)

out = out.transpose(1, 2)  # Back to [batch, heads, seq, dim]
```

**Step 3: Enable advanced features**

Multi-query attention (shared K/V across heads):
```python
from flash_attn import flash_attn_func

# q: [batch, seq, num_q_heads, dim]
# k, v: [batch, seq, num_kv_heads, dim]  # Fewer KV heads
out = flash_attn_func(q, k, v)  # Automatically handles MQA
```

Sliding window attention (local attention):
```python
# Only attend to window of 256 tokens before/after
out = flash_attn_func(
    q, k, v,
    window_size=(256, 256),  # (left, right) window
    causal=True
)
```

**Step 4: Benchmark performance**

```python
import torch
from flash_attn import flash_attn_func
import time

q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]

# Warmup
for _ in range(10):
    _ = flash_attn_func(q, k, v)

# Benchmark
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
    out = flash_attn_func(q, k, v)
    torch.cuda.synchronize()
end = time.time()

print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
```

### Workflow 3: H100 FP8 optimization (FlashAttention-3)

For maximum performance on H100 GPUs.

```
FP8 Setup:
- [ ] Step 1: Verify H100 GPU available
- [ ] Step 2: Install flash-attn with FP8 support
- [ ] Step 3: Convert inputs to FP8
- [ ] Step 4: Run with FP8 attention
```

**Step 1: Verify H100 GPU**

```bash
nvidia-smi --query-gpu=name --format=csv
# Should show "H100" or "H800"
```

**Step 2: Install flash-attn with FP8 support**

```bash
pip install flash-attn --no-build-isolation
# FP8 support included for H100
```

**Step 3: Convert inputs to FP8**

```python
import torch

q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)

# Convert to float8_e4m3 (FP8)
q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.to(torch.float8_e4m3fn)
```

**Step 4: Run with FP8 attention**

```python
from flash_attn import flash_attn_func

# FlashAttention-3 automatically uses FP8 kernels on H100
out = flash_attn_func(q_fp8, k_fp8, v_fp8)
# Result: ~1.2 PFLOPS, 1.5-2x faster than FP16
```

## When to use vs alternatives

**Use Flash Attention when:**
- Training transformers with sequences >512 tokens
- Running inference with long context (>2K tokens)
- GPU memory constrained (OOM with standard attention)
- Need 2-4x speedup without accuracy loss
- Using PyTorch 2.2+ or can install flash-attn

**Use alternatives instead:**
- **Standard attention**: Sequences <256 tokens (overhead not worth it)
- **xFormers**: Need more attention variants (not just speed)
- **Memory-efficient attention**: CPU inference (Flash Attention needs GPU)

## Common issues

**Issue: ImportError: cannot import flash_attn**

Install with no-build-isolation flag:
```bash
pip install flash-attn --no-build-isolation
```

Or install CUDA toolkit first:
```bash
conda install cuda -c nvidia
pip install flash-attn --no-build-isolation
```

**Issue: Slower than expected (no speedup)**

Flash Attention benefits increase with sequence length:
- <512 tokens: Minimal speedup (10-20%)
- 512-2K tokens: 2-3x speedup
- >2K tokens: 3-4x speedup

Check sequence length is sufficient.

**Issue: RuntimeError: CUDA error**

Verify GPU supports Flash Attention:
```python
import torch
print(torch.cuda.get_device_capability())
# Should be ≥(7, 5) for Turing+
```

Flash Attention requires:
- Ampere (A100, A10): ✅ Full support
- Turing (T4): ✅ Supported
- Volta (V100): ❌ Not supported

**Issue: Accuracy degradation**

Check dtype is float16 or bfloat16 (not float32):
```python
q = q.to(torch.float16)  # Or torch.bfloat16
```

Flash Attention uses float16/bfloat16 for speed. Float32 not supported.

## Advanced topics

**Integration with HuggingFace Transformers**: See [references/transformers-integration.md](references/transformers-integration.md) for enabling Flash Attention in BERT, GPT, Llama models.

**Performance benchmarks**: See [references/benchmarks.md](references/benchmarks.md) for detailed speed and memory comparisons across GPUs and sequence lengths.

**Algorithm details**: See [references/algorithm.md](references/algorithm.md) for tiling strategy, recomputation, and IO complexity analysis.

**Advanced features**: See [references/advanced-features.md](references/advanced-features.md) for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.

## Hardware requirements

- **GPU**: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
- **VRAM**: Same as standard attention (Flash Attention doesn't increase memory)
- **CUDA**: 12.0+ (11.8 minimum)
- **PyTorch**: 2.2+ for native support

**Not supported**: V100 (Volta), CPU inference

## Resources

- Paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022)
- Paper: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR 2024)
- Blog: https://tridao.me/blog/2024/flash3/
- GitHub: https://github.com/Dao-AILab/flash-attention
- PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

Related Skills

memory-merger

25
from ComeOnOliver/skillshub

Merges mature lessons from a domain memory file into its instruction file. Syntax: `/memory-merger >domain [scope]` where scope is `global` (default), `user`, `workspace`, or `ws`.

memory-init

25
from ComeOnOliver/skillshub

在当前目录下初始化记忆系统,生成 CLAUDE.md(可选 AGENT.md 给 Cursor 用)、MEMORY.md 和 memory/ 目录。当用户说"初始化记忆"、"搭建记忆"、"memory init"、"/memory-init"时触发。

python-fastapi-development

25
from ComeOnOliver/skillshub

Python FastAPI backend development with async patterns, SQLAlchemy, Pydantic, authentication, and production API patterns.

memory-forensics

25
from ComeOnOliver/skillshub

Master memory forensics techniques including memory acquisition, process analysis, and artifact extraction using Volatility and related tools. Use when analyzing memory dumps, investigating incidents, or performing malware analysis from RAM captures.

fastapi-templates

25
from ComeOnOliver/skillshub

Create production-ready FastAPI projects with async patterns, dependency injection, and comprehensive error handling. Use when building new FastAPI applications or setting up backend API projects.

fastapi-router-py

25
from ComeOnOliver/skillshub

Create FastAPI routers with CRUD operations, authentication dependencies, and proper response models. Use when building REST API endpoints, creating new routes, implementing CRUD operations, or adding authenticated endpoints in FastAPI applications.

fastapi-pro

25
from ComeOnOliver/skillshub

Build high-performance async APIs with FastAPI, SQLAlchemy 2.0, and Pydantic V2. Master microservices, WebSockets, and modern Python async patterns. Use PROACTIVELY for FastAPI development, async optimization, or API architecture.

conversation-memory

25
from ComeOnOliver/skillshub

Persistent memory systems for LLM conversations including short-term, long-term, and entity-based memory Use when: conversation memory, remember, memory persistence, long-term memory, chat history.

agent-memory-systems

25
from ComeOnOliver/skillshub

Memory is the cornerstone of intelligent agents. Without it, every interaction starts from zero. This skill covers the architecture of agent memory: short-term (context window), long-term (vector stores), and the cognitive architectures that organize them. Key insight: Memory isn't just storage - it's retrieval. A million stored facts mean nothing if you can't find the right one. Chunking, embedding, and retrieval strategies determine whether your agent remembers or forgets. The field is fragm

agent-memory-mcp

25
from ComeOnOliver/skillshub

A hybrid memory system that provides persistent, searchable knowledge management for AI agents (Architecture, Patterns, Decisions).

ponyflash

25
from ComeOnOliver/skillshub

Generate images, videos, speech audio, and music using the PonyFlash Python SDK. Also handle local media editing with FFmpeg, including clip, concat, transcode, extract audio, frame capture, subtitle capability checks, and ASS subtitle prep. Use when the user asks to create, generate, produce, edit, trim, merge, concatenate, transcode, subtitle, or render AI-generated media content.

ai-runtime-memory

25
from ComeOnOliver/skillshub

AI Runtime分层记忆系统,支持SQL风格的事件查询、时间线管理,以及记忆的智能固化和检索,用于项目历史追踪和经验传承