当前位置: 首页 > news >正文

【图像分类】实战ResNet——从零构建到CIFAR-10分类(Pytorch)

1. 初识ResNet:为什么它能解决深度神经网络的瓶颈问题

第一次接触ResNet是在处理一个图像分类项目时,当时我遇到了所有深度学习工程师都会面临的经典问题:随着网络层数增加,模型性能不升反降。这就像给小孩子叠积木,叠得越高反而越容易倒塌。ResNet的提出者何恺明团队用"残差学习"的概念完美解决了这个问题。

残差块(Residual Block)的设计其实非常巧妙。想象你在学习骑自行车,如果直接学习完整的骑行动作很困难,但如果你已经会骑三轮车,现在只需要学习"保持平衡"这个差异部分,学习难度就大大降低了。ResNet正是采用了这种思想,通过shortcut connection让网络只需要学习当前输出与输入之间的残差(差异部分)。

我常用的ResNet-18和ResNet-34都采用基础残差块(BasicBlock),它们的结构对比如下:

组件ResNet-18ResNet-34
卷积层总数1834
残差块类型BasicBlockBasicBlock
参数量(M)11.721.8
ImageNet Top1准确率69.76%73.30%

在实际项目中,当计算资源有限时,我通常会先尝试ResNet-18。它的参数量只有11.7M,在CIFAR-10这种小规模数据集上训练速度很快,而且准确率也能达到不错的效果。记得第一次在Colab上跑ResNet-18时,只用15分钟就完成了训练,测试准确率轻松突破85%,这让我深刻体会到好模型不在于复杂,而在于设计巧妙。

2. 环境准备与数据加载:打造高效的PyTorch工作流

搭建环境就像准备厨房,工具齐全才能做出好菜。我习惯用conda创建独立环境,避免包版本冲突。以下是完整的安装步骤:

conda create -n resnet python=3.8 conda activate resnet pip install torch torchvision torchaudio pip install matplotlib tqdm numpy

加载CIFAR-10数据集时,有几个细节需要特别注意。第一次使用时我犯了个错误:直接下载的图片没有做归一化,导致模型难以收敛。正确的做法是使用torchvision提供的标准化参数:

from torchvision import transforms from torchvision.datasets import CIFAR10 transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) train_set = CIFAR10(root='./data', train=True, download=True, transform=transform_train) test_set = CIFAR10(root='./data', train=False, download=True, transform=transform_test)

数据增强是提升模型泛化能力的关键。在CIFAR-10这种小数据集上,我通常会加入随机水平翻转和随机裁剪。曾经对比过使用和不使用数据增强的效果,在ResNet-18上准确率相差近5个百分点!

创建数据加载器时,batch_size的设置很有讲究。在我的RTX 3060显卡上,32-64是比较理想的范围。太大会导致显存不足,太小则无法充分利用GPU并行计算能力:

from torch.utils.data import DataLoader train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=4) test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)

3. 构建ResNet模型:从残差块到完整网络

实现ResNet的核心在于正确构建残差块。我第一次实现时犯了个典型错误:忘记在shortcut连接中添加1x1卷积当维度不匹配时。这导致模型根本无法训练,损失值居高不下。正确的BasicBlock实现应该是这样的:

import torch.nn as nn class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(out_channels) self.shortcut = nn.Sequential() if stride != 1 or in_channels != self.expansion * out_channels: self.shortcut = nn.Sequential( nn.Conv2d( in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, bias=False ), nn.BatchNorm2d(self.expansion * out_channels) ) def forward(self, x): out = nn.ReLU()(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = nn.ReLU()(out) return out

完整ResNet的构建需要特别注意层与层之间的通道数变化。ResNet-18的结构可以分为以下几个部分:

  1. 初始卷积层:7x7卷积 -> 这个在CIFAR-10上我改成了3x3,因为图像尺寸较小
  2. 四个残差层:每层包含多个BasicBlock
  3. 全局平均池化和全连接层
class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=10): super().__init__() self.in_channels = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) def _make_layer(self, block, out_channels, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.in_channels, out_channels, stride)) self.in_channels = out_channels * block.expansion return nn.Sequential(*layers) def forward(self, x): x = nn.ReLU()(self.bn1(self.conv1(x))) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x

对于CIFAR-10,我们需要调整网络输入部分,因为原始ResNet是为ImageNet设计的。主要修改包括:

  • 将第一个7x7卷积改为3x3卷积
  • 去掉初始的max pooling层
  • 最后的平均池化改为自适应池化

4. 模型训练与调优:从基础训练到高级技巧

训练神经网络就像教小朋友学习,既要有耐心又要讲究方法。我总结了一套有效的训练流程:

基础训练配置

import torch.optim as optim model = ResNet(BasicBlock, [2, 2, 2, 2]).to(device) # ResNet-18 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)

学习率设置很关键,我习惯用学习率预热策略。初期用小学习率(如0.01),5个epoch后再调到0.1:

for epoch in range(5): # Warmup train(..., lr=0.01) for epoch in range(5, 200): train(..., lr=0.1)

训练过程中的重要技巧

  1. 混合精度训练:可以显著减少显存占用,加快训练速度
scaler = torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): outputs = model(inputs) loss = criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
  1. 标签平滑:缓解过拟合,提升模型泛化能力
class LabelSmoothingCrossEntropy(nn.Module): def __init__(self, epsilon=0.1): super().__init__() self.epsilon = epsilon def forward(self, outputs, targets): n_classes = outputs.size(-1) log_preds = F.log_softmax(outputs, dim=-1) loss = -log_preds.mean() return loss * self.epsilon + (1 - self.epsilon) * F.nll_loss(log_preds, targets)
  1. 模型EMA:保持模型参数的滑动平均,提升最终性能
class ModelEMA: def __init__(self, model, decay=0.999): self.ema = deepcopy(model).eval() self.decay = decay def update(self, model): with torch.no_grad(): for ema_p, model_p in zip(self.ema.parameters(), model.parameters()): ema_p.mul_(self.decay).add_(model_p, alpha=1 - self.decay)

训练监控

我习惯用TensorBoard记录训练过程,关键指标包括:

  • 训练/验证损失曲线
  • 学习率变化
  • 分类准确率
  • 参数分布直方图
from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter() for epoch in range(200): train_loss, train_acc = train(...) val_loss, val_acc = validate(...) writer.add_scalar('Loss/train', train_loss, epoch) writer.add_scalar('Loss/val', val_loss, epoch) writer.add_scalar('Accuracy/train', train_acc, epoch) writer.add_scalar('Accuracy/val', val_acc, epoch)

5. 模型评估与可视化:深入理解模型行为

训练完成后,我们需要全面评估模型性能。基础的准确率指标远远不够,我通常会从以下几个维度进行分析:

1. 混淆矩阵分析

from sklearn.metrics import confusion_matrix import seaborn as sns conf_mat = confusion_matrix(all_labels, all_preds) plt.figure(figsize=(10, 8)) sns.heatmap(conf_mat, annot=True, fmt='d', xticklabels=classes, yticklabels=classes) plt.xlabel('Predicted') plt.ylabel('Actual')

2. 特征可视化

使用t-SNE降维可视化最后一层特征:

from sklearn.manifold import TSNE features = [] # 收集模型最后一层前的特征 labels = [] with torch.no_grad(): for data, target in test_loader: data = data.to(device) feature = model.conv1(data) # ... 通过所有层直到最后一层前 features.append(feature.cpu()) labels.append(target) features = torch.cat(features).numpy() labels = torch.cat(labels).numpy() tsne = TSNE(n_components=2, random_state=42) features_2d = tsne.fit_transform(features) plt.scatter(features_2d[:, 0], features_2d[:, 1], c=labels, alpha=0.6) plt.colorbar()

3. 类激活图(CAM)可视化

理解模型关注哪些区域:

class CamExtractor: def __init__(self, model): self.model = model self.gradients = None def save_gradient(self, grad): self.gradients = grad def forward_pass(self, x): conv_output = None for name, module in self.model.named_children(): x = module(x) if name == 'layer4': # 最后一个卷积层 x.register_hook(self.save_gradient) conv_output = x return conv_output, x # 使用示例 extractor = CamExtractor(model) conv_output, model_output = extractor.forward_pass(input_img) model_output = model_output[:, target_class] conv_output.backward() gradients = extractor.gradients pooled_gradients = torch.mean(gradients, dim=[0, 2, 3]) for i in range(conv_output.shape[1]): conv_output[:, i, :, :] *= pooled_gradients[i] heatmap = torch.mean(conv_output, dim=1).squeeze().cpu().numpy() heatmap = np.maximum(heatmap, 0) heatmap /= np.max(heatmap) # 叠加到原图 img = input_img.squeeze().permute(1, 2, 0).cpu().numpy() heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])) heatmap = np.uint8(255 * heatmap) heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) superimposed_img = heatmap * 0.4 + img * 255

4. 错误分析

收集模型预测错误的样本,分析共同特征:

error_indices = np.where(all_preds != all_labels)[0] error_samples = test_set.data[error_indices] error_preds = all_preds[error_indices] error_labels = all_labels[error_indices] plt.figure(figsize=(15, 10)) for i in range(25): plt.subplot(5, 5, i+1) plt.imshow(error_samples[i]) plt.title(f'P:{classes[error_preds[i]]} A:{classes[error_labels[i]]}') plt.axis('off')

通过这些分析,我们可以发现模型在哪些类别上容易混淆,哪些特征被过度关注等问题。比如在CIFAR-10上,猫和狗、汽车和卡车常常是模型容易混淆的类别。针对这些问题,可以采取数据增强、类别重加权等改进措施。

http://www.cnnetsun.cn/news/3148791.html

相关文章:

  • Agent记忆系统设计与实现
  • 别把知识图谱做成高级文档库——定制化做企业级知识图谱
  • 【面板数据模型实战】从理论到Stata/R/Python实现与选择
  • 【机器人】基于缓冲的不确定性感知沃罗诺伊单元多机器人碰撞规避附Matlab代码
  • Rmarkdown动态文档创作与数据科学报告实战指南
  • 【HarmonyOS NEXT】error: failed to install bundle. code:9568322...
  • 多接地配电系统的基于PMU的系统状态估计附Matlab代码
  • Linux /etc/fstab 配置详解:5个关键参数避免重启后挂载回退只读
  • 普推黑体(PUTUI)1.202,更适合商标及标题文字!
  • 用C语言的<wchar.h>宽字节库实现好玩的逐字输出效果(模拟打字)
  • 鸿蒙新特性——Badge 徽章组件详解
  • Linux 用户管理知识与应用实践(二:用户相关命令与示例)
  • 高速 ADC 与 FPGA LVDS 接口设计:5 项 PCB 布线规则与 IDELAY 时序校准实战
  • 远控横评:向日葵、ToDesk、UU 远程,远程玩游戏差距有多大
  • Transformers自动化训练全流程优化实战
  • 助睿实验7-3:可视化探索
  • 基于51单片机的教室智能照明灯控制系统光控人数检测定做定制电子13(设计源文件+万字报告+讲解)(支持资料、图片参考_相关定制)_文章底部可以扫码
  • kotlin-basic-blog
  • 89个公共Tracker如何让BT下载告别“孤岛困境“?
  • 剧云推出分镜大师:让剧本更快变成可拍摄的镜头方案
  • Deceive:终极游戏隐身指南 - 如何在英雄联盟、VALORANT和符文大地传说中保持隐身状态
  • 《鸿蒙原生应用从0-1构建:项目工程结构与核心配置全景解析》
  • ExplorerPatcher深度解析:重塑Windows界面体验的高效工具
  • Node.js 插件沙箱:开放扩展之前先限制能力
  • Go 泛型的运行时性能:单态化、接口装箱与编译器优化的基准分析
  • OBS美颜文章_终极指南
  • 别再手写Bug了!用Python+LangGraph实现AI自修复代码的完整指南
  • AI机器学习高级数学与优化
  • SSTI攻击链构造手册(带WAF绕过)
  • 创客指南:oDrive X2212电机从零到闭环的完整配置流程