pytorch学习--第一个模型(线性模型)

news2024/12/27 3:14:20

目标

我们想通过随机初始化的参数 ω , b \omega ,b ω,b能在迭代过程中使预测值和目标值能无限接近
y = ω x + b y=\omega x+b y=ωx+b

定义数据

x = torch.rand([60, 1])*10
y = x*2 + torch.randn(60,1)

构建模型

利用pytorch中的nn.Module
想要构建模型时,继承这个类即可
一些重写nn.Module类时的注意事项
(1)一般把网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数__init__()中;
(2)一般把不具有可学习参数的层(如ReLU、dropout、BatchNormanation层)可放在构造函数中,也可不放在构造函数中,如果不放在构造函数__init__里面,则在forward方法里面可以使用nn.functional来代替
(3)forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心。

from torch import nn
class Lr(nn.Module):
    def __init__(self):
        super(Lr, self).__init__()  #继承父类init的参数
        self.linear = nn.Linear(1, 1) #只有线性层(全链接层)
 
    def forward(self, x):
        out = self.linear(x)#输出
        return out

输出的数量nn.Linear(in_features, out_features);nn.Linear(1, 1)这里的参数易知我们通过方程得到的最终是一列数

# 实例化模型
model = Lr()
# 传入数据,计算结果
predict = model(x)

优化器

1、优化器主要是在模型训练阶段对模型可学习参数进行更新, 常用优化器有 SGD,RMSprop,Adam等
2、优化器初始化时传入传入模型的可学习参数,以及其他超参数如 lr,momentum等
3、在训练过程中先调用 optimizer.zero_grad() 清空梯度,再调用 loss.backward() 反向传播,最后调用 optimizer.step()更新模型参数
4、参数可以使用model.parameters()来获取,获取模型中所有requires_grad=True的参数

optimizer = optim.SGD(model.parameters(), lr=1e-3) #1. 实例化,1e-3也可以写成0.001
optimizer.zero_grad() #2. 梯度置为0
loss.backward() #3. 计算梯度
optimizer.step()  #4. 更新参数的值

损失函数

torch中有很多损失函数

1、均方误差:nn.MSELoss(),常用于回归问题

2、交叉熵损失:nn.CrossEntropyLoss(),常用于分类问题

criterion = nn.MSELoss() # 实例化损失函数

训练模型

1、定义一个epoch,代表需要将所有数据训练epoch个轮次
2、数据传入模型,获取预测值
3、将预测值和目标值传入损失函数,计算损失
4、优化器的梯度归零,在每次更新参数中必须进行此步骤,否则梯度会一直累加
5、计算梯度,此步骤在4之后进行
6、更新梯度,参数随之更新
7、(可选)在训练过程中每隔一段时间打印下损失,观察收敛速度

#训练模型
for i in range(30000):
    out = model(x)  # 3.1 获取预测值
    loss = criterion(y, out)  # 3.2 计算损失
    optimizer.zero_grad()  # 3.3 梯度归零
    loss.backward()  # 3.4 计算梯度
    optimizer.step()  # 3.5 更新梯度
    if i % 300 == 0:
        print('Epoch[{}/{}], loss: {:.6f}'.format(i, 30000, loss.data))

模型测试

在模型的测试中,我们一般会使用测试集来评估训练得到的模型,这时候我们不需要梯度相关的操作,只需要将数据通过模型,得到损失、精确率等即可。测试中有以下需要注意:

model.eval()  # 设置模型为评估模式,即预测模式
predict = model(x)

绘图

predict = predict.data.numpy()
plt.scatter(x.data.numpy(), y.data.numpy(), c="r")
plt.plot(x.data.numpy(), predict)
plt.show()

在GPU上运行

判断GPU是否可用torch.cuda.is_available()

1、torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)

device(type=‘cuda’, index=0) #使用gpu
device(type=‘cpu’) #使用cpu

2、把模型参数和input数据转化为cuda的支持类型

model.to(device)
x_true.to(device)

3、在GPU上计算结果也为cuda的数据类型,需要转化为numpy或者torch的cpu的tensor类型

predict = predict.cpu().detach().numpy()
detach()的效果和data的相似,但是detach()是深拷贝,data是取值,是浅拷贝

完整代码

import torch
from torch import nn
from torch import optim
import numpy as np
from matplotlib import pyplot as plt

# 1. 定义数据
x = torch.rand([60, 1])*10
y = x*2 + torch.randn(60,1)


# 2 .定义模型
class Lr(nn.Module):
    def __init__(self):
        super(Lr, self).__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        out = self.linear(x)
        return out


# 2. 实例化模型,loss,和优化器
model = Lr()
# 损失函数
criterion = nn.MSELoss()
# 优化器
optimizer = optim.SGD(model.parameters(), lr=1e-3)
# 3. 训练模型
for i in range(30000):
    out = model(x)  # 3.1 获取预测值
    loss = criterion(y, out)  # 3.2 计算损失
    optimizer.zero_grad()  # 3.3 梯度归零
    loss.backward()  # 3.4 计算梯度
    optimizer.step()  # 3.5 更新梯度
    if i % 300 == 0:
        print('Epoch[{}/{}], loss: {:.6f}'.format(i, 30000, loss.data))

# 4. 模型评估
model.eval()  # 设置模型为评估模式,即预测模式
predict = model(x)
predict = predict.data.numpy()
plt.scatter(x.data.numpy(), y.data.numpy(), c="r")
plt.plot(x.data.numpy(), predict)
plt.show()

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/771105.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SylixOS下SSH和SFTP连接

简要 基于网络的连接(telnet,ftp)方便高效,但其是基于明文的通信,容易被窃取、篡改和攻击,存在网络安全问题,尤其在进行远程访问时,穿过复杂未知的公网环境非常危险,为此…

linux之Ubuntu系列(八)用户管理 修改文件权限

修改文件权限 chown 修改拥有者 修改 文件|目录 的拥有者 sudo chown 用户名 文件名|目录 递归修改文件|目录的组 sudo chgrp -R 组名 文件|目录 递归修改文件权限 chmod -R xxx 文件|目录

mongoDB 分组汇总统计-执行语句(亲测)

# 注意 “gl_id” 需要分组的字段名 “gl_idlCount” 分组后获取数量的字段名(可随意命名) db.getCollection(集合名).aggregate([{ "$group": {"_id": {"gl_id": "$gl_id"},"gl_idlCount": { "…

一文读懂FPC(16)- 关于过孔盖油和过孔开窗的区分

FPC系列文章目录 1.什么是FPC 2.什么是R-FPC 3,FPC的基材 4.FPC基材压延铜和电解铜的区别 5,FPC的辅材 6,FPC常见的四种类型 7,FPC的生产流程简介 8,R-FPC的生产流程简介 9,FPC的发展及应用 10&a…

优思学院|六西格玛管理:依据事实的质量管理方式

一个企业的质量管理制度是否规范,也就是质量管理体系是否很完备的问题,要考察管理体系是否还有哪里不尽完美?各部门之间的连系、调整是否能够顺利进行?各自是否达成在质量保证上的任务等,进行质量管理体系的审核&#…

【无线通信模块】什么是PCB板载天线,PCB板载天线UART/USB接口WiFi模块

基于射频技术的无线模块需要通过天线来发射和接收电磁波信号,市场上常见的天线类型有陶瓷天线、板载天线、棒状天线以及外接天线,外接天线是通过在PCB板上预留IPEX座子,可选天线类型就比较多。本篇SKYLAB小编带大家了解一下板载天线的UART接口…

ROS:pluginlib

目录 一、前言二、概念三、作用四实际用例4.1需求4.2流程4.3准备4.4创建基类4.5创建插件4.6注册插件4.7构建插件库4.8使插件可用于ROS工具链4.8.1配置xml4.8.2导出插件 4.9使用插件4.10执行 一、前言 pluginlib直译是插件库,所谓插件字面意思就是可插拔的组件&…

vue使用docxtemplater导出word实现使用textarea输入的内容换行

注:本文只做导出word并且换行操作,不做vue引入docxtemplater步骤 先看一下实现效果 这是文本域输入的 这是导出来的结果 可以看出来导出来的结果也是换行的呢 接下来我们手摸手操作一下流程 首先咱们捋一捋思路 知道文本域的换行的换行标识符,我们发…

Spingboot 多模块引入第三方jar包

1. 在需要的模块中引入jar包 2. 在此模块中的pom.xml 中引用 3. 要想打包部署服务器&#xff0c;需要在启动模块中添加配置信息 ps&#xff1a;启动模块要引用此模块才能将此一起jar打包部署 <build><plugins><plugin><groupId>org.springframework.…

对链表进行插入排序

给定单个链表的头 head &#xff0c;使用 插入排序 对链表进行排序&#xff0c;并返回 排序后链表的头 。 插入排序 算法的步骤: 插入排序是迭代的&#xff0c;每次只移动一个元素&#xff0c;直到所有元素可以形成一个有序的输出列表。 每次迭代中&#xff0c;插入排序只从输…

刷新vue项目后,在非routerview中的组件获取路由路径永远是“/“的解决方案

问题&#xff1a; //此文件时项目导航栏&#xff0c;不属于router-view的内容 //route.fullPath表示当前的页面路径onMounted(() > {setTimeout(() > {console.log("100ms________", route.fullPath);}, 100);setTimeout(() > {console.log("500ms___…

Comparator.comparing()实现中文排序及空指针处理

一、 Comparator.comparing()的用法请详见以下上一篇文章的汇总介绍。 Comparator用法_乞力马扎罗の黎明的博客-CSDN博客 二、应用示例&#xff1a; 1、中文排序、空值处理 Collator instance Collator.getInstance(Locale.CHINA); checkItemVoList.stream().sorted(Compar…

微服务保护——Sentinel【快速入门】

一、雪崩问题&#x1f349; (一) 什么是雪崩&#x1f95d; 微服务调用链路中的某个服务故障&#xff0c;引起整个链路中的所有微服务都不可用&#xff0c;这就是雪崩。服务D故障引起服务A故障&#xff0c;服务A引起其他服务故障&#xff0c;渐渐导致所有微服务都不可用。有人…

【深入探究人工智能】:历史、应用、技术与未来

深入探究人工智能 前言人工智能的历史人工智能的应用人工智能的技术人工智能的未来当代的人工智能产物结语&#x1f340;小结&#x1f340; &#x1f389;博客主页&#xff1a;小智_x0___0x_ &#x1f389;欢迎关注&#xff1a;&#x1f44d;点赞&#x1f64c;收藏✍️留言 &am…

Linux云服务器,docker compose文件部署多个jar,docker部署多模块boot项目

前提条件 Linux服务器 服务器已经安装docker docker已经安装jdk镜像 docker已经安装mysql镜像 将要部署的项目的jar包打包好&#xff0c;项目是多模块springboot项目 部署过程 项目是3个模块的Spring boot项目&#xff0c;打出来3个jar&#xff0c;将这些jar包拷贝到…

小米手把手教你轻松搞定复杂需求!聊聊商家与店铺的关系优化方案

大家好&#xff0c;我是小米&#xff01;今天要和大家分享一次与产品大佬张小姐的有趣对话和我所面对的一项“小需求”。废话不多说&#xff0c;让我们一起来看看如何应对这个需求挑战吧&#xff01; 系统现状 在开始之前&#xff0c;先让我们简单回顾一下目前系统的现有功能…

echarts实现渐变折线图并添加点击事件

折线图点击事件代码: let myChart = this.$echarts.init(document.getElementById(trendBoxECharts))myChart.getZr().on(click, params => {console.log(params)let pointInPixel = [params.offsetX, params.offsetY]if (myChart.containPixel(grid, pointInPixel)) {//点…

云迁移第二波热潮来袭,你准备好了吗?

最近&#xff0c;云迁移再次被频繁提及&#xff0c;企业对云迁移的需求量有回升趋势&#xff0c;究其根本&#xff0c;主要有以下原因&#xff1a; 企业数字化进程加速&#xff0c;本地上云需求强劲 根据《2021中国企业上云指数洞察报告》&#xff0c;我国实体经济上云渗透率…

vue 当新增样式无法生效的情况下如何处理

使用scoped属性时&#xff0c;会遇到样式问题。需要使用样式穿透解决 <style lang"scss" scoped> </style> 可以使用以下方法 &#xff1a;deep css 使用 >>> less 使用 /deep/ scss 使用 ::v-deep 代码写法如下: .a :deep(.b) { } .…

Word文档突然无法打开?如何修复损坏文档?

在工作学习中&#xff0c;通常会遇到这种情况&#xff0c;我们正在编辑Word文件&#xff0c;电脑忽然断电关机&#xff0c;或者死机需要重启。当电脑重启以后&#xff0c;辛辛苦苦编辑很久的Word文件却忽然打不开了&#xff01;一直提示文件错误&#xff0c;如何解决Word无法打…