Logistic Regression——逻辑回归

news2024/11/16 11:35:18

1. 为什么需要逻辑回归     

        在前面学习的线性回归中,我们的预测值都是任意的连续值,例如预测房价。除此之外,还有一个常见的问题就是分类问题,而逻辑回归是一个解决分类问题的模型,其预测值是离散的

        分类问题又包括二分类问题与多分类问题,对于二分类问题来说,预测值只可能是\否即1\0,

        对于多分类问题来说,预测值可能是多个分类中的一个,例如我输入的是一些动物的图片,我想让模型辨认这些是什么动物,我可以设定预测值1代表模型认为输入是一只猫,预测值2代表模型认为输入是一只狗,预测值3代表模型认为输入是一只猪。

2. 二分类逻辑回归

2.1 从线性回归到分类

        如果有这样一个场景,输入x为肿瘤的大小,而需要预测是否是恶性的。接下来我们仍然使用线性回归模型,但如果我们这增设这样一个阈值

                

        这样一来,所有预测值都将变成1或者0,实现了分类的目的

2.2 逻辑回归模型

        对于线性回归的模型来说,其输出值是任意的,常常会远远大于1或者远远小于0,仅仅上述的阈值可能并不会起到作用或者效果很差。

        对此,逻辑回归会先将所有预测值通过sigmoid 函数映射到[0,1]区间,函数表达式和图像如下图

                                                        (z为输入)

        ​​​​​​​        ​​​​​​​        

        sigmoid 函数是一个非线性函数,当x大于0时,输出值大于0.5,当x<0时输出值小于0.5

        最终我们得到逻辑回归的模型如下

        ​​​​​​​        

           h_{\theta }(x)作用是,对于给定的输入变量,通过参数\theta计算输出变量为1的可能性是多少

        

        假如对于一个输入x,最终计算出h_{\theta }(x)=0.7,则模型认为有70%的可能其为正向类(=1),相反负向类的可能性就为1-0.7=0.3

        最后在分类时,再入加上之前的阈值

        所以逻辑回归就是线性回归再嵌套一个非线性的sigmoid函数,其本质还是回归

2.4 决策边界(Decision Boundary

        假如分类这样一些数据,‘x’为1,圈为0

        ​​​​​​​        ​​​​​​​        ​​​​​​​        

        通过建立逻辑回归模型

        ​​​​​​​        ​​​​​​​        ​​​​​​​        

        假设经过训练我们得到了这样一组参数,于是得到嵌套在逻辑回归里的线性回归模型\theta^{\top }X=-3+x_{1}+x_{2},根据逻辑回归的原理当-3+x_{1}+x_{2}>=0时预测1,当-3+x_{1}+x_{2}<0时预测0,于是分隔情况就是-3+x_{1}+x_{2}=0,我们可以画出这个直线

        ​​​​​​​        ​​​​​​​        ​​​​​​​        

        这条线便是模型的决策边界

        如果是这样的数据

        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        

        建立逻辑回归模型

        得到参数

        

        同样的原理,得到其决策边界,是一个圆心在原点,半径为1的圆

        

        ​​​​​​​        ​​​​​​​        ​​​​​​​                

2.5 损失函数

2.5.1 为什么不用MSE损失函数

       根据上述的理论可以知道,逻辑回归的和线性回归的本质是一样的。那是不是意味着损失函数也可以用MSE。

        在线性回归中损失函数如下

                                                J(\theta )=\frac{1}{2m}\sum_{i=1}^{m}(h_{\theta }^{i}-y^{i})^{2}

        我们将​​​​​​​带入可以得到

        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​J(\theta )=\frac{1}{2m}\sum_{i=1}^{m}(\frac{1}{1+e^{\theta ^{\top }x}}-y)^{2}

        得到的是一个非凸函数(non-convexfunction),这会很大程度上影响梯度下降法寻找全局最小值,很可能停留在在某个局部极小值

2.5.2 对数损失函数

        介于上述问题,对于二分类逻辑回归来说,使用的是对数损失函数。

对于一个样本来说,预测值会有1和0两种情况,对应两个损失值

     

(log一般以e为底)

        ​​​​​​​        ​​​​​​​        

        当实际y=1时,如果预测值h_{\theta }(x)=1,此时预测是完全正确的,代入上式计算误差为0,如果预测值h_{\theta }(x)不为1,代表模型没有100%的把握认为这是正向类的,此时误差会随着h_{\theta }(x)的减小而变大。

        ​​​​​​​        ​​​​​​​        ​​​​​​​        

        当实际y=0时,如果预测值h_{\theta }(x)=0,此时预测是完全正确的,代入上式计算误差为0,如果预测值h_{\theta }(x)不为0,代表模型没有100%的把握认为这是负向类的,此时误差会随着h_{\theta }(x)的增大而变大。

        ​​​​​​​        ​​​​​​​        ​​​​​​​        

        将这两种情况合在一起

        再求和取平均得到最终损失函数表达式

        采用矩阵的形式表达

2.6 梯度下降

 

        矩阵表达式为

        使用梯度下降

        矩阵表达式为

        \theta = \theta -\frac{\alpha }{m}X^{\top }(h-y)

3. 多分类逻辑回归

        多分类逻辑回归的实现依赖于二分类

        将其中一个类标记为正向类,然后将其他类都标记为负向类,得到一个模型h_{\theta }^{1}(x),接着选择另外一个类标记为正向类,然后将其他类都标记为负向类,又得到一个模型h_{\theta }^{2}(x),以此类推,我们可以得到一系列模型,假设有k个类

        h_{\theta }^{i}(x)=p(y=i|x;\theta ),i=(1,2,3,4……k)

        训练好这一系列模型后,对于一个输入x,让其在所有的分类器都得到一个输出,最后选择一个maxh_{\theta }(x)作为最终的输出

4. 逻辑回归的实例

        ex2data1数据集包含100行数据前两列是学生的两种考试的成绩,最后一列是他们被是否录取。需要根据学生的两种考试的成绩来预测他们被是否录取。

1.读取数据集

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

data = pd.read_csv('ex2data1.txt',names=['exam1','exam2','admitted'])
data.head()

# 根据admitted的值分类
plt.scatter(positive['exam1'],positive['exam2'],marker='o',label='Admitted')
plt.scatter(negative['exam1'],negative['exam2'],marker='x',label='Not Admitted')
plt.xlabel('Exam1 Score')
plt.ylabel('Exam2 Score')
plt.legend()
plt.show()

2.数据预处理

data.insert(0,'ones',1)
X = data.iloc[:,0:-1].values
y = data.iloc[:,-1].values
y = y.reshape(100,1)

3.定义Sigmoid函数

def sigmoid(z):
    return 1/(1+np.exp(-z))

4.定义损失函数

def lossFunction(X,y,theta):
    m = len(X)
    h = sigmoid(X@theta)
    return (1/m)*np.sum(-y.T@np.log(h)-(1-y).T@np.log(1-h))

5.模型训练

def train(X,y,alpha,epochs):
    loss_history = []
    theta = np.random.rand(3,1)
    for i in range(epochs):
        m = len(X)
        h = sigmoid(X@theta)
        theta = theta - (alpha/m)*X.T@(h-y)
        current_loss = lossFunction(X,y,theta)
        loss_history.append(current_loss) 
        if (i+1) % 100 == 0:
            print("epochs={},current_loss={}".format(i+1,current_loss))
     # 绘制损失函数图像
    plt.plot(range(1,epochs+1),loss_history)
    plt.xlabel('epochs')
    plt.ylabel('loss')
    plt.title('Loss Curve')
    plt.show()
    return theta

# 参数
alpha = 0.1
epochs = 1000
theta = train(X,y,alpha,epochs)

admitted = X[y.flatten() == 1]
not_admitted = X[y.flatten() == 0]
plt.scatter(admitted[:, 1], admitted[:, 2], label='Admitted', marker='o')
plt.scatter(not_admitted[:, 1], not_admitted[:, 2], label='Not Admitted', marker='x')
plt.xlabel('Exam 1 score')
plt.ylabel('Exam 2 score')

# 绘制决策边界
plot_x = np.array([min(X[:, 1]) - 2, max(X[:, 1]) + 2])
plot_y = (-1 / theta[2]) * (theta[1] * plot_x + theta[0])
plt.plot(plot_x, plot_y, label='Decision Boundary')
plt.legend()
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1314599.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

uniGUI学习之UniTreeview

UniTreeview中能改变一级目录的字体和颜色 function beforeInit(sender, config) { ID"#"config.id; Ext.util.CSS.createStyleSheet( ${ID} .x-tree-node-text{color:green;font-weight:800;} ${ID} .x-tree-elbow-line ~ span{color:black;font-weight:400;} ); }

通义千问 Qwen-72B-Chat在PAI-DSW的微调推理实践

01 引言 通义千问-72B&#xff08;Qwen-72B&#xff09;是阿里云研发的通义千问大模型系列的720亿参数规模模型。Qwen-72B的预训练数据类型多样、覆盖广泛&#xff0c;包括大量网络文本、专业书籍、代码等。Qwen-72B-Chat是在Qwen-72B的基础上&#xff0c;使用对齐机制打造的…

Unity中 URP Shader 常量缓冲区CBUFFER

文章目录 前言一、常量缓冲区CBUFFER 使用步骤1、在属性面版定义我们需要使用的属性2、在Pass中&#xff0c;使用前需要提前声明3、使用时&#xff0c;直接使用即可 二、使用 常量缓冲区CBUFFER 的好处三、ShaderGraph属性 和 对应Shader的功能1、我们创建一个颜色属性2、使用&…

springcloudalibaba01

整合springcloud 和 springcloudalibaba&#xff0c;&#xff0c;&#xff0c; 版本对应关系 <dependencyManagement><dependencies><!--每个springcloud的工具都有一个版本每个springcloud alibaba的工具都有一个版本统一版本--> <!-- 整合…

MySQL作为服务端的配置过程与实际案例

MySQL是一款流行的关系型数据库管理系统&#xff0c;广泛应用于各种业务场景中。作为服务端&#xff0c;MySQL的配置过程对于数据库的性能、安全性和稳定性至关重要。本文将详细介绍MySQL作为服务端的配置过程&#xff0c;并通过一个实际案例进行举例说明。 一、MySQL服务端配…

Leaflet.Graticule源码分析以及经纬度汉化展示

目录 前言 一、源码分析 1、类图设计 2、时序调用 3、调用说明 二、经纬度汉化 1、改造前 2、汉化 3、改造效果 总结 前言 在之前的博客基于Leaflet的Webgis经纬网格生成实践中&#xff0c;已经深入介绍了Leaflet.Graticule的实际使用方法和进行了简单的源码分析。认…

DSP的ADC简单笔记

DSP不需要复用GPIO&#xff0c;是单独的ADC引脚&#xff0c;与GPIO不共用 ADC时钟在PCLKCR0寄存器 所以还要配置HSPCLK HISPCP/HSPCLK寄存器 所以ADC的输入时钟&#xff0c;有固定公式&#xff1b; 控制寄存器1 简单配置3个东西&#xff1b; 控制寄存器2 设置为1软件触发 控…

使用kibana查看es数据

前提 已安装好es还有kibana&#xff0c;启动es及kibana 修改kibana配置文件 在kibana文件中配置es的地址及索引&#xff0c;我的kibana安装在mac端了 修改配置文件 /usr/local/opt/kibana/config/kibana.yml 重启kibana 配置kibana 下面查询数据 例如查询 traceId 为192…

Actor-Critic 跑 CartPole-v1

gym-0.26.1 CartPole-v1 Actor-Critic 这里采用 时序差分残差 ψ t r t γ V π θ ( s t 1 ) − V π θ ( s t ) \psi_t r_t \gamma V_{\pi _ \theta} (s_{t1}) - V_{\pi _ \theta}({s_t}) ψt​rt​γVπθ​​(st1​)−Vπθ​​(st​) 详细请参考 动手学强化学习 简…

C语言--clock()时间函数【详细介绍】

一.clock()时间函数介绍 在 C/C 中&#xff0c;clock() 函数通常用于处理和测量程序运行时间&#xff08;时钟时间&#xff09;。它是一种数据类型&#xff0c;表示 CPU 执行指定任务所耗费的“时钟计数”数量&#xff0c;单位为“时钟周期”。 这个函数通常包含在 time.h 头文…

后缀数组模板

详细理解后缀数组求sa数组的函数&#xff0c;该函数可以看为主要分为三个部分&#xff0c;第一个部分是预处理&#xff1b;第二个部分是进行基数排序&#xff0c;首先根据第二关键词排序&#xff0c;然后根据第一关键字排序&#xff1b;第三个部分是根据排序后的结果重新为每个…

Bytebase 2.12.0 - 改进自动补全和布局导航

&#x1f680; 新功能 支持 MySQL 高级自动补全。支持从 UI 上导入分类分级配置。 &#x1f514; 重大变更 作废已有企业版试用证书。之后可以通过提交申请获取新的试用证书。 &#x1f384; 改进 改进整体布局和导航。 支持在 SQL 编辑器里显示以及查询 PostgreSQL 数据…

DDOS 攻击是什么?有哪些常见的DDOS攻击?

DDOS简介 DDOS又称为分布式拒绝服务&#xff0c;全称是Distributed Denial of Service。DDOS本是利用合理的请求造成资源过载&#xff0c;导致服务不可用&#xff0c;从而造成服务器拒绝正常流量服务。就如酒店里的房间是有固定的数量的&#xff0c;比如一个酒店有50个房间&am…

继续看回溯问题

关卡名 继续看回溯问题 我会了✔️ 内容 1.复习递归和N叉树&#xff0c;理解相关代码是如何实现的 ✔️ 2.理解回溯到底怎么回事 ✔️ 3.掌握如何使用回溯来解决二叉树的路径问题 ✔️ 1 复原IP地址 这也是一个经典的分割类型的回溯问题。LeetCode93.有效IP地址正好由四…

生产环境_Spark处理轨迹中跨越本初子午线的经度列

使用spark处理数据集&#xff0c;解决gis轨迹点在地图上跨本初子午线的问题&#xff0c;这个问题很复杂&#xff0c;先补充一版我写的 import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.func…

t-io 程序执行后,jvm不退出的原因

基于t-io 1.7.3 版本分析源码 1、设定当前时间&#xff0c;每10毫秒执行一次 (非守护线程) 2、对应线程池的核心线程在AioServer启动时全部激活&#xff0c;并且添加空任务到阻塞队列&#xff0c;让核心线程(非守护线程)一直存活

ArcGIS Pro SDK文件选择对话框

文件保存对话框 // 获取默认数据库var gdbPath Project.Current.DefaultGeodatabasePath;//设置文件的保存路径SaveItemDialog saveLayerFileDialog new SaveItemDialog(){Title "Save Layer File",OverwritePrompt true,//获取或设置当同名文件已存在时是否出现…

七. 使用ts写一个贪吃蛇小游戏

之前学习了几篇的ts基础&#xff0c;今天我们就使用ts来完成一个贪吃蛇的小游戏。 游戏拆解 我们将我们的任务进行简单拆解分析。 首先我们应该有一个窗口&#xff0c;我们叫做屏幕。让蛇在里面移动&#xff0c;所有我们应该想到要设计一个大盒子当作地图。考虑到食物以及蛇…

【Java代码审计】文件上传篇

【Java代码审计】文件上传篇 1.Java常见文件上传方式2.文件上传漏洞修复 1.Java常见文件上传方式 1、通过文件流的方式上传 public static void uploadFile(String targetURL, String filePath) throws IOException {File file new File(filePath);FileInputStream fileInpu…

【单调栈】【区间合并】LeetCode85:最大矩形

作者推荐 【动态规划】【广度优先搜索】LeetCode:2617 网格图中最少访问的格子数 本文涉及的知识点 单调栈 区间合并 题目 给定一个仅包含 0 和 1 、大小为 rows x cols 的二维二进制矩阵&#xff0c;找出只包含 1 的最大矩形&#xff0c;并返回其面积。 示例 1&#xff1…