FlashAttention-2 (ES)
FlashAttention-2 es un algoritmo avanzado diseñado para calcular el mecanismo de atención (attention) en grandes modelos de lenguaje (LLM). El algoritmo fue desarrollado por Tri Dao e investigadores de la Universidad de Stanford y presentado en julio de 2023[1]. Su objetivo clave es acelerar significativamente el entrenamiento y la inferencia de los modelos transformer mediante un uso más eficiente de los recursos de hardware de la GPU, preservando al mismo tiempo la identidad computacional completa con el mecanismo de atención estándar, es decir, sin pérdida de precisión.
FlashAttention-2 es la continuación lógica del algoritmo FlashAttention, presentado por el mismo equipo en 2022. La nueva versión resuelve el problema de la utilización incompleta de la GPU que se observaba en su predecesor y logra un aumento de velocidad de casi el doble en comparación con la primera versión.
Antecedentes: el problema de la atención en los transformers
El mecanismo estándar de autoatención (self-attention) es un cuello de botella al trabajar con secuencias de texto largas en los transformers. Su complejidad computacional y consumo de memoria crecen de forma cuadrática (O(N²)) en función de la longitud de la secuencia (N), lo que impone serias limitaciones a la longitud máxima del contexto y a la escalabilidad de los LLM[1].
Para resolver este problema, en 2022 se presentó el algoritmo FlashAttention[2]. Sus ideas clave son:
- Conciencia de la jerarquía de memoria de la GPU (IO-awareness): El algoritmo minimiza las costosas operaciones de lectura/escritura entre la memoria lenta de la GPU (HBM) y la memoria estática rápida (SRAM) en el chip.
- Procesamiento en bloques (tiling): Los cálculos se dividen en pequeños bloques (tiles), que se procesan en la SRAM rápida, lo que permite evitar la materialización de la matriz de atención completa en la memoria.
Esto permitió alcanzar un crecimiento lineal del consumo de memoria (O(N)) y una aceleración de 2 a 4 veces en comparación con las implementaciones estándar[2]. FlashAttention se popularizó ampliamente y contribuyó a la aparición de modelos con un contexto significativamente mayor, por ejemplo, de 2-4 mil tokens (GPT-3) a 128 mil (GPT-4) y más[3]. Así, en el modelo Falcon-40B, el uso de FlashAttention aceleró la inferencia 3 veces y el rendimiento general de la generación 5 veces en comparación con GPT-3[4].
Desarrollo y objetivos de FlashAttention-2
A pesar de su éxito, la primera versión de FlashAttention no utilizaba completamente los recursos computacionales de la GPU. En las tarjetas de video NVIDIA A100, el rendimiento solo alcanzaba el 25-40% del máximo teórico (FLOPs/s)[1]. La razón principal era la carga subóptima de los multiprocesadores de streaming (Streaming Multiprocessors) y las operaciones redundantes con la memoria compartida[5].
El objetivo de FlashAttention-2 fue acelerar aún más los cálculos mediante una paralelización más eficiente del trabajo y la minimización de operaciones auxiliares. El algoritmo fue completamente reescrito utilizando primitivas de bajo nivel de la biblioteca NVIDIA CUTLASS 3.x para lograr el máximo rendimiento[6].
Arquitectura técnica y principios de funcionamiento
FlashAttention-2 introduce tres mejoras clave para aumentar el paralelismo y la eficiencia[1]:
1. Minimización de operaciones no matriciales
El algoritmo reduce la cantidad de operaciones auxiliares de punto flotante que no son multiplicaciones de matrices (non-matmul FLOPs). Dado que los núcleos tensoriales de la GPU están optimizados específicamente para operaciones matriciales (GEMM) y las ejecutan hasta 16 veces más rápido, este cambio permite utilizar la mayor parte del tiempo los bloques más productivos de la GPU.
2. Paralelismo mejorado
En el FlashAttention original, el trabajo sobre una sola "cabeza" de atención no se paralelizó, lo que provocaba tiempos de inactividad con secuencias largas y tamaños de lote pequeños. FlashAttention-2 introduce el paralelismo entre bloques: ahora los cálculos para una sola cabeza de atención se distribuyen entre diferentes multiprocesadores de streaming de la GPU, lo que aumenta significativamente su carga.
3. División optimizada del trabajo dentro de un bloque
A nivel de un solo bloque computacional, el trabajo fue redistribuido entre grupos de hilos (warps) para reducir el intercambio de datos a través de la memoria compartida (shared memory). Esto reduce la cantidad de operaciones de lectura/escritura redundantes necesarias para la normalización de Softmax.
Rendimiento y eficiencia
Gracias a las mejoras arquitectónicas, FlashAttention-2 demuestra un aumento significativo del rendimiento:
- Aceleración doble: El algoritmo funciona aproximadamente 2 veces más rápido en comparación con la primera versión de FlashAttention[1].
- Alta utilización de la GPU: En la GPU NVIDIA A100 se alcanza el 50-73% del rendimiento máximo teórico (TFLOPs), lo que se acerca a la eficiencia de las operaciones optimizadas de multiplicación de matrices (GEMM)[1].
- Velocidad de cálculo récord:
- En la GPU A100 se alcanza una velocidad de hasta 225 TFLOP/s en el ciclo de entrenamiento de extremo a extremo de un modelo tipo GPT, lo que corresponde a un 72% de utilización de los bloques computacionales. En comparación, la atención estándar en las mismas condiciones cargaba la GPU a menos de 100 TFLOP/s[7].
- En la GPU H100, el rendimiento alcanza los 335 TFLOP/s[7].
Este aumento de rendimiento permite, por ejemplo, entrenar un modelo con una ventana de contexto de 16k tokens en el mismo tiempo que antes se requería para una ventana de 8k tokens[5]. Es importante destacar que el algoritmo sigue siendo exacto y determinista, por lo que su aplicación no afecta la calidad de las predicciones del modelo[8].
Aplicación e integración en el ecosistema
FlashAttention-2 se convirtió rápidamente en una herramienta estándar en el ecosistema de los LLM. Está integrado en muchos frameworks y bibliotecas populares:
- PyTorch: Soporte nativo.
- Hugging Face Transformers: El soporte se activa con el parámetro `attn_implementation="flash_attention_2"` al cargar el modelo[9]. Es compatible con docenas de arquitecturas (GPT, Llama, Falcon, BERT, etc.)[10].
- TensorRT-LLM, xFormers y Triton: El algoritmo está implementado para estas plataformas, lo que asegura una amplia aplicación[7].
La integración permite combinar fácilmente FlashAttention-2 con otros métodos de optimización, como la cuantización (GPTQ, QLoRA) y el ajuste fino eficiente (PEFT)[9].
Comparación con versiones posteriores
FlashAttention-3
La investigación en el campo de la optimización de la atención continúa. En julio de 2024, Tri Dao presentó FlashAttention-3, orientado a aprovechar las capacidades de la arquitectura de GPU NVIDIA Hopper (H100/H200). Las innovaciones clave son[3]:
- Soporte para FP8: Utiliza cálculos de punto flotante de 8 bits para una mayor aceleración.
- Operaciones asíncronas: Utiliza de manera más eficiente las capacidades asíncronas de la GPU.
FlashAttention-3 proporciona una aceleración de 1.5 a 2 veces en comparación con FlashAttention-2 en la GPU H100, alcanzando un rendimiento de hasta 740 TFLOP/s (75% del máximo teórico)[11].
Literatura
- Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. arXiv:2307.08691.
- Dao, T. et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. arXiv:2205.14135.
- Shah, J. et al. (2024). FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision. arXiv:2407.08608.
- Kwon, W. et al. (2023). Efficient Memory Management for Large Language Model Serving with PagedAttention. arXiv:2309.06180.
- Ye, Z. et al. (2025). FlashInfer: Efficient and Customizable Attention Engine for LLM Inference Serving. arXiv:2501.01005.
- Chen, Y. et al. (2023). FlashDecoding++: Faster Large Language Model Inference on GPUs. arXiv:2311.01282.
- Liu, Y. et al. (2024). FastAttention: Extending FlashAttention-2 to NPUs and Low-Resource GPUs. OpenReview: 76NYyOrnfk.
- Dege, P. et al. (2025). FlashMLA-ETAP: Efficient Transpose Attention Pipeline for Accelerating MLA Inference on NVIDIA H20 GPUs. arXiv:2506.01969.
- Wang, G. et al. (2024). FlashMask: Efficient and Rich Mask Extension of FlashAttention. OpenReview: rog0J435OO.
- Abbott, V.; Zardini, G. (2025). FlashAttention on a Napkin: A Diagrammatic Approach to Deep Learning IO-Awareness. arXiv:2412.03317.
Notas
- ↑ 1.0 1.1 1.2 1.3 1.4 1.5 Dao, Tri. «FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning». arXiv:2307.08691 [cs.LG], 17 de julio de 2023. [1]
- ↑ 2.0 2.1 «Optimizing LLMs for Speed and Memory». Hugging Face Documentation. [2]
- ↑ 3.0 3.1 Dao, Tri. «FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision». Tri Dao's Blog. [3]
- ↑ «FlashAttention vs FlashAttention-2 - an Analysis». E2E Networks Blog. [4]
- ↑ 5.0 5.1 «FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning». OpenReview. [5]
- ↑ «FlashAttention-2». Hazy Research, Stanford University. [6]
- ↑ 7.0 7.1 7.2 Dao, Tri. «FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning» (PDF). arXiv:2307.08691. [7]
- ↑ Raschka, Sebastian. «Llama 2 and FlashAttention 2». Ahead of AI Magazine. [8]
- ↑ 9.0 9.1 Belkada, Younes. «Faster and more memory efficient models with Flash Attention 2!». LinkedIn. [9]
- ↑ «GPU inference». Hugging Face Documentation. [10]
- ↑ Dao, Tri, et al. «FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision». arXiv:2407.08608 [cs.LG], 11 de julio de 2024. [11]