本文基于 horizon_plugin_pytorch 量化工具链详细介绍多输出网络的量化配置策略、HistogramObserver 使用、混合精度设置及校准流程。1. 模型结构设计1.1 多输出网络示例import torch import torch.nn as nn import torch.nn.functional as F from torch.quantization import QuantStub, DeQuantStub # ------------------------- # Backbone # ------------------------- class Backbone(nn.Module): def __init__(self): super().__init__() self.conv1 nn.Conv2d(3, 32, kernel_size3, padding1) self.bn1 nn.BatchNorm2d(32) self.relu1 nn.ReLU(inplaceTrue) self.conv2 nn.Conv2d(32, 64, kernel_size3, padding1) self.bn2 nn.BatchNorm2d(64) self.relu2 nn.ReLU(inplaceTrue) def forward(self, x): x self.relu1(self.bn1(self.conv1(x))) x self.relu2(self.bn2(self.conv2(x))) return x # ------------------------- # Head A: 分类头 # ------------------------- class HeadA(nn.Module): def __init__(self, in_channels, num_classes10): super().__init__() self.avgpool nn.AdaptiveAvgPool2d((1, 1)) self.fc1 nn.Linear(in_channels, 128) self.fc2 nn.Linear(128, num_classes) def forward(self, x): x self.avgpool(x) x torch.flatten(x, 1) x F.relu(self.fc1(x)) x self.fc2(x) return x # ------------------------- # Head B: 回归头 # ------------------------- class HeadB(nn.Module): def __init__(self, in_channels, output_dim4): super().__init__() self.flatten nn.Flatten() self.fc nn.Linear(in_channels * 8 * 8, output_dim) def forward(self, x): x F.adaptive_avg_pool2d(x, (8, 8)) x self.flatten(x) x self.fc(x) return x # ------------------------- # 总网络双输出 # ------------------------- class MyNet(nn.Module): def __init__(self): super().__init__() self.backbone Backbone() self.headA HeadA(in_channels64, num_classes10) self.headB HeadB(in_channels64, output_dim4) # 量化入口与出口 self.quant QuantStub() self.dequant_A DeQuantStub() self.dequant_B DeQuantStub() def forward(self, x): x self.quant(x) # 量化入口 feat self.backbone(x) outA self.headA(feat) outB self.headB(feat) return self.dequant_A(outA), self.dequant_B(outB) # 双输出反量化1.2 关键设计说明组件作用QuantStub量化入口将 FP32 输入转换为量化域DeQuantStub量化出口将量化值还原为 FP32 输出多 DeQuantStub多输出网络需为每个输出配置独立的反量化节点2. QConfig 配置策略2.1 Observer 选择建议​推荐使用 HistogramObserver​原因如下特性MinMaxObserverHistogramObserver统计方式仅记录 min/max构建完整直方图分布感知否是完整分布多方法支持否是mse/percentile/kl 等离群值处理敏感自动处理核心优势HistogramObserver 将收集与计算分离在不改变网络结构/权重/校准数据一次校准后可通过 reset_scale 切换不同计算方法无需重新跑校准。2.2 配置示例激活 HistogramObserver 权重 MinMaxObserverfrom horizon_plugin_pytorch.quantization.qconfig import QConfig from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize from horizon_plugin_pytorch.quantization.observer_v2 import MinMaxObserver, HistogramObserver from horizon_plugin_pytorch.quantization import qint8, get_qconfig ## 方法1 ## 可以查看 MinMaxObserver 和 HistogramObserver 的 __init__ 方法了解有哪些可以设置的参数 qconfig QConfig( weightFakeQuantize.with_args( observerMinMaxObserver, averaging_constant0.01, # 滑动平均系数 dtypeqint8, qschemetorch.per_channel_symmetric, # 权重使用per-channel量化只有 weight 支持 per channel 量化 ch_axis0, # Conv权重shape为 (out_channels, in_channels, H, W)第0维是输出通道 ), outputFakeQuantize.with_args( observerHistogramObserver, # 激活使用HistogramObserver dtypeqint8, qschemetorch.per_tensor_symmetric, # 激活使用per-tensor量化 ch_axis-1, # per-tensor模式下值无实际意义只要为负即可 ), ) ## 方法2 qconfig get_qconfig(observerHistogramObserver)​适用场景​常规量化任务权重分布稳定用 MinMaxObserver 即可激活分布复杂用 HistogramObserver 精细处理。​注意​根据敏感节点配置 int16 时无需重新校准权重 channel min/max 会被记录下来、直方图信息也会被记录下来。3. QconfigSetter 与 Template 配置3.1 完整 QconfigSetter 配置from horizon_plugin_pytorch.quantization import QconfigSetter, qint8, qint16 from horizon_plugin_pytorch.quantization.qconfig_setter import ( ModuleNameTemplate, ConvDtypeTemplate, MatmulDtypeTemplate, SensitivityTemplate, ) # 敏感节点列表需要int16的层 out_weight_int16 [backbone.conv1, backbone.conv2,] sensitive_list [] for name in out_weight_int16: sensitive_list.append((name, output)) sensitive_list.append((name, weight)) qconfig_setter QconfigSetter( reference_qconfigqconfig, # 基准QConfig templates[ # 1. 默认dtype配置 ModuleNameTemplate({: qint8}), # 2. Conv层专用配置 ConvDtypeTemplate(input_dtypeqint8, weight_dtypeqint8), # 3. Matmul层专用配置 MatmulDtypeTemplate(input_dtypesqint8), # 4. 敏感节点配置int16 SensitivityTemplate(sensitive_list, 1.0), ], )3.2 Template 执行顺序与优先级优先级后配置的Template优先级更高 执行顺序 ModuleNameTemplate - ConvDtypeTemplate - MatmulDtypeTemplate - SensitivityTemplate ↓ ↓ ↓ ↓ 全局默认 Conv覆盖 Matmul覆盖 敏感节点覆盖3.3 SensitivityTemplate 参数说明SensitivityTemplate(sensitive_list, ratio) # ratio参数 # - 1.0 (float): 表示100%列表中所有节点都配置为int16 # - 1 (int): 表示top1仅配置最敏感的1个节点 # - 0.5 (float): 表示50%配置前50%敏感度的节点4. 校准流程详解4.1 标准校准流程from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantState from horizon_plugin_pytorch.march import set_march, March from horizon_plugin_pytorch.quantization.hbdk4 import export # 1. 设置芯片架构 set_march(March.NASH_M) # 2. 准备校准模型 calib_net prepare(model, (input_tensor,), qconfig_setterqconfig_setter) calib_net.eval() # 3. 切换到校准模式 set_fake_quantize(calib_net, FakeQuantState.CALIBRATION) # 4. 执行校准统计量化参数 with torch.no_grad(): calib_net(input_tensor) # 5. 切换到验证模式 set_fake_quantize(calib_net, FakeQuantState.VALIDATION) # 6. 导出量化模型 qat_bc export(calib_net, input_tensor)4.2 FakeQuantState 状态说明状态作用适用阶段CALIBRATION统计激活分布不进行伪量化校准阶段VALIDATION启用伪量化模拟量化推理验证/导出阶段QAT启用伪量化支持梯度回传QAT 训练阶段5. HistogramObserver 高级用法5.1 校准后切换计算方法HistogramObserver 的核心优势​校准后可切换不同计算方法无需重新跑校准数据​。重要提示通过 reset_scale 重新计算 scale 后scale 会保存在 state_dict 中。如果需要保存量化模型请在 scale 更新后重新保存 state_dictHistogramObserver.reset_scale(calib_net, methodpercentile, dtypeqint16) torch.save(calib_net.state_dict(), calib_model.pth) # 重新保存from horizon_plugin_pytorch.quantization.observer_v2 import HistogramObserver # 校准完成后切换为mse方法 HistogramObserver.reset_scale( calib_net, methodmse, dtypeqint8, ) # 或切换为percentile方法适合长尾分布 # percentile 越小截断越激进 HistogramObserver.reset_scale( calib_net, methodpercentile, method_kwargs{percentile: 0.999999}, dtypeqint8, ) # int16层使用完整范围不截断 # int16精度足够高无需通过截断换取精度 HistogramObserver.reset_scale( calib_net, methodpercentile, method_kwargs{percentile: 1.0}, # percentile越小截断越激进 dtypeqint16, )5.2 支持的计算方法方法说明适用场景mse最小化量化误差正态分布默认推荐percentile百分位截断长尾分布、存在离群值kl最小化 KL 散度分布差异敏感场景5.3 方法对比实验# 对比不同方法的效果 methods [mse, percentile, kl] for method in methods: HistogramObserver.reset_scale(calib_net, method) # 评估精度 acc evaluate(calib_net) print(fMethod: {method}, Accuracy: {acc})6. 混合精度配置重要提示reset_dtype 会修改 dtype/scale修改后需要重新保存 state_dict# reset_dtype 后重新保存 for name, mod in calib_net.named_modules(): if headB in name: ... torch.save(calib_net.state_dict(), calib_model.pth)6.1 整个模块配置 int16from horizon_plugin_pytorch.quantization import qint16 # 将headB所有层配置为双int16权重激活 for name, mod in calib_net.named_modules(): if headB in name: weight_fake_quant getattr(mod, weight_fake_quant, None) if weight_fake_quant is not None: weight_fake_quant.reset_dtype(qint16) activation_post_process getattr(mod, activation_post_process, None) if activation_post_process is not None: activation_post_process.reset_dtype(qint16)6.2 单个节点配置 int16# 针对敏感节点单独配置int16 calib_net.headA.fc1.weight_fake_quant.reset_dtype(qint16) calib_net.headA._generated_relu_0.activation_post_process.reset_dtype(qint16) # _generated_* 层是由prepare自动生成的量化节点 # 如: _generated_relu_0, _generated_adaptive_avg_pool2d_0 等7. 完整可运行示例7.1 示例激活 HistogramObserver 权重 MinMaxObserverimport torch import torch.nn as nn import torch.nn.functional as F from torch.quantization import QuantStub, DeQuantStub from horizon_plugin_pytorch.quantization import ( set_fake_quantize, FakeQuantState, prepare, QconfigSetter, qint8, qint16 ) from horizon_plugin_pytorch.quantization.qconfig import QConfig from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize from horizon_plugin_pytorch.quantization.observer_v2 import MinMaxObserver, HistogramObserver from horizon_plugin_pytorch.quantization.qconfig_setter import ( ModuleNameTemplate, ConvDtypeTemplate, MatmulDtypeTemplate, SensitivityTemplate ) from horizon_plugin_pytorch.march import set_march, March from horizon_plugin_pytorch.quantization.hbdk4 import export # 模型定义 class Backbone(nn.Module): def __init__(self): super().__init__() self.conv1 nn.Conv2d(3, 32, kernel_size3, padding1) self.bn1 nn.BatchNorm2d(32) self.relu1 nn.ReLU(inplaceTrue) self.conv2 nn.Conv2d(32, 64, kernel_size3, padding1) self.bn2 nn.BatchNorm2d(64) self.relu2 nn.ReLU(inplaceTrue) def forward(self, x): x self.relu1(self.bn1(self.conv1(x))) x self.relu2(self.bn2(self.conv2(x))) return x class HeadA(nn.Module): def __init__(self, in_channels, num_classes10): super().__init__() self.avgpool nn.AdaptiveAvgPool2d((1, 1)) self.fc1 nn.Linear(in_channels, 128) self.fc2 nn.Linear(128, num_classes) def forward(self, x): x self.avgpool(x) x torch.flatten(x, 1) x F.relu(self.fc1(x)) x self.fc2(x) return x class HeadB(nn.Module): def __init__(self, in_channels, output_dim4): super().__init__() self.flatten nn.Flatten() self.fc nn.Linear(in_channels * 8 * 8, output_dim) def forward(self, x): x F.adaptive_avg_pool2d(x, (8, 8)) x self.flatten(x) x self.fc(x) return x class MyNet(nn.Module): def __init__(self): super().__init__() self.backbone Backbone() self.headA HeadA(in_channels64, num_classes10) self.headB HeadB(in_channels64, output_dim4) self.quant QuantStub() self.dequant_A DeQuantStub() self.dequant_B DeQuantStub() def forward(self, x): x self.quant(x) feat self.backbone(x) outA self.headA(feat) outB self.headB(feat) return self.dequant_A(outA), self.dequant_B(outB) # 主流程 if __name__ __main__: # 1. 创建模型和输入 model MyNet() input_tensor torch.randn(1, 3, 32, 32) * 1000 # 2. 设置芯片架构 set_march(March.NASH_M) # 3. 配置敏感节点backbone使用int16 out_weight_int16 [backbone.conv1, backbone.conv2,] sensitive_list [] for name in out_weight_int16: sensitive_list.append((name, output)) sensitive_list.append((name, weight)) # 4. 配置QConfig激活HistogramObserver 权重MinMaxObserver qconfig QConfig( weightFakeQuantize.with_args( observerMinMaxObserver, averaging_constant0.01, dtypeqint8, qschemetorch.per_channel_symmetric, ch_axis0, ), outputFakeQuantize.with_args( observerHistogramObserver, dtypeqint8, qschemetorch.per_tensor_symmetric, ch_axis-1, ), ) # 5. 配置QconfigSetter qconfig_setter QconfigSetter( reference_qconfigqconfig, templates[ ModuleNameTemplate({: qint8}), ConvDtypeTemplate(input_dtypeqint8, weight_dtypeqint8), MatmulDtypeTemplate(input_dtypesqint8), SensitivityTemplate(sensitive_list, 1.0) ], ) # 6. 准备校准模型 calib_net prepare(model, (input_tensor,), qconfig_setterqconfig_setter) calib_net.eval() # 7. 执行校准 set_fake_quantize(calib_net, FakeQuantState.CALIBRATION) with torch.no_grad(): calib_net(input_tensor) # 8. 可选切换HistogramObserver的计算方法默认是mse # 如需切换为percentile方法适合长尾分布 # HistogramObserver.reset_scale( # calib_net, # methodpercentile, # method_kwargs{percentile: 0.999999}, # dtypeqint8, # ) # 9. 混合精度配置headB全部使用int16 for name, mod in calib_net.named_modules(): if headB in name: if hasattr(mod, weight_fake_quant) and mod.weight_fake_quant is not None: mod.weight_fake_quant.reset_dtype(qint16) if hasattr(mod, activation_post_process) and mod.activation_post_process is not None: mod.activation_post_process.reset_dtype(qint16) # reset_dtype 后重新 reset_scale 以获得最优缩放 HistogramObserver.reset_scale( calib_net, methodmse, dtypeqint16, ) # 10. 切换验证模式并导出 calib_net.eval() set_fake_quantize(calib_net, FakeQuantState.VALIDATION) qat_bc export(calib_net, input_tensor) print(示例量化模型导出成功)7.2. 常见问题与解决方案Q1: AttributeError: ‘NoneType’ object has no attribute ‘reset_dtype’​原因​​部分模块如 ReLU、BatchNorm的 activation_post_process 为 None。​解决​if hasattr(mod, activation_post_process) and mod.activation_post_process is not None: mod.activation_post_process.reset_dtype(qint16)Q2: 敏感节点配置未生效​原因​​Template 优先级问题或节点名称不匹配。​解决​# 确保SensitivityTemplate放在templates列表最后 templates[ ModuleNameTemplate({: qint8}), ConvDtypeTemplate(...), SensitivityTemplate(sensitive_list, 1.0) # 最后配置优先级最高 ]Q3: HistogramObserver 无需重复校准​分析​在浮点结构/权重未修改、校准数据相同的情况下无需重复校准​使用方法​使用 HistogramObserver 激活 MinMaxObserver 权重校准后通过 reset_scale 尝试不同方法校准后通过 reset_dtype 配置敏感 int168.总结本文系统介绍了 征程 6 平台的量化配置与校准流程阶段关键配置QConfigHistogramObserver 推荐QconfigSetterTemplate 优先级、SensitivityTemplate校准流程CALIBRATION - forward - reset_scale - VALIDATION - export混合精度reset_dtype 批量/单节点配置 int16无需重新校准​推荐实践​校准阶段优先使用 HistogramObserver利用 reset_scale 对比不同计算方法mse/percentile/kl敏感节点使用 reset_dtype 配置 int16