LLM - GPU显存占用

Posted by MakiNaruto on Wed, Jan 15, 2025

存储分类

首先先看看模型计算过程中, 哪些过程需要被存储下来.
存储主要分为两大块:静态显存 + 动态显存
M_total = 模型参数 + 梯度 + 优化器状态 + Activation

Model States 指和模型本身息息相关的,必须存储的内容,具体包括:

  • parameters(固定):模型参数
  • gradients(固定):梯度
  • optimizer states(固定):优化器状态,例如 Adam 优化器的 momentum 和 variance 等。
  • activation(固定):激活值。虽然它不是必须存储的,但在训练过程中会额外产生的内容.
Model States精度Bytes
Model WeightsFP162
GradientsFP162
Master WeightsFP324
Adam Optimizer (m/v)FP328
Total / 总计-16 Bytes

Activation

M_act ≈ C × L × B × S × H × bytes

📊 显存优化效果对比
无 Checkpoint : C≈10~15 (显存占用极高)
开启 Checkpoint : C≈2~4 (显存降低3~5倍)
FlashAttention : C≈1.5 (显存占用最低)
Checkpoint的核心策略, 用计算换取显存, 反向传播时重算部分激活值,大幅降低峰值.