昇思MindSpore学习入门-函数式自动微分

news2025/1/13 17:50:04

函数式自动微分

神经网络的训练主要使用反向传播算法,模型预测值(logits)与正确标签(label)送入损失函数(loss function)获得loss,然后进行反向传播计算,求得梯度(gradients),最终更新至模型参数(parameters)。自动微分能够计算可导函数在某点处的导数值,是反向传播算法的一般化。自动微分主要解决的问题是将一个复杂的数学运算分解为一系列简单的基本运算,该功能对用户屏蔽了大量的求导细节和过程,大大降低了框架的使用门槛。

MindSpore使用函数式自动微分的设计理念,提供更接近于数学语义的自动微分接口grad和value_and_grad。下面我们使用一个简单的单层线性变换模型进行介绍。

函数与计算图

计算图是用图论语言表示数学函数的一种方式,也是深度学习框架表达神经网络模型的统一方法。我们将根据下面的计算图构造计算函数和神经网络。

在这个模型中,𝑥为输入,𝑦为正确值,𝑤和𝑏是我们需要优化的参数。

我们根据计算图描述的计算过程,构造计算函数。 其中,binary_cross_entropy_with_logits 是一个损失函数,计算预测值和目标值之间的二值交叉熵损失。

执行计算函数,可以获得计算的loss值。

微分函数与梯度计算

为了优化模型参数,需要求参数对loss的导数:

,此时我们调用mindspore.grad函数,来获得function的微分函数。

这里使用了grad函数的两个入参,分别为:

  • fn:待求导的函数。
  • grad_position:指定求导输入位置的索引。

由于我们对𝑤和𝑏求导,因此配置其在function入参对应的位置(2, 3)。

执行微分函数,即可获得𝑤、𝑏对应的梯度。

Stop Gradient

通常情况下,求导时会求loss对参数的导数,因此函数的输出只有loss一项。当我们希望函数输出多项时,微分函数会求所有输出项对参数的导数。此时如果想实现对某个输出项的梯度截断,或消除某个Tensor对梯度的影响,需要用到Stop Gradient操作。

这里我们将function改为同时输出loss和z的function_with_logits,获得微分函数并执行。

可以看到求得𝑤、𝑏对应的梯度值发生了变化。此时如果想要屏蔽掉z对梯度的影响,即仍只求参数对loss的导数,可以使用ops.stop_gradient接口,将梯度在此处截断。我们将function实现加入stop_gradient,并执行。

可以看到,求得𝑤、𝑏对应的梯度值与初始function求得的梯度值一致。

Auxiliary data

Auxiliary data意为辅助数据,是函数除第一个输出项外的其他输出。通常我们会将函数的loss设置为函数的第一个输出,其他的输出即为辅助数据。

grad和value_and_grad提供has_aux参数,当其设置为True时,可以自动实现前文手动添加stop_gradient的功能,满足返回辅助数据的同时不影响梯度计算的效果。

下面仍使用function_with_logits,配置has_aux=True,并执行。

可以看到,求得𝑤、𝑏对应的梯度值与初始function求得的梯度值一致,同时z能够作为微分函数的输出返回。

神经网络梯度计算

前述章节主要根据计算图对应的函数介绍了MindSpore的函数式自动微分,但我们的神经网络构造是继承自面向对象编程范式的nn.Cell。接下来我们通过Cell构造同样的神经网络,利用函数式自动微分来实现反向传播。

首先我们继承nn.Cell构造单层线性变换神经网络。这里我们直接使用前文的𝑤、𝑏作为模型参数,使用mindspore.Parameter进行包装后,作为内部属性,并在construct内实现相同的Tensor操作。

接下来我们实例化模型和损失函数。

由于需要使用函数式自动微分,需要将神经网络和损失函数的调用封装为一个前向计算函数。

完成后,我们使用value_and_grad接口获得微分函数,用于计算梯度。

由于使用Cell封装神经网络模型,模型参数为Cell的内部属性,此时我们不需要使用grad_position指定对函数输入求导,因此将其配置为None。对模型参数求导时,我们使用weights参数,使用model.trainable_params()方法从Cell中取出可以求导的参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1879062.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

为何同一PDF文档用不同软件打印效果不同?

通过扫描仪生成的同一PDF文档,同样的设置,为什么别的电脑打出来是白底我的打出来有灰色格子背景?这种情况通常是由于PDF阅读软件的不同造成的差异。 ### 可能的原因和解决方法: 1. **PDF阅读软件的不同**: - **解决方…

Java高级重点知识点-17-异常

文章目录 异常异常处理自定义异常 异常 指的是程序在执行过程中,出现的非正常的情况,最终会导致JVM的非正常停止。Java处 理异常的方式是中断处理。 异常体系 异常的根类是 java.lang.Throwable,,其下有两个子类:ja…

Java 面试指南合集

JVM 篇 线程篇 springBoot篇 SpringCloud篇 待更新 黑夜无论怎样悠长,白昼总会到来。 此文会一直更新哈 如果你希望成功,当以恒心为良友,以经验为参谋,以当心为兄弟,以希望为哨兵。

Java单体架构项目_云霄外卖-特殊点

项目介绍: 定位: 专门为餐饮企业(餐厅、饭店)定制的一款软件商品 分为: 管理端:外卖商家使用 用户端(微信小程序):点餐用户使用。 功能架构: &#xff08…

改进经验模态分解方法-通过迭代方式(IMF振幅加权频率,Python)

一种新颖的改进经验模态分解方法-通过迭代方式(IMF振幅加权频率)有效缓解了模态混叠缺陷,以后慢慢讲,先占坑。 import numpy as np import matplotlib.pyplot as plt import os import seaborn as sns from scipy import stats i…

Pikachu 不安全的文件下载(Unsafe file download)概述 附漏洞利用案例

目录 获取下载链接 修改链接 重新构造链接 拓展 不安全的文件下载概述 文件下载功能在很多web系统上都会出现,一般我们当点击下载链接,便会向后台发送一个下载请求,一般这个请求会包含一个需要下载的文件名称,后台在收到请求…

回溯法基本思想-01背包、N皇后回溯法图解

基本思想: ​ 回溯法是一种系统地搜索问题解空间的算法,常用于解决组合优化和约束满足问题。其核心思想是利用深度优先搜索逐步构建可能的解,同时在搜索过程中进行剪枝操作,以排除那些无法满足问题约束或不能产生最优解的分支&am…

第十一节:学习通过动态调用application.properties参数配置实体类(自学Spring boot 3.x的第二天)

大家好,我是网创有方。这节实现的效果是通过代码灵活地调用application.properties实现配置类参数赋值。 第一步:编写配置类 package cn.wcyf.wcai.config;import org.springframework.beans.factory.annotation.Value; import org.springframework.boo…

【后端面试题】【中间件】【NoSQL】ElasticSearch 节点角色、写入数据过程、Translog和索引与分片

中间件的常考方向: 中间件如何做到高可用和高性能的? 你在实践中怎么做的高可用和高性能的? Elasticsearch节点角色 Elasticsearch的节点可以分为很多种角色,并且一个节点可以扮演多种角色,下面列举几种主要的&…

python第一课 环境准备篇

一、所需工具 电脑:windows或mac 二、安装教程 1、访问 Python 的官方网站(https://www.python.org/ ),找到 DownLoad ,无法访问百度网盘下载 链接:百度网盘 请输入提取码 提取码:8cho 2、选…

用Java操作MySQL数据中的日期类型的数据存取问题分析及其解决办法

目录 一、问题说明二、问题分析三、解决办法1.Java日期向数据存方法一:方法二: 2.从数据库中取日期最后 在Java中向MySQL数据库存取日期类型的数据时,可能会遇到一些常见问题,以下是一些关键点和解决办法: 一、问题说…

基于bootstrap的12种登录注册页面模板

基于bootstrap的12种登录注册页面模板,分三种类型,默认简单的登录和注册,带背景图片的登录和注册,支持弹窗的登录和注册页面html下载。 微信扫码下载

Spring学习01-[Spring实现IOC的几种方式]

Spring实现IOC的几种方式 基于xml实现Spring的IOC基于注解实现Spring的IOC基于JavaConfig实现的Spring的IOC基于SpringBoot实现Spring的IOC 基于xml实现Spring的IOC 引入spring核心依赖 <!--spring核心容器--><dependency><groupId>org.springframework<…

【最新鸿蒙应用开发】——用户信息封装

用户管理工具封装 1. 为什么要封装 在进行如下登录功能时&#xff0c; 通常需要将一些用户信息以及token进行持久化保存&#xff0c;以方便下次进行数据请求时携带这些用户信息来进行访问后端数据。下面分享一下鸿蒙当中实用的持久化封装操作。 2. 步骤 封装用户信息管理工具…

数据恢复篇:如何在没有备份的情况下从恢复已删除的照片

许多用户更喜欢将他们的私人照片保存在他们的 Android 设备上的一个单独的安全空间中&#xff0c;以确保他们的记忆不仅被存储&#xff0c;而且受到保护。这就是“安全文件夹”功能派上用场的地方。您可以使用 PIN 码、密码、指纹或图案锁定此文件夹&#xff0c;即使您的设备落…

springboot汽车租赁管理系统-计算机毕业设计源码08754

目 录 摘 要 第 1 章 引 言 1.1 选题背景和意义 1.2 国内外研究现状 1.3 论文结构安排 第 2 章 系统的需求分析 2.1 系统可行性分析 2.1.1 技术方面可行性分析 2.1.2 经济方面可行性分析 2.1.3 法律方面可行性分析 2.1.4 操作方面可行性分析 2.2 系统功能需求分析…

正版软件 | R-Studio Technician:数据恢复领域的专业利器

在数据恢复的专业领域&#xff0c;每一个挑战都需要精准而强大的工具来应对。R-Studio Technician 是一款专为 Windows、Mac 和 Linux 系统设计的高级数据恢复软件&#xff0c;为数字取证实验室、数据恢复企业或个人提供了全面的解决方案。 专业级工具&#xff0c;全面功能 R-S…

MySQL高级-MVCC-原理分析(RC级别)

文章目录 1、RC隔离级别下&#xff0c;在事务中每一次执行快照读时生成ReadView2、先来看第一次快照读具体的读取过程&#xff1a;3、再来看第二次快照读具体的读取过程: 1、RC隔离级别下&#xff0c;在事务中每一次执行快照读时生成ReadView 我们就来分析事务5中&#xff0c;两…

微信小程序毕业设计-垃圾分类系统项目开发实战(附源码+论文)

大家好&#xff01;我是程序猿老A&#xff0c;感谢您阅读本文&#xff0c;欢迎一键三连哦。 &#x1f49e;当前专栏&#xff1a;微信小程序毕业设计 精彩专栏推荐&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb; &#x1f380; Python毕业设计…

IP配置SSL的方式

近年SSL证书的运用群体越来越多&#xff0c;实现网站https访问已经成为了常态。 目前SSL证书广泛应用在域名服务器上&#xff0c;所以大家最熟悉的证书类型可能就是单域名SSL证书、泛域名SSL证书&#xff08;通配符SSL证书、泛解析SSL证书&#xff09;、以及方便集成化管理的多…