ConvNeXt 的 facebookresearch 版本 的模型 实现 和 推理
ConvNeXt 的 facebookresearch 版本 的模型 实现 和 推理flyfish之前是torchvision版本的各种实现 现在是 facebookresearch 版本模型代码importtorchimporttorch.nnasnnimporttorch.nn.functionalasF# classDropPath(nn.Module): 随机深度Stochastic Depth等价于 timm.models.layers.DropPath 训练时按概率随机丢弃整个样本的残差分支推理时直通 def__init__(self,drop_prob:float0.,scale_by_keep:boolTrue):super().__init__()self.drop_probdrop_prob self.scale_by_keepscale_by_keepdefforward(self,x):# 概率为0 或 评估模式下直接返回原张量ifself.drop_prob0.ornotself.training:returnx keep_prob1-self.drop_prob# 生成与batch维度对齐的随机掩码shape(x.shape[0],)(1,)*(x.ndim-1)random_tensorx.new_empty(shape).bernoulli_(keep_prob)# 除以保留概率保证输出期望不变ifself.scale_by_keep:random_tensor.div_(keep_prob)returnx*random_tensor# # 直接使用 PyTorch 官方内置的截断正态初始化fromtorch.nn.initimporttrunc_normal_# ConvNeXt 模块 classBlock(nn.Module):def__init__(self,dim,drop_path0.,layer_scale_init_value1e-6):super().__init__()self.dwconvnn.Conv2d(dim,dim,kernel_size7,padding3,groupsdim)self.normLayerNorm(dim,eps1e-6)self.pwconv1nn.Linear(dim,4*dim)self.actnn.GELU()self.pwconv2nn.Linear(4*dim,dim)self.gammann.Parameter(layer_scale_init_value*torch.ones((dim)),requires_gradTrue)iflayer_scale_init_value0elseNoneself.drop_pathDropPath(drop_path)ifdrop_path0.elsenn.Identity()defforward(self,x):inputx xself.dwconv(x)xx.permute(0,2,3,1)xself.norm(x)xself.pwconv1(x)xself.act(x)xself.pwconv2(x)ifself.gammaisnotNone:xself.gamma*x xx.permute(0,3,1,2)xinputself.drop_path(x)returnxclassLayerNorm(nn.Module):支持 channels_first 和 channels_last 两种格式的层归一化def__init__(self,normalized_shape,eps1e-6,data_formatchannels_last):super().__init__()self.weightnn.Parameter(torch.ones(normalized_shape))self.biasnn.Parameter(torch.zeros(normalized_shape))self.epseps self.data_formatdata_format self.normalized_shape(normalized_shape,)defforward(self,x):ifself.data_formatchannels_last:returnF.layer_norm(x,self.normalized_shape,self.weight,self.bias,self.eps)elifself.data_formatchannels_first:ux.mean(1,keepdimTrue)s(x-u).pow(2).mean(1,keepdimTrue)x(x-u)/torch.sqrt(sself.eps)xself.weight[:,None,None]*xself.bias[:,None,None]returnxclassConvNeXt(nn.Module):def__init__(self,in_chans3,num_classes1000,depths[3,3,9,3],dims[96,192,384,768],drop_path_rate0.,layer_scale_init_value1e-6,head_init_scale1.,):super().__init__()self.downsample_layersnn.ModuleList()stemnn.Sequential(nn.Conv2d(in_chans,dims[0],kernel_size4,stride4),LayerNorm(dims[0],eps1e-6,data_formatchannels_first))self.downsample_layers.append(stem)foriinrange(3):downsample_layernn.Sequential(LayerNorm(dims[i],eps1e-6,data_formatchannels_first),nn.Conv2d(dims[i],dims[i1],kernel_size2,stride2),)self.downsample_layers.append(downsample_layer)self.stagesnn.ModuleList()dp_rates[x.item()forxintorch.linspace(0,drop_path_rate,sum(depths))]cur0foriinrange(4):stagenn.Sequential(*[Block(dimdims[i],drop_pathdp_rates[curj],layer_scale_init_valuelayer_scale_init_value)forjinrange(depths[i])])self.stages.append(stage)curdepths[i]self.normnn.LayerNorm(dims[-1],eps1e-6)self.headnn.Linear(dims[-1],num_classes)self.apply(self._init_weights)self.head.weight.data.mul_(head_init_scale)self.head.bias.data.mul_(head_init_scale)def_init_weights(self,m):ifisinstance(m,(nn.Conv2d,nn.Linear)):trunc_normal_(m.weight,std.02)nn.init.constant_(m.bias,0)defforward_features(self,x):foriinrange(4):xself.downsample_layers[i](x)xself.stages[i](x)returnself.norm(x.mean([-2,-1]))defforward(self,x):xself.forward_features(x)xself.head(x)returnx# 工厂函数ConvNeXt-Base defconvnext_base(num_classes1000,**kwargs): ConvNeXt-Base 标准配置 depths: [3, 3, 27, 3] dims: [128, 256, 512, 1024] modelConvNeXt(depths[3,3,27,3],dims[128,256,512,1024],num_classesnum_classes,**kwargs)returnmodel推理代码import torchimportosfromPILimportImagefromtorchvisionimporttransforms# 引入你自定义的ConvNeXt模型文件确保和本文件同目录fromconvnetxt_baseimportconvnext_base# 配置 # 本地模型权重路径.pth / .pt / .pth.tar 格式MODEL_WEIGHT_PATHconvnext_base_22k_224.pth# 测试图像路径TEST_IMAGE_PATHtest.jpg# 分类类别数必须和训练模型时的num_classes一致NUM_CLASSES21841# 输出Top-K预测结果TOPK5# # 设备配置自动使用GPU无GPU则用CPUDEVICEtorch.device(cudaiftorch.cuda.is_available()elsecpu)defget_convnext_preprocess(): ConvNeXt 官方标准图像预处理必须严格匹配训练流程 步骤Resize → CenterCrop → ToTensor → Normalize preprocesstransforms.Compose([# 缩放到256像素transforms.Resize(256),# 中心裁剪为224x224ConvNeXt-Tiny标准输入尺寸transforms.CenterCrop(224),# 转为Tensor数值归一化到 [0, 1]transforms.ToTensor(),# ImageNet 标准归一化ConvNeXt训练使用的均值/方差transforms.Normalize(mean[0.485,0.456,0.406],std[0.229,0.224,0.225])])returnpreprocessdefload_local_model(model_path:str,num_classes:int): 初始化自定义ConvNeXt模型 加载本地权重 :param model_path: 本地权重文件路径 :param num_classes: 分类类别数 :return: 评估模式的模型 # 1. 初始化模型结构modelconvnext_base(num_classesnum_classes)# 2. 检查权重文件是否存在ifnotos.path.exists(model_path):raiseFileNotFoundError(f模型权重不存在路径{model_path})# 3. 加载本地权重map_location适配CPU/GPUcheckpointtorch.load(model_path,map_locationDEVICE)model.load_state_dict(checkpoint[model],strictTrue)# strictFalse适配修改了分类头num_classes的模型避免权重不匹配报错#model.load_state_dict(checkpoint, strictTrue)# 4. 模型迁移到设备 开启评估模式关闭dropout/bn训练特性modelmodel.to(DEVICE)model.eval()returnmodeldefimage_preprocess(image_path:str,transform): 图像完整预处理加载 → 转RGB → 预处理 → 增加Batch维度 → 设备迁移 # 检查图像是否存在ifnotos.path.exists(image_path):raiseFileNotFoundError(f测试图像不存在路径{image_path})# 1. 加载图像 强制转为RGB解决灰度图/透明通道报错问题imageImage.open(image_path).convert(RGB)# 2. 应用预处理 [H, W, C] → [C, H, W]tensor_imagetransform(image)# 3. 增加Batch维度[C, H, W] → [1, C, H, W]模型要求批量输入tensor_imagetensor_image.unsqueeze(0)# 4. 迁移到设备tensor_imagetensor_image.to(DEVICE)returnimage,tensor_imagedefmodel_infer(model,input_tensor): 模型推理 后处理 1. 无梯度推理节省显存/加速 2. Softmax将Logits转为概率 3. 提取Top-K置信度和类别索引 withtorch.no_grad():# 禁用梯度计算推理必备# 前向传播输出原始预测值 (logits)outputsmodel(input_tensor)# 后处理1softmax归一化得到0~1的概率值probabilitiestorch.softmax(outputs,dim1)# 后处理2获取Top-K 类别索引 置信度topk_probs,topk_indicestorch.topk(probabilities,kTOPK)# 转为numpy格式方便打印topk_probstopk_probs.cpu().numpy()[0]topk_indicestopk_indices.cpu().numpy()[0]returntopk_probs,topk_indicesdefprint_result(topk_probs,topk_indices,class_namesNone):打印最终预测结果print(\n*50)print(f图像分类预测结果Top-{TOPK})print(*50)ifclass_namesisNone:class_names[f类别_{i}foriinrange(NUM_CLASSES)]fori,(idx,prob)inenumerate(zip(topk_indices,topk_probs)):print(fTop{i1}{class_names[idx]:10}| 置信度{prob:.4f}({prob*100:.2f}%))print(*50)if__name____main__:# 1. 初始化预处理工具 transformget_convnext_preprocess()print(图像预处理配置完成)# 2. 加载本地模型 print(f正在加载本地模型{MODEL_WEIGHT_PATH})modelload_local_model(MODEL_WEIGHT_PATH,NUM_CLASSES)#print(model)print(模型加载完成已切换到评估模式)# 3. 图像预处理 print(f正在处理图像{TEST_IMAGE_PATH})raw_image,input_tensorimage_preprocess(TEST_IMAGE_PATH,transform)print(图像预处理完成)# 4. 模型推理 print(开始模型推理...)topk_probs,topk_indicesmodel_infer(model,input_tensor)print(推理完成)# 5. 打印结果 # 替换为真实类别名称# 示例custom_classes [cat, dog, car, tree]custom_classesNoneprint_result(topk_probs,topk_indices,custom_classes)结果图像预处理配置完成 正在加载本地模型convnext_base_22k_224.pth 模型加载完成已切换到评估模式 正在处理图像test.jpg 图像预处理完成 开始模型推理... 推理完成 图像分类预测结果Top-5 Top1类别_9454 | 置信度0.2625 (26.25%) Top2类别_2266 | 置信度0.1490 (14.90%) Top3类别_2298 | 置信度0.1391 (13.91%) Top4类别_2265 | 置信度0.0401 (4.01%) Top5类别_216 | 置信度0.0379 (3.79%)