031、Transformer降临超分:SwinIR的窗口注意力机制详解与源码走读

发布时间:2026/7/2 23:43:11
031、Transformer降临超分:SwinIR的窗口注意力机制详解与源码走读 031、Transformer降临超分SwinIR的窗口注意力机制详解与源码走读上个月调一个4倍超分模型训练到第80个epoch突然loss炸了从0.003跳到0.8。排查了三天最后发现是注意力计算时QK^T的维度没对齐导致梯度爆炸。这个坑让我重新翻出SwinIR的源码才发现自己之前对窗口注意力机制的理解全是“知其然不知其所以然”。今天就把这块硬骨头彻底啃透。为什么是SwinIR而不是ViT2021年SwinIR出来的时候很多人觉得就是Swin Transformer搬过来做超分。但真正跑过实验就知道直接套用ViT做超分256x256的输入全局注意力计算复杂度是O(N^2)N65536一张卡显存直接爆掉。SwinIR的核心贡献在于用窗口注意力替代全局注意力同时用交叉窗口连接弥补信息割裂。窗口大小默认8x8每个窗口内64个token计算复杂度降到O(M^2)其中M64显存占用直接降了两个数量级。代价是窗口之间的信息被切断了——这就是为什么需要后面的Shifted Window操作。窗口注意力到底怎么算的看源码之前先理解SwinIR里窗口注意力的三个关键设计1. 相对位置编码不是ViT那种绝对位置编码而是计算query和key之间的相对偏移。比如窗口内第3个token和第7个token相对位置是(3-7, 3-7)(-4,-4)。这样做的好处是模型学到的是“两个像素之间距离多远”的关系而不是“这个像素在绝对坐标第几行”。超分任务里纹理的局部相关性远比绝对位置重要。2. 窗口划分与mask机制输入特征图[H,W,C]先reshape成[num_windows, window_size*window_size, C]。每个窗口独立计算注意力。Shifted Window时特征图先做循环移位再重新划分窗口这样原本在边界处的像素就能和相邻窗口的像素交互。但移位后窗口内会混入来自不同区域的像素需要用mask把不该交互的位置屏蔽掉。3. 多头注意力残差连接和标准Transformer一样但SwinIR在注意力后加了两个卷积层做特征融合而不是MLP。这个设计很巧妙——卷积能更好地保持空间连续性适合图像任务。源码走读从入口到核心计算直接看SwinIR的forward函数核心模块是RSTBResidual Swin Transformer Block里面嵌套了SwinTransformerLayer。classSwinTransformerLayer(nn.Module):def__init__(self,dim,num_heads,window_size8,shift_size0):super().__init__()self.window_sizewindow_size self.shift_sizeshift_size# 0表示普通窗口window_size//2表示移位窗口self.norm1nn.LayerNorm(dim)self.attnWindowAttention(dim,num_heads,window_size)# 这里踩过坑LayerNorm一定要放在attention前面否则梯度不稳定重点看forward里的窗口划分逻辑defforward(self,x):B,C,H,Wx.shape shortcutx xself.norm1(x.permute(0,2,3,1).reshape(B*H*W,-1)).reshape(B,H,W,C).permute(0,3,1,2)# 别这样写直接对4D tensor做norm会破坏通道维度关系ifself.shift_size0:# 循环移位把左上角区域移到右下角shifted_xtorch.roll(x,shifts(-self.shift_size,-self.shift_size),dims(2,3))else:shifted_xx# 划分窗口把[H,W] reshape成 [H//ws, ws, W//ws, ws]# 然后转置成 [num_windows, ws*ws, C]x_windowswindow_partition(shifted_x,self.window_size)# x_windows shape: [B*num_windows, ws*ws, C]attn_windowsself.attn(x_windows)# 这里有个细节attn内部做了相对位置编码的bias计算# 还原回原始尺寸xwindow_reverse(attn_windows,self.window_size,H,W)ifself.shift_size0:# 反向移位把之前移走的区域移回来xtorch.roll(x,shifts(self.shift_size,self.shift_size),dims(2,3))returnxshortcut# 残差连接WindowAttention内部相对位置编码的坑WindowAttention的forward里QKV计算很简单关键是相对位置编码的生成classWindowAttention(nn.Module):def__init__(self,dim,num_heads,window_size):super().__init__()self.num_headsnum_heads self.window_sizewindow_size self.scale(dim//num_heads)**-0.5# 生成相对位置索引表coords_htorch.arange(window_size)coords_wtorch.arange(window_size)coordstorch.stack(torch.meshgrid([coords_h,coords_w]))# [2, ws, ws]coords_flattencoords.flatten(1)# [2, ws*ws]# 计算相对坐标每个位置减去所有位置relative_coordscoords_flatten[:,:,None]-coords_flatten[:,None,:]# [2, ws*ws, ws*ws]relative_coordsrelative_coords.permute(1,2,0).contiguous()# [ws*ws, ws*ws, 2]# 关键步骤把负坐标映射到正数方便查表relative_coords[:,:,0]window_size-1relative_coords[:,:,1]window_size-1relative_coords[:,:,0]*2*window_size-1# 这里踩过坑乘的是(2ws-1)不是wsrelative_position_indexrelative_coords.sum(-1)# [ws*ws, ws*ws]self.register_buffer(relative_position_index,relative_position_index)# 可学习的相对位置偏置表self.relative_position_bias_tablenn.Parameter(torch.zeros((2*window_size-1)*(2*window_size-1),num_heads))forward里查表时defforward(self,x,maskNone):B_,N,Cx.shape qkvself.qkv(x).reshape(B_,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)q,k,vqkv[0],qkv[1],qkv[2]attn(q k.transpose(-2,-1))*self.scale# 查相对位置偏置relative_position_biasself.relative_position_bias_table[self.relative_position_index.view(-1)].view(N,N,-1).permute(2,0,1).unsqueeze(0)# [1, num_heads, N, N]attnattnrelative_position_biasifmaskisnotNone:attnattn.masked_fill(mask0,float(-inf))attnattn.softmax(dim-1)x(attn v).transpose(1,2).reshape(B_,N,C)returnx这里有个容易忽略的点relative_position_index是固定的但relative_position_bias_table是可学习的。这意味着模型在训练过程中会调整“不同相对位置之间的注意力权重”但“哪些位置算作相对位置”是预先定义好的。窗口大小固定后相对位置的范围就固定了所以SwinIR不支持动态窗口大小。Shifted Window的mask实现Shifted Window的难点在于循环移位后窗口内可能包含来自不同区域的像素这些像素之间不应该有注意力交互。SwinIR的做法是生成一个mask矩阵在softmax前把非法位置的注意力值设为负无穷。defcreate_mask(self,x,H,W):# 生成一个和特征图同尺寸的索引图img_masktorch.zeros(1,H,W,1)h_slices(slice(0,-self.window_size),slice(-self.window_size,-self.shift_size),slice(-self.shift_size,None))w_slices(slice(0,-self.window_size),slice(-self.window_size,-self.shift_size),slice(-self.shift_size,None))cnt0forhinh_slices:forwinw_slices:img_mask[:,h,w,:]cnt cnt1# 划分窗口后每个窗口内的像素来自不同区域不同cnt值mask_windowswindow_partition(img_mask,self.window_size)mask_windowsmask_windows.view(-1,self.window_size*self.window_size)# 计算mask如果两个像素来自不同区域mask0attn_maskmask_windows.unsqueeze(1)-mask_windows.unsqueeze(2)attn_maskattn_mask.masked_fill(attn_mask!0,float(-inf)).masked_fill(attn_mask0,0.0)returnattn_mask这个mask在attention forward里和relative_position_bias一起加到attn上。注意mask是在forward里动态生成的因为每次输入尺寸可能不同超分任务经常需要处理不同分辨率。实战经验调参踩坑记录窗口大小选8还是16我试过16x16的窗口显存直接翻倍PSNR只涨了0.02dB。8x8是性价比最高的选择。如果输入分辨率超过512x512建议用8x8并配合梯度检查点。shift_size设成多少官方代码里是window_size//2也就是4。别改成其他值否则窗口之间的信息交互会不均匀。我试过shift_size2结果模型在纹理区域出现网格伪影。多头注意力的头数SwinIR默认6个头每个头64维。如果显存紧张可以减到4个头但PSNR会掉0.1dB左右。头数太少相对位置编码的表达能力会下降。训练时梯度爆炸如果遇到loss突然飙升先检查LayerNorm的位置。SwinIR的LayerNorm是在attention之前如果放在attention之后梯度会变得非常不稳定。另外学习率超过2e-4基本必炸建议用1e-4配合warmup。推理时速度优化SwinIR的窗口划分和还原操作涉及大量reshape和permute在PyTorch里这些操作会打断CUDA kernel的连续性。实测把window_partition和window_reverse写成C extension可以提速30%。如果不想写C至少保证输入尺寸是window_size的整数倍避免padding带来的额外计算。个人经验性建议SwinIR的成功不是因为它用了Transformer而是因为它用窗口注意力解决了计算复杂度问题同时用Shifted Window解决了信息割裂问题。这个设计思路比ViT更适合图像任务。如果你现在要做一个新的超分模型别直接抄SwinIR。考虑两个改进方向一是用可变形窗口替代固定窗口让模型自己学习在哪里划分窗口二是把窗口注意力和通道注意力结合SwinIR只做了空间维度的注意力通道维度还是靠卷积这里还有提升空间。最后说一句别迷信SwinIR的官方实现。它的代码为了通用性做了很多冗余操作比如每次forward都重新生成mask。实际部署时把mask和relative_position_index缓存起来能省掉不少计算。我自己的项目里把这两项预计算后推理速度提升了15%。