拒绝算法黑盒XGBoost SHAP 一键生成 10 张出版级模型解释图现在跑机器学习大家最头疼的往往不是怎么把 R² 刷高而是被导师或者业务方灵魂拷问“你这个模型为什么会得出这个结果哪个特征起决定性作用”像 XGBoost、随机森林这种集成树模型虽然精度吊打传统回归但“黑盒”属性太强。这时候SHAP (SHapley Additive exPlanations)就是我们最好的破局利器。今天分享一套我压箱底的 Python 自动化脚本。它不仅能完成 XGBoost 的训练与评估更核心的是它能一口气生成 10 张高颜值、高分辨率400 DPI的 SHAP 可视化图表包括小提琴图、热力图、瀑布图、依赖图等直接满足发 Paper 或做汇报的全部需求。️ 核心代码逻辑拆解这套脚本主打一个“端到端”从数据塞进去到美图吐出来一气呵成。为了方便理解我们把核心操作拆解开来看看。1. 中文映射与模型训练在很多实际业务中比如做城市规划、经济地理分析我们的特征变量通常是中文如“人均GDP”、“交通可达性”。为了防止绘图时出现乱码脚本里内置了字段映射字典并在训练前完成了数据清洗和 XGBoost 拟合# 核心特征中文化映射 FEATURE_CN_MAP { Feat1: 人均GDP, Feat2: 专利/万人, # ... 其他特征 } TARGET_CN_NAME 韧性指数 # XGBoost 模型训练 model xgb.XGBRegressor( n_estimators1000, learning_rate0.05, max_depth6, random_state42 ) model.fit(X_train, y_train)2. SHAP 值的核心计算模型算完了接下来就是把模型喂给 SHAP 解释器。这一步是所有可视化的基础它会计算出每个样本、每个特征对最终预测结果的贡献度SHAP Value。# 实例化 SHAP 解释器并计算测试集的 SHAP 值 explainer shap.Explainer(model) shap_values_test explainer(X_test) shap_mat shap_values_test.values # 顺手把特征按照重要性SHAP绝对值均值排个序方便后续画图 feature_order np.argsort(np.abs(shap_mat).mean(axis0))[::-1]3. 出版级图表定制以热力图为例很多直接调shap.plots画出来的默认图表颜色比较暗淡。脚本里我对 Matplotlib 进行了深度的客制化统一使用了极简、明亮的配色如#5DADE2亮蓝、#C0392B亮红并去除了冗余的边框非常适合直接放进论文里。以“全样本 SHAP 热力图”的生成为例# 提取排序后的数据并设置颜色阈值 heat_data shap_mat[sample_order][:, top_idx].T vmax np.percentile(np.abs(heat_data), 98) # 自定义高颜值热力图 fig_h, ax_h plt.subplots(figsize(16, 9), dpi150) # 使用 RdBu_r 红蓝渐变色带清晰对比正负贡献 im_h ax_h.imshow(heat_data, aspectauto, cmapRdBu_r, vmin-vmax, vmaxvmax) ax_h.set_title(SHAP Values Heatmap, fontsize20, pad12, fontweightbold) # 去除多余的网格线保持画面干净 # ... (详见文末完整代码) 完整源码拿去即用以下是完整的 Python 脚本。运行前请确保安装了xgboost,shap,pandas,matplotlib,scikit-learn等依赖。你只需要把FILE_PATH改成你自己的数据路径调整一下映射字典运行后在同级目录下就会自动生成 10 张高清美图和一个模型指标评估表包含 R², RMSE, MAE。# -*- coding: utf-8 -*- 功能XGBoost 模型训练及 10 种出版级 SHAP 可视化出图 特点高分辨率 (400 DPI)、明亮极简配色、支持中文字段映射 import os import sys import importlib import numpy as np import pandas as pd import matplotlib.pyplot as plt import xgboost as xgb from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score from sklearn.model_selection import train_test_split if __name__ __main__: current_dir os.path.dirname(os.path.abspath(__file__)) sys.path [p for p in sys.path if os.path.abspath(p or .) ! current_dir] shap importlib.import_module(shap) # 参数配置区 FIG_DPI 400 TOP_N 15 RANDOM_HEATMAP_SAMPLES 20 # 数据路径请替换为您自己的 CSV FILE_PATH ./data/dataset.csv TARGET_COL XHM # 字段映射为了图表展示更直观 FEATURE_CN_MAP { Feat1: 人均GDP, Feat2: 专利/万人, Feat3: 对外开放度, Feat4: 产业高级化, Feat5: 科技支出占比, Feat6: 交通可达性, Feat7: 普惠金融指数, } TARGET_CN_NAME 韧性指数 # 解决图表中文字体显示问题 plt.rcParams[font.family] [SimSun, DejaVu Sans] plt.rcParams[axes.unicode_minus] False # # 1. 数据加载与清洗 df pd.read_csv(FILE_PATH) df_numeric df.select_dtypes(include[np.number]) df_numeric df_numeric.loc[:, df_numeric.nunique(dropnaTrue) 1] df_numeric df_numeric.replace([np.inf, -np.inf], np.nan).dropna() if TARGET_COL not in df_numeric.columns: raise ValueError(f目标列 {TARGET_COL} 不在数据中。) X df_numeric.drop(columns[TARGET_COL], errorsignore) y df_numeric[TARGET_COL] X X.rename(columnsFEATURE_CN_MAP) y y.rename(TARGET_CN_NAME) X_train, X_test, y_train, y_test train_test_split( X, y, test_size0.2, random_state42, stratifyNone ) # 2. 模型训练与评估 model xgb.XGBRegressor( n_estimators1000, learning_rate0.05, max_depth6, random_state42 ) model.fit(X_train, y_train) y_pred model.predict(X_test) test_r2 r2_score(y_test, y_pred) test_rmse np.sqrt(mean_squared_error(y_test, y_pred)) test_mae mean_absolute_error(y_test, y_pred) metrics_df pd.DataFrame( { Metric: [R2, RMSE, MAE], Value: [test_r2, test_rmse, test_mae], } ) metrics_df.to_csv(model_test_metrics.csv, indexFalse, encodingutf-8-sig) # 3. 计算 SHAP 值 explainer shap.Explainer(model) shap_values_test explainer(X_test) shap_mat shap_values_test.values feature_order np.argsort(np.abs(shap_mat).mean(axis0))[::-1] top_n min(TOP_N, X_test.shape[1]) top_idx feature_order[:top_n] top_feature_names [X_test.columns[i] for i in top_idx] mean_abs_shap np.abs(shap_mat).mean(axis0) top_importance mean_abs_shap[top_idx] # 开始批量绘图 # 1) Violin Plot (小提琴图) plt.figure(figsize(12, 8), dpi150) shap.summary_plot( shap_values_test, X_test, plot_typeviolin, max_displaytop_n, color#5DADE2, showFalse, ) ax_v plt.gca() ax_v.set_title(SHAP Value Distribution (Violin Plot), fontsize20, pad14, fontweightbold) ax_v.set_xlabel(SHAP Value, fontsize16) ax_v.set_ylabel() ax_v.tick_params(axisboth, labelsize13) ax_v.grid(axisx, linestyle--, alpha0.2) plt.tight_layout() plt.savefig(shap_violin.png, dpiFIG_DPI, bbox_inchestight, facecolorwhite) plt.close() # 2) 全样本热力图 sample_order np.argsort(np.abs(shap_mat).sum(axis1))[::-1] heat_data shap_mat[sample_order][:, top_idx].T vmax np.percentile(np.abs(heat_data), 98) fig_h, ax_h plt.subplots(figsize(16, 9), dpi150) im_h ax_h.imshow(heat_data, aspectauto, cmapRdBu_r, vmin-vmax, vmaxvmax) ax_h.set_title(SHAP Values Heatmap, fontsize20, pad12, fontweightbold) ax_h.set_ylabel(Feature, fontsize14) ax_h.set_xlabel(Sample Index (sorted by total |SHAP|), fontsize14) ax_h.set_yticks(np.arange(len(top_feature_names))) ax_h.set_yticklabels(top_feature_names, fontsize11) ax_h.set_xticks(np.linspace(0, heat_data.shape[1] - 1, min(6, heat_data.shape[1])).astype(int)) ax_h.tick_params(axisx, labelsize10) cbar_h fig_h.colorbar(im_h, axax_h, fraction0.03, pad0.02) cbar_h.set_label(SHAP Value, fontsize13) cbar_h.ax.tick_params(labelsize10) fig_h.tight_layout() fig_h.savefig(shap_heatmap.png, dpiFIG_DPI, bbox_inchestight, facecolorwhite) plt.close(fig_h) # 3) 20个随机样本热力图 rng np.random.default_rng(42) sample_count min(RANDOM_HEATMAP_SAMPLES, X_test.shape[0]) rand_idx np.sort(rng.choice(X_test.shape[0], sizesample_count, replaceFalse)) top12_idx feature_order[: min(12, len(feature_order))] random_heat_data shap_mat[rand_idx][:, top12_idx] random_feature_labels [X_test.columns[i] for i in top12_idx] random_sample_labels [f样本 {i} for i in rand_idx] vmax2 np.percentile(np.abs(random_heat_data), 98) fig_r, ax_r plt.subplots(figsize(13, 10), dpi150) im_r ax_r.imshow(random_heat_data, aspectauto, cmapRdBu_r, vmin-vmax2, vmaxvmax2) ax_r.set_title(SHAP Heatmap - 20 Random Samples, fontsize20, pad12, fontweightbold) ax_r.set_xlabel(Features, fontsize14, fontweightbold) ax_r.set_ylabel(Samples, fontsize14, fontweightbold) ax_r.set_xticks(np.arange(len(random_feature_labels))) ax_r.set_xticklabels(random_feature_labels, rotation40, haright, fontsize11) ax_r.set_yticks(np.arange(len(random_sample_labels))) ax_r.set_yticklabels(random_sample_labels, fontsize10) cbar_r fig_r.colorbar(im_r, axax_r, fraction0.036, pad0.04) cbar_r.set_label(SHAP Value, fontsize13, fontweightbold) cbar_r.ax.tick_params(labelsize10) fig_r.tight_layout() fig_r.savefig(shap_heatmap_20samples.png, dpiFIG_DPI, bbox_inchestight, facecolorwhite) plt.close(fig_r) # 4) Waterfall 图 (单样本解释) waterfall_idx min(5, len(shap_values_test) - 1) plt.figure(figsize(12, 9), dpi150) shap.plots.waterfall(shap_values_test[waterfall_idx], max_display10, showFalse) ax_w plt.gca() ax_w.set_title(fSHAP Waterfall Plot - Sample {waterfall_idx}, fontsize20, pad14, fontweightbold) ax_w.tick_params(axisboth, labelsize12) plt.tight_layout() plt.savefig(shap_waterfall_sample5.png, dpiFIG_DPI, bbox_inchestight, facecolorwhite) plt.close() # 5) 测试集预测散点图 fig_p, ax_p plt.subplots(figsize(8, 8), dpi150) ax_p.scatter(y_test, y_pred, s86, color#2E86C1, alpha0.82, edgecolorwhite, linewidth0.8) min_value min(y_test.min(), y_pred.min()) max_value max(y_test.max(), y_pred.max()) padding (max_value - min_value) * 0.08 if max_value min_value else 0.05 line_min min_value - padding line_max max_value padding ax_p.plot([line_min, line_max], [line_min, line_max], color#C0392B, linewidth2.0, linestyle--) ax_p.set_xlim(line_min, line_max) ax_p.set_ylim(line_min, line_max) ax_p.set_title(Test Set Prediction Performance, fontsize20, pad14, fontweightbold) ax_p.set_xlabel(fActual {TARGET_CN_NAME}, fontsize14) ax_p.set_ylabel(fPredicted {TARGET_CN_NAME}, fontsize14) ax_p.grid(linestyle--, alpha0.25) ax_p.text( 0.05, 0.95, fR² {test_r2:.4f}\nRMSE {test_rmse:.4f}\nMAE {test_mae:.4f}, transformax_p.transAxes, vatop, fontsize13, bbox{boxstyle: round,pad0.35, facecolor: white, edgecolor: #D0D3D4, alpha: 0.92}, ) fig_p.tight_layout() fig_p.savefig(model_prediction_performance.png, dpiFIG_DPI, bbox_inchestight, facecolorwhite) plt.close(fig_p) # 6) 特征重要性条形图 bar_order top_idx[::-1] bar_names [X_test.columns[i] for i in bar_order] bar_values mean_abs_shap[bar_order] fig_b, ax_b plt.subplots(figsize(12, 8), dpi150) colors plt.cm.Blues(np.linspace(0.35, 0.95, len(bar_values))) ax_b.barh(bar_names, bar_values, colorcolors, edgecolorwhite, linewidth1.0) ax_b.set_title(Mean |SHAP| Feature Importance, fontsize20, pad14, fontweightbold) ax_b.set_xlabel(Mean Absolute SHAP Value, fontsize14) ax_b.tick_params(axisboth, labelsize12) ax_b.grid(axisx, linestyle--, alpha0.25) for spine in [top, right, left]: ax_b.spines[spine].set_visible(False) for value, name in zip(bar_values, bar_names): ax_b.text(value, name, f {value:.4f}, vacenter, fontsize10) fig_b.tight_layout() fig_b.savefig(shap_feature_importance_bar.png, dpiFIG_DPI, bbox_inchestight, facecolorwhite) plt.close(fig_b) # 7) Beeswarm (蜂群图) plt.figure(figsize(12, 8), dpi150) shap.summary_plot( shap_values_test, X_test, plot_typedot, max_displaytop_n, showFalse, ) ax_s plt.gca() ax_s.set_title(SHAP Beeswarm Summary, fontsize20, pad14, fontweightbold) ax_s.set_xlabel(SHAP Value, fontsize16) ax_s.tick_params(axisboth, labelsize12) plt.tight_layout() plt.savefig(shap_beeswarm.png, dpiFIG_DPI, bbox_inchestight, facecolorwhite) plt.close() # 8-9) 依赖图 (Dependence Plots) - 自动画出前两个最重要的特征 def save_dependence_plot(feature_idx, rank): feature_name X_test.columns[feature_idx] feature_values X_test.iloc[:, feature_idx] feature_shap_values shap_mat[:, feature_idx] fig_d, ax_d plt.subplots(figsize(10, 7), dpi150) scatter ax_d.scatter( feature_values, feature_shap_values, cfeature_values, cmapcoolwarm, s78, alpha0.85, edgecolorwhite, linewidth0.7, ) ax_d.axhline(0, color#777777, linewidth1.2, linestyle--, alpha0.7) ax_d.set_title(fDependence Plot - {feature_name}, fontsize18, pad12, fontweightbold) ax_d.set_xlabel(feature_name, fontsize14) ax_d.set_ylabel(SHAP Value, fontsize14) ax_d.tick_params(axisboth, labelsize11) ax_d.grid(linestyle--, alpha0.22) cbar_d fig_d.colorbar(scatter, axax_d, fraction0.045, pad0.04) cbar_d.set_label(Feature Value, fontsize12) cbar_d.ax.tick_params(labelsize10) fig_d.tight_layout() fig_d.savefig(fshap_dependence_top{rank}.png, dpiFIG_DPI, bbox_inchestight, facecolorwhite) plt.close(fig_d) for rank, feature_idx in enumerate(feature_order[: min(2, len(feature_order))], start1): save_dependence_plot(feature_idx, rank) # 10) 累计贡献比例图 importance_pct top_importance / top_importance.sum() * 100 cumulative_pct np.cumsum(importance_pct) fig_c, ax_c plt.subplots(figsize(13, 7), dpi150) x_pos np.arange(len(top_feature_names)) ax_c.bar(x_pos, importance_pct, color#5DADE2, edgecolorwhite, linewidth1.0) ax_c.set_title(SHAP Contribution Share, fontsize20, pad14, fontweightbold) ax_c.set_ylabel(Contribution Share (%), fontsize14) ax_c.set_xticks(x_pos) ax_c.set_xticklabels(top_feature_names, rotation35, haright, fontsize11) ax_c.tick_params(axisy, labelsize11) ax_c.grid(axisy, linestyle--, alpha0.25) ax_c2 ax_c.twinx() ax_c2.plot(x_pos, cumulative_pct, color#D35400, markero, linewidth2.6) ax_c2.set_ylabel(Cumulative Share (%), fontsize14) ax_c2.set_ylim(0, 105) ax_c2.tick_params(axisy, labelsize11) for spine in [top]: ax_c.spines[spine].set_visible(False) ax_c2.spines[spine].set_visible(False) fig_c.tight_layout() fig_c.savefig(shap_contribution_share.png, dpiFIG_DPI, bbox_inchestight, facecolorwhite) plt.close(fig_c) print(模型与图表生成完毕查看本地文件即可)