当前位置: 首页 > news >正文

Pytorch图像去噪实战(十):Restormer图像去噪实战,用高效Transformer解决高分辨率去噪问题

Pytorch图像去噪实战(十):Restormer图像去噪实战,用高效Transformer解决高分辨率去噪问题


一、问题场景:Transformer效果好,但高分辨率图片跑不动

上一篇我们实现了一个简化版 SwinIR,用 Transformer 思路提升复杂纹理恢复能力。
但很快就会遇到一个真实工程问题:

Transformer去噪效果不错,但图像稍微大一点显存就爆。

比如输入从 128x128 提升到 256x256,显存占用会明显上升。
如果处理真实业务图片,比如 1024x1024,普通全局注意力几乎不可用。

这就是图像恢复任务里非常关键的问题:

如何在高分辨率图像上使用Transformer?

Restormer就是为这类图像恢复任务设计的代表模型之一。


二、Restormer解决什么问题?

Restormer的核心目标是:

在保持Transformer建模能力的同时,降低高分辨率图像恢复的计算压力。

它的关键思想包括:

  • 使用卷积保留局部结构
  • 使用通道维度注意力降低复杂度
  • 使用门控前馈网络增强表达
  • 适合去噪、去雨、去模糊等图像恢复任务

这一篇我们实现一个简化版 Restormer Block,用于图像去噪实战。


三、为什么普通Self-Attention不适合高分辨率图像?

普通 Self-Attention 的复杂度和 token 数量平方相关。

如果图像大小是 H x W,token数是:

N = H * W

注意力复杂度接近:

当输入为 512x512 时,N 非常大,计算基本不可接受。

Restormer采用更适合图像恢复的设计,避免直接做巨大空间注意力。


四、工程目录结构

restormer_denoise/ ├── data/ │ ├── train/ │ └── val/ ├── models/ │ └── mini_restormer.py ├── dataset.py ├── train.py ├── eval.py └── utils.py

五、核心模块一:LayerNorm2d

Transformer通常用 LayerNorm,但图像特征是 BCHW 格式。
这里实现一个适合图像的 LayerNorm2d。

importtorchimporttorch.nnasnnclassLayerNorm2d(nn.Module):def__init__(self,channels,eps=1e-6):super().__init__()self.weight=nn.Parameter(torch.ones(channels))self.bias=nn.Parameter(torch.zeros(channels))self.eps=epsdefforward(self,x):mean=x.mean(dim=1,keepdim=True)var=x.var(dim=1,keepdim=True,unbiased=False)x=(x-mean)/torch.sqrt(var+self.eps)weight=self.weight.view(1,-1,1,1)bias=self.bias.view(1,-1,1,1)returnx*weight+bias

六、核心模块二:通道注意力

这里实现一个简化版通道注意力,用来让模型判断哪些特征通道更重要。

classChannelAttention(nn.Module):def__init__(self,channels):super().__init__()self.pool=nn.AdaptiveAvgPool2d(1)self.net=nn.Sequential(nn.Conv2d(channels,channels//4,1),nn.ReLU(inplace=True),nn.Conv2d(channels//4,channels,1),nn.Sigmoid())defforward(self,x):weight=self.pool(x)weight=self.net(weight)returnx*weight

七、核心模块三:门控前馈网络

Restormer里很重要的一个思想是 Gated Feed Forward。
简单理解:

不是所有特征都直接通过,而是通过门控机制筛选。

classGatedFeedForward(nn.Module):def__init__(self,channels):super().__init__()hidden=channels*2self.project_in=nn.Conv2d(channels,hidden*2,1)self.depthwise=nn.Conv2d(hidden*2,hidden*2,3,padding=1,groups=hidden*2)self.project_out=nn.Conv2d(hidden,channels,1)defforward(self,x):x=self.project_in(x)x=self.depthwise(x)x1,x2=x.chunk(2,dim=1)x=torch.nn.functional.gelu(x1)*x2returnself.project_out(x)

八、Mini Restormer Block

把 LayerNorm、Attention、GatedFFN 组合起来。

classRestormerBlock(nn.Module):def__init__(self,channels):super().__init__()self.norm1=LayerNorm2d(channels)self.attn=ChannelAttention(channels)self.norm2=LayerNorm2d(channels)self.ffn=GatedFeedForward(channels)defforward(self,x):x=x+self.attn(self.norm1(x))x=x+self.ffn(self.norm2(x))returnx

九、完整Mini Restormer模型

models/mini_restormer.py

importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassLayerNorm2d(nn.Module):def__init__(self,channels,eps=1e-6):super().__init__()self.weight=nn.Parameter(torch.ones(channels))self.bias=nn.Parameter(torch.zeros(channels))self.eps=epsdefforward(self,x):mean=x.mean(dim=1,keepdim=True)var=x.var(dim=1,keepdim=True,unbiased=False)x=(x-mean)/torch.sqrt(var+self.eps)returnx*self.weight.view(1,-1,1,1)+self.bias.view(1,-1,1,1)classChannelAttention(nn.Module):def__init__(self,channels):super().__init__()self.pool=nn.AdaptiveAvgPool2d(1)self.net=nn.Sequential(nn.Conv2d(channels,channels//4,1),nn.ReLU(inplace=True),nn.Conv2d(channels//4,channels,1),nn.Sigmoid())defforward(self,x):weight=self.net(self.pool(x))returnx*weightclassGatedFeedForward(nn.Module):def__init__(self,channels):super().__init__()hidden=channels*2self.project_in=nn.Conv2d(channels,hidden*2,1)self.depthwise=nn.Conv2d(hidden*2,hidden*2,3,padding=1,groups=hidden*2)self.project_out=nn.Conv2d(hidden,channels,1)defforward(self,x):x=self.project_in(x)x=self.depthwise(x)x1,x2=x.chunk(2,dim=1)x=F.gelu(x1)*x2returnself.project_out(x)classRestormerBlock(nn.Module):def__init__(self,channels):super().__init__()self.norm1=LayerNorm2d(channels)self.attn=ChannelAttention(channels)self.norm2=LayerNorm2d(channels)self.ffn=GatedFeedForward(channels)defforward(self,x):x=x+self.attn(self.norm1(x))x=x+self.ffn(self.norm2(x))returnxclassMiniRestormerDenoise(nn.Module):def__init__(self,in_channels=1,channels=64,num_blocks=6):super().__init__()self.head=nn.Conv2d(in_channels,channels,3,padding=1)self.body=nn.Sequential(*[RestormerBlock(channels)for_inrange(num_blocks)])self.tail=nn.Conv2d(channels,in_channels,3,padding=1)defforward(self,x):residual=x feat=self.head(x)feat=self.body(feat)noise=self.tail(feat)returnresidual-noise

十、训练代码

importtorchfromtorch.utils.dataimportDataLoaderfromdatasetimportDenoiseDatasetfrommodels.mini_restormerimportMiniRestormerDenoisedeftrain():device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")dataset=DenoiseDataset("data/train",patch_size=128)loader=DataLoader(dataset,batch_size=8,shuffle=True,num_workers=4)model=MiniRestormerDenoise().to(device)optimizer=torch.optim.AdamW(model.parameters(),lr=2e-4,weight_decay=1e-4)criterion=torch.nn.L1Loss()forepochinrange(1,101):model.train()total_loss=0fornoisy,cleaninloader:noisy=noisy.to(device)clean=clean.to(device)pred=model(noisy)loss=criterion(pred,clean)optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)optimizer.step()total_loss+=loss.item()print(f"Epoch{epoch}, Loss:{total_loss/len(loader):.6f}")ifepoch%10==0:torch.save(model.state_dict(),f"mini_restormer_epoch_{epoch}.pth")if__name__=="__main__":train()

十一、数据集代码

importosimportrandomimporttorchfromPILimportImagefromtorch.utils.dataimportDatasetimporttorchvision.transformsastransformsclassDenoiseDataset(Dataset):def__init__(self,root_dir,patch_size=128):self.paths=[os.path.join(root_dir,name)fornameinos.listdir(root_dir)ifname.lower().endswith((".jpg",".png",".jpeg"))]self.patch_size=patch_size self.to_tensor=transforms.ToTensor()def__len__(self):returnlen(self.paths)def__getitem__(self,idx):img=Image.open(self.paths[idx]).convert("L")w,h=img.sizeifw>=self.patch_sizeandh>=self.patch_size:x=random.randint(0,w-self.patch_size)y=random.randint(0,h-self.patch_size)img=img.crop((x,y,x+self.patch_size,y+self.patch_size))else:img=img.resize((self.patch_size,self.patch_size))clean=self.to_tensor(img)sigma=random.choice([15,25,35,50])noisy=torch.clamp(clean+torch.randn_like(clean)*sigma/255.0,0,1)returnnoisy,clean

十二、为什么Restormer比普通Transformer更适合图像恢复?

普通Transformer直接对空间token做全局注意力,代价非常高。

Restormer类结构更重视:

  • 局部卷积
  • 通道交互
  • 门控特征
  • 残差学习

这使它更适合高分辨率恢复任务。

本文实现的是简化版,重点是理解结构思想,而不是完全复现论文细节。


十三、踩坑记录

坑1:LayerNorm维度写错

图像是 BCHW,不是 NLP 中的 BNC。
如果直接用 nn.LayerNorm(channels),很容易维度不匹配。

本文用的是自定义 LayerNorm2d。


坑2:GatedFFN通道数对不上

project_in 输出 hidden * 2,后面要 chunk 成两半。

如果通道设置错误,会报维度错误。


坑3:训练初期输出偏暗

原因可能是残差预测不稳定。

解决:

  • 降低学习率
  • 使用梯度裁剪
  • 输出后 clamp
  • 使用L1Loss

十四、效果验证

MiniRestormer相比普通UNet,主要优势是:

  • 纹理保留更好
  • 高噪声下更稳
  • 平坦区域更自然
  • 不容易出现过度平滑
模型高分辨率适应性纹理恢复训练成本
UNet中等
MiniSwinIR一般
MiniRestormer较好中高

十五、适合收藏总结

MiniRestormer流程

  1. Conv提取浅层特征
  2. RestormerBlock建模
  3. Channel Attention筛选特征
  4. GatedFFN增强表达
  5. 预测噪声残差
  6. noisy - noise得到结果

避坑清单

  • LayerNorm要适配BCHW
  • GatedFFN通道要算对
  • 学习率不要太大
  • 高分辨率建议patch训练
  • 推理结果必须clamp

十六、优化建议

可以继续升级:

  • 多尺度Encoder-Decoder结构
  • 更接近原版MDTA注意力
  • 加像素shuffle上采样
  • 加真实噪声数据微调
  • 支持彩色RGB图像

结尾总结

Restormer代表的是图像恢复模型的一个重要方向:

不再简单套用NLP Transformer,而是针对图像恢复任务重新设计高效结构。

如果你已经掌握了 UNet 和 SwinIR,Restormer是非常值得继续深入的模型。


下一篇预告

Pytorch图像去噪实战(十一):Diffusion扩散模型图像去噪入门,从噪声预测理解生成式去噪

http://www.cnnetsun.cn/news/2147490.html

相关文章:

  • Flowframes终极指南:免费AI视频插帧工具让普通视频秒变流畅大片
  • 别再手动排期了!用Microsoft Project 2007三步搞定你的第一个项目计划(附WBS实战)
  • 终极指南:如何用Deep3D免费将2D视频秒变沉浸式3D立体影像
  • 氛!某插件肆意搜集信息,吾爱论坛站长打造完美替代品来救场
  • 如何用BiliTools跨平台工具箱轻松下载B站视频:完整指南
  • BepInEx Unity插件框架架构演进:从Mono到IL2CPP的技术突破与性能优化路径
  • 【仅限持牌机构技术负责人可见】:某头部支付平台PHP国密迁移内部白皮书节选(含性能损耗压测数据:TPS下降≤3.7%,密钥轮换耗时<86ms)
  • CircuitJS1 Desktop Mod:零基础入门电子电路仿真的完整指南
  • 当ISO镜像不再需要实体光驱:WinCDEmu的驱动级虚拟化方案
  • **超融合架构下的自动化运维:基于Python的容器化部署与监控实战**在现代数据中心演进中,**超融合架构(Hyper-Converg
  • YooAsset:企业级Unity资源管理框架的架构设计与实施指南
  • 如何快速掌握Charticulator:零代码图表设计的完整入门指南
  • 模型选型背后的成本工程:DeepSeek-V4、GPT-5.5与中国大模型API成本全解析
  • 绝地求生罗技鼠标宏压枪脚本:5分钟从新手到精准射击高手
  • AJ-Captcha行为验证码技术架构深度解析:构建智能人机识别系统的实践指南
  • 告别打包烦恼:用Auto.js Pro 9.0.0 + VSCode插件高效开发手机自动化脚本(附Scrcpy投屏技巧)
  • 任务分配的底层逻辑:告别 “能者多劳”,让每个人都 “物尽其用”
  • GLM-4.1V-9B-Base保姆级教程:Web界面UI功能分区与交互逻辑详解
  • Win11Debloat:Windows 11终极优化工具,5分钟还你一个干净高效的系统
  • 免费Switch模拟器Ryujinx:在PC上畅玩任天堂游戏的终极指南
  • 英雄联盟国服换肤神器:R3nzSkin免费解锁全皮肤完整教程
  • 29000+ 个 AI Skill 怎么选?这个工具帮你 30 秒找到最佳选择(附方法论)
  • 从MES到ERP:一份简历讲透你的技术栈演进,让猎头主动找上门
  • 别再只改主干网络了!YOLOv5模型轻量化避坑指南:从MobileNetV3、ShuffleNetV2到GhostNet的全面对比实验
  • 如何永久免费使用IDM?开源激活脚本完整指南
  • 终极Windows注册表取证分析:RegRipper3.0专业指南
  • 别再手动拼接字符串了!用Qt的QDateTime轻松搞定日志时间戳(附完整代码)
  • 如何用Autoticket大麦网自动抢票工具3倍提升抢票成功率?终极实战指南
  • 基于Java开发的制造业MES生产管理系统源码(含ERP集成模块)
  • cpp-httplib vs. 原生socket:手把手教你用C++写个高性能HTTP客户端(含连接池思路)