保姆级教程:在MMSegmentation中从零搭建并训练你自己的SegFormer模型(B0-B5全系列)
从零构建SegFormer语义分割模型的实战指南
1. 环境配置与MMSegmentation框架搭建
在开始构建SegFormer模型之前,我们需要先搭建好开发环境。MMSegmentation是一个基于PyTorch的开源语义分割工具包,它提供了丰富的预训练模型和灵活的配置系统。
基础环境要求:
- Python 3.7+
- PyTorch 1.8+
- CUDA 11.1+
- MMSegmentation 0.20.0+
安装MMSegmentation的推荐方式是通过pip安装:
pip install torch torchvision pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html pip install mmsegmentation其中{cu_version}和{torch_version}需要替换为你的CUDA和PyTorch版本,例如:
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html提示:建议使用conda创建独立的Python环境,避免依赖冲突
验证安装是否成功:
import mmseg print(mmseg.__version__)2. SegFormer模型架构解析
SegFormer是一种基于Transformer的语义分割架构,它结合了分层Transformer编码器和轻量级MLP解码器。相比传统CNN-based方法,SegFormer在保持高效的同时,能够捕捉更大范围的上下文信息。
2.1 分层Transformer编码器
SegFormer的编码器采用分层设计,包含四个阶段(B0-B5不同版本层数不同),每个阶段处理不同尺度的特征图:
| 阶段 | 特征图尺寸 | 通道数(B0) | 通道数(B5) |
|---|---|---|---|
| 1 | H/4 × W/4 | 32 | 64 |
| 2 | H/8 × W/8 | 64 | 128 |
| 3 | H/16 × W/16 | 160 | 320 |
| 4 | H/32 × W/32 | 256 | 512 |
每个阶段由多个Transformer块组成,每个块包含:
- 高效自注意力机制(Efficient Self-Attention)
- Mix-FFN前馈网络
2.2 轻量级MLP解码器
SegFormer的解码器设计非常简洁,仅包含以下几个步骤:
- 将四个阶段的特征图统一投影到相同维度
- 上采样到1/4原始尺寸并拼接
- 通过MLP层融合特征
- 最终分类预测
这种设计得益于Transformer编码器产生的大感受野特征,不需要复杂的上下文聚合模块。
3. 自定义SegFormer模型实现
3.1 模型配置文件
在MMSegmentation中,我们通过配置文件定义模型结构。以下是一个SegFormer-B0的配置示例:
# model settings model = dict( type='EncoderDecoder', pretrained='pretrain/mit_b0.pth', backbone=dict( type='mit_b0', style='pytorch'), decode_head=dict( type='SegFormerHead', in_channels=[32, 64, 160, 256], in_index=[0, 1, 2, 3], feature_strides=[4, 8, 16, 32], channels=256, dropout_ratio=0.1, num_classes=19, align_corners=False, decoder_params=dict(embed_dim=256), loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole'))3.2 关键参数解析
SegFormer有几个关键参数需要特别注意:
sr_ratios: 自注意力机制的缩放因子,控制计算复杂度mlp_ratios: MLP中间层的扩展系数embed_dims: 各阶段的特征维度num_heads: 多头注意力的头数
对于不同规模的SegFormer模型,这些参数的典型设置如下:
| 参数 | B0 | B1 | B2 | B3 | B4 | B5 |
|---|---|---|---|---|---|---|
| embed_dims | [32,64,160,256] | [64,128,320,512] | [64,128,320,512] | [64,128,320,512] | [64,128,320,512] | [64,128,320,512] |
| num_heads | [1,2,5,8] | [1,2,5,8] | [1,2,5,8] | [1,2,5,8] | [1,2,5,8] | [1,2,5,8] |
| mlp_ratios | [4,4,4,4] | [4,4,4,4] | [4,4,4,4] | [4,4,4,4] | [4,4,4,4] | [4,4,4,4] |
| sr_ratios | [8,4,2,1] | [8,4,2,1] | [8,4,2,1] | [8,4,2,1] | [8,4,2,1] | [8,4,2,1] |
| depths | [2,2,2,2] | [2,2,2,2] | [3,4,6,3] | [3,4,18,3] | [3,8,27,3] | [3,6,40,3] |
3.3 自定义数据集适配
要使SegFormer适应你的数据集,需要修改以下几个部分:
- 数据管道配置:
train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=(512, 512), cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]), dict(type='Pad', size=(512, 512), pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']), ]- 修改模型输出类别数:
num_classes = 5 # 你的数据集类别数 model['decode_head']['num_classes'] = num_classes4. 模型训练与优化
4.1 训练策略配置
SegFormer通常采用以下训练策略:
optimizer = dict( type='AdamW', lr=6e-5, betas=(0.9, 0.999), weight_decay=0.01) lr_config = dict( policy='poly', warmup='linear', warmup_iters=1500, warmup_ratio=1e-6, power=1.0, min_lr=0.0, by_epoch=False) runner = dict(type='IterBasedRunner', max_iters=160000) checkpoint_config = dict(by_epoch=False, interval=16000) evaluation = dict(interval=16000, metric='mIoU')4.2 分布式训练启动
使用MMSegmentation的tools/train.py脚本启动训练:
./tools/dist_train.sh configs/segformer/segformer_mit-b0_512x512_160k_ade20k.py 8 \ --work-dir work_dirs/segformer_mit-b0_512x512_160k_ade20k \ --seed 0 \ --deterministic其中8表示使用8个GPU,可以根据实际情况调整。
4.3 训练监控与调优
训练过程中需要关注以下指标:
- 损失曲线:确保训练损失稳定下降
- 验证集mIoU:监控模型泛化性能
- 学习率变化:确认学习率调度正常
常见的调优策略包括:
- 调整初始学习率(通常在1e-5到6e-5之间)
- 修改权重衰减系数(0.01-0.05)
- 增加数据增强强度
- 调整批次大小
5. 模型评估与推理
5.1 评估指标计算
MMSegmentation提供了多种评估指标:
./tools/dist_test.sh configs/segformer/segformer_mit-b0_512x512_160k_ade20k.py \ work_dirs/segformer_mit-b0_512x512_160k_ade20k/latest.pth \ 8 \ --eval mIoU mAcc aAcc主要评估指标包括:
- mIoU:平均交并比
- mAcc:平均准确率
- aAcc:整体像素准确率
5.2 单张图像推理
使用训练好的模型进行预测:
from mmseg.apis import inference_segmentor, init_segmentor import mmcv config_file = 'configs/segformer/segformer_mit-b0_512x512_160k_ade20k.py' checkpoint_file = 'work_dirs/segformer_mit-b0_512x512_160k_ade20k/latest.pth' # 初始化模型 model = init_segmentor(config_file, checkpoint_file, device='cuda:0') # 加载测试图像 img = 'test.jpg' # 执行推理 result = inference_segmentor(model, img) # 可视化结果 model.show_result(img, result, out_file='result.jpg', opacity=0.5)5.3 模型部署优化
为了提升推理速度,可以考虑以下优化:
- 半精度推理:
model.half() # 转换为半精度- TensorRT加速:
python tools/deployment/pytorch2onnx.py \ configs/segformer/segformer_mit-b0_512x512_160k_ade20k.py \ work_dirs/segformer_mit-b0_512x512_160k_ade20k/latest.pth \ --output-file segformer.onnx \ --shape 512 512- 模型剪枝:减少模型参数和计算量
6. 常见问题与解决方案
6.1 训练不收敛
可能原因:
- 学习率设置不当
- 数据标注存在问题
- 损失函数配置错误
解决方案:
- 尝试降低学习率
- 检查数据标注质量
- 验证损失函数实现
6.2 显存不足
优化策略:
- 减小批次大小
- 使用梯度累积
- 启用混合精度训练
配置示例:
fp16 = dict(loss_scale=dict(init_scale=512))6.3 推理速度慢
加速方法:
- 使用更小的模型(如B0而不是B5)
- 减小输入图像尺寸
- 启用ONNX或TensorRT推理
6.4 类别不平衡
处理方法:
- 在损失函数中添加类别权重
loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, class_weight=[1.0, 2.0, 1.5, 3.0, 1.2]) # 根据你的类别调整- 采用OHEM策略
- 数据重采样
7. 进阶技巧与最佳实践
7.1 自定义Backbone
如果你想修改SegFormer的Backbone结构,可以继承MixVisionTransformer类:
@BACKBONES.register_module() class CustomSegFormer(MixVisionTransformer): def __init__(self, **kwargs): super().__init__( patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], # 其他自定义参数 )7.2 知识蒸馏
使用更大的SegFormer模型(如B5)作为教师模型,蒸馏到更小的模型(如B0):
distiller = dict( type='SegFormerDistiller', teacher_cfg='configs/segformer/segformer_mit-b5_512x512_160k_ade20k.py', teacher_ckpt='checkpoints/segformer_mit-b5_512x512_160k_ade20k.pth', student_cfg='configs/segformer/segformer_mit-b0_512x512_160k_ade20k.py', distill_losses=dict( loss_feat=dict(type='FeatureLoss', criterion='L2', weight=1.0), loss_logits=dict(type='KLDivLoss', temperature=1, weight=0.5) ))7.3 多任务学习
SegFormer可以扩展为多任务模型,例如同时进行分割和深度估计:
model = dict( type='MultiTaskEncoderDecoder', backbone=dict( type='mit_b0', style='pytorch'), decode_heads=[ dict( type='SegFormerHead', # 分割任务配置 ), dict( type='DepthHead', # 深度估计任务配置 ) ], # 其他配置 )7.4 模型量化
对训练好的模型进行8位量化,减少模型大小并提升推理速度:
quant_config = dict( quantization_type='INT8', quantizer=dict( type='TensorRTQuantizer', calib_dataset=dict( type='CityscapesDataset', data_root='data/cityscapes/', pipeline=test_pipeline), calib_steps=100, calib_batch_size=8))