用Pytorch实现线性回归(Linear Regression with Pytorch)

news2024/9/21 10:40:18

使用pytorch写神经网络的第一步就是需要准备好数据集,设计模型(用于计算y_hat(y的预测值)),构造损失函数和优化器(使用PyTorch API),写训练周期(前馈(算loss)+反馈(算梯度)+更新(更新权重))

一:准备数据

现在使用mini-batch的方式,X和Y为3x1(可以变,但是x和y要相同)的矩阵形式。

从代码中也可以看出来,x和y都是3x1的矩阵。

二:设计模型(构造计算图)

此处使用了一个仿射模型(在pytorch中叫做线性单元)

在我们设计的例子中,我们需要设置权重w的数值,和偏置量b。

那w和b的形状(几x几的矩阵),是由y_hat和x来共同确定。

之后将y_hat和y放入loss函数中进行计算,得出loss的值(一定是一个标量)。

看下模型设计的代码:

#需要继承自module ,因为module中有很多方法我们需要使用
class LinearModel(torch.nn.Module):
    def __init__(self): #构造函数 在初始化对象时默认调用的函数
        super(LinearModel,self).__init__() #super调用父类的构造
        self.linear = torch.nn.Linear(1,1) #构造一个对象 linear Unit中的w和b(linear来自父类,可以自动反向传播)
    
    def forward(self,x): #前馈需要进行的计算 发现没有backword模块,因为Module中自动根据计算图实现backword过程
        y_pred = self.linear(x)
        return y_pred

model = LinearModel() #实例化 在之后既可以使用model(x)将x传入forword中的x,求得y_pred

其中torch.nn.Linear 的使用方法如下 

三:构造loss和optimizer

此处我们使用MSEloss,需要的参事时y_hat和y,就可以求出loss。

代码如下:

criterion = torch.nn.MSELoss(size_average=False)

我们使用SGD优化器(不会构建计算图),代码如下

optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

四:训练过程

for epoch in range(100):
    y_pred = model(x_data)  #先计算出y_hat
    loss = criterion(y_pred,y_data) #再计算出loss
    print(epoch,loss.item()) 
    
    optimizer.zero_grad()#在反馈前将梯度清0
    loss.backward()#反馈
    optimizer.step()#更新

最后打印一些相关内容

# w b
print('w=',model.linear.weight.item())
print('b=',model.linear.weight.item())

#Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred=',y_test.data)

发现当range为1000时,已经达到了我们的预期。

五:整体流程

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1930691.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

线性表的链式存储结构————单链表(java)

线性表的链式存储结构————单链表(java) 文章目录 线性表的链式存储结构————单链表(java)链表的概述单链表单链表的创建插入结点的操作尾插法头插法 求单链表的长度输出单链表查找单链表数据元素对应的索引值删除数据元素总…

「AI得贤招聘官」通过首批“AI产业创新场景应用案例”评估

近日,上海近屿智能科技有限公司的「AI得贤招聘官」,经过工业和信息化部工业文化发展中心数字科技中心的严格评估,荣获首批“AI产业创新场景应用案例”。 据官方介绍,为积极推进通用人工智能产业高质量发展,围绕人工智能…

HCNA VRP基础

交换机可以隔离冲突域,路由器可以隔离广播域,这两种设备在企业网络中应用越来越广泛。随着越来越多的终端接入到网络中,网络设备的负担也越来越重,这时网络设备可以通过专有的VRP系统来提升运行效率。通过路由平台VRP是华为公司数…

Large Language Model系列之二:Transformers和预训练语言模型

Large Language Model系列之二:Transformers和预训练语言模型 1 Transformer模型 Transformer模型是一种基于自注意力机制的深度学习模型,它最初由Vaswani等人在2017年的论文《Attention Is All You Need》中提出,主要用于机器翻译任务。随…

mac安装win10到外接固态硬盘

1、制作win10系统 1.1 下载 winToUSB,打开后选择第一个 1.2 选择本地下载镜像, 我用的分区方案是适用于UEFI的GPT模式 1.3 点右下角执行,等待执行完成即可 2、mac系统下载win驱动 2.1 comman空格 搜索启动转换助理,打开后选择…

Linux shell编程学习笔记64:vmstat命令 获取进程、内存、虚拟内存、IO、cpu等信息

0 前言 在系统安全检查中,通常要收集进程、内存、IO等信息。Linux提供了功能众多的命令来获取这些信息。今天我们先研究vmstat命令。 1.vmstat命令的功能、用法、选项说明和注意事项 1.1 vmstat命令的功能 vmstat是 Virtual Meomory Statistics(虚拟内…

React 实现五子棋

简介 本文将会基于React 实现五子棋小游戏&#xff0c;游戏规则为先让5颗棋子连成1线的一方获胜。 实现效果 技术实现 页面布局 <div><table style{{border: 1px solid #000, borderCollapse: collapse, backgroundColor: lightgray}}><tbody>{squares.ma…

在 Windows 上运行 Linux:WSL2 完整指南(一)

系列文章目录 在 Windows 上运行 Linux&#xff1a;WSL2 完整指南&#xff08;一&#xff09;&#x1f6aa; 在 Windows 上运行 Linux&#xff1a;WSL2 完整指南&#xff08;二&#xff09; 文章目录 系列文章目录前言一、什么是 WSL&#xff1f;1.1 WSL 的主要特性1.2 WSL 的…

STM32智能工业自动化监控系统教程

目录 引言环境准备智能工业自动化监控系统基础代码实现&#xff1a;实现智能工业自动化监控系统 4.1 数据采集模块 4.2 数据处理与控制模块 4.3 通信与网络系统实现 4.4 用户界面与数据可视化应用场景&#xff1a;工业自动化与管理问题解决方案与优化收尾与总结 1. 引言 智能…

【unity笔记】十、Obi绳索插件使用

一. 创建绳索 1.1 新建蓝图 在Assets中右键选择创建->Obi->Rope Blueprint&#xff0c;其属性如图所示 1.2 Obi solver 在场景下创建一个obi solver对象&#xff0c;在该对象下再创建Obi Rope对象。 随后将蓝图拖到Obi Rope对象下的Obi Rope组件&#xff0c;即可看到…

【Web服务与Web应用开发】【C#】VS2019 创建ASP.NET Web应用程序,以使用WCF服务

目录 0.简介 1.环境 2.知识点 3.详细过程 1&#xff09;创建空项目 2&#xff09;添加Web表单 3&#xff09;使用Web表单的GUI设计 4&#xff09;添加服务引用 5&#xff09;在Web的button函数中调用服务&#xff0c;获取PI值 6&#xff09;测试 0.简介 本文属于一个…

环境配置|PyCharm——Pycharm本地项目打包上传到Github仓库的操作步骤

一、Pycharm端的设置操作 通过Ctrl+Alt+S快捷组合键的方式,打开设置,导航到版本控制一栏中的Git,在Git可执行文件路径中,输入Git.exe。 按照下图顺序,依次点击,完成测试。输出如图标④的结果,即可完成测试。 输出下图结果,配置Git成功,如本地未安装Git,需自行安装。

设计模式9-工厂模式(Factory Method)

[TOC](工厂模式(Factory Method)) 写在前面 对象创建模式 通过对象超级模式绕开。动态内存分配&#xff08;new)&#xff0c;来避免对象创建过程中所导致的紧耦合(依赖具体类)&#xff0c;从而支持对象创建的稳定&#xff0c;它是结构抽象之后的第一步工作。 典型模式&…

新版本安卓更换下载源解决gradle时间太久问题

老版本android studio 解决方法如下 : android studio gradle:build model执行时间太久 最近又做到安卓的任务了,下载的安卓studio最新版 这个版本的android studio 不能用上面那种老版本的方法了,需要更新方法 新版本需要跟换两个地方 gradle/wrapper/gradle-wrapper.proper…

JAVA 异步编程(异步,线程,线程池)一

目录 1.概念 1.1 线程和进程的区别 1.2 线程的五种状态 1.3 单线程,多线程,线程池 1.4 异步与多线程的概念 2. 实现异步的方式 2.1 方式1 裸线程&#xff08;Thread&#xff09; 2.1 方式2 线程池&#xff08;Executor&#xff09; 2.1.1 源码分析 2.1.2 线程池创建…

三丰云评测:免费虚拟主机与免费云服务器体验

今天我来为大家分享一下我对三丰云的评测。作为一家知名的云服务提供商&#xff0c;三丰云一直以来备受用户好评。他们提供免费虚拟主机和免费云服务器服务&#xff0c;深受网站建设者和开发者的喜爱。 首先谈谈免费虚拟主机服务。三丰云的免费虚拟主机方案性价比非常高&#x…

代码随想录算法训练营第33天|LeetCode 509. 斐波那契数、70. 爬楼梯、746. 使用最小花费爬楼梯

1. LeetCode 509. 斐波那契数 题目链接&#xff1a;https://leetcode.cn/problems/fibonacci-number/ 文章链接&#xff1a;https://programmercarl.com/0509.斐波那契数.html 视频链接&#xff1a;https://www.bilibili.com/video/BV1f5411K7mo 思路&#xff1a; 动态规划步骤…

django-ckeditor富文本编辑器

一.安装django-ckeditor 1.安装 pip install django-ckeditor2.注册应用 INSTALLED_APPS [...ckeditor&#xff0c; ]3.配置model from ckeditor.fields import RichTextFieldcontent RichTextField()4.在项目中manage.py文件下重新执行迁移&#xff0c;生成迁移文件 py…

jvm优化

1.jvm组成 什么是jvm&#xff0c;java是跨平台语言&#xff0c;对不同的平台&#xff08;windos&#xff0c;linux&#xff09;&#xff0c;有不同的jvm版本。jvm屏蔽了平台的不同&#xff0c;提供了统一的运行环境&#xff0c;让java代码无需考虑平台的差异。 jdk包含jre包含…

ValueError和KeyError: ‘bluegrass’的问题解决

项目场景&#xff1a; 项目相关背景&#xff1a; 问题描述 遇到的问题1&#xff1a; KeyError: ‘bluegrass’ 不能识别某标签 遇到的问题2&#xff1a; xml etree.fromstring(xml_str) ValueError: Unicode strings with encoding declaration are not supported. Please …