064、模型轻量化:知识蒸馏与剪枝技术在超分网络中的应用实践
064、模型轻量化知识蒸馏与剪枝技术在超分网络中的应用实践上周在部署一个EDSR变体到移动端时我遇到了一个让人头疼的问题模型参数量接近40M在骁龙8Gen2上跑一张720p图像需要将近800ms。客户说“这速度还不如直接插值”我无言以对。于是开始认真研究超分网络的轻量化今天聊聊知识蒸馏和剪枝这两个方向都是我在实际项目中踩过坑、流过泪后总结出来的。从“大模型教小模型”说起知识蒸馏这个概念说白了就是让一个复杂的教师网络去指导一个简单的学生网络。在超分任务里教师网络通常是参数量大、性能好的模型比如RCAN、EDSR学生网络则是我们真正要部署的轻量模型比如FSRCNN、IDN的变体。我第一次尝试蒸馏时犯了个低级错误——直接让学生网络去拟合教师网络的输出像素值。结果PSNR确实提升了一点但视觉效果很生硬细节纹理像被“磨皮”了一样。后来才意识到超分任务中教师网络学到的不仅仅是像素值还有高频细节的分布规律。正确的做法是让学生网络同时学习教师网络的中间特征和最终输出。我在代码里这样实现# 这里踩过坑只算输出层的MSE损失效果很差# 别这样写# loss nn.MSELoss()(student_out, teacher_out)# 正确做法多层级特征对齐defdistillation_loss(student_feats,teacher_feats,student_out,teacher_out,gt):# 特征层对齐让学生的中间特征模仿教师的feat_loss0fors_feat,t_featinzip(student_feats,teacher_feats):# 注意教师和学生特征图尺寸可能不同需要上采样或下采样对齐# 我在这里吃过亏忘记调整尺寸导致loss爆炸s_feat_resizedF.interpolate(s_feat,sizet_feat.shape[2:],modebilinear)feat_lossnn.MSELoss()(s_feat_resized,t_feat.detach())# detach教师特征不让梯度回传# 输出层对齐用L1损失代替MSE对边缘更友好out_lossnn.L1Loss()(student_out,teacher_out.detach())# 加上GT监督防止学生完全依赖教师gt_lossnn.L1Loss()(student_out,gt)# 权重分配特征损失占0.3输出损失占0.3GT损失占0.4# 这个比例我调了大概两周不同数据集表现不一样return0.3*feat_loss0.3*out_loss0.4*gt_loss这里有个细节容易被忽略教师网络在蒸馏过程中必须冻结参数。我刚开始没注意结果教师网络也跟着更新两个模型一起“跑偏”PSNR直接掉了0.5dB。剪枝不是简单砍掉权重剪枝听起来很暴力——把不重要的参数删掉。但实际操作起来比想象中复杂得多。超分网络不同于分类网络它对每个像素的贡献都很敏感盲目剪枝会导致图像出现“棋盘格”伪影。我尝试过两种剪枝策略结构化剪枝和非结构化剪枝。非结构化剪枝把权重矩阵中绝对值小的元素置零。优点是压缩率高缺点是需要专门的硬件或库支持稀疏计算。我在NVIDIA Jetson上试过稀疏矩阵运算效率并不高因为硬件对稠密矩阵做了大量优化。结构化剪枝直接剪掉整个卷积核或通道。这种方式更实用但需要仔细评估每个通道的重要性。我写过一个基于BN层缩放因子的剪枝方法在超分任务上效果还不错# 在训练过程中给每个卷积层后面加BN层并添加L1正则化到缩放因子gamma上# 注意超分网络通常不加BN因为会破坏图像统计特性# 但为了剪枝我硬加上了训练完再移除BN层deftrain_with_pruning(model,dataloader,sparsity1e-4):# 只对BN层的gamma参数施加L1正则# 别这样写对所有参数加正则会导致图像模糊bn_params[]other_params[]forname,paraminmodel.named_parameters():ifbn.weightinname:# gamma参数bn_params.append(param)else:other_params.append(param)optimizertorch.optim.Adam([{params:other_params},{params:bn_params,weight_decay:sparsity}# 对gamma加L2正则],lr1e-4)# 训练完成后统计所有BN层的gamma值gamma_values[]forname,moduleinmodel.named_modules():ifisinstance(module,nn.BatchNorm2d):gamma_values.extend(module.weight.data.cpu().numpy().tolist())# 设定剪枝阈值保留前70%的通道# 这个比例需要根据模型和数据集调整我试过剪掉50%还能保持95%的PSNRthresholdnp.percentile(gamma_values,30)# 生成剪枝掩码pruned_channels{}forname,moduleinmodel.named_modules():ifisinstance(module,nn.BatchNorm2d):maskmodule.weight.datathreshold pruned_channels[name]maskreturnpruned_channels剪枝后一定要做微调。我第一次剪完直接部署PSNR从32.1掉到28.5图像全是噪点。微调时有个技巧先用小学习率1e-5恢复几个epoch再逐渐增大到正常学习率。这就像手术后先让病人静养再慢慢恢复运动。蒸馏剪枝的组合拳单独用蒸馏或剪枝效果提升有限。我试过组合方案先用知识蒸馏训练一个轻量学生网络再对这个学生网络进行结构化剪枝最后微调。具体流程是这样的训练一个大的教师网络比如RCAN40层残差块设计一个轻量学生网络比如只有8个残差块用蒸馏损失训练学生网络直到收敛对学生网络进行通道剪枝剪掉30%的通道用蒸馏损失GT损失微调剪枝后的网络这个流程跑下来模型大小从40M压缩到4.5M推理速度从800ms降到120msPSNR只掉了0.3dB。客户终于满意了。有个坑必须提醒蒸馏和剪枝的顺序不能颠倒。我试过先剪枝再蒸馏结果学生网络太弱学不到教师网络的知识PSNR反而比直接训练还低。实战中的经验教训不要迷信理论压缩率。论文里说能压缩90%实际部署时可能只压缩了60%。因为超分网络的残差连接对剪枝特别敏感剪掉一个通道可能影响多个残差块的输出。蒸馏温度是个玄学。分类任务中常用的温度参数T在超分任务中效果不明显。我试过T1, 2, 4, 8PSNR差异不超过0.1dB。后来干脆不用温度缩放直接用L1损失。剪枝后的模型需要重新调整学习率。剪枝后网络结构变了原来的学习率可能太大。我一般把学习率降到原来的1/10然后使用余弦退火调度。注意BN层的统计量。剪枝后BN层的running_mean和running_stat需要重新计算。我踩过这个坑剪完枝直接部署前几张图全是黑色因为BN统计量没更新。硬件适配很重要。同样的剪枝策略在GPU上可能加速不明显但在NPU上效果显著。建议先了解目标硬件的特性再决定剪枝策略。最后说一句模型轻量化不是目的目的是在有限资源下获得尽可能好的超分效果。有时候一个精心设计的轻量网络比如FSRCNN的变体加上适当的蒸馏比把一个大型网络剪枝到极致更实用。毕竟工程落地追求的是性价比不是理论极限。