pytorch(四、五)用pytorch实现线性回归和逻辑斯蒂回归(分类)

news2025/1/11 6:55:54

文章目录

  • 线性回归
    • 代码过程
    • 准备数据
    • 设计模型
    • 设计构造函数与优化器
    • 训练过程
    • 训练代码和结果
    • pytorch中的Linear层的底层原理(个人喜欢,不用看)
      • 普通矩阵乘法实现
      • Linear层实现
    • 回调机制
  • 逻辑斯蒂回归
    • 模型
    • 损失函数
    • 代码和结果

线性回归

代码过程

训练过程:

  1. 准备数据集
  2. 设计模型(用来计算 y ^ \hat y y^
  3. 构造损失函数和优化器(API)
  4. 训练周期(前馈、反馈、更新)

准备数据

这里的输入输出数据均表示为3×1的,也就是维度均为1

# 行表示实例数量,列表示维度feature
import torch
x_data=torch.Tensor([[1.0],[2.0],[3.0]])
y_data=torch.Tensor([[2.0],[4.0],[6.0]])

设计模型

模型继承Module类,并且必须要实现 init 和 forward 两个方法,其中self.linear=torch.nn.Linear(1,1)表示实例化Linear类,这个类是可调用的,其__call__函数调用了 forward 方法

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel,self).__init__()
        # weight 和 bias 1 1 
        self.linear=torch.nn.Linear(1,1)
        
    def forward(self,x):
        # callable
        y_pred=self.linear(x)
        return y_pred

# callable
model=LinearModel()

pytorch中的linear类是在某一个数据上应用线性转换,其公式表达为 y = x w T + b y=xw^T+b y=xwT+b

class torch.nn.Linear(in_features,out_features,bias=True) :其中in_features和out_features分别表示输入和输出的数据的维度(列的数量),bias表示偏置,默认是true,该类有两个参数

  • weight:可学习参数,值从均匀分布 U ( − k , k ) U(-\sqrt k,\sqrt k) U(k ,k )中获取,其中 k = 1 i n _ f e a t u r e s k=\frac{1}{in\_features} k=in_features1
  • bias:shape和输出的维度一样,也是从分布 U ( − k , k ) U(-\sqrt k,\sqrt k) U(k ,k )中初始化的
    在这里插入图片描述

设计构造函数与优化器

# 构造损失函数和优化器
criterion=torch.nn.MSELoss(size_average=False)

# w和b--->parameters
opyimizer=torch.optim.SGD(model.parameters(),lr=0.01)

在这里插入图片描述
在这里插入图片描述

训练过程

# 训练过程
for epoch in range(100):
    y_pred=model(x_data)
    loss=criterion(y_pred,y_data)
    # loss标量,自动调用__str__()
    print(epoch,loss)
    
    optimizer.zero_grad()
    # backward
    loss.backward()
    # update
    optimizer.step()

训练代码和结果

# 行表示实例数量,列表示维度feature
import torch
x_data=torch.Tensor([[1.0],[2.0],[3.0]])
y_data=torch.Tensor([[2.0],[4.0],[6.0]])

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel,self).__init__()
        # weight 和 bias 1 1 
        self.linear=torch.nn.Linear(1,1)
        
    def forward(self,x):
        # callable
        y_pred=self.linear(x)
        return y_pred

# callable
model=LinearModel()

# 构造损失函数和优化器
criterion=torch.nn.MSELoss(size_average=False)
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)

# 训练过程
for epoch in range(100):
    y_pred=model(x_data)
    loss=criterion(y_pred,y_data)
    # loss标量,自动调用__str__()
    print(epoch,loss)
    
    optimizer.zero_grad()
    # backward
    loss.backward()
    # update
    optimizer.step()
    
    
# 打印信息
print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())

x_test=torch.Tensor([4.0])
y_test=model(x_test)
print('y_pred=',y_test.data)

在这里插入图片描述


pytorch中的Linear层的底层原理(个人喜欢,不用看)

我们在课本使用到的线性函数的基本公式表达为 y = x w T + b y=xw^T+b y=xwT+b,但是在Linear层中,当输入特征被Linear层接收是,它会接收后转置,然后乘以权重矩阵,得到的是输出特征的转置,换句话说可以在底层使用Linear,它实际上做的是 y T = w x T + b y^T=wx^T+b yT=wxT+b。可以使用下面的案例进行验证:

在这里插入图片描述

普通矩阵乘法实现

很明显,上面的图标表示一个 3×4 的矩阵乘以 4×1 的矩阵,得到一个 3×1 的输出矩阵,使用普通矩阵的乘法实现如下。

import torch

in_features=torch.tensor([1,2,3,4],dtype=torch.float32)
weight_matrix=torch.tensor([
    [1,2,3,4],
    [2,3,4,5],
    [3,4,5,6]
],dtype=torch.float32)

weight_matrix.matmul(in_features)# 矩阵乘法

实现截图:
在这里插入图片描述

Linear层实现

# 这里还是使用上面使用过的数据
import torch
in_features=torch.tensor([1,2,3,4],dtype=torch.float32)
weight_matrix=torch.tensor([
    [1,2,3,4],
    [2,3,4,5],
    [3,4,5,6]
],dtype=torch.float32)

print(weight_matrix.matmul(in_features))# 矩阵乘法

fc = torch.nn.Linear(in_features=4, out_features=3, bias=False)
# 这里是随机一个权重矩阵
print('fc.weight',fc.weight)
fc(in_features)

输出结果:
在这里插入图片描述

print('fc.weight',fc.weight)

# 使用上面的权重矩阵进行计算
fc.weight = torch.nn.Parameter(weight_matrix)
print('fc.weight',fc.weight)
fc(in_features)

结果截图:
在这里插入图片描述

可以看到上面截图与下面的截图的区别,一开始随机一个权重的时候,进行运算,使用到前面提及到的权重矩阵后,Linear层进行运算之后,得到与使用普通矩阵乘法一样的结果,相同的结果说明,Linear底层的实现与上面的矩阵乘法的逻辑是一致的

以上的论证可以说明,Linear的底层实现其实是 y T = w x T + b y^T=wx^T+b yT=wxT+b,而不是 y = x w T + b y=xw^T+b y=xwT+b,可能会有人好奇,为什么书本上都是写的后者而不是写前者,其实本质上二者都一样,前者的转置就是后者。

回调机制

在pytorch学习(一)线性模型中,第一个代码中,我们没有通过pytorch实现线性模型的时候,我们会显式调用forward函数,计算前馈的值,我们是这样写的y_pred_val=forward(x_val),但是在使用pytorch之后,我们是这样写的y_pred=model(x_data),直接实例化一个对象,然后通过对象直接计算预测值(前馈值),但是并没有使用到forward函数。这是因为pytorch模块类中实现了python中一个特殊的函数,也就是回调函数

如果一个类实现了回调方法,那么只要对象实例被调用,这个特殊的方法也会被调用。我们不直接调用forward()方法,而是调用对象实例。在对象实例被调用之后,在底层调用了__ call __方法,然后调用了forward()方法。这适用于所有的PyTorch神经网络模块。

以上仅代表小白个人学习观点,如有错误欢迎批评指正。

参考



逻辑斯蒂回归

逻辑斯蒂回归解决的事分类问题,分类输出的是类别的概率

模型

在线性模型中,通过 y = w x + b y=wx+b y=wx+b输出的是一个实数值,但是在分类问题中,输出的是类别的概率,所以需要一个函数,把实数值映射到[0,1]之间,表示概率,这个函数为 sigmoid 函数 y = 1 1 + e − x ∈ [ 0 , 1 ] y=\frac{1}{1+e^{-x}}∈[0,1] y=1+ex1[0,1],sigmoid函数属于饱和函数。

以上,逻辑斯蒂回归模型的公式为: y ^ = σ ( x ∗ w + b ) \hat{y}=\sigma(x*w+b) y^=σ(xw+b)

在这里插入图片描述

损失函数

在线性回归中,计算损失一般是使用均方误差(预测值与真实值的差值的平方和的累加),在回归问题中,均方误差表示数轴上两个值之间的距离,但是分类问题中,输出的结果表示的是概率(分布),使用距离是没有意义的,所以分类问题中的损失函数并不是均方误差。

在逻辑斯蒂回归中,使用的是BCE
在这里插入图片描述

代码和结果

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt


x_data=torch.Tensor([[1.0],[2.0],[3.0]])
y_data=torch.Tensor([[0],[0],[1]])

class LogisticRegressionModel(torch.nn.Module):
    def __init__(self):
        super(LogisticRegressionModel,self).__init__()
        self.linear=torch.nn.Linear(1,1)
        
    def forward(self,x):
        y_pred=F.sigmoid(self.linear(x))
        return y_pred
    
model=LogisticRegressionModel()

# 损失函数
criterion=torch.nn.BCELoss(size_average=False)
# 优化器
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)

# 训练
for epoch in range(1000):
    y_pred=model(x_data)
    loss=criterion(y_pred,y_data)
    print(epoch,loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
# linspace与range函数类似,用于生成均匀分布的数值序列
# np.linspace(start=0,stop=10,num=200)
x=np.linspace(0,10,200)
# 数据集,test,生成200*1的矩阵
x_t=torch.Tensor(x).view((200,1))
y_t=model(x_t)
y=y_t.data.numpy()
plt.plot(x,y)
plt.plot([0,10],[0.5,0.5],c='r')
plt.xlabel("Hours")
plt.ylabel("Probability of Pass")
# 显示网格线
plt.grid()
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1499194.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

jumpserver项目配置讲解

下载地址:https://community.fit2cloud.com/#/products/jumpserver/downloads 产品文档:https://docs.jumpserver.org/zh/v3/ [rootbogon ~]# tar -xf jumpserver-offline-installer-v3.9.3-amd64.tar.gz [rootbogon ~]# cd jumpserver-offline-instal…

Python Tkinter GUI 基本概念

归纳编程学习的感悟, 记录奋斗路上的点滴, 希望能帮到一样刻苦的你! 如有不足欢迎指正! 共同学习交流! 🌎欢迎各位→点赞 👍 收藏⭐ 留言​📝如果停止,就是低谷&#xf…

【机器学习】实验6,基于集成学习的 Amazon 用户评论质量预测

清华大学驭风计划课程链接 学堂在线 - 精品在线课程学习平台 (xuetangx.com) 代码和报告均为本人自己实现(实验满分),此次代码开源大家可以自行参考学习 有任何疑问或者问题,也欢迎私信博主,大家可以相互讨论交流哟…

Go的安装

一. 下载地址 Go官方下载地址:https://golang.org/dl/ Go中文网:https://go.p2hp.com/go.dev/dl/ 根据不同系统下载不同的包。 二. 配置GOPATH GOPATH是一个环境变量,用来表明你写的go项目的存放路径。 GOPATH路径最好只设置一个&#xff0…

C++ 特殊的类设计

目录 1.请设计一个类,不能被拷贝 2. 请设计一个类,只能在堆上创建对象 3. 请设计一个类,只能在栈上创建对象 4. 请设计一个类,不能被继承 5. 请设计一个类,只能创建一个对象(单例模式) 1.请设计一个类,…

Ant Design Vue 修改Model弹框 样式不生效

今天在使用 Ant Design Vue 组件库中又踩了一个坑 其他的样式都可以更改,唯独更改 Model 弹框组件的样式一直不生效 于是研究了好久才找到样式不生效的原因 最后又折腾了好久,参考了不少资料才得出的解决方案:

蓝桥杯备赛之二分专题

常用的算法二分模板 1. 在数组a[]中找大于等于x的第一个数的下标 //int ans lower_bound(a, a n, x) - a //相当于下方 int l 0, r n - 1; while(l < r) {int mid l r >> 1;if(a[mid] > x) r mid;else l mid 1; } cout << r;2. 在数组a[]中找大于…

CVPR 2022 Oral | Bailando: 基于编舞记忆和Actor-Critic GPT的3D舞蹈生成

目录 测试结果&#xff1a; 02 提出的方法 测试结果&#xff1a; 预测有3个步骤&#xff0c;速度比较慢 02 提出的方法 1. 针对舞蹈序列的VQ-VAE和编舞记忆 与之前的方法不同&#xff0c;我们不学习从音频特征到 3D 关键点序列的连续域的直接映射。相反&#xff0c;我们先让…

基于springboot实现线上阅读系统项目【项目源码+论文说明】

基于springboot实现线上阅读系统演示 摘要 随着社会发展速度的愈来愈快&#xff0c;以及社会压力变化的越来越快速&#xff0c;致使很多人采取各种不同的方法进行解压。大多数人的稀释压力的方法&#xff0c;是捧一本书籍&#xff0c;心情地让自己沉浸在情节里面&#xff0c;以…

基于亚马逊云科技新功能:Amazon SageMaker Canvas 无代码机器学习—以构建货物的交付状态检测模型实战为例深度剖析以突显其特性

授权说明&#xff1a;本篇文章授权活动官方亚马逊云科技文章转发、改写权&#xff0c;包括不限于在亚马逊云科技开发者社区、 知乎、自媒体平台、第三方开发者媒体等亚马逊云科技官方渠道。 亚马逊云科技 2023 re:Invent 全球大会是亚马逊云科技举办的一场技术盛会&#xff0c;…

数据治理实践——YY 直播业务指标治理实践

目录 一、问题背景 1.1 问题场景 1.2 问题小结 二、治理方案 2.1 治理目标 2.2 团队协同&#xff0c;共建规范 2.3 指标管理的定位 2.4 指标管理的目标及思路 2.5 指标管理&#xff0c;规范内容落地 2.6 数仓设计-关联指标维度 2.7 数据报表开发-配置口径说明 2.8 …

windows重装系统后如何恢复自带的正版office

前言 重装系统后&#xff0c;正版office如何安装 登录微软官网 https://www.microsoft.com 下载office&#xff0c;在已购买的产品中找到Office产品&#xff0c;点击安装,选择默认即可 https://account.microsoft.com/services

信号处理--基于EEG脑电信号处理研究概述

目录 前言 EEG特点 EEG预处理 EEG通道选择 EEG数据增强 EEG 维度降低 EEG特征提取 传统特征提取 深度学习自动提取特征 未来展望 创新的预处理方法 跨被试性能问题 模型融合 参考 前言 脑电信号&#xff08;EEG&#xff09;因其安全性、便携性、易用性、高时间分…

【你也能从零基础学会网站开发】Web建站之HTML+CSS入门篇 CSS常用属性

&#x1f680; 个人主页 极客小俊 ✍&#x1f3fb; 作者简介&#xff1a;web开发者、设计师、技术分享 &#x1f40b; 希望大家多多支持, 我们一起学习和进步&#xff01; &#x1f3c5; 欢迎评论 ❤️点赞&#x1f4ac;评论 &#x1f4c2;收藏 &#x1f4c2;加关注 CSS常用属性…

python--宣传篇--personal-qrcode个性二维码

文章目录 准备代码效果 准备 代码 from MyQR import myqr import osdef get_img_qrcode(words, save_name, picture, colorizedTrue):if save_name[-3:] in ["jpg", "png", "gif"]:if picture[-3:] in ["png", "jpg", &qu…

Github 2024-03-08 Java开源项目日报 Top10

根据Github Trendings的统计,今日(2024-03-08统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Java项目9C++项目1非开发语言项目1《Hello 算法》:动画图解、一键运行的数据结构与算法教程 创建周期:476 天协议类型:OtherStar数量:63556…

Android 性能优化--APK加固(2)加密

文章目录 字符串加密图片加密如何避免应用被重新签名分发APK 加壳的方案简析DEX加密原理及实现 本文首发地址&#xff1a;https://h89.cn/archives/212.html 最新更新地址&#xff1a;https://gitee.com/chenjim/chenjimblog 通过 前文 介绍&#xff0c;我们知晓了如何使用代码…

AI安全白皮书 | “深度伪造”产业链调查以及四类防御措施

以下内容&#xff0c;摘编自顶象防御云业务安全情报中心正在制作的《“深度伪造”视频识别与防御白皮书》&#xff0c;对“深度伪造”感兴趣的网友&#xff0c;可在文章留言中写下邮箱&#xff0c;在该白皮书完成后&#xff0c;会为您免费寄送一份电子版。 “深度伪造”就是创建…

OpenCV开发笔记(七十六):相机标定(一):识别棋盘并绘制角点

若该文为原创文章&#xff0c;转载请注明原文出处 本文章博客地址&#xff1a;https://blog.csdn.net/qq21497936/article/details/136535848 各位读者&#xff0c;知识无穷而人力有穷&#xff0c;要么改需求&#xff0c;要么找专业人士&#xff0c;要么自己研究 红胖子(红模仿…

排序算法——梳理总结

✨冒泡 ✨选择 ✨插入  ✨标准写法  &#x1f3ad;不同写法 ✨希尔排序——标准写法 ✨快排 ✨归并 ✨堆排 ✨冒泡 void Bubble(vector<int>& nums) {// 冒泡排序只能先确定最右边的结果&#xff0c;不能先确定最左边的结果for (int i 0; i < nums.size(); i){…