当前位置: 首页 > news >正文

CIFAR-10/100 数据集 20 类粗粒度标签实战:PyTorch 加载与分层分类

CIFAR-100粗粒度分类实战:PyTorch双标签加载与分层模型设计

1. 理解CIFAR-100的层次化标签体系

CIFAR-100数据集最显著的特征是其双重标签系统。每张32x32的彩色图像不仅包含100个细粒度类别标签(如"苹果"、"蘑菇"),还关联着20个粗粒度的大类标签(如"水果和蔬菜")。这种层次结构为计算机视觉研究提供了独特的实验场景:

  • 粗粒度分类(20类):识别高级语义类别
  • 细粒度分类(100类):区分更具体的子类别
  • 层次关系:每个粗粒度类别包含5个细粒度类别(如"水果和蔬菜"包含苹果、蘑菇、橙子等)
# CIFAR-100标签结构示例 coarse_labels = [ '水生哺乳动物', '鱼类', '花卉', '食品容器', '水果和蔬菜', '家用电器', '家具', '昆虫', '大型食肉动物', '人造户外物品', '自然户外场景', '大型杂食动物', '中型哺乳动物', '无脊椎动物', '人物', '爬行动物', '小型哺乳动物', '树木', '交通工具1', '交通工具2' ] fine_labels = { '水果和蔬菜': ['苹果', '蘑菇', '橙子', '梨', '甜椒'], '家用电器': ['钟表', '电脑键盘', '台灯', '电话', '电视机'] # 其他大类省略... }

这种结构特别适合研究:

  • 层次化分类模型
  • 知识迁移(从粗粒度到细粒度)
  • 多任务学习(同时预测粗细标签)

提示:粗粒度标签在数据量不足时能提供更强的监督信号,而细粒度标签适合需要高精度的场景。

2. PyTorch数据加载器实现

我们需要自定义Dataset类来同时加载两种标签。关键点在于正确处理CIFAR-100的二进制文件格式:

2.1 数据集目录结构

cifar-100-python/ ├── train # 训练集 ├── test # 测试集 ├── meta # 标签名称元数据

2.2 自定义Dataset类

import torch from torch.utils.data import Dataset import pickle import numpy as np class CIFAR100WithCoarse(Dataset): def __init__(self, root, train=True, transform=None): self.transform = transform self.data = [] self.fine_labels = [] self.coarse_labels = [] # 加载数据文件 file = 'train' if train else 'test' with open(f'{root}/cifar-100-python/{file}', 'rb') as fo: dict = pickle.load(fo, encoding='bytes') # 转换数据格式 self.data = dict[b'data'].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1) self.fine_labels = dict[b'fine_labels'] self.coarse_labels = dict[b'coarse_labels'] # 加载标签名称 with open(f'{root}/cifar-100-python/meta', 'rb') as fo: meta = pickle.load(fo, encoding='bytes') self.fine_label_names = [t.decode('utf8') for t in meta[b'fine_label_names']] self.coarse_label_names = [t.decode('utf8') for t in meta[b'coarse_label_names']] def __len__(self): return len(self.data) def __getitem__(self, idx): img = self.data[idx] fine_label = self.fine_labels[idx] coarse_label = self.coarse_labels[idx] if self.transform: img = self.transform(img) return img, (coarse_label, fine_label)

2.3 数据增强策略

对于32x32的小尺寸图像,推荐使用以下增强组合:

from torchvision import transforms train_transform = transforms.Compose([ transforms.ToPILImage(), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ToTensor(), transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]) ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]) ])

3. 分层分类模型架构设计

3.1 基础特征提取网络

我们使用改进的ResNet-18作为基础架构:

import torch.nn as nn import torchvision.models as models class HierarchicalResNet(nn.Module): def __init__(self): super().__init__() # 加载预训练ResNet并修改输入层 resnet = models.resnet18(pretrained=True) resnet.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) resnet.maxpool = nn.Identity() # 移除初始下采样 # 特征提取部分 self.features = nn.Sequential( resnet.conv1, resnet.bn1, resnet.relu, resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4, resnet.avgpool ) # 分类头 self.coarse_head = nn.Linear(512, 20) # 粗粒度分类 self.fine_head = nn.Linear(512, 100) # 细粒度分类 def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) return self.coarse_head(x), self.fine_head(x)

3.2 层次感知损失函数

设计考虑标签层次结构的损失函数:

class HierarchicalLoss(nn.Module): def __init__(self, alpha=0.3): super().__init__() self.alpha = alpha # 粗粒度损失权重 self.ce_coarse = nn.CrossEntropyLoss() self.ce_fine = nn.CrossEntropyLoss() def forward(self, outputs, targets): coarse_out, fine_out = outputs coarse_target, fine_target = targets # 计算两种损失 loss_coarse = self.ce_coarse(coarse_out, coarse_target) loss_fine = self.ce_fine(fine_out, fine_target) # 组合损失 return self.alpha * loss_coarse + (1 - self.alpha) * loss_fine

4. 训练策略与评估

4.1 分层训练流程

def train_model(model, criterion, dataloaders, optimizer, num_epochs=100): for epoch in range(num_epochs): # 每个epoch包含训练和验证阶段 for phase in ['train', 'val']: if phase == 'train': model.train() else: model.eval() running_loss = 0.0 coarse_correct = 0 fine_correct = 0 total = 0 for inputs, (coarse_labels, fine_labels) in dataloaders[phase]: inputs = inputs.to(device) coarse_labels = coarse_labels.to(device) fine_labels = fine_labels.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): coarse_outputs, fine_outputs = model(inputs) loss = criterion((coarse_outputs, fine_outputs), (coarse_labels, fine_labels)) if phase == 'train': loss.backward() optimizer.step() # 统计指标 running_loss += loss.item() * inputs.size(0) _, coarse_preds = torch.max(coarse_outputs, 1) _, fine_preds = torch.max(fine_outputs, 1) coarse_correct += torch.sum(coarse_preds == coarse_labels) fine_correct += torch.sum(fine_preds == fine_labels) total += inputs.size(0) epoch_loss = running_loss / total coarse_acc = coarse_correct.double() / total fine_acc = fine_correct.double() / total print(f'{phase} Epoch {epoch}: Loss={epoch_loss:.4f}, ' f'Coarse Acc={coarse_acc:.4f}, Fine Acc={fine_acc:.4f}')

4.2 性能评估指标

除了常规的准确率,我们还应该关注:

指标计算公式意义
分层准确率粗/细粒度分类正确率评估不同层次性能
一致性误差细粒度预测与粗粒度不一致的比例评估层次一致性
混淆矩阵分析细粒度类别在粗粒度类别内的分布发现困难样本
def evaluate_hierarchy(model, dataloader): model.eval() confusion = np.zeros((20, 5, 5)) # 20个粗类,每个粗类5个细类 with torch.no_grad(): for inputs, (coarse_labels, fine_labels) in dataloader: inputs = inputs.to(device) coarse_labels = coarse_labels.cpu().numpy() fine_labels = fine_labels.cpu().numpy() coarse_out, fine_out = model(inputs) _, fine_preds = torch.max(fine_out, 1) fine_preds = fine_preds.cpu().numpy() for c, f_true, f_pred in zip(coarse_labels, fine_labels, fine_preds): f_true_in_c = f_true % 5 # 粗类内的相对索引 f_pred_in_c = f_pred % 5 confusion[c, f_true_in_c, f_pred_in_c] += 1 # 计算每个粗类内部的分类准确率 class_acc = [] for c in range(20): class_acc.append(np.diag(confusion[c]).sum() / confusion[c].sum()) return confusion, class_acc

5. 进阶技巧与优化方向

5.1 知识蒸馏应用

利用粗粒度标签指导细粒度分类:

class HierarchicalDistillationLoss(nn.Module): def __init__(self, temp=2.0, alpha=0.7): super().__init__() self.temp = temp self.alpha = alpha self.ce = nn.CrossEntropyLoss() self.kl = nn.KLDivLoss(reduction='batchmean') def forward(self, outputs, targets): coarse_out, fine_out = outputs coarse_target, fine_target = targets # 标准交叉熵损失 loss_fine = self.ce(fine_out, fine_target) # 知识蒸馏损失 with torch.no_grad(): coarse_probs = torch.softmax(coarse_out / self.temp, dim=1) # 将粗粒度概率映射到细粒度空间 fine_from_coarse = self._map_coarse_to_fine(coarse_probs) fine_student = torch.log_softmax(fine_out / self.temp, dim=1) loss_distill = self.kl(fine_student, fine_from_coarse) * (self.temp**2) return self.alpha * loss_fine + (1 - self.alpha) * loss_distill def _map_coarse_to_fine(self, coarse_probs): # 创建从粗粒度到细粒度的映射矩阵 mapping = torch.zeros(20, 100) for c in range(20): mapping[c, c*5:(c+1)*5] = 1/5 mapping = mapping.to(coarse_probs.device) return torch.matmul(coarse_probs, mapping)

5.2 模型轻量化策略

针对嵌入式设备部署的优化方案:

  1. 通道剪枝:基于粗粒度重要性的卷积核剪枝
  2. 量化感知训练:8整数量化
  3. 注意力机制:添加轻量级SE模块
class SELayer(nn.Module): def __init__(self, channel, reduction=16): super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel), nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y

6. 实际应用案例

6.1 零售商品分级系统

利用CIFAR-100的层次结构构建商品分类系统:

零售商品分类体系 ├── 食品 │ ├── 水果 │ ├── 零食 │ └── 饮料 ├── 电子产品 │ ├── 手机 │ └── 电脑 └── 家居用品 ├── 清洁用品 └── 厨具

6.2 多粒度图像检索

基于层次标签构建检索系统:

class HierarchicalRetriever: def __init__(self, model, database): self.model = model self.database = database # (paths, coarse_labels, fine_labels) def query(self, image, topk=5, level='fine'): with torch.no_grad(): features = self.model.features(image) coarse, fine = self.model.coarse_head(features), self.model.fine_head(features) if level == 'coarse': _, pred = torch.max(coarse, 1) mask = (self.database['coarse'] == pred.item()) else: _, pred = torch.max(fine, 1) mask = (self.database['fine'] == pred.item()) # 返回同类别的topk最相似图像 distances = compute_similarity(features, self.database['features'][mask]) indices = np.argsort(distances)[:topk] return self.database['paths'][mask][indices]

在医疗影像分析中,类似的层次结构也很有价值——比如先区分影像模态(X光/CT/MRI),再识别具体病症。

http://www.cnnetsun.cn/news/3124398.html

相关文章:

  • Unity性能优化:Draw Call与SetPass Call实战解析
  • UMG自发光效果快速实现与优化技巧
  • Pygame入门:从零开发2D游戏《飞机大战》实战指南
  • 游戏3D模型面数优化与UE5实战技巧
  • Godot 2D游戏开发入门:从环境搭建到角色控制
  • 数据分析师速成指南:Excel、SQL、Python与PowerBI实战路径
  • Cocos游戏集成Android原生隐私弹窗开发指南
  • SSL RC4漏洞修复实战:从原理到配置,全面加固TLS安全
  • MAX9744与PIC18LF25K50在音频功放系统中的应用与优化
  • UE5 PeerStream公网部署实战:WebRTC像素流送全链路配置指南
  • SAT碰撞检测优化:Burst与SIMD实战
  • 机械设计公差与配合实战指南:从图纸到装配的精准控制
  • YOLO目标检测算法:从原理到实战的100集系统学习指南
  • 虚幻引擎动画蓝图:混合空间与状态机实战解析
  • Unity游戏服务端开发:TCP通信与状态同步实战
  • QueryExcel:3分钟搞定100个Excel文件的批量查询神器
  • 轻量化YOLOv8船舶检测:99.1%精度背后的跨模态优化与工程部署实践
  • 西门子交换机环网冗余设置(理论篇)
  • 从真人舞蹈到虚拟偶像:OpenMMD如何用AI让动作捕捉平民化
  • 基于改进YOLOv8的无人机航拍小目标检测实战:电动自行车违规行为识别
  • UE像素流送实战:从原理到双向通信的完整部署指南
  • Unity安卓游戏手柄支持实战:从输入原理到完整实现
  • 自我管理书籍推荐,学会用更科学的方式管理自己
  • ComfyUI-to-Python:5分钟掌握从可视化AI工作流到Python代码的智能转换
  • AI增强生态模型:PLUS-InVEST技术融合与实战指南
  • STM32F103 外部晶振电路设计:8MHz与32.768KHz 双时钟源 PCB 布局 5 要点
  • Few Shot场景下的Agent开发与上下文处理实战
  • AI工具助力毕业论文写作:10款实用工具全流程指南
  • 随机森林实战:从原理到调优全解析
  • AI Agent技术架构解析与n8n平台实战指南