Pytorch基本概念和使用方法

news2024/9/22 9:45:27

目录

1 Adam及优化器optimizer(Adam、SGD等)是如何选用的?

1)Momentum

2)RMSProp

3)Adam

2 Pytorch的使用以及Pytorch在以后学习工作中的应用场景。

1)Pytorch的使用

2)应用场景

3 不同的数据、数据集加载方式以及加载后各部分的调用处理方式。如DataLoder的使用、datasets内置数据集的使用。

4 如何加快训练速度以及减少GPU显存占用

技巧1:inplace=True

技巧2:with torch.no_grad():

技巧3:forward中的变量命名

技巧4:Dataloader数据读取

技巧5:gradient accumulation


1 Adam及优化器optimizer(Adam、SGD等)是如何选用的?

深度学习的优化算法主要有GD,SGD,Momentum,RMSProp和Adam算法。Adam是一种计算每个参数的自适应学习率的方法。相当于 RMSprop + Momentum。

在讲这个算法之前说一下移动指数加权平均。移动指数加权平均法加权就是根据同一个移动段内不同时间的数据对预测值的影响程度,分别给予不同的权数,然后再进行平均移动以预测未来值。假定给定一系列数据值那么,我们根据这些数据来拟合一条曲线,所得的值就是如下的公式:

其中,在上面的公式中,β等于历史值的加权率。根据这个公式我们可以根据给定的数据,拟合出下图类似的一条比较平滑的曲线。

1)Momentum

通常情况我们在训练深度神经网络的时候把数据拆解成一小批地进行训练,这就是我们常用的mini-batch SGD训练算法,然而虽然这种算法能够带来很好的训练速度,但是在到达最优点的时候并不能够总是真正到达最优点,而是在最优点附近徘徊。另一个缺点就是这种算法需要我们挑选一个合适的学习率,当我们采用小的学习率的时候,会导致网络在训练的时候收敛太慢;当我们采用大的学习率的时候,会导致在训练过程中优化的幅度跳过函数的范围,也就是可能跳过最优点。我们所希望的仅仅是网络在优化的时候网络的损失函数有一个很好的收敛速度同时又不至于摆动幅度太大。

所以 Momentum 优化器刚好可以解决我们所面临的问题,它主要是基于梯度的移动指数加权平均。假设在当前的迭代步骤第 t 步中,那么基于 Momentum 优化算法可以写成下面的公式:

其中,在上面的公式中vdwvdb分别是损失函数在前 t-1 轮迭代过程中累积的梯度动量,β是梯度累积的一个指数,这里我们一般设置值为0.9。所以Momentum优化器的主要思想就是利用了类似于移动指数加权平均的方法来对网络的参数进行平滑处理的,让梯度的摆动幅度变得更小。

dW和db分别是损失函数反向传播时候所求得的梯度,下面两个公式是网络权重向量和偏置向量的更新公式,α是网络的学习率。当我们使用Momentum优化算法的时候,可以解决mini-batch SGD优化算法更新幅度摆动大的问题,同时可以使得网络的收敛速度更快。

2)RMSProp

RMSProp算法的全称叫 Root Mean Square Prop,是Geoffrey E. Hinton在Coursera课程中提出的一种优化算法,在上面的Momentum优化算法中,虽然初步解决了优化中摆动幅度大的问题。所谓的摆动幅度就是在优化中经过更新之后参数的变化范围,如下图所示,蓝色的为Momentum优化算法所走的路线,绿色的为RMSProp优化算法所走的路线。

为了进一步优化损失函数在更新中存在摆动幅度过大的问题,并且进一步加快函数的收敛速度,RMSProp算法对权重 W 和偏置 b 的梯度使用了微分平方加权平均数。 其中,假设在第 t 轮迭代过程中,各个公式如下所示:

算法的主要思想就用上面的公式表达完毕了。在上面的公式中sdw和sdb分别是损失函数在前 t−1 轮迭代过程中累积的梯度动量,β是梯度累积的一个指数。所不同的是,RMSProp算法对梯度计算了微分平方加权平均数。这种做法有利于消除了摆动幅度大的方向,用来修正摆动幅度,使得各个维度的摆动幅度都较小。另一方面也使得网络函数收敛更快。(比如当 dW或者 db中有一个值比较大的时候,那么我们在更新权重或者偏置的时候除以它之前累积的梯度的平方根,这样就可以使得更新幅度变小)。为了防止分母为零,使用了一个很小的数值 来进行平滑,一般取值为10的负八次方。

3)Adam

有了上面两种优化算法,一种可以使用类似于物理中的动量来累积梯度,另一种可以使得收敛速度更快同时使得波动的幅度更小。那么将两种算法结合起来所取得的表现一定会更好。Adam(Adaptive Moment Estimation)算法是将Momentum算法和RMSProp算法结合起来使用的一种算法。

很多论文里都会用 SGD,没有 Momentum 等。SGD 虽然能达到极小值,但是比其他算法用的时间长,而且可能会被困在鞍点

如果需要更快的收敛,或者是训练更深更复杂的神经网络,需要用一种自适应的算法。

整体来讲,Adam 是最好的选择。

2 Pytorch的使用以及Pytorch在以后学习工作中的应用场景。

1)Pytorch的使用

①安装pytorch

②使用Spyder创建一个project,点击Projects--->New Project

③在其中输入project名称,选择项目地址就ok了,比如我们创建Handwritten_numeral_recognition(手写数字识别)

④创建一个module,创建一个test.py。

⑤输入import torch,就可以开始pytorch的使用了。

2)应用场景

①医疗

医学图像分割

基于U-net的医学影像分割

通过Pytorch深度学习框架,编写分割脑部解剖结构程序。

②工业

比如通过Pytorch深度学习框架,编写设备的剩余寿命预测、故障诊断程序。

3 不同的数据、数据集加载方式以及加载后各部分的调用处理方式。如DataLoder的使用、datasets内置数据集的使用。

- dataloader本质是一个可迭代对象,使用iter()访问,不能使用next()访问;

- 使用iter(dataloader)返回的是一个迭代器,然后可以使用next访问;

- 也可以使用`for inputs, labels in dataloaders`进行可迭代对象的访问;

- 一般我们实现一个datasets对象,传入到dataloader中;然后内部使用yeild返回每一次batch的数据;

pytorch 的数据加载到模型的操作顺序是这样的:

① 创建一个 Dataset 对象

② 创建一个 DataLoader 对象

③ 循环这个 DataLoader 对象,将img, label加载到模型中进行训练

dataset = MyDataset()

dataloader = DataLoader(dataset)

num_epoches = 100

for epoch in range(num_epoches):

    for img, label in dataloader:

        ....

所以,作为直接对数据进入模型中的关键一步, DataLoader非常重要。

4 如何加快训练速度以及减少GPU显存占用

到底什么在占用显存?

输入的数据占用空间其实并不大,比如一个(256, 3, 100, 100)的Tensor(相当于batchsize=256的100*100的三通道图片。)只占用31M显存。

实际上,占用显存的大头在于:1. 动辄上千万的模型参数;2. 模型中间变量;3. 优化器中间参数。

第一点模型参数不必介绍;第二点,中间变量指每个语句的输出。而在backward时,这部分中间变量会翻倍(因为需要保留原中间值)。第三点,优化器在梯度下降时,模型参数在更新时会产生保存中间变量,也就是模型的params在这时翻倍。

技巧1:inplace=True

一些激活函数与Dropout有一个参数"inplace",默认设置为False,当设置为True时,我们在通过ReLU()计算时得到的新值不会占用新的空间而是直接覆盖原来的值,这也就是为什么当inplace参数设置为True时可以节省一部分内存的缘故。但在某些需要原先的值的情况下,就不可设置inplace。

此操作相当于针对显存占用第二点(模型中间变量)的优化。

技巧2:with torch.no_grad():

对于只需要forward而不需要backward的过程(validation和test),使用torch.no_grad做上下文管理器(注意要在model.eval()之后),可以让测试时batchsize扩大近十倍,而且也可以加速测试过程。此操作相当于针对显存占用第二点(因为直接没有backward了)和第三点进行优化。

model.eval()

with torch.no_grad():

pass

技巧3:forward中的变量命名

在研究pytorch官方架构和大神的代码后可发现大部分的forward都是以x=self.conv(x)的形式,很少引入新的变量,所以启发两点以减少显存占用(1)把不需要的变量都用x代替,(2)变量用完之后马上用del删除(此操作慎用,清除显存的同时使得backProp速度变慢)。此操作相当于针对第二点(模型中间变量)进行优化。

技巧4:Dataloader数据读取

一定要使用pytorch的Dataloader来读取数据。按照以下方式来设置:

loader = data.Dataloader(PYTORCH_DATASET, num_works=CPU_COUNT,

                         pin_memory=True, drop_last=True)

第一个参数是用pytorch制作的TensorDataset,第二个参数是CPU的数量(默认为0,在真正训练时建议调整),第三个参数默认为False,用来控制是否把数据先加载到缓存再加载到GPU,建议设置为True,第四个参数用于扔掉最后一个batch,使得训练更为稳定。

将pin_memory开启后,在通过dataloader读取数据后将数据to进GPU时把non_blocking设置为True,可以大幅度加快数据计算的速度。

for input_tensor in loader:

    input_tensor.to(gpu, non_blocking=True)

    model.forward(input_tensor)

技巧5:gradient accumulation

梯度积累通过累计梯度来解决本地显存不足的问题,即不在每个batch都更新模型参数,而是每经过accumulation steps步后,更新一次模型参数。相当于针对第三点(n步才更新一次参数)来进行优化。且由于参数更新的梯度计算是算力消耗的一部分,故梯度累计还可以一定程度上加快训练速度。

loss = model(input_tensor)

loss.backward()

if batch_idx % accumulate_steps == 0:

    optim.step()

    optim.zero_grad()

相当于一个epoch的步数(step)变少了(一个step相当于参数更新一次),但单个step的计算时间变长了(略小于n倍的原来时间)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/594158.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue methods 互相调用的方法

methods是一个内置的函数,主要用于两个组件之间的数据传递,也就是调用方法。下面给大家介绍一个在 vue中互相调用的方法,在使用过程中可以参考一下。 methods实现了两个组件之间数据的传递,我们先来看一下 Methods是如何实现数据传…

统计软件与数据分析Lesson17----利用pytorch构建LSTM预测股票收益率详细教程

利用pytorch构建LSTM预测股票收益率详细教程 1. 整体实现思路2.代码编写2.1 step1:导入所需的库2.2 step2: 读取数据、构建训练样本2.3 step3: 定义部分辅助函数2.4 step4:LSTM模型构建2.5 step5:模型训练2.6 step6:模型预测和评估 3. 小结 1. 整体实现思路 step1:导入所需的库…

对抗样本攻击

目录 一、对抗样本攻击的基本原理 1.1 什么是对抗样本攻击和对抗样本 1.2 对抗样本攻击的基本思路 1.3 对抗样本攻击的分类 1.3.1 按攻击效果分类 1.3.2 按攻击者能力分类 1.3.3 按攻击环境分类 1.4 对抗扰动的衡量 二、对抗样本攻击方法 一、对抗样本攻击的基本原理 …

华为OD机试真题B卷 Java 实现【最少交换次数】,附详细解题思路

一、题目描述 给出数字K&#xff0c;请输出所有小于K的整数组合到一起的最小交换次数。 组合一起是指满足条件的数字相邻&#xff0c;不要求相邻后在数组中的位置。 取值范围&#xff1a; -100 < K < 100 -100 < 数组中的数值 < 100 二、输入描述 第一行输入…

网络安全合规-ISO 27001(一)

实施ISO27001认证的步骤 在长期实践过程中&#xff0c;总结创新了一套高效可行的ISO27001/ISMS项目实施的规范流程。 一、现状调研分析&#xff1a;我方派咨询师去企业了解基本情况&#xff1b;本阶段主要是前期的准备和计划工作&#xff0c;包括明确评估目标&#xff0c;确定…

如何远程控制电脑,远程控制电脑的设置方法

很多人无论是在工作还是生活中使用电脑的时候都需要用到远程控制&#xff0c;因为它可以方便我们解决很多需要到现场操作的问题&#xff0c;在很大方面提升了我们的工作效率&#xff0c;下面来跟大家分享一下&#xff0c;如何远程控制电脑&#xff0c;远程控制电脑的设置方法 …

Web应用技术(第十五周/持续更新)

本次练习基于how2j和课本&#xff0c;进行SSM的初步整合&#xff0c;理解SSM整合的原理、好处。 SSM整合应用 1.简单的实例项目&#xff1a;2.原理分析&#xff1a;3.浅谈使用SSM框架化&#xff1a; 1.简单的实例项目&#xff1a; how2j 2.原理分析&#xff1a; 具体见流程图…

【网络】基础知识1

目录 网络发展 独立模式 网络互联 局域网LAN 广域网WAN 什么是协议 初识网络协议 协议分层 OSI七层模型 TCP/IP四层&#xff08;或五层&#xff09;模型 OSI和TCP/IP对比 网络传输流程 什么是报头 局域网通信原理 同网段的主机通讯 跨网段的主机通讯 数据包封装…

Kali搭建GVM完整版-渗透测试模拟环境(7)

上一篇:OpenVAS、GSA配置验证-渗透测试模拟环境(6)_luozhonghua2000的博客-CSDN博客 在bt5上面进行了安装,调试等配置验证,这篇在kali上面继续安装调试卸载等配置验证,中途版本问题,依赖问题,脚本编写都一一解决。 特别是因网络原因造成的rsync: [Receiver] safe_read f…

Sinkhorn-Knopp算法

Sinkhorn-Knopp是为了解决最优传输问题所提出的。 Sinkhorn算法原理 最优运输问题的目标就是以最小的成本将一个概率分布转换为另一个概率分布。即将概率分布 c 以最小的成本转换到概率分布 r&#xff0c;此时就要获得一个分配方案 P ∈ R n m 其中需满足以下条件&#xff1…

数据分析应该怎么学习?适合什么人学?

先来分享下适合学习数据分析的人群&#xff1a; 数据爱好者&#xff1a;对数据比较感兴趣&#xff0c;喜欢从数据中发现问题&#xff0c;有一定的见解&#xff0c;那么数据分析可以让这类小伙伴能够更好的理解和解释数据。市场营销、运营、业务分析&#xff1a;这类小伙伴学习…

SAP从入门到放弃系列之MRP区域

注&#xff1a;MRP AREA&#xff0c;本文中MRP范围或MRP区域都是指MRP AREA。另外MRP组和MRP区域是两个概念。 目录 MRP区域-库位层级 MRP区域-分包 其他事项 MRP区域-库位层级 除了在单个工厂级别、物料级别或产品组级别运行 MRP 之外&#xff0c;如果业务需要为以下运行 …

NLPChatGPTLLMs技术、源码、案例实战210课

NLP&ChatGPT&LLMs技术、源码、案例实战210课 超过12.5万行NLP/ChatGPT/LLMs代码的AI课程 讲师介绍 现任职于硅谷一家对话机器人CTO&#xff0c;专精于Conversational AI 在美国曾先后工作于硅谷最顶级的机器学习和人工智能实验室 CTO、杰出AI工程师、首席机器学习工程…

【机器学习】浅析过拟合

过度拟合 我们来想象如下一个场景&#xff1a;我们准备了10000张西瓜的照片让算法训练识别西瓜图像&#xff0c;但是这 10000张西瓜的图片都是有瓜梗的&#xff0c;算法在拟合西瓜的特征的时候&#xff0c;将西瓜带瓜梗当作了一个一般性的特征。此时出现一张没有瓜梗的西瓜照片…

探索Java面向对象编程的奇妙世界(七)

⭐ 字符串 String 类详解⭐ 阅读 API 文档⭐ String 类常用的方法⭐ 字符串相等的判断⭐ 内部类 ⭐ 字符串 String 类详解 String 是最常用的类&#xff0c;要掌握 String 类常见的方法&#xff0c;它底层实现也需要掌握好&#xff0c;不然在工作开发中很容易犯错。 &#x…

UI设计师必备的远程软件有哪些?

远程工作时&#xff0c;选择高效的远程软件非常重要。以下是3款提高工作效率的远程软件&#xff0c;希望对你有所帮助&#xff01; 1、即时设计协同设计 是国内首款集合原型、设计、交付、协作和资源管理于一体的高效远程设计软件。它提供实时在线协作功能&#xff0c;使用户…

14肖特基二极管

目录 一、介绍 二、结构 三、关键参数 1、导通压降VF 2、反向饱和漏电流IR 3、额定电流Io/IF 4、最大浪涌电流IFSM 5、最大反向峰值电压VRM 6、最大直流反向电压VR 7、最高工作频率fM 8、反向恢复时间Trr 9、最大耗散功率P 四、特点 1、反向恢复时间 2、缺点 五…

vue router 拆分路由 自动导入

目录 目录结构&#xff1a;拆分路由&#xff1a;自动导入&#xff1a;配置路由&#xff1a; 不求甚解&#xff0c;直接照搬就行了。 目录结构&#xff1a; 拆分路由&#xff1a; // danweiRouter.js export default {path: /danwei,name: danwei,component: () > import(.…

详解RGB和YUV色彩空间转换

前言 首先指出本文中的RGB指的是非线性RGB&#xff0c;意思就是经过了伽马校正&#xff0c;按照行业规矩应当写成RGB&#xff0c;但是为了书写方便&#xff0c;仍写成RGB。关于YUV有多种叫法&#xff0c;分别是YUV&#xff0c;YPbPr&#xff0c;YCbCr。因此本文将首先指出他们之…

这 13 种职业用AI提效的 40 类场景盘点

随着人工智能技术的发展&#xff0c;职业领域出现了诸如我们“小蜜蜂助手Beezy”等神奇的工具&#xff0c;大幅度提升了各行各业里从业人员的工作效率。 笔者今天将详述13种常见职业&#xff0c;分别是如何利用这些工具在实际工作过程中来帮助自己提升效率的。大量干货和私藏宝…