Pytorch最最适合研究生的入门教程,Q3 开始训练

news2024/11/29 10:58:33

文章目录

  • Pytorch最最适合研究生的入门教程
    • Q3 开始训练
      • 3.1 训练的见解
      • 3.2 Pytorch基本训练框架
        • work

Pytorch最最适合研究生的入门教程

Q3 开始训练

3.1 训练的见解

如何理解深度学习能够完成任务? 考虑如下回归问题
由函数 y = f ( x ) y=f(x) y=f(x)采样得到的100个点
1 x l n x + 0.65 ∣ l n ( x + x 2 − l g ( x + 1 ) ) ∣ + 0.05 R ( t ) \frac{1}{x}lnx+0.65\lvert ln(x+x^2 - lg(x+1)) \rvert + 0.05R(t) x1lnx+0.65ln(x+x2lg(x+1))∣+0.05R(t)
其中 R ( t ) R(t) R(t)函数用于生成0-1的随机数
在这里插入图片描述

而我们在回归任务中主要有两个
①通过前80个点进行训练,推理得到后20个点
②通过训练100个点中随机80个点,推理其余20个点的值
以上①属于外推任务,②属于内插任务

内插
内插是指利用已知数据点来预测或估计已知数据点之间的值

①仅限于已知数据点的范围内,即预测已知数据之间的值。
②由于数据点是已知的,内插通常比外推更可靠,因为预测的值更接近实际值。
③内插常用于插值计算,例如在绘图、科学计算和工程领域。

外推
外推是指利用已知数据点来预测或估计未知数据点,尤其是那些位于已知数据点之外的点的值。

①通常用于预测已知数据点之外的值,即向数据范围的更远处进行预测。
②因为预测的是未知区域,所以外推通常伴随着较高的不确定性,结果可能不太可靠。
③在外推中,可能会使用曲线拟合、回归分析或更复杂的数学模型来预测趋势。

以下是个人理解,
相对来说,深度学习更加适合内插任务。

比如 1, 5, 10, 30, 50,预测下一个数
和 1, 10, 30, 50,预测第二个数,其难度是完全不一样的

当数据合适且都处于内插范围,即使是网络结构简单,都能有不错的效果
这项结论在CV、NLP任务中也绝对是成立的,即当训练集基本涵盖了所有可能出现的特征时,预测其余特征的难度会大幅度下降。这一点体现了神经网络的记忆性
而在针对外推等先验信息不足的任务的适合,任何结构的神经网络推理能力都是有限的!
所以,针对内插任务,我们考虑模型函数
P = g ( X , W ) P = g(X, W) P=g(X,W)
其中 P P P为神经网络的输出, X X X为模型输入(特征向量), W W W为所有参数的集合
当我们满足以下关系

X
Model
Function
Y
P

如果满足
P → Y P \to Y PY
则可以说在 U ˚ ( X , δ ) \mathring{U}(X, \delta) U˚(X,δ)满足
M o d e l → F u n c t i o n Model \to Function ModelFunction
此时称模型训练结束,且得到模型为精度最优模型
但实际训练过程中, 基本采用 P → Y + r ( X , Y ) P \to Y+r(X,Y) PY+r(X,Y)作为目标函数
其中 r ( X , Y ) r(X,Y) r(X,Y)损失函数
则我们最终优化式为
a r g m i n W r ( X , Y ) \mathop{argmin}\limits_{W} {r(X,Y)} Wargminr(X,Y)
而神经网络的训练过程就是通过梯度下降算法来式式子最小


3.2 Pytorch基本训练框架

我们这里规定,所有的训练代码,基本都要符合如下训练框架。而后续我们的教程也是围绕这个基本框架展开

模型训练
载入批数据
前馈得到结果
计算损失
反向传播
载入数据
载入模型
载入优化器
载入损失函数
结束训练

对应以上框架,写出最最最基础的代码

import torch
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader

# 超参数设置
batches = 16  # 批大小
lr = 1e-3  # 学习率
epochs = 100

# 创造数据
X = torch.linspace(0, 1, 10000).reshape(-1, 10)
Y = torch.sigmoid(X).mean(dim=1, keepdim=True) + 0.05 * torch.rand(X.shape[0], 1)

# 创建移入Dataset
dataset = TensorDataset(X, Y)

# 创建移入DataLoad
dataloader = DataLoader(dataset, batch_size=batches)

# 创建模型
model = torch.nn.Sequential(
    torch.nn.Linear(10, 128),
    torch.nn.Sigmoid(),
    torch.nn.Linear(128, 128),
    torch.nn.Sigmoid(),
    torch.nn.Linear(128, 1)
)

# 创建优化器
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# 创建损失函数
criterion = torch.nn.MSELoss()

# 训练
for epoch in range(epochs):
    for idx, data in enumerate(dataloader):
        x, y = data
        p = model(x)

        loss = criterion(p, y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

# 结束训练
with torch.no_grad():
    P = model(X)
loss = criterion(P, Y)
print(f'Total Loss: {float(loss):.3f}')

plt.plot(P, label='prediction')
plt.plot(Y, label='Evaluation', ls='--')
plt.plot(torch.abs(P - Y), label='Absolute Loss')
plt.legend()
plt.show()

Total Loss: 0.001
在这里插入图片描述

其中某些参数的解释

参数名词解析
batches批大小指一次前馈中用于训练的样本数量(加速训练)
lr学习率学习率指梯度下降过程中的超参数
epochs迭代次数指总共模型迭代次数
datasetTorch中数据集类训练中使用dataloader取出dataset的数据
dataloaderTorch中数据迭代类训练中每次取出(batches)个样本
work

将Q2中work中的模型运用起来,修改参数后使用iris数据集进行训练

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2189141.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

现在的新电脑在任务管理器里又多了个NPU?它是啥?

前言 今年中旬各家品牌的新笔记本感觉上都是很不错,搞得小白自己心痒痒,突然间想要真的买一台Windows笔记本来耍耍了。 但今天这个文章并不是什么商品宣传啥的,而是小白稍微尝试了一下新笔记本之后的一些发现。 在今年的新笔记本上都多了一…

【GESP】C++一级练习BCQM3025,输入-计算-输出-6

题型与BCQM3024一样,计算逻辑上稍微复杂了一点点,代码逻辑没变,仍属于小学3,4年级的题目水平。 题解详见:https://www.coderli.com/gesp-1-bcqm3025/ https://www.coderli.com/gesp-1-bcqm3025/https://www.coderli.c…

数据提取之JSON与JsonPATH

第一章 json 一、json简介 json简单说就是javascript中的对象和数组,所以这两种结构就是对象和数组两种结构,通过这两种结构可以表示各种复杂的结构 > 1. 对象:对象在js中表示为{ }括起来的内容,数据结构为 { key&#xff1…

最新版本SkyWalking【10.1.0】部署

这里写目录标题 前言前置条件启动Skywalking下载解压启动说明 集成Skywalking Agent下载Agent在IDEA中添加agent启动应用并访问SpringBoot接口 说明 前言 基于当前最新版10.1.0搭建skywalking 前置条件 装有JDK11版本的环境了解SpringBoot相关知识 启动Skywalking 下载 地…

浑元换算策略和武德换算策略-《分析模式》漫谈36

DDD领域驱动设计批评文集 做强化自测题获得“软件方法建模师”称号 《软件方法》各章合集 “Analysis Patterns”的第3章有这么一句: A conversion, however deterministic, does not follow that faithfully. 2004(机械工业出版社)中译本…

HTB:Explosion[WriteUP]

目录 连接至HTB服务器并启动靶机 1.What does the 3-letter acronym RDP stand for? 2.What is a 3-letter acronym that refers to interaction with the host through a command line interface? 3.What about graphical user interface interactions? 4.What is the…

【MySQL 08】复合查询

目录 1.准备工作 2.多表查询 笛卡尔积 多表查询案例 3. 自连接 4.子查询 1.单行子查询 2.多行子查询 3.多列子查询 4.在from子句中使用子查询 5.合并查询 1.union 2.union all 1.准备工作 如下三个表,将作为示例,理解复合查询 EMP员工表…

深入探究:在双链表的前面进行插入操作的顺序

归纳编程学习的感悟, 记录奋斗路上的点滴, 希望能帮到一样刻苦的你! 如有不足欢迎指正! 共同学习交流! 🌎欢迎各位→点赞 👍 收藏⭐ 留言​📝惟有主动付出,才有丰富的果…

一次解决Go编译问题的经过

用Go语言编写了一个小的项目,项目开发环境是在本地的Windows环境中,一切单元测试和集成测试通过后,计划将项目部署到VPS服务器上自动运行,但在服务器上执行go run运行时,程序没有任何响应和回显,甚至main函…

有没有一款软件,可以在二楼电脑直接唤醒三楼的电脑?

前言 今天有个小姐姐找到我,咨询能不能在二楼的电脑直接访问到三楼电脑的资料。 这个肯定是可以的啊! 其实事情很简单,只需要弄好共享文件夹这个功能,只要手机、平板或者电脑在同个局域网下,就能访问到三楼电脑里的…

深入理解Dubbo源码核心原理-Part4

现在开始研究,消费端真正调用proxy的方法时,走的rpc调用 接下来就要走client,发送request请求了 Dubbo协议是怎样的呢? 具体每个字段什么含义请参照官网 链接:Dubbo协议头含义 编码器按照Dubbo协议来进行编码请求 Ne…

JVM内存回收机制

目录 1.JVM运行时数据区 2.JVM类加载过程 3.双清委派模型 4.垃圾回收机制(GC) 找出谁是垃圾方案一:引用计数 找出谁是垃圾:方案二,可达性分析 释放垃圾的内存空间 判断垃圾:jvm依据对象的年龄对 对象…

基于Zynq SDIO WiFi移植三(支持2.4/5G)

应用问题-WIFI作为AP-hostapd多次连接 设备作为WIFI热点时,连接出现了下述问题: 1 手机连接需要三次,三次都需要输入密码; 2 平板连接需要三次,三次都需要输入密码; 3 电脑连接需要一次,无感…

隧道人员定位UWB双通道定位终端

大家好,我是华星智控小智,今天我给大家介绍我们的UWB双通道定位终端。 双通道定位终端(型号STD)主要用于隧道人员或天车定位,终端基于无线脉冲技术,采用双天线设计,可实现对2路方向的测距定位&a…

实施威胁暴露管理、降低网络风险暴露的最佳实践

随着传统漏洞管理的发展,TEM 解决了因攻击面扩大和安全工具分散而产生的巨大风险。 主动式 TEM 方法优先考虑风险并与现有安全工具无缝集成,使组织能够在威胁被有效利用之前缓解威胁。 为什么威胁暴露管理 (TEM) 在现代网络安全策略中变得至关重要&…

使用模拟和真实的 Elasticsearch 来测试你的 Java 代码

作者:来自 Elastic Piotr Przybyl 在本文中,我们将介绍并解释两种使用 Elasticsearch 作为外部系统依赖项来测试软件的方法。我们将介绍使用模拟测试和集成测试的测试,展示它们之间的一些实际差异,并给出一些关于每种风格的提示。…

嵌入式C语言自我修养:编译链接

源文件生成可执行文件的过程? 源文件经过预处理、编译、汇编、链接生成一个可执行的目标文件。 编译器驱动程序,包括预处理器、编译器、汇编器和链接器。Linux用户可以调用GCC驱动程序来完成整个编译流程。 使用GCC驱动程序将示例程序从ASCII码源文件转换…

如何使用EventChannel

文章目录 1 知识回顾2 示例代码3 经验总结我们在上一章回中介绍了MethodChannel的使用方法,本章回中将介绍EventChannel的使用方法.闲话休提,让我们一起Talk Flutter吧。 1 知识回顾 我们在前面章回中介绍了通道的概念和作用,并且提到了通道有不同的类型,本章回将其中一种…

仿RabbitMQ实现消息队列服务端(一)

文章目录 交换机数据管理队列数据管理绑定信息(交换机-队列)管理队列消息管理虚拟机管理交换机路由管理队列消费者/订阅者管理 整体框架:工具模块及项目整体模块框架 交换机数据管理 交换机数据管理就是描述了交换机应该有哪些数据 定义交换机数据类 1、交换机的名…

Linux忘记root用户密码怎么重设密码

直接说步骤: 1.重启客户机 2.在选择内核页面快速按e键,进入编辑模式 进入后应该是这个样子 在这里只能按上下键切换行 找到Linux16这里 3.按右方向键切换到行尾,也就是UTF-8处,在后面添加一个空格,然后加上这段话 …