昇思25天学习打卡营第07天 | 函数式自动微分

news2024/10/6 18:28:55

昇思25天学习打卡营第07天 | 函数式自动微分

文章目录

  • 昇思25天学习打卡营第07天 | 函数式自动微分
    • 函数与计算图
      • 微分函数与梯度
      • Stop Gradient
      • Auxiliary data
    • 神经网络梯度计算
    • 总结
    • 打卡

神经网络的训练主要使用反向传播算法,首先计算模型预测值(logits)与正确标签(label)之间的loss,然后进行反向传播,通过梯度来更新模型参数从而完成网路的训练。

MindSpore使用函数式自动微分的设计理念,提供更接近于数学语义的自动微分接口gradvalue_and_grad

函数与计算图

计算图是图论语言表示数学函数的一种方式,也是深度学习框架表达神经网络模型的统一方法。
compute-graph
在这个模型中, x x x为输入, z z z为输出, y y y为正确值, w w w b b b是需要优化的参数。

x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # bias

def function(x, y, w, b):
	z = ops.matmul(x, w) + b
	loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss

通过执行function获得loss值:

loss = function(x,y,w,b)

微分函数与梯度

为了优化参数 w w w b b b,需要求参数对loss的导数 ∂ l o s s ∂ w \frac{\partial loss}{\partial w} wloss ∂ l o s s ∂ b \frac{\partial loss}{\partial b} bloss

可以通过mindspore.grad函数来获得function的微分函数:

grad_fn = mindspore.grad(function, (2, 3))

grads = grad_fn(x, y, w, b)

此处使用了grad的两个入参:

  • fn:待求导的函数;
  • grad_position:指定求导输入位置的索引。

Stop Gradient

通常情况下,求导时会求loss对参数的导数,因此函数只输出loss一项。
如果函数输出多项时,微分函数会求所有输出对参数的导数。

def function_with_logits(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, z
    
grad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

此处function_with_logits输出的z会影响梯度。

如果想要实现对某个输出项的梯度截断,或消除某个Tensor对梯度的影响,需要用到Stop Gradient操作。

def function_stop_gradient(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, ops.stop_gradient(z)
    
grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

Auxiliary data

Auxiliary data为辅助数据,是函数除第一个输出项外的其他输出。通常loss值为函数的第一个输出,而其它输出即为辅助数据。

gradvalue_and_grad提供has_aux参数,设置为True时,可以自动实现前文中stop_gradient的功能。

grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True)
grads, (z,) = grad_fn(x, y, w, b)
print(grads, z)

神经网络梯度计算

# Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.w = w
        self.b = b

    def construct(self, x):
        z = ops.matmul(x, self.w) + self.b
        return z

# Instantiate model
model = Network()
# Instantiate loss function
loss_fn = nn.BCEWithLogitsLoss()

实例化网络和损失函数后,将其封装为一个前向计算函数,用于自动微分:

# Define forward function
def forward_fn(x, y):
    z = model(x)
    loss = loss_fn(z, y)
    return loss

由于使用nn.Cell封装网络模型,其参数为Cell的内部属性,因此不需要指定grad_position参数,直接设置为None

对模型参数求导时,使用weights参数,指定为通过model.trainable_params()方法从Cell中取出的可以求导的参数:

grad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params())

loss, grads = grad_fn(x, y)
print(grads)

总结

这一节从一个简单的线性函数 w x + b wx+b wx+b出发,介绍了网络模型中数学函数的统一表示方法(即计算图),与loss的计算过程。使用gradvalue_and_grad方法可以通过自动微分获取目标函数的微分函数,从而得到参数对loss的梯度,进而优化参数。对于需要输出辅助数据的函数来说,可以通过ops.stop_gradient进行梯度截断,或设置has_aux=True来自动完成。
在通过Cell封装的网络模型中,需要将模型和loss的调用封装为一个前向计算函数,从而进行自动微分。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1898579.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Prompt-Free Diffusion: Taking “Text” out of Text-to-Image Diffusion Models

CVPR2024 SHI Labshttps://arxiv.org/pdf/2305.16223https://github.com/SHI-Labs/Prompt-Free-Diffusion 问题引入 在SD模型的基础之上,去掉text prompt,使用reference image作为生成图片语义的指导,optional structure image作为生成图片…

【Leetcode笔记】406.根据身高重建队列

文章目录 1. 题目要求2.解题思路 注意3.ACM模式代码 1. 题目要求 2.解题思路 首先,按照每个人的身高属性(即people[i][0])来排队,顺序是从大到小降序排列,如果遇到同身高的,按照另一个属性(即p…

关于SAP SAP NetWeaver AS JAVA 授权问题漏洞(CVE-2020-6287)及修复

路径参考 SAP NetWeaver AS Java 严重漏洞 (CVE-2020-6287) 安全通告 - 威胁通告 - 绿盟科技-巨人背后的专家 SAP NOTE ​​​​​​https://me.sap.com/notes/2939665 找到路径 导航到 http(s)://<主机名>:port/nwa -> 配置 -> 基础架构 -> Java HTTP 提供…

Leetcode - 周赛403

目录 一&#xff0c;3200. 三角形的最大高度 二&#xff0c;3195. 包含所有 1 的最小矩形面积 I 三&#xff0c;3196. 最大化子数组的总成本 四&#xff0c;3197. 包含所有 1 的最小矩形面积 II 一&#xff0c;3200. 三角形的最大高度 本题是一道模拟题&#xff0c;可以先排…

从零开始手写STL库:Vector

从零开始手写STL库–Vector部分 文章目录 从零开始手写STL库--Vector部分Vector是什么Vector需要包含什么函数1&#xff09;基础成员函数2&#xff09;核心功能 基础成员函数的编写核心功能函数的编写总结 Vector是什么 std::vector 是一个动态数组&#xff0c;它在内存中以连…

安装Nginx以及简单使用 —— windows系统

一、背景 Nginx是一个很强大的高性能Web和反向代理服务&#xff0c;也是一种轻量级的Web服务器&#xff0c;可以作为独立的服务器部署网站&#xff0c;应用非常广泛&#xff0c;特别是现在前后端分离的情况下。而在开发过程中&#xff0c;我们常常需要在window系统下使用Nginx作…

SwiftUI中List的liststyle样式及使用详解添加、移动、删除、自定义滑动

SwiftUI中的List可是个好东西&#xff0c;它用于显示可滚动列表的视图容器&#xff0c;类似于UITableView。在List中可以显示静态或动态的数据&#xff0c;并支持垂直滚动。List是一个数据驱动的视图&#xff0c;当数据发生变化时&#xff0c;列表会自动更新。针对List&#xf…

关于下载obsidian SimpRead Sync中报错的问题

参考Kenshin的配置方法&#xff0c;我却在输入简悦的配置文件目录时多次报错。 bug如下&#xff1a; 我发现导出来的配置文件格式如下&#xff1a; 然后根据报错的bug对此文件名进行修改&#xff0c;如下&#xff1a; 解决。

【后端面试题】【中间件】【NoSQL】MongoDB查询优化2(优化排序、mongos优化)

优化排序 在MongoDB里面&#xff0c;如果能够利用索引来排序的话&#xff0c;直接按照索引顺序加载数据就可以了。如果不能利用索引来排序的话&#xff0c;就必须在加载了数据之后&#xff0c;再次进行排序&#xff0c;也就是进行内存排序。 可想而知&#xff0c;如果内存排序…

elasticsearch-users和elasticsearch-reset-password介绍

elasticsearch 内置 elastic, kibana, logstash_system,beats_system 共4个用户&#xff0c;用途如下&#xff1a; elastic 账号&#xff1a;内置的超级用户&#xff0c;拥有 superuser 角色。 kibana 账号&#xff1a;用来连接 elasticsearch 并与之通信。Kibana 服务器以该用…

ACL2023 | 如何用175条种子数据打造顶级指令模型?揭秘self-instruct:媲美InstructGPT001的秘密武器

1. 论文的核心问题和核心贡献 核心问题&#xff1a;该论文解决的问题是大规模语言模型在微调响应指令时过于依赖人工编写的指令数据&#xff0c;这些数据往往在数量、种类和创意上都存在局限&#xff0c;阻碍了模型的广泛泛化能力。研究的主要目标是开发一种方法&#xff0c;通…

Java实习手册(小白也看得懂)

秃狼说 距离俺发布的学习路线已经六个月了&#xff0c;那我给小伙伴的学习周期是四五个月左右&#xff0c;我相信大多的小伙伴已经学习的差不多了。正好赶上暑期实习的阶段&#xff0c;在暑期找到实习就成为暑期的头等大事。 实习经验在校招的起到决定性的作用&#xff0c;所…

代码随想录算法训练营第九天|151.翻转字符串里的单词、右旋字符串、28. 实现 strStr()、459.重复的子字符串

打卡Day9 1.151.翻转字符串里的单词2.右旋字符串3.28. 实现 strStr()4.459.重复的子字符串 1.151.翻转字符串里的单词 题目链接&#xff1a;翻转字符串里的单词 文档讲解&#xff1a; 代码随想录 思路&#xff1a;首先&#xff0c;移除多余的空格&#xff1b;然后&#xff0c…

Amesim应用篇-信号传递

前言 在Amesim中常见的信号传递是通过信号线连接&#xff0c;针对简单的模型通过信号线连接还可以是信号线清晰规整&#xff0c;方便查看。如果模型较复杂&#xff0c;传递信号的元件较多时&#xff0c;此时再继续使用信号线进行信号传递&#xff0c;可能会使草图界面看起来杂…

比赛获奖的武林秘籍:02 国奖秘籍-大学生电子计算机类竞赛快速上手的流程,小白必看

比赛获奖的武林秘籍&#xff1a;02 国奖秘籍-大学生电子计算机类竞赛快速上手的流程&#xff0c;小白必看 摘要 本文主要介绍了大学生参加电子计算机类比赛&#xff08;电赛、光电设计大赛、计算机设计大赛、嵌入式芯片与系统设计大赛等比赛&#xff09;的流程和涉及到的知识…

一本超简单能用Python实现办公自动化的神书!让我轻松摆脱办公烦恼!

《超简单&#xff1a;用Python让Excel飞起来》 这本书旨在通过Python与Excel的“强强联手”&#xff0c;为办公人员提供一套高效的数据处理方案。书中还介绍了如何在Excel中调用Python代码&#xff0c;进一步拓宽了办公自动化的应用范围。 全书共9章。第1~3章主要讲解Python编…

软件设计之Java入门视频(11)

软件设计之Java入门视频(11) 视频教程来自B站尚硅谷&#xff1a; 尚硅谷Java入门视频教程&#xff0c;宋红康java基础视频 相关文件资料&#xff08;百度网盘&#xff09; 提取密码&#xff1a;8op3 idea 下载可以关注 软件管家 公众号 学习内容&#xff1a; 该视频共分为1-7…

【C++】 解决 C++ 语言报错:Memory Leak

文章目录 引言 内存泄漏&#xff08;Memory Leak&#xff09;是 C 编程中常见且严重的内存管理问题之一。当程序分配了内存而没有正确释放&#xff0c;导致内存无法被重新利用时&#xff0c;就会发生内存泄漏。这种错误会导致程序占用越来越多的内存&#xff0c;最终可能导致系…

Using a text embedding model locally with semantic kernel

题意&#xff1a;在本地使用带有语义核&#xff08;Semantic Kernel&#xff09;的文本嵌入模型 问题背景&#xff1a; Ive been reading Stephen Toubs blog post about building a simple console-based .NET chat application from the ground up with semantic-kernel. Im…

C++基础21 二维数组及相关问题详解

这是《C算法宝典》C基础篇的第21节文章啦~ 如果你之前没有太多C基础&#xff0c;请点击&#x1f449;C基础&#xff0c;如果你C语法基础已经炉火纯青&#xff0c;则可以进阶算法&#x1f449;专栏&#xff1a;算法知识和数据结构&#x1f449;专栏&#xff1a;数据结构啦 ​ 目…