昇思25天学习打卡营第6天|基础知识-函数式自动微分

news2024/12/23 11:46:23

目录

环境

函数与计算图

微分函数与梯度计算

Stop Gradient

Auxiliary data

神经网络梯度计算

学习打卡时间


神经网络的训练主要使用反向传播算法,模型预测值(logits)与正确标签(label)送入损失函数(loss function)获得loss,然后进行反向传播计算,求得梯度(gradients),最终更新至模型参数(parameters)。

自动微分能够计算可导函数在某点处的导数值,是反向传播算法的一般化。

自动微分主要解决的问题是将一个复杂的数学运算分解为一系列简单的基本运算,该功能对用户屏蔽了大量的求导细节和过程,大大降低了框架的使用门槛。

MindSpore使用函数式自动微分的设计理念,提供更接近于数学语义的自动微分接口gradvalue_and_grad

环境

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14

导包

import numpy as np
import mindspore
from mindspore import nn
from mindspore import ops
from mindspore import Tensor, Parameter

函数与计算图

计算图是用图论语言表示数学函数的一种方式,也是深度学习框架表达神经网络模型的统一方法。我们将根据下面的计算图构造计算函数和神经网络。

在这个模型中,𝑥x为输入,𝑦y为正确值,𝑤w和𝑏b是我们需要优化的参数。

x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # bias

我们根据计算图描述的计算过程,构造计算函数。 其中,binary_cross_entropy_with_logits 是一个损失函数,计算预测值和目标值之间的二值交叉熵损失。

def function(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss

执行计算函数,可以获得计算的loss值。

loss = function(x, y, w, b)
print(loss)

# Tensor(shape=[], dtype=Float32, value= 0.914285)

微分函数与梯度计算

为了优化模型参数,需要求参数对loss的导数:∂loss/∂𝑤 和 ∂loss/∂𝑏,

调用mindspore.grad函数,来获得function的微分函数。

这里使用了grad函数的两个入参,分别为:

  • fn:待求导的函数。
  • grad_position:指定求导输入位置的索引。

由于我们对𝑤和𝑏求导,因此配置其在function入参对应的位置(2, 3)

使用grad获得微分函数是一种函数变换,即输入为函数,输出也为函数。

grad_fn = mindspore.grad(function, (2, 3))

执行微分函数,即可获得𝑤w、𝑏b对应的梯度。

grads = grad_fn(x, y, w, b)
print(grads)

"""
(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01],
 [ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01],
 [ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01],
 [ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01],
 [ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01]]), Tensor(shape=[3], dtype=Float32, value= [ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01]))
"""

Stop Gradient

通常情况下,求导时会求loss对参数的导数,因此函数的输出只有loss一项。当我们希望函数输出多项时,微分函数会求所有输出项对参数的导数。此时如果想实现对某个输出项的梯度截断,或消除某个Tensor对梯度的影响,需要用到Stop Gradient操作。

这里我们将function改为同时输出loss和z的function_with_logits,获得微分函数并执行。

def function_with_logits(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, z
grad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

"""
(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 1.03263855e+00,  1.24709988e+00,  1.21991932e+00],
 [ 1.03263855e+00,  1.24709988e+00,  1.21991932e+00],
 [ 1.03263855e+00,  1.24709988e+00,  1.21991932e+00],
 [ 1.03263855e+00,  1.24709988e+00,  1.21991932e+00],
 [ 1.03263855e+00,  1.24709988e+00,  1.21991932e+00]]), Tensor(shape=[3], dtype=Float32, value= [ 1.03263855e+00,  1.24709988e+00,  1.21991932e+00]))
"""

可以看到求得𝑤w、𝑏b对应的梯度值发生了变化。此时如果想要屏蔽掉z对梯度的影响,即仍只求参数对loss的导数,可以使用ops.stop_gradient接口,将梯度在此处截断。我们将function实现加入stop_gradient,并执行。

def function_stop_gradient(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, ops.stop_gradient(z)
grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

"""
(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01],
 [ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01],
 [ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01],
 [ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01],
 [ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01]]), Tensor(shape=[3], dtype=Float32, value= [ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01]))
"""

可以看到,求得𝑤w、𝑏b对应的梯度值与初始function求得的梯度值一致。

Auxiliary data

Auxiliary data意为辅助数据,是函数除第一个输出项外的其他输出。通常我们会将函数的loss设置为函数的第一个输出,其他的输出即为辅助数据。

gradvalue_and_grad提供has_aux参数,当其设置为True时,可以自动实现前文手动添加stop_gradient的功能,满足返回辅助数据的同时不影响梯度计算的效果。

下面仍使用function_with_logits,配置has_aux=True,并执行。

grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True)
grads, (z,) = grad_fn(x, y, w, b)
print(grads, z)

"""
(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01],
 [ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01],
 [ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01],
 [ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01],
 [ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01]]), Tensor(shape=[3], dtype=Float32, value= [ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01])) [-2.2206016  1.0527335  0.6622162]
"""

可以看到,求得𝑤w、𝑏b对应的梯度值与初始function求得的梯度值一致,同时z能够作为微分函数的输出返回。

神经网络梯度计算

前述章节主要根据计算图对应的函数介绍了MindSpore的函数式自动微分,但我们的神经网络构造是继承自面向对象编程范式的nn.Cell。接下来我们通过Cell构造同样的神经网络,利用函数式自动微分来实现反向传播。

首先我们继承nn.Cell构造单层线性变换神经网络。这里我们直接使用前文的𝑤、𝑏作为模型参数,使用mindspore.Parameter进行包装后,作为内部属性,并在construct内实现相同的Tensor操作。

# Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.w = w
        self.b = b

    def construct(self, x):
        z = ops.matmul(x, self.w) + self.b
        return z

接下来我们实例化模型和损失函数。

# Instantiate model
model = Network()
# Instantiate loss function
loss_fn = nn.BCEWithLogitsLoss()

完成后,由于需要使用函数式自动微分,需要将神经网络和损失函数的调用封装为一个前向计算函数。

# Define forward function
def forward_fn(x, y):
    z = model(x)
    loss = loss_fn(z, y)
    return loss

完成后,我们使用value_and_grad接口获得微分函数,用于计算梯度。

由于使用Cell封装神经网络模型,模型参数为Cell的内部属性,此时我们不需要使用grad_position指定对函数输入求导,因此将其配置为None。对模型参数求导时,我们使用weights参数,使用model.trainable_params()方法从Cell中取出可以求导的参数。

grad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params())
loss, grads = grad_fn(x, y)
print(grads)

"""
(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01],
 [ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01],
 [ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01],
 [ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01],
 [ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01]]), Tensor(shape=[3], dtype=Float32, value= [ 3.26385535e-02,  2.47099832e-01,  2.19919339e-01]))
"""

执行微分函数,可以看到梯度值和前文function求得的梯度值一致。

学习打卡时间

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1960839.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C++红黑树应用】模拟实现STL中的map与set

目录 🚀 前言一: 🔥 红黑树的修改二: 🔥 红黑树的迭代器 三: 🔥 perator() 与 operator--() 四: 🔥 红黑树相关接口的改造✨ 4.1 Find 函数的改造✨ 4.2 Insert 函数的改…

推荐珍藏已久的 3 款优质电脑软件,每一款都值得拥有

Advanced Find and Replace Advanced Find and Replace是一款功能强大的文本查找和替换工具,能够高效地在多个文档中进行复杂的内容操作。它支持通配符和正则表达式,使得用户可以精确地定位和替换特定的文本内容。该软件不仅适用于普通文本文件&#xff…

防洪评价报告编制方法与水流数学模型建模技术

原文链接:防洪评价报告编制方法与水流数学模型建模技术https://mp.weixin.qq.com/s?__bizMzUzNTczMDMxMg&mid2247610610&idx2&sn432d30cb40ec36160d635603c7f22c96&chksmfa827115cdf5f803ddcaa03a21e3721d6949d6a336062bb38170e3f9d5bd4d391cc36cc…

【速记!】3DMAX的50个常用快捷键

分享一组基本的3dMax动画和建模快捷键,以用于你的建筑项目。 3dMax是创建三维模型和动画的设计师中流行的软件。它用于建筑、电子游戏或其他需要高清晰度和高精度图形的视觉项目,是视觉艺术家寻找新工具的理想伴侣,这些工具可以帮助他们详细…

Vue3实战案例 知识点全面 推荐收藏 超详细 及附知识点解读

最近经常用到vue中的一些常用知识点,打算系统性的对 vue3 知识点进行总结,方便自己查看,另外也供正在学习 vue3 的同学参考,本案例基本包含 Vue3所有的基本知识点,欢迎参考,有问题评论区留言,谢…

Linux基本功能

Linux 操作系统,作为开源社区的明星之一,以其稳定性、安全性和灵活性在全球范围内得到广泛应用。 1. 多用户和多任务支持 Linux 是一个真正的多用户系统,允许多个用户同时登录并在同一时间内运行多个程序。每个用户拥有自己的账户和权限&…

每日OJ_牛客HJ86 求最大连续bit数

目录 牛客HJ86 求最大连续bit数 解析代码 牛客HJ86 求最大连续bit数 求最大连续bit数_牛客题霸_牛客网 解析代码 根据位运算&#xff0c;获取每一位的二进制值。获取第i位的值&#xff1a; (n >> i) & 1或者 n & (1 << i)。如果1连续&#xff0c;则计数…

Redis 安装和数据类型

Redis 安装和数据类型 一、Redis 1、Redis概念 redis 缓存中间件&#xff1a;缓存数据库 nginx web服务 php 转发动态请求 tomcat web页面&#xff0c;也可以转发动态请求 springboot 自带tomcat 数据库不支持高并发&#xff0c;一旦访问量激增&#xff0c;数据库很快就…

网工内推 | 合资公司、上市公司数据库工程师,OCP/OCM认证优先,双休

01 欣旺达电子股份有限公司 &#x1f537;招聘岗位&#xff1a;数据库管理高级工程师 &#x1f537;岗位职责&#xff1a; 1、负责数据库规划、管理、调优工作&#xff1b; 2、负责数据库应急预案制定、应急预案维护和应急支持&#xff1b; 3、负责数据库异常处理&#xff…

Unity UGUI 之 事件触发器

本文仅作学习笔记与交流&#xff0c;不作任何商业用途 本文包括但不限于unity官方手册&#xff0c;唐老狮&#xff0c;麦扣教程知识&#xff0c;引用会标记&#xff0c;如有不足还请斧正 本文在发布时间选用unity 2022.3.8稳定版本&#xff0c;请注意分别 1.什么是UI事件触发器…

linux安装jdk和jps(为rocketMq准备)

20240730 一、安装rocketMq之前的准备工作1. 安装jkd&#xff08;这里以1.8为例子&#xff09;1.1 下载jdk1.81.2 上传到linux&#xff08;拖拽&#xff09;1.3 解压1.4 配置环境变量1.5 使配置文件生效1.6 验证结果 2. JPS2.1 解决 一、安装rocketMq之前的准备工作 1. 安装jk…

angular入门基础教程(十)管道即过滤器

管道 何为管道&#xff0c;ng 翻译的真烂&#xff0c;但是听多了你就理解了&#xff0c;类似于 vue2 中的过滤器&#xff0c;过滤器在 vue3 中已经废弃 从common包里面引入&#xff0c;并注册 import { Component, inject } from "angular/core"; import { UpperC…

C# 调用Webservice接口接受数据测试

1.http://t.csdnimg.cn/96m2g 此链接提供测试代码&#xff1b; 2.http://t.csdnimg.cn/64iCC 此链接提供测试接口&#xff1b; 关于Webservice的基础部分不做赘述&#xff0c;下面贴上我的测试代码&#xff08;属于动态调用Webservice&#xff09;&#xff1a; 1&#xff…

Appium自动化测试 ------ 常见模拟操作!

Appium自动化测试中的常见模拟操作涵盖了多种用户交互行为&#xff0c;这些操作对于自动化测试框架来说至关重要&#xff0c;因为它们能够模拟真实用户的使用场景&#xff0c;从而验证应用程序的功能和稳定性。 以下是一些Appium自动化测试中常见的模拟操作&#xff1a; 基本操…

XPathParser类

XPathParser类是mybatis对 javax.xml.xpath.XPath的包装类。 接下来我们来看下XPathParser类的结构 1、属性 // 存放读取到的整个XML文档private final Document document;// 是否开启验证private boolean validation;// 自定义的DTD约束文件实体解析器&#xff0c;与valida…

JavaSE面向对象进阶

static 介绍 static表示静态&#xff0c;是Java中的一个修饰符可以修饰成员方法、成员变量 被static修饰的成员变量&#xff0c;叫做静态变量被static修饰的成员方法&#xff0c;叫做静态方法 静态变量 特点&#xff1a;被该类所有对象共享 调用方式&#xff1a; 类名调用&am…

关于@Async

Spring Boot 2.x 开始&#xff0c;默认情况下&#xff0c;Spring AOP 使用 CGLIB 代理 Async不能在同一个类中直接调用 关于在控制器不能使用Async 并不是因为SpringBoot2以前使用JDK代理 因为JDK代理需要类实现接口,控制器没有实现接口等原因 真正原因是 Async 不能…

windows@powershell@任务计划@自动任务计划@taskschd.msc.md

文章目录 使用任务计划windows中的任务计划任务计划命令行程序开发windows 应用中相关api传统图形界面FAQ schtasks 命令常见用法创建计划任务删除计划任务查询计划任务修改计划任务运行计划任务 PowerShell ScheduledTasks常用 cmdlet 简介1. Get-ScheduledTask2. Register-Sc…

手动在ubuntu上搭建一个nginx,并安装证书的最简化完整过程

背景&#xff1a;由于想做个测试&#xff1a;即IP为A的服务器&#xff0c;绑定完域名X后&#xff0c;如果再绑定域名Y&#xff0c;能不能被访问到。&#xff08;假设对A不做绑定域名的设置&#xff09; 这个问题的来源&#xff0c;见上一篇文章&#xff1a;《云服务器被非法域名…

kaggle使用api下载数据集

背景 kaggle通过api并配置代理下载数据集datasets 步骤 获取api key 登录kaggle&#xff0c;点个人资料&#xff0c;获取到自己的api key 创建好的key会自动下载 将key放至家目录下的kaggle.json文件中 我这里是windows的administrator用户。 装包 我用了虚拟环境 pip …