机器学习——代价敏感错误率与代价曲线

news2024/11/18 1:26:34

文章目录

  • 代价敏感错误率
  • 实现
  • 代价曲线
  • 例子

代价敏感错误率

指在分类问题中,不同类别的错误分类所造成的代价不同。在某些应用场景下,不同类别的错误分类可能会产生不同的代价。例如,在医学诊断中,将疾病患者错误地分类为健康人可能会带来严重的后果,而将健康人错误地分类为疾病患者的后果可能相对较轻。

传统的分类算法通常假设所有的错误分类代价是相同的,但在实际应用中,这种情况并不常见。代价敏感错误率的目标是通过考虑不同类别的错误分类代价,使得模型在测试阶段能够最小化总体代价。

在代价敏感错误率中,通常会引入代价矩阵(cost matrix)来描述不同类别之间的代价关系。代价矩阵是一个二维矩阵,其中的元素c(i, j)表示将真实类别为i的样本误分类为类别j的代价。通过使用代价矩阵,可以在模型训练和评估过程中更好地考虑不同类别的错误分类代价。

在这里插入图片描述

在实际应用中,代价敏感错误率通常用于处理那些不同类别错误分类代价差异较大的问题,例如欺诈检测、医学诊断等领域。在训练分类模型时,引入代价矩阵可以帮助模型更好地适应不同类别的代价,从而提高模型在实际应用中的性能。

在这里插入图片描述

实现

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from skmultilearn.problem_transform import ClassifierChain
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.ensemble import RandomForestClassifier
import numpy as np

class CostSensitiveClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, base_estimator=None, cost_matrix=None):
        self.base_estimator = base_estimator
        self.cost_matrix = cost_matrix

    def fit(self, X, y):
        self.base_estimator.fit(X, y)
        return self

    def predict(self, X):
        return self.base_estimator.predict(X)

    def predict_proba(self, X):
        return self.base_estimator.predict_proba(X)

# 以加载iris数据集为例
iris = load_iris()
X = iris.data
y = iris.target

# 将数据集分成训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 定义成本矩阵
# 首先有三种分类所有这个矩阵就是3x3,可以看上面表2.2.的代价矩阵就可以理解了
cost_matrix = np.array([[0, 100, 0],
                        [0, 0, 1000],
                        [0, 0, 0]])

# 使用RandomForestClassifier创建成本敏感分类器作为基本估计器
cost_sensitive_classifier = CostSensitiveClassifier(base_estimator=RandomForestClassifier(random_state=42), cost_matrix=cost_matrix)

# Fit the model
cost_sensitive_classifier.fit(X_train, y_train)

# Make predictions
predictions = cost_sensitive_classifier.predict(X_test)

# Evaluate the model using confusion matrix
conf_matrix = confusion_matrix(y_test, predictions)
print("Confusion Matrix:")
print(conf_matrix)


代价曲线

代价曲线(Cost Curve)是一个用于评估代价敏感学习算法性能的工具。在许多实际问题中,不同类别的错误分类可能导致不同的代价。例如,在医疗诊断中,将一个健康患者误诊为患病可能会导致不同于将一个患病患者误诊为健康的代价。代价曲线帮助我们在考虑不同错误分类代价的情况下,选择一个适当的分类阈值。

代价曲线的意义和应用包括以下几点:

  1. 选择最优分类阈值: 代价曲线可以帮助你选择一个最适合的分类阈值,使得在考虑不同类别错误分类代价的情况下,总代价最小化。通过观察代价曲线,你可以选择使得总代价最小的分类阈值,从而在实际应用中取得更好的性能。

  2. 模型选择和比较: 代价曲线可以用于比较不同模型在代价敏感情况下的性能。通过比较不同模型在代价曲线上的表现,你可以选择最适合你问题的模型。

  3. 业务决策支持: 在许多业务场景中,不同类型的错误可能具有不同的后果。例如,在信用卡欺诈检测中,将一个正常交易误判为欺诈交易的代价可能比将一个欺诈交易误判为正常交易的代价大得多。代价曲线可以帮助业务决策者理解模型的性能,并根据业务需求调整模型的阈值。

在这里插入图片描述
在这里插入图片描述

绘制代价曲线的步骤通常包括以下几个部分:

  1. 定义代价矩阵: 首先,你需要明确定义每种错误分类的代价。代价矩阵(cost matrix)是一个二维矩阵,其中的每个元素表示将真实类别i的样本错误地分类为类别j的代价。这个矩阵应该根据你的问题领域和需求来定义。

  2. 模型预测概率: 使用你的分类模型(比如随机森林、逻辑回归等)对测试数据集进行预测,并获取每个样本属于各个类别的概率。

  3. 计算总代价: 对于每个可能的分类阈值,使用代价矩阵和模型的预测概率计算出混淆矩阵(confusion matrix),然后根据代价矩阵计算总代价。

  4. 绘制代价曲线: 将不同分类阈值下的总代价绘制成图表,横轴表示阈值,纵轴表示总代价。通过观察代价曲线,你可以选择使得总代价最小的分类阈值。

以下是在Python中绘制代价曲线的基本步骤示例:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

# 1. 定义代价矩阵
cost_matrix = np.array([[0, 1, 1],
                        [10, 0, 5],
                        [5, 5, 0]])

# 2. 模型预测概率
# 这里假设你已经有了模型的预测概率,存储在变量 predicted_probs 中

# 3. 计算总代价
thresholds = np.linspace(0, 1, 100)
total_costs = []

for threshold in thresholds:
    y_pred_thresholded = np.argmax(predicted_probs, axis=1)
    y_pred_thresholded = np.where(predicted_probs.max(axis=1) > threshold, y_pred_thresholded, -1)
    conf_matrix = confusion_matrix(true_labels, y_pred_thresholded, labels=[0, 1, 2])
    total_cost = np.sum(conf_matrix * cost_matrix)  # 计算总代价
    total_costs.append(total_cost)

# 4. 绘制代价曲线
plt.figure(figsize=(8, 6))
plt.plot(thresholds, total_costs, marker='o', linestyle='-', color='b')
plt.xlabel('Threshold')
plt.ylabel('Total Cost')
plt.title('Cost Curve')
plt.grid(True)
plt.show()

请确保 predicted_probs 包含模型的预测概率,true_labels 包含测试数据的真实类别。这个示例代码中,total_costs 中存储了不同阈值下的总代价,将其绘制成曲线,你就得到了代价曲线。希望这个例子能帮到你理解代价曲线的绘制过程。

例子

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix

# 加载数据 这个类别是有三种
iris = load_iris()
X = iris.data
y = iris.target

# 假设一个代价矩阵,其中假阳性代价为1,类1的假阴性代价为10,类2的假阴性代价为5
cost_matrix = np.array([[0, 1, 1],
                        [10, 0, 5],
                        [5, 5, 0]])

# 将数据集分成训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 训练一个随机森林分类器
rf_classifier = RandomForestClassifier(random_state=42)
rf_classifier.fit(X_train, y_train)

# 作出预测
"""
三个类别(类别0、类别1、类别2),rf_classifier.predict_proba(X_test) 的输出
可能是一个形状为 (n_samples, 3) 的数组,其中 n_samples 
是测试样本的数量。这个数组的每一行表示一个测试样本,每一列的值表示该样本属于对应类别的概率。
"""
y_pred_probs = rf_classifier.predict_proba(X_test)


# 计算每个阈值的总成本
thresholds = np.linspace(0, 1, 100)
total_costs = []

for threshold in thresholds:
    y_pred_thresholded = np.argmax(y_pred_probs, axis=1)
    y_pred_thresholded = np.where(y_pred_probs.max(axis=1) > threshold, y_pred_thresholded, -1)
    conf_matrix = confusion_matrix(y_test, y_pred_thresholded, labels=[0, 1, 2])
    total_cost = np.sum(conf_matrix * cost_matrix)  # 使用给定的成本矩阵计算总成本
    total_costs.append(total_cost)

# Plot the cost curve
plt.figure(figsize=(8, 6))
plt.plot(thresholds, total_costs, marker='o', linestyle='-', color='b')
plt.xlabel('Threshold')
plt.ylabel('Total Cost')
plt.title('Cost Curve')
plt.grid(True)
plt.show()


print(y_pred_probs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1142325.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

区块链技术的未来:去中心化应用和NFT的崛起

区块链技术正在以前所未有的速度改变着金融和数字资产领域。它的演进为去中心化应用和非替代性代币(NFT)的崛起提供了坚实的基础。在本文中,我们将深入探讨这一数字革命的关键方面,从区块链的基本原理到它如何改变金融领域&#x…

修改svc的LoadBalancer的IP引发的惨案

文章目录 背景修改externalIPs的操作api-server报错日志挽救教训 背景 k8s集群没有接外部负载均衡,部署istio的时候ingressgateway一直pending。 于是手动修改了这个lb svc的externalIP,于是k8s就崩了,如何崩的,且听我还道来。 …

Redis(03)| 数据结构-链表

大家最熟悉的数据结构除了数组之外,我相信就是链表了。 Redis 的 List 对象的底层实现之一就是链表。C 语言本身没有链表这个数据结构的,所以 Redis 自己设计了一个链表数据结构。 链表节点结构设计 先来看看「链表节点」结构的样子: type…

ReentrantLock 的实现原理

ReentrantLock ReentrantLock 是一种可重入的排它锁,主要用来解决多线程对共享资源竞争的问题。它的核心特性有几个: 它支持可重入,也就是获得锁的线程在释放锁之前再次去竞争同一把锁的时候,不需要加锁就可以直接访问。它支持公…

企业微信接入芋道SpringBoot项目

背景:使用芋道框架编写了一个数据看板功能需要嵌入到企业微信中,方便各级人员实时观看 接入企业微信的话肯定不能像平常pc端一样先登录再根据权限看页面,不然的话不如直接手机浏览器打开登录账号来得更为方便,所以迎面而来面临两…

结构体数组经典运用---选票系统

结构体的引入 1、概念:结构体和其他类型基础数据类型一样,例如int类型,char类型,float类型等。整型数,浮点型数,字符串是分散的数据表示,有时候我们需要用很多类型的数据来表示一个整体&#x…

网络通信 | 内网穿透

内网穿透 / 获取临时域名 : 下载且安装软件 : cpolar获得 “Authtoken” 且配置 cpolar ,生成“内网穿透”工具配置文件启动服务,临时获取到一个IP地址 (临时域名) 让当前 电脑能获取一个公网的IP地址,让微信后台能调用到当前外卖系统的后端服…

亚马逊发布卖家专用的AI工具,能为商品添加背景

🦉 AI新闻 🚀 亚马逊发布卖家专用的AI工具,能为商品添加背景 摘要:亚马逊发布了一款针对卖家的生成式AI工具,允许卖家为商品添加由AI生成的背景。亚马逊称有75%的广告主因广告内容质量不佳而导致营销失败。通过比对两…

企业文件传输速度慢?宽带利用率不高怎么办?

在当今数字化时代,企业间的文件传输已经成为日常工作中不可或缺的一部分。很多企业业务全球化趋势明显,拥有分布在不同地理位置的办事处、工厂、供应商和合作伙伴。它们需要在这些地点之间传输文件,如合同、报告、设计文件等。企业常常需要传…

【Idea】idea启动同一程序不同端口

前言 在idea中配置两个不同端口,同时运行两个相同的主程序。 更多端口配置同理 idea版本:2022.2.3 1.在service中复制一个进程,指定不同端口 右键打开点击 copy Configuration 2.点击Modify option 2. 选择VM option(用于指定新的端口) 页面就会出现下面的指定…

代码随想录算法训练营第3天| 203.移除链表元素 、 707.设计链表 、 206.反转链表

JAVA代码编写 203. 移除链表元素 给你一个链表的头节点 head 和一个整数 val ,请你删除链表中所有满足 Node.val val 的节点,并返回 新的头节点 。 示例 1: 输入:head [1,2,6,3,4,5,6], val 6 输出:[1,2,3,4,5]示…

[Leetcode] 0101. 对称二叉树

101. 对称二叉树 题目描述 给你一个二叉树的根节点 root , 检查它是否轴对称。 示例 1: 输入:root [1,2,2,3,4,4,3] 输出:true示例 2: 输入:root [1,2,2,null,3,null,3] 输出:false提示&#…

“智”造时代引领AI机器视觉技术进步,加速渗透下游应用赛道

机器视觉作为智能制造的“智慧之眼”,为智能制造打开了新的“视”界,是实现工业自动化和智能化的关键核心技术。人工智能与机器视觉的结合使用,逐渐成为现阶段研究与应用的热点,同时也成为机器视觉行业发展的核心驱动力。 近年来…

Compose 自定义 - 绘制 Draw

一、概念 所有的绘制操作都是通过调整像素大小来执行的。若要确保项目在不同的设备密度和屏幕尺寸上都能采用一致的尺寸,请务必使用 .toPx() 对 dp 进行转换或者采用小数尺寸。 二、Modifier 修饰符绘制 官方页面 在修饰的可组合项之上或之下绘制。 .drawWithCon…

机器视觉3D项目评估的基本要素及测量案例分析

目录 一. 检测需求确认 1、产品名称:【了解是什么产品上的零件,功能是什么】 2、*产品尺寸:【最大兼容尺寸】 3、*测量项目:【确认清楚测量点位】 4、*精度要求:【若客户提出的精度值过大或者过小,可以和客…

论坛介绍 | COSCon'23 开源文化(GL)

众多开源爱好者翘首期盼的开源盛会:第八届中国开源年会(COSCon23)将于 10月28-29日在四川成都市高新区菁蓉汇举办。本次大会的主题是:“开源:川流不息、山海相映”!各位新老朋友们,欢迎到成都&a…

hadoop使用简介

git clone hadoop源码地址:https://gitee.com/CHNnoodle/hadoop.git git clone错误: Filename too long错误,使用git config --global core.longpaths true git clone https://gitee.com/CHNnoodle/hadoop.git -b rel/release-3.2.2 拉取指定…

Spigot 通过 BuildTools 构建 MineCraft Spigot 官方服务端文件

文章目录 从 Spigot 官方下载 BuildTools spigotmc / buildtools确保你有正确版本的 Java(例如构建 1.20.2 的服务端一般需要有 Java17)在 BuildTools.jar 同名文件夹打开 cmd 命令行(点击红色圈圈区域输入 cmd 按 enter 即可) …

如何查找特定基因集合免疫基因集 炎症基因集

温故而知新,再次看下Msigdb数据库。它更新了很多内容。给我们提供了一个查询基因集的地方。 关注微信:生信小博士 比如纤维化基因集: 打开网址:https://www.gsea-msigdb.org/gsea/msigdb/index.jsp 2.点击search 3.比如我对纤维…

第四章 文件管理 十、文件系统的全局结构

目录 一、文件系统的建立 1、原始磁盘 2、物理格式化后 3、逻辑格式化后 二、文件系统在内存中的结构 三、系统调用背后的过程 一、文件系统的建立 1、原始磁盘 2、物理格式化后 物理格式化,即低级格式化――划分扇区,检测坏扇区,并用…