《昇思25天学习打卡营第6天 | 函数式自动微分》

news2025/1/18 13:51:08

《昇思25天学习打卡营第6天 | 函数式自动微分》

目录

  • 《昇思25天学习打卡营第6天 | 函数式自动微分》
    • 函数式自动微分
    • 简单的单层线性变换模型
    • 函数与计算图
    • 微分函数与梯度计算
    • Stop Gradient

函数式自动微分

神经网络的训练主要使用反向传播算法,模型预测值(logits)与正确标签(label)送入损失函数(loss function)获得loss,然后进行反向传播计算,求得梯度(gradients),最终更新至模型参数(parameters)。自动微分能够计算可导函数在某点处的导数值,是反向传播算法的一般化。自动微分主要解决的问题是将一个复杂的数学运算分解为一系列简单的基本运算,该功能对用户屏蔽了大量的求导细节和过程,大大降低了框架的使用门槛。

MindSpore使用函数式自动微分的设计理念,提供更接近于数学语义的自动微分接口grad和value_and_grad。

简单的单层线性变换模型

我们通过学习使用一个简单的单层线性变换模型来了解

import numpy as np
import mindspore
from mindspore import nn
from mindspore import ops
from mindspore import Tensor, Parameter

函数与计算图

计算图是用图论语言表示数学函数的一种方式,也是深度学习框架表达神经网络模型的统一方法。我们将根据下面的计算图构造计算函数和神经网络。

compute-graph
在这个模型中, 𝑥
为输入, 𝑦
为正确值, 𝑤
和 𝑏
是我们需要优化的参数。

x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # bias

我们根据计算图描述的计算过程,构造计算函数。 其中,binary_cross_entropy_with_logits 是一个损失函数,计算预测值和目标值之间的二值交叉熵损失。

def function(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss

执行计算函数,可以获得计算的loss值。

loss = function(x, y, w, b)
print(loss)

Tensor(shape=[], dtype=Float32, value= 0.914285)

微分函数与梯度计算

为了优化模型参数,需要求参数对loss的导数: ∂ loss ⁡ ∂ w \frac{\partial \operatorname{loss}}{\partial w} wloss ∂ loss ⁡ ∂ b \frac{\partial \operatorname{loss}}{\partial b} bloss,此时我们调用mindspore.grad函数,来获得function的微分函数。

这里使用了grad函数的两个入参,分别为:

  • fn:待求导的函数。
  • grad_position:指定求导输入位置的索引。

由于我们对 w w w b b b求导,因此配置其在function入参对应的位置(2, 3)

使用grad获得微分函数是一种函数变换,即输入为函数,输出也为函数。

grad_fn = mindspore.grad(function, (2, 3))

执行微分函数,即可获得 𝑤
、 𝑏
对应的梯度。

grads = grad_fn(x, y, w, b)
print(grads)

(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01],
[ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01],
[ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01],
[ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01],
[ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01]]),
Tensor(shape=[3], dtype=Float32, value= [ 6.56869709e-02, 5.37334494e-02, 3.01467031e-01]))

Stop Gradient

通常情况下,求导时会求loss对参数的导数,因此函数的输出只有loss一项。当我们希望函数输出多项时,微分函数会求所有输出项对参数的导数。此时如果想实现对某个输出项的梯度截断,或消除某个Tensor对梯度的影响,需要用到Stop Gradient操作。

这里我们将function改为同时输出loss和z的function_with_logits,获得微分函数并执行。

def function_with_logits(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, z
grad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00],
[ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00],
[ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00],
[ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00],
[ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00]]),
Tensor(shape=[3], dtype=Float32, value= [ 1.06568694e+00, 1.05373347e+00, 1.30146706e+00]))

可以看到求得 w w w b b b对应的梯度值发生了变化。此时如果想要屏蔽掉z对梯度的影响,即仍只求参数对loss的导数,可以使用ops.stop_gradient接口,将梯度在此处截断。我们将function实现加入stop_gradient,并执行。

def function_stop_gradient(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, ops.stop_gradient(z)
grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1881240.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LeetCode题练习与总结:重排链表--143

一、题目描述 给定一个单链表 L 的头节点 head ,单链表 L 表示为: L0 → L1 → … → Ln - 1 → Ln请将其重新排列后变为: L0 → Ln → L1 → Ln - 1 → L2 → Ln - 2 → … 不能只是单纯的改变节点内部的值,而是需要实际的进…

centos 7.9 离线环境安装GPU服务环境

文章目录 centos 7.9 离线环境安装GPU服务环境系统配置更新 gcc更新内核安装显卡驱动安装cuda安装docker 和 nvidia-container-runtime验证 centos 7.9 离线环境安装GPU服务环境 基于centos 7.9 离线安装gpu 服务基础环境,用于在docker 中运行算法服务 系统配置 …

python open函数中文乱码怎么解决

首先在D盘下新建一个html文档,接着在里面输入含有中文的Html字符,使用中文格式对读取的字符进行解码,再用utf-8的模式对字符进行编码,然后就能正确输出中文字符。 代码如下: # -*- coding: UTF-8 -*- file1 open(&quo…

Python | Leetcode Python题解之第207题课程表

题目: 题解: class Solution:def canFinish(self, numCourses: int, prerequisites: List[List[int]]) -> bool:edges collections.defaultdict(list)indeg [0] * numCoursesfor info in prerequisites:edges[info[1]].append(info[0])indeg[info[…

【Web3项目案例】Ethers.js极简入门+实战案例:实现ERC20协议代币查询、交易

苏泽 大家好 这里是苏泽 一个钟爱区块链技术的后端开发者 本篇专栏 ←持续记录本人自学智能合约学习笔记和经验总结 如果喜欢拜托三连支持~ 目录 简介 前景科普-ERC20 Ethers极简入门教程:HelloVitalik(非小白可跳) 教程概览 开发工具 V…

LeetCode-刷题记录-二分法合集(本篇blog会持续更新哦~)

一、二分查找概述 二分查找(Binary Search)是一种高效的查找算法,适用于有序数组或列表。(但其实只要满足二段性,就可以使用二分法,本篇博客后面博主会持续更新一些题,来破除一下人们对“只有有…

44 - 50题高级字符串函数 / 正则表达式 / 子句 - 高频 SQL 50 题基础版

目录 1. 相关知识点2.例子2.44 - 修复表中的名字2.45 - 患某种疾病的患者2.46 - 删除重复的电子邮箱2.47 - 第二高的薪水2.48 - 按日期分组销售产品2.49 - 列出指定时间段内所有的下单产品2.50 - 查找拥有有效邮箱的用户 1. 相关知识点 相关函数 函数含义concat()字符串拼接upp…

[AIGC] Shell脚本在工作中的常用用法

Shell脚本是一种为 shell 编写的脚本程序。商业上的 Unix Shell 一般都配备图形界面,主要包括:Bourne Shell(/usr/bin/sh或/bin/sh)、Bourne Again Shell(/bin/bash)、C Shell(/usr/bin/csh&…

专题三:Spring源码中新建module

前面我们构建好了Spring源码,接下来肯定迫不及待来调试啦,来一起看看大名鼎鼎ApplicationContext 新建模块 1、基础步骤 1.1 自定义模块名称如:spring-self 1.2 选择构建工具因为spring使用的是gradle,所以这边需要我们切换默认…

KV260视觉AI套件--PYNQ-DPU

目录 1. 简介 2. DPU 原理介绍 2.1 基本原理 2.2 增强型用法 3. DPU 开发流程 3.1 添加 DPU IP 3.2 在 BD 中调用 3.3 配置 DPU 参数 3.4 DPU 与 Zynq MPSoC互联 3.5 分配地址 3.6 生成 Bitstream 3.7 生成 BOOT.BIN 4. 总结 1. 简介 在《Vitis AI 环境搭建 &…

爬虫中如何创建Beautiful Soup 类的对象

在使用 lxml 库解析网页数据时,每次都需要编写和测试 XPath 的路径表达式,显得非常 烦琐。为了解决这个问题, Python 还提供了 Beautiful Soup 库提取 HTML 文档或 XML 文档的 节点。 Beautiful Soup 使用起来很便捷,…

数据结构-期末复习题

数据结构-期末复习题 一、选择题 1、在数据结构中,与所使用的计算机无关的是数据的( ) 结构。 A. 存储B. 物理C. 逻辑D. 物理和存储 【答案】C 【解析】暂无解析2、算法分析的两个主要方面是 ( )。 A. 正确性和简单性B. 可读性和文档性C. 空间复杂度…

拼多多滑块逆向

声明(lianxi a15018601872) 本文章中所有内容仅供学习交流使用,不用于其他任何目的,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关! 前言(lianxi …

onInterceptTouchEvent() 与 onTouch() 事件分析

前言 本文主要分析 onTouch() 与 onTouchEvent() 事件的差异 正文 先看布局文件&#xff1a; <?xml version"1.0" encoding"utf-8"?> <com.longzhiye.intercepttouch.MyFrameLayout xmlns:android"http://schemas.android.com/apk/res…

Big Data Tools插件

一些介绍 在Jetbrains的产品中&#xff0c;均可以安装插件&#xff0c;其中&#xff1a;Big Data Tools插件可以帮助我们方便的操作HDFS&#xff0c;比如 IntelliJ IDEA&#xff08;Java IDE&#xff09; PyCharm&#xff08;Python IDE&#xff09; DataGrip&#xff08;SQL …

ctfshow-web入门-命令执行(web71-web74)

目录 1、web71 2、web72 3、web73 4、web74 1、web71 像上一题那样扫描但是输出全是问号 查看提示&#xff1a;我们可以结合 exit() 函数执行php代码让后面的匹配缓冲区不执行直接退出。 payload&#xff1a; cvar_export(scandir(/));exit(); 同理读取 flag.txt cinclud…

mybatis实现动态sql

第一章、动态SQL MyBatis 的强大特性之一便是它的动态 SQL。如果你有使用 JDBC 或其它类似框架的经验&#xff0c;你就能体会到根据不同条件拼接 SQL 语句的痛苦。例如拼接时要确保不能忘记添加必要的空格&#xff0c;还要注意去掉列表最后一个列名的逗号。利用动态 SQL 这一特…

网安小贴士(2)OSI七层模型

一、前言 OSI七层模型是一种网络协议参考模型&#xff0c;用于描述计算机网络体系结构中的不同层次和功能。它由国际标准化组织 (ISO) 在1984年开发并发布。 二、定义 OSI七层模型&#xff0c;全称为开放式系统互联通信参考模型&#xff08;Open Systems Interconnection Refe…

微型导轨:手术机器人的高精度“骨骼”

微型导轨精度高&#xff0c;摩擦系数小&#xff0c;自重轻&#xff0c;结构紧凑&#xff0c;被广泛应用在医疗器械中&#xff0c;尤其是在手术机器人中的应用&#xff0c;通过手术机器人&#xff0c;外科医生可以远离手术台操纵机器人进行手术。可以说&#xff0c;是当之无愧的…

6-linux操作文件与数据上传

目录 通配符学习 创建和上传下载文件 系统盘与数据盘 创建和查找文件夹 文件上传与移动 Windows上传文件到服务器方法1 Windows上传文件到服务器方法2&#xff1a;网盘等 删除文件 rm 复制文件cp 移动文件 常用的命令 cp:复制文件和目录。 mv:移动或重命名文件和目…