昇思MindSpore学习总结六——函数式自动微分

news2025/1/19 20:32:44

        神经网络的训练主要使用反向传播算法,模型预测值(logits)与正确标签(label)送入损失函数(loss function)获得loss,然后进行反向传播计算,求得梯度(gradients),最终更新至模型参数(parameters)。自动微分能够计算可导函数在某点处的导数值,是反向传播算法的一般化。自动微分主要解决的问题是将一个复杂的数学运算分解为一系列简单的基本运算,该功能对用户屏蔽了大量的求导细节和过程,大大降低了框架的使用门槛。

        MindSpore使用函数式自动微分的设计理念,提供更接近于数学语义的自动微分接口gradvalue_and_grad。下面我们使用一个简单的单层线性变换模型进行介绍。

import numpy as np
import mindspore
from mindspore import nn
from mindspore import ops
from mindspore import Tensor, Parameter

1、函数与计算图

        计算图是用图论语言表示数学函数的一种方式,也是深度学习框架表达神经网络模型的统一方法。我们将根据下面的计算图构造计算函数和神经网络。

 在这个模型中,𝑥为输入,𝑦为正确值,𝑤和𝑏是我们需要优化的参数。

x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # bias
print(x)
print(y)

2、构造损失函数

         我们根据计算图描述的计算过程,构造计算函数。 其中,binary_cross_entropy_with_logits 是一个损失函数,计算预测值和目标值之间的二值交叉熵损失。

mindspore.ops.binary_cross_entropy_with_logits(logitslabelweight=Nonepos_weight=Nonereduction='mean')

        输入经过sigmoid激活函数后作为预测值,binary_cross_entropy_with_logits 计算预测值和目标值之间的二值交叉熵损失。

将输入 logits 设置为 𝑋 ,输入 labels 设置为 𝑌 ,输入 weight 设置为 𝑊 ,输出设置为 𝐿 。则,

𝑖 表示 𝑖^𝑡ℎ 样例, 𝑗 表示类别。则,

ℓ 表示计算损失的方法。

        有三种方法:第一种方法是直接提供损失值,第二种方法是计算所有损失的平均值,第三种方法是计算所有损失的总和。

        该算子会将输出乘以相应的权重。 𝑤𝑒𝑖𝑔ℎ𝑡 表示一个batch中的每条数据分配不同的权重, 𝑝𝑜𝑠_𝑤𝑒𝑖𝑔ℎ𝑡 为每个类别的正例子添加相应的权重。

        此外,它可以通过向正例添加权重来权衡召回率和精度。 在多标签分类的情况下,损失可以描述为:

其中 c 是类别数目(c>1 表示多标签二元分类,c=1 表示单标签二元分类),n 是批次中样本的数量,𝑃𝑐 是 第c类正例的权重。 𝑃𝑐>1 增大召回率, 𝑃𝑐<1 增大精度。

【参数】

  • logits (Tensor) - 输入预测值。其数据类型为float16或float32。

  • label (Tensor) - 输入目标值,shape与 logits 相同。数据类型为float16或float32。

  • weight (Tensor,可选) - 指定每个批次二值交叉熵的权重。支持广播,使其shape与 logits 的shape保持一致。数据类型必须为float16或float32。默认值:None , weight 是值为 1 的Tensor。

  • pos_weight (Tensor,可选) - 指定正类的权重。是一个长度等于分类数的向量。支持广播,使其shape与 logits 的shape保持一致。数据类型必须为float16或float32。默认值:None , pos_weight 是值为 1 的Tensor。

  • reduction (str,可选) - 指定应用于输出结果的规约计算方式,可选 'none' 、 'mean' 、 'sum' ,默认值: 'mean' 。

    • 'none':不应用规约方法。

    • 'mean':计算输出元素的加权平均值。

    • 'sum':计算输出元素的总和。

def function(x, y, w, b):
    z = ops.matmul(x, w) + b
    #mindspore.ops.matmul(input, other)计算两个数组的乘积。
    #input (Tensor) - 输入Tensor,不支持Scalar, input 的最后一维度和 other 的倒数第二维度相
#等,且 input 和 other 彼此支持广播。
    #other (Tensor) - 输入Tensor,不支持Scalar, input 的最后一维度和 other 的倒数第二维度相
#等,且 input 和 other 彼此支持广播。
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss
loss = function(x, y, w, b)
print(loss)

3、微分函数与梯度计算

        为了优化模型参数,需要求参数对loss的导数:∂loss/∂𝑤和∂loss/∂𝑏,此时我们调用mindspore.grad函数,来获得function的微分函数。

mindspore.grad(fngrad_position=0weights=Nonehas_aux=Falsereturn_ids=False)

        生成求导函数,用于计算给定函数的梯度。

函数求导包含以下三种场景:

  1. 对输入求导,此时 grad_position 非None,而 weights 是None;

  2. 对网络变量求导,此时 grad_position 是None,而 weights 非None;

  3. 同时对输入和网络变量求导,此时 grad_position 和 weights 都非None。

【参数】

  • fn (Union[Cell, Function]) - 待求导的函数或网络。

  • grad_position (Union[NoneType, int, tuple[int]]) - 指定求导输入位置的索引。若为int类型,表示对单个输入求导;若为tuple类型,表示对tuple内索引的位置求导,其中索引从0开始;若是None,表示不对输入求导,这种场景下, weights 非None。默认值: 0 。

  • weights (Union[ParameterTuple, Parameter, list[Parameter]]) - 训练网络中需要返回梯度的网络变量。一般可通过 weights = net.trainable_params() 获取。默认值: None 。

  • has_aux (bool) - 是否返回辅助参数的标志。若为 True , fn 输出数量必须超过一个,其中只有 fn 第一个输出参与求导,其他输出值将直接返回。默认值: False 。

  • return_ids (bool) - 是否返回由返回的梯度和指定求导输入位置的索引或网络变量组成的tuple。若为 True ,其输出中所有的梯度值将被替换为:由该梯度和其输入的位置索引,或者用于计算该梯度的网络变量组成的tuple。默认值: False 。

这里使用了grad函数的两个入参,分别为:

  • fn:待求导的函数。
  • grad_position:指定求导输入位置的索引。

由于我们对𝑤和𝑏求导,因此配置其在function入参对应的位置(2, 3)

使用grad获得微分函数是一种函数变换,即输入为函数,输出也为函数。

grad_fn = mindspore.grad(function, (2, 3))
# 执行微分函数,即可获得 𝑤、 𝑏对应的梯度。
grads = grad_fn(x, y, w, b)
print(grads)

3.1 Stop Gradient

        通常情况下,求导时会求loss对参数的导数,因此函数的输出只有loss一项。当我们希望函数输出多项时,微分函数会求所有输出项对参数的导数。此时如果想实现对某个输出项的梯度截断,或消除某个Tensor对梯度的影响,需要用到Stop Gradient操作。

        这里我们将function改为同时输出loss和z的function_with_logits,获得微分函数并执行。

def function_with_logits(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, z

grad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

        可以看到求得𝑤𝑤、𝑏𝑏对应的梯度值发生了变化。此时如果想要屏蔽掉z对梯度的影响,即仍只求参数对loss的导数,可以使用ops.stop_gradient接口,将梯度在此处截断。我们将function实现加入stop_gradient,并执行。

mindspore.ops.stop_gradient(value)

        用于消除某个值对梯度的影响,例如截断来自于函数输出的梯度传播。

【参数】

  • value (Any) - 需要被消除梯度影响的值。

def function_stop_gradient(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, ops.stop_gradient(z)
grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

可以看到,求得𝑤𝑤、𝑏𝑏对应的梯度值与初始function求得的梯度值一致。 

3.2 Auxiliary data

        Auxiliary data意为辅助数据,是函数除第一个输出项外的其他输出。通常我们会将函数的loss设置为函数的第一个输出,其他的输出即为辅助数据。

  gradvalue_and_grad提供has_aux参数,当其设置为True时,可以自动实现前文手动添加stop_gradient的功能,满足返回辅助数据的同时不影响梯度计算的效果。

下面仍使用function_with_logits,配置has_aux=True,并执行。

grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True)
grads, (z,) = grad_fn(x, y, w, b)
print(grads, z)

        可以看到,求得𝑤𝑤、𝑏𝑏对应的梯度值与初始function求得的梯度值一致,同时z能够作为微分函数的输出返回。

4、神经网络梯度计算

        神经网络构造是继承自面向对象编程范式的nn.Cell。接下来我们通过Cell构造同样的神经网络,利用函数式自动微分来实现反向传播。

        首先我们继承nn.Cell构造单层线性变换神经网络。这里我们直接使用前文的𝑤、𝑏作为模型参数,使用mindspore.Parameter进行包装后,作为内部属性,并在construct内实现相同的Tensor操作。

# Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.w = w
        self.b = b

    def construct(self, x):
        z = ops.matmul(x, self.w) + self.b
        return z

接下来我们实例化模型和损失函数。

# Instantiate model
model = Network()
# Instantiate loss function
loss_fn = nn.BCEWithLogitsLoss()

         完成后,由于需要使用函数式自动微分,需要将神经网络和损失函数的调用封装为一个前向计算函数。

# Define forward function
def forward_fn(x, y):
    z = model(x)
    loss = loss_fn(z, y)
    return loss

完成后,我们使用value_and_grad接口获得微分函数,用于计算梯度。

        由于使用Cell封装神经网络模型,模型参数为Cell的内部属性,此时我们不需要使用grad_position指定对函数输入求导,因此将其配置为None。对模型参数求导时,我们使用weights参数,使用model.trainable_params()方法从Cell中取出可以求导的参数。

grad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params())
loss, grads = grad_fn(x, y)
print(grads)

 执行微分函数,可以看到梯度值和前文function求得的梯度值一致。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1880930.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C语言】--分支和循环(1)

&#x1f37f;个人主页: 起名字真南 &#x1f9c7;个人专栏:【数据结构初阶】 【C语言】 目录 前言1 if 语句1.1 if1.2 else1.3 嵌套if1.4 悬空else 前言 C语言是结构化的程序设计语言&#xff0c;这里的结构指的是顺序结构、选择结构、循环结构。 我们可以用if、switch实现分支…

鸿蒙开发设备管理:【@ohos.multimodalInput.inputDevice (输入设备)】

输入设备 输入设备管理模块&#xff0c;用于监听输入设备连接、断开和变化&#xff0c;并查看输入设备相关信息。比如监听鼠标插拔&#xff0c;并获取鼠标的id、name和指针移动速度等信息。 说明&#xff1a; 本模块首批接口从API version 8开始支持。后续版本的新增接口&…

Todesk远程Ubuntu桌面系统100%但是进不去桌面

1、报错情况 如下图所示&#xff0c;用Todesk远程Ubuntu桌面&#xff0c;看到连接100%了&#xff0c;但是进不去桌面 ubuntu系统看起来的话&#xff0c;已经像被远程成功了 我就首先把todesk卸载重新安装了&#xff0c;后面发现还是这样&#xff0c;于是我就找客服去问了&…

数据结构笔记第2篇:单向链表

1、链表的概念及结构 概念&#xff1a;链表是一种物理结构上非连续、非顺序的存储结构&#xff0c;数据结构的逻辑顺序是通过链表中的指针链接次序实现的。 就像图中的小火车&#xff0c;每节车厢都是一个节点&#xff0c;每个节点都存储着一个数据。它们本身并不是顺序存储的…

python自动化办公之PyPDF2

用到的库&#xff1a;PyPDF2 实现效果&#xff1a;打开pdf文件&#xff0c;把每一页的内容读出来 代码&#xff1a; import PyPDF2 # 打开pdf文件 fileopen(friday.pdf,rb) # 创建pdf文件阅读器对象 readerPyPDF2.PdfReader(file) # 获取pdf文件的总页数 total_pageslen(rea…

SonicSense:声学振动丰富机器人的物体感知能力

在通过声学振动进行物体感知方面&#xff0c;尽管以往的研究已经取得了一些有希望的结果&#xff0c;但目前的解决方案仍然受限于几个方面。首先&#xff0c;大多数现有研究集中在只有少数&#xff08;N < 5&#xff09;基本物体的受限设置上。这些物体通常具有均质材料组成…

抖音直播自动点赞脚本:让点赞变得简单

抖音直播自动点赞脚本&#xff1a;让点赞变得简单 简介 点赞是社交媒体上表达喜爱的一种方式&#xff0c;尤其在抖音这样的平台上&#xff0c;点赞不仅能够增加主播的人气&#xff0c;还能鼓励他们创作更多优质内容。然而&#xff0c;手动点赞往往既耗时又费力。为了解决这个…

远程连接mysql等支持网络服务的数据库

1.ubuntu服务器上的mysql用datagrip连接需要专门去给mysql在服务器上运行的端口开放安全组吗 在使用 DataGrip 或任何其他数据库管理工具远程连接到 Ubuntu 服务器上的 MySQL 时&#xff0c;确实需要确保服务器的防火墙和安全组设置允许从你的 IP 地址访问 MySQL 所运行的端口&…

在Linux (Ubuntu 16) 下安装LabVIEW

用户尝试在Ubuntu 16操作系统上安装LabVIEW&#xff0c;但找不到合适的安装文件来支持Ubuntu。已经下载了运行时文件&#xff0c;并尝试将.rpm包转换为.deb包并安装在Ubuntu上。然而&#xff0c;安装完成后&#xff0c;没有在应用程序中看到LabVIEW的图标。 用户希望能够在Ubu…

Apache Ranger 2.4.0 集成Hive 3.x(Kerbos)

一、解压tar包 tar zxvf ranger-2.4.0-hive-plugin.tar.gz 二、修改install.propertis POLICY_MGR_URLhttp://localhost:6080REPOSITORY_NAMEhive_repoCOMPONENT_INSTALL_DIR_NAME/BigData/run/hiveCUSTOM_USERhadoop 三、进行enable [roottv3-hadoop-01 ranger-2.4.0-hive…

【SGX系列教程】(八)Intel-SGX 官方示例分析(SampleCode)——Seal Unseal

文章目录 一.Seal Unseal原理介绍1.1 Intel SGX supported Sealing Policies 二.源码分析2.1 README2.2 重点代码分析2.2.1 主要代码模块交互流程分析2.2.2 App/App.cpp2.2.3 Enclave_Seal/Enclave_Seal.cpp2.2.4 Enclave_Unseal/Enclave_Unseal.cpp 2.3 总结 三.参考文献四.感…

Debugging using Visual Studio Code

One of the key features of Visual Studio Code is its great debugging support. VS Code’s built-in debugger helps accelerate your edit, compile, and debug loop. Debugger extensions VS Code 内置了对 Node.js 运行时的调试支持,可以调试 JavaScript、TypeScript…

HDFS学习

3.5 HDFS存储原理 3.5.1 冗余数据保存 作为一个分布式文件系统&#xff0c;为了保证系统的容错性和可用性&#xff0c;HDFS采用了多副本方式对数据进行冗余存储&#xff0c;通常一个数据块的多个副本会被分布到不同的数据节点上。 如图所示&#xff0c;数据块1被分别存放到…

【IVI】CarService启动-Android13

【IVI】CarService启动-Android13 1、CarServiceImpl启动概述2、简要时序图 1、CarServiceImpl启动概述 【IVI】CarService启动&#xff1a; CarServiceHelperService中绑定CarServiceICarImpl初始化各种服务 packages/services/Car/README.md 2、简要时序图

RabbitMQ-交换机的类型以及流程图练习-01

自己的飞书文档:‌‍‬‍‬‍​‍‬​⁠‍​​​‌⁠​​‬‍​​​‬‬‌​‌‌​​&#xfeff;​​​​&#xfeff;‍​‍​‌&#xfeff;⁠‬&#xfeff;&#xfeff;&#xfeff;​RabbitMQ的流程图和作业 - 飞书云文档 (feishu.cn) 作业 图片一张 画rabbit-mq 消息发…

imx6ull/linux应用编程学习(6)jpeg和png的图片显示

1.JPEG图片显示 JPEG&#xff08;Joint Photographic Experts Group&#xff09;是由国际标准组织为静态图像所建立的第一个国际数字图像压缩标准&#xff0c;也是至今一直在使用的、应用最广的图像压缩标准。JPEG 由于可以提供有损压缩&#xff0c;因此压缩比可以达到其他传统…

sqlserver开启CDC

1、背景 由于需要学习flink cdc&#xff0c;并且数据选择sqlserver&#xff0c;所以这里记录sqlserver的cdc开启操作步骤。 2、基础前提 官方介绍地址&#xff1a;https://learn.microsoft.com/zh-cn/sql/relational-databases/track-changes/enable-and-disable-change-dat…

快递物流仓库管理系统java项目springboot和vue的前后端分离系统java课程设计java毕业设计

文章目录 快递物流仓库管理系统一、项目演示二、项目介绍三、部分功能截图四、部分代码展示五、底部获取项目源码&#xff08;9.9&#xffe5;带走&#xff09; 快递物流仓库管理系统 一、项目演示 快递物流仓库管理系统 二、项目介绍 语言: Java 数据库&#xff1a;MySQL 前…

论文阅读《U-KAN Makes Strong Backbone for MedicalImage Segmentation and Generation》

Abstract U-Net 已成为图像分割和扩散概率模型等各种视觉应用的基石。虽然通过结合transformer或 MLP&#xff0c;U-Net 已经引入了许多创新设计和改进&#xff0c;但仍然局限于线性建模模式&#xff0c;而且可解释性不足。为了应对这些挑战&#xff0c;我们的直觉受到了 Kolm…

PotPlayer安装及高分辨率设置

第1步&#xff1a; 下载安装PotPlayer软件 PotPlayer链接&#xff1a;https://pan.baidu.com/s/1hW168dJrLBonUnpLI6F3qQ 提取码&#xff1a;z8xd 第2步&#xff1a; 下载插件&#xff0c;选择系统对应的位数进行运行&#xff0c;该文件不能删除&#xff0c;删除后将失效。 …