《PyTorch深度学习实践》第六讲 逻辑斯蒂回归

news2025/1/12 18:57:42

b站刘二大人《PyTorch深度学习实践》课程第六讲逻辑斯蒂回归笔记与代码:https://www.bilibili.com/video/BV1Y7411d7Ys?p=6&vd_source=b17f113d28933824d753a0915d5e3a90


分类问题:

  • MNIST数据集:手写数字数据集;6万个训练样本,1万个测试样本,10个分类(0 ~ 9)
    • PyTorch框架中有一个配套的工具包torchvision,里面提供了一些流行的数据集,但是安装的时候并不会包含数据集,需要额外下载
image-20230630150805519
import torchvision

# root是数据集的保存位置,train表示是不是训练集,download表示是否要下载数据集
train_set = torchvision.datasets.MNIST(root='D:/pycharm_workspace/Liuer_lecturer/dataset/mnist', train=True, download=True)
test_set = torchvision.datasets.MNIST(root='D:/pycharm_workspace/Liuer_lecturer/dataset/mnist', train=False, download=True)
  • CIFAR-10数据集:5万个训练样本,1万个测试样本,10个分类
image-20230630153350033
import torchvision

train_set = torchvision.datasets.CIFAR10(...)
test_set = torchvision.datasets.CIFAR10(...)

回归任务和分类任务的区别

  • 之前的学习时间和分数的例子是回归任务,其中的 y y y是一个具体的数值
  • 分类任务的 y y y不再是具体分数,而是能否通过,fail / pass两个类别
    • 二分类 P ( y ^ = 1 ) + P ( y ^ = 0 ) = 1 P(\hat{y} = 1) + P(\hat{y} = 0) = 1 P(y^=1)+P(y^=0)=1,某一个类别的概率
image-20230630160411477
  • 分类问题中,模型的输出是输入属于某个类别的概率

使用线性模型的时候,模型输出 y ^ = w x + b \hat{y} = wx + b y^=wx+b是实数,而现在所需要的分类输出是概率,即模型输出 y ^ ∈ [ 0 , 1 ] \hat{y} \in \left[0,1\right] y^[0,1]

那么就需要将线性模型的输出值由实数空间映射到0到1之间,逻辑斯蒂回归就是用于完成该映射任务

image-20230701124932795 image-20230701125011712

其他的sigmoid function

image-20230701125236539

Logistic Regression Model

image-20230701125530919

Loss Function for Binary Classification

  • 原来的线性回归中的Loss函数是计算两个实数之间的差值,即在一个数轴上计算 y y y y ^ \hat{y} y^的距离
  • 在二分类中输出的不再是一个数值,而是一个分布,表示 y ^ \hat{y} y^分类为1的概率是多少, 1 − y ^ 1-\hat{y} 1y^就是分类为0的概率
image-20230701125807388
  • 在上述问题中,所想要比较的是两个分布之间的差异

    • KL散度(相对熵):用来衡量两个分布之间的差异程度,若两者差异越小,KL散度越小,反之亦反。当两分布一致时,其KL散度为0

      • 衡量的是当用一个分布Q来拟合真实分布P时所需要的额外信息的平均量
      image-20230701130513583
      • https://zhuanlan.zhihu.com/p/365400000
      • KL散度通常用于无监督学习任务中,如聚类、降维和生成模型等。在这些任务中,我们没有相应的标签信息,因此无法使用交叉熵来评估模型的性能,所以需要一种方法来衡量模型预测的分布和真实分布之间的差异,这时就可以使用KL散度来衡量模型预测的分布和真实分布之间的差异。
    • 交叉熵:衡量了模型预测的概率分布与真实概率分布之间的差异,即模型在预测上的不确定性与真实情况的不确定性之间的差距

      • P(x)和Q(x)分别表示真实概率分布和模型预测的概率分布中事件x的概率
      image-20230701130634723
      • https://en.wikipedia.org/wiki/Cross_entropy
      • 在机器学习中,交叉熵通常用于衡量模型预测和真实标签之间的差异。例如,在分类任务中,交叉熵被用作损失函数,以衡量模型预测的类别分布和真实标签之间的差。
    • 区别:https://baijiahao.baidu.com/s?id=1763841223452070719&wfr=spider&for=pc

image-20230701132011149

目的是让 y y y y ^ \hat{y} y^之间的差异最小

上述函数也被称为BCE

image-20230701132304618

代码实现:

image-20230701132612959 image-20230701133000577
  • size_average设置是否求均值,这个会影响到学习率的设置
image-20230701133148397

完整代码:

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# 数据集
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])


class LogisticRegressionModel(torch.nn.Module):
    def __init__(self):
        super(LogisticRegressionModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)

    def forward(self, x):
        y_pred = F.sigmoid(self.linear(x))
        return y_pred


model = LogisticRegressionModel()

# criterion = torch.nn.MSELoss(size_average=False) pytorch更新后被弃用了
criterion = torch.nn.BCELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 训练过程
for epoch in range(1000):
    y_pred = model(x_data)              # 前馈:计算y_hat
    loss = criterion(y_pred, y_data)    # 前馈:计算损失
    print(epoch, loss.item())

    optimizer.zero_grad()   # 反馈:在反向传播开始将上一轮的梯度归零
    loss.backward()         # 反馈:反向传播(计算梯度)
    optimizer.step()        # 更新权重w和偏置b

# 输出权重和偏置
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())

# 测试模型
x = np.linspace(0, 10, 200)  # 每周学习时间从0 ~ 10小时采样200个点
x_t = torch.Tensor(x).view((200, 1))  # 将学习时间x转成200行1列的张量,view类似numpy中的reshape
y_t = model(x_t)    # 输给模型
y = y_t.data.numpy()    # 将y_t的数据拿出来

plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], c='r')
plt.ylabel('Hours')
plt.xlabel('Probability of Pass')
plt.grid()
plt.show()
image-20230701134131786

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/706650.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue3+vite+ts视频背景酷炫登录模板【英雄联盟主题】

最近我准备在自己的网站上开发一个博客系统,首先要实现后台登录界面。我选择使用Vue 3 Vite TypeScript框架来构建,下面是针对该主题的详细说明: 在网页中使用视频作为背景图已经相当常见了,而且网上也有很多相关的插件可供使用…

QT Creator上位机学习(四)多线程操作

系列文章目录 文章目录 系列文章目录前言多线程操作多线程创建基本概念接口函数线程类的定义实例 线程同步基础互斥量的线程同步基于QReadWriteLock的线程同步基于QWaitCondition的线程同步基于信号量的线程同步 总结 前言 由于目前时间比较赶,同时还在学习FreeRTO…

ModaHub 魔搭社区:火山方舟是如何解决大模型互信问题的

火山方舟是一个全面的大模型服务平台,通过整合多个大模型公司的产品,为需要大模型的企业提供联系和选择的机会。它不仅提供相关工具和服务,还构建了大模型"安全互信计算架构",解决了大模型互信的问题。 这个安全互信计算…

【ArcGIS微课1000例】0069:用ArcGIS提取一条线的高程值

本实验讲解用ArcGIS软件,基于数字高程模型DEM提取一条线的高程值并导出。 文章目录 一、加载实验数据二、将线转为折点三、提取折点高程值四、导出高程值五、注意事项【相关阅读】:【GlobalMapper精品教程】060:用dem提取一条线的高程值 一、加载实验数据 本实验使用的数据…

AI创作与游戏开发(三)世界观设计

本文将从实践出发,全方位的在美术,程序,策划, 音乐方面使用AIGC进行游戏开发的辅助创作,来探索AI的上限。 写在前面 不管AI发展到什么地步,要记住一点的是。它只是工具,还是要以我为主,为我所…

Lake Shore475高斯计使用教程

475高斯计具有双排20字符真空荧光显示屏。在正常操作下,显示屏用来显示磁场读数和功能(最大、最小值、相对读数等)信息。另外也可以被配置为显示被测磁场温度和频率等信息。当设置高斯计参数或功能时,屏幕会显示操作提示和反馈信息…

华为云Could not connect to ‘121.37.92.110‘ (port 22): Connection failed.

今天在使用xshell连接服务器的时候,一直报错,爆的心态都炸了: 在输入主机和密码都正确的情况下,还是连接不上服务器: 后来经过长时间摸索,发现xshell软件要通过镜像系统来操作,而自己买的服务器…

走进人工智能|自动驾驶 迈向无人驾驶未来

前言: 自动驾驶是一种技术,通过使用传感器、人工智能和算法来使车辆能够在不需要人类干预的情况下自主地感知、决策和操作。 文章目录 序言背景核心技术支持传感器技术人工智能与机器学习 迈向无人驾驶未来目前形式领跑人困境和挑战 总结 本篇带你走进自…

【Mysql】X-DOC:Mysql数据库大量数据查询加速(定时JOB和存储过程应用案例)

X-DOC:Mysql数据库大量数据查询加速(定时JOB和存储过程应用案例) 1、案例背景2、解决思路3、实现方式3.1 开启定时调度功能3.2 创建JOB日志表3.3 创建JOB任务3.4 创建JOB3.5 JOB的维护及查看 4、总结 1、案例背景 在某中台系统中&#xff0c…

基于HTML5的手术室信息管理系统的设计与实现(源码+文档+数据库)

本文通过对现有手术室信息管理系统分析,设计了一套基于 HTML的手术室信息管理系统,实现了患者信息、手术记录及术后随访等功能,提高了手术室工作效率。 本系统实现了患者基本资料的录入及基本信息的查询,提供了术前准备情况及术中…

Android Studio 下载安装教程

在我们下载前,先来了解一下Android的4大组件: 1.活动 2.服务:类似线程,听歌时跳转发信息,后台进行播放音乐,前台交互,后台运行任务 3.广播接收者:【例1】感知充电线充电进度&#xf…

【Spring Boot统一功能处理】统一异常处理,统一的返回格式,@ControllerAdvice简单分析,即将走进SSM项目的大门! ! !

前言: 大家好,我是良辰丫,在上一篇文章中我们已经学习了一些统一功能处理的相关知识,今天我们继续深入学习这些知识,主要学习统一异常处理,统一的返回格式,ControllerAdvice简单分析.💌💌💌 🧑个人主页:良辰针不戳 &am…

邀请功能的实现分析

邀请功能 功能分析 场景:项目中出现用户邀请其他用户加入群组的功能 需求:用户点击生成邀请链接可以生成一个url,将这个url分享给其他用户,其他用户点击后对用户登录状态进行校验,校验通过即可加入群组,…

【dubbo triple provider 底层流转】

一、maven依赖 <dependency><groupId>io.netty</groupId><artifactId>netty-codec-http2</artifactId><version>4.1.90.Final</version> </dependency><dependency><groupId>org.apache.dubbo</groupId>&l…

vue3 父子组件传值 记录

最近这个组件之间传值用的较多&#xff0c;我这该死的记性&#xff0c;总给忘记写法&#xff0c;特此记录下 第一种 父传子 补充&#xff1a;LeftView.vue 是父组件&#xff1b; Video.vue 是子组件 第二种 子传父 Video.vue 子组件 第一步 引入&#xff1a; import { de…

Linux搭建Discuz论坛

环境&#xff1a;redhat 9 mysql 8 Discuz 3.5 题目要求&#xff1a;在 bbs.example.com 主机上创建 Discuz 论坛&#xff0c;数据库服务器使用 db.example.com 主机的 bbs 数据库实例&#xff0c;该实例由 MySQL数据库软件提供服务。 题目要求没有说是在一台虚拟机…

PostgreSQL学习笔记

目录 一、PostgreSQL安装 1、下载 2、安装 二、PostgreSQL操作 1、数据库操作 2、表操作 3、数据操作 一、PostgreSQL安装 本章节以windows系统安装为例&#xff0c;讲解PostgreSQL 15.0的安装过程。 1、下载 访问PostgreSQL官方网站&#xff0c;下载对应的安装包&am…

Qt/C++编写超精美自定义控件(历时9年更新迭代/超202个控件/祖传原创)

一、前言 无论是哪一门开发框架&#xff0c;如果涉及到UI这块&#xff0c;肯定需要用到自定义控件&#xff0c;越复杂功能越多的项目&#xff0c;自定义控件的数量就越多&#xff0c;最开始的时候可能每个自定义控件都针对特定的应用场景&#xff0c;甚至里面带了特定的场景的…

多元回归预测 | Matlab基于麻雀算法(SSA)优化混合核极限学习机HKELM回归预测, SSA-HKELM数据回归预测,多变量输入模型

文章目录 效果一览文章概述部分源码参考资料效果一览 文章概述 多元回归预测 | Matlab基于麻雀算法(SSA)优化混合核极限学习机HKELM回归预测, SSA-HKELM数据回归预测,多变量输入模型 评价指标包括:MAE、RMSE和R2等,代码质量极高,方便学习和替换数据。要求2018版本及以上。 …

Idea中使用Git详细教学

目录 一、配置 Git 二、创建项目远程仓库 三、初始化本地仓库 方法一&#xff1a; 方法二&#xff1a; 四、连接远程仓库 五、提交与拉取到本地仓库 六、推送到远程仓库 七、克隆远程仓库到本地 方法一&#xff1a; 方法二&#xff1a; 八、Git分支操作 一、配置 G…