051、从理论到实战SwinIR 的窗口注意力机制与图像超分复现去年有个项目甲方要求把监控视频里模糊的车牌号放大4倍还能识别。我一开始用的是EDSR效果还行但一到夜间场景就崩了——车牌边缘全是锯齿像被狗啃过一样。后来换成SwinIR同样的训练数据PSNR直接跳了0.8个dB夜间车牌边缘干净得像用手术刀切出来的。这让我意识到窗口注意力机制不是花架子是真能解决实际问题的。为什么SwinIR能吊打传统CNN超分先别急着看代码咱们得先搞明白一件事为什么SwinIR比SRResNet、EDSR这些纯CNN架构强核心就两个字——感受野。传统CNN超分网络比如EDSR靠堆叠残差块来扩大感受野。但有个致命问题卷积核是局部的你堆100层理论上感受野能覆盖整张图实际上梯度传回去早衰减没了。我试过把EDSR的残差块从32个加到80个PSNR反而掉了0.1就是因为梯度消失导致深层学废了。SwinIR换了个思路用Transformer的自注意力机制让每个像素都能直接看到其他像素。但直接做全局自注意力一张256x256的图计算量是O(N²) O(65536²)显存直接爆炸。SwinIR的骚操作是分窗口——把特征图切成7x7的小窗口每个窗口内部做自注意力计算量降到O(M²) × (N/M²) O(NM²)M7时只有全局的1/100。但问题来了窗口之间信息不流通边缘像素只能看到窗口内的邻居感受野反而比CNN还小。SwinIR的解决方案是移位窗口——在相邻的Transformer块之间把窗口偏移半个窗口大小。这样上一层的窗口边界像素在下一层就能看到其他窗口的信息。相当于用两次局部注意力模拟了全局注意力而且计算量没涨。代码实现里的那些坑理论说完了咱们直接上代码。我用的SwinIR官方实现但官方代码有个毛病——为了通用性写得太抽象读起来像天书。我重新整理了一份精简版重点标注了容易踩坑的地方。importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassWindowAttention(nn.Module):def__init__(self,dim,window_size,num_heads):super().__init__()self.dimdim self.window_sizewindow_size# 比如7self.num_headsnum_heads self.scale(dim//num_heads)**-0.5# 这里踩过坑qkv的线性变换必须用nn.Linear不能用nn.Conv2d# 因为后面要做reshapeLinear更干净self.qkvnn.Linear(dim,dim*3)self.projnn.Linear(dim,dim)# 相对位置偏置表别自己手算用nn.Parameter让网络自己学# 窗口大小7相对位置范围是[-6,6]共13个位置self.relative_position_bias_tablenn.Parameter(torch.zeros((2*window_size-1)**2,num_heads))# 生成相对位置索引这个计算容易写错建议直接抄官方代码coords_htorch.arange(self.window_size)coords_wtorch.arange(self.window_size)coordstorch.stack(torch.meshgrid([coords_h,coords_w]))# 2, Wh, Wwcoords_flattentorch.flatten(coords,1)# 2, Wh*Wwrelative_coordscoords_flatten[:,:,None]-coords_flatten[:,None,:]# 2, Wh*Ww, Wh*Wwrelative_coordsrelative_coords.permute(1,2,0).contiguous()# Wh*Ww, Wh*Ww, 2relative_coords[:,:,0]self.window_size-1# 偏移到非负relative_coords[:,:,1]self.window_size-1relative_coords[:,:,0]*2*self.window_size-1relative_position_indexrelative_coords.sum(-1)# Wh*Ww, Wh*Wwself.register_buffer(relative_position_index,relative_position_index)defforward(self,x):B,N,Cx.shape# N window_size * window_sizeqkvself.qkv(x).reshape(B,N,3,self.num_heads,C//self.num_heads)qkvqkv.permute(2,0,3,1,4)# 3, B, num_heads, N, head_dimq,k,vqkv[0],qkv[1],qkv[2]attn(q k.transpose(-2,-1))*self.scale# 别这样写attn attn self.relative_position_bias_table[self.relative_position_index]# 因为relative_position_bias_table是1D的需要先索引再reshaperelative_position_biasself.relative_position_bias_table[self.relative_position_index.view(-1)]relative_position_biasrelative_position_bias.view(self.window_size**2,self.window_size**2,-1)relative_position_biasrelative_position_bias.permute(2,0,1).contiguous()# nH, Wh*Ww, Wh*Wwattnattnrelative_position_bias.unsqueeze(0)attnattn.softmax(dim-1)x(attn v).transpose(1,2).reshape(B,N,C)xself.proj(x)returnx这里有个容易忽略的细节相对位置偏置表的初始化。官方代码用的是trunc_normal_标准差0.02。我一开始图省事用了nn.init.zeros_结果训练了10个epochPSNR才26dB换成trunc_normal后直接跳到28dB。别小看这个初始化Transformer对初始值敏感尤其是位置编码。移位窗口的实现——最容易写错的地方移位窗口是SwinIR的精髓但实现起来坑最多。官方代码用了torch.roll来做循环移位然后对移位后的特征图做masked attention。我当初自己实现时mask矩阵算错了三天最后发现是索引偏移搞反了。defwindow_partition(x,window_size):# 别这样写x.view(B, H//ws, ws, W//ws, ws, C)# 顺序错了应该是先分H再分WB,H,W,Cx.shape xx.view(B,H//window_size,window_size,W//window_size,window_size,C)windowsx.permute(0,1,3,2,4,5).contiguous().view(-1,window_size,window_size,C)returnwindowsdefwindow_reverse(windows,window_size,H,W):Bint(windows.shape[0]/(H*W/window_size/window_size))xwindows.view(B,H//window_size,W//window_size,window_size,window_size,-1)xx.permute(0,1,3,2,4,5).contiguous().view(B,H,W,-1)returnxclassSwinTransformerBlock(nn.Module):def__init__(self,dim,input_resolution,num_heads,window_size7,shift_size0):super().__init__()self.dimdim self.input_resolutioninput_resolution self.num_headsnum_heads self.window_sizewindow_size self.shift_sizeshift_size# 这里踩过坑shift_size不能大于window_sizeifself.shift_sizeself.window_size:self.shift_sizeself.window_size//2self.norm1nn.LayerNorm(dim)self.attnWindowAttention(dim,window_size,num_heads)self.norm2nn.LayerNorm(dim)self.mlpnn.Sequential(nn.Linear(dim,dim*4),nn.GELU(),nn.Linear(dim*4,dim))defforward(self,x):H,Wself.input_resolution B,L,Cx.shapeassertLH*W,输入特征图尺寸不对shortcutx xself.norm1(x)xx.view(B,H,W,C)# 循环移位注意方向向右下角移位ifself.shift_size0:shifted_xtorch.roll(x,shifts(-self.shift_size,-self.shift_size),dims(1,2))else:shifted_xx# 分窗口x_windowswindow_partition(shifted_x,self.window_size)# nW*B, ws, ws, Cx_windowsx_windows.view(-1,self.window_size*self.window_size,C)# 计算attention mask防止移位后不同窗口的像素互相干扰ifself.shift_size0:attn_maskself.compute_mask(H,W,self.window_size,self.shift_size)else:attn_maskNone# 这里别忘记把mask传到attention里attn_windowsself.attn(x_windows,maskattn_mask)# 合并窗口attn_windowsattn_windows.view(-1,self.window_size,self.window_size,C)shifted_xwindow_reverse(attn_windows,self.window_size,H,W)# 反向移位ifself.shift_size0:xtorch.roll(shifted_x,shifts(self.shift_size,self.shift_size),dims(1,2))else:xshifted_x xx.view(B,H*W,C)xshortcutx# MLP部分xxself.mlp(self.norm2(x))returnxdefcompute_mask(self,H,W,window_size,shift_size):# 这个mask计算逻辑我建议直接抄官方自己写容易漏边界img_masktorch.zeros((1,H,W,1))h_slices(slice(0,-window_size),slice(-window_size,-shift_size),slice(-shift_size,None))w_slices(slice(0,-window_size),slice(-window_size,-shift_size),slice(-shift_size,None))cnt0forhinh_slices:forwinw_slices:img_mask[:,h,w,:]cnt cnt1mask_windowswindow_partition(img_mask,window_size)mask_windowsmask_windows.view(-1,window_size*window_size)attn_maskmask_windows.unsqueeze(1)-mask_windows.unsqueeze(2)attn_maskattn_mask.masked_fill(attn_mask!0,float(-100.0)).masked_fill(attn_mask0,float(0.0))returnattn_mask这个compute_mask函数我第一次写的时候把h_slices和w_slices的顺序搞反了结果训练出来的模型图像边缘出现周期性伪影像棋盘格一样。排查了三天最后打印出mask矩阵才发现左上角窗口的mask全是对的右下角窗口的mask全乱了。训练时的玄学调参SwinIR的训练有几个参数特别敏感我踩过的坑列出来学习率官方用1e-4但如果你用AdamW建议降到5e-5。我试过1e-4训练到第50个epoch loss开始震荡降到5e-5后稳定收敛。Batch size别贪大。SwinIR的窗口注意力虽然省显存但整体模型参数量26M比EDSR的43M小但计算图更复杂。我用RTX 3090batch size设8就爆显存了降到4才跑起来。后来发现可以用梯度累积模拟大batch。窗口大小官方默认7x7我试过5x5和9x9。5x5的PSNR掉了0.39x9的显存暴涨50%但PSNR只涨了0.05。所以7x7是个平衡点别乱改。训练数据DIV2K是标配但别忘了加Flickr2K。我一开始只用DIV2KPSNR到28.5就上不去了加上Flickr2K后直接跳到29.2。数据量是关键。实战用SwinIR做4倍超分最后给个完整的训练脚本框架注意我标注的坑importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataLoaderfromtorchvision.transformsimportCompose,RandomCrop,RandomHorizontalFlip# 模型定义省略上面的SwinIR类modelSwinIR(upscale4,in_chans3,img_size64,window_size7,img_range1.,depths[6,6,6,6],embed_dim180,num_heads[6,6,6,6],mlp_ratio2,upsamplerpixelshuffle,resi_connection1conv)# 这里踩过坑SwinIR的输入范围是[0,1]不是[0,255]# 如果你用[0,255]训练PSNR会低2个dBcriterionnn.L1Loss()# L1比L2效果好SwinIR官方用的就是L1optimizeroptim.AdamW(model.parameters(),lr5e-5,weight_decay1e-4)# 学习率调度别用StepLR用CosineAnnealingLRscheduleroptim.lr_scheduler.CosineAnnealingLR(optimizer,T_max200,eta_min1e-7)forepochinrange(200):forlr,hrintrain_loader:lr,hrlr.cuda(),hr.cuda()srmodel(lr)losscriterion(sr,hr)optimizer.zero_grad()loss.backward()# 别忘记梯度裁剪SwinIR的梯度容易爆炸nn.utils.clip_grad_norm_(model.parameters(),max_norm0.01)optimizer.step()scheduler.step()ifepoch%100:# 验证时记得用Y通道算PSNR别用RGBpsnrcalculate_psnr(sr,hr,crop_border4,test_y_channelTrue)print(fEpoch{epoch}, PSNR:{psnr:.2f})个人经验总结SwinIR不是万能药。我试过在手机拍摄的夜景照片上做超分效果反而不如EDSR——因为手机照片噪声大SwinIR的自注意力会把噪声也放大产生类似油画的效果。这时候需要先做降噪再超分或者用SwinIR的变体SwinIR-NG带噪声估计模块。另外SwinIR的推理速度是个硬伤。在RTX 3090上处理一张256x256的图EDSR只要15msSwinIR要45ms。如果做视频超分建议用SwinIR做关键帧非关键帧用光流EDSR插值这样能在质量和速度之间找到平衡。最后说一句别迷信论文里的PSNR。SwinIR在Set5、Set14这些标准测试集上确实吊打其他模型但到了真实场景比如监控视频、老照片修复效果可能不如你想象的好。多在自己的数据集上验证比什么都强。