Pytorch深度学习笔记(七)逻辑斯蒂回归

news2024/11/19 22:37:17

目录

1. logistic(逻辑斯蒂)函数

2.二分类任务(binary classification)损失函数

3.二分类任务(binary classification)最小批量损失函数

4.逻辑斯蒂回归代码实现

附:pytorch提供的数据集


 

推荐课程:06.逻辑斯蒂回归_哔哩哔哩_bilibili

回归是对连续变量预测。

分类是对离散变量预测。通过比较分类的概率来判断预测的结果。

回归&分类

以学生学习为例,回归任务:学习时间预测学习成绩,分类任务:学习时间预测通过考试的概率,两个类别标签,通过与不通过,这是一个二分类任务。

逻辑斯蒂回归是一种分类任务

1. logistic(逻辑斯蒂)函数

 这里的x更换为\hat{y}

适用于线性模型将输出值由实数空间映射到[0,1]之间,以此进行分类。与线性回归模型相比logistic(逻辑斯蒂)回归模型,多增加了一个映射函数。

映射

注:只要满足饱和函数的规定,都属于sigmoid函数,如logistic(逻辑斯蒂)函数。所以logistic回归有时也叫sigmoid

2.二分类任务(binary classification)损失函数

也称为BCELoss()函数,二分类交叉熵(cross entorpy)

在二分类任务中,\hat{y}为class=1的概率, 1-\hat{y}为class=0的概率。ylog\hat{y}交叉熵表示log前后两个分布概率的差异大小。如果y=0表示class=0的概率为1,class=1的概率为0。如y=1表示class=1的概率为1,class=0的概率为0。

当y=1时,loss=-log\hat{y},表示\hat{y}值越大越接近class=1的概率为1真实分布概率,损失值越小。当y=0时, loss=-log(1 - \hat{y}), 表示\hat{y}值越小,class=0的概率越大,越接近class=0的概率为1真实分布概率,损失值越小。可见下图。

3.二分类任务(binary classification)最小批量损失函数

 求损失量均值。

4.逻辑斯蒂回归代码实现

1.数据准备

2.设计模型

3.构造损失函数和优化器

4.训练周期(前馈—>反馈—>更新)

逻辑斯蒂回归完整代码:

import torch
import torch.nn.functional as F

#…1.准备数据………………………………………………………………………………………………………………………………………#
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
# 二分类
y_data = torch.Tensor([[0], [0], [1]])

#…2.设计模型………………………………………………………………………………………………………………………………………#
# 继承torch.nn.Module,定义自己的计算模块,neural network
class LogisticRegressionModel(torch.nn.Module):
    # 构造函数
    def __init__(self):
        # 调用父类构造
        super(LogisticRegressionModel, self).__init__()
        # 定义输入样本和输出样本的维度
        self.linear = torch.nn.Linear(1, 1)

    # 前馈函数
    def forward(self, x):
        # 返回x线性计算后的预测值
        # sigmoid()作映射变化
        y_pred = F.sigmoid(self.linear(x))
        return y_pred

#……3.构造损失函数和优化器……………………………………………………………………………………………………………#
# 实例化自定义模型,返回做logistic变化(也叫sigmoid)的预测值
model = LogisticRegressionModel()
# 实例化损失函数,返回损失值
criterion = torch.nn.BCELoss(size_average=False)
# 实例化优化器,优化权重w
# model.parameters(),取出模型中的参数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

#……4.训练周期………………………………………………………………………………………………………………………………………#
for epoch in range(1000):
    # 获得预测值
    y_pred = model(x_data)
    # 获得损失值
    loss = criterion(y_pred, y_data)
    # 不会产生计算图,因为__str()__
    print(epoch, loss.item())
    # 梯度归零
    optimizer.zero_grad()
    # 反向传播
    loss.backward()
    # 更新权重w
    optimizer.step()
# 打印权重和偏值
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())

#……5.绘图………………………………………………………………………………………………………………………………………#
#用于在大型、多维数组上执行数值运算
import numpy as np
import matplotlib.pyplot as plt

# 定义均匀间隔创建数值序列,指定间隔起始点、终止端,指定分隔值总数
x = np.linspace(0, 10, 200)
# 重新调整维度为200*1
x_t = torch.Tensor(x).view((200, 1))
y_t = model(x_t)
# 将tensor转化为numpy类型
y = y_t.data.numpy()
# 图线1,x,y 轴上的数值
plt.plot(x, y)
# 图线2,x,y 轴上的数值,设置颜色
plt.plot([0, 10], [0.5, 0.5], c='r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
# 绘制刻度线的网格线
plt.grid()
plt.show()

附:pytorch提供的数据集

pytorch的免费数据集由两个上层的API提供,分别是torchvision和torchtext。

torchvision提供了对照片数据处理相关的API和数据,数据所在位置:torchvision.datasets,比如torchvision.datasets.MNIST(手写数字照片数据),torchvision.datasets.cifar(十类彩色图像数据)。

torchtext提供了对文本数据处理相关的API和数据,数据所在位置:torchtext.datasets,比如torchtext.datasets.IMDB(电影评论文本数据)。

import torchvision
# 训练集
train_set = torchvision.datasets.MNIST(root="../dataset/mnist", train=True, download=True)
# 测试集
test_set = torchvision.datasets.MNIST(root="../dataset/mnist", train=False, download=True)
  • rootstring)– 数据集的根目录,其中存放processed/training.ptprocessed/test.pt文件。
  • train(bool, 可选)– 如果设置为True,从training.pt创建数据集,否则从test.pt创建。
  • download(bool, 可选)– 如果设置为True, 从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
  • transform可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop
  • target_transform可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。

torchvision.datasets-PyTorch 1.0 中文文档 & 教程

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/451054.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

算法小课堂(六)回溯算法

目录 一、概述 1.1概念(树形结构) 1.2区别 1.3步骤 1.4回溯法模板 1.5应用 1.6回溯三部曲 二、组合问题 2.1组合 回溯算法 优化剪枝操作 2.2组合总和 2.3组合总和2 2.4组合总和3 2.5电话号码的字母组合 三、切割问题 3.1分割回文串 3.2…

简述Nginx中的负载均衡、正向代理、反向代理

前言 今天吃饭的时候看某站,然后就刷到了一个视频,感觉图片蛮好看的,讲的也适合入门,这里就跟大家分享一下 视频链接: https://www.bilibili.com/video/BV1vm4y1z7EB/?spm_id_from333.1007.tianma.4-3-13.click&…

实现栅格形式的进度条+奇特的渐变边框效果

介绍 效果图展示:(颜色自定义哦~js控制) 实现逻辑介绍: (1)主要实现方案就是使用css渐变背景实现的。(linear-gradient) (2)因为需要js控制颜色&#xff…

DDD领域驱动设计:支付系统中的应用一

文章目录 前言一、DDD意义1 为什么需要DDD2 DDD的价值 二、DDD设计流程1 战略设计2 战术设计 三、DDD代码落地四、参考文献总结 前言 DDD作为一种优秀的设计思想,为复杂业务治理带来了曙光。然而又因为DDD本身难以掌握,很容易造成DDD从理论到工程落地之…

V4L2系列 之 V4L2驱动框架(1)

目录 前言一、V4L2驱动框架概览1、应用层 -》中间层-》驱动层2、主要代码文件(Linux 4.19版本内核) 二、怎么写V4L2驱动1、如何写一个设备的驱动?2、Video设备主要结构体3、怎么写V4L2驱动 三、V4L2的调试工具1、v4l2-ctl2、dev_debug3、v4l2-compliance 前言 本篇文…

00后卷王的自述,我真有同事口中说的那么卷?

前言 前段时间去面试了一个公司,成功拿到了offer,薪资也从14k涨到了20k,对于工作都还没几年的我来说,还是比较满意的,毕竟一些工作5、6年的可能还没我高。 我可能就是大家口中的卷王,感觉自己年轻&#xf…

一文读懂Redis哨兵

Redis哨兵(sentinel) 哨兵是什么? 吹哨人巡查监控后台master主机是否故障,如果故障了根据投票数自动将某一个从库转换为新主库,继续对外服务。 俗称,无人值守运维。 干什么? 主从监控&…

Win10系统下VS2019编译Qt的Ribbon控件 -- SARibbon

Win10系统下VS2019编译Qt的Ribbon控件 -- SARibbon 一、源码下载二、源码编译三、封装成库四、Qt配库五、运行测试 原文链接:https://blog.csdn.net/m0_51204289/article/details/126431338 一、源码下载 【1】https://gitee.com/czyt1988/SARibbon/tree/master/s…

Python开发工具PyCharm v2023.1正式发布——推出全新的用户界面

JetBrains PyCharm是一种Python IDE,其带有一整套可以帮助用户在使用Python语言开发时提高其效率的工具。此外,该IDE提供了一些高级功能,以用于Django框架下的专业Web开发。 PyCharm v2023.1正式版下载 更新日志如下: 推出新的…

【UE】三步创建自动追踪自爆可造成伤害的敌人

效果 可以看到造成伤害时在右上角打印玩家当前的生命值 步骤 1. 首先拖入导航网格体边界体积 2. 首先复制一份“ThirdPersonCharacter”,命名为“ExplodingAI” 打开“ExplodingAI”,删除事件图表中所有节点 添加一个panw感应组件 在事件图表中添加如…

机器学习实战:Python基于PCA主成分分析进行降维分类(七)

文章目录 1 前言1.1 主成分分析的介绍1.2 主成分分析的应用[](https://chat.openai.com/ "openai") 2 Mushroom分类数据演示2.1 导入函数2.2 导入数据2.3 PCA可视化2.4 PCA散点图2.5 PCA散点图 3 讨论 1 前言 1.1 主成分分析的介绍 主成分分析(Principa…

Consistency Models

Consistency Models- 理解 问题定义研究动机本文中心论点 相关工作和进展Consistency Models创新点review扩散模型 Consistency Model-Definition一致性模型的定义一致性模型参数化一致性模型采样 Training Consistency Models via DistillationTraining Consistency Models in…

ChatGPT on Notes/Domino

大家好,才是真的好。 随着春节过去,小盆友也开始陆续到幼儿园报到,我们又回来和大家一起继续Notes/Domino传奇之旅。 去年年底ChatGPT横空出世,让大家震惊了一把。 可能有些老Notes/Domino人,还不知道ChatGPT是什么…

MySQL_第11章_数据处理之增删改

第11章_数据处理之增删改 讲师:尚硅谷 - 宋红康(江湖人称:康师傅) 官网: http://www.atguigu.com 1. 插入数据 1.1 实际问题 解决方式:使用 INSERT 语句向表中插入数据。 1.2 方式1:VA…

在OpenHarmony 开发者大会2023,听见百业同鸣

加强开源,助推中国科技强国战略,已经成为中国科技繁荣的必要条件,“十四五”规划中首次提到了“开源”两个字,并明确指出,支持数字技术开源社区等创新联合体的发展。 在中国发展开源,有着拓荒的色彩&#x…

Springsecurity笔记14-18章JWT+Spring Security+redis+mysql 实现认证【动力节点】

15 SpringSecurity 集成thymeleaf 此项目是在springsecurity-12-database-authorization-method 的基础上进行 复制springsecurity-12-database-authorization-method 并重命名为springsecurity-13-thymeleaf 15.1 添加thymeleaf依赖 | <groupId>org.springframewor…

西门子s7-300/400PLC-MMC密码解密

西门子s7-300/400-MMC密码解密 简介西门子加密工具及操作密码验证 简介 目前&#xff0c;市面上或网络上有很多针对s7-200&#xff0c;300&#xff0c;400&#xff0c;1200&#xff0c;1500的密码解密破解软件&#xff0c;但很多时候只能解数字或英文密码&#xff0c;对设置了…

Linux-初学者系列——篇幅5_系统目录相关命令

系统目录相关命令-目录 一、系统目录层级1、目录绝对路径2、目录相对路径3、目录层级结构查看-tree不带任何参数获取目录结构数据信息以树形结构显示目录下的所有内容&#xff08;包含隐藏信息&#xff09;只列出根目录下第一层的目录结构信息只显示目录结构信息中的所有目录信…

ThingsBoard如何自定义topic

1、背景 业务需要,mqtt设备,他们协议和topic都定义好了,想使用tb的mqtt直接接入设备,但是设备的topic和tb规定的不一致,该如何解决呢? 2、要求 设备的topic要求规则是这样的 首先第二点是满足的,网关的发布主题是可以通过tb的设备配置来自定义遥测和属性的topic,问题…

qiankun应用级缓存-多页签缓存

需求&#xff1a; A&#xff1a;主应用 B&#xff1a;子应用 项目框架&#xff1a;vue2 全家桶 qiankun 应用间切换需要保存页面缓存&#xff08;多页签缓存&#xff09;&#xff0c;通过vue keep-alive只能实现页面级缓存&#xff0c;在单独打开的应用里能实现缓存&#xf…