FlashAttention-3 — FlashAttention-3
FlashAttention-3は、トランスフォーマーニューラルネットワークにおけるアテンション機構(attention)を最適化するためのアルゴリズムであり、特にNVIDIA Hopperアーキテクチャ(H100)GPUのハードウェア能力を最大限に活用するために設計されました[1]。このアルゴリズムは2024年に、Colfax Research、Meta、NVIDIA、ジョージア工科大学、プリンストン大学、およびTogether AIの研究者グループによって発表されました。この研究はNeurIPS 2024カンファレンスに採択され、spotlight論文として注目されました[2]。
FlashAttention-3は、FlashAttention(2022年)およびFlashAttention-2(2023年)に続くアルゴリズムファミリーの第3世代です。その主な目的は、計算の精度を維持しながら、大規模言語モデル(LLM)の学習と推論を大幅に高速化することです。
導入と背景
アテンション機構の課題
トランスフォーマーの重要な構成要素は自己アテンション機構(self-attention)ですが、その計算量とメモリ消費量は入力シーケンス長(n)の増加に伴い二乗(O(n²))で増大します[1]。これは深刻な「ボトルネック」を生み出します。なぜなら、現代のGPUは高速な行列乗算に最適化されていますが、指数関数(例:Softmax)の計算は桁違いに遅いためです。さらに、ナイーブな実装では、GPUメモリに巨大な中間アテンションテンソルを保持する必要があり、モデルのスケーラビリティが制限されます。
FlashAttentionとFlashAttention-2
この問題を解決するために、2022年にFlashAttentionが提案されました。これは、2つの技術によって低速なグローバルメモリ(HBM)へのアクセス量を削減しました。
- ブロック処理(タイリング): 計算をブロック(タイル)に分割し、高速なオンチップメモリ(SRAM)で処理します。
- 演算の融合: すべての演算(行列乗算、Softmax)は、中間結果をグローバルメモリに書き込むことなく、単一のGPUカーネルで実行されます。
これにより、メモリ計算量は二次から線形に削減され、計算は2~4倍高速化されました。
2023年には、計算の並列化を最適化した改良版であるFlashAttention-2が発表されました。NVIDIA Ampereアーキテクチャ(A100)のGPUでは、理論上のピーク性能の約70%を達成しました[3]。しかし、より新しいNVIDIA Hopperアーキテクチャ(H100)では、その効率は約35%とはるかに低くなりました[1]。これは、アルゴリズムがHopperの新しいハードウェア機能を活用していなかったためであり、これがFlashAttention-3開発のきっかけとなりました。
Hopper GPU (H100) の新しいハードウェア機能
NVIDIA Hopperアーキテクチャは、FlashAttention-3が最大限のパフォーマンスを達成するために活用する、いくつかの新機能を提供しました[4]。
- WGMMA (Warpgroup Matrix Multiply-Accumulate): テンソルコア向けの新しい命令タイプで、Ampereアーキテクチャと比較してほぼ2倍の性能向上で行列乗算を実行します。
- TMA (Tensor Memory Accelerator): グローバルメモリ(HBM)と共有メモリ(shared memory)間のデータ転送を高速化するハードウェアモジュールです。TMAはアドレス計算を自動的に実行し、計算コアの負荷を軽減します。
- FP8フォーマット: 8ビット浮動小数点データフォーマットのハードウェアサポートです。FP16と比較して理論上の性能を2倍にしますが、動的範囲が限られているため精度が低下するリスクがあります。
FlashAttention-3の技術的革新
このアルゴリズムは、Hopperアーキテクチャのために特別に設計された3つの主要な最適化手法を実装しています[4]。
1. 非同期実行とワープの専門化
FlashAttention-3はwarp-specialization(ワープの専門化)の原則を採用しており、GPU上の異なるスレッドグループ(warps)がそれぞれ異なるタスクに特化します。
- Producer warps: TMAを使用してグローバルメモリからデータをロードします。
- Consumer warps: テンソルコアで行列乗算を実行します。
Hopperのハードウェア非同期性により、これらの操作は時間的にオーバーラップします。一方のワープグループが計算を実行している間に、もう一方のグループが次のブロックのデータを並行してロードします。この「ピンポン」方式(ping-pong scheduling)で構成されたパイプラインアプローチ(pipeline)により、低速な演算(例:Softmax)の遅延を隠蔽し、GPUのすべての機能モジュールを最大限に活用することができます。
2. メモリアクセスの最小化
このアルゴリズムは、以前のバージョンからtilingの思想を継承しつつ、TMAを積極的に活用して、現在の計算と並行して次のデータブロックを非同期にロードします。低速なHBMから高速なSRAMへのデータ転送は、実質的に主要な計算の「影」で実行されるため、GPUがデータを待ってアイドル状態になる時間が短縮されます。
3. 量子化誤差を低減した低精度(FP8)
FP8への移行は速度を2倍にしますが、量子化により大幅な精度低下を引き起こす可能性があります。これに対処するため、開発者はincoherent processingという手法を導入しました[4]。その本質は次の通りです。
- アテンションを計算する前に、特徴ベクトル(クエリQとキーK)にランダムな直交行列(例:アダマール行列)を乗算します。
- この変換により、異常に大きな値(外れ値)がすべての座標に「拡散」され、その分布が均一化されます。
- その後、FP8への量子化が実行されますが、この時点では誤差がより小さくなっています。
- この変換は直交であるため、行列の効果は乗算時に相殺されるので、最終的なアテンションの結果(QKᵀ)を歪めることはありません。
この技術により、変換なしでFP8を標準的に適用する場合と比較して、FP8でのアテンション計算誤差を約2.6倍削減することができました[4]。
パフォーマンスと意義
これらの技術を適用することで、FlashAttention-3はH100 GPU上で以前のバージョンを大幅に上回る性能を達成しました。
- FlashAttention-2と比較して1.5~2倍の高速化。
- 高いGPU使用率: H100の理論上の最大性能の約75~85%に到達。
- スループット:
- 半精度(FP16/BF16)で最大740~840 TFLOPS。
- 8ビット精度(FP8)を使用した場合、最大1.2~1.3 PFLOPS(ペタフロップス)[2]。
FlashAttention-3の高い効率性は、LLMの開発と応用に直接的な影響を与えます。
- 学習時間の短縮: アテンションが75~100%高速化されることで、数週間から数ヶ月かかることもあるモデルの学習時間が大幅に短縮されます。
- コンテキストウィンドウの拡大: モデルはより長いシーケンス(数十万トークン)を効率的に処理できるようになり、これは長文のドキュメントやコードの分析において重要です[1]。
- リソースの効率的な利用: より少ないGPUで同等のパフォーマンスを達成したり、同じハードウェアでより高い速度を得たりすることが可能になり、モデルのデプロイコストが削減されます。
利用可能性と統合
開発者たちはFlashAttention-3のソースコードをオープンソースライセンスの下でGitHubに公開しています[4]。PyTorchやHugging Face Transformersライブラリなどの主要な深層学習フレームワークへの統合が期待されており、これにより幅広い開発者や研究者がこの技術を利用できるようになります。以前のバージョンはすでに業界のデファクトスタンダードとなっており、FlashAttention-3もこの傾向を引き継ぐ可能性が高いです。
外部リンク
参考文献
- Shah, J. et al. (2024). FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision. arXiv:2407.08608.
- Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. arXiv:2307.08691.
- Dao, T. et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. arXiv:2205.14135.
- Kwon, W. et al. (2023). Efficient Memory Management for Large Language Model Serving with PagedAttention. arXiv:2309.06180.
- Ye, Z. et al. (2025). FlashInfer: Efficient and Customizable Attention Engine for LLM Inference Serving. arXiv:2501.01005.
- Chen, Y. et al. (2023). FlashDecoding++: Faster Large Language Model Inference on GPUs. arXiv:2311.01282.
- Liu, Y. et al. (2024). FastAttention: Extending FlashAttention-2 to NPUs and Low-Resource GPUs. OpenReview: 76NYyOrnfk.
- Dege, P. et al. (2025). FlashMLA-ETAP: Efficient Transpose Attention Pipeline for Accelerating MLA Inference on NVIDIA H20 GPUs. arXiv:2506.01969.
- Wang, G. et al. (2024). FlashMask: Efficient and Rich Mask Extension of FlashAttention. arXiv:2410.01359.
- Abbott, V.; Zardini, G. (2025). FlashAttention on a Napkin: A Diagrammatic Approach to Deep Learning IO-Awareness. arXiv:2412.03317.
脚注
- ↑ 1.0 1.1 1.2 1.3 "FlashAttention-3 unleashes the power of H100 GPUs for LLMs". VentureBeat. [1]
- ↑ 2.0 2.1 Shah, Jay, et al. "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision". OpenReview. [2]
- ↑ Shah, Jay, et al. "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision". arXiv:2407.08608v2 [cs.LG], 15 July 2024. [3]
- ↑ 4.0 4.1 4.2 4.3 4.4 Shah, Jay, et al. "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision". Together AI Blog. [4]