pytorch复习笔记--loss.backward()、optimizer.step()和optimizer.zero_grad()的用法

news2024/11/30 0:41:19

目录

1--loss.backward()的用法

2--optimizer.step()的用法

3--optimizer.zero_grad()的用法

4--举例说明

5--参考


1--loss.backward()的用法

作用:将损失loss向输入测进行反向传播;这一步会计算所有变量x的梯度值 \small \frac{d}{dx}loss,并将其累积为\small x*grad 进行备用,即 \small x*grad=(x*grad)_{pre}+\frac{d}{dx}loss ,公式中的 \small (x*grad)_{pre} 指的是上一个epoch累积的梯度。

2--optimizer.step()的用法

作用:利用优化器对参数x进行更新,以随机梯度下降SGD为例,更新的公式为:\small x = x - lr*(x*grad),lr 表示学习率 learning rate,减号表示沿着梯度的反方向进行更新;

3--optimizer.zero_grad()的用法

作用:清除优化器关于所有参数x的累计梯度值 \small x*grad,一般在loss.backward()前使用,即清除 \small (x*grad)_{pre}

4--举例说明

① 展示 loss.backward() 和 optimizer.step() 的用法:

import torch

# 初始化参数值x
x = torch.tensor([1., 2.], requires_grad=True)

# 模拟网络运算,计算输出值y
y = 100*x

# 定义损失
loss = y.sum() 

print("x:", x)
print("y:", y)
print("loss:", loss)


print("反向传播前, 参数的梯度为: ", x.grad)
# 进行反向传播
loss.backward()  # 计算梯度grad, 更新 x*grad    
print("反向传播后, 参数的梯度为: ", x.grad)

# 定义优化器
optim = torch.optim.SGD([x], lr = 0.001) # SGD, lr = 0.001

print("更新参数前, x为: ", x) 
optim.step()  # 更新x
print("更新参数后, x为: ", x)

② 展示不使用 optimizer.zero_grad() 的梯度:

import torch

# 初始化参数值x
x = torch.tensor([1., 2.], requires_grad=True)

# 模拟网络运算,计算输出值y
y = 100*x

# 定义损失
loss = y.sum() 

print("x:", x)
print("y:", y)
print("loss:", loss)


print("反向传播前, 参数的梯度为: ", x.grad)
# 进行反向传播
loss.backward()  # 计算梯度grad, 更新 x*grad    
print("反向传播后, 参数的梯度为: ", x.grad)

# 定义优化器
optim = torch.optim.SGD([x], lr = 0.001) # SGD, lr = 0.001

print("更新参数前, x为: ", x) 
optim.step()  # 更新x
print("更新参数后, x为: ", x)

# 再进行一次网络运算
y = 100*x

# 定义损失
loss = y.sum()

# 不进行optimizer.zero_grad()
loss.backward()  # 计算梯度grad, 更新 x*grad 
print("不进行optimizer.zero_grad(), 参数的梯度为: ", x.grad)

③ 展示使用 optimizer.zero_grad() 的梯度:

import torch

# 初始化参数值x
x = torch.tensor([1., 2.], requires_grad=True)

# 模拟网络运算,计算输出值y
y = 100*x

# 定义损失
loss = y.sum() 

print("x:", x)
print("y:", y)
print("loss:", loss)


print("反向传播前, 参数的梯度为: ", x.grad)
# 进行反向传播
loss.backward()  # 计算梯度grad, 更新 x*grad    
print("反向传播后, 参数的梯度为: ", x.grad)

# 定义优化器
optim = torch.optim.SGD([x], lr = 0.001) # SGD, lr = 0.001

print("更新参数前, x为: ", x) 
optim.step()  # 更新x
print("更新参数后, x为: ", x)

# 再进行一次网络运算
y = 100*x

# 定义损失
loss = y.sum()

# 进行optimizer.zero_grad()
optim.zero_grad()
loss.backward()  # 计算梯度grad, 更新 x*grad 
print("进行optimizer.zero_grad(), 参数的梯度为: ", x.grad)

通过②和③的对比,能看出 optimizer.zero_grad() 的作用是清除之前累积的梯度值。

5--参考

参考1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/49025.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

迎合国家新政策,共享购联合共享经济,三方互利,消费增值

共享单车打通出行“最后一公里”,共享充电宝让人们出门在外免于“电量烦恼”,共享办公降低办公成本……共享经济已深入到人们日常生活。近日,国家信息中心发布的《中国共享经济发展报告(2022)》显示,2021年…

EMR-Jindo Spark 核心引擎优化

Jindo-Spark 是阿里云智能E-MapReduce 团队在开源的Apache Spark 基础上自主研发的分布式云原生 OLAP 引擎,已经在近千E-MapReduce 客户中大规模部署使用。Jindo Spark 在开源版本基础上做了大量优化和扩展,深度集成和连接了众多阿里云基础服务。凭借该引…

工作流-流程实例【ProcessInstance】与执行实例【Execution】

一、ProcessInstance与Execution的区别 这是一个Activiti的难点,能够懂得这个,工作流也就入门大半了。 下面,我就细致的讲解一下他们的区别。 (1)首先,我们来看一张我总结的图片(这个图片中两条…

Flink-处理函数以及TopN运用案例

7 处理函数 7.1 概述 更底层的操作,直接对流进行操作,直接调用处理函数 7.2 基本处理函数ProcessFunction 分析 ProcessFunction的来源 处理函数继承了AbstractRichFunction富函数抽象类,因此就具有访问状态(state)和其他运行时环境 例…

Day39——Dp专题

文章目录01背包二维数组一维数组6.整数拆分7.不同的二叉搜索01背包 01背包:每一个物品只能选一次,选或者不选 状态标识:f[i][j]:所有只考虑前i个物品,且总体积不超j的所有选法的集合 属性:Max 状态计算&a…

链表之反转链表

文章目录链表之反转链表题目描述解题思路代码实现链表之反转链表 力扣链接 题目描述 定义一个函数,输入一个链表的头节点,反转该链表并输出反转后链表的头节点。 示例: ​ 输入: 1->2->3->4->5->NULL ​ 输出: 5->4-&…

如何设计高性能架构

高性能复杂度模型 高性能复杂度分析和设计 单机 集群 任务分配 将任务分配给多个服务器执行 复杂度分析 增加“任务分配器”节点,可以是独立的服务器,也可以是SDK任务分配器需要管理所有的服务器,可以通过配置文件,也可以通过…

RK3588移植-opencv交叉编译aarch64

文章参考:https://blog.csdn.net/KayChanGEEK/article/details/80365320 文章目录概括准备资源交叉编译OPENCV修改CMakelist文件将lib库复制到/lib目录注意:本文中的所有配置相关路径都与当前安装的路径有关,需要根据自己的环境进行自行修改&…

『Java课设』JavaSwing+MySQL实现学生成绩管理系统

👨‍🎓作者简介:一位喜欢写作,计科专业大三菜鸟 🏡个人主页:starry陆离 如果文章有帮到你的话记得点赞👍收藏💗支持一下哦 『Java课设』JavaSwingMySQL实现学生成绩管理系统前言1.开…

SparkMlib 之随机森林及其案例

文章目录什么是随机森林?随机森林的优缺点随机森林示例——鸢尾花分类什么是随机森林? 随机森林算法是机器学习、计算机视觉等领域内应用极为广泛的一个算法,它不仅可以用来做分类,也可用来做回归即预测,随机森林机由…

RabbitMQ之可靠性分析

在使用任何消息中间件的过程中,难免会出现消息丢失等异常情况,这个时候就需要有一个良好的机制来跟踪记录消息的过程(轨迹溯源),帮助我们排查问题。在RabbitMQ 中可以使用Firehose 功能来实现消息追踪,Fire…

艾美捷MTT细胞增殖检测试剂盒结果示例引用文献

艾美捷MTT细胞增殖检测试剂盒测定原理: 该试剂盒提供了比色形式测量和监测细胞增殖,含有足够的试剂用于评估在96孔板中进行960次测定或在24孔板中进行192次测定。细胞可以被镀,然后用影响增殖的化合物或药剂。然后用增殖试剂检测细胞&#x…

3.矩阵计算及导数基础

1. 梯度 将导数拓展到向量。 1. 标量对向量求导 x是列向量,y是标量,求导之后变成了行向量 ps: x1^2 2x2^2 这个函数可以画成等高线,对于(x1,x2)这个点,可以做等高线的切线,再做出…

Spark Streaming(二)

声明: 文章中代码及相关语句为自己根据相应理解编写,文章中出现的相关图片为自己实践中的截图和相关技术对应的图片,若有相关异议,请联系删除。感谢。转载请注明出处,感谢。 By luoyepiaoxue2014 B站&#xff…

动态规划算法(2)最长回文子串详解

文章目录最长回文字串动态规划代码示例前篇: (1)初识动态规划 最长回文字串 传送门: https://leetcode.cn/problems/longest-palindromic-substring/description/ 给你一个字符串 s,找到 s 中最长的回文子串。 s &qu…

大数据学习:使用Java API操作HDFS

文章目录一、创建Maven项目二、添加依赖三、创建日志属性文件四、在HDFS上创建文件五、写入HDFS文件1、将数据直接写入HDFS文件2、将本地文件写入HDFS文件六、读取HDFS文件1、读取HDFS文件直接在控制台显示2、读取HDFS文件,保存为本地文件一、创建Maven项目 二、添加…

Spring Security 中重要对象汇总

前言 已经写了好几篇关于 Spring Security 的文章了,相信很多读者还是对 Spring Security 的云里雾里的。这是因为对 Spring Security 中的对象还不了解。本文就来介绍介绍一下常用对象。 认证流程 SecurityContextHolder 用户认证通过后,为了避免用…

【JavaWeb】Servlet系列 --- HttpServlet【底层源码分析】

HttpServlet一、什么是协议?什么是HTTP协议?二、HTTP的请求协议(B -- > S)1. HTTP的请求协议包括4部分(记住)2. HTTP请求协议的具体报文:GET请求3. HTTP请求协议的具体报文:POST请…

生成式模型和判别式模型

决策函数Yf(x)Y f(x)Yf(x)或者条件概率分布 P(Y∣X)P(Y|X)P(Y∣X) 监督学习的任务都是从数据中学习一个模型(也叫做分类器),应用这一模型,对给定的输入xxx预测相应的输出YYY,这个模型的一般形式为:决策函数Yf(x)Y f(x)Yf(x)&…

java 每日一练(6)

java 每日一练(6) 文章目录单选不定项选择题编程题单选 1.关于抽象类与最终类,下列说法错误的是?   A 抽象类能被继承,最终类只能被实例化。   B 抽象类和最终类都可以被声明使用   C 抽象类中可以没有抽象方法,最终类中可以没…