FlashAttention
FlashAttention — это революционный алгоритм вычисления механизма внимания (attention), разработанный для значительного ускорения обучения и инференса больших языковых моделей (LLM) при сохранении полной точности вычислений. Алгоритм был впервые представлен в 2022 году командой исследователей из Стэнфордского университета под руководством Три Дао (Tri Dao)[1].
Ключевая идея FlashAttention заключается в реорганизации вычислений с учётом иерархии памяти GPU, что позволяет минимизировать количество обращений к медленной памяти и устранить главное узкое место стандартного механизма внимания.
Проблематика стандартного внимания
Стандартный механизм самовнимания в трансформерах вычисляется по формуле: где Q, K, V — это матрицы запросов, ключей и значений.
Основная проблема этого подхода — квадратичная сложность по времени и памяти (O(N²)) относительно длины последовательности N[1]. При наивной реализации необходимо вычислять и хранить в памяти GPU полную матрицу внимания S размером N×N, что приводит к двум критическим проблемам:
- Большое потребление памяти: Хранение матрицы N×N становится невозможным при работе с длинными контекстами.
- Операции ввода-вывода (IO): Главным узким местом является не количество арифметических операций, а постоянные обращения к медленной памяти GPU.
Иерархия памяти GPU
Для понимания проблемы важно различать два типа памяти в GPU (на примере NVIDIA A100):
- SRAM (статическая память): Быстрая внутрикристаллическая память малого объёма (~20 МБ) с огромной пропускной способностью (до 19 ТБ/с).
- HBM (высокопропускная память): Медленная память большого объёма (40–80 ГБ) с гораздо меньшей пропускной способностью (около 1.5 ТБ/с)[2].
Эта асимметрия делает стандартный алгоритм внимания ограниченным пропускной способностью памяти (memory-bound), так как он постоянно читает и записывает большие матрицы из медленной HBM, что и является главным источником задержек.
Ключевые инновации FlashAttention
FlashAttention является IO-осведомлённым (IO-aware) алгоритмом, который решает проблему путём минимизации обращений к HBM. Это достигается за счёт трёх основных техник.
Тайлинг и блочная обработка
Вместо обработки всей матрицы целиком, FlashAttention разбивает входные матрицы Q, K, V на небольшие блоки (тайлы), которые помещаются в быструю SRAM. Алгоритм последовательно загружает эти блоки, выполняет для них все вычисления внимания и обновляет конечный результат, не сохраняя полную матрицу внимания в медленной HBM[1].
Онлайн-вычисление Softmax
Ключевым техническим прорывом стало "онлайн" вычисление Softmax. Стандартный Softmax требует знания всех элементов входного вектора для нормализации. FlashAttention использует модифицированный алгоритм, который позволяет вычислять Softmax по частям. Он поддерживает два промежуточных значения (текущий максимум и сумму экспонент), которые обновляются по мере обработки новых блоков, что позволяет получить точный результат без доступа ко всей матрице сразу[2].
Слияние операций в одно ядро CUDA
Все операции внимания (матричное умножение QKᵀ, маскирование, Softmax, умножение на V) объединены в единое слитое ядро CUDA (fused kernel). Это кардинально сокращает количество операций чтения/записи в HBM: вместо многократных проходов по всей матрице алгоритм загружает блок в SRAM один раз, выполняет все вычисления и записывает только конечный результат.
Теоретическая и практическая эффективность
Сложность и оптимальность
FlashAttention снижает потребление памяти с O(N²) до O(N), что обеспечивает линейное масштабирование. Было доказано, что IO-сложность алгоритма является теоретически оптимальной для вычисления внимания в двухуровневой иерархии памяти, то есть быстрее выполнить точное внимание невозможно без изменения аппаратной части[3].
Эмпирические результаты
Первая версия FlashAttention продемонстрировала значительные улучшения:
- Ускорение:
- Экономия памяти: До 20-кратной экономии памяти по сравнению с точными базовыми реализациями.
- Улучшение качества моделей: Благодаря возможности работать с более длинными контекстами, FlashAttention не только не теряет, но и улучшает качество моделей. Например, перплексия GPT-2 улучшилась на 0.7 пункта, а точность в задачах классификации длинных документов выросла на 6.4 пункта[1].
Эволюция и дальнейшие разработки
Успех FlashAttention положил начало целой серии аппаратно-ориентированных алгоритмов.
FlashAttention-2 (2023)
Вторая версия была нацелена на более полное использование ресурсов GPU. В оригинальном FlashAttention эффективность на NVIDIA A100 составляла лишь 25–40% от максимума. FlashAttention-2 ввела улучшения в параллелизации вычислений, что позволило[4]:
- Достичь двукратного ускорения по сравнению с первой версией.
- Увеличить утилизацию GPU до 50–73% от теоретического максимума.
- Расширить поддержку до голов внимания размером 256, а также для архитектур Multi-Query Attention (MQA).
FlashAttention-3 (2024)
Третья версия была оптимизирована специально для архитектуры GPU NVIDIA Hopper (H100)[5]. Она использует новые аппаратные возможности, такие как асинхронность Tensor Cores и поддержку FP8, что позволило:
- Достичь ещё 1.5–2-кратного ускорения по сравнению с FlashAttention-2.
- Достичь производительности до 740 TFLOPS на FP16 и близко к 1.2 PFLOPS на FP8.
Специализированные решения
Идеи FlashAttention были развиты в других проектах:
- FlashInfer (2025): Настраиваемый движок внимания, оптимизированный специально для задач инференса LLM. Он фокусируется на эффективной работе с KV-кэшем в режиме потоковой генерации[6].
- FlashMLA (2024): Реализация внимания со сжатием контекстного кэша (latent attention), позволяющая экономить память на очень длинных последовательностях с минимальной потерей информации[7].
Влияние на индустрию и экосистему
FlashAttention стал фундаментальным прорывом и быстро превратился в стандарт индустрии для эффективного обучения и инференса LLM. Он был интегрирован в ключевые библиотеки, такие как PyTorch и Hugging Face, и используется в большинстве крупных языковых моделей (LLaMA, MPT, Falcon, Claude и др.).
Именно FlashAttention и его последующие версии сыграли решающую роль в увеличении контекстных окон языковых моделей: с 2–4 тыс. токенов (GPT-3) до 128 тыс. токенов (GPT-4) и даже до миллионов токенов в экспериментальных моделях[8]. Алгоритм устранил одно из главных препятствий на пути масштабирования трансформеров, открыв новые возможности для приложений ИИ, от анализа длинных документов до мультимодального понимания.
Ссылки
Литература
- Dao, T. et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. arXiv:2205.14135.
- Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. arXiv:2307.08691.
- Shah, J. et al. (2024). FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-Precision. arXiv:2407.08608.
- Kwon, W. et al. (2023). Efficient Memory Management for Large Language Model Serving with PagedAttention. arXiv:2309.06180.
- Hong, K. et al. (2023). FlashDecoding++: Faster Large Language Model Inference on GPUs. arXiv:2311.01282.
- Ye, Z. et al. (2025). FlashInfer: Efficient and Customizable Attention Engine for LLM Inference Serving. arXiv:2501.01005.
- Dege, P. et al. (2025). FlashMLA-ETAP: Efficient Transpose Attention Pipeline for Accelerating MLA Inference on NVIDIA H20 GPUs. arXiv:2506.01969.
- Wang, G. et al. (2025). FlashMask: Efficient and Rich Mask Extension of FlashAttention. OpenReview wUtXB43Chi.
- Dao, T. et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (OpenReview version). OpenReview H4DqfPSibmx.
- Gholami, A. et al. (2024). FlashAttention on a Napkin: A Diagrammatic Approach to Deep Learning IO-Awareness. OpenReview pF2ukh7HxA.
Примечания
- ↑ 1,0 1,1 1,2 1,3 1,4 Дао, Три, и др. «FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness». arXiv:2205.14135 [cs.LG], 28 мая 2022 г. [1]
- ↑ 2,0 2,1 Дао, Три, и др. «FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness». OpenReview. [2]
- ↑ «We're Training AI Twice as Fast This Year as Last». IEEE Spectrum. [3]
- ↑ Дао, Три. «FlashAttention-2». tridao.me. [4]
- ↑ «FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision». PyTorch Blog. [5]
- ↑ «[2501.01005] FlashInfer: Efficient and Customizable Attention Engine for LLM Inference Serving». arXiv. [6]
- ↑ «GitHub - deepseek-ai/FlashMLA: FlashMLA: Efficient MLA decoding kernels». GitHub. [7]
- ↑ «The Evolution of Flash Attention: Revolutionizing Transformer Efficiency». Medium. [8]