使用神经网络解决二分类问题从回归到分类在前面的神经网络学习中我们通常用模型解决回归问题例如预测房价、销量、温度等连续数值。现在要进入另一个非常常见的机器学习任务分类问题。分类任务与回归任务在网络结构上有很多相似之处例如仍然可以使用全连接层、激活函数、优化器和早停机制。但二者最大的区别在于输出层希望得到的结果不同损失函数不同评估指标不同本文以二分类 Binary Classification为例介绍神经网络如何完成分类任务。1. 什么是二分类问题二分类指的是将样本划分到两个类别之一。常见例子包括判断客户是否会购买商品判断信用卡交易是否为欺诈判断雷达信号是否探测到目标判断医学检测结果是否显示患病判断酒店订单是否会被取消这些问题的共同点是结果只有两种可能。例如购买 / 不购买 欺诈 / 正常 有病 / 无病 取消 / 不取消 猫 / 狗在原始数据中类别可能是字符串例如Yes / No Dog / Cat good / bad但神经网络不能直接处理这些文本标签因此需要将它们转换成数字标签。例如df[Class]df[Class].map({good:0,bad:1})也就是good - 0 bad - 1这样模型才能学习输入特征与类别标签之间的关系。2. 分类问题中的准确率 Accuracy在分类任务中最直观的评估指标是准确率 Accuracy。准确率定义为accuracy 预测正确的样本数 / 总样本数如果模型全部预测正确则准确率为accuracy 1.0例如有 100 个样本模型预测对了 88 个那么准确率就是accuracy 88 / 100 0.88也就是 88%。准确率适合用在类别比较均衡的数据集中。例如正样本和负样本数量差不多时准确率通常是一个比较合理的指标。但是如果类别极度不均衡准确率可能会产生误导。例如某疾病检测数据中99% 的人没有病 1% 的人有病如果模型永远预测“没有病”它也能达到 99% 的准确率但这个模型显然没有实际价值。所以在实际项目中除了准确率还经常关注Precision 精确率Recall 召回率F1-scoreROC-AUCPR-AUC混淆矩阵 Confusion Matrix3. 为什么不能直接用准确率作为损失函数训练神经网络时需要一个损失函数 Loss Function来告诉模型当前预测有多差。在回归问题中我们常使用MAE平均绝对误差MSE均方误差这些损失函数是连续、平滑变化的适合梯度下降算法优化。但准确率不适合作为损失函数。原因是准确率是一个“计数比例”它的变化是不连续的。例如模型输出概率从0.51 - 0.99如果真实标签是 1这两个预测最终都被判定为类别 1因此准确率没有变化。但显然0.99 比 0.51 更有信心也更接近理想预测。反过来如果模型输出从0.49 - 0.51只变化了一点点但分类结果却从 0 变成了 1准确率发生突变。这种“不平滑”的特性不适合梯度下降算法。因此分类任务通常使用另一种损失函数交叉熵 Cross-Entropy。4. 交叉熵 Cross-Entropy对于分类任务我们希望模型输出的是某个类别的概率。例如对于二分类问题模型可以输出0.8表示模型认为该样本属于类别 1 的概率是 80%。如果真实标签是 1那么这个预测比较好。如果真实标签是 0那么这个预测就很差。交叉熵的核心思想是当模型给正确类别分配的概率越高损失越小当模型给正确类别分配的概率越低损失越大。对于二分类问题常用损失函数是binary_crossentropy其数学形式可以写成Loss−[y⋅log⁡(p)(1−y)⋅log⁡(1−p)] \text{Loss} -\big[ y \cdot \log(p) (1 - y) \cdot \log(1 - p) \big]Loss−[y⋅log(p)(1−y)⋅log(1−p)]其中yyy是真实标签取值为 0 或 1ppp是模型预测为类别 1 的概率log⁡\loglog是对数函数通常指自然对数如果真实标签是 1希望ppp越接近 1 越好。如果真实标签是 0希望ppp越接近 0 越好。举个例子真实标签为 1预测概率 p 0.99 - 损失很小 预测概率 p 0.60 - 损失较大 预测概率 p 0.01 - 损失非常大这就很好地惩罚了“自信但错误”的预测。5. Sigmoid 函数把输出变成概率普通神经网络的 Dense 层输出可以是任意实数例如-10, -2.5, 0, 3.7, 20但二分类任务需要的是概率也就是范围在[0, 1]之间的数。因此在二分类模型的最后一层通常使用sigmoid 激活函数。Sigmoid 函数可以将任意实数映射到 0 到 1 之间sigmoid(x)11e−x \text{sigmoid}(x) \frac{1}{1 e^{-x}}sigmoid(x)1e−x1​它的特点是输入xxx很大时输出接近 1输入xxx很小时输出接近 0输入xxx为 0 时输出为 0.5因此它非常适合用来表示二分类概率。例如layers.Dense(1,activationsigmoid)这里的含义是输出层只有 1 个神经元该神经元输出类别 1 的概率使用 sigmoid 将输出限制在 0 到 1 之间6. 如何从概率得到最终类别模型输出的是概率而不是直接输出类别。例如0.82表示模型认为样本属于类别 1 的概率是 82%。通常会设置一个阈值 threshold例如0.5判断规则为预测概率 0.5 - 类别 0 预测概率 0.5 - 类别 1例如0.12 - 0 0.48 - 0 0.51 - 1 0.93 - 1Keras 中的binary_accuracy默认也使用 0.5 作为分类阈值。不过在实际业务中阈值不一定非要是 0.5。例如在医学检测、欺诈检测等场景中漏判代价很高可能会降低阈值来提高召回率。7. 示例Ionosphere 数据集二分类Ionosphere 数据集包含来自雷达信号的特征任务是判断信号是否显示存在某种物体还是只是空信号。原始数据中的类别是good bad需要先映射为数字标签df[Class]df[Class].map({good:0,bad:1})8. 数据划分与归一化示例中将数据划分为训练集和验证集df_traindf.sample(frac0.7,random_state0)df_validdf.drop(df_train.index)含义是70% 数据作为训练集剩余 30% 数据作为验证集random_state0保证结果可复现然后进行归一化max_df_train.max(axis0)min_df_train.min(axis0)df_train(df_train-min_)/(max_-min_)df_valid(df_valid-min_)/(max_-min_)归一化后特征值被缩放到大致 0 到 1 之间。这样做的好处是加快模型收敛避免某些数值范围过大的特征主导训练提高梯度下降的稳定性需要注意的是验证集归一化时使用的是训练集的min_和max_而不是验证集自己的统计值。这是为了避免数据泄漏。9. 构建二分类神经网络模型结构如下fromtensorflowimportkerasfromtensorflow.kerasimportlayers modelkeras.Sequential([layers.Dense(4,activationrelu,input_shape[33]),layers.Dense(4,activationrelu),layers.Dense(1,activationsigmoid),])这个模型包含输入层33 个特征隐藏层 14 个神经元ReLU 激活隐藏层 24 个神经元ReLU 激活输出层1 个神经元Sigmoid 激活这里输出层使用layers.Dense(1,activationsigmoid)是二分类模型的关键。如果是回归任务最后一层通常不使用 sigmoid。如果是多分类任务最后一层通常会使用 softmax。10. 编译模型模型编译代码如下model.compile(optimizeradam,lossbinary_crossentropy,metrics[binary_accuracy],)各参数含义如下optimizer‘adam’使用 Adam 优化器。Adam 是深度学习中非常常用的优化算法通常比普通 SGD 更容易训练收敛速度也更稳定。loss‘binary_crossentropy’使用二分类交叉熵作为损失函数。这是二分类问题中最常见的选择。metrics[‘binary_accuracy’]训练过程中同时记录二分类准确率。注意loss用于模型优化metrics用于观察模型表现模型真正优化的是binary_crossentropy不是binary_accuracy。11. 使用 Early Stopping 防止过拟合训练代码中使用了早停机制early_stoppingkeras.callbacks.EarlyStopping(patience10,min_delta0.001,restore_best_weightsTrue,)参数含义如下patience10如果验证集指标连续 10 个 epoch 没有明显改善就停止训练。min_delta0.001只有改善幅度大于 0.001才认为是真的改善。restore_best_weightsTrue训练结束后恢复到验证集表现最好的那一轮模型参数。这个参数非常重要。因为模型最后一轮的权重不一定是最好的可能已经开始过拟合。启用该参数后可以自动保留验证集表现最佳的模型。12. 训练模型训练代码如下historymodel.fit(X_train,y_train,validation_data(X_valid,y_valid),batch_size512,epochs1000,callbacks[early_stopping],verbose0,)这里设置了最多训练 1000 个 epoch但由于使用了 Early Stopping模型通常不会真的训练满 1000 轮。参数解释X_train, y_train训练数据validation_data验证数据batch_size512每次用 512 个样本更新模型epochs1000最多训练 1000 轮callbacks[early_stopping]启用早停verbose0不显示训练过程日志13. 查看训练曲线训练完成后可以将history.history转换为 DataFramehistory_dfpd.DataFrame(history.history)然后绘制损失曲线history_df.loc[5:,[loss,val_loss]].plot()以及准确率曲线history_df.loc[5:,[binary_accuracy,val_binary_accuracy]].plot()这里从第 5 个 epoch 开始画图是为了避开训练初期波动较大的阶段让趋势更清晰。输出示例Best Validation Loss: 0.3534 Best Validation Accuracy: 0.8857表示验证集上最好的结果为最佳验证损失0.3534最佳验证准确率0.8857也就是模型在验证集上的准确率约为 88.57%。14. 二分类模型的标准配置对于神经网络二分类任务常见配置可以总结如下modelkeras.Sequential([layers.Dense(若干神经元,activationrelu,input_shape[特征数量]),layers.Dense(若干神经元,activationrelu),layers.Dense(1,activationsigmoid),])model.compile(optimizeradam,lossbinary_crossentropy,metrics[binary_accuracy],)核心要点是任务类型输出层激活函数损失函数回归1 个或多个神经元通常无激活MAE / MSE二分类1 个神经元sigmoidbinary_crossentropy多分类类别数量个神经元softmaxcategorical_crossentropy / sparse_categorical_crossentropy15. Sigmoid 与 Softmax 的区别二分类通常使用Dense(1,activationsigmoid)多分类通常使用Dense(num_classes,activationsoftmax)二者区别如下Sigmoid适合二分类或多标签分类。输出每个类别独立成立的概率。例如多标签任务一张图片可以同时包含猫、狗、车每个标签可以独立为真。Softmax适合单标签多分类。输出所有类别的概率分布且概率和为 1。例如一张图片只能是猫 / 狗 / 鸟 中的一类16. 实践中的注意事项1. 类别标签必须是数字神经网络不能直接使用字符串类别需要先编码good-0bad-12. 特征需要归一化对于神经网络数值归一化通常很重要可以提升训练稳定性。3. 验证集不能参与训练统计归一化参数应该从训练集计算再应用到验证集和测试集。4. 准确率不是万能指标类别不平衡时准确率可能误导判断。此时应该结合 Precision、Recall、F1、AUC 等指标。5. 阈值可以根据业务调整默认阈值是 0.5但实际项目中可以根据业务需求调整。例如想减少漏检降低阈值想减少误报提高阈值17. 总结本文介绍了如何使用神经网络解决二分类问题。核心知识点如下二分类问题的目标是将样本分为两个类别之一原始类别标签需要转换为 0 和 1准确率可以用于评估模型但不适合作为损失函数二分类常用损失函数是binary_crossentropy输出层通常使用sigmoid激活函数Sigmoid 可以将模型输出转换为 0 到 1 之间的概率默认情况下概率大于等于 0.5 判为类别 1否则判为类别 0Adam 优化器同样适用于分类任务Early Stopping 可以减少过拟合并节省训练时间一句话概括二分类神经网络的典型组合是sigmoid输出概率binary_crossentropy作为损失函数binary_accuracy作为评估指标。完整流程可以概括为准备数据 - 标签编码 - 划分训练集和验证集 - 特征归一化 - 构建神经网络 - 输出层使用 sigmoid - 使用 binary_crossentropy 编译 - 训练模型 - 查看 loss 和 accuracy 曲线 - 根据验证集表现评估模型掌握了这一流程之后就可以将类似方法应用到更多实际二分类任务中例如客户流失预测、订单取消预测、欺诈检测和医学辅助诊断等场景。