显存占用分类
首先先看看模型计算过程中, 哪些过程需要被存储下来.
推理阶段
$$M_{KV} = 2 × L × N_{𝑘𝑣}× D × S × B (bytes)$$$$M_{per\_token} = 2 × L × N_{𝑘𝑣} × D (bytes/token)$$- N_{𝑘𝑣} 为 K V矩阵的数量, 一个head 2 个(K + V各一个)
- D 为每个head的维度
- S 为输入序列长度
- B 为batch size
- L 为层数
训练阶段
存储主要分为两大块:静态显存 + 动态显存
M_total = 模型参数 + 梯度 + 优化器状态 + Activation
Model States 指和模型本身息息相关的,必须存储的内容,具体包括:
- parameters(固定):模型参数
- gradients(固定):梯度
- optimizer states(固定):优化器状态,例如 Adam 优化器的 momentum 和 variance 等。
- activation(固定):激活值。虽然它不是必须存储的,但在训练过程中会额外产生的内容.
下面是一个表格,列出了每个模型状态的精度和占用的字节数:
| Model States | 精度 | Bytes |
|---|---|---|
| Model Weights | FP16 | 2 |
| Gradients | FP16 | 2 |
| Master Weights | FP32 | 4 |
| Adam Optimizer (m/v) | FP32 | 8 |
| Total / 总计 | - | 16 Bytes |
以一个7B的模型为例, 若不做任何优化, 其显存占用可以估算如下:
Total = 模型 + 梯度 + 优化器状态 + activation
= 7B * (2bytes + 2bytes + 12bytes) + activation
= 7B * 16 bytes + activation
= 112 B bytes / (1024 ** 3) + activation
≈ 104GB + activation
Activation
$$ M_{act} ≈ C × L × B × S × H × bytes $$$$ M_{act}≈L⋅B⋅S⋅H⋅(3(QKV)+1(attn_{out})+4(MLP_{expand})+4(MLP_{intermediate} + residual))⋅bytes $$$$ M_{act} ≈ 12⋅L⋅B⋅S⋅H⋅bytes $$📊 显存优化效果对比
无 Checkpoint : C≈10~15 (显存占用极高)
开启 Checkpoint : C≈2~4 (显存降低3~5倍)
FlashAttention : C≈1.5 (显存占用最低)
Checkpoint的核心策略, 用计算换取显存, 反向传播时重算部分激活值,大幅降低峰值.