当前位置: 首页 > news >正文

067、混合精度训练 autocast 源码:前向 FP16到Loss Scale到反向 FP32 的完整机制

067、混合精度训练 autocast 源码:前向 FP16到Loss Scale到反向 FP32 的完整机制

一、从一次显存爆炸说起

去年有个项目,训练YOLOv8-L,batch size设到32,RTX 4090 24G显存直接爆了。同事说“降batch size呗”,我说“试试混合精度”。结果改了四行代码,batch size拉到48,显存占用反而降了30%,训练速度还快了近一倍。但第二天发现loss曲线在某个epoch后突然变成NaN,模型直接废了。

这就是混合精度训练的典型坑:FP16精度不够,梯度下溢,模型崩了。后来我翻autocast源码才彻底搞明白,这东西不是简单地把float32转成float16就完事了,背后有一套完整的动态Loss Scale机制在兜底。

二、autocast到底干了什么

很多人以为with torch.cuda.amp.autocast():就是自动把模型输入转成FP16。错。它做的是选择性精度转换——哪些操作用FP16算,哪些必须用FP32保精度,autocast内部有个白名单。

看PyTorch 2.0的源码,autocast核心逻辑在torch/cuda/amp/autocast.py。它维护了一个_cast字典,里面记录了每个op的精度偏好:

# 源码简化版,实际在torch/amp/autocast_mode.py_CASTS={torch.addmm:['fp16','fp16','fp16'],# 三个参数都转fp16torch.mm:['fp16','fp16'],torch.bmm:['fp16','fp16'],torch.conv2d:['fp16','fp16','fp32'],# 权重和输入转fp16,bias保持fp32torch.softmax:['fp32'],# softmax必须fp32,否则梯度爆炸torch.layer_norm:['fp32'],# layer norm也是}

这里踩过坑:softmax和layer norm在FP16下精度损失极大,尤其是YOLO的检测头里用了softmax做分类,如果autocast没把它保护成FP32,训练到一半loss直接飞掉。PyTorch官方已经把这些op写死了,但如果你自己写了个自定义op,记得手动加装饰器@torch.cuda.amp.custom_fwd

三、前向传播:FP16的“偷懒”哲学

前向时,autocast的工作流程是这样的:

  1. 拦截op调用:每个torch操作被_cast函数包裹,检查输入类型
  2. 精度转换:如果op在白名单里,把float32 tensor转成float16(半精度)
  3. 执行计算:用FP16做矩阵乘法、卷积等计算,速度翻倍
  4. 结果缓存:输出保持FP16,但autocast会记录哪些tensor是“关键节点”

别这样写:x = x.half()手动转。autocast会自动处理,你手动转反而可能破坏它的精度选择逻辑。我见过有人把输入手动转成FP16,结果softmax也变成FP16算,loss直接NaN。

关键点:autocast只在with块内生效,块外的操作不受影响。所以训练循环里,前向和loss计算要包在autocast里,反向传播和优化器更新在外面。

四、Loss Scale:那个救命的缩放因子

FP16的数值范围是[-65504, 65504],但梯度通常很小(比如1e-5),在FP16下直接变成0,这就是下溢。Loss Scale就是把这个梯度放大,算完再缩回去。

PyTorch的GradScaler源码在torch/cuda/amp/grad_scaler.py,核心逻辑:

classGradScaler:def__init__(self,init_scale=2.**16,growth_factor=2.0,backoff_factor=0.5,growth_interval=2000):self._scale=torch.tensor(init_scale,dtype=torch.float32)self._growth_factor=growth_factor self._backoff_factor=backoff_factor self._growth_interval=growth_interval self._growth_tracker=0# 记录连续无溢出步数defscale(self,loss):# 把loss放大,避免梯度下溢returnloss*self._scaledefunscale_(self,optimizer):# 反向传播后,把梯度缩回来forgroupinoptimizer.param_groups:forpingroup['params']:ifp.gradisnotNone:p.grad.data.div_(self._scale)defstep(self,optimizer):# 检查是否有梯度溢出(NaN或Inf)ifself._has_inf_or_nan(optimizer):self._scale*=self._backoff_factor# 发现溢出,缩小scaleself._growth_tracker=0optimizer.zero_grad()# 跳过这步更新else:self._scale*=self._growth_factor# 连续无溢出,放大scaleself._growth_tracker+=1ifself._growth_tracker>=self._growth_interval:self._scale*=self._growth_factor self._growth_tracker=0

这里踩过坑:unscale_必须在step之前调用。如果你先调了optimizer.step()再unscale,梯度已经被更新了,scale白做了。PyTorch官方推荐写法:

scaler.scale(loss).backward()# 前向+反向scaler.unscale_(optimizer)# 先缩梯度scaler.step(optimizer)# 再更新参数scaler.update()# 最后调整scale

别这样写:scaler.step(optimizer)之后才调unscale_。我debug过一整天,发现loss曲线震荡,就是因为顺序搞反了。

五、反向传播:FP32的“救场”机制

反向传播时,autocast会自动把梯度转回FP32。为什么?因为梯度计算需要高精度。FP16的梯度更新参数,相当于用一把尺子量头发丝,误差太大。

看反向传播的源码逻辑:

# 在autocast模式下,每个op的反向函数被包装classAutocastFunction(torch.autograd.Function):@staticmethoddefforward(ctx,input,weight,bias=None):# 前向用FP16input_fp16=input.half()weight_fp16=weight.half()output=torch.conv2d(input_fp16,weight_fp16,bias)ctx.save_for_backward(input_fp16,weight_fp16,bias)returnoutput@staticmethoddefbackward(ctx,grad_output):# 反向时,梯度自动转回FP32input_fp16,weight_fp16,bias=ctx.saved_tensors grad_input=torch.conv2d_backward(grad_output.float(),# 这里转成FP32input_fp16.float(),weight_fp16.float(),bias)returngrad_input

关键点:梯度计算用FP32,但参数更新时又转回FP16。所以模型权重在内存里是FP16,但梯度计算时临时转成FP32,算完再转回去。这解释了为什么混合精度能省显存——权重存FP16,但计算时临时用FP32,用完就释放。

六、YOLO实战中的坑与优化

在YOLOv8的训练代码里,混合精度配置是这样的:

scaler=torch.cuda.amp.GradScaler(enabled=args.amp)forbatchindataloader:withtorch.cuda.amp.autocast(enabled=args.amp):preds=model(images)loss=compute_loss(preds,targets)scaler.scale(loss).backward()scaler.unscale_(optimizer)torch.nn.utils.clip_grad_norm_(model.parameters(),max_norm=10.0)scaler.step(optimizer)scaler.update()optimizer.zero_grad()

这里有个隐藏坑:clip_grad_norm_必须在unscale_之后。因为梯度被scale放大了,直接clip会剪掉正常梯度。我见过有人把clip放在unscale前面,结果梯度被剪成0,模型不收敛。

另一个坑:YOLO的检测头里用了nn.SiLU激活函数,它在FP16下表现不稳定。解决方案是在autocast块外手动把检测头的输入转成FP32:

classDetect(nn.Module):defforward(self,x):withtorch.cuda.amp.autocast(enabled=False):# 关闭autocastx=x.float()# 强制FP32# 检测头计算...

别这样写:整个模型都包在autocast里,然后指望它自动处理所有层。YOLO的检测头对精度敏感,必须手动干预。

七、个人经验:什么时候该用,什么时候别用

该用的情况

  • 显存不够,batch size上不去
  • 模型大(YOLOv8-L以上),训练速度慢
  • 梯度值在1e-3到1e-5之间,不会下溢

别用的情况

  • 模型很小(YOLOv5-n),显存充足,用FP32更快
  • 任务对精度极度敏感(比如医学图像检测),FP16的误差不可接受
  • 自定义op太多,且没有注册精度偏好

调试技巧

  1. 第一次训练先用FP32跑10个epoch,记录loss范围
  2. 如果loss在1e-4以下,说明梯度太小,需要调高init_scale(比如2^20)
  3. 如果频繁出现infnan,检查是否有op没被autocast保护
  4. torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16)试试BF16,数值范围更大,不容易溢出

最后说一句:混合精度不是银弹。我见过有人为了省显存,强行用FP16训练YOLOv8-x,结果精度掉了2个点,还不如降batch size用FP32。工具是死的,人是活的,根据你的硬件和任务灵活选择。

http://www.cnnetsun.cn/news/2843571.html

相关文章:

  • RAG 知识库增量更新与版本管理:从全量重建到实时生效
  • TypeScript 编程中 Jest 单元测试的类型 Mock 与 Spy 详解
  • 15分钟搭建个人游戏云:Sunshine开源串流服务器完全指南
  • 终极Windows热键侦探:3步快速定位快捷键冲突根源
  • 【鸿蒙原生开发会议随记 Pro】用 NavPathStack 收拢会议页面跳转和返回刷新
  • 3步掌握抖音内容高效采集:从单条视频到批量资源的完整解决方案
  • 大模型+Skills=MCP?深度解析智能体核心组件,告别概念混乱!
  • Python+OpenCV多目标跟踪实战:鼠标框选目标、KCF算法实时跟踪、含完整实验文档与测试视频
  • 网盘下载速度慢?这个开源工具帮你一键获取高速直链下载地址![特殊字符]
  • 别再让标题和摘要拖后腿!SCI/SSCI论文投稿前必看的5个自查清单(附实例)
  • 从用户体验出发:聊聊Vue项目中Loading动画设计的那些‘坑’与最佳实践
  • 论Web服务技术的应用与发展
  • IEEE论文投稿不求人:手把手教你用BibTeX和Mathtype高效管理参考文献与公式
  • 有哪些高效的NOI省选专题题目解题技巧
  • 【论文复现】基于行波理论的输电线路故障诊断方法研究附Simulink仿真
  • SAP 物料主数据计划变更实战,如何让 Material Master 在未来某一天生效
  • COM3D2.MaidFiddler:3分钟上手的游戏实时编辑器完全指南
  • 双喜临门|腾视科技杭州总部及深圳子公司乔迁新址,以全新姿态奔赴新征程!
  • 重大升级|大家反映配置最复杂的“会务报名”也变成“点哪儿改哪儿”啦!
  • 终极指南:三步免费解锁WeMod专业版所有高级功能
  • 6字符内CRC32碰撞生成器:输入校验值或明文,秒出多组不同字符串但相同CRC结果
  • Beyond Compare 5密钥生成终极指南:三种方案深度解析与实战应用
  • 16MB大存储版,ESP32-S3-WROOM-1-N16适合哪些AIoT项目?
  • VRM-Addon-for-Blender终极指南:从模型创建到VR应用集成的深度解析
  • 大规模MIMO能效优化仿真工具:一键跑通功率与天线数联合寻优全流程
  • Python图像处理实战:电商主图光照校正与主体分割
  • 三步掌握微信数据库解密:轻松访问你的聊天记录
  • 解锁专业工作流:3分钟掌握Adobe插件智能安装方案
  • STM32F103搭配AD7616实现16路电压同步采集的可运行工程(含串口上传与波形示例)
  • 2048-AI:揭秘高效期望最大化算法在经典数字游戏中的实战应用