056、BasicVSR 视频超分:双向传播与光流对齐的核心技术解析
056、BasicVSR 视频超分双向传播与光流对齐的核心技术解析从一次诡异的训练崩溃说起去年冬天调BasicVSR跑REDS数据集batch size设成8V100上吭哧吭哧跑了三天。第四天凌晨三点loss突然从0.003飙到NaN。我第一反应是梯度爆炸查了梯度范数正常。再查学习率也没问题。最后定位到光流估计模块——某个视频帧序列里出现了大面积的纯黑帧光流网络直接输出了一堆inf。这问题在单帧超分里根本不会出现但视频超分里帧与帧之间的运动估计一旦崩了整个双向传播链就全完蛋。后来我加了个光流mask的clamp操作问题解决。但这件事让我意识到BasicVSR看似结构简单真正落地时光流对齐和时序传播这两个环节任何一个细节没处理好模型就跟你玩罢工。双向传播不是简单的“前向后向”很多人看BasicVSR论文觉得双向传播不就是把EDVR的时序注意力换成光流对齐然后前向传播一次、后向传播一次最后融合。这么理解没错但实现起来有坑。BasicVSR的传播结构是先做一次后向传播从最后一帧往前再做一次前向传播从第一帧往后每次传播都通过光流把前一帧的特征warp到当前帧然后和当前帧的特征做残差连接。注意这里不是简单的特征拼接而是用了一个残差块来融合。我踩过的第一个坑传播顺序搞反了。论文里说“backward then forward”但有人觉得应该先前向再后向理由是“先看过去再看未来”。实际上BasicVSR的设计逻辑是后向传播先建立全局的时序依赖前向传播再在此基础上细化。如果你先做前向后向传播时看到的特征就已经包含了未来的信息反而会导致信息泄露。这跟视频编解码里的B帧预测逻辑有点像但又不完全一样。代码实现时我习惯这样写# 后向传播从T-1到0fortinrange(T-2,-1,-1):# warp前一帧的特征到当前帧warped_featflow_warp(feat_backward[t1],flow_backward[t])# 这里踩过坑直接加warped_feat和当前特征梯度会爆炸# 正确做法先过一层卷积再残差连接feat_backward[t]residual_block(torch.cat([feat_backward[t],warped_feat],dim1))别这样写feat_backward[t] feat_backward[t] warped_feat。光流warp后的特征和原始特征分布差异很大直接相加会让特征空间扭曲训练初期loss降不下去。光流对齐SPyNet的“糙快猛”哲学BasicVSR用的光流网络是SPyNet不是FlowNet或者RAFT。为什么选它因为SPyNet轻量参数量只有FlowNet的十分之一推理速度快。但代价是精度不如RAFT。SPyNet的核心思想是金字塔粗到细。输入两张图先下采样到低分辨率估计一个粗糙的光流然后上采样在高分辨率层做refine。每个金字塔层只输出一个残差光流最后累加得到最终光流。这里有个细节容易被忽略SPyNet的输入是RGB图像不是特征图。BasicVSR里光流估计是在原始图像上做的而不是在特征空间。这意味着光流网络和超分网络是解耦的光流网络可以单独预训练然后冻结住。我试过把光流网络也一起训练结果超分网络学歪了——它开始依赖光流网络的误差来补偿超分效果而不是真正学会对齐。实际部署时我建议用预训练的SPyNet权重然后冻结。如果你非要微调记得把光流网络的学习率设成超分网络的十分之一否则光流会漂移。传播链中的梯度问题为什么你的模型训不动BasicVSR的传播链很长后向传播T帧前向传播T帧总共2T步。T30时梯度要回传60步。这还不算光流warp操作里的双线性插值梯度。我遇到过最典型的问题训练到一半loss不再下降但也没发散。检查梯度发现后向传播的梯度几乎为0前向传播的梯度正常。这是因为后向传播的起始帧最后一帧的特征没有经过任何传播梯度直接从它开始回传越往前梯度越小。解决方案有两个使用梯度裁剪但阈值不能太小我一般设max_norm0.1。在传播链中插入残差连接让梯度有短路路径。BasicVSR的原始设计里每个时间步的特征都会和warp后的特征做残差连接这已经缓解了梯度消失。但如果你的视频序列特别长比如超过60帧建议把序列切分成多个子序列每个子序列独立传播然后做overlap融合。光流warp的边界效应一个容易被忽视的细节光流warp时边界像素会映射到图像外部导致warp后的特征在边界处出现空洞。BasicVSR用了一个简单的padding策略在warp之前对特征图做replicate padding。但这样不够。我自己的做法在warp之后对边界像素做mask。具体来说计算每个像素的warp坐标是否在有效范围内生成一个0-1 mask然后和warp后的特征相乘。这样边界处的无效特征就不会污染后续的传播。defflow_warp_with_mask(x,flow):# x: [B, C, H, W], flow: [B, 2, H, W]gridmake_grid(H,W)flow# 生成有效区域maskmask_x(grid[:,0:1,:,:]-1)(grid[:,0:1,:,:]1)mask_y(grid[:,1:2,:,:]-1)(grid[:,1:2,:,:]1)mask(mask_xmask_y).float()# warpwarpedF.grid_sample(x,grid.permute(0,2,3,1),modebilinear,padding_modezeros)returnwarped*mask别这样写直接用padding_modeborder。border padding会让边界像素重复导致光流估计出现假阳性。训练策略从REDS到真实场景REDS数据集是BasicVSR的标准benchmark但真实场景和REDS差异很大。REDS的视频是固定帧率、固定分辨率、运动平滑。真实场景里视频可能有跳帧、运动模糊、光照突变。我踩过的一个大坑在REDS上训好的模型直接拿到监控视频上测试效果惨不忍睹。原因是监控视频的帧率不固定光流估计不准。解决方案在训练时加入帧率扰动随机跳过一些帧让模型适应不同的运动速度。另一个经验BasicVSR对输入分辨率敏感。训练时用128x128的patch测试时直接上1080p效果会下降。建议在测试时用overlap的滑动窗口每个窗口独立推理然后融合。窗口大小设成128x128步长64融合时用高斯权重。个人经验性建议光流网络的选择SPyNet够用但如果你追求极致效果换成RAFT。代价是显存翻倍训练时间翻三倍。我自己的项目里用RAFT替换SPyNet后PSNR提升了0.15dB但推理速度从30fps降到了8fps。看你的应用场景取舍。传播长度不是越长越好。T30时效果最好但T60时反而下降。因为长序列的累积误差会抵消时序信息带来的增益。我一般设T20-30超过40帧的视频就切段。梯度检查训练初期每隔100个iteration打印一次梯度范数。如果后向传播的梯度比前向传播小两个数量级说明传播链有问题。这时候检查光流warp的梯度是否正常。数据增强视频超分的数据增强和单帧不同。除了常规的翻转、旋转还要加时序上的增强随机跳帧、随机裁剪时间片段、模拟运动模糊。我试过在训练时随机丢弃30%的帧模型反而更鲁棒。部署优化BasicVSR的推理速度瓶颈在光流warp。如果你用PyTorch的grid_sample建议用CUDA的warp操作替代速度能快3倍。TensorRT里也有对应的plugin。最后说一句BasicVSR是视频超分领域的一个里程碑但它不是终点。它的核心思想——双向传播光流对齐——启发了后来的很多工作比如IconVSR、BasicVSR。理解BasicVSR就等于理解了视频超分的一半。另一半留给那些在光流和传播之间做文章的新模型。