PyTorch使用------自动微分模块

news2024/9/20 18:37:01

目录

🍔 梯度基本计算

1.1 单标量梯度的计算

1.2 单向量梯度的计算

1.3 多标量梯度计算

1.4 多向量梯度计算

1.5 运行结果💯

🍔 控制梯度计算

2.1 控制不计算梯度

2.2 注意: 累计梯度

2.3 梯度下降优化最优解

2.4 运行结果💯

🍔 梯度计算注意

3.1 detach 函数用法

3.2 detach 前后张量共享内存

3.3 运行结果💯

🍔 小节


学习目标

🍀 掌握梯度计算


自动微分(Autograd)模块对张量做了进一步的封装,具有自动求导功能。自动微分模块是构成神经网络训练的必要模块,在神经网络的反向传播过程中,Autograd 模块基于正向计算的结果对当前的参数进行微分计算,从而实现网络权重参数的更新。

🍔 梯度基本计算

我们使用 backward 方法、grad 属性来实现梯度的计算和访问.

import torch

1.1 单标量梯度的计算

   

 # y = x**2 + 20
    def test01():

    # 定义需要求导的张量
    # 张量的值类型必须是浮点类型
    x = torch.tensor(10, requires_grad=True, dtype=torch.float64)
    # 变量经过中间运算
    f = x ** 2 + 20
    # 自动微分
    f.backward()
    # 打印 x 变量的梯度
    # backward 函数计算的梯度值会存储在张量的 grad 变量中
    print(x.grad)


1.2 单向量梯度的计算

# y = x**2 + 20
def test02():

    # 定义需要求导张量
    x = torch.tensor([10, 20, 30, 40], requires_grad=True, dtype=torch.float64)
    # 变量经过中间计算
    f1 = x ** 2 + 20

    # 注意:
    # 由于求导的结果必须是标量
    # 而 f 的结果是: tensor([120., 420.])
    # 所以, 不能直接自动微分
    # 需要将结果计算为标量才能进行计算
    f2 = f1.mean()  # f2 = 1/2 * x

    # 自动微分
    f2.backward()

    # 打印 x 变量的梯度
    print(x.grad)

1.3 多标量梯度计算

# y = x1 ** 2 + x2 ** 2 + x1*x2
def test03():

    # 定义需要计算梯度的张量
    x1 = torch.tensor(10, requires_grad=True, dtype=torch.float64)
    x2 = torch.tensor(20, requires_grad=True, dtype=torch.float64)

    # 经过中间的计算
    y = x1**2 + x2**2 + x1*x2

    # 将输出结果变为标量
    y = y.sum()

    # 自动微分
    y.backward()

    # 打印两个变量的梯度
    print(x1.grad, x2.grad)


1.4 多向量梯度计算

def test04():

    # 定义需要计算梯度的张量
    x1 = torch.tensor([10, 20], requires_grad=True, dtype=torch.float64)
    x2 = torch.tensor([30, 40], requires_grad=True, dtype=torch.float64)

    # 经过中间的计算
    y = x1 ** 2 + x2 ** 2 + x1 * x2
    print(y)

    # 将输出结果变为标量
    y = y.sum()

    # 自动微分
    y.backward()

    # 打印两个变量的梯度
    print(x1.grad, x2.grad)


if __name__ == '__main__':
    test04()

1.5 运行结果💯

tensor(20., dtype=torch.float64)
tensor([ 5., 10., 15., 20.], dtype=torch.float64)
tensor(40., dtype=torch.float64) tensor(50., dtype=torch.float64)
tensor([1300., 2800.], dtype=torch.float64, grad_fn=<AddBackward0>)
tensor([50., 80.], dtype=torch.float64) tensor([ 70., 100.], dtype=torch.float64)

🍔 控制梯度计算

我们可以通过一些方法使得在 requires_grad=True 的张量在某些时候计算不进行梯度计算。

import torch

2.1 控制不计算梯度

def test01():

    x = torch.tensor(10, requires_grad=True, dtype=torch.float64)
    print(x.requires_grad)

    # 第一种方式: 对代码进行装饰
    with torch.no_grad():
        y = x ** 2
    print(y.requires_grad)

    # 第二种方式: 对函数进行装饰
    @torch.no_grad()
    def my_func(x):
        return x ** 2
    print(my_func(x).requires_grad)


    # 第三种方式
    torch.set_grad_enabled(False)
    y = x ** 2
    print(y.requires_grad)


2.2 注意: 累计梯度

def test02():

    # 定义需要求导张量
    x = torch.tensor([10, 20, 30, 40], requires_grad=True, dtype=torch.float64)

    for _ in range(3):

        f1 = x ** 2 + 20
        f2 = f1.mean()

        # 默认张量的 grad 属性会累计历史梯度值
        # 所以, 需要我们每次手动清理上次的梯度
        # 注意: 一开始梯度不存在, 需要做判断
        if x.grad is not None:
            x.grad.data.zero_()

        f2.backward()
        print(x.grad)


2.3 梯度下降优化最优解

def test03():

    # y = x**2
    x = torch.tensor(10, requires_grad=True, dtype=torch.float64)

    for _ in range(5000):

        # 正向计算
        f = x ** 2

        # 梯度清零
        if x.grad is not None:
            x.grad.data.zero_()

        # 反向传播计算梯度
        f.backward()

        # 更新参数
        x.data = x.data - 0.001 * x.grad

        print('%.10f' % x.data)


if __name__ == '__main__':
    test01()
    test02()
    test03()

2.4 运行结果💯

True
False
False
False
tensor([ 5., 10., 15., 20.], dtype=torch.float64)
tensor([ 5., 10., 15., 20.], dtype=torch.float64)
tensor([ 5., 10., 15., 20.], dtype=torch.float64)

🍔 梯度计算注意

当对设置 requires_grad=True 的张量使用 numpy 函数进行转换时, 会出现如下报错:

Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

此时, 需要先使用 detach 函数将张量进行分离, 再使用 numpy 函数.

注意: detach 之后会产生一个新的张量, 新的张量作为叶子结点,并且该张量和原来的张量共享数据, 但是分离后的张量不需要计算梯度。

import torch

3.1 detach 函数用法

def test01():

    x = torch.tensor([10, 20], requires_grad=True, dtype=torch.float64)

    # Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.
    # print(x.numpy())  # 错误
    print(x.detach().numpy())  # 正确


3.2 detach 前后张量共享内存

def test02():

    x1 = torch.tensor([10, 20], requires_grad=True, dtype=torch.float64)

    # x2 作为叶子结点
    x2 = x1.detach()

    # 两个张量的值一样: 140421811165776 140421811165776
    print(id(x1.data), id(x2.data))
    x2.data = torch.tensor([100, 200])
    print(x1)
    print(x2)

    # x2 不会自动计算梯度: False
    print(x2.requires_grad)


if __name__ == '__main__':
    test01()
    test02()

3.3 运行结果💯

10. 20.]
140495634222288 140495634222288
tensor([10., 20.], dtype=torch.float64, requires_grad=True)
tensor([100, 200])
False

🍔 小节

本小节主要讲解了 PyTorch 中非常重要的自动微分模块的使用和理解。我们对需要计算梯度的张量需要设置 requires_grad=True 属性,并且需要注意的是梯度是累计的,在每次计算梯度前需要先进行梯度清零。

 

😀 小言在此感谢大家的支持😀 

顺便问一下大佬们,最擅长使用的编程语言是什么呢~

欢迎评论区讨论哦~

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2149696.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

在grafana上配置显示全部node资源信息概览

在grafana上配置显示全部node资源信息概览&#xff0c;便于巡检 1&#xff0c;注册grafana官网账号&#xff1a;Grafana dashboards | Grafana Labs&#xfeff; 2、寻找可以展示所有node资源概览信息的dashboard&#xff0c;并下载支持prometheus数据源的dashboard&#xff…

论文开题不用愁,5步带你轻松搞定研究计划!

开题报告是每位研究生在论文写作初期必须完成的一项重要任务。它不仅是对自己研究方向的初步规划&#xff0c;也是导师和评审专家衡量课题可行性的重要依据。一份优秀的开题报告不仅能够清晰阐述研究的背景、目的和意义&#xff0c;还能展示研究的创新性和可行性&#xff0c;从…

与谷歌旗下自动驾驶公司扩大合作后,Uber股票值得买入吗?

来源&#xff1a;猛兽财经 作者&#xff1a;猛兽财经 猛兽财经核心观点&#xff1a; &#xff08;1&#xff09;与Waymo扩大在无人驾驶出租车方面的合作后&#xff0c;Uber的股价上涨了5%。 &#xff08;2&#xff09;Uber的第二季度的财务业绩非常强劲&#xff0c;收入增长了…

【Python】练习:控制语句(二)第4关

第4关&#xff1a;控制结构综合实训 第一题第二题&#xff08;※&#xff09;第三题&#xff08;※&#xff09;第四题&#xff08;※&#xff09;第五题&#xff08;※&#xff09;第六题&#xff08;※&#xff09; 第一题 #第一题def rankHurricane(velocity):#请在下面编写…

[Simpfun游戏云1]搭建MC Java+基岩互通生存游戏服务器

众所周知&#xff0c;MC有多个客户端&#xff0c;像常见的比如Java Edition和基岩等&#xff0c;这就导致&#xff0c;比如我知道一个超级好玩的JE服务器&#xff0c;但我又想使用基岩版来玩&#xff0c;肯定是不行的&#xff0c;因为通讯协议不一样。 这就有一些人才发明了多…

浅谈死锁以及判断死锁的方法

引言 我们在并发情况下见过很多种锁&#xff0c;synchronized&#xff0c;ReentrantLock 等等&#xff0c;这些锁是为了保证线程安全&#xff0c;使线程同步的锁&#xff0c;与今天所要学习的死锁并不相同&#xff0c;死锁并不是一种锁&#xff0c;而是一种现象。 官方定义&a…

Live800:从心出发,以情动人:构建深度客户服务文化

在当今这个竞争激烈的市场环境中&#xff0c;企业之间的较量已不仅仅局限于产品质量的比拼&#xff0c;更在于谁能提供更优质、更贴心的客户服务。在这个背景下&#xff0c;“从心出发&#xff0c;以情动人”成为了构建深度客户服务文化的核心理念&#xff0c;它要求企业不仅要…

Nacos注册中心(Nacos安装,快速入门,多级存储,负载均衡,环境隔离,配置管理,热更新,集群搭建,nginx反向代理)

Nacos注册中心 1. Nacos安装 (windows) 1.1 官网下载 网址:https://github.com/alibaba/nacos/releases/tag/1.4.1 这里下载nacos1.4.1的Windows版本为例1.2 解压到本地 注: 解压到非中文目录 nacos默认端口号:8848,可在配置文件properties中修改1.3 启动nacos 在G:\Sp…

基于单片机的智能电话控制系统设计

摘要&#xff1a; 为了能够使用电话实现电器设备的控制&#xff0c;文中通过单片机及双音多频解码集成电路&#xff0c;使用用 户通过电话输入相应的指令就能够实现远程设备的智能化控制。文章主要对系统的构成、软件及 硬件设计进行了简单的介绍&#xff0c;并且对其中的电路…

前端vue-v-for循环遍历

&#xff08;item,index&#xff09;in list中&#xff0c;index这个索引可加可不加&#xff0c;item代表list中的每一个元素&#xff0c;list可以是数组&#xff0c;也可以是对象&#xff0c;要遍历谁就把 &#xff08;item,index&#xff09;in list加在哪里。 关于加不加&a…

抓机遇,促发展——2025第十二届广州国际汽车零部件加工技术及汽车模具展览会

新能源时代&#xff0c;电动化、智能化正在重塑全球汽车市场格局。中国自主品牌新能源汽车的市占率不断提升、头部效应初显&#xff0c;更有机会带动相关供应链企业成长。中国的零部件企业有望抓住变局下的机会&#xff0c;在新一轮竞争中崛起。 智能电动车时代&#xff0c;汽车…

OpenHarmony(鸿蒙南向开发)——小型系统内核(LiteOS-A)【内核通信机制】上

往期知识点记录&#xff1a; 鸿蒙&#xff08;HarmonyOS&#xff09;应用层开发&#xff08;北向&#xff09;知识点汇总 鸿蒙&#xff08;OpenHarmony&#xff09;南向开发保姆级知识点汇总~ 子系统开发内核 轻量系统内核&#xff08;LiteOS-M&#xff09; 轻量系统内核&#…

操作系统迁移(CentOs -> Ubuntu)

目录 1. CentOs操作系统:备份数据 1.1 gitee备份 1.1.1 CentOs安装git 1.1.1.1 运行安装命令 1.1.1.2 运行安装命令时出错 1.1.1.3 再次执行安装命令 1.1.2 gitee创建仓库 1.1.2.1 创建仓库 1.1.3 备份 1.1.3.1 复制链接 1.1.3.2 克隆仓库 1.1.3.3 备份 1.3.3.4 查…

docker学习第一步:基于Linux安装docker

要求Linux下的CentOS 7.0 以上的版本 01 安装docker版本仓库 1、设置仓库 sudo yum install -y yum-utils device-mapper-persistent-data lvm2 2、稳定仓库 sudo yum-config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo 02 安装及使用do…

一文看懂误植域名的威胁和对策

你熟悉“误植域名”&#xff08;typosquatting&#xff09;这个术语吗&#xff1f;这是一种域名抢注&#xff0c;可能会损害你的品牌和声誉。我们在这篇博文中将探讨误植域名是什么、它与域名抢注的区别&#xff0c;并提供如何防止这种行为的要点。另外&#xff0c;我们将讨论沦…

选择优质代理IP建议分享

“在互联网的广阔世界中&#xff0c;代理IP作为一种重要的网络工具&#xff0c;扮演着连接用户与目标服务器之间的桥梁角色。不同类型的代理IP适用于不同的场景和需求&#xff0c;因此选择合适的代理IP类型对于提高网络访问效率、保护用户隐私至关重要。” 一、代理IP类型概述 …

如何在多台Linux虚拟机上安装和配置Zookeeper集群

Zookeeper 是一个高性能的协调服务&#xff0c;广泛应用于分布式系统中。本文将详细介绍如何在多台Linux虚拟机上安装和配置Zookeeper集群。下面以三台服务器&#xff08;node1、node2、node3&#xff09;进行讲解。 前置准备&#xff1a; 配置多台Linux虚拟机参考&#xff1a;…

跨国公司决策的影响与中国IT产业的应对

跨国公司在华研发中心的调整是一个复杂的现象&#xff0c;它可能受到多种因素的影响&#xff0c;包括全球经济环境的变化、成本考量、战略重心的转移以及地缘政治因素等。IBM中国研发中心的撤出可能会对中国IT行业造成短期的就业压力&#xff0c;加速人才流动&#xff0c;并促使…

YOLOv5白皮书-第Y1周:调用官方权重进行检测

>- **&#x1f368; 本文为[&#x1f517;365天深度学习训练营](小团体&#xff5e;第八波) 中的学习记录博客** >- **&#x1f356; 原作者&#xff1a;[K同学啊](K同学啊-CSDN博客)** 一、前言 拖了好久&#xff0c;终于要开始目标检测系列了。自己想过好几次&#xf…

【Python机器学习】NLP信息提取——值得提取的信息

目录 提取GPS信息 提取日期 如下一些关键的定量信息值得“手写”正则表达式&#xff1a; GPS位置&#xff1b;日期&#xff1b;价格&#xff1b;数字。 和上述可以通过正则表达式轻松捕获的信息相比&#xff0c;其他一些重要的自然语言信息需要更复杂的模式&#xff1a; 问…