Pytorch深度学习实践(5)逻辑回归

news2025/1/11 1:17:15

逻辑回归

逻辑回归主要是解决分类问题

  • 回归任务:结果是一个连续的实数
  • 分类任务:结果是一个离散的值

分类任务不能直接使用回归去预测,比如在手写识别中(识别手写 0 − − 9 0 -- 9 09),因为各个类别之间没有大小之差。

因此,对于分类问题,我们最终的输出是个概率,即属于某个类别的概率是多少,然后从概率集合里找最大值,作为当前预测的结果

下载MNIST数据集

import torchvision
train_set = torchvision.dataset.MNIST(root="../dataset/mnist", train=True, download=True)
test_set = torchvision.dataset.MNIST(root="../dataset/mnist", train=False, dowload=True)
  • 通过train参数来指定训练集和测试集

逻辑回归

将之前的学习时长—考试分数转化为二分类任务,即学习时长—是否通过考试

x(hours)y(pass/fail)
10(fail)
20(fail)
31(pass)
4?

其中, P ( y ^ = 1 ) + P ( y ^ = 0 ) = 1 P(\hat y = 1) + P(\hat y = 0) = 1 P(y^=1)+P(y^=0)=1

当输出的概率在 0.5 0.5 0.5附近时,即模型不确定,因此通常会输出一个不确定的值

对于二分类任务,逻辑回归会先使用回归,生成一个得分值,即落在实数集区间内,然后再使用 s i g m o i d sigmoid sigmoid函数,将得分值映射到 [ 0 , 1 ] [0, 1] [0,1]区间内,得到预测概率

s i g m o i d sigmoid sigmoid函数
σ ( x ) = 1 1 + e − x \sigma (x) = \frac{1}{1+e^{-x}} σ(x)=1+ex1
在这里插入图片描述

S i g m o i d Sigmoid Sigmoid常用来做二分类任务,常具备三个特征:

  • 饱和函数
  • 单调递增
  • 有极限

当我们使用线性回归来得到逻辑回归的得分值时,逻辑回归模型的函数定义就如下所示:
y ^ = σ ( x ∗ ω + b ) \hat y = \sigma (x*\omega + b) y^=σ(xω+b)

损失函数

线性回归使用的损失函数是计算预测值和真实值之差

而对于逻辑回归,由于我们得到的是概率,是一个 0 − 1 0-1 01分布,因此需要修改损失函数
l o s s = − ( y l o g y ^ + ( 1 − y ) l o g ( 1 − y ^ ) ) loss = -(ylog\hat y + (1-y)log(1-\hat y)) loss=(ylogy^+(1y)log(1y^))
即我们比较的是分布之间的差异

交叉熵 c r o s s − e n t r o p y cross-entropy crossentropy

存在两个分布 P D 1 ( x ) P_{D1}(x) PD1(x) P D 2 ( x ) P_{D2}(x) PD2(x)

两个分布的差异程度使用公式: ∑ i = 1 n P D 1 ( x i ) l n P D 2 ( x i ) \sum_{i=1}^{n}P_{D1}(x_i)lnP_{D2}(x_i) i=1nPD1(xi)lnPD2(xi) 来衡量

上述公式越大时,两个分布的差异越小

模型的改变

模型构造的改变
class LogisticRegressionModel(torch.nn.Module):
    def __init__(self):
        super(LogisticRegressionModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)
        # 由于逻辑回归中Sigmoid函数不需要传参 所以在forward中直接计算即可
        # 在这里不需要实例化
    def forward(self, x):
        y_pred = F.sigmoid(self.linear(x))
        return y_pred

需要先将输入写入到linear()线性模型中,再使用Sigmoid()函数

模型损失函数的改变

使用交叉熵函数BCELoss

criterion = torch.nn.BCELoss(size_average=False)

整体代码

import torch
import matplotlib.pyplot as plt
import numpy as np

########## 数据集准备 ##########
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])

########## 模型定义 ##########
class LogisticRegressionModel(torch.nn.Module):
    def __init__(self):
        super(LogisticRegressionModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)
        # 由于逻辑回归中Sigmoid函数不需要传参 所以在forward中直接计算即可
        # 在这里不需要实例化
    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred

model = LogisticRegressionModel()

########## 损失函数和优化器的设置 ##########
criterion = torch.nn.BCELoss(size_average=False) # BCELoss -- 交叉熵函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

########## 模型训练 ##########
for epoch in range(1000):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    print(epoch, loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

########## 模型测试 ##########
x = np.linspace(0, 10, 200)
x_test = torch.Tensor(x).view((200, 1)) # view()相当于reshape
y_test = model(x_test)
y = y_test.data.numpy()  # 转化为np类型
plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], 'r--')
plt.xlabel("Hours")
plt.ylabel("Probability of Pass")
plt.grid()
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1950936.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python毕业设计选题协同过滤算法在音乐推荐系统

✌网站介绍:✌10年项目辅导经验、专注于计算机技术领域学生项目实战辅导。 ✌服务范围:Java(SpringBoo/SSM)、Python、PHP、Nodejs、爬虫、数据可视化、小程序、安卓app、大数据等设计与开发。 ✌服务内容:免费功能设计、免费提供开题答辩P…

【进程检测】使用pywin32捕获window进程信息

需求 检测win系统依赖服务进程的运行情况,版本信息(进程检测器)检测内外网连接情况 实现 进程检测 # 使用pywin32获取进程版本信息 def get_version_info(path):try:info GetFileVersionInfo(path, \\)ms info[FileVersionMS]ls info[…

基于单片机控制的气动机械手设计

摘 要: 机械手拥有灵活的运动结构,可以在控制系统控制下完成复杂的运动,从而实现高效率的自动化生产方式,因而成为发展工业生产技术的重要方向。气动技术和单片机技术已相当成熟,工业应用广泛,该文将基于单…

使用 useRequestURL 组合函数访问请求URL

title: 使用 useRequestURL 组合函数访问请求URL date: 2024/7/26 updated: 2024/7/26 author: cmdragon excerpt: 摘要:本文介绍了Nuxt 3中的useRequestURL组合函数,用于在服务器端和客户端环境中获取当前页面的URL信息。通过示例展示了如何在页面中…

html+css 实现水波纹按钮

前言:哈喽,大家好,今天给大家分享htmlcss 绚丽效果!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏关注哦 💕 文…

Vue 3 实现左侧列表点击跳转滚动到右侧对应区域的功能

使用 Vue 3 实现左侧列表点击跳转到右侧对应区域的功能 1. 引言 在这篇博客中,我们将展示如何使用 Vue 3 实现一个简单的页面布局,其中左侧是一个列表,点击列表项时,右侧会平滑滚动到对应的内容区域。这种布局在很多应用场景中都…

云计算实训15——shell脚本、变量、自动化安装脚本、条件判断、循环

一、shell 脚本 1.基本概念 shell脚本就是由Shell命令组成的执行文件,将一些命令整合到一个文件 中,进行处理业务逻辑,脚本不用编译即可运行,它从一定程度上减轻 了工作量,提高了工作效率,还可以批量、定…

云服务部署项目(Spring + Vue)

云计算:腾讯云 操作系统:Ubuntu 22.04.4 LTS 项目:若依前后端分离项目(SpringBoot Vue) 首先要安装基本的一些依赖环境,大家可以看一下我往期的文章: Ubuntu在线JDK Ubuntu在线安装Nginx Ubunt…

Debain安装PostgreSql

目录 Debian和Centos区别 安装PostgreSql 更新包索引: 安装 PostgreSQL: 配置自动启动和启用 PostgreSQL 服务: 配置postGreSql 切换到 PostgreSQL 用户: 访问 PostgreSQL Shell: 设置密码 退出 PostgreSQL …

【C++题解】1066. 字符图形2-星号直角

问题&#xff1a;1066. 字符图形2-星号直角 类型&#xff1a;嵌套循环、图形输出 题目描述&#xff1a; 打印字符图形。 输入&#xff1a; 一个整数&#xff08; 0<n<10 &#xff09;。 输出&#xff1a; 一个字符图形。 样例&#xff1a; 输入&#xff1a; 3输…

集成显卡和独立显卡有什么区别?

显卡&#xff08;Video card&#xff0c;Graphics card&#xff09;全称显示接口卡&#xff0c;又称显示适配器&#xff0c;是计算机最基本配置、最重要的配件之一。平时大家可能都听过集显和独显&#xff0c;那么这两个显卡类别都是什么意思呢&#xff1f; 集成显卡 集成显卡…

Unity UGUI 之 坐标转换

本文仅作学习笔记与交流&#xff0c;不作任何商业用途 本文包括但不限于unity官方手册&#xff0c;唐老狮&#xff0c;麦扣教程知识&#xff0c;引用会标记&#xff0c;如有不足还请斧正 本文在发布时间选用unity 2022.3.8稳定版本&#xff0c;请注意分别 前置知识&#xff1a;…

混凝土做的风机?风领新能源推出钢-超高性能混凝土体内预应力塔筒

随着全球能源结构的转型和技术的不断进步&#xff0c;陆上风电行业已正式迈入平价时代。这一里程碑式的转变不仅降低了风电项目的投资门槛&#xff0c;更激发了行业对于机组大型化的热情与追求。在这样的市场下&#xff0c;主流风机制造商纷纷推出10MW平台级别机组&#xff0c;…

VLAN通讯实验

目录 拓扑图 需求 需求分析 配置过程 1、手工配置 2、 使用DHCP获得IP地址信息 3、测试全网是否可达 拓扑图 需求 1、PC1、PC3属于VLAN 2 2、PC2、PC4属于VLAN 3 3、通过DHCP使得PC获取IP地址信息 4、全网可达 需求分析 1、先手工配置网段&#xff0c;VLAN 2为192.168.1…

杭州东网约车管理再出行方面取得的显著成效

随着科技的飞速发展&#xff0c;网约车已成为人们日常出行的重要选择。在杭州这座美丽的城市&#xff0c;网约车服务更是如雨后春笋般蓬勃发展。特别是杭州东站&#xff0c;作为杭州的重要交通枢纽&#xff0c;网约车管理显得尤为重要。近日&#xff0c;沧穹科技郑重宣告已助力…

【iOS】 消息传递和消息转发

消息传递和消息转发 objc_msgSend执行流程消息发送动态方法解析消息转发super面试题 objc_msgSend执行流程 OC中的方法调用&#xff0c;其实都是转化为objc_msgSend函数调用&#xff0c;objc_msgSend的执行流程可以分为3大阶段 消息发送 动态方法解析 消息转发 消息发送 消…

情绪稳定的人有什么特点?

第一部分&#xff1a;至纯之人&#xff0c;大器晚成 1.1 单纯&#xff0c;不是天真 你知道吗&#xff1f;那些能够成就大事的人&#xff0c;往往在人性上非常单纯。他们对外界的需求很低&#xff0c;更多的是向内寻求。这样的人&#xff0c;他们的内心世界像一片净土&#xff…

Internet Download Manager(IDM)6.43免费试用体验版下载

1. Internet Download Manager&#xff08;IDM&#xff09;是一款功能强大的下载管理器&#xff0c;能够将下载速度提高多达5倍。 2. IDM具备多线程下载、断点续传、自动捕获链接等特性。 IDM马丁正版下载如下: https://wm.makeding.com/iclk/?zoneid34275 idm最新绿色版一…

HarmonyOS 核心功能:一多特性

简述定义和目标&#xff0c;分述界面级&#xff08;核心&#xff09;、功能级、工程级的一多开发 定义&#xff1a; 一套代码&#xff0c;一次开发&#xff0c;多端按需部署 目标&#xff1a;支持开发者快速开发多设备上的应用 一 界面级一多开发&#xff08;重点&#xf…

宝塔单ip,新建多站点

报错如上&#xff1a; 那么如何新建多站点呢 先随便写个名字上去&#xff0c;然后再重新绑定别的端口… 这个时候访问99端口即可 。 如果是有域名&#xff0c;则不需要这样做 、直接80端口也可以多站点