【机器学习】 逻辑回归算法:原理、精确率、召回率、实例应用(癌症病例预测)

news2024/10/6 8:29:04

1. 概念理解

逻辑回归,简称LR,它的特点是能够将我们的特征输入集合转化为0和1这两类的概率。一般来说,回归不用在分类问题上,但逻辑回归却能在二分类(即分成两类问题)上表现很好。

逻辑回归本质上是线性回归,只是在特征到结果的映射中加入了一层Sigmod函数映射,即先把特征线形求和,然后使用Sigmoid函数将最为假设函数来概率求解,再进行分类

Sigmoid函数为: S(x)=\frac{1}{1+e^{-x}}

sigmoid函数形如s曲线下侧无限接近0,上侧无限接近1

例如,在进行预测的过程中,预测结果大于0.5的认为是属于一类,小于0.5的我们认为是第二类,进而我们实现二分类。

优点: 适合需要得到一个分类概率的场景,简单,速度快

缺点: 只能用来处理二分类问题,不好处理多分类问题

应用: 是否患病、金融诈骗、是否虚假账号等


2. 精确率和召回率

如下表所示,如果我预测出一个人得了癌症,他的真实值也是得了癌症,那么这种情况称为TP真正例;如果我预测出一个人得了癌症,而他的真实值是没有得癌症,这种情况称为FN假反例。

(1)精确率:预测结果为正例样本中真实为正例的比例(用于表示查得准不准)

        公式为: P=TP/(TP+FP)

        例:100个人中,我预测的结果是有20个人得了癌症。在这20个人中,真实得癌症的只有5个人,没得癌症的有15人。那么精确率为 P=5/(5+15)=0.25

(2)召回率:真实为正例的样本中预测结果为正例的比例(表示查的全,对正样本的区分能力)

        公式为: R=TP/(TP+FN)

        例:现在有20个人得了癌症,在这些人中我检测到有18个人得了癌症,还有2个人没有检测出来,召回率R=18/(18+2)

(3)综合指标:P和R指标有时候会出现的矛盾的情况,这样就需要综合考虑他们,最常见的方法就是F-Measure。

        公式为: F_{1}=\frac{2*P*R}{P+R}

        若F1较大的话,综合性能较好


导入方法: from sklearn.metrics import classification_report
classification_report()  函数参数

y_true1维数组,或标签指示器数组/稀疏矩阵,真实值。

y_pred1维数组,或标签指示器数组/稀疏矩阵,预测值

labels列表,shape = [n_labels],报表中包含的标签索引的可选列表。

target_names字符串列表,与标签匹配的可选显示名称(相同顺序)

sample_weight类似于shape = [n_samples]的数组,可选项,样本权重 

digitsint,输出浮点值的位数


3. 实例应用 -- 癌症病例预测

3.1 Sklearn 实现

逻辑回归方法导入: from sklearn.linear_model import LogisticRegression
参数设置: 参考博客 https://blog.csdn.net/jark_/article/details/78342644

penalty惩罚项,str类型,可选参数为l1和l2,默认为l2。用于指定惩罚项中使用的规范。

        L1规范假设的是模型的参数满足拉普拉斯分布L2假设的模型参数满足高斯分布。所谓的范式就是加上对参数的约束,使得模型不会过拟合,加约束的情况下,理论上应该可以获得泛化能力更强的结果。

dual对偶或原始方法,bool类型,默认为False。对偶方法只用在求解线性多核的L2惩罚项上。当样本数量>样本特征的时候,dual通常设置为False

tol停止求解的标准,float类型,默认为1e-4。就是求解到多少的时停止,认为已经求出最优解。

C正则化系数λ的倒数,float类型,默认为1.0。必须是正浮点型数。像SVM一样,越小的数值表示越强的正则化。

fit_intercept是否存在截距或偏差,bool类型,默认为True

intercept_scaling仅在正则化项为”liblinear”,且fit_intercept设置为True时有用。float类型,默认为1。
class_weight用于标示分类模型中各种类型的权重,可以是一个字典或者’balanced’字符串,默认为不输入,也就是不考虑权重,即为None。

        如果选择输入的话,可以选择balanced让类库自己计算类型权重,或者自己输入各个类型的权重。举个例子,比如对于0,1的二元模型,我们可以定义class_weight={0:0.9,1:0.1},这样类型0的权重为90%,而类型1的权重为10%。如果class_weight选择balanced,那么类库会根据训练样本量来计算权重。某种类型样本量越多,则权重越低,样本量越少,则权重越高。当class_weight为balanced时,类权重计算方法如下:n_samples / (n_classes * np.bincount(y))。n_samples为样本数,n_classes为类别数量,np.bincount(y)会输出每个类的样本数,例如y=[1,0,0,1,1],则np.bincount(y)=[2,3]。

random_state随机数种子,int类型,可选参数,默认为无,仅在正则化优化算法为sag,liblinear时有用。

solver优化算法选择参数,只有五个可选参数,即newton-cg, lbfgs, liblinear, sag, saga。默认为liblinear。

solver参数决定了我们对逻辑回归损失函数的优化方法,有四种算法可以选择,分别是:
liblinear使用了开源的liblinear库实现,内部使用了坐标轴下降法来迭代优化损失函数。
lbfgs拟牛顿法的一种,利用损失函数二阶导数矩阵即海森矩阵来迭代优化损失函数。
newton-cg牛顿法的一种,利用损失函数二阶导数矩阵即海森矩阵来迭代优化损失函数。
sag即随机平均梯度下降,是梯度下降法的变种,和普通梯度下降法的区别是每次迭代仅仅用一部分的样本来计算梯度,适合于样本数据多的时候。
saga线性收敛的随机优化算法的的变重。

verbose日志冗长度,int类型。默认为0。就是不输出训练过程,1的时候偶尔输出结果,大于1,对于每个子模型都输出。

warm_start热启动参数,bool类型。默认为False。如果为True,则下一次训练是以追加树的形式进行(重新使用上一次的调用作为初始化)。


3.1 癌症预测

数据集包含10项特征值数据和1项目标数据,字符'?'代表缺失数据,目标中数字2代表癌症良性,4代表癌症恶性。

数据集下载地址:Index of /ml/machine-learning-databases/breast-cancer-wisconsin

names中存放的是每一项数据的列索引名称,pandas导入数据集时会默认将数据第一行当作数据索引名,而原数据没有列索引名,我们需要自定义列 pd.read_csv(文件路径,names=列名称)

#(1)数据获取
import pandas as pd
import numpy as np
# 癌症数据路径
filepath = 'C:\\Users\\admin\\.spyder-py3\\test\\文件处理\\癌症\\breast-cancer-wisconsin.data'
# 癌症的每一项特征名
names = ['Sample code number', 'Clump Thickness', 'Uniformity of Cell Size', 'Uniformity of Cell Shape','Marginal Adhesion', 'Single Epithelial Cell Size', 'Bare Nuclei', 'Bland Chromatin','Normal Nucleoli', 'Mitoses', 'Class']
# breast存放癌症数据,不默认将第一行作为列索引名,自定义列索引名
breast = pd.read_csv(filepath,names=names)
# 查看唯一值,Class这列代表的是否得癌症,使用.unique()函数查看该列有哪些互不相同的值
unique = breast['Class'].unique()  #只有两种情况,是二分类问题,2代表良性,4代表恶性


3.2 数据处理 

首先通过 .info() 函数查看数据中是否存在缺失数据nan和重复数据,本例子中没有。然后对字符'?'进行处理,先将'?'转换成nan值,再使用 .dropna() 函数将nan所在的行删除。完成以后划分特征值和目标值。再划分训练集和测试集,测试集取25%的数据。

#(2)数据处理
breast.info()  #查看是否有缺失值、重复数据
# 该数据集存在字符串类型数据'?'
# 将'?'转换成nan
breast = breast.replace(to_replace='?',value=np.nan)
# 将nan所在的行删除
breast = breast.dropna()
 
# 特征值是除了class列以外的所有数据
features = breast.drop('Class',axis=1)
# 目标值是class这一列
targets = breast['Class']
 
#(3)划分训练集和测试集
from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test = train_test_split(features,targets,test_size=0.25)

3.3 标准化处理

由于单位不一以及数据跨度过大等问题会影响模型准确度,因此对训练数据的和测试数据的特征值进行标准化处理。特征工程的具体方法会在后续章节中介绍,此处先做了解。

#(4)特征工程
# 导入标准化方法
from sklearn.preprocessing import StandardScaler
# 接收标准化方法
transfer = StandardScaler()
# 对训练的特征值x_train提取特征并标准化处理
x_train = transfer.fit_transform(x_train)
# 对测试的特征值x_test标准化处理
x_test = transfer.transform(x_test)

3.4 逻辑回归预测

由于癌症数据中结果只有2和4,良性和恶性,属于二分问题,可以使用逻辑回归方法来预测,此处,为方便各位理解,采用默认参数的逻辑回归方法。其中.fit()函数接收训练模型所需的特征值和目标值,预测函数.predict()接收的是预测所需的特征值,评分法.score()通过真实结果和预测结果计算准确率。计算得到的模型准确率为0.97

#(5)逻辑回归预测
# 导入逻辑回归方法
from sklearn.linear_model import LogisticRegression
# 接收逻辑回归方法
logist = LogisticRegression()
# penalty=l2正则化;tol=0.001损失函数小于多少时停止;C=1惩罚项,越小惩罚力度越小,是岭回归的乘法力度的分之一
# 训练
logist.fit(x_train,y_train)
# 预测
y_predict = logist.predict(x_test)
# 评分法计算准确率
accuracy = logist.score(x_test,y_test)
 

3.5 准确率和召回率

#(6)准确率和召回率
# 导入
from sklearn.metrics import classification_report
# classification_report()
# 参数(真实值,预测值,labels=None,target_names=None)
# labelsclass列中每一项,如该题的24,给它们取名字
# target_names:命名
 
# 计算准确率和召回率
res = classification_report(y_test,y_predict,labels=[2,4],target_names=['良性','恶性'])
print(res)

precision表示准确率;recall表示召回率;f1-score表示综合指标;support表示预测的人数。本模型的召回率,良性达到0.97,恶性达到0.96;该例子是检测癌症,我们希望能找到所有得癌症的人,即使他不是癌症,也可以做进一步检查,因此我们需要一个召回率高的模型。 


数据集获取:

Index of /ml/machine-learning-databases/breast-cancer-wisconsin

完整代码:
#(1)数据获取
import pandas as pd
import numpy as np
# 癌症数据路径
filepath = 'C:\\Users\\admin\\.spyder-py3\\test\\文件处理\\癌症\\breast-cancer-wisconsin.data'
# 癌症的每一项特征名
names = ['Sample code number', 'Clump Thickness', 'Uniformity of Cell Size', 'Uniformity of Cell Shape','Marginal Adhesion', 'Single Epithelial Cell Size', 'Bare Nuclei', 'Bland Chromatin','Normal Nucleoli', 'Mitoses', 'Class']
# breast存放癌症数据,不默认将第一行作为列索引名,自定义列索引名
breast = pd.read_csv(filepath,names=names)
# 查看唯一值,Class这列代表的是否得癌症,使用.unique()函数查看该列有哪些互不相同的值
unique = breast['Class'].unique()  #只有两种情况,是二分类问题,2代表良性,4代表恶性
 
#(2)数据处理
breast.info()  #查看是否有缺失值、重复数据
# 该数据集存在字符串类型数据'?'
# 将'?'转换成nan
breast = breast.replace(to_replace='?',value=np.nan)
# 将nan所在的行删除
breast = breast.dropna()
 
# 特征值是除了class列以外的所有数据
features = breast.drop('Class',axis=1)
# 目标值是class这一列
targets = breast['Class']
 
#(3)划分训练集和测试集
from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test = train_test_split(features,targets,test_size=0.25)
 
#(4)特征工程
# 导入标准化方法
from sklearn.preprocessing import StandardScaler
# 接收标准化方法
transfer = StandardScaler()
# 对训练的特征值x_train提取特征并标准化处理
x_train = transfer.fit_transform(x_train)
# 对测试的特征值x_test标准化处理
x_test = transfer.transform(x_test)
 
 
#(5)逻辑回归预测
# 导入逻辑回归方法
from sklearn.linear_model import LogisticRegression
# 接收逻辑回归方法
logist = LogisticRegression()
# penalty=l2正则化;tol=0.001损失函数小于多少时停止;C=1惩罚项,越小惩罚力度越小,是岭回归的乘法力度的分之一
# 训练
logist.fit(x_train,y_train)
# 预测
y_predict = logist.predict(x_test)
# 评分法计算准确率
accuracy = logist.score(x_test,y_test)
 
 
#(6)准确率和召回率
# 导入
from sklearn.metrics import classification_report
# classification_report()
# 参数(真实值,预测值,labels=None,target_names=None)
# labelsclass列中每一项,如该题的24,给它们取名字
# target_names:命名
 
# 计算准确率和召回率
res = classification_report(y_test,y_predict,labels=[2,4],target_names=['良性','恶性'])
# precision准确率;recall召回率;综合指标F1-score;support:预测的人数
print(res)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1221757.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器学习中的独立和同分布 (IID):假设和影响

一、介绍 在机器学习中,独立和同分布 (IID) 的概念在数据分析、模型训练和评估的各个方面都起着至关重要的作用。IID 假设是确保许多机器学习算法和统计技术的可靠性和有效性的基础。本文探讨了 IID 在机器学习中的重要性、其假设及其对模型开…

JUC工具类_CyclicBarrier与CountDownLatch

最近被问到CyclicBarrier和CountDownLatch相关的面试题,CountDownLatch平时工作中经常用到,但是CyclicBarrier没有用过,一时答不上来,因此简单总结记录一下 1.什么是CyclicBarrier? 1.1 概念 CyclicBarrier&#xff…

php中RESTful API使用

1、RESTful AP是什么 RESTful API是一种软件架构风格 RESTful API基于HTTP协议,并遵循一系列约定和原则。它的设计理念是将资源(Resource)作为核心概念,并通过一组统一的接口对资源进行操作。API的资源通常通过URL进行标识&…

Linux grep 命令

Linux grep 命令 1: 作用 ​ grep是一种文本搜索工具,它能使用特定的搜索模式,包括[正则表达式]搜索文本,并默认输出匹配行。 ​ windows类似的命令是findstr. 2:语法 grep -options(参数)…

Spring Boot - devtools 热部署

spring-boot-devtools是Spring Boot提供的一组开发工具,它旨在提高开发体验。这些工具包括应用程序的自动重新启动、自动刷新和远程调试等功能。下面是将spring-boot-devtools整合到Spring Boot应用程序中的步骤: 0、启用"Build project automatic…

【AI视野·今日Sound 声学论文速览 第三十六期】Mon, 30 Oct 2023

AI视野今日CS.Sound 声学论文速览 Mon, 30 Oct 2023 Totally 7 papers 👉上期速览✈更多精彩请移步主页 Daily Sound Papers Style Description based Text-to-Speech with Conditional Prosodic Layer Normalization based Diffusion GAN Authors Neeraj Kumar, A…

开发知识点-前端-webpack

webpack技术笔记 一、 介绍二、 下载使用 一、 介绍 Webpack是一个现代 JavaScript 应用程序的静态模块打包器 打包:可以把js、css等资源按模块的方式进行处理然后再统一打包输出 静态:最终产出的静态资源都可以直接部署到静态资源服务器上进行使用 模…

达尔优EK87键盘说明书

EK87说明书连接说明: **有线模式:**开关拨到最右边,然后插线连接电脑即可使用 2.4G **接收器模式:**开关拨到中间,然后接收器插入电脑USB接口即可使用 **蓝牙模式:**开关拨到最左边,然后按FNQ长…

<Linux>(极简关键、省时省力)《Linux操作系统原理分析之Linux 进程管理 2》(6)

《Linux操作系统原理分析之Linux 进程管理 2》(6) 4 Linux 进程管理4.2 Linux 进程的状态和标识4.2.1 Linux 进程的状态及转换4.2.2 Linux 进程的标识4.2.3 进程标识哈希表 4 Linux 进程管理 4.2 Linux 进程的状态和标识 4.2.1 Linux 进程的状态及转换…

Navicat 基于 GaussDB 主备版的快速入门

Navicat Premium(16.2.8 Windows版或以上) 已支持对GaussDB 主备版的管理和开发功能。它不仅具备轻松、便捷的可视化数据查看和编辑功能,还提供强大的高阶功能(如模型、结构同步、协同合作、数据迁移等),这…

API接口怎么对接电商平台获取商品详情数据

对于api接口的对接,你可以按照以下步骤进行操作: 1. 确定需求:首先要明确你的对接需求,即想要通过对接api接口实现什么功能,例如获取数据、实现支付等。 2. 寻找文档:在对接之前,要找到相关ap…

一、认识STM32

目录 一、初识STM32 1.1 STM32的命名规则介绍 1.2 STM32F103ZET6资源配置介绍 二、如何识别芯片管脚 2.1 如何寻找 IO 的功能说明 三、构成最小系统的要素 一、初识STM32 1.1 STM32的命名规则介绍 以 STM32F103ZET6 来讲解下 STM32 的命名方法: &…

AH4056线性锂电池充电IC:高效、安全的充电解决方案

随着移动设备的普及,人们对电池续航能力的要求越来越高。为了满足这一需求,电池充电技术不断创新。本文将为您介绍一款AH4056线性锂电池充电IC,采用同步整流技术,具有宽输入电压范围、大充电电流、温度保护等优点,适用…

虾皮产品标题生成器:为您的商品打造吸引眼球的标题

在电商平台上,一个引人注目的商品标题是吸引潜在买家点击进入您的产品页面的第一步。然而,很多商家在创建商品标题时遇到困难,不知道如何吸引更多的目标受众。幸运的是,现在有一个名为知虾工具的强大工具,可以帮助商家…

【Git学习二】时光回溯:git reset和git checkout命令详解

😁 作者简介:一名大四的学生,致力学习前端开发技术 ⭐️个人主页:夜宵饽饽的主页 ❔ 系列专栏:JavaScript小贴士Git等软件工具技术的使用 👐学习格言:成功不是终点,失败也并非末日&a…

基于STM32的无线传感器网络(WSN)通信方案设计与实现

无线传感器网络(Wireless Sensor Network,简称WSN)是由一组分布式的无线传感器节点组成的网络,用于监测和收集环境中的各类物理信息。本文将基于STM32微控制器,设计并实现一个简单的无线传感器网络通信方案&#xff0c…

系统韧性研究(5)| 常用的系统韧性技术

如果不利事件或条件导致系统无法正常运行,则它们可能会对有价值的资产造成各种形式的损害。正如我在本系列的前几篇文章中概述的那样,系统韧性很重要,因为没有人想要一个无法克服“不可避免的逆境”的脆弱系统。 在本系列的第一篇文章中&…

nacos客户端连接服务端报Client not connected, current status:STARTING

说明&#xff1a; nacos服务端版本&#xff1a;v2.1.2 nacos客户端版本&#xff1a;2.1.2 结果启动项目报错&#xff1a; Client not connected, current status:STARTING 解决&#xff1a; 降低客户端版本至 1.4.1 就Ok了 <dependency><groupId>com.alibaba.naco…

AI监管规则:各国为科技监管开辟了不同的道路

AI监管规则&#xff1a;各国为科技监管开辟了不同的道路 一份关于中国、欧盟和美国如何控制AI的指南。 编译 李升伟 茅 矛 &#xff08;特趣生物科技有限公司&#xff0c;广东深圳&#xff09; 插图&#xff1a;《自然》尼克斯宾塞 今年5月&#xff0c;科技公司OpenAI首席…

List is a raw type. References to generic type List<E> should be parameterized

List is a raw type. References to generic type List<E> should be parameterized 都是代码习惯问题懒