2-2 自动微分机制

news2024/11/28 4:37:08

神经网络是如何更新网络参数的呢?
为什么需要自动微分机制?
以及什么是自动微分机制?
神经网络通常依赖反向传播求梯度来更新网络参数,求梯度过程通常是一件非常复杂而容易出错的事情。而深度学习框架可以帮助我们自动地完成这种求梯度运算。
Pytorch一般通过反向传播 backward 方法 实现这种求梯度计算。该方法求得的梯度将存在对应自变量张量的grad属性下。
除此之外,也能够调用torch.autograd.grad 函数来实现求梯度计算。
这就是Pytorch的自动微分机制。

一、利用backward方法求导数

backward 方法通常在一个标量张量上调用,该方法求得的梯度将存在对应自变量张量的grad属性下。
如果调用的张量非标量,则要传入一个和它同形状的gradient参数张量。
相当于用该gradient参数张量与调用张量作向量点乘,得到的标量结果再反向传播。

1.标量的反向传播

import numpy as np 
import torch 

# f(x) = a*x**2 + b*x + c的导数

x = torch.tensor(0.0,requires_grad = True) # x需要被求导
a = torch.tensor(1.0)
b = torch.tensor(-2.0)
c = torch.tensor(1.0)
y = a*torch.pow(x,2) + b*x + c 

y.backward() # 利用backward() 注意是y调用这个方法 而前面已经说明 x 需要被求导
dy_dx = x.grad # 
print(dy_dx)

image.png

2.非标量的反向传播

如果调用的张量非标量,则要传入一个和它同形状 的gradient参数张量。

import numpy as np 
import torch 

# f(x) = a*x**2 + b*x + c

x = torch.tensor([[0.0,0.0],[1.0,2.0]],requires_grad = True) # x需要被求导
a = torch.tensor(1.0)
b = torch.tensor(-2.0)
c = torch.tensor(1.0)
y = a*torch.pow(x,2) + b*x + c 

gradient = torch.tensor([[1.0,1.0],[1.0,1.0]])

print("x:\n",x)
print("y:\n",y)
y.backward(gradient = gradient)
x_grad = x.grad
print("x_grad:\n",x_grad)

image.png

3.非标量的反向传播可以用标量的反向传播实现

接2 2就相当于用该gradient参数张量与调用张量作向量点乘,得到的标量结果再反向传播。

import numpy as np 
import torch 

# f(x) = a*x**2 + b*x + c

x = torch.tensor([[0.0,0.0],[1.0,2.0]],requires_grad = True) # x需要被求导
a = torch.tensor(1.0)
b = torch.tensor(-2.0)
c = torch.tensor(1.0)
y = a*torch.pow(x,2) + b*x + c 

gradient = torch.tensor([[1.0,1.0],[1.0,1.0]])
z = torch.sum(y*gradient) # torch.sum计算标量相乘获取标量之后再反向传播

print("x:",x)
print("y:",y)
z.backward()
x_grad = x.grad
print("x_grad:\n",x_grad)

image.png

二、利用autograd.grad方法求导数

import numpy as np 
import torch 

# f(x) = a*x**2 + b*x + c的导数

x = torch.tensor(0.0,requires_grad = True) # x需要被求导
a = torch.tensor(1.0)
b = torch.tensor(-2.0)
c = torch.tensor(1.0)
y = a*torch.pow(x,2) + b*x + c


# create_graph 设置为 True 将允许创建更高阶的导数 
dy_dx = torch.autograd.grad(y,x,create_graph=True)[0] # 去除[0]报错 AttributeError: 'tuple' object has no attribute 'data'
print(dy_dx.data) # 一阶导

# 求二阶导数
dy2_dx2 = torch.autograd.grad(dy_dx,x)[0] 

print(dy2_dx2.data) # 二阶导


image.png

import numpy as np 
import torch 

x1 = torch.tensor(1.0,requires_grad = True) # x需要被求导
x2 = torch.tensor(2.0,requires_grad = True)

y1 = x1*x2
y2 = x1+x2


# 允许同时对多个自变量求导数
(dy1_dx1,dy1_dx2) = torch.autograd.grad(outputs=y1,inputs = [x1,x2],retain_graph = True)
print(dy1_dx1,dy1_dx2)

# 如果有多个因变量,相当于把多个因变量的梯度结果求和
(dy12_dx1,dy12_dx2) = torch.autograd.grad(outputs=[y1,y2],inputs = [x1,x2])
print(dy12_dx1,dy12_dx2)


image.png

三、利用自动微分和优化器求最小值

import numpy as np 
import torch 

# f(x) = a*x**2 + b*x + c的最小值

x = torch.tensor(0.0,requires_grad = True) # x需要被求导
a = torch.tensor(1.0)
b = torch.tensor(-2.0)
c = torch.tensor(1.0)

optimizer = torch.optim.SGD(params=[x],lr = 0.01)


def f(x):
    result = a*torch.pow(x,2) + b*x + c 
    return(result)

for i in range(500):
    optimizer.zero_grad()
    y = f(x)
    y.backward()
    optimizer.step()
   
print("y=",f(x).data,";","x=",x.data) # 当 x=0 时 y取得最小值0

image.png
参考:https://github.com/lyhue1991/eat_pytorch_in_20_days

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1008557.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PostgreSQL数据库IPC——SI Message Queue

SI Message Queue代码位于src/backend/storage/ipc/sinvaladt.c和src/backend/storage/ipc/sinval.c文件中,属于PostgreSQL数据库IPC进程间通信的一种方式【之前介绍过PostgreSQL数据库PMsignal——后端进程\Postmaster信号通信也是作为PostgreSQL数据库IPC进程间通…

Linux下Minio分布式存储安装配置(图文详细)

文章目录 Linux下Minio分布式存储安装配置(图文详细)1 资源准备1.1 创建存储目录1.2 获取Minio Server资源1.3 获取Minio Client资源 2 Minio Server安装配置2.1 切换目录2.2 后台启动2.3 查看进程2.4 控制台测试 3 Minio Client安装配置3.1 切换目录3.2 移动mc脚本3.2 运行mc命…

记录一次LiteFlow项目实战

文章目录 学习LiteFlowspring boot整合LiteFlow依赖配置组件定义spring boot配置文件规则文件的定义 执行 组件EL规则串行并行 动态构建组件动态构建chain(流程)销毁chain高级特性 题外话: 最近喜欢上骑摩托车了,不是多大排量的摩…

C++ if...else 语句

一个 if 语句 后可跟一个可选的 else 语句,else 语句在布尔表达式为假时执行。 语法 C 中 if…else 语句的语法: if(boolean_expression) {// 如果布尔表达式为真将执行的语句 } else {// 如果布尔表达式为假将执行的语句 }如果布尔表达式为 true&…

联想拯救者Lenovo Legion Y7000 2020H(81Y7)原厂Win10系统

链接:https://pan.baidu.com/s/15GwrfjvhK3-_OlhkfPyUbA?pwd4v03 系统自带所有驱动、出厂主题壁纸LOGO、Office办公软件、联想电脑管家等预装程序 所需要工具:16G或以上的U盘 文件格式:ISO 文件大小:10.7GB 注:恢复…

用vb语言编写一个抄底的源代码程序实例

以下是一个基于通达信软件编写的简单抄底源代码程序,用于自动识别股票的底部形态并发出买入信号: vbs 复制 导入通达信软件自带的股票数据接口 Dim TdxApi Set TdxApi CreateObject("TdxApi.TdxLocalAPI") 设置股票代码、周期、数据类型等参…

算法通关村第十九关:白银挑战-动态规划高频问题

白银挑战-动态规划高频问题 1. 最少硬币数 LeetCode 322 https://leetcode.cn/problems/coin-change/description/ 思路分析 尝试用回溯来实现 假如coins[2,5,7],amount27,求解过程中,每个位置都可以从[2,5,7]中选择,因此可以…

深入理解JVM虚拟机第六篇:内存结构与类加载子系统概述

文章目录 一:内存结构概述 1:运行时数据区 2:运行时数据区简图 3:运行时数据区详细图中英文版 二:类加载器子系统 1:加载 2:连接 3:初始化 一:内存结构概述 1…

CAN基础概念

文章目录 目的控制器、收发器、总线帧格式CAN2.0和CAN-FD波特率与采样点工作模式总结 目的 CAN是非常常用的一种数据总线,被广泛用在各种车辆系统中。大多数时候CAN的控制器和收发器干了比较多的工作,从而对于写代码使用来说比较简单。这篇文章将对CAN使…

经历网数据库共享

经历网,为留住您的经历而生 点击 经历网 进入网站查看当前数据 经历网网址:https://www.jili20.com/ 以下 数据库 数据 截止至 2023年9月13日 1)百度网盘 提取 链接:https://pan.baidu.com/s/1WwR4cI9lbSAYTuffo8qmVQ 或点击 此…

微信小程序的在线课外阅读打卡记录系统uniapp

本文从管理员、学生和教师的功能要求出发,中学课外阅读记录系统中的功能模块主要是实现学生、教师、阅读任务、阅读打卡、提醒信息、阅读排行、任务计划、阅读类型、在线考试等。经过认真细致的研究,精心准备和规划,最后测试成功,…

zemax畸变与消畸变

物体不同位置的放大率不同,产生图形变形 这里选择zemax自带的案例: 畸变效果: 明显的负畸变(桶形畸变) 从场曲畸变图中可以看出: 该系统的最大畸变大约为38% 放入图片观察成像效果: 优化操作数…

GpsAndMap模块开源,欢迎测评

背景 之前的文章有提到,最近在使用folium的过程中,深感对于一个非专业人员来说,GPS坐标以及其所隐含的GPS坐标系,以及不同GPS坐标系之间的相互转换关系,不是一个十分清晰的概念,往往造成在使用GPS坐标在fo…

基本的SELECT语句——“MySQL数据库”

各位CSDN的uu们好呀,好久没有更新小雅兰的MySQL数据库专栏啦,接下来一段时间,小雅兰都会更新MySQL数据库的知识,下面,让我们进入今天的主题吧——基本的SELECT语句!!! SQL概述 SQL语…

Linux - 性能可观察性工具

文章目录 常用的Linux性能可观察性工具图解小结 常用的Linux性能可观察性工具 以下是一些常用的Linux性能可观察性工具: top: 显示实时的系统性能数据,包括CPU使用率、内存使用情况、进程信息等。 htop: 类似于top,但提供了更多的交互式功能…

谷粒商城----rabbitmq

一、 为什么要用 MQ? 三大好处,削峰,解耦,异步。 削峰 比如秒杀,或者高铁抢票,请求在某些时间点实在是太多了,服务器处理不过来,可以把请求放到 MQ 里面缓冲一下,把一秒内收到的…

Arcgis栅格转点时ERROR 999999: 执行函数时出错。 无法创建要素数据集。 执行(RasterToPoint)失败

Arcgis栅格转点时ERROR 999999: 执行函数时出错。 无法创建要素数据集。 执行(RasterToPoint)失败。 问题描述 原因 输出点要素的位置不对 解决方案 点击新建文件地理数据库 然后在该文件地理数据库下输出

RocketMQ 消息传递模型

文章目录 0. 前言1. RocketMQ的消息传递模型1.1. 同步发送1.2. 异步发送1.3. 单向发送 2. RocketMQ的批量发送和消费2.1 批量发送2.2 批量消费2.3 Spring Boot集成RocketMQ官方starter 示例 3. 总结4. 参考文档5. 源码地址 0. 前言 RocketMQ 支持6种消息传递方式,我…

【Java 基础篇】Java 泛型:类型安全的编程指南

在 Java 编程中,泛型是一项强大的特性,它允许您编写更通用、更安全和更灵活的代码。无论您是初学者还是有经验的 Java 开发人员,了解和掌握泛型都是非常重要的。本篇博客将从基础概念一直深入到高级应用,详细介绍 Java 泛型。 什…

nrf52832蓝牙GAP 通用访问规范

nrf52832蓝牙GAP 通用访问规范 文章目录 nrf52832蓝牙GAP 通用访问规范前言一、蓝牙GAP(通用访问配置文件)可以设置什么参数?二、使用步骤广播名称修改广播名字长度;全显示和自定义显示中文显示广播名称 蓝牙图标没有图标加入图标…