当前位置: 首页 > news >正文

神经网络实战:ResNet 医学影像分类全流程解析

前言

在医学影像领域,处理高分辨率图像往往耗时耗力。本次项目采用 MedMNIST 风格的简化数据集,即28×28像素的小尺寸医学图像,重点完成医学影像的多分类任务,并拆解深度学习中非常经典的网络结构——ResNet,也就是深度残差网络。

一、环境准备与数据加载

医学数据集通常包含多种类别,例如结肠癌切片、皮肤病变、肺部 X 光等。由于不同数据集的类别数、通道数和样本数量可能不同,因此项目中通过一个 JSON 配置文件来统一管理数据集的基本信息。

1.1 数据预处理代码

dataset.py中,我们需要定义数据的读取方式以及标准化操作。

import torch from torch.utils.data import Dataset from torchvision import transforms class MedicalDataset(Dataset): def __init__(self, data_array, labels, transform=None): self.data = data_array self.labels = labels self.transform = transform def __getitem__(self, index): # 1. 提取图像和标签 img, target = self.data[index], self.labels[index] # 2. 预处理:ToTensor 转为 Tensor,Normalize 进行标准化 if self.transform is not None: img = self.transform(img) return img, target def __len__(self): return len(self.data) # 定义预处理流程 data_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize( mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5] ) ])

这里需要注意,如果当前数据集是三通道 RGB 图像,可以使用:

mean=[0.5, 0.5, 0.5] std=[0.5, 0.5, 0.5]

如果是单通道灰度图像,则应改为:

mean=[0.5] std=[0.5]

二、核心原理:残差块 BasicBlock

ResNet 的核心思想是解决深层网络的退化问题。普通网络层数不断加深后,训练效果不一定变好,甚至可能变差。ResNet 通过 Shortcut Connection,也就是短路连接,让输入信息可以直接传到后面,从而缓解这个问题。

简单来说,残差结构可以表示为:

out = F(x) + x

其中,F(x)表示卷积层学习到的特征,x表示原始输入。这样一来,如果新增层学得好,就能提升效果;如果新增层学得不好,shortcut 也能保留原始信息,避免模型效果明显下降。

2.1 BasicBlock 代码实现

下面是 ResNet 的最小构建单元 BasicBlock。

import torch.nn as nn class BasicBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super(BasicBlock, self).__init__() # 第一层卷积:负责特征提取或降采样 self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) # 第二层卷积 self.conv2 = nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(out_channels) # Shortcut 路径 self.shortcut = nn.Sequential() # 如果维度不匹配,需要用 1×1 卷积调整 if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=stride, bias=False ), nn.BatchNorm2d(out_channels) ) def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) # 核心步骤:主分支输出与 shortcut 分支相加 out += self.shortcut(identity) out = self.relu(out) return out

这里最关键的是 shortcut 路径。

如果输入和输出维度一致,shortcut 不需要做任何操作,直接相加即可;如果输入和输出维度不一致,例如通道数变化或特征图尺寸变化,就需要使用1×1卷积进行调整。

可以简单理解为:

shape 一样,直接相加;shape 不一样,先用1×1卷积调整后再相加。

三、搭建 ResNet 网络架构

通过_make_layer函数,我们可以像搭积木一样重复使用 BasicBlock,从而搭建完整的 ResNet 网络。

3.1 完整网络代码

class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=9, input_channels=3): super(ResNet, self).__init__() self.in_channels = 64 # 第一步:初始卷积层 self.conv1 = nn.Conv2d( input_channels, 64, kernel_size=3, stride=1, padding=1, bias=False ) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) # 第二步:堆叠 4 个 layer self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) # 第三步:平均池化 + 全连接分类 self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512, num_classes) def _make_layer(self, block, out_channels, num_blocks, stride): # 第一个 block 可能需要降采样,其余 block 保持 stride=1 strides = [stride] + [1] * (num_blocks - 1) layers = [] for s in strides: layers.append(block(self.in_channels, out_channels, s)) self.in_channels = out_channels return nn.Sequential(*layers) def forward(self, x): # 初始输入:[Batch, 3, 28, 28] out = self.relu(self.bn1(self.conv1(x))) out = self.layer1(out) # [Batch, 64, 28, 28] out = self.layer2(out) # [Batch, 128, 14, 14] out = self.layer3(out) # [Batch, 256, 7, 7] out = self.layer4(out) # [Batch, 512, 4, 4] out = self.avgpool(out) # [Batch, 512, 1, 1] out = out.view(out.size(0), -1) # [Batch, 512] out = self.fc(out) # [Batch, num_classes] return out def resnet18(num_classes, input_channels): return ResNet(BasicBlock, [2, 2, 2, 2], num_classes, input_channels)

这里的[2, 2, 2, 2]表示 ResNet-18 中四个 layer 分别包含 2 个 BasicBlock。

整体 shape 变化可以理解为:

[Batch, 3, 28, 28] → [Batch, 64, 28, 28] → [Batch, 128, 14, 14] → [Batch, 256, 7, 7] → [Batch, 512, 4, 4] → [Batch, 512] → [Batch, num_classes]

随着网络逐渐加深,特征图的通道数不断增加,高和宽逐渐减小,模型提取到的特征也越来越抽象。

四、训练与评估

最后,将数据喂入模型,完成前向传播、损失计算和反向传播。

4.1 训练循环代码

# 初始化模型、优化器和损失函数 model = resnet18(num_classes=9, input_channels=3).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() def train(model, train_loader, optimizer, criterion, epoch): model.train() for batch_idx, (inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(device), targets.to(device).long() # 1. 梯度清零 optimizer.zero_grad() # 2. 前向传播 outputs = model(inputs) # 3. 计算损失 loss = criterion(outputs, targets) # 4. 反向传播 loss.backward() # 5. 更新参数 optimizer.step() if batch_idx % 10 == 0: print( f"Epoch: {epoch} | " f"Batch: {batch_idx} | " f"Loss: {loss.item():.4f}" ) # 训练启动 for epoch in range(1, 101): train(model, train_loader, optimizer, criterion, epoch)

训练流程可以概括为:

读取 batch 数据 → 输入模型 → 得到预测结果 → 计算交叉熵损失 → 反向传播 → 更新参数

验证阶段和训练阶段类似,但不需要反向传播和参数更新,只需要计算 loss、accuracy、AUC 等指标。

五、文章总结

本项目基于简化版医学影像数据集,完成了一个 ResNet-18 图像分类流程。通过将医学图像统一处理为28×28,可以降低训练成本,方便快速跑通整个深度学习实验。

ResNet 的核心在于残差连接。它通过 shortcut 保留原始输入信息,从而缓解深层网络退化问题。尤其是在 layer 切换时,如果输入输出维度不一致,代码中会通过stride=2进行下采样,同时利用1×1卷积在 shortcut 路径上调整通道数和特征图尺寸,保证两条分支可以正常相加。

在实际调试时,建议在forward()中打印每一层的out.shape,观察张量从[Batch, 3, 28, 28]逐步变为[Batch, 512]特征向量,再变为[Batch, num_classes]分类结果的过程。

通过以上流程,我们完成了从医学图像输入、ResNet 特征提取,到最终分类输出的完整实战项目。这个项目的重点不在于追求最高准确率,而在于理解 ResNet 的网络结构、shortcut 的作用,以及深度学习分类任务的完整训练流程。

http://www.cnnetsun.cn/news/2192487.html

相关文章:

  • 使用Python和Taotoken实现一个简单的多模型自动降级调用策略
  • AutoResearch:基于LLM的自动化研究流水线架构与实战指南
  • 多模态大模型在文档智能处理中的技术实践
  • Nginx SSL证书加载失败?除了.pem,你还需要检查证书格式和权限
  • SQL视图查询结果正确性校验_对比物理表数据与视图
  • 抖音内容下载难题怎么破?douyin-downloader 批量下载神器完全指南
  • 终极指南:如何在S905L2-B电视盒上快速部署Armbian系统
  • 无监督图像编辑:基于GAN与特征解耦的创新方法
  • Y语言-Y++全中文可视化编程语言
  • 大语言模型在数学奥赛解题中的应用与实践
  • 3分钟完成B站视频转文字:bili2text完整指南
  • YimMenu终极指南:如何在GTA5在线模式中建立你的数字堡垒
  • CyberEngineTweaks架构解析:赛博朋克2077性能调优与脚本框架深度优化
  • 别再混淆了!一文讲透scATAC-seq、Bulk ATAC-seq和scRNA-seq的应用场景与选择逻辑
  • 利用 Taotoken 模型广场为 AIGC 内容生成项目挑选合适的大模型
  • 抖音下载终极指南:轻松获取无水印视频的完整解决方案
  • 五一前夕DeepSeek发布多模态模型:解决指代鸿沟,拓扑推理大幅超越GPT-5.4等模型
  • Claude Code 工具 详解
  • 利用 Taotoken 为团队知识库构建智能问答机器人应用场景
  • 从数学建模到工程实践:用MATLAB复现多波束测线优化(附贪心算法与模拟退火代码)
  • 别再混淆MIPI-DSI的命令包了!0x29和0x39到底怎么选?附SPRD/Rockchip实例解析
  • 跨平台项目中QString 与 非Qt 跨平台动态库在字符集上的一个实用的互操作约定.
  • 喜马拉雅VIP音频下载终极指南:3步实现付费内容本地化
  • 对比直连与通过 Taotoken 调用在容灾体验上的不同
  • 终极免费d2s-editor:暗黑破坏神2存档修改完全指南
  • 【LLM推理优化与部署工程⑧】模型部署了,但没人知道它在干什么——出事了你都不知道
  • 终极魔兽争霸3优化指南:告别卡顿,畅享144Hz流畅体验
  • 中兴光猫解锁终极指南:5分钟获取完整root权限的完整教程
  • 八大网盘直链解析技术深度解析:架构设计与性能优化指南
  • PySpice终极指南:如何用Python轻松完成专业级电路仿真