5.3 用PyTorch实现Logistic回归

news2024/9/23 17:23:27

一、数据准备

Logistic回归常用于解决二分类问题。

为了便于描述,我们分别从两个多元高斯分布 N₁(μ₁,Σ₁ )、N₂(μ₂,Σ₂)中生成数据 x₁ 和 x₂,这两个多元高斯分布分别表示两个类别,分别设置其标签为 y₁ 和 y₂。

PyTorch 的 torch.distributions 提供了 MultivariateNormal 构建多元高斯分布

下面代码设置两组不同的均值向量和协方差矩阵,μ₁(mul)和 μ₂(mul)是二维均值向量,Σ₁(sigmal)和Σ₂(sigma2)是2*2的协方差矩阵。

前面定义的均值向量和协方差矩阵作为差数传入 MultivariateNormal,就实例化了两个多元高斯分布 m₁和 m₂。

调用 m₁和 m₂ 的sample方法分别生成100个样本。

设置样本对应的标签 y,分别用 0 和 1 表示不同高斯分布的数据,也就是正样本和负样本。

使用 cat 函数将 x₁(m1)和 x₂(m2)组合在一起。

打乱样本和标签的顺序,将数据重新随机排列,这是十分重要的步骤,否则算法的每次迭代只会学习到同一个类别的信息,容易造成模型过拟合。 

将生成的样本用 plt.scatter 绘制出来。

绘制结果如图:

可以明显的看出多元高斯分布生成的样本聚成了两个簇,并且簇的中心分别处于不同的位置(多元高斯分布的均值向量决定了其位置)。

右上角簇的样本分布比较稀疏,而左下角簇的样本分布紧凑(多元高斯分布的协方差矩阵决定了分布形状)。

【可调整

mu1 = -3 * torch.ones(2)

mu2 = 3 * torch.ones(2)

的参数,观察变化!

二、线性方程

Logistic回归用输入变量x的线性函数表示样本为正类的对数概率。nn.Linear 实现了 y = xAᵀ + b,我们可以直接调用它来实现Logistic回归的线性部分。

定义线性模型的输入维度D_in 和输出维度 D_out,因为前面定义的多元高斯分布 m₁(m1)和 m₂(m2)产生的变量是二维的,所以线性模型的输入维度应该定义为D_in = 2 ;而Logistic回归是二分类模型,预测的是变量为正类的概率,所以输出的维度应该为D_in = 1。

实例化了nn.Linear,将线性模型应用到数据 x 上,得到计算结果output。

Linear的初始参数是随机设置的,可以调用Linear.weight 和 Linear.bias 获取线性模型的参数。

输出输入的变量x,模型参数weight和bias,以及计算结果output的维度。

定义线性模型my_linear,将my_linear的计算结果和PyTorch的计算结果output进行比较,可以发现他们是一致的。

输出:

三、激活函数

前面介绍了nn.Linear可用于实现线性模型,除此之外,torch.nn还提供了机器学习中常用的激活函数。当Logistic回归用于二分类问题时,使用sigmoid函数将线性模型的计算结果映射到0和1之间,得到的计算结果作为样本为正类的置信概率。nn.Sigmoid提供了sigmoid函数的计算,在使用时,将Sigmoid类实例化,再将需要计算的变量作为参数传递给实例化的对象。

输出:

【def my_sigmoid(x):

        x = 1 / (1 + torch.exp(-x))

        return x手动实现sigmoid函数;

print(torch.sum(sigmoid(output) - sigmoid_(output)))通过PyTorch验证我们的实现结果,其结果一致】

四、损失函数

1、

Logistic回归使用交叉熵作为损失函数

PyTorch的torch.nn提供了许多标准的损失函数,我们可以直接使用 nn.BCELoss 计算二值交叉熵损失。

调用BCELoss来计算我们实现Logistic回归模型的输出结果sigmoid(output)和数据的标签y。

自定义二值交叉熵函数

将my_loss 和PyTorch的BCELoss进行比较,发现其结果一致。

2、

前面的代码中,我们使用了torch.nn包中的线性模型nn.Linear、激活函数nn.Softmax、损失函数nn.BCELoss,他们都继承自nn.Module类。

而在PyTorch中,我们通过继承nn.Module来构建我们自己的模型。

下面用nn.Module来实现Logistic回归。

输出:

3、

当通过继承nn.Module实现自己的模型时,forward方法是必须被子类覆写的,

在forward内部应当定义每次调用模型时执行的计算。

从代码中可以看出,nn.Module类的主要作用就是接收Tensor然后计算并返回结果。

在一个Module中,还可以嵌套其他的Module,被嵌套的Module的属性就可以被自动获取,比如可以调用nn.Module.parameters方法获取Module所有保留的参数,调用nn.Module.to方法将模型的参数放置到GPU上等。

输出:

五、优化算法

Logistic回归通常采用梯度下降法优化目标函数。

PyTorch的torch.optim包实现了大多数常用的优化算法,使用起来非常简单。

首先构建一个优化器,在构建时,需要将学习的参数传入,然后传入优化器需要的参数,比如学习率。

构建完优化器,就可以迭代地对模型进行训练,

两个步骤:

(1)调用损失函数的backward方法计算模型的梯度

(1)调用优化器的step方法更新模型的参数。

需要注意:应当提前调用优化器的zero_grad方法清空参数的梯度。

六、模型可视化

Logistic回归模型的判决边界在高维空间是一个超平面,而我们的数据集是二维的,所以判决边界只是平面内的一条直线,在线的一侧被预测为正类,另一侧被预测为负类。下面我们实现draw_decision_boundary函数。

它接收线性模型的参数w和b,以及数据集x。

绘制判决边界的方法十分简单。

,只需要计算一些数据在线性模型的映射值,然后调用plt.plot绘制线条即可。如下图:

      --------------------------------------------------------------------------------------------------------------

七、回归VS分类

在之前的回归任务中,我们是预测分值是多少;

分类任务中就可以变成根据学习时间判断是否能通过考试,即结果分为两类:fail、pass。

我们的任务就是计算不同时间 x 分别是 fail、pass 的概率。(二分类问题其实只需要计算一个概率; 另一个概况就是1-算的概率)

如果预测pass概率为0.6,fail概率就是0.4,那么判断为pass。

1、sigmoid函数

σ(x)= 1 / 1+e⁻ˣ

sigmoid函数在x无限趋近于正无穷、 负无穷时,y无线趋近于1、0;

可以看到当x非常大或者非常小的时候,函数梯度变化就非常小了。这 种函数称为饱和函数

八、逻辑回归

1、逻辑回归模型

只是在线性回归之后加了一个sigmoid激活函数!将值映 射在【0,1】之间。

在线性回归中,我们假设随机变量𝑥1,⋯,𝑥𝑛与𝑦之间的关系是线性的。

但在实际中,我们通常会遇到非线性关系。这个时候,我们可以使用一个非线性变化g(·),使得线性回归模型 𝑓(⋅) 实际上对g(y)而非y进行拟合,即:

y = g⁻¹(f(x))

其中 f(·)仍为:

f(x)= wᵀx + b

因此这样的回归模型称为广义线性回归模型

广义线性回归模型使用非常广泛。例如在二元分类任务中,我们的目标是拟合这样一 个分离超平面𝑓(𝒙)=𝒘ᵀ𝒙+𝑏,使得目标分类𝑦可表示为以下阶跃函数:

但是在分类问题中,由于𝑦取离散值,这个阶跃判别函数是不可导的。

不可导的性质使得许多数学方法不能使用。我们考虑 通常可以使用一个函数σ(·)来近似这个离散的阶跃函数,通常可以使用 logistic(Sigmoid函数)

2、损失函数

MSE loss:计算 数值之间的差异

BCE loss:计算 分布之间的差异

3、Logistic回归代码实现

训练结果为:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1542978.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Redis知识点总结】(七)——缓存雪崩、缓存穿透、缓存击穿、Redis高级用法

Redis知识点总结(七)——缓存雪崩、缓存穿透、缓存击穿、Redis高级用法 缓存雪崩缓存穿透布隆过滤器 缓存击穿Redis高级用法bitmapHyperLogLog 缓存雪崩 缓存雪崩是指,同一时间有大量的缓存key失效,或者redis节点直接宕机了&…

基于nodejs+vue学生作业管理系统python-flask-django-php

他们不仅希望页面简单大方,还希望操作方便,可以快速锁定他们需要的线上管理方式。基于这种情况,我们需要这样一个界面简单大方、功能齐全的系统来解决用户问题,满足用户需求。 课题主要分为三大模块:即管理员模块和学生…

(已解决)vue运行时出现Moudle Bulid error: this.getoptios is not a function at........

1.首先查看你的less-loader版本 点击package.json 即可查看less-loader版本,我之前的版本是12.0,太高了,出现了不兼容的问题 2、卸载less-loader ctrlshift~ 打开项目终端 ,输入: npm uninstall less-loader 3.重…

Linux中的常用基础操作

ls 列出当前目录下的子目录和文件 ls -a 列出当前目录下的所有内容(包括以.开头的隐藏文件) ls [目录名] 列出指定目录下的子目录和文件 ls -l 或 ll 以列表的形式列出当前目录下子目录和文件的详细信息 pwd 显示当前所在目录的路径 ctrll 清屏 cd…

突破极限!DNF全新装备现身,无限高达流震撼登场

近期,DNF(地下城与勇士)的玩家们热议的话题不再是往日的副本攻略或职业平衡,而是一件神秘的装备——“玉化腰带”。这个装备的出现,引发了一股前所未有的热潮,因为它带来的无限高达流,彻底颠覆了…

0325 ISP流程介绍 RAW/RGB/YUV

学习参考:isp流程介绍(yuv格式阶段) - 知乎 一、RAW 1.1 ISP RAW 之DPC DPC(defective pixel correction)也就是坏点矫正,在sensor接收光信号,并做光电转换之后。 这一步设计的意义在于:摄像头sensor的感光元件通常很多会存在一些工艺缺陷…

硅谷甄选项目笔记

硅谷甄选运营平台 此次教学课程为硅谷甄选运营平台项目,包含运营平台项目模板从0到1开发,以及数据大屏幕、权限等业务。 此次教学课程涉及到技术栈包含***:vue3TypeScriptvue-routerpiniaelement-plusaxiosecharts***等技术栈。 一、vue3组件通信方式 通信仓库地…

maven下载与使用

Maven介绍 Maven 是一款为 Java 项目构建管理、依赖管理的工具(软件),使用 Maven 可以自动化构建、测试、打包和发布项目,大大提高了开发效率和质量。 总结:Maven就是一个软件,掌握软件安装、配置、以及基…

苍穹外卖项目笔记

软件开发流程 需求分析:说明书和原型 设计:UI,数据库,接口设计 编码:项目代码,单元测试 测试:测试用例,测试报告 上线运维:软件环境安装,配置 软件环境…

《大模型技术要求标准》重磅发布,九章云极DataCanvas公司助力我国大模型技术发展

近日,中国信息通信研究院(简称“中国信通院”)重磅发布**《人工智能开发平台通用能力要求 第4部分:大模型技术要求》**(以下简称“大模型技术要求标准”),九章云极DataCanvas公司依托在大模型领…

Linux收到一个网络包是怎么处理的?

目录 摘要 ​编辑 1 从网卡开始 2 硬中断,有点短 2.1 Game Over 3 接力——软中断 3.1 NET_RX_SOFTIRQ 软中断的开始 3.2 数据包到了协议栈 3.3 网络层处理 3.4 传输层处理 4 应用层的处理 5 总结 摘要 一个网络包的接收始于网卡,经层层协议栈…

Samtec应用漫谈 | EVSE基础设施的前进道路

【摘要/前言】 能源的格局正在不断变化,燃油价格正在上升,因此越来越多的人正在考虑个人电动车的经济性。 此外,全球物流公司正在更多地采用电动面包车和卡车进行物流运输。大街小巷中,大家可以看到各式新能源汽车,电…

闻了刚脱下的袜子导致肺部感染真菌?后果很严重

之前有网友分享自己只因闻了刚脱下的袜子,就导致了肺部感染真菌的经历,引发众多网友的关注与热议。 那么,臭袜子又怎么会和肺部感染有关系呢?臭袜子为什么不能闻呢?袜子上面到底有什么有危险的成分呢? 图源…

10-shell编程-辅助功能

一、字体颜色设置 第一种: \E[1:色号m需要变色的字符串\E[0m 第二种: \033[1:色号m需要变色的字符串\033[0m ########################### \E或者\033 #开启颜色功能 [1: #效果 31m #颜色色号 \E[0m #结束符 1,颜色案例 2,效果案例 二、gui&am…

应急响应实战笔记04Windows实战篇(1)

第1篇:FTP暴力破解 0x00 前言 ​ FTP是一个文件传输协议,用户通过FTP可从客户机程序向远程主机上传或下载文件,常用于网站代码维护、日常源码备份等。如果攻击者通过FTP匿名访问或者弱口令获取FTP权限,可直接上传webshell&#…

C语言学习 五、一维数组与字符数组

5.1一维数组 5.1.1数组的定义 数组特点: 具有相同的数据类型使用过程中需要保存原始数据 C语言为了方便操作这些数据,提供了一种构造数据类型——数组,数组是指一组具有相同数据类型的数据的有序集合。 一维数组的定义格式为 数据类型 数…

删除数组中的指定元素(了解如何删除数组中的指定元素,并返回一个新的数组,看这一篇就足够了!)

前言:有时候我们会遇到要在数组中删除指定元素,但是不能创建新的数组,那么这个时候应该如何操作呢? ✨✨✨这里是秋刀鱼不做梦的BLOG ✨✨✨想要了解更多内容可以访问我的主页秋刀鱼不做梦-CSDN博客 废话不多讲,让我们…

阿里 Modelscope 创空间部署在本地环境操作文档

创建创空间的步骤直接跳过。 备注:我的电脑是Windows 第一步:获取创空间代码,直接下载代码太慢了,建议通过git获取代码 第二步:复制链接,打开cmd 直接粘贴回车下载。下载完之后的到了我的Service-Assistant文件夹。再git clone https://gith…

幻尔机械臂FPV安装darknet_ros(YOLO V3)

mkdir -p catkin_workspace/src cd catkin_workspace/src git clone --recursive gitgithub.com:leggedrobotics/darknet_ros.git cd ../ 在ROS工作空间目录下,执行命令: catkin_make -DCMAKE_BUILD_TYPERelease 发布摄像头图像话题: …

电商API数据采集接口——电商大数据构建及智能应用

现在越来越多的电商企业和运营都开始关注数据的应用,在13年淘宝运营技巧的爆发,这其实就是数据带来的红利。在数据大爆炸的时代,数据分析已经成为了企业制定策略、发现问题的重要方法,所以,数据分析绝对是企业管理的贤…