PyTorch深度学习实践1——线性回归和Logistic回归

news2024/10/5 15:32:28

PyTorch的风格

  • 准备数据集
  • 使用类设计模型
  • 计算损失函数和优化器
  • 训练【前向、反向和更新】

线性回归



import torch

# 准备数据集
# x,y是矩阵,3行1列 也就是说总共有3个数据,每个数据只有1个特征
x_data = torch.tensor([[1.0], [2.0], [3.0]])
y_data = torch.tensor([[2.0], [4.0], [6.0]])

#LinearModel类继承torch的Module模块
class LinearModel(torch.nn.Module):
    #需要事先__init__和forward两个函数
    def __init__(self):
        #父类初始化
        super(LinearModel, self).__init__()
        # (1,1)是指输入x和输出y的特征维度,这里数据集中的x和y的特征都是1维的
        # 该线性层需要学习的参数是w和b  获取w/b的方式分别是~linear.weight/linear.bias
        self.linear = torch.nn.Linear(1, 1)
        #linear相当于一个类

    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred

#构造模型
model = LinearModel()

# 构造损失函数和优化器
#损失函数
criterion = torch.nn.MSELoss(size_average = False)
#size_average=false表示不将最后结果求平均值

#优化器SGD
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# model.parameters()自动完成参数的初始化操作 lr学习率

# training cycle forward, backward, update
#每一次的epoch过程总结就是
#1、前向传播计算y_pred值【预测值】
#2、根据预测值和测试值计算损失
#3、反向传播backward【计算梯度】
#4、根据梯度更新参数

for epoch in range(100):
    y_pred = model(x_data)  # forward:predict
    loss = criterion(y_pred, y_data)  # forward: loss
    #loss.item()表示直接输出数值
    print(epoch, loss.item())

    optimizer.zero_grad()  # the grad computer by .backward() will be accumulated. so before backward, remember set the grad to zero
    loss.backward()  # backward: autograd,自动计算梯度
    optimizer.step()  # update 参数,即更新w和b的值

#输出偏置值和权重
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())

#测试预测
x_test = torch.tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

Logistic回归模型

如下图所示,与线性回归模型不同的是,Logistic回归是对线性回归计算的值使用sigmoid函数对其进行变换,映射在0-1之间
在这里插入图片描述

在这里插入图片描述

Logistic回归损失函数计算公式
在这里插入图片描述
在这里插入图片描述
说明:预测与标签越接近,BCE损失越小。

对比线性回归和Logistic回归构造的类,Logistic回归是使用torch.nn.functional模块中的sigmoid函数
在这里插入图片描述
Logistic回归的损失函数为BCELoss
在这里插入图片描述
构建Logistic回归模型
第一步:准备数据
在这里插入图片描述
第二步:设计类模型
在这里插入图片描述
第三步:设置损失函数和优化器
在这里插入图片描述
第四步:训练
在这里插入图片描述
在这里插入图片描述
可视化
在这里插入图片描述

import torch
# import torch.nn.functional as F
 
# prepare dataset
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])
 
#design model using class
class LogisticRegressionModel(torch.nn.Module):
    def __init__(self):
        super(LogisticRegressionModel, self).__init__()
        self.linear = torch.nn.Linear(1,1)
 
    def forward(self, x):
        # y_pred = F.sigmoid(self.linear(x))
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred
model = LogisticRegressionModel()
 
# construct loss and optimizer
# 默认情况下,loss会基于element平均,如果size_average=False的话,loss会被累加。
criterion = torch.nn.BCELoss(size_average = False) 
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)
 
# training cycle forward, backward, update
for epoch in range(1000):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    print(epoch, loss.item())
 
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
 
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())
 
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1002503.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

关于Java的类加载机制

1、概述 类会在运行期间第一次使用时,被类加载器动态加载至JVM。JVM不会一次性加载所有类。因为如果一次性加载,会占用很多的内存。 2、类的生命周期 类的生命周期包括以下 7 个阶段: 加载(Loading)验证(…

关于软件的功能复用

有一些人总在说软件要复用,开发一个项目时要想想怎么在另一个项目中能重用。你问他怎么做到复用,就会听到微服务、中台一些名词 复用的层次 说到复用,首先要想明白复用的是啥 级别越低,粒度越小,复用的范围越广&#…

实现div的height从0到auto的过渡效果

通过修改max-height打到高度自适应的过程。 展开状态 收起状态 一般场景描述需求(与项目业务无关): 需要完成一种过渡效果,即height是变化的,但不是数字到数字的变化,因为不知道展开之后的高度到底是多少?不确定!!!。…

Qt加载本地图片转为YUV420P格式数据

一、背景介绍 在流媒体应用中,视频编码是必不可少的一环。视频编码的作用是将高带宽、高码率的原始视频流压缩成低带宽、低码率的码流,以便于传输和存储。H264是一种高效的视频编码标准,具有良好的压缩性能和广泛的应用范围,在实…

《C++ primer》练习3.20:输出vector相邻元素的和输出vector头尾对象的和

最近看《C primer》,有这样一个题目 输出vector相邻元素的和 读入一组整数并把它们存入一个vector对象,将每对相邻整数的和输出出来。 这里要注意输入的奇数个和偶数个的数的区别。偶数个整数的话刚好数全部用完,奇数个整数最后一个数空出来…

淘宝平台开放接口API接口

淘宝平台开放接口API接口是指淘宝平台提供给第三方开发者的一组接口,用于实现与淘宝平台的数据交互和功能扩展。通过API接口,第三方开发者可以获取淘宝平台上的商品信息、订单信息、用户信息等数据,也可以实现商品的发布、订单的创建和支付等…

【图解RabbitMQ-7】图解RabbitMQ五种队列模型(简单模型、工作模型、发布订阅模型、路由模型、主题模型)及代码实现

🧑‍💻作者名称:DaenCode 🎤作者简介:CSDN实力新星,后端开发两年经验,曾担任甲方技术代表,业余独自创办智源恩创网络科技工作室。会点点Java相关技术栈、帆软报表、低代码平台快速开…

2023年世界机器人大会回顾

1、前记: 本次记录是我自己去世界机器人博览会参观的一些感受,所有回顾为个人感兴趣部分的机器人产品分享。整个参观下来最大的感受就是科学技术、特别是机器人技术和人工智能毫无疑问地、广泛的应用在我们日常生活的方方面面,在安全巡检、特…

【洛谷算法题】P5704-字母转换【入门1顺序结构】

👨‍💻博客主页:花无缺 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 本文由 花无缺 原创 收录于专栏 【洛谷算法题】 文章目录 【洛谷算法题】P5704-字母转换【入门1顺序结构】🌏题目描述🌏输入格式&a…

ClickHouse进阶(十三):Clickhouse数据字典-3-文件数据源及Mysql数据源

进入正文前,感谢宝子们订阅专题、点赞、评论、收藏!关注IT贫道,获取高质量博客内容! 🏡个人主页:含各种IT体系技术,IT贫道_大数据OLAP体系技术栈,Apache Doris,Kerberos安全认证-CSDN博客 📌订阅…

metinfo __ 6.0.0 __ file-read

metinfo __ 6.0.0 __ file-read 说明内容漏洞编号–漏洞名称MetInfo 6.0.0 任意文件读取漏洞漏洞评级高危影响范围6.0.0.0漏洞描述MetInfo 存在任意文件读取漏洞,攻击者利用该漏洞,在具有权限的情况下,可以读取网站任意文件,包括…

ASP.NET+sqlserver公司项目管理系统

一、源码描述 这是一款简洁十分美观的ASP.NETsqlserver源码,界面十分美观,功能也比较全面,比较适合 作为毕业设计、课程设计、使用,感兴趣的朋友可以下载看看哦 二、功能介绍 该源码功能十分的全面,具体介绍如下&…

centos定期清理磁盘

centos/linux定期清理磁盘 要定时清理空间,我们需要了解一个命令,find 命令,这个命令可以查询目录下特定文件名,生成日期的文件 小白教程,一看就会,一做就成。 1.查找需要删除的 find /data_back/zhhyba…

buuctf crypto 【Cipher】解题记录

1.打开题目就有密文 2.一点思路没有,看看大佬的wp(BUUCTF Cipher 1_cipher buuctf_玥轩_521的博客-CSDN博客),捏麻麻的原来玄机就在“公平的玩吧”这句话里,playfair也是一种加密方式,密钥猜测也是playfair…

SMB 协议详解之-NTLM身份认证

前面的文章说明了SMB协议交互的过程,在SMB交互的Session Setup Request/Response会对请求者的身份进行验证,这其中涉及到两个主要的协议NTLM以及Kerberos,本文将对NTLM协议进行详细的说明。 什么是NTLM NTLM是 NT LAN Manager (NTLM) Authentication Protocol 的缩写,主要…

MediaCodec源码分析 configure流程

前言 本文梳理MediaCodec configure流程,基于7.0代码,这里只分析AVC和HEVC的视频硬解,流程图如下。 代码见: frameworks/base/media/java/android/media/MediaCodec.java frameworks/base/media/jni/android_media_MediaCodec.h frameworks/base/media/jni/android_media_…

C#,数值计算——指数微分(exponential deviates)的计算方法与源程序

1 文本格式 using System; namespace Legalsoft.Truffer { /// <summary> /// 指数偏差 /// Structure for exponential deviates. /// </summary> public class Expondev : Ran { private double beta { get; set; } /// <s…

windows或者任何系统通过二进制安装最新的Protocol Buffer Compiler

此处使用二进制法安装,适用于任何操作系统 安装预编译的二进制文件&#xff08;任何操作系统&#xff09; 要从预编译的二进制文件安装最新版本的协议编译器&#xff0c;请按照以下说明操作&#xff1a; 1.从 github.com/google/protobuf/releases 手动下载与您的操作系统和计…

学习 [Spring MVC] 的JSR 303和拦截器,提高开发效率

&#x1f3ac; 艳艳耶✌️&#xff1a;个人主页 &#x1f525; 个人专栏 &#xff1a;《推荐】Spring与Mybatis集成整合》 ⛺️ 生活的理想&#xff0c;不断更新自己 ! 1.JSR303 1.1JSR303是什么 JSR 303是Java规范请求&#xff08;Java Specification Request&#xff09;…

【AI理论学习】语言模型Performer:一种基于Transformer架构的通用注意力框架

语言模型Performer&#xff1a;一种基于Transformer架构的通用注意力框架 Performer论文解读Regular Attention MechanismFAVOR&#xff1a;通过矩阵相关性实现快速注意力 Attention的时间复杂性绕过softmax瓶颈通过Gaussian kernel求Softmax kernel寻找更稳定的Softmax内核使用…