PyTorch SGD优化器报错怎么办?教你一招避坑
💓 博客主页:瑕疵的CSDN主页
📝 Gitee主页:瑕疵的gitee主页
⏩ 文章专栏:《热点资讯》
被SGD坑了,参数传错的深夜血泪史
目录
昨晚写CNN模型,SGD优化器突然报错。我盯着屏幕,心想:学习率0.01,应该没问题啊?结果弹出报错:TypeError: SGD() argument 'params' must be an iterable of Tensors, but got <class 'torch.nn.Module'>。
核心根源:
SGD优化器要的是参数列表,不是模型本身。我直接传了model,PyTorch直接崩了。model.parameters()才是返回可训练参数的正确方式。
我测试过10次,错误原因就这一个——新手总以为传模型就行。
错误示范(我踩过的坑):
# 错误:直接传入model,不是参数optimizer=optim.SGD(model,lr=0.01)# 运行就报错!# 为什么?model是整个神经网络模块,不是参数列表正确姿势(改完直接跑通):
# 正确:必须传model.parameters()optimizer=optim.SGD(model.parameters(),lr=0.01)# 重点!# model.parameters()返回可训练参数的迭代器
(截图里清楚显示"got
