241130_昇思MindSpore函数式自动微分

news2025/2/12 0:49:49

241130_昇思MindSpore函数式自动微分

函数式自动微分是Mindspore学习框架所特有的,更偏向于数学计算的习惯。这里也是和pytorch差距最大的部分,具体体现在训练部分的代码,MindSpore是把各个梯度计算、损失函数计算

image-20241130201017018

在这幅图中,右边这个就是函数式编程,首先先自己定义一个loss函数,最后一行使用grad(),把loss function传进去,因为传进去的是一个函数,做的是一个函数闭包,所以返回的还是一个函数。

深度学习的计算流程

首先是先正向计算,得到一个logits,然后会计算这个logits和真实的targets的误差loss,然后反向传播backwards,得到梯度,然后再送到优化器里面去更新网络权重。

MindSpore是整图计算,将模型的前向、反向、梯度更新过程全部视为一个完整的计算图,这样就有效提高了执行速度,但是带来的弊端就是代码不好书写,和pytorch差异较大,但到了2.0就得到了有效解决。

MindSpore后来使用面向对象+函数式混合使用,1-2和pytorch一样,后面3-6和原来的函数式编程一样。

1、用类构建神经网络

2、实例化Network对象

3、Network+Loss直接构造正向函数

4、函数变换,获得梯度计算(反向传播)函数

5、构造训练过程函数

6、调用函数进行训练

具体实现看以下代码:

image-20241130201123780

pytorch的实现

看右边pytorch的代码就可以很直观的看出来上面说的计算流程的几步。

正向计算得到logits

logits=net(inputs)

计算logits和target的误差loss

loss=loss_fn(logits,targets)

反向传播backward

loss.backward()

送到优化器里去更新网络权重

optimizer.step()

以上是pytorch对应步骤的代码

MindSpore2.x的实现

相对来说MindSpore没有那么直观,但逻辑上都是一样的。

首先,MindSpore这边正向计算需要定义一个函数方法,里面写了loss的计算

def forword_fn(inputs,targets):
    logits=net(inputs)
    loss=loss_fn(logits,targets)
    return loss

然后把整个方法作为参数传入value_and_grad方法,做一个函数闭包,然后得到一个同样的grad_fn方法(第三个参数往往是net的所有可训练参数,看上图mindspore这边第三行代码也能看出来)。

grad_fn=value_and_grad(forward_fn,None,optim.parameters)

训练的每个step我们也需要定义一个方法,里面把刚才得到的grad_fn方法拿过来,一次得到损失loss和梯度grads,然后直接把梯度传进优化器进行更新权重。

def train_step(inputs,targets):
    loss,grads=grad_fn(inputs,targets)
    optimizer(grads)
    return loss

在实际epoch循环中,我们只需要读入data和其targets,然后直接传入单步训练的train_step就可以了

其实要说函数式微分的话,封装的第一个forward_fn已经是函数式微分了,train_step反而不像函数式微分,就是一个单纯的计算,没有涉及到函数闭包,函数套函数这样的写法,那为什么还要封装呢。

这里就涉及到MindSpore想实现加速的问题,实现整图下发,避免来回传输数据受到的带宽限制,具体实现只需要给train_step上方添加一行修饰器代码

@ms.jit

官方的教学示例notebook

接下来我们将定义几个变量x,y,w,b,z

x是输入,y是真实targets,w是权重,b是偏置,z是我们计算得到的label

w和b是我们要优化的东西

image-20241130211124203

x置全1矩阵,y给全0矩阵,w和b随机给个初始值

image-20241130212009351

然后可以构建一个loss计算的fuction

image-20241130212036985

然后就是使用函数式微分,函数套函数,计算梯度

image-20241130213146336

image-20241130213307493

简单来说,就是你如果计算loss的哪个fuction返回多个参数的话,传入计算梯度的方法,计算结果就会出现偏差,这时候我们就要调用接口去实现stop_gradient(也不算手动实现吧。就是输出的时候嵌套一下)

image-20241130213546914

我们在loss里面返回的z,现在看起来也没用,即返回了,又要排除他的影响,那为什么还要返回呢,数据又拿不出来。其实不对,数据是可以拿出来的。

调用grad_fn的时候我们就可以拿到这个数据,相当于z只是到grad_fn中转了一圈,没干什么。

image-20241130214011239

接下来主要就是一个没有封装train_step的操作方法,主要逻辑上文也说过的,这里也不再赘述

候我们就可以拿到这个数据,相当于z只是到grad_fn中转了一圈,没干什么。

[外链图片转存中…(img-3PZXndA7-1732978830210)]

接下来主要就是一个没有封装train_step的操作方法,主要逻辑上文也说过的,这里也不再赘述

image-20241130214123571

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2251044.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

菱形打印(Python)

“以块组合块”,以行凝结循环打印。 (笔记模板由python脚本于2024年11月30日 19:55:22创建,本篇笔记适合正在学习python循环的coder翻阅) 【学习的细节是欢悦的历程】 Python 官网:https://www.python.org/ Free:大咖免费“圣经”…

【QT入门到晋级】QT项目打生产环境包--(Linux和window)

前言 使用QTcreator完成正常编译后,在构建目录中有可执行程序生成,如果直接把可执行程序拷贝到干净的生产环境上是无法运行成功的,使用ldd(查看程序依赖包)会发现缺失很多QT的特性包,以及将介绍国产Linux桌…

Super Vlan与Mux Vlan

SuperVlan VLAN Aggregation, 也称 Super-VLAN : 指 在一个物理网络内,用多个 VLAN (称为 Sub-VLAN )隔离 广播域,并将这些 Sub-VLAN 聚合成一个逻辑的 VLAN (称为 Super-VLAN ),这…

蓝牙定位的MATLAB程序,四个锚点、三维空间

这段代码通过RSSI信号强度实现了在三维空间中的蓝牙定位,展示了如何使用锚点位置和测量的信号强度来估计未知点的位置。代码涉及信号衰减模型、距离计算和最小二乘法估计等基本概念,并通过三维可视化展示了真实位置与估计位置的关系。 目录 程序描述 运…

Hutool 秒速实现 2FA 两步验证

前言 随着网络安全威胁的日益复杂,传统的用户名和密码认证方式已不足以提供足够的安全保障。为了增强用户账户的安全性,越来越多的应用和服务开始采用多因素认证(MFA)。基于时间的一次性密码(TOTP, Time-based One-Ti…

【继承】—— 我与C++的不解之缘(十九)

前言: 面向对象编程语言的三大特性:封装、继承和多态 本篇博客来学习C中的继承,加油! 一、什么是继承? ​ 继承(inheritance)机制是⾯向对象程序设计使代码可以复⽤的最重要的⼿段,它允许我们在保持原有类…

【目标跟踪】Anti-UAV数据集详细介绍

Anti-UAV数据集是在2021年公开的专用于无人机跟踪的数据集,该数据集采用RGB-T图像对的形式来克服单个类型视频的缺点,包含了318个视频对,并提出了相应的评估标准(the state accurancy, SA)。 文章链接:https://arxiv.…

偏差-方差权衡(Bias–Variance Tradeoff):理解监督学习中的核心问题

偏差-方差权衡(Bias–Variance Tradeoff):理解监督学习中的核心问题 在机器学习中,我们希望构建一个能够在训练数据上表现良好,同时对未见数据也具有强大泛化能力的模型。然而,模型的误差(尤其…

Figma入门-原型交互

Figma入门-原型交互 前言 在之前的工作中,大家的原型图都是使用 Axure 制作的,印象中 Figma 一直是个专业设计软件。 最近,很多产品朋友告诉我,很多原型图都开始用Figma制作了,并且很多组件都是内置的,对…

Windows系统怎么把日历添加在桌面上用来记事?

在众多电脑操作系统中,Windows系统以其广泛的用户基础和强大的功能,成为许多人的首选。对于习惯于在电脑前工作和学习的用户来说,能够直接在桌面上查看和记录日历事项,无疑会大大提高工作效率和生活便利性。今天,就为大…

蓝桥杯备赛笔记(一)

这里的笔记是关于蓝桥杯关键知识点的记录,有别于基础语法,很多内容只要求会用就行,无需深入掌握。 文章目录 前言一、编程基础1.1 C基础格式和版本选择1.2 输入输出cin和cout: 1.3 string以下是字符串的一些简介:字符串…

大数据新视界 -- 大数据大厂之 Hive 数据压缩:优化存储与传输的关键(上)(19/ 30)

💖💖💖亲爱的朋友们,热烈欢迎你们来到 青云交的博客!能与你们在此邂逅,我满心欢喜,深感无比荣幸。在这个瞬息万变的时代,我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…

RNN And CNN通识

CNN And RNN RNN And CNN通识一、卷积神经网络(Convolutional Neural Networks,CNN)1. 诞生背景2. 核心思想和原理(1)基本结构:(2)核心公式:(3)关…

求整数的和与均值

求整数的和与均值 C语言代码C 代码Java代码Python代码 &#x1f490;The Begin&#x1f490;点点关注&#xff0c;收藏不迷路&#x1f490; 读入n&#xff08;1 < n < 10000&#xff09;个整数&#xff0c;求它们的和与均值。 输入 输入第一行是一个整数n&#xff0c;…

配置idea环境进行scala编程

这里用的jdk是jdk-8u161,scala版本是2.12.0 在d盘新建一个本地仓库用来存放下载的maven包&#xff0c;在里面创建如下两个文件 更改settings文件为下面的样子 点击左下角的设置&#xff0c;更改maven本地仓库的位置&#xff08;默认在c盘用户目录下的.m2文件中&#xff0c;更改…

WSL简介与安装流程(Windows 下的 Linux 子系统)

目录 1.wsl安装 1.1 WSL简介 1.1.1 WSL 的主要功能 1.1.2 WSL 的版本 1.1.3 为什么使用 WSL&#xff1f; 1.1.4 WSL 的工作原理 1.1.5 WSL 的常见使用场景 1.1.6 与虚拟机的区别 1.1.7 适合使用 WSL 的人群 1.2 启用 WSL 1.2.1 打开 PowerShell&#xff08;管理员模…

【Java树】二叉树遍历的简单实现

二叉树的遍历 二叉树的遍历是值按照一定顺序访问二叉树中所有结点的过程&#xff0c;确保每个结点被访问且仅被访问一次。遍历操作是对二叉树的基础操作&#xff0c;用于后续的查找、排序和路径计算等功能。 二叉树的遍历有以下几种常见方式&#xff1a;深度遍历&#xff08;…

STL算法之set相关算法

STL一共提供了四种与set(集合)相关的算法&#xff0c;分别是并集(union)、交集(intersection)、差集(difference)、对称差集(symmetric difference)。 目录 set_union set_itersection set_difference set_symmetric_difference 所谓set&#xff0c;可细分为数学上定义的和…

鸿蒙ArkUI-X已更新适配API13啦

ArkUI-X 5.0.1 Release版配套OpenHarmony 5.0.1 Rlease&#xff0c;API 13&#xff0c;新增适配部分API 13接口支持跨平台&#xff1b;框架能力进一步完善&#xff0c;支持Android应用非压缩模式&#xff0c;支持Android Fragment对接跨平台。ACE Tools工具易用性提升&#xff…

rest-assured multiPart上传中文名称文件,文件名乱码

rest-assured是一个基于java语言的REST API测试框架&#xff0c;在使用rest-assured的multipart 上传文件后&#xff0c;后端获取的文件名称乱码。截图如下&#xff1a; 原因是rest-assured multipart/form-data默认的编码格式是US-ASCII&#xff0c;需要设置为UTF-8。 Befo…