1. 概述ObserverBase 是 horizon_plugin_pytorch 量化框架中所有 Observer 的抽象基类。它定义了量化校准器的统一接口和核心功能为各种量化策略MinMax、MSE、KL 等提供了基础架构。2. ABCMeta 深度解析2.1 Python 元类机制在 Python 中​类也是对象​类是由元类metaclass创建的默认情况下所有类都由元类创建。当指定 metaclassABCMeta 时类的创建过程由 ABCMeta 控制。示例如下from abc import ABCMeta, abstractmethod class ObserverBase(torch.nn.Module, metaclassABCMeta): abstractmethod def forward(self, x): pass2.2 abstractmethod 装饰器def abstractmethod(funcobj): 标记方法为抽象方法 funcobj.__isabstractmethod__ True # 仅设置标志位 return funcobj2.3 ObserverBase 中的应用# 基类定义抽象方法 class ObserverBase(torch.nn.Module, metaclassABCMeta): abstractmethod def forward(self, x): pass # ObserverBase.__abstractmethods__ frozenset({forward}) # 子类实现 class MinMaxObserver(ObserverBase): def forward(self, x_orig): return x_orig # MinMaxObserver.__abstractmethods__ frozenset() → 可实例化3. ObserverBase 完整源码class ObserverBase(torch.nn.Module, metaclassABCMeta): rBase observer Module. Any observer implementation should derive from this class. Concrete observers should follow the same API. In forward, they will update the statistics of the observed Tensor. And they should provide a calculate_qparams function that computes the quantization parameters given the collected statistics. Args: averaging_constant: Averaging constant for min/max. ch_axis: Channel axis. dtype: Quantized data type. qscheme: Quantization scheme to be used. quant_min: Min quantization value. Will follow dtype if unspecified. quant_max: Max quantization value. Will follow dtype if unspecified. is_sync_quantize: If sync statistics when training with multiple devices. factory_kwargs: kwargs which are passed to factory functions for min_val and max_val. _version 3 eps: torch.Tensor min_val: torch.Tensor max_val: torch.Tensor is_sync_quantize: Optional[bool] True typechecked def __init__( self, averaging_constant: float 0.01, ch_axis: int -1, dtype: Union[torch.dtype, QuantDType] qint8, qscheme: torch.qscheme torch.per_tensor_symmetric, quant_min: int None, quant_max: int None, is_sync_quantize: Optional[bool] None, factory_kwargs: Dict None, compute_scale_strategyComputeScaleStrategy.STATISTIC, ): super(ObserverBase, self).__init__() if qscheme torch.per_channel_symmetric: assert ( ch_axis 0 ), ch_axis should be non-negative when using per_channel_symmetric qcsheme else: assert ( ch_axis 0 ), ch_axis should be negative when using per_tensor_symmetric qcsheme dtype get_horizon_quant_dtype(dtype) assert qscheme in ( torch.per_tensor_symmetric, torch.per_channel_symmetric, ), ( only support per_tensor_symmetric and per_channel_symmetric qscheme ) self.averaging_constant averaging_constant self.ch_axis ch_axis self.dtype dtype self.qscheme qscheme self._set_quant_min_max(self.dtype, quant_min, quant_max) if is_sync_quantize is not None: self.is_sync_quantize is_sync_quantize self.compute_scale_strategy compute_scale_strategy factory_kwargs torch.nn.factory_kwargs(factory_kwargs) self.register_buffer( eps, torch.tensor([torch.finfo(torch.float32).eps], **factory_kwargs), ) self.register_buffer(min_val, torch.tensor([], **factory_kwargs)) self.register_buffer(max_val, torch.tensor([], **factory_kwargs)) def _set_quant_min_max( self, dtype, quant_minNone, quant_maxNone, ): if (quant_min is not None) and (quant_max is not None): assert quant_min quant_max, ( qmin must be strictly less than qmax for user-specified quantization range. ) assert ( quant_min 0 quant_max ), Used-specified quantization range must include 0. assert qinfo(dtype).min quant_min, quant_min out of bound assert quant_max qinfo(dtype).max, quant_max out of bound self.quant_min, self.quant_max quant_min, quant_max else: self.quant_min, self.quant_max ( qinfo(self.dtype).min, qinfo(self.dtype).max, ) def reset_dtype(self, dtype): dtype get_horizon_quant_dtype(dtype) if dtype self.dtype: return self.dtype dtype self._set_quant_min_max(self.dtype) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): # buffers has been renamed from min/max_vals to min/max_val buffer_name_mapping {min_vals: min_val, max_vals: max_val} for old_name in buffer_name_mapping: k prefix old_name if k in state_dict: v state_dict.pop(k) state_dict[prefix buffer_name_mapping[old_name]] v eps_key prefix eps if eps_key not in state_dict: # eps was moved to a buffer in version 2 eps torch.tensor([torch.finfo(torch.float32).eps]) state_dict[eps_key] eps local_state [min_val, max_val] for name in local_state: key prefix name if key in state_dict: # if ndim0, make it ndim1 state_dict[key] state_dict[key].reshape(-1) val state_dict[key] # Custom handling to allow loading min_val or max_val # of size N into uninitialized buffers of size 0. The # buffers are resized here, and the values are copied in # the default state_dict loading code of the parent. if name min_val and hasattr(self, min_val): self.min_val.resize_(val.shape) elif hasattr(self, max_val): self.max_val.resize_(val.shape) super(ObserverBase, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ) def _load_from_state_dict_script( self, state_dict: Union[Dict[str, torch.Tensor], Dict[str, torch.Tensor]], prefix: str, local_metadata: Dict[str, torch.Tensor], strict: bool, missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str], ): self._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ) def sync_minmax(self, min_val, max_val): if dist.is_initialized() and min_val.is_cuda: dist.all_reduce(min_val, opdist.ReduceOp.MIN) dist.all_reduce(max_val, opdist.ReduceOp.MAX) def calculate_qparams(self): rCalculate the quantization parameters. Returns: scales: Scales tensor of shape (#channels,) zero_points: Zero points tensor of shape (#channels,) if self.min_val.numel() 0 or self.max_val.numel() 0: warnings.warn( Must run observer before calling calculate_qparams. Returning default scale and zero point. This is an expected behavior if you use KLObserver and set 1 update_interval total steps. , ) return torch.tensor( [1.0], deviceself.min_val.device ), torch.tensor([0], deviceself.min_val.device) scale _compute_scale_symmetric( self.min_val, self.max_val, self.quant_min, self.quant_max, self.eps, self.compute_scale_strategy, ) return scale, None def repr_msgs(self): msges [] # only print minmax value for per tensor if hasattr(self, min_val) and self.min_val.numel() 1: msges.append(min_val{}.format(self.min_val.item())) if hasattr(self, max_val) and self.max_val.numel() 1: msges.append(max_val{}.format(self.max_val.item())) return msges def extra_repr(self): return ,.join(self.repr_msgs()) abstractmethod def forward(self, x): pass with_args classmethod(_with_args)4. 核心属性详解4.1 量化配置属性# 基础量化参数 self.averaging_constant: float # 移动平均系数 self.ch_axis: int # 通道轴 (per_channel量化时使用) self.dtype: QuantDType # 量化数据类型 (qint8, qint4等) self.qscheme: torch.qscheme # 量化方案 (per_tensor/per_channel) self.quant_min: int # 量化最小值 self.quant_max: int # 量化最大值 self.is_sync_quantize: bool # 多卡同步统计量 self.compute_scale_strategy # scale计算策略 (STATISTIC/POT/FP16等)4.2 统计量缓冲区self.register_buffer(eps, torch.tensor([torch.finfo(torch.float32).eps])) self.register_buffer(min_val, torch.tensor([])) self.register_buffer(max_val, torch.tensor([]))使用 register_buffer 注册的原因​不参与梯度计算​统计量不是模型参数​随模型迁移设备​model.cuda() 时自动迁移​可保存到 state_dict​校准结果可持久化5. 核心方法详解5.1 init - 初始化参数说明参数默认值说明averaging_constant0.01移动平均系数值越大当前 batch 权重越高ch_axis-1通道轴负数表示 per_tensor非负表示 per_channeldtypeqint8量化数据类型qschemeper_tensor_symmetric量化方案quant_min/maxNone自定义量化范围None 时根据 dtype 自动设置is_sync_quantizeTRUE多卡训练时是否同步统计量关键校验逻辑# per_channel 必须指定有效的 ch_axis if qscheme torch.per_channel_symmetric: assert ch_axis 0, ch_axis should be non-negative else: assert ch_axis 0, ch_axis should be negative for per_tensor # 仅支持对称量化 assert qscheme in ( torch.per_tensor_symmetric, torch.per_channel_symmetric, )5.2 forward - 更新统计信息抽象方法设计意图子类必须实现此方法由 ABCMeta 强制在校准阶段每个 forward pass 收集激活值的统计信息返回原始输入不修改数据流典型实现模式def forward(self, x_orig): # 1. 计算当前 batch 的统计量 min_val_cur, max_val_cur compute_statistics(x_orig) # 2. 多卡同步可选 if self.is_sync_quantize: self.sync_minmax(min_val_cur, max_val_cur) # 3. 更新累计统计量移动平均 self.min_val update_statistics(self.min_val, min_val_cur) self.max_val update_statistics(self.max_val, max_val_cur) return x_orig # 原样返回不干扰前向传播5.3 calculate_qparams - 计算量化参数核心计算逻辑_compute_scale_symmetricdef _compute_scale_symmetric(min_val, max_val, quant_min, quant_max, eps, strategy): # 对称量化公式scale max(|min|, |max|) / (quant_range / 2) scale ( torch.max(-min_val, max_val) .clamp_min(0) .div(float(quant_max - quant_min) / 2) .clamp_min(eps) ) # 可选的 scale 约束策略 if strategy ComputeScaleStrategy.KPOT: # K-POT (可训练POT) scale k_pot_scale(scale) elif strategy ComputeScaleStrategy.POT: # Power-of-Two scale 2 ** torch.ceil(torch.log2(scale)) elif strategy ComputeScaleStrategy.FP16: # FP16 精度 scale _get_fp16_scale(scale) return scale5.4 sync_minmax - 多卡同步def sync_minmax(self, min_val, max_val): if dist.is_initialized() and min_val.is_cuda: dist.all_reduce(min_val, opdist.ReduceOp.MIN) dist.all_reduce(max_val, opdist.ReduceOp.MAX)原理使用 all_reduce 聚合多卡的统计量MIN 操作取所有卡的最小值MAX 操作取所有卡的最大值确保多卡训练时校准结果一致5.5 _load_from_state_dict - 状态加载关键功能版本兼容处理旧版名称 min_vals → min_val动态调整 buffer 大小支持从校准模型加载参数到 QAT 模型6. 类继承体系ObserverBase (抽象基类) │ ├── MinMaxObserver # 移动平均 min/max 统计 │ │ │ └── ClipObserver # 带截断的 min/max 统计 │ ├── FixedScaleObserver # 固定 scale不统计 │ ├── PercentileObserver # 百分位统计 │ ├── MSEObserver # 最小化 MSE 搜索最优 scale │ ├── KLObserver # KL 散度校准 │ ├── MixObserver # 混合多种方法 │ └── HistogramObserver # 直方图统计支持多种度量7. 设计亮点​统一接口​所有 Observer 遵循相同的 API便于替换和扩展​抽象基类约束​通过 ABCMeta 强制子类实现 forward 方法​状态持久化​统计量作为 buffer 保存支持校准结果复用​分布式支持​内置多卡同步机制​版本兼容​_load_from_state_dict 处理历史版本兼容​灵活配置​支持多种量化方案、数据类型、scale 策略8.与 PyTorch 原生 Observer 的对比特性PyTorch ObserverBaseHorizon ObserverBase量化方案支持非对称量化仅支持对称量化scale 约束无POT/FP16/KPOT 策略分布式同步需自行实现内置 sync_minmax数据类型标准 torch.dtype扩展 QuantDType (qint4 等)版本管理无_version 字段支持迁移