⚡ Flash Attentionで注意機構を2〜4倍速
メモリ10〜20倍削減・速度2〜4倍のFlash Attention 最適化Skill。
📺 まず動画で見る(YouTube)
▶ 【最新版】Claude(クロード)完全解説!20以上の便利機能をこの動画1本で全て解説 ↗
※ jpskill.com 編集部が参考用に選んだ動画です。動画の内容と Skill の挙動は厳密には一致しないことがあります。
📜 元の英語説明(参考)
Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster inference. Supports PyTorch native SDPA, flash-attn library, H100 FP8, and sliding window attention.
🇯🇵 日本人クリエイター向け解説
メモリ10〜20倍削減・速度2〜4倍のFlash Attention 最適化Skill。
※ jpskill.com 編集部が日本のビジネス現場向けに補足した解説です。Skill本体の挙動とは独立した参考情報です。
下記のコマンドをコピーしてターミナル(Mac/Linux)または PowerShell(Windows)に貼り付けてください。 ダウンロード → 解凍 → 配置まで全自動。
mkdir -p ~/.claude/skills && cd ~/.claude/skills && curl -L -o optimizing-attention-flash.zip https://jpskill.com/download/124.zip && unzip -o optimizing-attention-flash.zip && rm optimizing-attention-flash.zip
$d = "$env:USERPROFILE\.claude\skills"; ni -Force -ItemType Directory $d | Out-Null; iwr https://jpskill.com/download/124.zip -OutFile "$d\optimizing-attention-flash.zip"; Expand-Archive "$d\optimizing-attention-flash.zip" -DestinationPath $d -Force; ri "$d\optimizing-attention-flash.zip"
完了後、Claude Code を再起動 → 普通に「動画プロンプト作って」のように話しかけるだけで自動発動します。
💾 手動でダウンロードしたい(コマンドが難しい人向け)
- 1. 下の青いボタンを押して
optimizing-attention-flash.zipをダウンロード - 2. ZIPファイルをダブルクリックで解凍 →
optimizing-attention-flashフォルダができる - 3. そのフォルダを
C:\Users\あなたの名前\.claude\skills\(Win)または~/.claude/skills/(Mac)へ移動 - 4. Claude Code を再起動
⚠️ ダウンロード・利用は自己責任でお願いします。当サイトは内容・動作・安全性について責任を負いません。
🎯 このSkillでできること
下記の説明文を読むと、このSkillがあなたに何をしてくれるかが分かります。Claudeにこの分野の依頼をすると、自動で発動します。
📦 インストール方法 (3ステップ)
- 1. 上の「ダウンロード」ボタンを押して .skill ファイルを取得
- 2. ファイル名の拡張子を .skill から .zip に変えて展開(macは自動展開可)
- 3. 展開してできたフォルダを、ホームフォルダの
.claude/skills/に置く- · macOS / Linux:
~/.claude/skills/ - · Windows:
%USERPROFILE%\.claude\skills\
- · macOS / Linux:
Claude Code を再起動すれば完了。「このSkillを使って…」と話しかけなくても、関連する依頼で自動的に呼び出されます。
詳しい使い方ガイドを見る →- 最終更新
- 2026-05-17
- 取得日時
- 2026-05-17
- 同梱ファイル
- 3
💬 こう話しかけるだけ — サンプルプロンプト
- › Flash Attentionで注意機構を2〜4倍速 を使って、最小構成のサンプルコードを示して
- › Flash Attentionで注意機構を2〜4倍速 の主な使い方と注意点を教えて
- › Flash Attentionで注意機構を2〜4倍速 を既存プロジェクトに組み込む方法を教えて
これをClaude Code に貼るだけで、このSkillが自動発動します。
📖 Skill本文(日本語訳)
※ 原文(英語/中国語)を Gemini で日本語化したものです。Claude 自身は原文を読みます。誤訳がある場合は原文をご確認ください。
Flash Attention - 高速でメモリ効率の良いAttention
クイックスタート
Flash Attentionは、IOを考慮したタイリングと再計算により、TransformerのAttentionにおいて2〜4倍の高速化と10〜20倍のメモリ削減を実現します。
PyTorchネイティブ (最も簡単、PyTorch 2.2以降):
import torch
import torch.nn.functional as F
q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
# Automatically uses Flash Attention if available
out = F.scaled_dot_product_attention(q, k, v)
flash-attnライブラリ (より多くの機能):
pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func
# q, k, v: [batch, seqlen, nheads, headdim]
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
一般的なワークフロー
ワークフロー1: 既存のPyTorchモデルで有効にする
このチェックリストをコピーしてください:
Flash Attention統合:
- [ ] ステップ1: PyTorchのバージョンを確認する (≥2.2)
- [ ] ステップ2: Flash Attentionバックエンドを有効にする
- [ ] ステップ3: プロファイリングで高速化を確認する
- [ ] ステップ4: 精度がベースラインと一致するかテストする
ステップ1: PyTorchのバージョンを確認する
python -c "import torch; print(torch.__version__)"
# 2.2.0以上である必要があります
2.2未満の場合は、アップグレードしてください:
pip install --upgrade torch
ステップ2: Flash Attentionバックエンドを有効にする
標準のAttentionを置き換えます:
# Before (標準のAttention)
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
out = attn_weights @ v
# After (Flash Attention)
import torch.nn.functional as F
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
Flash Attentionバックエンドを強制的に有効にします:
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
out = F.scaled_dot_product_attention(q, k, v)
ステップ3: プロファイリングで高速化を確認する
import torch.utils.benchmark as benchmark
def test_attention(use_flash):
q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
if use_flash:
with torch.backends.cuda.sdp_kernel(enable_flash=True):
return F.scaled_dot_product_attention(q, k, v)
else:
attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
return attn @ v
# Benchmark
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
期待される結果: シーケンス長が512トークンを超える場合、2〜4倍の高速化。
ステップ4: 精度がベースラインと一致するかテストする
# Compare outputs
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype='torch.float16') for _ in range(3)]
# Flash Attention
out_flash = F.scaled_dot_product_attention(q, k, v)
# Standard attention
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
out_standard = attn_weights @ v
# Check difference
diff = (out_flash - out_standard).abs().max()
print(f"Max difference: {diff:.6f}")
# float16の場合、1e-3未満である必要があります
ワークフロー2: 高度な機能のためにflash-attnライブラリを使用する
マルチクエリAttention、スライディングウィンドウ、またはH100 FP8の場合。
このチェックリストをコピーしてください:
flash-attnライブラリのセットアップ:
- [ ] ステップ1: flash-attnライブラリをインストールする
- [ ] ステップ2: Attentionコードを修正する
- [ ] ステップ3: 高度な機能を有効にする
- [ ] ステップ4: パフォーマンスをベンチマークする
ステップ1: flash-attnライブラリをインストールする
# NVIDIA GPU (CUDA 12.0以降)
pip install flash-attn --no-build-isolation
# インストールを確認する
python -c "from flash_attn import flash_attn_func; print('Success')"
ステップ2: Attentionコードを修正する
from flash_attn import flash_attn_func
# Input: [batch_size, seq_len, num_heads, head_dim]
# 必要に応じて [batch, heads, seq, dim] から転置する
q = q.transpose(1, 2) # [batch, seq, heads, dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = flash_attn_func(
q, k, v,
dropout_p=0.1,
causal=True, # 自己回帰モデルの場合
window_size=(-1, -1), # スライディングウィンドウなし
softmax_scale=None # 自動スケーリング
)
out = out.transpose(1, 2) # [batch, heads, seq, dim] に戻す
ステップ3: 高度な機能を有効にする
マルチクエリAttention (ヘッド間でK/Vを共有):
from flash_attn import flash_attn_func
# q: [batch, seq, num_q_heads, dim]
# k, v: [batch, seq, num_kv_heads, dim] # KVヘッド数が少ない
out = flash_attn_func(q, k, v) # MQAを自動的に処理します
スライディングウィンドウAttention (ローカルAttention):
# 前後256トークンのウィンドウのみにAttentionを適用する
out = flash_attn_func(
q, k, v,
window_size=(256, 256), # (左, 右) ウィンドウ
causal=True
)
ステップ4: パフォーマンスをベンチマークする
import torch
from flash_attn import flash_attn_func
import time
q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
# ウォームアップ
for _ in range(10):
_ = flash_attn_func(q, k, v)
# ベンチマーク
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
out = flash_attn_func(q, k, v)
torch.cuda.synchronize()
end = time.time()
print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
ワークフロー3: H100 FP8最適化 (FlashAttention-3)
H100 GPUで最高のパフォーマンスを実現します。
FP8セットアップ:
- [ ] ステップ1: H100 GPUが利用可能であることを確認する
- [ ] ステップ2: FP8サポート付きのflash-attnをインストールする
- [ ] ステップ3: 入力をFP8に変換する
- [ ] ステップ4: FP8 Attentionで実行する
ステップ1: H100 GPUを確認する
nvidia-smi --query-gpu=name --format=csv
# "H100"または"H800"と表示されるはずです
ステップ2: FP8サポート付きのflash-attnをインストールする
pip install flash-attn --no-build-isolation
# H100にはFP8サポートが含まれています
ステップ3: 入力をFP8に変換する
import torch
q = torch.randn(2, 4096, 32, 64 📜 原文 SKILL.md(Claudeが読む英語/中国語)を展開
Flash Attention - Fast Memory-Efficient Attention
Quick start
Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.
PyTorch native (easiest, PyTorch 2.2+):
import torch
import torch.nn.functional as F
q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
# Automatically uses Flash Attention if available
out = F.scaled_dot_product_attention(q, k, v)
flash-attn library (more features):
pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func
# q, k, v: [batch, seqlen, nheads, headdim]
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
Common workflows
Workflow 1: Enable in existing PyTorch model
Copy this checklist:
Flash Attention Integration:
- [ ] Step 1: Check PyTorch version (≥2.2)
- [ ] Step 2: Enable Flash Attention backend
- [ ] Step 3: Verify speedup with profiling
- [ ] Step 4: Test accuracy matches baseline
Step 1: Check PyTorch version
python -c "import torch; print(torch.__version__)"
# Should be ≥2.2.0
If <2.2, upgrade:
pip install --upgrade torch
Step 2: Enable Flash Attention backend
Replace standard attention:
# Before (standard attention)
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
out = attn_weights @ v
# After (Flash Attention)
import torch.nn.functional as F
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
Force Flash Attention backend:
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
out = F.scaled_dot_product_attention(q, k, v)
Step 3: Verify speedup with profiling
import torch.utils.benchmark as benchmark
def test_attention(use_flash):
q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
if use_flash:
with torch.backends.cuda.sdp_kernel(enable_flash=True):
return F.scaled_dot_product_attention(q, k, v)
else:
attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
return attn @ v
# Benchmark
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
Expected: 2-4x speedup for sequences >512 tokens.
Step 4: Test accuracy matches baseline
# Compare outputs
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
# Flash Attention
out_flash = F.scaled_dot_product_attention(q, k, v)
# Standard attention
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
out_standard = attn_weights @ v
# Check difference
diff = (out_flash - out_standard).abs().max()
print(f"Max difference: {diff:.6f}")
# Should be <1e-3 for float16
Workflow 2: Use flash-attn library for advanced features
For multi-query attention, sliding window, or H100 FP8.
Copy this checklist:
flash-attn Library Setup:
- [ ] Step 1: Install flash-attn library
- [ ] Step 2: Modify attention code
- [ ] Step 3: Enable advanced features
- [ ] Step 4: Benchmark performance
Step 1: Install flash-attn library
# NVIDIA GPUs (CUDA 12.0+)
pip install flash-attn --no-build-isolation
# Verify installation
python -c "from flash_attn import flash_attn_func; print('Success')"
Step 2: Modify attention code
from flash_attn import flash_attn_func
# Input: [batch_size, seq_len, num_heads, head_dim]
# Transpose from [batch, heads, seq, dim] if needed
q = q.transpose(1, 2) # [batch, seq, heads, dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = flash_attn_func(
q, k, v,
dropout_p=0.1,
causal=True, # For autoregressive models
window_size=(-1, -1), # No sliding window
softmax_scale=None # Auto-scale
)
out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]
Step 3: Enable advanced features
Multi-query attention (shared K/V across heads):
from flash_attn import flash_attn_func
# q: [batch, seq, num_q_heads, dim]
# k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads
out = flash_attn_func(q, k, v) # Automatically handles MQA
Sliding window attention (local attention):
# Only attend to window of 256 tokens before/after
out = flash_attn_func(
q, k, v,
window_size=(256, 256), # (left, right) window
causal=True
)
Step 4: Benchmark performance
import torch
from flash_attn import flash_attn_func
import time
q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
# Warmup
for _ in range(10):
_ = flash_attn_func(q, k, v)
# Benchmark
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
out = flash_attn_func(q, k, v)
torch.cuda.synchronize()
end = time.time()
print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
Workflow 3: H100 FP8 optimization (FlashAttention-3)
For maximum performance on H100 GPUs.
FP8 Setup:
- [ ] Step 1: Verify H100 GPU available
- [ ] Step 2: Install flash-attn with FP8 support
- [ ] Step 3: Convert inputs to FP8
- [ ] Step 4: Run with FP8 attention
Step 1: Verify H100 GPU
nvidia-smi --query-gpu=name --format=csv
# Should show "H100" or "H800"
Step 2: Install flash-attn with FP8 support
pip install flash-attn --no-build-isolation
# FP8 support included for H100
Step 3: Convert inputs to FP8
import torch
q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
# Convert to float8_e4m3 (FP8)
q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.to(torch.float8_e4m3fn)
Step 4: Run with FP8 attention
from flash_attn import flash_attn_func
# FlashAttention-3 automatically uses FP8 kernels on H100
out = flash_attn_func(q_fp8, k_fp8, v_fp8)
# Result: ~1.2 PFLOPS, 1.5-2x faster than FP16
When to use vs alternatives
Use Flash Attention when:
- Training transformers with sequences >512 tokens
- Running inference with long context (>2K tokens)
- GPU memory constrained (OOM with standard attention)
- Need 2-4x speedup without accuracy loss
- Using PyTorch 2.2+ or can install flash-attn
Use alternatives instead:
- Standard attention: Sequences <256 tokens (overhead not worth it)
- xFormers: Need more attention variants (not just speed)
- Memory-efficient attention: CPU inference (Flash Attention needs GPU)
Common issues
Issue: ImportError: cannot import flash_attn
Install with no-build-isolation flag:
pip install flash-attn --no-build-isolation
Or install CUDA toolkit first:
conda install cuda -c nvidia
pip install flash-attn --no-build-isolation
Issue: Slower than expected (no speedup)
Flash Attention benefits increase with sequence length:
- <512 tokens: Minimal speedup (10-20%)
- 512-2K tokens: 2-3x speedup
-
2K tokens: 3-4x speedup
Check sequence length is sufficient.
Issue: RuntimeError: CUDA error
Verify GPU supports Flash Attention:
import torch
print(torch.cuda.get_device_capability())
# Should be ≥(7, 5) for Turing+
Flash Attention requires:
- Ampere (A100, A10): ✅ Full support
- Turing (T4): ✅ Supported
- Volta (V100): ❌ Not supported
Issue: Accuracy degradation
Check dtype is float16 or bfloat16 (not float32):
q = q.to(torch.float16) # Or torch.bfloat16
Flash Attention uses float16/bfloat16 for speed. Float32 not supported.
Advanced topics
Integration with HuggingFace Transformers: See references/transformers-integration.md for enabling Flash Attention in BERT, GPT, Llama models.
Performance benchmarks: See references/benchmarks.md for detailed speed and memory comparisons across GPUs and sequence lengths.
Algorithm details: See references/algorithm.md for tiling strategy, recomputation, and IO complexity analysis.
Advanced features: See references/advanced-features.md for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.
Hardware requirements
- GPU: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
- VRAM: Same as standard attention (Flash Attention doesn't increase memory)
- CUDA: 12.0+ (11.8 minimum)
- PyTorch: 2.2+ for native support
Not supported: V100 (Volta), CPU inference
Resources
- Paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022)
- Paper: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR 2024)
- Blog: https://tridao.me/blog/2024/flash3/
- GitHub: https://github.com/Dao-AILab/flash-attention
- PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
同梱ファイル
※ ZIPに含まれるファイル一覧。`SKILL.md` 本体に加え、参考資料・サンプル・スクリプトが入っている場合があります。
- 📄 SKILL.md (10,201 bytes)
- 📎 references/benchmarks.md (7,129 bytes)
- 📎 references/transformers-integration.md (7,427 bytes)