LLM - Improving Attention

Posted by MakiNaruto on Thu, Feb 6, 2025

背景: 当输入序列(sequence length)较长时, Transformer的计算过程缓慢且耗费内存,即计算的矩阵会变得很大, 这是因为self-attention的计算时间内存存取复杂度会随着输入序列的增加成二次增长。因此业界提出了几种加速方案.

FlashAttention

Attention标准实现没有考虑到对内存频繁的IO操作, 它基本上将HBM加载/存储操作视为0成本。因此FlashAttention的优化方案是通过“split attention”的方式, 将多个操作融合在一起, 只从HBM加载一次,然后将结果写回来。减少了内存带宽的通信开销,并且采用了高效的GPU实现, 极大地提高了效率。
核心:用分块softmax等价替代传统softmax。
优点:节约HBM,高效利用SRAM,省显存,提速度。

相关内容补充

内存不是一个单一的工件,它在本质上是分层的,一般的规则是:内存越快,越昂贵,容量越小。因此和木桶原理类似, 需要考虑到每个模块的瓶颈。

内存速率图

  • SRAM(Static Random Access Memory)是一种高速缓存内存,通常用于CPU和GPU的缓存层。它具有较低的访问延迟和较高的带宽,但成本较高,容量较小。SRAM的访问速度非常快,适合频繁访问的数据存储和计算操作。

  • HBM(High Bandwidth Memory)是一种高带宽的内存技术,通常用于GPU等高性能计算设备。它通过将多个DRAM芯片垂直堆叠,并使用高速接口与处理器直接连接,提供了极高的带宽和较大的容量。HBM适合存储大量数据,但访问速度相对较慢,适合大规模数据存储和传输。

这就导致了传统的Attention计算过程需要频繁地在SRAM和HBM之间进行数据传输, 这会带来显著的性能瓶颈。FlashAttention通过优化计算流程,减少了这种数据传输的次数,从而提高了整体的计算效率。因此FlashAttention的核心优化点在于减少内存访问次数,充分利用SRAM的高速性能,同时降低对HBM的依赖,从而实现更高效的Attention计算。

传统分块计算过程

例如原本的$QK^{T}$一次计算过程进行拆分, 分别将$Q$和$K^{T}$划分为为$m$,和$n$个小块, 然后依次将$m_{i}$和$n_{i}$小块计算的结果放置到指定的区域. 当然, 这样操作会带来额外的通讯次数的开销, 变成m * n, 但对于存储架构来说, SRAM与HBM的通信速率是非常快的, 在这里的通讯次数开销是可以接受的.

传统分块计算过程

通信过程, 整个过程需要6次通信, 3次写入到SRAM, 3次到HBM中.

  1. 将矩阵 Q 和K从 HBM 分块加载到 SRAM 中
  2. 逐块计算 $S_{ij} = Q_{i}K_{j}^{T}$, 并将每个子矩阵计算得出的 $S_{ij}$ 从SRAM 写入HBM。
  3. 从 HBM 中加载所需的子矩阵 $S_{ij}$ 到 SRAM 中,为后续 softmax计算做准备。
  4. 对每个子矩阵 $S_{ij}$ 计算 softmax,得到$P_{ij}= softmax( S_{ij})$,并将每个子矩阵 $P_{ij}$从 SRAM 写入 HBM。
  5. 将矩阵 P和V从 HBM 分块加载到 SRAM 中。
  6. 将P和V分成较小的块,逐块计算 $O_{ij} =P_{i}V_{j}$,并将每个子矩阵 $O_{ij}$ 从 SRAM 写入 HBM.

FlashAttention的改进

FlashAttention改进了计算过程, 所有计算过程统一在SRAM中计算, 将最终的计算结果返回给HBM, 只进行一次读写. 其过程如下.

FlashAttention对内存读写的改进

online-softmax 分块计算原理

原本softmax 需要先计算出所有的$S_{ij}$, 然后再进行softmax计算, 但是FlashAttention的online-softmax算法, 通过分块计算的方式, 在每个块计算完之后就进行softmax计算, 并且在计算过程中维护一个全局的最大值和指数和, 来保证数值稳定性. 具体来说, 在每个块计算完之后, 会更新全局的最大值和指数和, 以便在下一个块计算时使用. 这样做不仅可以节约内存带宽,还可以提高计算效率. 核心公式如下:

$𝑚_{𝑛𝑒𝑤}=max⁡(𝑚_{𝑜𝑙𝑑}−𝑚_{c})$

$𝑙_{𝑛𝑒𝑤}=𝑙_{𝑜𝑙𝑑} ·𝑒^{(𝑚_{𝑜𝑙𝑑}−𝑚_{𝑛𝑒𝑤} )}+𝑙_{c} ·𝑒^{(𝑚_{c}−𝑚_{𝑛𝑒𝑤} )}$

其中:

  • $𝑚_{𝑐}$: 当前块最大值
  • $𝑚_{𝑜𝑙𝑑}$: 当前全局最大值
  • $𝑙_{𝑐}$: 当前块局部指数和
  • $𝑙_{𝑜𝑙𝑑}$: 当前全局分母 指数和

详细计算原理示例及讲解: 小红书: 图解Flash Attention核心原理

online_softmax_example

不同的Attention结构

多个Head共享使用1组KV,将原来每个Head一个KV,变成1组Head一个KV,来压缩KV的存储。代表方法:GQA,MQA等

MHA, MQA, GQA, MLA

Multi-Head Attention

图1, 每一层的所有Head都独立拥有自己的KQV权重矩阵, 计算时各自使用自己的权重计算.

Multi-Query Attention

图2, 每一层的所有Head,按照数量分组, 一组的成员, 在计算Attention时, KV权重矩阵是共享的, 只有Q权重矩阵是独立的.

例如每个Head中的Q权重矩阵是独立的, 但每组的Q权重共享一组KV权重矩阵, 这样就减少了KV的存储, 当然也会损失一定的性能. 比如原来MHA的Head有8个KQV权重矩阵, 现在进行分组后, 每两个Q权重矩阵共享一组KV权重矩阵, 则现在每个Head8个Q权重矩阵, 但只有4个KV权重矩阵, KV的存储就减少了一半.

Group-Query Attention

图3, 每个Head的Q权重矩阵是独立的, 但所有Head共享一组KV权重矩阵, 这样就进一步减少了KV的存储, 当然也会损失更多的性能. 比如原来MHA的Head有8个KQV权重矩阵, 现在进行分组后, 每个Head8个Q权重矩阵, 但只有1个KV权重矩阵, KV的存储就减少了7/8.

输入$X$在一个Head得到的矩阵为$Q, K, V$, 假设Head数量为4, 分组数量G为2, 现在我们得到的矩阵如下:

$Q$权重矩阵$K$权重矩阵$V$权重矩阵
$Q_1, Q_2, Q_3, Q_4$$K_1, K_1, K_2, K_2$$V_1, V_1, V_2, V_2$

而实际在multi-head attention中, 多个head的 $Q$ 矩阵, 其得到可以由一次计算得出

$Q=\left[Q_1, Q_2, Q_3, Q_4\right]=XW^{Q}, W^{Q} \in R^{d \times\left(d_k \times h\right)}$

其中K的计算会稍微复杂一些, 因为每个Head的K权重矩阵是共享的, 但不同组的K权重矩阵又不共享, 因此需要先计算出每组的K矩阵, 然后再将其复制到对应的Head中. 其中:

$$K_{\downarrow}=\left[K_1, K_2\right]=XW^{K}, W^{K} \in R^{d \times\left(d_k \times n\right)}$$

因此$XW^{K}$权重矩阵的维度是$d \times\left(d_k \times n\right)$, 其中n是分组数量, 每组对应一个K权重矩阵. 然后将每组的K矩阵复制到对应的Head中, 得到最终的K矩阵为:

$$ K=\left[K_1, K_1, K_2, K_2\right]=\left[K_1, K_2\right]\left[\begin{array}{rrrr} I_{d_k} & I_{d_k} & 0 & 0 \\ 0 & 0 & I_{d_k} & I_{d_k} \end{array}\right] $$

整个计算过程表达式可以表示为:

$$ K=XW^{K} \left[\begin{array}{rrrr} I_{d_k} & I_{d_k} & 0 & 0 \\ 0 & 0 & I_{d_k} & I_{d_k} \end{array}\right]\\ =XW^{K}_{\uparrow} $$

$V$的计算过程和K一样, 整理一下, 我们现在得到了:

$$K=XW^{K}_{\uparrow}, V=XW^{V}_{\uparrow}$$

这部分也可以合并起来一次计算, 我们将KV其权重矩阵拼接得到:

$$ XW^{KV} = [K_{\downarrow}, V_{\downarrow}] = \left[K_1, K_2, V_1, V_2\right] W^{KV} , W^{KV} \in R^{d \times\left(d_k n_g+d_v n_g\right)}$$

同理, 再乘以一个拼接的复制矩阵, 得到最终的K和V矩阵:

$$XW_{\uparrow}^{KV} = \left[K_1, K_2, V_1, V_2\right]\left[\begin{array}{rrrr} I_{d_k} & I_{d_k} & 0 & 0 \\ 0 & 0 & I_{d_k} & I_{d_k} \\ I_{d_v} & I_{d_v} & 0 & 0 \\ 0 & 0 & I_{d_v} & I_{d_v} \\ \end{array}\right]\\$$

Multi-Head Latent Attention

MLA由GQA和GQA的基础上发展而来, 其核心思想是将每个Transformer层的KV权重矩阵分解成两个低秩矩阵, 其中一个矩阵是输入序列的线性变换, 另一个矩阵是一个小的可学习参数矩阵. 这样做的好处是可以大幅减少KV权重矩阵的存储需求, 同时保持较好的性能.

$$ 令 C^{KV} = XW^{KV}$$

通过分解得到 $C^{KV} = C^{K}C^{V}$, 其中$C^{K}$是输入序列的线性变换矩阵, $C^{V}$是一个小的可学习参数矩阵. 这样就将原来的KV权重矩阵分解成了两个低秩矩阵, 从而大幅减少了存储需求.

MLA

⭐️推荐观看, 可视化讲解, https://www.bilibili.com/video/BV17QSpBDEHG

Page Attention

由于传统申请连续内存的方式在处理长序列时会导致内存碎片化和性能下降, Page Attention通过将内存划分为固定大小的页面(pages), 并使用虚拟地址空间来管理这些页面的访问和存储, 这种机制允许模型在处理长序列时, 可以更灵活地管理内存资源, 从而提高计算效率和性能.

具体来说, Page Attention通过将输入序列划分为多个页面(pages), 每个页面可以独立地进行Attention计算, 并且通过虚拟地址空间来管理这些页面的访问和存储. 这样做不仅可以减少内存的使用, 还可以提高模型在处理长序列时的效率和性能.

示例及详解 https://zhuanlan.zhihu.com/p/9632325957

参考地址

[1] deepseek-v2
[2] deepseek-v3
[3] deepseek技术解读(1)-彻底理解MLA(Multi-Head Latent Attention)
[4] 知乎: FlashAttention算法详解