057、BaseTrainer初始化源码精读:模型、数据、优化器、调度器的初始化全流程
057、BaseTrainer初始化源码精读:模型、数据、优化器、调度器的初始化全流程
上周帮一个学员调YOLOv8的训练脚本,他卡在一个诡异的bug上——模型跑着跑着loss突然变成NaN,但同样的配置在另一台机器上就正常。排查了半天,最后发现是优化器初始化时参数组没对齐,学习率调度器把warmup阶段的lr直接推到了负值。这种问题,说白了就是没吃透BaseTrainer的初始化流程。
今天咱们就把ultralytics的BaseTrainer初始化源码扒开,一行一行过。我用的版本是ultralytics 8.0.200,不同版本可能有细微差异,但核心逻辑没变过。
入口:__init__方法
classBaseTrainer:def__init__(self,cfg=DEFAULT_CFG,overrides=None,_callbacks=None):这里有个坑——cfg参数默认是DEFAULT_CFG,这是个全局配置字典。如果你直接传一个修改过的字典进去,小心引用传递的问题。我习惯这样写:
# 别这样写,会污染全局配置cfg=DEFAULT_CFG cfg['lr0']=0.01# 正确做法cfg=deepcopy(DEFAULT_CFG)cfg['lr0']=0.01接下来是overrides参数,这个设计得很巧妙。它允许你在不修改原始配置文件的情况下,临时覆盖某些参数。比如你想快速测试不同的batch size:
trainer=BaseTrainer(overrides={'batch':16,'epochs':50})模型初始化:从配置到网络
self.model=self.get_model(cfg)get_model方法在子类中实现,YOLOv8的DetectionTrainer里是这样写的:
defget_model(self,cfg,weights=None,verbose=True):model=Model(cfg.model,ch=3,nc=self.data['nc'],verbose=verbose)returnmodel注意这里传的是cfg.model,不是整个cfg。cfg.model是一个字符串,指向模型配置文件(比如yolov8n.yaml)。Model类会解析这个yaml文件,构建出完整的网络结构。
这里踩过坑——如果你自定义了模型结构,一定要确保yaml文件里的nc(类别数)和你的数据集一致。否则模型最后一层的输出维度会不对,训练时loss直接起飞。
数据加载:Dataset和DataLoader的初始化
self.train_loader=self.get_dataloader(self.trainset,batch_size=batch_size,rank=rank,mode='train')self.test_loader=self.get_dataloader(self.testset,batch_size=batch_size*2,rank=rank,mode='val')get_dataloader方法里有个细节——batch_size在分布式训练时会被除以world size。如果你手动设置batch size,记得考虑这个因素。
defget_dataloader(self,dataset,batch_size,rank=0,mode='train'):# 分布式训练时,每个进程的batch size = 总batch size / 进程数batch_size=min(batch_size,len(dataset))nd=torch.cuda.device_count()nw=min([os.cpu_count()//max(nd,1),batch_sizeifbatch_size>1else0,8])# 别这样写:nw = 8,会吃满CPUloader=DataLoader(dataset,batch_size=batch_size,shuffle=(mode=='train'),num_workers=nw,pin_memory=True,collate_fn=dataset.collate_fn)returnloadernum_workers的计算逻辑值得注意——它取了三个值的最小值:CPU核心数除以GPU数、batch size、8。这样设计是为了避免数据加载成为瓶颈,同时防止worker数过多导致内存溢出。
优化器初始化:参数分组的艺术
self.optimizer=self.build_optimizer(model,lr=cfg.lr0,momentum=cfg.momentum,decay=cfg.weight_decay)build_optimizer是初始化中最容易出问题的环节。YOLO的优化器对参数做了分组:
defbuild_optimizer(self,model,lr=0.001,momentum=0.937,decay=5e-4):g=[],[],[]# optimizer parameter groupsbn=tuple(vfork,vinnn.__dict__.items()if'Norm'ink)# normalization layersforvinmodel.modules():ifhasattr(v,'bias')andisinstance(v.bias,nn.Parameter):g[2].append(v.bias)# biasesifisinstance(v,bn):g[1].append(v.weight)# no decayelifhasattr(v,'weight')andisinstance(v.weight,nn.Parameter):g[0].append(v.weight)# apply decay这里把参数分成了三组:
- g[0]:卷积层、线性层的权重,应用weight decay
- g[1]:BN层的权重,不应用weight decay
- g[2]:所有bias,不应用weight decay
为什么这样分?因为BN层的weight和bias本身就有正则化作用,再加weight decay会过拟合。bias同理。
optimizer=SGD(g[2],lr=lr,momentum=momentum,nesterov=True)optimizer.add_param_group({'params':g[0],'weight_decay':decay})optimizer.add_param_group({'params':g[1],'weight_decay':0.0})returnoptimizer注意这里先初始化了bias组,再添加其他组。这样bias组会使用默认的weight_decay=0,而其他组可以单独设置。如果你用AdamW,记得把weight_decay参数传对,AdamW的weight_decay实现方式和SGD不一样。
学习率调度器:warmup和余弦退火
self.scheduler=self.build_scheduler(optimizer,epochs=cfg.epochs,lr0=cfg.lr0,lrf=cfg.lrf)build_scheduler里实现了YOLO经典的warmup + 余弦退火策略:
defbuild_scheduler(self,optimizer,epochs=100,lr0=0.01,lrf=0.01):lf=lambdax:((1+math.cos(x*math.pi/epochs))/2)*(1-lrf)+lrf scheduler=lr_scheduler.LambdaLR(optimizer,lr_lambda=lf)returnscheduler这个lambda函数实现了从lr0到lrf的余弦衰减。但注意,warmup阶段是在训练循环里手动实现的,不在调度器里:
# 在训练循环中ifself.epoch<self.warmup_epochs:# 线性增加lr从0到lr0lr=[lr0*(self.epoch+1)/self.warmup_epochs]forparam_groupinoptimizer.param_groups:param_group['lr']=lr[0]这里有个容易忽略的点——warmup阶段结束后,调度器才开始生效。如果你在warmup期间也调用了scheduler.step(),学习率会从0开始余弦衰减,导致warmup白做了。
损失函数和评估指标
self.criterion=self.get_criterion()self.metrics=self.get_metrics()YOLOv8的损失函数是复合的,包含分类损失、回归损失和DFL损失。get_criterion方法返回一个v8DetectionLoss实例,它内部维护了三个损失计算模块。
评估指标用的是ConfusionMatrix和Metric类,这些在验证阶段才会用到。初始化时只是创建了空对象。
断点续训:从checkpoint恢复
ifcfg.resume:self.resume=Trueself.last=Path(cfg.resume)ifisinstance(cfg.resume,str)elseget_latest_run()self.ckpt=torch.load(self.last)断点续训时,会加载checkpoint中的模型权重、优化器状态、调度器状态、epoch数等。这里有个坑——如果你改了模型结构,直接加载旧checkpoint会报错。建议在resume时保持模型结构不变,或者用strict=False加载。
# 加载模型权重self.model.load_state_dict(self.ckpt['model'],strict=False)# 加载优化器状态self.optimizer.load_state_dict(self.ckpt['optimizer'])# 恢复epochself.start_epoch=self.ckpt['epoch']+1个人经验建议
调试时先固定随机种子:YOLO的初始化涉及很多随机操作(数据增强、dropout等),不固定种子的话,同样的代码跑两次结果不一样,很难排查问题。
监控优化器参数组:训练初期打印一下
optimizer.param_groups,确认每个组的lr和weight_decay是否符合预期。我见过太多因为参数组顺序搞错导致训练失败的案例。warmup和调度器的衔接:如果你自定义了学习率策略,一定要在warmup结束后再开始调度器。可以在训练循环里加个判断:
if epoch >= warmup_epochs: scheduler.step()断点续训的陷阱:resume时不仅要恢复模型和优化器,还要恢复数据加载器的状态(比如shuffle的种子)。否则同一个epoch的数据顺序变了,影响训练稳定性。
内存泄漏排查:如果训练过程中内存持续增长,检查DataLoader的num_workers是否设置过大,或者模型里有没有未释放的中间变量。YOLO的损失函数里有个
bbox_loss会缓存一些计算结果,记得在epoch结束时清空。
最后说一句,源码读到这里,你已经超越了90%的YOLO使用者。剩下的10%,是在实际调试中积累的经验。下次遇到训练异常,先检查初始化流程,八成能找到问题。
