昇思MindSpore学习笔记7--函数式自动微分

news2024/11/25 19:20:35

摘要:

介绍了昇思MindSpore神经网络训练反向传播算法中函数式自动微分的使用方法和步骤。包括构造计算函数和神经网络、grad获得微分函数,以及如何处理停止渐变、获取辅助数据等内容。

一、概念要点

神经网络训练主要使用反向传播算法

        准备模型预测值logits与正确标签label

        损失函数loss function计算loss

        反向传播计算梯度gradients

        更新模型参数parameters

自动微分

        计算可导函数在某点处的导数值

        将复杂的数学运算分解为一系列简单的基本运算

MindSpore函数式自动微分接口

        grad

        value_and_grad

二、环境准备

安装minspore模块

!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.3.0rc1

导入numpy、minspore、nn、ops等相关模块

import numpy as np
import mindspore
from mindspore import nn
from mindspore import ops
from mindspore import Tensor, Parameter

三、函数与计算图

计算图是用图论语言来表示数学函数。

深度学习框架用来表达神经网络模型。

以下图为例构造计算函数和神经网络。

在这个模型中,x为输入,z为预测值,y为目标值,w和b是需要优化的参数。

根据计算图表达的计算过程,构造计算函数。

binary_cross_entropy_with_logits

        损失函数,计算预测值z和目标值y之间的二值交叉熵损失。

x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # bias

def function(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
return loss

loss = function(x, y, w, b)
print(loss)

输出:

1.5375124

四、微分函数与梯度计算

优化模型参数需要对参数w、b求loss的导数:

调用mindspore.grad函数获得function的微分函数。

 注:grad获得微分函数是一种函变换,即输入为函数,输出也为函数。

grad_fn = mindspore.grad(function, (2, 3))

mindspore.grad函数的两个入参:

fn: 求导函数。

grad_position:求导参数的索引位置。

参数 w和b在function入参对应的位置为(2, 3)。

执行微分函数获得w、b对应的梯度。

grads = grad_fn(x, y, w, b)
print(grads)

输出:

(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01],
 [ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01],
 [ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01],
 [ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01],
 [ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01]]), 
Tensor(shape=[3], dtype=Float32, value=
 [ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01]))

五、Stop Gradient停止渐变

消除某个Tensor对梯度的影响或实现对某个输出项的梯度截断

通常情况下,求导函数的输出只有loss一项。

微分函数会对参数求所有输出项的导数。

function_with_logits修改原function支持同时输出loss和z

获得自动微分函数并执行。

def function_with_logits(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, z
grad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

输出:

(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 1.27216899e+00,  1.20658815e+00,  1.28590846e+00],
 [ 1.27216899e+00,  1.20658815e+00,  1.28590846e+00],
 [ 1.27216899e+00,  1.20658815e+00,  1.28590846e+00],
 [ 1.27216899e+00,  1.20658815e+00,  1.28590846e+00],
 [ 1.27216899e+00,  1.20658815e+00,  1.28590846e+00]]),
 Tensor(shape=[3], dtype=Float32, value=
 [ 1.27216899e+00,  1.20658815e+00,  1.28590846e+00]))

wb对应的梯度值变了。

使用ops.stop_gradient接口屏蔽z对梯度的影响

def function_stop_gradient(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, ops.stop_gradient(z)
grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

输出:

(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01],
 [ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01],
 [ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01],
 [ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01],
 [ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01]]),
 Tensor(shape=[3], dtype=Float32, value=
 [ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01]))

此时wb对应的梯度值与初始function求得的梯度值一致。

六、Auxiliary data辅助数据

函数除第一个输出项外的其他输出。

通常会将loss设置为函数的第一个输出,其他的输出即为辅助数据。

grad和value_and_gradhas_aux参数

设置True,自动实现stop_gradient,返回辅助数据同时不影响梯度计算的效果。

下面仍使用function_with_logits,配置has_aux=True,并执行。

grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True)
grads, (z,) = grad_fn(x, y, w, b)
print(grads, z)

输出:

(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01],
 [ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01],
 [ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01],
 [ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01],
 [ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01]]),
 Tensor(shape=[3], dtype=Float32, value=
 [ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01]))
 [1.4928596  0.48854822 1.7965223 ]

此时wb 对应的梯度值与初始function求得的梯度值一致

同时z能够作为微分函数的输出返回。

七、神经网络梯度计算

下面通过继承nn.Cell构造单层线性变换神经网络

利用函数式自动微分来实现反向传播。

使用mindspore.Parameter封装w、b模型参数作为内部属性

在construct内实现相同的Tensor操作。

# Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.w = w
        self.b = b
​
    def construct(self, x):
        z = ops.matmul(x, self.w) + self.b
        return z

# Instantiate model
model = Network()
# Instantiate loss function
loss_fn = nn.BCEWithLogitsLoss()

使用函数式自动微分需要将神经网络和损失函数的调用封装为一个前向计算函数。

# Define forward function
def forward_fn(x, y):
    z = model(x)
    loss = loss_fn(z, y)
    return loss

使用value_and_grad接口获得微分函数,用于计算梯度。

Cell封装神经网络模型,模型参数为Cell的内部属性,不需要使用grad_position指定对函数输入求导,因此将其配置为None。

使用model.trainable_params()方法从weights参数取出可以求导的参数。

grad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params())
loss, grads = grad_fn(x, y)
print(grads)

输出:

(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01],
 [ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01],
 [ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01],
 [ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01],
 [ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01]]),
 Tensor(shape=[3], dtype=Float32, value=
 [ 2.72169024e-01,  2.06588134e-01,  2.85908401e-01]))

执行微分函数,可以看到梯度值和前文function求得的梯度值一致。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1879140.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

学习平台推荐_菜鸟教程官网

网址: 菜鸟教程 - 学的不仅是技术,更是梦想!菜鸟教程(www.runoob.com)提供了编程的基础技术教程, 介绍了HTML、CSS、Javascript、Python,Java,Ruby,C,PHP , MySQL等各种编程语言的基础知识。 同…

[数据集][目标检测]猪只状态吃喝睡站检测数据集VOC+YOLO格式530张4类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):530 标注数量(xml文件个数):530 标注数量(txt文件个数):530 标注类别…

【操作系统期末速成】EP06 | 学习笔记(基于五道口一只鸭)

文章目录 一、前言🚀🚀🚀二、正文:☀️☀️☀️2.1 考点十四:同步互斥的基本概念2.2 考点十五:实现临界区互斥的基本方法2.3 考点十六:信号量的含义及常用信号量 一、前言🚀&#x1…

土体中应力的计算

土中的应力的计算 非水面以下土体中应力的计算:水面以下的土中的应力计算 参考视频: https://www.bilibili.com/video/BV1Rh411J72h/?spm_id_from333.788&vd_source02b2bad477a153eaeb9c48cbbedaf8df 非水面以下土体中应力的计算: 按成…

深入理解 Spring MVC:原理与架构解析

文章目录 前言一、MVC二、Spring MVC三、Spring MVC 工作流程四、小结推荐阅读 前言 Spring MVC 是一种基于 Java 的 Web 应用开发框架,它通过模型-视图-控制器(Model-View-Controller, MVC)的设计模式来组织和管理 Web 应用程序。本文将深入…

【Uniapp微信小程序】图片左右分割/分割线切割图片/图片批量分割线切割

特别说明:本文章因业务组件功能,不完全开放/暂vip可见,有需要者留言找博主! ps:注意!!本效果为图片分割切割!!不是文档切割!!图片仅供参考&#x…

【机器学习】Python中sklearn中数据基础处理与分析过程

📝个人主页:哈__ 期待您的关注 目录 1. 简介 ​编辑 1.1 什么是Scikit-learn 介绍Scikit-learn 应用领域 1.2 安装Scikit-learn 安装步骤 必要的依赖 2. 数据处理 2.1 创建示例数据 2.2 数据预处理 处理缺失值 特征编码 特征缩放 3. 数据…

kali/ubuntu安装vulhub

无须更换源,安装docker-compose apt install docker.io docker -vdocker-compose #提示没有,输入y安装mkdir -p /etc/docker vi /etc/docker/daemon.json #更换dockerhub国内源┌──(root㉿kali)-[/home/kali/vulhub-master/tomcat/CVE-2017-12615] …

Java对象创建过程

在日常开发中,我们常常需要创建对象,那么通过new关键字创建对象的执行中涉及到哪些流程呢?本文主要围绕这个问题来展开。 类的加载 创建对象时我们常常使用new关键字。如下 ObjectA o new ObjectA();对虚拟机来讲首先需要判断ObjectA类的…

测评推荐:企业管理u盘的软件有哪些?

U盘作为一种便携的存储设备,方便易用,被广泛应用于企业办公、个人学习及日常工作中。然而,U盘的使用也带来了数据泄露、病毒传播等安全隐患。为了解决这些问题,企业管理U盘的软件应运而生。 本文将对市面上流行的几款U盘管理软件…

大模型RAG问答中的文档分段

昨天,我们谈了句子分段,我们再来回顾一下段落的分段方法,目前已经有其他方案,图来自于:https://www.rungalileo.io/blog/mastering-rag-advanced-chunking-techniques-for-llm-applications,可以看到其中的…

Java17-时间类、包装类

目录 Date类 概述 常用方法 SimpleDateFormat类 概述 构造方法 格式规则 常用方法 Calendar类 概述 常用方法 get方法示例 set方法示例 add方法示例 JDK8时间相关类 ZoneId 时区 Instant 时间戳 ZoneDateTime 带时区的时间 DateTimeFormatter 用于时间的格式…

【Lua】第一篇:在Linux系统中安装搭建lua5.4.1环境

文章目录 一. 远程下载安装包二. 解压安装包三. 编译安装Lua环境 一. 远程下载安装包 输入以下命令即可在当前目录下,远程下载安装包lua-5.4.1.tar.gz: wget http://www.lua.org/ftp/lua-5.4.1.tar.gzPS:其他版本的安装包如下,可…

Django 模版继承

1&#xff0c;设计母版页 Test/templates/6/base.html <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><!-- 修正了模板标签的全角字符问题 -->{% block title %}<title>这个是母版页</title>{…

教师资格证(教资)笔试如何备考?含备考资料

教师资格证&#xff08;教资&#xff09;笔试如何备考&#xff1f;含备考资料 前言 教师&#xff0c;一直以来的热门职业&#xff0c;而要成为一名教师&#xff0c;考取教师资格证则是基本条件&#xff0c;那么教资笔试如何备考呢&#xff1f;&#xff0c;这里准备笔试备考攻…

南方空调企业疑似暗讽对手用铝代铜,偷工减料,空调不耐用

空调行业卷到何种程度了&#xff1f;已经开始偷工减料了&#xff0c;日前南方一家空调企业的老板就公开指出一些企业用铝管替代铜管&#xff0c;如此做后果将是导致空调的耐久性和稳定性不足&#xff0c;其实还有散热效果不好&#xff0c;导致耗电量、制冷效果下降。 今年空调的…

Spring-循环依赖是如何解决的

1、bean被创建保存到spring容器的过程 1、实例化 -> 获取对象&#xff1b; 2、填充属性&#xff1b;这里可能需要依赖其它的bean。 3、AOP代理对象替换&#xff1b; 4、加入单例池&#xff1b; 问题&#xff1a; 循环依赖怎么处理 ServiceA 中有属性ServiceB b&#…

记一次小程序渗透

这次的小程序渗透刚好每一个漏洞都相当经典所以记录一下。 目录 前言 漏洞详情 未授权访问漏洞/ 敏感信息泄露&#xff08;高危&#xff09; 水平越权&#xff08;高危&#xff09; 会话重用&#xff08;高危&#xff09; 硬编码加密密钥泄露&#xff08;中危&#xff0…

FFmpeg 命令行 音视频格式转换

&#x1f4da;&#xff1a;FFmpeg 提供了丰富的命令行选项和功能&#xff0c;可以用来处理音视频文件、流媒体等&#xff0c;掌握命令行的使用&#xff0c;可以有效提高工作效率。 目录 一、视频转换和格式转换 &#x1f535; 将视频文件转换为另一种格式 &#x1f535; 指定…

代码随想录--字符串--替换数字

题目 给定一个字符串 s&#xff0c;它包含小写字母和数字字符&#xff0c;请编写一个函数&#xff0c;将字符串中的字母字符保持不变&#xff0c;而将每个数字字符替换为number。 例如&#xff0c;对于输入字符串 “a1b2c3”&#xff0c;函数应该将其转换为 “anumberbnumber…