当前位置: 首页 > news >正文

大模型显示优化之ZeRO-1/ZeRO-2/ZeRO-3

1. 简介

zero-1、zero-2、zero-3 是deepspeed的配置方法,对应megatron也有相应的方法,Megatron-LM 的实现方式:Distributed Optimizer(分布式优化器)。等效于 ZeRO-1,Megatron 的 Distributed Optimizer 默认行为就是将优化器状态(Optimizer States)均匀地切分并分布在数据并行(DP)组的所有 GPU 上。等效于 ZeRO-2,由于 Megatron 通常结合混合精度训练,它在计算完梯度后,会通过Reduce-Scatter操作直接将梯度同步并切分到各卡上,不再保留全量梯度。这在效果上完全等同于 ZeRO-2。zero-3 将参数也拆分卡来存,但后续实际反向梯度更新时操作时还是需要all-gather,参数显存还是会全量缓存,再一个Megatron针对参数拆分,更多使用的是TP/PP拆分,所以业界megatron架构使用zero-3不多, 所以本文不做重点分析。

Zero架构说的是DP并行域GPU之间

阶段优化对象核心原理效果
ZeRO-1优化器状态 (OS)将优化器状态切分并分布到各个 GPU 上,每个 GPU 只负责更新自己那一块。显存占用降低约为原来的 1/4(以 Adam 为例)。
ZeRO-2OS + 梯度 (G)在 ZeRO-1 基础上,进一步将梯度也进行切分。每个 GPU 只保留对应参数的梯度。进一步降低显存占用,是目前最常用的平衡配置。
ZeRO-3OS + G + 参数 (W)最彻底的切分。模型参数在平时也分布在不同 GPU 上,只有在正向/反向传播需要时才临时同步。显存占用理论上随 GPU 数量线性下降,支持训练超大规模模型。

实际官方Megatron实现中,ZeRO-2 反向不只是对梯度进行切分,还对参数在back阶段进行了小段时间的切分,后面AllGather回收,是一个技术操作。这样好处:

1. 节省显存

2. 避免冗余计算

3. 最后的AllGather可以和后续的layer forward 做overlap

纯 DPZeRO-2
Forward各 rank 用完整 W各 rank 用完整 W(相同)
Backward 后通信AllReduce梯度(每人拿完整梯度)ReduceScatter梯度(每人只拿 1/DP 梯度,显存也只存1/DP属于自己的梯度)
Optimizer step各自完整更新 W(结果一致,冗余计算)各自只更新 W 的 1/DP 段(此更新过程比较复杂)
Step 后无需额外通信(W 天然一致)需要AllGather W恢复完整参数
显存节省梯度 + 优化器状态各节省 1/DP

注意:

AdamW全局grad_norm

路径通信方式时机
标准路径all_reduceon model parallel group(TP × DP)optimizer.step() 内,clip grad 前
PP bypass 路径TP 内all_reduce+ PP 间send/recv逐 stage 累加pre_step 阶段,流水线化减少同步 barrier

AdamW 的step()中确实有一次全局 grad norm 的all_reduce通信,用于计算全局 L2 norm 以确定clip_coeff(梯度裁剪系数)。这是每一步更新都必须做的集合通信,会引入跨所有 model parallel rank 的同步点。

2. 显存与通信量分析

为了让 ZeRO-1 和 ZeRO-2 的区别更加直观,我把之前流程图里的抽象内容,具体化成了4 张 GPU 卡在不同阶段的显存状态

这样你可以像看“快照”一样,清晰地看到每张卡上到底存了什么。

设定:假设模型有4个参数块:[P0, P1, P2, P3]。4 张 GPU 卡训练。FP16训练的模型为例,参数量为

  1. 参数 (Weights):字节。

  2. 梯度 (Grads):字节。

  3. 优化器Adam 状态:

    • FP32 权重副本(为了精度):

    • Momentum(动量):

    • Variance(方差):


场景一:ZeRO-1 (只切分优化器状态)

核心特征:每张卡都有完整的参数完整的梯度,但只负责更新1/4的优化器状态

GPU 卡前向/反向计算时梯度通信后 (All-Reduce)参数更新后
GPU 0参数:[P0, P1, P2, P3]
梯度:[G0, G1, G2, G3]
优化器状态:[O0](只负责P0)
梯度:[G_avg0, G_avg1, G_avg2, G_avg3]
(已同步为平均梯度)
*G_avg0更新O0, 计算出P0_new
然后拼出完整参数[P0_new, P1_new, P2_new....]
GPU 1参数:[P0, P1, P2, P3]
梯度:[G0, G1, G2, G3]
优化器状态:[O1](只负责P1)
梯度:[G_avg0, G_avg1, G_avg2, G_avg3]*G_avg1更新O1, 计算出P1_new
然后拼出完整参数[P0_new, P1_new, P2_new....]

显存占用

  • 。因为每张卡都要存下4份参数 + 4份梯度

  • 冗余度高P0被同时存在了 4 张卡上。

场景二:ZeRO-2 (切分梯度 + 优化器状态)

核心特征:每张卡有完整的参数,但只保留1/4的梯度,并只更新对应的1/4优化器状态。

GPU 卡前向/反向计算时 (初始)梯度通信后 (Reduce-Scatter)参数更新后
GPU 0参数:[P0, P1, P2, P3]
梯度(原始):[G0, G1, G2, G3]
优化器状态:[O0]
梯度(保留):[G_avg0]
梯度(丢弃):[G_avg1, G_avg2, G_avg3]✔️ 丢弃
G_avg0更新O0, 计算出P0_new
然后通过 All-Gather 从其他卡获取 P1~P3 的更新。
GPU 1参数:[P0, P1, P2, P3]
梯度(原始):[G0, G1, G2, G3]
优化器状态:[O1]
梯度(保留):[G_avg1]
梯度(丢弃):[G_avg0, G_avg2, G_avg3]✔️ 丢弃
G_avg1更新O1, 计算出P1_new
然后通过 All-Gather 从其他卡获取 P0, P2, P3 的更新。

显存占用

  • 中等。每张卡存4份参数 + 1份梯度

  • 显存优化:相比 ZeRO-1,节省了 3 份梯度的存储空间


两张图的对比总结
特征ZeRO-1 (图里场景)ZeRO-2 (图里场景)
每张卡上的参数全部[P0, P1, P2, P3]全部[P0, P1, P2, P3]
每张卡上的梯度全部[G_avg0...G_avg3](All-Reduce后)只有1块[G_avg0](Reduce-Scatter后)
优化器状态分片[O0]分片[O0]
参数更新方式各卡独立计算出完整参数各卡计算部分参数,再互相广播合并
主要节省不节省梯度节省了3/4的梯度显存

通过这两张“快照”,你应该能清晰地看到:ZeRO-2 的本质,就是用梯度通信后的一个“丢弃”动作,换来了大量的显存空间。

通信量总结

维度ZeRO-1ZeRO-2ZeRO-3
参数存储完整 (每卡都有)完整 (每卡都有)切分(每卡1/DP)
梯度存储完整 (每卡都有)切分(每卡1/DP)切分(每卡1/DP)
优化器状态切分 (每卡1/DP)切分 (每卡1/DP)切分 (每卡1/DP)
单卡模型状态显存2Ψ + 2Ψ + 12Ψ/DP2Ψ + 2Ψ/DP + 12Ψ/DP(2Ψ+2Ψ+12Ψ)/DP
主要通信All-Reduce (梯度)Reduce-Scatter + All-GatherAll-Gather ×2 + Reduce-Scatter
通信量2×Ψ(最小)2×Ψ3×Ψ(最大)
显存节省仅优化器状态优化器+梯度全部
3. Megatron ZeRO配置
Stage分片内容Megatron对应参数
ZeRO-1优化器状态分片(m,v)--user-distributed-optimizer
ZeRO-2优化器分片+梯度分片

--user-distributed-optimizer+

--overlap-grad-reduce

ZeRO-3优化器分片+梯度+参数需要单独搞
4. ZeRO2架构 backward过程

计算梯度和更新参数的过程

http://www.cnnetsun.cn/news/2587376.html

相关文章:

  • 关于大学专业课如何去正确学习
  • 阿里云个人测试SSL证书申请及部署
  • Android系统中的AI融合技术:架构设计与实践
  • Prompt工程×前端渲染×实时协同,Lovable写作助手开发全流程解析,含GitHub可运行代码库
  • 三相异步电动机定子磁动势的谐波分析与抑制策略
  • AI Agent上云到底卡在哪?揭秘92%团队在K8s调度Agent时忽略的4个Operator级配置漏洞
  • 科研党福音:手把手教你搞定Matlab+Gurobi学术版安装(附IP验证避坑指南)
  • cartopy 绘制中国地图:从基础边界到南海诸岛与十段线的完整实践
  • 5分钟学会B站缓存视频转换:永久保存你收藏的珍贵内容
  • Linux---进程(概念,PCB,进程属性,标示符,fork)
  • RAG 高级技术与调优实战手册
  • 自治系统失控:从故障模式到抗错设计的工程实践
  • 构建稳健AI应用:隔离、容错与可观测性架构设计实践
  • pypto:用Python直接写NPU算子,门槛有多低?
  • 保姆级教程:用RDPWrap解锁Win10/11家庭版远程桌面,还能多人同时登录
  • 告别混乱状态机!用UE4行为树+黑板实现智能敌人AI(实战案例解析)
  • Unity 2022.3.3 LTS + Visual Studio 2022:手把手教你复刻《吸血鬼幸存者》核心战斗(附完整源码)
  • Taotoken模型广场首发更新Qwen与Gemini等旗舰模型体验
  • 模型评测为什么一上对抗攻击测试就开始高分低防御:从 Adversarial Prompt 到 Robustness Budget 的工程实战
  • 淘宝任务自动化终极指南:5分钟解放双手的免费淘金币脚本
  • “襄阳造”打磨车出口毛里塔尼亚
  • 贝叶斯双重机器学习:高维因果推断的去偏与不确定性量化
  • Claude Code VS Code扩展:AI编程代理的工程化实践
  • TikTok 短视频生成工具哪家好?爆款视频复刻工具实用推荐
  • Godot PCK文件结构解析与安全解包实战指南
  • sqlmap原理深度解析:从DVWA靶场看SQL注入本质
  • 机器学习辅助高通量筛选:uMLIP与迁移学习加速功能材料发现
  • GBase 8s数据库常见问题排查及解决方法简述
  • 机器学习与模拟退火优化布尔特征集变量排序,加速密码分析计算
  • Unity Hub安装Android组件失败的真相与三步修复法