Lecture6 逻辑斯蒂回归(Logistic Regression)

news2024/10/1 5:33:46

目录

1 常用数据集

1.1 MNIST数据集

1.2 CIFAR-10数据集

2 课堂内容

2.1 回归任务和分类任务的区别

2.2 为什么使用逻辑斯蒂回归

2.3 什么是逻辑斯蒂回归

2.4 Sigmoid函数和饱和函数的概念

2.5 逻辑斯蒂回归模型

2.6 逻辑斯蒂回归损失函数

2.6.1 二分类损失函数

2.6.2 小批量二分类损失函数

3 代码实现


1 常用数据集

1.1 MNIST数据集

MNIST是一个手写数字图像数据集,主要用于训练和测试机器学习模型。它由60,000个训练图像和10,000个测试图像组成,每个图像都是28x28像素的灰度图像,表示一个手写数字。MNIST数据集已成为许多机器学习算法的基准数据集之一,尤其是用于图像分类任务和数字识别任务。

图1 MNIST数据集

下载方式

import torchvision
train_set = torchvision.datasets.MNIST(root='../dataset/mnist', train=True,  download=True)
test_set  = torchvision.datasets.MNIST(root='../dataset/mnist', train=False, download=True)

这段代码中,包括两个参数:root和train,download为可选参数。root指定数据集下载的根目录,train指定要加载的数据集类型,train=True表示加载训练集,train=False表示加载测试集。download=True表示如果本地不存在则进行该数据集的下载,最后将训练集和测试集分别保存在train_set和test_set两个变量中。

1.2 CIFAR-10数据集

图2 CIFAR-10数据集

 CIFAR-10是一个常用的图像分类数据集,由10个不同类别的60000个32x32彩色图像组成。这些类别包括:飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车。数据集被分为训练集和测试集,其中训练集包含50000个图像,测试集包含10000个图像。CIFAR-10数据集被广泛用于计算机视觉和深度学习研究领域,特别是用于图像分类算法的开发和测试。

import torchvision
train_set = torchvision.datasets.CIFAR10(…)
test_set  = torchvision.datasets.CIFAR10(…)

2 课堂内容

2.1 回归任务和分类任务的区别

  回归任务和分类任务是机器学习中两个重要的任务类型,其主要区别在于预测目标的不同。分类任务的目标是将输入数据分到预定义的类别中。例如,手写数字分类任务中,目标是将手写数字图像分到0-9十个数字中的一个类别中。回归任务的目标是预测连续的数值。例如,房价预测任务中,目标是预测房屋的售价,这是一个连续的数值。

  因此,回归任务和分类任务的主要区别在于预测目标的类型:分类任务的目标是预测一个离散的类别,而回归任务的目标是预测一个连续的数值。

2.2 为什么使用逻辑斯蒂回归

  对于分类问题,由于值是离散的,所以用逻辑斯蒂回归(logistic regression)。二分类问题中,我们需要求得结果 是 输入数据 所属的类别。例如下图,结果要么通过要么不通过,属于典型的二分类(binary classification)问题。

图3 二分类问题

  二分类问题是机器学习中最常见的问题之一,应用广泛,例如判断一封邮件是否是垃圾邮件、判断一个人是否患有某种疾病等。

2.3 什么是逻辑斯蒂回归

  在二分类问题中,我们衡量输入的数据运算过后的所属类别的方法,一般是通过概率来表示,我们得到对应类别的概率值,哪个概率值大,就认为它是属于哪个类别。

  Logistic回归是一种用于二分类问题的线性分类模型。它的基本思想是,将输入特征和对应的类别之间的关系建模为一个线性函数,并通过Sigmoid函数将其映射到[0,1]的区间上,以得到对应于正例的概率,所以可以说sigmoid函数的作用是使数据映射到[0,1]区间上。

2.4 Sigmoid函数和饱和函数的概念

  在数学中,饱和函数(Saturated Function)是一类函数,当输入值接近正或负无穷时,函数的输出值趋向于一个有限的上下限。饱和函数通常用于神经网络的激活函数,例如Sigmoid函数

  当输入值x的绝对值很大时,Sigmoid函数的输出值会趋近于0或1,因此称为“饱和函数”。

图4 Sigmoid函数

图5 Sigmoid函数图像

  Sigmoid 函数在输入为负无穷时,输出为 0,在输入为正无穷时,输出为 1,因此可以将预测结果解释为概率值,并进行阈值分类。

  在神经网络中,饱和函数作为激活函数具有平滑、可导的性质,并能够将输出值限制在一定范围内,使神经网络的训练更加稳定。但是,由于饱和函数在梯度计算时可能会出现梯度消失的问题,因此在深度神经网络中,更常使用一些非饱和函数作为激活函数,例如ReLU函数。

常见的Sigmoid函数还有:

图6 常见的Sigmoid函数

2.5 逻辑斯蒂回归模型

添加完logistic函数后的模型相较于普通的线性回归模型的变化:

图7 线性回归与逻辑斯蒂回归的模型区别

2.6 逻辑斯蒂回归损失函数

2.6.1 二分类损失函数

  二分类损失函数(Binary Cross Entropy, BCE)较线性回归的损失函数也发生了些许变化,主要是引入了交叉熵(Cross Entropy)这个概念。交叉熵是在给定一组真实标签和一组预测标签的情况下,衡量这两组标签之间的差异的一种方法。在机器学习中,交叉熵通常被用来作为损失函数,用于优化分类模型的参数。

交叉熵计算两个分布差异的公式:

图8 交叉熵计算公式

  二分类交叉熵是指特定于二元分类任务而设计的交叉熵损失函数。对于二元分类问题,二分类损失函数可以表示为:

图9 二分类损失函数

  其中,y是真实标签(0或1),ŷ是模型的预测标签(介于0和1之间的概率值)。这个公式中的第一项是如果y=1时的损失,第二项是如果y=0时的损失。

BCE和MSE的区别?

  BCE主要应用于二元分类问题,它的损失函数形式简单,可以直接衡量模型对于每个样本预测出的概率值和真实标签之间的差距。而 MSE 更适合用于回归问题,它可以衡量模型对于每个样本预测出的数值和真实值之间的差距。 

2.6.2 小批量二分类损失函数

  在实际应用中,我们通常需要对大规模数据集进行训练。如果使用全量数据计算梯度,会占用过多的内存和计算资源,从而导致训练速度缓慢或者无法完成训练。而使用小批量的BCELoss,可以将数据集分批次读入内存,逐个小批量地计算损失函数,进而计算梯度,从而加快训练速度,同时还能够有效降低内存使用和计算复杂度。

图10 小批量二分类损失函数

图11 小批量二分类损失函数的应用

由上图可以看到,如果模型的预测值越接近真实值,交叉熵的损失就越小。因此,通过最小化交叉熵损失,我们可以训练一个能够准确分类数据的模型。

3 代码实现

1、这段代码定义了一个 logistic 回归模型类 LogisticRegressionModel,继承了 Module 类。该模型使用单个线性层(Linear)来预测输入 x 的输出。

class LogisticRegressionModel(torch.nn.Module):
    def __init__(self):
        super(LogisticRegressionModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)

    def forward(self, x):
        y_pred = F.sigmoid(self.linear(x)) # sigmoid 函数将输入值压缩到 [0,1] 的区间内,可以将线性层的输出转换为概率值,用于二分类问题的预测。
        return y_pred

2、定义一个二分类交叉熵损失函数对象,并将其赋值给变量criterion:

criterion = torch.nn.BCELoss (size_average=False)

在PyTorch中,torch.nn.BCELoss是二分类交叉熵损失函数的实现,它用于度量模型输出与真实标签之间的差异,其返回值即为模型的损失值。size_average参数表示是否对每个batch的损失值求平均,默认为True,如果设为False,则不求平均,返回的是每个batch的总和。因为我们一般在训练神经网络时,采用小批量梯度下降法,因此需要对每个小批量的损失值求平均。

输出图像

import numpy as np
import matplotlib.pyplot as plt

'''这里使用NumPy的linspace函数在0到10之间生成了200个等间距的数字,
并将其转换为PyTorch张量,方便后续计算。'''
x = np.linspace(0, 10, 200)
x_t = torch.Tensor(x).view((200, 1))

y_t = model(x_t)

'''这里使用PyTorch张量的data属性将其转换为NumPy数组,
并使用matplotlib库的plot函数绘制出曲线。'''
y = y_t.data.numpy()
plt.plot(x, y)

plt.plot([0, 10], [0.5, 0.5], c='r') # 这里绘制了红色的水平分界线,表示y值为0.5时的x轴取值范围
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()

完整代码

x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])
#-------------------------------------------------------#
class LogisticRegressionModel(torch.nn.Module):
def __init__(self):
super(LogisticRegressionModel, self).__init__()
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
y_pred = F.sigmoid(self.linear(x))
return y_pred
model = LogisticRegressionModel()
#-------------------------------------------------------#
criterion = torch.nn.BCELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
#-------------------------------------------------------#
for epoch in range(1000):
y_pred = model(x_data)
loss = criterion(y_pred, y_data)
print(epoch, loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()

使用逻辑斯蒂回归处理二分类问题的整个过程:

标题

官方文档链接:https://pytorch.org/docs/stable/nn.html?highlight=bceloss#torch.nn.BCELoss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/370269.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

3-1 图文并茂说明raid0,raid1, raid10, raid01, raid5等原理

文章目录简介RAID类型RAID0RAID1RAID5RAID6RAID10RAID01RAID对比图简介 一、RAID 是什么? RAID ( Redundant Array of Independent Disks )即独立磁盘冗余阵列,简称为「磁盘阵列」,其实就是用多个独立的磁盘组成在一起…

Jenkins第一讲

目录 一、Jenkins 1.1 敏捷开发与持续集成 1.1.1 敏捷开发 1.1.2 持续集成 1.2 持续集成工具 1.2.1 jenkins和hudson 1.2.2 技术组合 1.2.3 部署方式对比 1.3 安装Jenkins 1.3.1 下载Jenkins的war包 1.3.2 开启Jenkins 1.4 Jenkins全局安全配置 1.5 使用Jenkins部…

InfluxDB docker安装与界面的使用

influxdb github主页:https://github.com/influxdata/influxdb chronograf github主页:https://github.com/influxdata/chronograf Docker安装InfluxDB docker run -p 8086:8086 --name influxdb-dev influxdb:latest这里博主安装的是2.2.1版本 然后…

Python学习-----排序问题2.0(sort()函数和sorted()函数)

目录 前言: 1.sort() 函数 示例1:阿斯克码比较 示例2:(设置reverse,由大到小排序) 示例3:基于key排序(传入一个参数) 示例4:key的其他应用 2.sorted() …

平时技术积累很少,面试时又会问很多这个难题怎么破?别慌,没事看看这份Java面试指南,解决你的小烦恼!

前言技术面试是每个程序员都需要去经历的事情,随着行业的发展,新技术的不断迭代,技术面试的难度也越来越高,但是对于大多数程序员来说,工作的主要内容只是去实现各种业务逻辑,涉及的技术难度并不高&#xf…

Allegro如何画Photoplot_Outline操作指导

Allegro如何画Photoplot_Outline操作指导 在用Allegro进行PCB设计的时候,最后进行光绘输出前,Photoplot_Outline是必备一个图形,所有在Photoplot_Outline中的图形将被输出,Photoplot_Outline以外的图形都将不被输出。 如何绘制Photoplot_Outline,具体操作如下 点击Shape点…

视觉人培训团队把它称之为,工业领域人类最伟大的软件创造,它的名字叫Halcon

目前为止,世界上综合能力强大的机器视觉软件,,它的名字叫Halcon。 视觉人培训团队把它称之为,工业领域人类最伟大的软件创造,它的名字叫Halcon。 持续不断更新最新的图像技术,软件综合能力持续提升。 综…

常量和变量——“Python”

各位CSDN的uu们你们好呀,今天,小雅兰的内容是Python的一些基础语法噢,会讲解一些常量和变量的知识点,那么,现在就让我们进入Python的世界吧 常量和表达式 变量和类型 变量是什么 变量的语法 变量的类型 常量和表达式 …

go面向对象思想封装继承多态

go貌似都没有听说过继承,当然这个继承不像c中通过class类的方式去继承,还是通过struct的方式,所以go严格来说不是面向对象编程的语言,c和java才是,不过还是可以基于自身的一些的特性实现面向对象的功能,面向…

TCP 的演化史-byte stream 和 packet

不想写太多代码,我想直接抄一个 TCP sack 实现,参考了 lwIP TCP,很遗憾:TCP: Implement handling received SACKs 无奈不得不自己实现 sack option 的处理。由于 tso/gso/lro/gro,在软件层面难免遇到下面的情况&#…

Java 如何学习?这份5000页Java学习手册值得拥有,适合零基础自学也适合查漏补缺!

学习技巧 在以前大部分人学习都是先去找本书,先看看,再试,要是不懂了在去网上去查,再在继续啃着书本。但现在向书学习和在网上学习这掌握的效果是不同的,要学会用适合自己的学习方式。 目前的学习要是能看进去书本&a…

【5】linux命令每日分享——touch创建文件

大家好,这里是sdust-vrlab,Linux是一种免费使用和自由传播的类UNIX操作系统,Linux的基本思想有两点:一切都是文件;每个文件都有确定的用途;linux涉及到IT行业的方方面面,在我们日常的学习中&…

飞桨 Tensor 介绍

Tensor 介绍 一、Tensor 的概念介绍 飞桨使用张量(Tensor) 来表示神经网络中传递的数据,Tensor 可以理解为多维数组,类似于 Numpy 数组(ndarray) 的概念。与 Numpy 数组相比,Tensor 除了支持运…

C语言 深度剖析数据在内存中的存储

目录数据类型详细介绍整形在内存中的存储:原码,反码,补码大小端字节序介绍及判断浮点型在内存中的存储解析数据类型详细介绍整形:1.为什么char类型也会归类到整形家族当中去呢?字符存储和表示的时候本质上使用的是ASCI…

【华为OD机试模拟题】用 C++ 实现 - 最大相连男生数(2023.Q1)

最近更新的博客 【华为OD机试模拟题】用 C++ 实现 - 货币单位换算(2023.Q1) 【华为OD机试模拟题】用 C++ 实现 - 选座位(2023.Q1) 【华为OD机试模拟题】用 C++ 实现 - 停车场最大距离(2023.Q1) 【华为OD机试模拟题】用 C++ 实现 - 重组字符串(2023.Q1) 【华为OD机试模…

integrationobjects/OPC AE Client ActiveX Crack

使用 OPC AE 客户端 ActiveX 进行快速 OPC 警报和事件客户端编程! OPC AE Client ActiveX包括多个 OPC ActiveX 控件,可以轻松嵌入到最流行的 OLE 容器中。这允许用户与任何 OPC AE 服务器连接并实时检索警报和事件。 这种易于使用的 OPC AE ActiveX 简化…

论文笔记|固定效应的解释和使用

DeHaan E. Using and interpreting fixed effects models[J]. Available at SSRN 3699777, 2021. 虽然固定效应在金融经济学研究中无处不在,但许多研究人员对作用的了解有限。这篇论文解释了固定效应如何消除遗漏变量偏差并影响标准误差,并讨论了使用固…

【C语言进阶】文件的顺序读写、随机读写、文本文件和二进制文件、文件读取结束的判定以及文件缓冲区相关知识

​ ​📝个人主页:Sherry的成长之路 🏠学习社区:Sherry的成长之路(个人社区) 📖专栏链接:C语言进阶 🎯长路漫漫浩浩,万事皆有期待 文章目录1.文件操作1.1 概述…

优思学院:《改变世界的机器・精益生产之道》是什么著作?

《改变世界的机器》(The Machine That Changed the World)是一本经典的商业管理书籍,由詹姆斯P温斯顿(James P. Womack)、丹尼尔T琼斯(Daniel T. Jones)和丹尼尔罗斯(Daniel Roos&am…

带组态物联网平台源码 代码开源可二次开发 web MQTT Modbus

物联网IOT平台开发辅助文档 技术栈:JAVA [ springmvc / spring / mybatis ] 、Mysql 、Html 、 Jquery 、css 使用协议和优势: TCP/IP、HTTP、MQTT 通讯协议 1.1系统简介 IOT通用物联网系统平台带组态,是一套面向通用型业务数据处理的系统…