
好问题这是理解 ZeRO 为什么有效的前提。让我从 Adam 优化器的数学原理讲起。可以看一下—神经网络中优化器的作用Adam 为什么要保存这么多状态先回忆梯度下降的本质参数沿着梯度方向更新。但直接用梯度更新有很多问题——梯度噪声大、不同参数尺度不同、容易陷入局部最优。Adam 通过维护参数的历史梯度统计信息来解决这些问题。Adam 的更新公式对每个参数θ i \theta_iθiAdam 维护三个东西第 1 个FP32 参数副本 θ_i → 因为训练用 FP16但优化器计算需要 FP32 精度否则会溢出/下溢 第 2 个一阶动量 m_i梯度的指数移动平均 m_i β₁ × m_i (1 - β₁) × g_i → 记录梯度大致朝哪个方向平滑噪声 第 3 个二阶动量 v_i梯度平方的指数移动平均 v_i β₂ × v_i (1 - β₂) × g_i² → 记录梯度的变化幅度自动调节学习率更新参数时m̂_i m_i / (1 - β₁ᵗ) ← 偏差修正 v̂_i v_i / (1 - β₂ᵗ) ← 偏差修正 θ_i θ_i - lr × m̂_i / (√v̂_i ε)显存怎么算的每个参数需要保存 4 个数值状态精度每参数字节7B 模型总计FP16 参数模型本身FP162 bytes14 GBFP16 梯度FP162 bytes14 GBFP32 参数副本FP324 bytes28 GB一阶动量 mFP324 bytes28 GB二阶动量 vFP324 bytes28 GB合计16 bytes/参数112 GB其中优化器状态FP32 副本 m v 12 bytes/参数 × 7B 84 GB。这就是 84 GB 的来源3 个 FP32 变量 × 4 bytes × 70 亿参数。为什么需要 FP32 副本训练时模型参数用 FP16省显存、加速计算但 FP16 的范围太小最大值 65504精度只有 ~3 位有效数字。优化器的更新量通常很小比如 1e-5如果用 FP16 做θ - lr × m̂ / √v̂FP16: θ 1.234 lr × m̂ / √v̂ 0.00001 θ_new 1.234 - 0.00001 1.234 ← 更新被截断了完全没变化 FP32: θ 1.234567 lr × m̂ / √v̂ 0.00001 θ_new 1.234557 ← 更新被正确保留所以 Adam 内部必须用 FP32 做计算和保存状态更新完后再截断回 FP16 给模型用。不同优化器的显存对比优化器每参数额外状态额外 bytes/参数7B 模型额外显存SGD无动量无00 GBSGD Momentumm动量4 (FP32)28 GBAdaGradv梯度平方累积4 (FP32)28 GBAdamFP32副本 m v12 (FP32)84 GBAdamW同 Adam12 (FP32)84 GB8-bit AdamFP32副本 m(8bit) v(8bit)411 642 GBAdafactor分解的 m, v行列分解~2~14 GBAdam 是显存消耗最大的主流优化器。这就是为什么大模型训练领域有这么多显存优化技术的原因。显存优化的几种思路1. 8-bit Adambitsandbytes把 m 和 v 从 FP32 量化到 INT81 byte节省 6 bytes/参数原始 Adam: FP32副本(4) m(4) v(4) 12 bytes/参数 8-bit Adam: FP32副本(4) m_INT8(1) v_INT8(1) 6 bytes/参数 → 节省 50%原理动量值 m 和 v 的分布比较集中可以用动态缩放量化到 INT8 而不损失太多精度。2. Adafactor利用 Transformer 的结构特点——参数大多是矩阵m×n把 m 和 v分解成行向量和列向量标准 Adam: v 是 m×n 矩阵 → m×n 个值 Adafactor: v 分解为 row_factor(m) col_factor(n) → mn 个值 例一个 4096×4096 的权重矩阵 Adam: 4096 × 4096 16,777,216 个 v 值 Adafactor: 4096 4096 8,192 个值 → 节省 2000 倍代价精度有损失某些模型收敛变慢。3. 去掉 FP32 副本纯 FP16 优化一些研究尝试完全用 FP16 做优化器计算如 BF16 优化器省掉 4 bytes/参数的 FP32 副本。BF16 因为指数位和 FP32 一样8 bit范围够大可以在某些场景替代 FP32。4. ZeRO 切分本课文内容不改优化器本身而是把状态分散到多卡上。4 卡 ZeRO-1/2/3 时每卡只保存 1/4 的优化器状态 → 84/4 21 GB。一张图总结7B 模型训练显存构成FP16 混合精度 Adam 112 GB 总计 ┌──────────────────────────────────────────────┐ │ │ │ 模型参数 14GB │ 梯度 14GB │ 优化器状态 84GB │ │ (FP16) │ (FP16) │ │ │ ████ │ ████ │ ████████████████████ │ │ │ │ FP32副本 m v │ │ │ │ 28GB 28GB 28GB │ └──────────────────────────────────────────────┘ ↑ ↑ 真正需要的 68% 是冗余的 (不同数据产生 (每张卡都一样) 不同梯度) ZeRO-1 切分优化器状态后4 卡: ┌───────────────────────────────┐ │ 参数 14GB │ 梯度 14GB │ 优化器 21GB │ ← 每卡 49 GB └───────────────────────────────┘ ZeRO-3 全切分后4 卡: ┌────────────────┐ │ 参数 3.5 │ 梯度 3.5 │ 优化器 21GB │ ← 每卡 28 GB不含临时参数 └────────────────┘