from sklearn.metrics import confusion_matrix 是从 scikit-learn 库中导入的一个函数,用于计算分类模型的混淆矩阵。混淆矩阵是一个重要的工具,用于评估分类模型的性能,通过总结预测值与真实值之间的关系,直观地展示模型在每个类别上的表现。


1. 什么是混淆矩阵?

混淆矩阵是一个方阵,其中:

  • 行表示真实的类别(True Labels)。
  • 列表示预测的类别(Predicted Labels)。

对于二分类问题,混淆矩阵通常是一个 2 × 2 2 \times 2 2×2 的矩阵,形式如下:

实际\预测 预测为正类 (Positive) 预测为负类 (Negative)
实际为正类 (Positive) True Positive (TP) False Negative (FN)
实际为负类 (Negative) False Positive (FP) True Negative (TN)

对于多分类问题,混淆矩阵会扩展为 C × C C \times C C×C的矩阵,其中 C C C是类别的数量。


2. 混淆矩阵的计算公式

假设有以下符号:

  • (TP):预测为正类且实际为正类的样本数。
  • (FN):预测为负类但实际为正类的样本数。
  • (FP):预测为正类但实际为负类的样本数。
  • (TN):预测为负类且实际为负类的样本数。

混淆矩阵统计如下:

  • True Positive (TP):模型正确预测为正类的样本数量。
  • False Negative (FN):模型错误预测为负类的正类样本数量。
  • False Positive (FP):模型错误预测为正类的负类样本数量。
  • True Negative (TN):模型正确预测为负类的样本数量。

3. 函数签名和参数

confusion_matrix(y_true, y_pred, labels=None, sample_weight=None, normalize=None)
参数
  1. y_true
    • 实际的类别标签。
    • 类型:array-like
  2. y_pred
    • 模型预测的类别标签。
    • 类型:array-like
  3. labels(可选):
    • 指定类别标签的顺序。如果未指定,则按标签出现的顺序排序。
  4. sample_weight(可选):
    • 每个样本的权重,用于加权计算混淆矩阵。
  5. normalize(可选):
    • 是否归一化混淆矩阵。
      • None(默认):返回样本数量。
      • 'true':按每一行的总和归一化。
      • 'pred':按每一列的总和归一化。
      • 'all':按整体样本总数归一化。

4. 返回值

返回一个矩阵:

  • 对于二分类问题,返回形状为 2 × 2 2 \times 2 2×2的数组。
  • 对于多分类问题,返回形状为 C × C C \times C C×C的数组, C C C是类别数。

5. 示例代码

二分类问题
from sklearn.metrics import confusion_matrix

# 实际标签和预测标签
y_true = [0, 1, 0, 1, 0, 1, 1, 0]
y_pred = [0, 0, 0, 1, 0, 1, 1, 1]

# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
print("Confusion Matrix:")
print(cm)

输出

Confusion Matrix:
[[3 1]
 [1 3]]

解释:

  • True Negative (TN) = 3 \text{True Negative (TN)} = 3 True Negative (TN)=3:实际为 0 且预测为 0 的样本。
  • False Positive (FP) = 1 \text{False Positive (FP)} = 1 False Positive (FP)=1:实际为 0 但预测为 1 的样本。
  • False Negative (FN) = 1 \text{False Negative (FN)} = 1 False Negative (FN)=1:实际为 1 但预测为 0 的样本。
  • True Positive (TP) = 3 \text{True Positive (TP)} = 3 True Positive (TP)=3:实际为 1 且预测为 1 的样本。

多分类问题
from sklearn.metrics import confusion_matrix

# 实际标签和预测标签
y_true = [0, 1, 2, 2, 0]
y_pred = [0, 0, 2, 2, 1]

# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
print("Confusion Matrix:")
print(cm)

输出

Confusion Matrix:
[[1 1 0]
 [1 0 0]
 [0 0 2]]

解释:

  • 行表示真实标签,列表示预测标签。
  • 每个位置的值表示对应类别的样本数。

6. 应用场景

  • 二分类问题:用于计算精准率、召回率、F1 分数等性能指标。
  • 多分类问题:评估模型对每个类别的区分能力。
  • 不平衡数据集:观察模型对小类别的预测效果。

7. 使用混淆矩阵计算其他指标

通过混淆矩阵可以计算多种评估指标:

  1. 精准率(Precision)
    Precision = T P T P + F P \text{Precision} = \frac{TP}{TP + FP} Precision=TP+FPTP
  2. 召回率(Recall)
    Recall = T P T P + F N \text{Recall} = \frac{TP}{TP + FN} Recall=TP+FNTP
  3. F1 分数
    F1 = 2 ⋅ Precision ⋅ Recall Precision + Recall \text{F1} = 2 \cdot \frac{\text{Precision} \cdot \text{Recall}}{\text{Precision} + \text{Recall}} F1=2Precision+RecallPrecisionRecall
  4. 准确率(Accuracy)
    Accuracy = T P + T N T P + T N + F P + F N \text{Accuracy} = \frac{TP + TN}{TP + TN + FP + FN} Accuracy=TP+TN+FP+FNTP+TN

At Last

confusion_matrix 是一个直观的工具,用于了解分类模型的预测表现。通过观察混淆矩阵中的值,可以快速定位模型在哪些类别上的表现不足,并结合其他指标(如精准率、召回率等)进行综合分析,优化模型性能。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐