CIFAR-10/100 数据集 20 类粗粒度标签实战:PyTorch 加载与分层分类
CIFAR-100粗粒度分类实战:PyTorch双标签加载与分层模型设计
1. 理解CIFAR-100的层次化标签体系
CIFAR-100数据集最显著的特征是其双重标签系统。每张32x32的彩色图像不仅包含100个细粒度类别标签(如"苹果"、"蘑菇"),还关联着20个粗粒度的大类标签(如"水果和蔬菜")。这种层次结构为计算机视觉研究提供了独特的实验场景:
- 粗粒度分类(20类):识别高级语义类别
- 细粒度分类(100类):区分更具体的子类别
- 层次关系:每个粗粒度类别包含5个细粒度类别(如"水果和蔬菜"包含苹果、蘑菇、橙子等)
# CIFAR-100标签结构示例 coarse_labels = [ '水生哺乳动物', '鱼类', '花卉', '食品容器', '水果和蔬菜', '家用电器', '家具', '昆虫', '大型食肉动物', '人造户外物品', '自然户外场景', '大型杂食动物', '中型哺乳动物', '无脊椎动物', '人物', '爬行动物', '小型哺乳动物', '树木', '交通工具1', '交通工具2' ] fine_labels = { '水果和蔬菜': ['苹果', '蘑菇', '橙子', '梨', '甜椒'], '家用电器': ['钟表', '电脑键盘', '台灯', '电话', '电视机'] # 其他大类省略... }这种结构特别适合研究:
- 层次化分类模型
- 知识迁移(从粗粒度到细粒度)
- 多任务学习(同时预测粗细标签)
提示:粗粒度标签在数据量不足时能提供更强的监督信号,而细粒度标签适合需要高精度的场景。
2. PyTorch数据加载器实现
我们需要自定义Dataset类来同时加载两种标签。关键点在于正确处理CIFAR-100的二进制文件格式:
2.1 数据集目录结构
cifar-100-python/ ├── train # 训练集 ├── test # 测试集 ├── meta # 标签名称元数据2.2 自定义Dataset类
import torch from torch.utils.data import Dataset import pickle import numpy as np class CIFAR100WithCoarse(Dataset): def __init__(self, root, train=True, transform=None): self.transform = transform self.data = [] self.fine_labels = [] self.coarse_labels = [] # 加载数据文件 file = 'train' if train else 'test' with open(f'{root}/cifar-100-python/{file}', 'rb') as fo: dict = pickle.load(fo, encoding='bytes') # 转换数据格式 self.data = dict[b'data'].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1) self.fine_labels = dict[b'fine_labels'] self.coarse_labels = dict[b'coarse_labels'] # 加载标签名称 with open(f'{root}/cifar-100-python/meta', 'rb') as fo: meta = pickle.load(fo, encoding='bytes') self.fine_label_names = [t.decode('utf8') for t in meta[b'fine_label_names']] self.coarse_label_names = [t.decode('utf8') for t in meta[b'coarse_label_names']] def __len__(self): return len(self.data) def __getitem__(self, idx): img = self.data[idx] fine_label = self.fine_labels[idx] coarse_label = self.coarse_labels[idx] if self.transform: img = self.transform(img) return img, (coarse_label, fine_label)2.3 数据增强策略
对于32x32的小尺寸图像,推荐使用以下增强组合:
from torchvision import transforms train_transform = transforms.Compose([ transforms.ToPILImage(), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ToTensor(), transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]) ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]) ])3. 分层分类模型架构设计
3.1 基础特征提取网络
我们使用改进的ResNet-18作为基础架构:
import torch.nn as nn import torchvision.models as models class HierarchicalResNet(nn.Module): def __init__(self): super().__init__() # 加载预训练ResNet并修改输入层 resnet = models.resnet18(pretrained=True) resnet.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) resnet.maxpool = nn.Identity() # 移除初始下采样 # 特征提取部分 self.features = nn.Sequential( resnet.conv1, resnet.bn1, resnet.relu, resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4, resnet.avgpool ) # 分类头 self.coarse_head = nn.Linear(512, 20) # 粗粒度分类 self.fine_head = nn.Linear(512, 100) # 细粒度分类 def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) return self.coarse_head(x), self.fine_head(x)3.2 层次感知损失函数
设计考虑标签层次结构的损失函数:
class HierarchicalLoss(nn.Module): def __init__(self, alpha=0.3): super().__init__() self.alpha = alpha # 粗粒度损失权重 self.ce_coarse = nn.CrossEntropyLoss() self.ce_fine = nn.CrossEntropyLoss() def forward(self, outputs, targets): coarse_out, fine_out = outputs coarse_target, fine_target = targets # 计算两种损失 loss_coarse = self.ce_coarse(coarse_out, coarse_target) loss_fine = self.ce_fine(fine_out, fine_target) # 组合损失 return self.alpha * loss_coarse + (1 - self.alpha) * loss_fine4. 训练策略与评估
4.1 分层训练流程
def train_model(model, criterion, dataloaders, optimizer, num_epochs=100): for epoch in range(num_epochs): # 每个epoch包含训练和验证阶段 for phase in ['train', 'val']: if phase == 'train': model.train() else: model.eval() running_loss = 0.0 coarse_correct = 0 fine_correct = 0 total = 0 for inputs, (coarse_labels, fine_labels) in dataloaders[phase]: inputs = inputs.to(device) coarse_labels = coarse_labels.to(device) fine_labels = fine_labels.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): coarse_outputs, fine_outputs = model(inputs) loss = criterion((coarse_outputs, fine_outputs), (coarse_labels, fine_labels)) if phase == 'train': loss.backward() optimizer.step() # 统计指标 running_loss += loss.item() * inputs.size(0) _, coarse_preds = torch.max(coarse_outputs, 1) _, fine_preds = torch.max(fine_outputs, 1) coarse_correct += torch.sum(coarse_preds == coarse_labels) fine_correct += torch.sum(fine_preds == fine_labels) total += inputs.size(0) epoch_loss = running_loss / total coarse_acc = coarse_correct.double() / total fine_acc = fine_correct.double() / total print(f'{phase} Epoch {epoch}: Loss={epoch_loss:.4f}, ' f'Coarse Acc={coarse_acc:.4f}, Fine Acc={fine_acc:.4f}')4.2 性能评估指标
除了常规的准确率,我们还应该关注:
| 指标 | 计算公式 | 意义 |
|---|---|---|
| 分层准确率 | 粗/细粒度分类正确率 | 评估不同层次性能 |
| 一致性误差 | 细粒度预测与粗粒度不一致的比例 | 评估层次一致性 |
| 混淆矩阵分析 | 细粒度类别在粗粒度类别内的分布 | 发现困难样本 |
def evaluate_hierarchy(model, dataloader): model.eval() confusion = np.zeros((20, 5, 5)) # 20个粗类,每个粗类5个细类 with torch.no_grad(): for inputs, (coarse_labels, fine_labels) in dataloader: inputs = inputs.to(device) coarse_labels = coarse_labels.cpu().numpy() fine_labels = fine_labels.cpu().numpy() coarse_out, fine_out = model(inputs) _, fine_preds = torch.max(fine_out, 1) fine_preds = fine_preds.cpu().numpy() for c, f_true, f_pred in zip(coarse_labels, fine_labels, fine_preds): f_true_in_c = f_true % 5 # 粗类内的相对索引 f_pred_in_c = f_pred % 5 confusion[c, f_true_in_c, f_pred_in_c] += 1 # 计算每个粗类内部的分类准确率 class_acc = [] for c in range(20): class_acc.append(np.diag(confusion[c]).sum() / confusion[c].sum()) return confusion, class_acc5. 进阶技巧与优化方向
5.1 知识蒸馏应用
利用粗粒度标签指导细粒度分类:
class HierarchicalDistillationLoss(nn.Module): def __init__(self, temp=2.0, alpha=0.7): super().__init__() self.temp = temp self.alpha = alpha self.ce = nn.CrossEntropyLoss() self.kl = nn.KLDivLoss(reduction='batchmean') def forward(self, outputs, targets): coarse_out, fine_out = outputs coarse_target, fine_target = targets # 标准交叉熵损失 loss_fine = self.ce(fine_out, fine_target) # 知识蒸馏损失 with torch.no_grad(): coarse_probs = torch.softmax(coarse_out / self.temp, dim=1) # 将粗粒度概率映射到细粒度空间 fine_from_coarse = self._map_coarse_to_fine(coarse_probs) fine_student = torch.log_softmax(fine_out / self.temp, dim=1) loss_distill = self.kl(fine_student, fine_from_coarse) * (self.temp**2) return self.alpha * loss_fine + (1 - self.alpha) * loss_distill def _map_coarse_to_fine(self, coarse_probs): # 创建从粗粒度到细粒度的映射矩阵 mapping = torch.zeros(20, 100) for c in range(20): mapping[c, c*5:(c+1)*5] = 1/5 mapping = mapping.to(coarse_probs.device) return torch.matmul(coarse_probs, mapping)5.2 模型轻量化策略
针对嵌入式设备部署的优化方案:
- 通道剪枝:基于粗粒度重要性的卷积核剪枝
- 量化感知训练:8整数量化
- 注意力机制:添加轻量级SE模块
class SELayer(nn.Module): def __init__(self, channel, reduction=16): super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel), nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y6. 实际应用案例
6.1 零售商品分级系统
利用CIFAR-100的层次结构构建商品分类系统:
零售商品分类体系 ├── 食品 │ ├── 水果 │ ├── 零食 │ └── 饮料 ├── 电子产品 │ ├── 手机 │ └── 电脑 └── 家居用品 ├── 清洁用品 └── 厨具6.2 多粒度图像检索
基于层次标签构建检索系统:
class HierarchicalRetriever: def __init__(self, model, database): self.model = model self.database = database # (paths, coarse_labels, fine_labels) def query(self, image, topk=5, level='fine'): with torch.no_grad(): features = self.model.features(image) coarse, fine = self.model.coarse_head(features), self.model.fine_head(features) if level == 'coarse': _, pred = torch.max(coarse, 1) mask = (self.database['coarse'] == pred.item()) else: _, pred = torch.max(fine, 1) mask = (self.database['fine'] == pred.item()) # 返回同类别的topk最相似图像 distances = compute_similarity(features, self.database['features'][mask]) indices = np.argsort(distances)[:topk] return self.database['paths'][mask][indices]在医疗影像分析中,类似的层次结构也很有价值——比如先区分影像模态(X光/CT/MRI),再识别具体病症。
