Pytorch 实现简单的 线性回归 算法

news2024/11/15 8:49:03

Pytorch实现简单的线性回归算法

简单 tensor的运算

Pytorch涉及的基本数据类型是tensor(张量)和Autograd(自动微分变量)

import torch
x = torch.rand(5, 3) #产生一个5*3的tensor,在 [0,1) 之间随机取值
y = torch.ones(5, 3) #产生一个5*3的Tensor,元素都是1  

z = x + y                    #两个tensor可以直接相加
q = x.mm(y.transpose(0, 1))  #x乘以y的转置  mm为矩阵的乘法,矩阵相乘必须某一个矩阵的行与另一个矩阵的列相等

Tensor与numpy.ndarray之间的转换

import numpy as np        #导入numpy包
a = np.ones([5, 3])       #建立一个5*3全是1的二维数组(矩阵)
b = torch.from_numpy(a)   #利用from_numpy将其转换为tensor
c = torch.FloatTensor(a)  #另外一种转换为tensor的方法,类型为FloatTensor,还可以使LongTensor,整型数据类型
b.numpy()                 #从一个tensor转化为numpy的多维数组
from torch.autograd import Variable                  # 导入自动梯度的运算包,主要用Variable这个类
x = Variable(torch.ones(2, 2), requires_grad=True)   # 创建一个Variable,包裹了一个2*2张量,将需要计算梯度属性置为True

用pytorch做一个简单的线性关系预测

线性关系是一种非常简单的变量之间的关系,因变量和自变量在线性关系的情况下,可以使用线性回归算法对一个或多个因变量和自变量间的线性关系进行建模,该模型的系数可以用最小二乘法进行求解。生活中的场景往往会比较复杂,需要考虑多元线性关系和非线性关系,用其他的回归分析方法求解。


x = Variable(torch.linspace(0, 100, 100).type(torch.FloatTensor))  # 生成一些样本点作为原始数据
rand = Variable(torch.randn(100)) * 10                             # 随机生成100个满足标准正态分布的随机数,均值为0,方差为1.将这个数字乘以10,标准方差变为10
y = x + rand                                                       # 将x和rand相加,得到伪造的标签数据y。所以(x,y)应能近似地落在y=x这条直线上

import matplotlib.pyplot as plt  
plt.figure(figsize=(10,8))                    #设定绘制窗口大小为10*8 inch
plt.plot(x.data.numpy(), y.data.numpy(), 'o') #绘制数据,考虑到x和y都是Variable,需要用data获取它们包裹的Tensor,并专成numpy
plt.xlabel('X') 
plt.ylabel('Y') 
plt.show() 

在这里插入图片描述

构建模型

#a,b就是要构建的线性函数的系数
a = Variable(torch.rand(1), requires_grad = True) #创建a变量,并随机赋值初始化
b = Variable(torch.rand(1), requires_grad = True) #创建b变量,并随机赋值初始化
print('Initial parameters:', [a, b])

learning_rate = 0.0001 #设置学习率
for i in range(1000):
    ### 增加了这部分代码,清空存储在变量a,b中的梯度信息,以免在backward的过程中会反复不停地累加
    if (a.grad is not None) and (b.grad is not None):  
        a.grad.data.zero_() 
        b.grad.data.zero_() 
    predictions = a.expand_as(x) * x+ b.expand_as(x)  #计算在当前a、b条件下的模型预测数值
    # 在 PyTorch 中,a.expand_as(x) 用于将张量 a 扩展(expand)为与张量 x 具有相同的形状
    loss = torch.mean((predictions - y) ** 2)         #通过与标签数据y比较,计算误差
    print('loss:', loss)

    loss.backward() #对损失函数进行梯度反传,backward的方向传播算法
    a.data.add_(- learning_rate * a.grad.data)  #利用上一步计算中得到的a的梯度信息更新a中的data数值
    b.data.add_(- learning_rate * b.grad.data)  #利用上一步计算中得到的b的梯度信息更新b中的data数值

绘制结果


x_data = x.data.numpy()                       # 将tensor 转为 numpy
plt.figure(figsize = (10, 7))
xplot = plt.plot(x_data, y.data.numpy(), 'o') # 绘制原始数据
yplot = plt.plot(x_data, a.data.numpy() * x_data + b.data.numpy())  #绘制拟合数据
plt.xlabel('X') 
plt.ylabel('Y') 
str1 = str(a.data.numpy()[0]) + 'x +' + str(b.data.numpy()[0]) # 图例信息 拟合的直线
plt.legend(['Obs', 'Model']) #绘制图例
plt.show()

在这里插入图片描述

x_test = Variable(torch.FloatTensor([1, 2, 10, 100, 1000])) #随便选择一些点1,2,……,1000
predictions = a.expand_as(x_test) * x_test + b.expand_as(x_test) #计算模型的预测结果
predictions  #输出预测的数值

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1813738.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ATFX汇市:非农数据超预期靓丽,美指重新站上105关口

ATFX汇市:6月7日,美国劳工统计局公布5月份非农就业报告,其中提到:5月份增加了27.2万个岗位,大幅高于前值16.5万人,数据超预期靓丽;几个行业的就业人数继续呈上升趋势,其中医疗领域增…

操作系统 c语言模仿 磁盘文件操作

1.实验目的 深入了解磁盘文件系统的实现。 2.实验预备知识 文件的操作; 文件的逻辑结构和物理结构; 磁盘空间的管理; 磁盘目录结构。 3.实验内容 设计一个简单的文件系统,用文件模拟磁盘&…

springboot+vue前后端分离项目中使用jwt实现登录认证

文章目录 一、后端代码1.响应工具类2.jwt工具类3.登录用户实体类4.登录接口5.测试接口6.过滤器7.启动类 二、前端代码1.登录页index 页面 三、效果展示 一、后端代码 1.响应工具类 package com.etime.util;import com.etime.vo.ResponseModel; import com.fasterxml.jackson.…

RAG核心算法

一、分块与向量化 首先,我们的目标是创建一个向量索引,用以代表我们文档的内容,然后在运行时寻找所有这些向量与查询向量之间的最小余弦距离,以匹配最接近的语义含义。 1、分块 由于 Transformer 模型具有固定的输入序列长度,即便输入上下文窗口很大,一个句子或几个句…

【全网最有效,保姆级教程】KEPServerEX 6下载安装解决时长问题

1、下载KEPServer KEPServerEX 6下载链接(为了防止版本不兼容,一定要使用下面链接里面的版本!): https://pan.baidu.com/s/19pAXzhWa5nxduU3mi1V4Nw?pwd1234 提取码:1234 2、安装KEPServer 基本上都是默认下一步,选择中文&…

python中用列表实现栈

【小白从小学Python、C、Java】 【考研初试复试毕业设计】 【Python基础AI数据分析】 python中用列表实现栈 选择题 以下代码最后一次输出的结果是? stack [] stack.append(1) stack.append(2) stack.append(3) print(【显示】stack ,stack) print(【显示】stack.…

SpringCloud 前端-网关-微服务-微服务间实现信息共享传递

目录 1 网关获取用户校验信息并保存至请求头(前端-网关) 2 微服务获取网关中的用户校验信息(网关-微服务) 2.1 一般的做法是在公共的module中添加,此处示例为common 公共配置module中添加 2.2 定义拦截器 2.3 定义…

C++|哈希结构封装unordered_set和unordered_map

上一篇章,学习了unordered系列容器的使用,以及哈希结构,那么这一篇章将通过哈希结构来封装unordered系列容器,来进一步的学习他们的使用以及理解为何是如此使用。其实,哈希表的封装方式和红黑树的封装方式形式上是差不…

鸿蒙低代码开发的局限性

在版本是DevEco Studio 3.1.1 Release,SDK是3.1.0(API9) 的基础上。 1、低代码插件没有WebView组件。 2、低代码插件没有空白的自定义组件,当前提供的所谓自定义组件,只能用列表中提供的组件来拼接新的组件。 3、使用ets代码自定义的组件&…

JVM 常量池汇总

Tips JVM常量池分为静态常量池和运行时常量池,因为Jdk1.7后字符串常量池从运行时常量池存储位置剥离,故很多博客也是区分开来,存储位置和内容注意区别! 字符串常量池底层是由C实现,是一个类似于HashTable的数据结构&am…

Spring 中使用MyBatis

一、Mybatis 的作用 1、MyBatis(前身为iBatis)是一个开源的Java持久层框架,它主要用于与数据库交互,帮助开发者更轻松地进行数据库操作。 持久层:指的是就是数据访问层(dao),是用来操作数据库的。 2、MyB…

Filament 【表单操作】修改密码

场景描述: 新增管理员信息时需要填写密码,修改管理员信息时密码可以为空(不修改密码),此时表单中密码输入有冲突,需要对表单中密码字段进项条件性的判断,使字段在 create 操作时为必需填写&…

深度学习-注意力机制和分数

深度学习-注意力机制 注意力机制定义与起源原理与特点分类应用领域实现方式优点注意力机制的变体总结注意力分数定义计算方式注意力分数的作用注意力分数的设计总结 注意力机制(Attention Mechanism)是一个源自对人类视觉研究的概念,现已广泛…

实测 WordPress 最佳优化方案:WP Super Cache+Memcached+CDN

说起 WordPress 优化加速来可以说是个经久不衰的话题了,包括明月自己都撰写发表了不少相关的文章。基本上到现在为止明月的 WordPress 优化方案已经固定成型了,那就是 WP Super CacheMemcachedCDN 的方案,因为这个方案可以做到免费、稳定、安…

如何用R语言ggplot2画高水平期刊散点图

文章目录 前言一、数据集二、ggplot2画图1、全部代码2、细节拆分1)导包2)创建图形对象3)主题设置4)轴设置5)图例设置6)散点颜色7)保存图片 前言 一、数据集 数据下载链接见文章顶部 处理前的数据…

基于FreeRTOS+STM32CubeMX+LCD1602+MCP6S26(SPI接口)的6通道模拟可编程增益放大器Proteus仿真

一、简介: MCP6S26是模拟可 编程增益放大器(Programmable Gain Amplifiers, PGA)。它们可配置为输出 +1 V/V 到 +32 V/V 之间的增 益,输入复用器可通过 SPI 端口选择最多 6 个通道中的 一个。串行接口也可以将 PGA 置为关断模式,以降低 功耗。这些 PGA 针对高速度、低失调…

Python编程基础5

邮件编程 SMTP(Simple Mail Transfer Protocol)简单邮件传输协议,使用TCP协议25端口,它是一组用于由源地址到目的地址传送邮件的规则,由它来控制信件的中转方式。python的smtplib提供了一种很方便的途径发送电子邮件。…

【python】tkinter GUI开发: Button和Entry的应用实战探索

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…

MySQL----排序ORDER BY

在对数据进行处理的时候,我们通常需要对读取的数据进行排序。而 MySQL 的也提供了 ORDER BY 语句来满足我们的排序要求。 ORDER BY 可以按照一个或多个列的值进行升序(ASC)或降序(DESC)排序。 语法 SELECT column1…

航班进出港管理系统的设计

管理员账户功能包括:系统首页,个人中心,管理员管理,用户管理,航班信息管理,航飞降落请求管理,公告信息管理 前台账户功能包括:系统首页,个人中心,公告信息&a…