2 线性回归demo数据-参数配置|训练回归模型|常见的tensor格式|Hub模块简介|气温数据集与任务介绍

news2025/1/13 7:58:10

文章目录

    • 线性回归demo数据-参数配置
    • 线性回归demo数据-训练回归模型
    • 常见的tensor格式
      • scalar
      • vector
      • matrix

线性回归demo数据-参数配置

# 先传入数据,可以是手动定义,也可以导入,这边就直接拿一条直线y=2x+1,来进行模拟了。
# 构造x和y(x的值为0-10)(y的值为2x+1)
x_value = [i for i in range(11)]
y_value = [2*i+1 for i in x_value]
# 由于都是列表(数组),也就是需要转化为torch可以接受的模式
# np.array是将一个数组转化为张量.
# reshape即将张量的格式进行修改
x_train = np.array(x_value,dtype=np.float32)
y_train = np.array(y_value,dtype=np.float32)
x_train = x_train.reshape(-1,1)
y_train = y_train.reshape(-1,1)
# 导入包
import torch
import torch.nn as nn
# 线性回归模型
class LinearRegressionModel(nn.Module) :
    def __init__(self,input_dim,out_dim):
        super(LinearRegressionModel,self).__init__()
        # 这里是构造一个全连接层
        self.linear = nn.Linear(input_dim,out_dim)
    # 创建x与y的构建,即此处判定了输出入x即得到y的一条线性链
    def forward(self,x):
        y = self.linear(x)
        return y

input_dim = 1
out_dim = 1
model = LinearRegressionModel(input_dim,out_dim)
# 指定好参数以及损失函数
# epochs训练次数 learning_rate学习率
# SGD是一个优化器,前一个参数代表着需要优化的参数,后一个是学习率,具体一点他是一个随机梯度优化器
# MSELoss是一个损失函数,目的是用来计算出L的
epochs = 1000
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate)
criterion = nn.MSELoss()

关于linear的理解:链接
关于SGD的理解:链接
关于MSELoss损失函数的理解:链接

线性回归demo数据-训练回归模型

# 指定好参数以及损失函数
# epochs训练次数 learning_rate学习率
# SGD是一个优化器,前一个参数代表着需要优化的参数,后一个是学习率,具体一点他是一个随机梯度优化器
# MSELoss是一个损失函数,目的是用来计算出L的
epochs = 1000
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate)
criterion = nn.MSELoss()
# 训练回归模型
for epho in range(epochs):
    epho+=1
    # 转化为tensor
    inputs = torch.from_numpy(x_train)
    labels = torch.from_numpy(y_train)
    # 梯度每一步都需要清零
    optimizer.zero_grad()
    # 前向传播
    outputs = model.forward(inputs)
    # 计算损失值
    loss = criterion(outputs,labels)
    # 反向传播
    loss.backward()
    # 更新权重参数
    optimizer.step()

测试数据:

# 测试模型预测的结果
predicted = model(torch.from_numpy(x_train).requires_grad_()).data.numpy()
print(predicted)

输出结果:
在这里插入图片描述

# 模型的保存以及提取
torch.save(model.state_dict(),'model.pkl')
model.load_state_dict(torch.load('model.pkl'))

总结以及关于GPU的训练:(仅有两处不同)

# 线性回归
# 先传入数据,可以是手动定义,也可以导入,这边就直接拿一条直线y=2x+1,来进行模拟了。
# 构造x和y(x的值为0-10)(y的值为2x+1)
x_value = [i for i in range(11)]
y_value = [2*i+1 for i in x_value]
# 由于都是列表(数组),也就是需要转化为torch可以接受的模式
# np.array是将一个数组转化为张量.
# reshape即将张量的格式进行修改
x_train = np.array(x_value,dtype=np.float32)
y_train = np.array(y_value,dtype=np.float32)
x_train = x_train.reshape(-1,1)
y_train = y_train.reshape(-1,1)
# 导入包
import torch
import torch.nn as nn
# 线性回归模型
class LinearRegressionModel(nn.Module) :
    def __init__(self,input_dim,out_dim):
        super(LinearRegressionModel,self).__init__()
        # 这里是构造一个全连接层
        self.linear = nn.Linear(input_dim,out_dim)
    # 创建x与y的构建,即此处判定了输出入x即得到y的一条线性链
    def forward(self,x):
        y = self.linear(x)
        return y

input_dim = 1
out_dim = 1
model = LinearRegressionModel(input_dim,out_dim)
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~不同处1~~~~~~~~~~~~~~~~~~~~~~~~~~~
device = torch.device("cuda:0"if torch.cuda.is_available() else "cpu")
model.to(device)
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

# 指定好参数以及损失函数
# epochs训练次数 learning_rate学习率
# SGD是一个优化器,前一个参数代表着需要优化的参数,后一个是学习率,具体一点他是一个随机梯度优化器
# MSELoss是一个损失函数,目的是用来计算出L的
epochs = 1000
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate)
criterion = nn.MSELoss()
# 训练回归模型
for epho in range(epochs):
    epho+=1
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~不同点2~~~~~~~~~~~~~~~~~~~
    # 转化为tensor
    inputs = torch.from_numpy(x_train).to(device)
    labels = torch.from_numpy(y_train).to(device)
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # 梯度每一步都需要清零
    optimizer.zero_grad()
    # 前向传播
    outputs = model.forward(inputs)
    # 计算损失值
    loss = criterion(outputs,labels)
    # 反向传播
    loss.backward()
    # 更新权重参数
    optimizer.step()
    # if epho%100==0 :
    #     print(loss.item())
# 测试模型预测的结果
predicted = model(torch.from_numpy(x_train).requires_grad_()).data.numpy()
print(predicted)
# 模型的保存以及提取
torch.save(model.state_dict(),'model.pkl')
model.load_state_dict(torch.load('model.pkl'))

常见的tensor格式

常见的tensor格式主要有下面四种:

  1. scalar
  2. vector
  3. matrix
  4. n-dimensional tensor

scalar

代表着一个值:

x = tensor(42)
print(x)
print(x.dim())# 维度
print(x.item())

输出结果:
在这里插入图片描述

vector

x = tensor([1,2,3])
print(x)
print(x.dim())
print(x.size())

输出结果:
在这里插入图片描述

matrix

基本上也是一致的,只是维度扩展了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/180211.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一篇五分生信临床模型预测文章代码复现——Figure 7 外部数据集验证模型

之前讲过临床模型预测的专栏,但那只是基础版本,下面我们以自噬相关基因为例子,模仿一篇五分文章,将图和代码复现出来,学会本专栏课程,可以具备发一篇五分左右文章的水平: 本专栏目录如下: Figure 1:差异表达基因及预后基因筛选(图片仅供参考) Figure 2. 生存分析,…

python爬虫学习笔记-SQL学习

Sql概述 先来看一个例子:小王第一次使用数据库,然后跟数据库来了个隔空对话 其实,我们想一想,mysql是一个软件,它有它自己一套的管理规则,我们想要跟它打交道,就必须遵守它的规则,如…

【stl -- 常用算法】

目录:前言一、遍历算法for_each、transform二、查找、统计算法find、find_ifadjacent_findbinary_searchcount、count_if三、排序算法sortrandom_shufflemergereverse拷贝、替换算法copyreplace、replace_ifswap算数生成算法accumulatefill集合算法set_intersection…

Day10 @Import整合第三方框架原理

1 前言Spring与MyBatis注解方式整合有个重要的技术点就是lmport,第三方框架与Spring整合xml方式很多是凭借自定义标签完成的,而第三方框架与Spring整合注解方式很多是靠import注解完成的。然后Import可以导入如下三种类:普通的配置类&#xf…

【蓝桥杯】历届真题 画廊(决赛)Java

【资源限制】 内存限制:256.0MB C/C时间限制:1.0s Java时间限制:3.0s Python时间限制:5.0s 【问题描述】 小蓝办了一个画展,在一个画廊左右两边陈列了他自己的作品。为了使画展更有意思,小…

英语学习打卡day6

2023.1.26 1.promiscuous adj.混杂的;杂乱的;滥交的 pro(往前)misc(mix):在混乱上勇往直前 2.susceptible adj.易受影响(或伤害等);敏感;过敏;感情丰富的;善感的 accept(抓)接受 be susceptible to对…敏感 She isn…

【数据结构】7.3 树表的查找

文章目录7.3.1 二叉排序树1. 二叉排序树的定义2. 二叉排序树的查找二叉排序树算法二叉排序树算法分析3. 二叉排序树的插入4. 二叉排序树的生成5. 二叉排序树的删除7.3.2 平衡二叉树1. 平衡二叉树的定义2. 平衡二叉树的平衡调整方法LL型调整RR型调整LR型调整RL型调整3. 构造平衡…

C#手动操作DataGridView之------使用各种数据源填充表格实例

C#中的表格控件只有一个,那就是datagridview,不像QT中可以用QTableview,QTableWidget。新手拿到datagridview的第一个问题就是数据从哪里来?难道从设计器中一个个手动输入,到时候要变怎办?所以,…

Python3 PIL处理任意尺寸图片为1920*1080 图片模糊 虚化 图片合并居中叠加

各位好,我是宋哈哈,很久没更新文章了,其实这篇代码是我在年前已经写好了。代码呢,也比较冷门适合人很少。仅仅对会AE , PR 视频剪辑, 又要会 python 的人,而且在公司领导又要你来做相册视频,在公…

MSBuild 命令行编译Delphi

为了构建项目,IDE现在使用MSBuild而不是以前的内部生成系统。IDE中的build、compile和make命令调用Microsoft的新生成引擎:MSBuild,它提供了全面的依赖性检查。MSBuild项目文件基于XML,包含描述项目的特定项、属性、任务和目标的部…

[GWCTF 2019]枯燥的抽奖

目录 信息收集 知识回顾 解题思路 信息收集 查看源码&#xff0c;发现check.php <?php #这不是抽奖程序的源代码&#xff01;不许看&#xff01; header("Content-Type: text/html;charsetutf-8"); session_start(); if(!isset($_SESSION[seed])){ $_SESSIO…

html表格

1.基本标签 标签名说明table表示整体&#xff0c;用于包裹多个trtr表格每行&#xff0c;用于包裹tdtd表格单元格&#xff0c;用于包裹内容 注意点&#xff1a; 表格嵌套关系&#xff1a;table>tr>td 表格table的常见属性&#xff1a; 修饰table属性的标签 需要写道tab…

【FA-GAN:超分辨率MRI图像】

FA-GAN: Fused attentive generative adversarial networks for MRI image super-resolution &#xff08;FA-GAN&#xff1a;融合注意生成对抗网络的MRI图像超分辨率&#xff09; 高分辨率磁共振图像可以提供细粒度的解剖信息&#xff0c;但是获取这样的数据需要长的扫描时间…

UVA11426 - GCD - Extreme (II)(数论,欧拉函数)

题目链接&#xff1a;GCD - Extreme (II) - UVA 11426 - Virtual Judge (vjudge.net)​​​​​ 题意 给一个数N&#xff0c;求&#xff1a; ​​​​​​​ 其中&#xff0c;多组输入&#xff0c;输入以0结束&#xff0c;保证答案在long long范围内。 思路 很好的一道题…

移动web主轴设置和flex总结

移动web主轴设置和flex总结设置主轴方向修改主轴经常的使用场景&#xff1a;弹性盒子换行设置侧轴对齐方式flex 总结梳理主轴排列方式侧轴对齐方式-单行对齐侧轴对齐方式-多行弹性盒子换行设置主轴方向伸缩比设置主轴方向 主轴默认是水平方向, 侧轴默认是垂直方向 修改主轴方…

Linux常用命令——semanage命令

在线Linux命令查询工具(http://www.lzltool.com/LinuxCommand) semanage 默认目录的安全上下文查询与修改 补充说明 semanage命令是用来查询与修改SELinux默认目录的安全上下文。SELinux的策略与规则管理相关命令&#xff1a;seinfo命令、sesearch命令、getsebool命令、set…

天地图矢量注记图坑

http://lbs.tianditu.gov.cn/server/MapService.html瓦片图案例见下文&#xff0c;注意其中的LAYER:&#xff0c;记住口诀&#xff0c;地址里用什么&#xff0c;这个layer就用什么。比如影像注记里&#xff0c;地址用了cia_w&#xff0c;那么这个layer后面必须是cia_w&#xff…

别总写代码,这130个网站比涨工资都重要

今天推荐一些学习资源给大家&#xff0c;当然大家可以留言评论自己发现的优秀资源地址 搞学习 找书籍 冷知识 / 黑科技 写代码 资源搜索 小工具 导航页&#xff08;工具集&#xff09; 看视频 学设计 搞文档 找图片 搞学习 TED&#xff08;最优质的演讲&#xff09;&#xff1…

解析JVM类加载器

文章目录1、何为类加载器2、三层类加载器3、双亲委派模型参考资料&#xff1a;《深入理解Java虚拟机》 1、何为类加载器 类加载过程中&#xff0c;加载阶段第一步操作就是通过一个类的全限定名获取此类的二进制字节流。实现这个动作的代码就是类加载器。 任意一个类都必须由加…

mybatis-plus1(前言技术)

目录 一、Mybatis-plus入门 1.什么是mybatis-plus 2.初体验 ① 准备数据库脚本 ② 初始化工程 ③ 编码 ④ 开始使用 3.日志 二、Mybatis-plus主键生成策略 1.更新 三、Mybatis-plus自动填充 1&#xff1a;通过数据库完成自动填充 2&#xff1a;使用程序完成自…