067、混合精度训练 autocast 源码:前向 FP16到Loss Scale到反向 FP32 的完整机制
067、混合精度训练 autocast 源码:前向 FP16到Loss Scale到反向 FP32 的完整机制
一、从一次显存爆炸说起
去年有个项目,训练YOLOv8-L,batch size设到32,RTX 4090 24G显存直接爆了。同事说“降batch size呗”,我说“试试混合精度”。结果改了四行代码,batch size拉到48,显存占用反而降了30%,训练速度还快了近一倍。但第二天发现loss曲线在某个epoch后突然变成NaN,模型直接废了。
这就是混合精度训练的典型坑:FP16精度不够,梯度下溢,模型崩了。后来我翻autocast源码才彻底搞明白,这东西不是简单地把float32转成float16就完事了,背后有一套完整的动态Loss Scale机制在兜底。
二、autocast到底干了什么
很多人以为with torch.cuda.amp.autocast():就是自动把模型输入转成FP16。错。它做的是选择性精度转换——哪些操作用FP16算,哪些必须用FP32保精度,autocast内部有个白名单。
看PyTorch 2.0的源码,autocast核心逻辑在torch/cuda/amp/autocast.py。它维护了一个_cast字典,里面记录了每个op的精度偏好:
# 源码简化版,实际在torch/amp/autocast_mode.py_CASTS={torch.addmm:['fp16','fp16','fp16'],# 三个参数都转fp16torch.mm:['fp16','fp16'],torch.bmm:['fp16','fp16'],torch.conv2d:['fp16','fp16','fp32'],# 权重和输入转fp16,bias保持fp32torch.softmax:['fp32'],# softmax必须fp32,否则梯度爆炸torch.layer_norm:['fp32'],# layer norm也是}这里踩过坑:softmax和layer norm在FP16下精度损失极大,尤其是YOLO的检测头里用了softmax做分类,如果autocast没把它保护成FP32,训练到一半loss直接飞掉。PyTorch官方已经把这些op写死了,但如果你自己写了个自定义op,记得手动加装饰器@torch.cuda.amp.custom_fwd。
三、前向传播:FP16的“偷懒”哲学
前向时,autocast的工作流程是这样的:
- 拦截op调用:每个torch操作被
_cast函数包裹,检查输入类型 - 精度转换:如果op在白名单里,把float32 tensor转成float16(半精度)
- 执行计算:用FP16做矩阵乘法、卷积等计算,速度翻倍
- 结果缓存:输出保持FP16,但autocast会记录哪些tensor是“关键节点”
别这样写:x = x.half()手动转。autocast会自动处理,你手动转反而可能破坏它的精度选择逻辑。我见过有人把输入手动转成FP16,结果softmax也变成FP16算,loss直接NaN。
关键点:autocast只在with块内生效,块外的操作不受影响。所以训练循环里,前向和loss计算要包在autocast里,反向传播和优化器更新在外面。
四、Loss Scale:那个救命的缩放因子
FP16的数值范围是[-65504, 65504],但梯度通常很小(比如1e-5),在FP16下直接变成0,这就是下溢。Loss Scale就是把这个梯度放大,算完再缩回去。
PyTorch的GradScaler源码在torch/cuda/amp/grad_scaler.py,核心逻辑:
classGradScaler:def__init__(self,init_scale=2.**16,growth_factor=2.0,backoff_factor=0.5,growth_interval=2000):self._scale=torch.tensor(init_scale,dtype=torch.float32)self._growth_factor=growth_factor self._backoff_factor=backoff_factor self._growth_interval=growth_interval self._growth_tracker=0# 记录连续无溢出步数defscale(self,loss):# 把loss放大,避免梯度下溢returnloss*self._scaledefunscale_(self,optimizer):# 反向传播后,把梯度缩回来forgroupinoptimizer.param_groups:forpingroup['params']:ifp.gradisnotNone:p.grad.data.div_(self._scale)defstep(self,optimizer):# 检查是否有梯度溢出(NaN或Inf)ifself._has_inf_or_nan(optimizer):self._scale*=self._backoff_factor# 发现溢出,缩小scaleself._growth_tracker=0optimizer.zero_grad()# 跳过这步更新else:self._scale*=self._growth_factor# 连续无溢出,放大scaleself._growth_tracker+=1ifself._growth_tracker>=self._growth_interval:self._scale*=self._growth_factor self._growth_tracker=0这里踩过坑:unscale_必须在step之前调用。如果你先调了optimizer.step()再unscale,梯度已经被更新了,scale白做了。PyTorch官方推荐写法:
scaler.scale(loss).backward()# 前向+反向scaler.unscale_(optimizer)# 先缩梯度scaler.step(optimizer)# 再更新参数scaler.update()# 最后调整scale别这样写:scaler.step(optimizer)之后才调unscale_。我debug过一整天,发现loss曲线震荡,就是因为顺序搞反了。
五、反向传播:FP32的“救场”机制
反向传播时,autocast会自动把梯度转回FP32。为什么?因为梯度计算需要高精度。FP16的梯度更新参数,相当于用一把尺子量头发丝,误差太大。
看反向传播的源码逻辑:
# 在autocast模式下,每个op的反向函数被包装classAutocastFunction(torch.autograd.Function):@staticmethoddefforward(ctx,input,weight,bias=None):# 前向用FP16input_fp16=input.half()weight_fp16=weight.half()output=torch.conv2d(input_fp16,weight_fp16,bias)ctx.save_for_backward(input_fp16,weight_fp16,bias)returnoutput@staticmethoddefbackward(ctx,grad_output):# 反向时,梯度自动转回FP32input_fp16,weight_fp16,bias=ctx.saved_tensors grad_input=torch.conv2d_backward(grad_output.float(),# 这里转成FP32input_fp16.float(),weight_fp16.float(),bias)returngrad_input关键点:梯度计算用FP32,但参数更新时又转回FP16。所以模型权重在内存里是FP16,但梯度计算时临时转成FP32,算完再转回去。这解释了为什么混合精度能省显存——权重存FP16,但计算时临时用FP32,用完就释放。
六、YOLO实战中的坑与优化
在YOLOv8的训练代码里,混合精度配置是这样的:
scaler=torch.cuda.amp.GradScaler(enabled=args.amp)forbatchindataloader:withtorch.cuda.amp.autocast(enabled=args.amp):preds=model(images)loss=compute_loss(preds,targets)scaler.scale(loss).backward()scaler.unscale_(optimizer)torch.nn.utils.clip_grad_norm_(model.parameters(),max_norm=10.0)scaler.step(optimizer)scaler.update()optimizer.zero_grad()这里有个隐藏坑:clip_grad_norm_必须在unscale_之后。因为梯度被scale放大了,直接clip会剪掉正常梯度。我见过有人把clip放在unscale前面,结果梯度被剪成0,模型不收敛。
另一个坑:YOLO的检测头里用了nn.SiLU激活函数,它在FP16下表现不稳定。解决方案是在autocast块外手动把检测头的输入转成FP32:
classDetect(nn.Module):defforward(self,x):withtorch.cuda.amp.autocast(enabled=False):# 关闭autocastx=x.float()# 强制FP32# 检测头计算...别这样写:整个模型都包在autocast里,然后指望它自动处理所有层。YOLO的检测头对精度敏感,必须手动干预。
七、个人经验:什么时候该用,什么时候别用
该用的情况:
- 显存不够,batch size上不去
- 模型大(YOLOv8-L以上),训练速度慢
- 梯度值在1e-3到1e-5之间,不会下溢
别用的情况:
- 模型很小(YOLOv5-n),显存充足,用FP32更快
- 任务对精度极度敏感(比如医学图像检测),FP16的误差不可接受
- 自定义op太多,且没有注册精度偏好
调试技巧:
- 第一次训练先用FP32跑10个epoch,记录loss范围
- 如果loss在1e-4以下,说明梯度太小,需要调高
init_scale(比如2^20) - 如果频繁出现
inf或nan,检查是否有op没被autocast保护 - 用
torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16)试试BF16,数值范围更大,不容易溢出
最后说一句:混合精度不是银弹。我见过有人为了省显存,强行用FP16训练YOLOv8-x,结果精度掉了2个点,还不如降batch size用FP32。工具是死的,人是活的,根据你的硬件和任务灵活选择。
