Pytorch图像去噪实战(十):Restormer图像去噪实战,用高效Transformer解决高分辨率去噪问题
Pytorch图像去噪实战(十):Restormer图像去噪实战,用高效Transformer解决高分辨率去噪问题
一、问题场景:Transformer效果好,但高分辨率图片跑不动
上一篇我们实现了一个简化版 SwinIR,用 Transformer 思路提升复杂纹理恢复能力。
但很快就会遇到一个真实工程问题:
Transformer去噪效果不错,但图像稍微大一点显存就爆。
比如输入从 128x128 提升到 256x256,显存占用会明显上升。
如果处理真实业务图片,比如 1024x1024,普通全局注意力几乎不可用。
这就是图像恢复任务里非常关键的问题:
如何在高分辨率图像上使用Transformer?
Restormer就是为这类图像恢复任务设计的代表模型之一。
二、Restormer解决什么问题?
Restormer的核心目标是:
在保持Transformer建模能力的同时,降低高分辨率图像恢复的计算压力。
它的关键思想包括:
- 使用卷积保留局部结构
- 使用通道维度注意力降低复杂度
- 使用门控前馈网络增强表达
- 适合去噪、去雨、去模糊等图像恢复任务
这一篇我们实现一个简化版 Restormer Block,用于图像去噪实战。
三、为什么普通Self-Attention不适合高分辨率图像?
普通 Self-Attention 的复杂度和 token 数量平方相关。
如果图像大小是 H x W,token数是:
N = H * W注意力复杂度接近:
N²当输入为 512x512 时,N 非常大,计算基本不可接受。
Restormer采用更适合图像恢复的设计,避免直接做巨大空间注意力。
四、工程目录结构
restormer_denoise/ ├── data/ │ ├── train/ │ └── val/ ├── models/ │ └── mini_restormer.py ├── dataset.py ├── train.py ├── eval.py └── utils.py五、核心模块一:LayerNorm2d
Transformer通常用 LayerNorm,但图像特征是 BCHW 格式。
这里实现一个适合图像的 LayerNorm2d。
importtorchimporttorch.nnasnnclassLayerNorm2d(nn.Module):def__init__(self,channels,eps=1e-6):super().__init__()self.weight=nn.Parameter(torch.ones(channels))self.bias=nn.Parameter(torch.zeros(channels))self.eps=epsdefforward(self,x):mean=x.mean(dim=1,keepdim=True)var=x.var(dim=1,keepdim=True,unbiased=False)x=(x-mean)/torch.sqrt(var+self.eps)weight=self.weight.view(1,-1,1,1)bias=self.bias.view(1,-1,1,1)returnx*weight+bias六、核心模块二:通道注意力
这里实现一个简化版通道注意力,用来让模型判断哪些特征通道更重要。
classChannelAttention(nn.Module):def__init__(self,channels):super().__init__()self.pool=nn.AdaptiveAvgPool2d(1)self.net=nn.Sequential(nn.Conv2d(channels,channels//4,1),nn.ReLU(inplace=True),nn.Conv2d(channels//4,channels,1),nn.Sigmoid())defforward(self,x):weight=self.pool(x)weight=self.net(weight)returnx*weight七、核心模块三:门控前馈网络
Restormer里很重要的一个思想是 Gated Feed Forward。
简单理解:
不是所有特征都直接通过,而是通过门控机制筛选。
classGatedFeedForward(nn.Module):def__init__(self,channels):super().__init__()hidden=channels*2self.project_in=nn.Conv2d(channels,hidden*2,1)self.depthwise=nn.Conv2d(hidden*2,hidden*2,3,padding=1,groups=hidden*2)self.project_out=nn.Conv2d(hidden,channels,1)defforward(self,x):x=self.project_in(x)x=self.depthwise(x)x1,x2=x.chunk(2,dim=1)x=torch.nn.functional.gelu(x1)*x2returnself.project_out(x)八、Mini Restormer Block
把 LayerNorm、Attention、GatedFFN 组合起来。
classRestormerBlock(nn.Module):def__init__(self,channels):super().__init__()self.norm1=LayerNorm2d(channels)self.attn=ChannelAttention(channels)self.norm2=LayerNorm2d(channels)self.ffn=GatedFeedForward(channels)defforward(self,x):x=x+self.attn(self.norm1(x))x=x+self.ffn(self.norm2(x))returnx九、完整Mini Restormer模型
models/mini_restormer.py
importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassLayerNorm2d(nn.Module):def__init__(self,channels,eps=1e-6):super().__init__()self.weight=nn.Parameter(torch.ones(channels))self.bias=nn.Parameter(torch.zeros(channels))self.eps=epsdefforward(self,x):mean=x.mean(dim=1,keepdim=True)var=x.var(dim=1,keepdim=True,unbiased=False)x=(x-mean)/torch.sqrt(var+self.eps)returnx*self.weight.view(1,-1,1,1)+self.bias.view(1,-1,1,1)classChannelAttention(nn.Module):def__init__(self,channels):super().__init__()self.pool=nn.AdaptiveAvgPool2d(1)self.net=nn.Sequential(nn.Conv2d(channels,channels//4,1),nn.ReLU(inplace=True),nn.Conv2d(channels//4,channels,1),nn.Sigmoid())defforward(self,x):weight=self.net(self.pool(x))returnx*weightclassGatedFeedForward(nn.Module):def__init__(self,channels):super().__init__()hidden=channels*2self.project_in=nn.Conv2d(channels,hidden*2,1)self.depthwise=nn.Conv2d(hidden*2,hidden*2,3,padding=1,groups=hidden*2)self.project_out=nn.Conv2d(hidden,channels,1)defforward(self,x):x=self.project_in(x)x=self.depthwise(x)x1,x2=x.chunk(2,dim=1)x=F.gelu(x1)*x2returnself.project_out(x)classRestormerBlock(nn.Module):def__init__(self,channels):super().__init__()self.norm1=LayerNorm2d(channels)self.attn=ChannelAttention(channels)self.norm2=LayerNorm2d(channels)self.ffn=GatedFeedForward(channels)defforward(self,x):x=x+self.attn(self.norm1(x))x=x+self.ffn(self.norm2(x))returnxclassMiniRestormerDenoise(nn.Module):def__init__(self,in_channels=1,channels=64,num_blocks=6):super().__init__()self.head=nn.Conv2d(in_channels,channels,3,padding=1)self.body=nn.Sequential(*[RestormerBlock(channels)for_inrange(num_blocks)])self.tail=nn.Conv2d(channels,in_channels,3,padding=1)defforward(self,x):residual=x feat=self.head(x)feat=self.body(feat)noise=self.tail(feat)returnresidual-noise十、训练代码
importtorchfromtorch.utils.dataimportDataLoaderfromdatasetimportDenoiseDatasetfrommodels.mini_restormerimportMiniRestormerDenoisedeftrain():device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")dataset=DenoiseDataset("data/train",patch_size=128)loader=DataLoader(dataset,batch_size=8,shuffle=True,num_workers=4)model=MiniRestormerDenoise().to(device)optimizer=torch.optim.AdamW(model.parameters(),lr=2e-4,weight_decay=1e-4)criterion=torch.nn.L1Loss()forepochinrange(1,101):model.train()total_loss=0fornoisy,cleaninloader:noisy=noisy.to(device)clean=clean.to(device)pred=model(noisy)loss=criterion(pred,clean)optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)optimizer.step()total_loss+=loss.item()print(f"Epoch{epoch}, Loss:{total_loss/len(loader):.6f}")ifepoch%10==0:torch.save(model.state_dict(),f"mini_restormer_epoch_{epoch}.pth")if__name__=="__main__":train()十一、数据集代码
importosimportrandomimporttorchfromPILimportImagefromtorch.utils.dataimportDatasetimporttorchvision.transformsastransformsclassDenoiseDataset(Dataset):def__init__(self,root_dir,patch_size=128):self.paths=[os.path.join(root_dir,name)fornameinos.listdir(root_dir)ifname.lower().endswith((".jpg",".png",".jpeg"))]self.patch_size=patch_size self.to_tensor=transforms.ToTensor()def__len__(self):returnlen(self.paths)def__getitem__(self,idx):img=Image.open(self.paths[idx]).convert("L")w,h=img.sizeifw>=self.patch_sizeandh>=self.patch_size:x=random.randint(0,w-self.patch_size)y=random.randint(0,h-self.patch_size)img=img.crop((x,y,x+self.patch_size,y+self.patch_size))else:img=img.resize((self.patch_size,self.patch_size))clean=self.to_tensor(img)sigma=random.choice([15,25,35,50])noisy=torch.clamp(clean+torch.randn_like(clean)*sigma/255.0,0,1)returnnoisy,clean十二、为什么Restormer比普通Transformer更适合图像恢复?
普通Transformer直接对空间token做全局注意力,代价非常高。
Restormer类结构更重视:
- 局部卷积
- 通道交互
- 门控特征
- 残差学习
这使它更适合高分辨率恢复任务。
本文实现的是简化版,重点是理解结构思想,而不是完全复现论文细节。
十三、踩坑记录
坑1:LayerNorm维度写错
图像是 BCHW,不是 NLP 中的 BNC。
如果直接用 nn.LayerNorm(channels),很容易维度不匹配。
本文用的是自定义 LayerNorm2d。
坑2:GatedFFN通道数对不上
project_in 输出 hidden * 2,后面要 chunk 成两半。
如果通道设置错误,会报维度错误。
坑3:训练初期输出偏暗
原因可能是残差预测不稳定。
解决:
- 降低学习率
- 使用梯度裁剪
- 输出后 clamp
- 使用L1Loss
十四、效果验证
MiniRestormer相比普通UNet,主要优势是:
- 纹理保留更好
- 高噪声下更稳
- 平坦区域更自然
- 不容易出现过度平滑
| 模型 | 高分辨率适应性 | 纹理恢复 | 训练成本 |
|---|---|---|---|
| UNet | 好 | 中等 | 低 |
| MiniSwinIR | 一般 | 好 | 高 |
| MiniRestormer | 较好 | 好 | 中高 |
十五、适合收藏总结
MiniRestormer流程
- Conv提取浅层特征
- RestormerBlock建模
- Channel Attention筛选特征
- GatedFFN增强表达
- 预测噪声残差
- noisy - noise得到结果
避坑清单
- LayerNorm要适配BCHW
- GatedFFN通道要算对
- 学习率不要太大
- 高分辨率建议patch训练
- 推理结果必须clamp
十六、优化建议
可以继续升级:
- 多尺度Encoder-Decoder结构
- 更接近原版MDTA注意力
- 加像素shuffle上采样
- 加真实噪声数据微调
- 支持彩色RGB图像
结尾总结
Restormer代表的是图像恢复模型的一个重要方向:
不再简单套用NLP Transformer,而是针对图像恢复任务重新设计高效结构。
如果你已经掌握了 UNet 和 SwinIR,Restormer是非常值得继续深入的模型。
下一篇预告
Pytorch图像去噪实战(十一):Diffusion扩散模型图像去噪入门,从噪声预测理解生成式去噪
