昇思25天学习打卡营第6天 | 函数式自动微分

news2024/12/25 4:08:17

神经网络的训练主要使用反向传播算法

模型预测值(logits)正确标签(label)送入损失函数(loss function)获得loss

然后进行反向传播计算,求得梯度(gradients),最终更新至模型参数(parameters)

自动微分能够计算可导函数在某点处的导数值,是反向传播算法的一般化。

自动微分主要解决的问题是将一个复杂的数学运算分解为一系列简单的基本运算,该功能对用户屏蔽了大量的求导细节和过程,大大降低了框架的使用门槛。

MindSpore使用函数式自动微分的设计理念,提供更接近于数学语义的自动微分接口gradvalue_and_grad。下面我们使用一个简单的单层线性变换模型进行介绍。

1.函数与计算图

计算图是用图论语言表示数学函数的一种方式,也是深度学习框架表达神经网络模型的统一方法。我们将根据下面的计算图构造计算函数和神经网络。

在这个模型中,𝑥𝑥为输入,𝑦𝑦为正确值,𝑤𝑤和𝑏𝑏是我们需要优化的参数。

我们根据计算图描述的计算过程,构造计算函数。 其中,binary_cross_entropy_with_logits 是一个损失函数,计算预测值和目标值之间的二值交叉熵损失

执行计算函数,可以获得计算的loss值。

x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # bias

def function(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss

loss = function(x, y, w, b)
print(loss)

2.微分函数与梯度计算

为了优化模型参数,需要求参数对loss的导数:\frac{\partial loss }{\partial w}\frac{\partial loss}{\partial b},此时我们调用mindspore.grad函数,来获得function的微分函数。

这里使用了grad函数的两个入参,分别为:

  • fn:待求导的函数。
  • grad_position:指定求导输入位置的索引。

由于我们对𝑤𝑤和𝑏𝑏求导,因此配置其在function入参对应的位置(2, 3)

使用grad获得微分函数是一种函数变换,即输入为函数,输出也为函数。

执行微分函数,即可获得𝑤𝑤、𝑏𝑏对应的梯度。

grad_fn = mindspore.grad(function, (2, 3))

grads = grad_fn(x, y, w, b)
print(grads)

3.Stop Gradient

通常情况下,求导时会求loss对参数的导数,因此函数的输出只有loss一项。当我们希望函数输出多项时,微分函数会求所有输出项对参数的导数。

此时如果想实现对某个输出项的梯度截断,或消除某个Tensor对梯度的影响,需要用到Stop Gradient操作。

这里我们将function改为同时输出loss和z的function_with_logits,获得微分函数并执行。

1

可以看到,求得𝑤𝑤、𝑏𝑏对应的梯度值与初始function求得的梯度值一致。

4.Auxiliary data

Auxiliary data意为辅助数据,是函数除第一个输出项外的其他输出。通常我们会将函数的loss设置为函数的第一个输出,其他的输出即为辅助数据。

gradvalue_and_grad提供has_aux参数,当其设置为True时,可以自动实现前文手动添加stop_gradient的功能,满足返回辅助数据的同时不影响梯度计算的效果。

下面仍使用function_with_logits,配置has_aux=True,并执行。

grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True)

grads, (z,) = grad_fn(x, y, w, b)
print(grads, z)

可以看到,求得𝑤𝑤、𝑏𝑏对应的梯度值与初始function求得的梯度值一致,同时z能够作为微分函数的输出返回。

5.神经网络梯度计算

我们的神经网络构造是继承自面向对象编程范式的nn.Cell。接下来我们通过Cell构造同样的神经网络,利用函数式自动微分来实现反向传播。

首先我们继承nn.Cell构造单层线性变换神经网络。这里我们直接使用前文的𝑤𝑤、𝑏𝑏作为模型参数,使用mindspore.Parameter进行包装后,作为内部属性,并在construct内实现相同的Tensor操作。

# Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.w = w
        self.b = b

    def construct(self, x):
        z = ops.matmul(x, self.w) + self.b
        return z

接下来我们实例化模型和损失函数:

# Instantiate model
model = Network()
# Instantiate loss function
loss_fn = nn.BCEWithLogitsLoss()

完成后,由于需要使用函数式自动微分,需要将神经网络和损失函数的调用封装为一个前向计算函数:

# Define forward function
def forward_fn(x, y):
    z = model(x)
    loss = loss_fn(z, y)
    return loss

完成后,我们使用value_and_grad接口获得微分函数,用于计算梯度。

由于使用Cell封装神经网络模型,模型参数为Cell的内部属性,此时我们不需要使用grad_position指定对函数输入求导,因此将其配置为None。对模型参数求导时,我们使用weights参数,使用model.trainable_params()方法从Cell中取出可以求导的参数。

grad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params())

loss, grads = grad_fn(x, y)
print(grads)

代码实现:

总结一下神经网络的梯度计算:

构建模型,基于nn.cell基类,_init_,定义类,construct类。

实例化:model=Network()

实例化损失函数:loss_fn = nn.BCEWithLogitsLoss()

函数式自动微分,需要把神经网络和损失函数调用封装为一个前向的计算函数def forward_fn()

使用接口grad_position获取微分函数

用weight函数求导,用model.trainable_params()取出求导的参数

byebye~~~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1859758.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Kubernetes 中 ElasticSearch 中的 MinIO 审核日志

无论您是在本地还是在云中,您都希望确保以同构的方式设置工具和流程。无论在何处访问基础结构,您都希望确保用于与各种基础结构进行交互的工具与其他区域相似。 考虑到这一点,在部署您自己的 MinIO 对象存储基础架构时,深入了解您…

【Python】已解决:urllib模块设置代理ip

文章目录 一、分析问题背景二、可能出错的原因三、错误代码示例四、正确代码示例五、注意事项 已解决:urllib模块设置代理ip 一、分析问题背景 在使用Python的urllib模块进行网络请求时,有时我们需要通过代理服务器来发送请求,以达到隐藏真…

python中的<class ‘complex‘>

一般编程里面不怎么会讲&#xff0c;但是还是挺强大的一个类。 在 Python 中&#xff0c;<class complex> 表示复数类型。复数是一种包含实部和虚部的数学数&#xff0c;可以用 a bj 的形式表示&#xff0c;其中 a 表示实部&#xff0c;b 表示虚部&#xff0c;j 是虚数…

2024年广西三支一扶报名详细流程(附报名照处理流程)​

2024年广西将招募1650名高校毕业生到基层从事支农、支医、支教和帮扶乡村振兴工作&#xff08;简称“三支一扶”&#xff09;。 招募对象为全日制普通高校应届及择业期内2022年至2024年毕业的全日制普通高校毕业生。 ➡️招募条件。 1.具有全日制大专&#xff08;含高职高专&am…

【C++题解】1058 - 求出100至999范围内的所有水仙花数。

问题:1058 - 求出100至999范围内的所有水仙花数。 类型&#xff1a;简单循环 题目描述&#xff1a; 所谓水仙花数&#xff0c;就是指各位数字立方之和等于该数的数&#xff1b;a^3 称为 a 的立方&#xff0c;即等于 aaa 的值。 例如&#xff1a;因为 15313533^3 &#xff0c;…

【JavaEE 进阶(六)】Mybatis操作数据库

❣博主主页: 33的博客❣ ▶️文章专栏分类:JavaEE◀️ &#x1f69a;我的代码仓库: 33的代码仓库&#x1f69a; &#x1faf5;&#x1faf5;&#x1faf5;关注我带你了解更多进阶知识 目录 1.前言2.JDBC操作3.Mybatis3.1Mybatis入门3.1.1准备工作3.1.2配置数据库连接3.1.2写持久…

13 物理层介质及设备

物理层介质及设备 一、线缆的连接 &#xff08;一&#xff09;线序 ​ 线序&#xff1a; RJ-45连接头12345678568A绿白绿橙白蓝蓝白橙棕白棕568B橙白橙绿白蓝蓝白绿棕白棕 ​ 1、2发送&#xff0c;3、6接收 &#xff08;二&#xff09;线缆的应用 1.线缆的连接 ​ 标准…

OpenHarmony 5.0 纯血鸿蒙系统

OpenHarmony-v5.0-Beta1 版本已于 2024-06-20 发布。 OpenHarmony 5.0 Beta1 版本标准系统能力持续完善&#xff0c;ArkUI 完善了组件通过 C API 调用的能力&#xff1b;应用框架细化了生命周期管理能力&#xff0c;完善了应用拉起、跳转的能力&#xff1b;分布式软总线连接能力…

MindManager2024思维导图软件重磅发布更新!

大家好啊&#xff01;&#x1f44b; 今天我超级激动要分享给大家一款改变我工作和学习方式的工具——MindManager2024思维导图软件&#xff01;这可不仅仅是个工具哦&#xff0c;它更像是我的私人思维助手&#xff0c;帮我整理思绪&#xff0c;规划时间&#xff0c;还能激发创新…

HTML+CSS 3D旋转登录表单

效果演示 实现了一个具有3D旋转效果的登录框&#xff0c;背景为太空图片&#xff0c;登录框位于太空中心&#xff0c;可以通过输入用户名和密码进行登录。登录框使用了CSS3的3D变换和动画效果&#xff0c;使其具有立体感和动态效果。同时&#xff0c;登录框的样式也经过精心设计…

TMC2209驱动模式详解

TMC2209驱动模式详解 1.TMC2209封装2.TMC2209引脚定义 1.TMC2209封装 2.TMC2209引脚定义

关于ONLYOFFICE8.1版本桌面编辑器测评——AI时代的领跑者

关于作者&#xff1a;个人主页 目录 一.产品介绍 1.关于ONLYOFFICE 2.关于产品的多元化功能 二.关于产品体验方式 1.关于套件的使用网页版登录 2.关于ONLYOFFICE本地版 三.关于产品界面设计 四.关于产品文字处理器&#xff08;Document Editor&#xff09; 1.电子表格&a…

破碎的像素地牢探险:游戏分享

软件介绍 《破碎的像素地牢》是开源一款地牢冒险探索类的游戏&#xff0c;融合了日系RPG经典风格&#xff0c;玩家将控制主角进行未知场景的探索。除了经典地牢玩法外&#xff0c;游戏还添加了更多创意内容&#xff0c;如黑屏状态前的挑战性等&#xff0c;使得游戏更加富有挑战…

【源码】Spring Data JPA原理解析之Auditing执行原理

Spring Data JPA系列 1、SpringBoot集成JPA及基本使用 2、Spring Data JPA Criteria查询、部分字段查询 3、Spring Data JPA数据批量插入、批量更新真的用对了吗 4、Spring Data JPA的一对一、LazyInitializationException异常、一对多、多对多操作 5、Spring Data JPA自定…

C#的无边框窗体项目模板 - 开源研究系列文章

继续整理和编写代码及博文。 这次将笔者自己整理的C#的无边框窗体项目的基本模板进行总结&#xff0c;得出了基于C#的.net framework的Winform的4个项目模板&#xff0c;这些模板具有基本的功能&#xff0c;即已经初步将代码写了&#xff0c;直接在其基础上添加业务代码即可&am…

Maya 白膜渲染简单教程

零基础渲染小白&#xff0c;没关系&#xff0c;一篇超简单教程带你学会渲染白膜。 先打开Maya&#xff0c;看看面板有没有渲染器&#xff0c;这里以Arnold为主。 要是没有这个&#xff0c;就去找插件管理器&#xff0c; Arnold的是mtoa&#xff0c;在搜索栏搜&#xff0c;然后把…

uni-app 微信小程序开发到发布流程

1. uni-app 微信小程序开发到发布流程 1.1. 新建一个uni-app 项目 1.2. 发行微信小程序 1.3. 微信开发者平台的微信小程序appid 复制进来&#xff08;点击发行&#xff09; 1.4. IDE may already started at port xxxx, trying to connect &#xff08;1&#xff09;关闭微信…

【Golang - 90天从新手到大师】Day14 - 方法和接口

一&#xff0e; go方法 go方法&#xff1a;在函数的func和函数名间增加一个特殊的接收器类型&#xff0c;接收器可以是结构体类型或非结构体类型。接收器可以在方法内部访问。创建一个接收器类型为Type的methodName方法。 func (t Type) methodName(parameter list) {}go引入…

解析JavaScript中逻辑运算符和||的返回值机制

本文主要内容&#xff1a;了解逻辑运算符 &&&#xff08;逻辑与&#xff09;和 ||&#xff08;逻辑或&#xff09;的返回值。 在JavaScript中&#xff0c;逻辑运算符 &&&#xff08;逻辑与&#xff09;和 ||&#xff08;逻辑或&#xff09;的返回值可能并不总…

大模型之-Seq2Seq介绍

大模型之-Seq2Seq介绍 1. Seq2Seq 模型概述 Seq2Seq&#xff08;Sequence to Sequence&#xff09;模型是一种用于处理序列数据的深度学习模型&#xff0c;常用于机器翻译、文本摘要和对话系统等任务。它的核心思想是将一个输入序列转换成一个输出序列。 Seq2Seq模型由两个主…