用PyTorch轻松实现二分类:逻辑回归入门

news2025/1/11 15:07:30

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢迎在文章下方留下你的评论和反馈。我期待着与你分享知识、互相学习和建立一个积极的社区。谢谢你的光临,让我们一起踏上这个知识之旅!
请添加图片描述

文章目录

  • 🥦引言
  • 🥦什么是逻辑回归?
  • 🥦分类问题
  • 🥦交叉熵
  • 🥦代码实现
  • 🥦总结

🥦引言

当谈到机器学习和深度学习时,逻辑回归是一个非常重要的算法,它通常用于二分类问题。在这篇博客中,我们将使用PyTorch来实现逻辑回归。PyTorch是一个流行的深度学习框架,它提供了强大的工具来构建和训练神经网络,适用于各种机器学习任务。

在机器学习中已经使用了sklearn库介绍过逻辑回归,这里重点使用pytorch这个深度学习框架

🥦什么是逻辑回归?

我们首先来回顾一下什么是逻辑回归?

逻辑回归是一种用于二分类问题的监督学习算法。它的主要思想是通过一个S形曲线(通常是Sigmoid函数)将输入特征映射到0和1之间的概率值,然后根据这些概率值进行分类决策。在逻辑回归中,我们使用一个线性模型和一个激活函数来实现这个映射。

🥦分类问题

这里以MINIST Dataset手写数字集为例
在这里插入图片描述

这个数据集中包含了6w个训练集1w个测试集,类别10个
这里我们不再向之前线性回归那样,根据属于判断具体的数值大小;而是根据输入的值判断从0-9每个数字的概率大小记为p(0)、p(1)…而且十个概率值和为1,我们的目标就是根据输入得到这十个分类对于输入的每一个的概率值,哪个大就是我们需要的。

这里介绍一下与torch相关联的库—torchvision
torchvision:

  • “torchvision” 是一个PyTorch的附加库,专门用于处理图像和视觉任务。
    它包含了一系列用于数据加载、数据增强、计算机视觉任务(如图像分类、目标检测等)的工具和数据集。
  • “torchvision” 提供了许多预训练的视觉模型(例如,ResNet、VGG、AlexNet等),可以用于迁移学习或作为基准模型。
    此外,它还包括了用于图像预处理、转换和可视化的函数。

上图已经清楚的显示了,这个库包含了一些自带的数据集,但是并不是我们安装完这个库就有了,而且需要进行调用的,类似在线下载,root指定下载的路径,train表示你需要训练集还是测试集,通常情况下就是两个一个训练,一个测试,download就是判断你下没下载,下载了就是摆设,没下载就给你下载了

我们再来看一个数据集(CIFAR-10)
在这里插入图片描述
包含了5w训练样本,1w测试样本,10类。调用方式与上一个类似。

接下来我们从一张图更加直观的查看分类和回归
在这里插入图片描述

左边的是回归,右边的是分类


在这里插入图片描述

过去我们使用回归例如 y ^ \hat{y} y^=wx+b∈R,这是属于一个实数的;但是在分类问题, y ^ \hat{y} y^∈[0,1]
这说明我们需要寻找一个函数,将原本实数的值经过函数的映射转化为[0,1]之间。这里我们引入Logistic函数,使用极限很清楚的得出x趋向于正无穷的时候函数为1,x趋向于负无穷的时候,函数为0,x=0的时候,函数为0.5,当我们计算的时候将 y ^ \hat{y} y^带入这样就会出现一个0到1的概率了。

下图展示一些其他的Sigmoid函数
在这里插入图片描述

🥦交叉熵

过去我们所使用的损失函数普遍都是MSE,这里引入一个新的损失函数—交叉熵

==交叉熵(Cross-Entropy)==是一种用于衡量两个概率分布之间差异的数学方法,常用于机器学习和深度学习中,特别是在分类问题中。它是一个非常重要的损失函数,用于衡量模型的预测与真实标签之间的差异,从而帮助优化模型参数。

在交叉熵的上下文中,通常有两个概率分布:

  • 真实分布(True Distribution): 这是指问题的实际概率分布,表示样本的真实标签分布。通常用 p ( x ) p(x) p(x)表示,其中 x x x表示样本或类别。

  • 预测分布(Predicted Distribution): 这是指模型的预测概率分布,表示模型对每个类别的预测概率。通常用 q ( x ) q(x) q(x)表示,其中 x x x表示样本或类别。

交叉熵的一般定义如下:
在这里插入图片描述其中, H ( p , q ) H(p, q) H(p,q) 表示真实分布 p p p 和预测分布 q q q 之间的交叉熵。

交叉熵的主要特点和用途包括:

  • 度量差异性: 交叉熵度量了真实分布和预测分布之间的差异。当两个分布相似时,交叉熵较小;当它们之间的差异增大时,交叉熵增大。

  • 损失函数: 在机器学习中,交叉熵通常用作损失函数,用于衡量模型的预测与真实标签之间的差异。在分类任务中,通常使用交叉熵作为模型的损失函数,帮助模型优化参数以提高分类性能。

  • 反向传播: 交叉熵在训练神经网络时非常有用。通过计算交叉熵的梯度,可以使用反向传播算法来调整神经网络的权重,从而使模型的预测更接近真实标签。

在分类问题中,常见的交叉熵损失函数包括二元交叉熵(Binary Cross-Entropy)和多元交叉熵(Categorical Cross-Entropy)。二元交叉熵用于二分类问题,多元交叉熵用于多类别分类问题。

刘二大人的PPT中也介绍了
在这里插入图片描述
右边的表格中每组y与 y ^ \hat{y} y^对应的BCE,BCE越高说明越可能,最后将其求均值

🥦代码实现

在这里插入图片描述

根据上图可知,线性回归和逻辑回归的流程与函数只区别于Sigmoid函数
在这里插入图片描述
这里就是BCEloss的调用,里面的参数代表求不求均值

完整代码如下

import torch.nn.functional as F
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])
class LogisticRegressionModel(torch.nn.Module):
	def __init__(self):
		super(LogisticRegressionModel, self).__init__() 
		self.linear = torch.nn.Linear(1, 1)
	def forward(self, x):
		y_pred = F.sigmoid(self.linear(x))
		return y_pred
model = LogisticRegressionModel() 
criterion = torch.nn.BCELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  
for epoch in range(1000):
	y_pred = model(x_data)
	loss = criterion(y_pred, y_data)
	print(epoch, loss.item())
	optimizer.zero_grad() 
	loss.backward()
	optimizer.step()

最后绘制一下

import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 10, 200)
x_t = torch.Tensor(x).view((200, 1))  # 相当于reshape
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], c='r') 
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()

运行结果如下
在这里插入图片描述

🥦总结

这就是使用PyTorch实现逻辑回归的基本步骤。逻辑回归是一个简单但非常有用的算法,可用于各种分类问题。希望这篇博客能帮助你开始使用PyTorch构建自己的逻辑回归模型。如果你想进一步扩展你的知识,可以尝试在更大的数据集上训练模型或探索其他深度学习算法。祝你好运!

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1071602.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

用ffmpeg删除视频的音轨,让视频静音

ffmpeg -i ~/video/video.mp4 -an -vcodec copy ~/video/muteVideo.mp4 删除以后我们查看muteVideo的文件信息,只有一个Stream:video信息了。 再对比看一下video.mp4的信息,是有两个Stream信息,一个video,一个audio。…

gitlab登录出现的Invalid login or password问题

前提 我是在一个项目里创建的gitlab账号,想在别的项目里登录或者官网登录发现怎么都登陆不上 原因 在GitLab中,有两种不同的账号类型:项目账号和个人账号(官网账号)。 项目账号:项目账号是在特定GitLab…

竞赛 深度学习 opencv python 实现中国交通标志识别

文章目录 0 前言1 yolov5实现中国交通标志检测2.算法原理2.1 算法简介2.2网络架构2.3 关键代码 3 数据集处理3.1 VOC格式介绍3.2 将中国交通标志检测数据集CCTSDB数据转换成VOC数据格式3.3 手动标注数据集 4 模型训练5 实现效果5.1 视频效果 6 最后 0 前言 🔥 优质…

竞赛选题 深度学习 python opencv 火焰检测识别 火灾检测

文章目录 0 前言1 基于YOLO的火焰检测与识别2 课题背景3 卷积神经网络3.1 卷积层3.2 池化层3.3 激活函数:3.4 全连接层3.5 使用tensorflow中keras模块实现卷积神经网络 4 YOLOV54.1 网络架构图4.2 输入端4.3 基准网络4.4 Neck网络4.5 Head输出层 5 数据集准备5.1 数…

vscode package.json文件开头的{总是提升警告

警告如下 Problems loading reference https://json.schemastore.org/stylelintrc.json: Unable to load schema from https://json.schemastore.org/stylelintrc.json: read ECONNRESET. 解决如下 在设置(settings.json)里 新增一条属性 "ht…

0501 货仓选址 【中位数 距离和的最小值】

0501 货仓选址 【中位数 距离和的最小值】 描述 在一条数轴上有N家商店,它们的坐标分别为 A[1]~A[N]。现在需要在数轴上建立一家货仓,每天清晨,从货仓到每家商店都要运送一车商品。为了提高效率,求把货仓建在何处,可以…

git+码云提交PR流程记录

前提条件:注册码云账号,本地安装git 如果不知道怎么注册和安装,可以参考gitgitee入门教程(https://bbs.huaweicloud.com/forum/thread-55222-1-1.html) 登录自己的码云账号 登陆了之后,在码云上打开目标项…

k8s-8 ingress-nginx

nodeport 默认端口 nodeport默认端口是30000-32767,超出会报错 添加如下参数,端口范围可以自定义 externalname ingress-nginx 通过一个外部的vip 地址 访问到集群内的多个service 一种全局的、为了代理不同后端 Service 而设置的负载均衡服务&…

照片处理软件Lightroom Classic mac中文版功能介绍(Lrc2021)

Lightroom Classic 2022 mac是一款桌面编辑工具,lrc2021 mac包括提亮颜色、使灰暗的摄影更加生动、删除瑕疵、将弯曲的画面拉直等。您可以在电脑桌面上轻松整理所有照片。使用Lightroom Classic, 轻松整理编辑照片,为您的作品锦上添花。 Ligh…

Vega Prime入门教程11:软件界面

本文首发于:Vega Prime入门教程11:软件界面 Vega Prime工具包中,包含了一个重要的编辑器Lynx prime(以后简称LP),它为VP提供一个人机交互界面。 启动 打开桌面上的快捷方式 软件会自动打开模板工程 界面构成 LynX Prime用户界…

阿加犀AI应用案例征集活动 持续进行中!

当下,人工智能正经历着迅猛的技术进步和广泛的应用拓展,边缘端计算运行也成为了一个重要的趋势。边缘计算通过降低延迟、节省带宽、增强隐私保护、提高系统可靠性等特性,为AI和IoT应用提供了强大的支持,使得智能应用更加灵活、高效…

通信与网络及软件工具的使用心得与记录

在当今的信息时代,通信工程和网络工具已经成为我们工作和生活中不可或缺的一部分。为了更好地利用这些工具,我们需要了解它们的基本原理和使用方法。本文将为您详细介绍一些重要的通信工程和网络工具,以及它们在实际应用中的使用心得和笔记。…

阶段六-Day01-Linux入门

一、 Linux简介 1. 概念 Linux是一款操作系统。和Windows操作系统类似。 2. Linux操作系统的优势 2.1 稳定性 Linux采取了许多安全技术措施,其中有对读、写进行权限控制、审计跟踪、核心授权等技术,这些都为安全提供了保障。 据说Linux系统可以十年…

应用案例 | dataFEED OPC Suite为化工行业中的质量控制和成本节约提供数据集成方案

一 背景 在当今化工行业中,质量控制对于特种塑料供应商至关重要。一家国际性的特种塑料供应商在全球拥有五个生产基地,每个基地都运行着2-6台塑料挤出机。为了确保塑料质量,他们需要每两小时分析一次挤出样品——导致这项工作占用了较大的生…

如何向客户推广 API 商品数据接口,如何跟进项目和程序员对接?

一、了解 API商品数据接口 在推广 API 商品数据接口之前,首先需要了解它的基本概念、优势以及如何选择合适的接口。 1.API 商品数据接口的基本概念 API 是 Application Programming Interface 的缩写,即应用程序编程接口。API 商品数据接口是一种允许…

C++ 01.学习C++的意义-狄泰软件学院

一些历史 UNIX操作系统诞生之初是用汇编语言编写的随着UNIX系统的发展,汇编语言的开发效率成为瓶颈,所以需要一个新的语言替代汇编语言1971年通过对B语言改良,使其能直接产生机器代码,C语言诞生UNIX使用C语言重写,同时…

力扣 -- 5. 最长回文子串

解题步骤&#xff1a; 参考代码&#xff1a; class Solution { public:string longestPalindrome(string s) {int ns.size();vector<vector<bool>> dp(n,vector<bool>(n));//最长回文串的起始位置int start0;//最长回文串的长度int len0;for(int in-1;i>…

前端开发转岗项目经理有什么建议吗?

前端开发转岗项目经理是一个不同领域的跨越&#xff0c;需要经历很多的学习和实践。当一个前端开发转岗项目经理时&#xff0c;需要做出一些改变&#xff0c;以适应新的角色和职责。在这篇文章中&#xff0c;我将分享一些建议&#xff0c;帮助前端开发转岗项目经理更好地适应新…

TensorFlow入门(十一、图的基本操作)

建立图 一个TensorFlow程序默认是建立一个图的,除了系统自动建图以外,还可以用tf.Graph()手动建立,并做一些其他的操作 如果想要获得程序一开始默认的图,可以使用tf.get_default_graph()函数 如果想要重新建立一张图代替原来的图,可以使用tf.reset_default_graph()函数 注意:在…

PL/SQL拉链表

练习:-- 拉链表练习: 维度表源表 ID M_NAME REST UP_DATE 1 车贷 0.01 2022/12/1 2 房贷 0.03 2022/12/1 3 经营贷 0.015 2022/12/1 维度表拉链表 ID M_NAME REST BEGIN_DATE END_DATE 1 车贷 …