昇思MindSpore学习入门-高阶自动微分

news2024/11/16 21:46:33

mindspore.ops模块提供的grad和value_and_grad接口可以生成网络模型的梯度。grad计算网络梯度,value_and_grad同时计算网络的正向输出和梯度。本文主要介绍如何使用grad接口的主要功能,包括一阶、二阶求导,单独对输入或网络权重求导,返回辅助变量,以及如何停止计算梯度。

一阶求导

计算一阶导数方法:mindspore.grad,其中参数使用方式为:

  • fn:待求导的函数或网络。
  • grad_position:指定求导输入位置的索引。若为int类型,表示对单个输入求导;若为tuple类型,表示对tuple内索引的位置求导,其中索引从0开始;若是None,表示不对输入求导,这种场景下,weights非None。默认值:0。
  • weights:训练网络中需要返回梯度的网络变量。一般可通过weights = net.trainable_params()获取。默认值:None。
  • has_aux:是否返回辅助参数的标志。若为True,fn输出数量必须超过一个,其中只有fn第一个输出参与求导,其他输出值将直接返回。默认值:False。

下面先构建自定义网络模型Net,再对其进行一阶求导,通过这样一个例子对grad接口的使用方式做简单介绍,即公式:

𝑓(𝑥,𝑦)=𝑥∗𝑥∗𝑦∗𝑧

首先定义网络模型Net、输入x和输入y:

import numpy as np

from mindspore import ops, Tensor

import mindspore.nn as nn

import mindspore as ms

# 定义输入x和y

x = Tensor([3.0], dtype=ms.float32)

y = Tensor([5.0], dtype=ms.float32)

class Net(nn.Cell):

    def __init__(self):

        super(Net, self).__init__()

        self.z = ms.Parameter(ms.Tensor(np.array([1.0], np.float32)), name='z')

    def construct(self, x, y):

        out = x * x * y * self.z

        return out

对输入求一阶导

对输入x, y进行求导,需要将grad_position设置成(0, 1):

对权重进行求导

对权重z进行求导,这里不需要对输入求导,将grad_position设置成None:

返回辅助变量

同时对输入和权重求导,其中只有第一个输出参与求导,示例代码如下:

停止计算梯度

可以使用stop_gradient来停止计算指定算子的梯度,从而消除该算子对梯度的影响。

在上面一阶求导使用的矩阵相乘网络模型的基础上,再增加一个算子out2并禁止计算其梯度,得到自定义网络Net2,然后看一下对输入的求导结果情况。

示例代码如下:

从上面的打印可以看出,由于对out2设置了stop_gradient,所以out2没有对梯度计算有任何的贡献,其输出结果与未加out2算子时一致。

下面删除out2 = stop_gradient(out2),再来看一下输出结果。示例代码为:

打印结果可以看出,把out2算子的梯度也计算进去之后,由于out2和out1算子完全相同,因此它们产生的梯度也完全相同,所以可以看到,结果中每一项的值都变为了原来的两倍(存在精度误差)。

高阶求导

高阶微分在AI支持科学计算、二阶优化等领域均有应用。如分子动力学模拟中,利用神经网络训练势能时,损失函数中需计算神经网络输出对输入的导数,则反向传播便存在损失函数对输入、权重的二阶交叉导数。

此外,AI求解微分方程(如PINNs方法)还会存在输出对输入的二阶导数。又如二阶优化中,为了能够让神经网络快速收敛,牛顿法等需计算损失函数对权重的二阶导数。

MindSpore可通过多次求导的方式支持高阶导数,下面通过几类例子展开阐述。

单输入单输出高阶导数

例如Sin算子,其公式为:

𝑓(𝑥)=𝑠𝑖𝑛(𝑥)

其一阶导数、二阶导数为:

其二阶导数(-Sin)实现如下:

从上面的打印结果可以看出,−𝑠𝑖𝑛(3.1415926)的值接近于0。

单输入多输出高阶导数

对如下公式求导:

(1)𝑓(𝑥)=(𝑓1(𝑥),𝑓2(𝑥))

其中:

(2)𝑓1(𝑥)=𝑠𝑖𝑛(𝑥)

(3)𝑓2(𝑥)=𝑐𝑜𝑠(𝑥)

梯度计算时由于MindSpore采用的是反向自动微分机制,会对输出结果求和后再对输入求导。因此其一阶导数是:

(4)𝑓′(𝑥)=𝑐𝑜𝑠(𝑥)−𝑠𝑖𝑛(𝑥)

其二阶导数为:

(5)𝑓″(𝑥)=−𝑠𝑖𝑛(𝑥)−𝑐𝑜𝑠(𝑥)

从上面的打印结果可以看出,−𝑠𝑖𝑛(3.1415926)−𝑐𝑜𝑠(3.1415926)的值接近于1。

多输入多输出高阶导数

对如下公式求导:

(1)𝑓(𝑥,𝑦)=(𝑓1(𝑥,𝑦),𝑓2(𝑥,𝑦))

其中:

(2)𝑓1(𝑥,𝑦)=𝑠𝑖𝑛(𝑥)−𝑐𝑜𝑠(𝑦)

(3)𝑓2(𝑥,𝑦)=𝑐𝑜𝑠(𝑥)−𝑠𝑖𝑛(𝑦)

梯度计算时由于MindSpore采用的是反向自动微分机制, 会对输出结果求和后再对输入求导。

求和:

(4)∑𝑜𝑢𝑡𝑝𝑢𝑡=𝑠𝑖𝑛(𝑥)+𝑐𝑜𝑠(𝑥)−𝑠𝑖𝑛(𝑦)−𝑐𝑜𝑠(𝑦)

输出和关于输入𝑥的一阶导数为:

(5)d∑𝑜𝑢𝑡𝑝𝑢𝑡d𝑥=𝑐𝑜𝑠(𝑥)−𝑠𝑖𝑛(𝑥)

输出和关于输入𝑥的二阶导数为:

(6)d∑𝑜𝑢𝑡𝑝𝑢𝑡2d2𝑥=−𝑠𝑖𝑛(𝑥)−𝑐𝑜𝑠(𝑥)

输出和关于输入𝑦的一阶导数为:

(7)d∑𝑜𝑢𝑡𝑝𝑢𝑡d𝑦=−𝑐𝑜𝑠(𝑦)+𝑠𝑖𝑛(𝑦)

输出和关于输入𝑦的二阶导数为:

(8)d∑𝑜𝑢𝑡𝑝𝑢𝑡2d2𝑦=𝑠𝑖𝑛(𝑦)+𝑐𝑜𝑠(𝑦)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1948923.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

代码随想录算法训练营Day 63| 图论 part03 | 417.太平洋大西洋水流问题、827.最大人工岛、127. 单词接龙

代码随想录算法训练营Day 63| 图论 part03 | 417.太平洋大西洋水流问题、827.最大人工岛、127. 单词接龙 文章目录 代码随想录算法训练营Day 63| 图论 part03 | 417.太平洋大西洋水流问题、827.最大人工岛、127. 单词接龙17.太平洋大西洋水流问题一、DFS二、BFS三、本题总结 82…

解析capl文件生成XML Test Module对应的xml工具

之前一直用的CAPL Test Module来写代码,所有的控制都是在MainTest()函数来实现的,但是有一次,代码都写完了,突然需要用xml的这种方式来实现,很突然,之前也没研究过,整理这个xml整的一身汗&#…

【1】CPU飙升到200%以上问题汇总

原链接 【1】CPU飙升到200%以上问题汇总 CPU飙升到200%以上是生成中常见的问题 注意: 1. linux的cpu使用频率是根据cpu个数和核数决定的 2. top,然后你按一下键盘的1,这就是单个核心的负载,不然是所有核心的负载相加,…

Golang | 腾讯一面

go的调度 Golang的调度器采用M:N调度模型,其中M代表用户级别的线程(也就是goroutine),而N代表的事内核级别的线程。Go调度器的主要任务就是N个OS线程上调度M个goroutine。这种模型允许在少量的OS线程上运行大量的goroutine。 Go调度器使用了三种队列来…

基于STM32瑞士军刀--【FreeRTOS开发】学习笔记(二)|| 堆 / 栈

堆和栈 1. 堆 堆就是空闲的一块内存,可以通过malloc申请一小块内存,用完之后使用再free释放回去。管理堆需要用到链表操作。 比如需要分配100字节,实际所占108字节,因为为了方便后期的free,这一小块需要有个头部记录…

2024年7月25日(Git gitlab以及分支管理 )

分布式版本控制系统 一、Git概述 Git 是一种分布式版本控制系统,用于跟踪和管理代码的变更。它是由Linus Torvalds创建的,最 初被设计用于Linux内核的开发。Git允许开发人员跟踪和管理代码的版本,并且可以在不同的开 发人员之间进行协作。 Github 用的就是Git系统来管理它们的…

JVM面试题之内存区域、类加载篇

文章目录 引言JVM是什么?1. JVM内存划分2. 对象如何在JVM中创建2.1 内存分配2.2 创建对象步骤 3. JVM类加载流程3.1 双亲委派 总结 引言 Java开发人员在面试中基本都会被问到关于JVM的问题。想要成为高级的开发人员,了解和学习Java运行的原理和JVM是必不…

webpack插件给所有的:src文件目录增加前缀

1.webpack4的版本写法 class AddPrefixPlugin {apply(compiler) {compiler.hooks.compilation.tap(AddPrefixPlugin, (compilation) > {HtmlWebpackPlugin.getHooks(compilation).beforeEmit.tapAsync(AddPrefixPlugin,(data, cb) > {// 使用正则表达式替换所有包含 /st…

阿里云服务器安装Anaconda后无法检测到

前言 问题如标题所言,就是conda -V验证错误,不过后来发现其实就是虽然安装时,同意了写入环境变量,但是其实还没有写入,需要手动写入。下面也会重复一遍安装流程。 安装 到[Anaconda下载处](Download Now | Anaconda)查…

基于微信小程序+SpringBoot+Vue的流浪动物救助(带1w+文档)

基于微信小程序SpringBootVue的流浪动物救助(带1w文档) 基于微信小程序SpringBootVue的流浪动物救助(带1w文档) 本系统实现的目标是使爱心人士都可以加入到流浪动物的救助工作中来。考虑到救助流浪动物的爱心人士文化水平不齐,所以本系统在设计时采用操作简单、界面…

通过IEC104转MQTT网关对接阿里云、华为云、亚马逊AWS、ThingsBoard、Ignition、Zabbix

随着工业互联网的快速发展,传统电力系统中的IEC 104协议设备正逐步向更加开放、灵活的物联网架构转型。MQTT(Message Queuing Telemetry Transport)作为一种轻量级的消息传输协议,因其低带宽消耗、高可靠性和广泛的支持性&#xf…

【JavaSE】基础知识复习 (二)

1.面向对象 对象内存分析 举例: class Person { //类:人String name;int age;boolean isMale; } public class PersonTest { //测试类public static void main(String[] args) {Person p1 new Person();p1.name "赵同学";p1.age 20;p1.is…

CentOS搭建Apache服务器

安装对应的软件包 [roothds ~]# yum install httpd mod_ssl -y 查看防火墙的状态和selinux [roothds ~]# systemctl status firewalld [roothds ~]# cat /etc/selinux/config 若未关闭,则关闭防火墙和selinux [roothds ~]# systemctl stop firewalld [roothds ~]# …

使用html2canvas制作一个截图工具

0 效果 1 下载html2canvas npm install html2canvas --save 2 创建ClipScreen.js import html2canvas from html2canvas; // 样式 const cssText {box: overflow:hidden;position:fixed;left:0;top:0;right:0;bottom:0;background-color:rgba(255,255,255,0.9);z-index: 10…

【嵌入式硬件】快衰减和慢衰减

1.引语 在使用直流有刷电机驱动芯片A4950时,这款芯片采用的是PWM控制方式,我发现他的正转、反转有两种控制方式,分别是快衰减和慢衰减。 2.理解 慢衰减:相当于加在电机(感性原件)两端电压消失,将电机两端正负短接。 快衰减:相当于加在电机(感性原件)两端电压消失,将电机…

一篇文章讲清楚html css js三件套之html

目录 HTML HTML发展史 HTML概念和语法 常见的HTML标签: HTML 调试 错误信息分析 HTML文档结构 HTML5的新特性 结论 HTML HTML是网页的基础,它是一种标记语言,用于定义网页的结构和内容。HTML标签告诉浏览器如何显示网页元素,例如段落…

快速安装torch-gpu和Tensorflow-gpu(自用,Ubuntu)

要更详细的教程可以参考Tensorflow PyTorch 安装(CPU GPU 版本),这里是有基础之后的快速安装。 一、Pytorch 安装 conda create -n torch_env python3.10.13 conda activate torch_env conda install cudatoolkit11.8 -c nvidia pip ins…

WINUI——Microsoft.UI.Xaml.Markup.XamlParseException:“无法找到与此错误代码关联的文本。

开发环境 VS2022 .net core6 问题现象 在Canvas内的子控件要绑定Canvas的兄弟控件的一个属性,在运行时出现了下述报错。 可能原因 在 WinUI(特别是用于 UWP 或 Windows App SDK 的版本)中,如果你尝试在 XAML 中将 Canvas 内的…

CSS 的工作原理

我们已经学习了CSS的基础知识,它的用途以及如何编写简单的样式表。在本课中,我们将了解浏览器如何获取 CSS 和 HTML 并将其转换为网页。 先决条件:已安装基本软件,了解处理文件的基本知识以及 HTML 基础知识(学习 HTML 简介。目的:要了解浏览器如何解析 CSS 和 HTML 的基…