《昇思25天学习打卡营第7天|函数式自动微分》

news2024/11/19 19:43:08

文章目录

  • 今日所学:
  • 一、函数与计算图
  • 二、微分函数与梯度计算
  • 三、Stop Gradient
  • 四、Auxiliary data
  • 五、神经网络梯度计算
  • 总结


今日所学:

今天我学习了神经网络训练的核心原理,主要是反向传播算法。这个过程包括将模型预测值(logits)和正确标签(label)输入到损失函数(loss function)中计算loss,然后通过反向传播算法计算梯度(gradients),最终更新模型参数(parameters)。自动微分技术能够在某点计算可导函数的导数值,是反向传播算法的一个广义实现。它的主要作用是将复杂的数学运算分解为一系列简单的基本运算,从而屏蔽了大量求导的细节和过程,显著降低了使用深度学习框架的门槛。

MindSpore采用函数式自动微分的设计理念,提供了更接近数学语义的自动微分接口,例如grad和value_and_grad。为了更好地理解这些概念,我还学习了如何使用一个简单的单层线性变换模型进行实践。


一、函数与计算图

MindSpore之前的还不熟悉的相关内容可以见:《昇思25天学习打卡营第1天|基本介绍》

计算图是一种借助图论来描绘数学函数的一种方法,同时也是深度学习框架用以表达神经网络模型的通用方式。以下,我们将以此计算图为基础,来构建计算函数和神经网络:
在这里插入图片描述
在本节所学的这个模型中,𝑥为输入,𝑦为正确值,𝑤和𝑏是我们需要优化的参数,根据计算图描述的计算过程,构造计算函数,执行计算函数,可以获得计算的loss值,代码与结果如下所示:

x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # bias

def function(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss

loss = function(x, y, w, b)
print(loss)

结果如下:

Tensor(shape=[], dtype=Float32, value= 0.914285)

二、微分函数与梯度计算

在之后学习内容中为了优化模型参数,需要求参数对loss的导数:

∂loss∂𝑤

∂loss∂𝑏

此时我们调用mindspore.grad函数,来获得function的微分函数。其中grad函数的两个入参,分别为fn(待求导的函数)与grad_position(指定求导输入位置的索引),代码如下:

grad_fn = mindspore.grad(function, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

结果如下:

在这里插入图片描述

使用grad获得微分函数是一种函数变换,即输入为函数,输出也为函数。

三、Stop Gradient

在常规的情况下,求导操作主要是计算loss相对于参数的导数,由此,函数的输出仅有loss一项。然而,当我们期望函数有多项输出时,微分函数将会计算所有输出项相对于参数的导数。在这种情况下,如果我们希望实现特定输出项的梯度截断,或者需要消除某个Tensor对梯度的影响,那么我们将需要使用Stop Gradient操作。在这里,我们会将function改造成同时输出loss和z的function_with_logits,并获取微分函数以供执行。

如果想要屏蔽掉z对梯度的影响,即仍只求参数对loss的导数,可以使用ops.stop_gradient接口,将梯度在此处截断。

代码如下:

def function_with_logits(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, z

grad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

def function_stop_gradient(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, ops.stop_gradient(z)

grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

截断前结果:
在这里插入图片描述

截断后结果:

在这里插入图片描述

四、Auxiliary data

我深入理解了Auxiliary data(辅助数据)的概念和应用。我明白了辅助数据其实就是函数的非主要输出项。在实际应用中,我们常将函数的主要输出设为loss,而其它的所有输出则被视为辅助数据。对于grad和value_and_grad函数,我享受到了has_aux参数带来的便利。当将其设为True,它就能自动实现之前需要手动添加的stop_gradient操作。这种设计巧妙地使我在返回辅助数据的同时,不受梯度计算的任何影响。

在后续的实践中,我继续使用了function_with_logits,并设置了has_aux为True进行操作。整个过程顺畅无比,加深了我对这一主题的理解。我会持续探索,并将这些知识应用到更广泛的场景中去:

grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True)

grads, (z,) = grad_fn(x, y, w, b)
print(grads, z)

结果如下:

在这里插入图片描述

五、神经网络梯度计算

前面章节已经讲述了网络构建,还不了解的可见这篇文章:《昇思25天学习打卡营第6天|网络构建》

接下来,我深入了解了如何通过Cell去构造神经网络,以及利用函数式自动微分来实现反向传播的过程。我首先继承了nn.Cell来构建单层线性变换神经网络。有意思的是,这个过程中我直接使用了之前的 𝑤 和 𝑏 来作为模型参数。这种做法完全打破了我早前的理解,让我认识到原来我们可以直接使用现有的参数以节约时间和计算资源。我将这些参数用mindspore.Parameter包装起来作为内部属性,并在construct内实现了与之前一样的Tensor操作:

# Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.w = w
        self.b = b

    def construct(self, x):
        z = ops.matmul(x, self.w) + self.b
        return z
        
# Instantiate model
model = Network()
# Instantiate loss function
loss_fn = nn.BCEWithLogitsLoss()

# Define forward function
def forward_fn(x, y):
    z = model(x)
    loss = loss_fn(z, y)
    return loss
    
grad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params())

loss, grads = grad_fn(x, y)
print(grads)

结果如下:
在这里插入图片描述

可以看出,执行微分函数后的梯度值和前文function求得的梯度值一致。

在这里插入图片描述

总结

在今天的学习中,我深入理解了神经网络训练的核心原理,包括反向传播算法和如何利用自动微分技术来计算梯度并更新模型参数。我也学习了如何使用MindSpore框架的函数式自动微分接口来进行实践,并利用计算图进行模型参数优化。此外,我理解了Stop Gradient操作和辅助数据对梯度计算的影响,以及如何在神经网络的梯度计算中有效利用它们。通过理论学习和实践操作,我对这些概念有了更深入的理解,期待在明天的学习中继续进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1890423.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CX1概念车空气动力学设计,打造典型“美式肌肉车”风格

Altair CX1概念车的设计 CX1是Altair公司为了满足汽车行业的设计和虚拟仿真需求而开发的一款概念车型。该车总长接近5米,外观具有典型的美式肌肉车的风格,具有视觉冲击力。 车辆的外形设计过程就是风险管理,设计师想要一个大胆而富有表现力的…

Continual Test-Time Domain Adaptation--论文笔记

论文笔记 资料 1.代码地址 https://github.com/qinenergy/cotta 2.论文地址 https://arxiv.org/abs/2203.13591 3.数据集地址 论文摘要的翻译 TTA的目的是在不使用任何源数据的情况下,将源预先训练的模型适应到目标域。现有的工作主要考虑目标域是静态的情况…

拉曼光谱入门:1.光谱的分类与散射光谱发展史

一、光谱是什么? 在一个宁静的午后,年轻的艾萨克牛顿坐在他母亲花园里的一棵苹果树下,手握一块精致的三棱镜。他沉思着光的奥秘,意识到光并非单一的白色,而是一种由多彩色组成的复杂结构。 他决心进行一次实验&#xf…

静态时序分析:ideal_clock、propagated_clock以及generated_clock的关系及其延迟计算规则(二)

相关阅读 静态时序分析https://blog.csdn.net/weixin_45791458/category_12567571.html?spm1001.2014.3001.5482 生成时钟 上一节中,我们讨论了理想时钟和传播时钟的创建和使用,本节将讨论生成时钟及其与理想时钟和传播时钟的关系。 图1所示的是一个简…

Java环境变量的设置

JAVA环境变量的设置 1.设置环境变量的作用2.如何设置环境变量2.1 找到系统的环境变量2.2 设置环境变量 1.设置环境变量的作用 说明:在Java中设置环境变量主要是为了能够让Java运行时能够找到Java开发工具包(JDK)的安装位置以及相关的库文件。…

JavaSE阶段面试题(一)

目录 1.int a 1, int b 1, Integer c 1, Integer d 1;四个区别和联系,以及c和d是同一个吗? 2.为什么重写HashCode必须重写euqals,两者之间的关系? 3.创建对象的方式有哪些 4.重写和重载的区别 5.抽象类和接口…

Webpack: Dependency Graph 管理模块间依赖

概述 Dependency Graph 概念来自官网 Dependency Graph | webpack 一文,原文解释: Any time one file depends on another, webpack treats this as a dependency. This allows webpack to take non-code assets, such as images or web fonts, and als…

算法day1 两数之和 两数相加 冒泡排序 快速排序

两数之和 最简单的思维方式肯定是去凑两个数,两个数的和是目标值就ok。这里两遍for循环解决。 两数相加 敲了一晚上哈哈,结果超过int范围捏,难受捏。 public class Test2 {public static void main(String[] args) { // ListNode l1 …

像学Excel 一样学 Pandas系列-创建数据分析维度

嗨,小伙伴们。又到喜闻乐见的Python 数据分析王牌库 Pandas 的学习时间。按照数据分析处理过程,这次轮到了新增维度的部分了。 老样子,我们先来回忆一下,一个完整数据分析的过程,包含哪些部分内容。 其中&#xff0c…

四十篇:内存巨擘对决:Redis与Memcached的深度剖析与多维对比

内存巨擘对决:Redis与Memcached的深度剖析与多维对比 1. 引言 在现代的系统架构中,内存数据库已经成为了信息处理的核心技术之一。这类数据库系统的高效性主要来源于其对数据的即时访问能力,这是因为数据直接存储在RAM中,而非传统…

二叉树的前中后序遍历(递归法、迭代法)leetcode144、94/145

leetcode144、二叉树的前序遍历 给你二叉树的根节点 root ,返回它节点值的 前序 遍历。 示例 1: 输入:root [1,null,2,3] 输出:[1,2,3] 示例 2: 输入:root [] 输出:[] 示例 3: 输…

前端入门超级攻略:你的第一步学习指南

如果您觉得这篇文章有帮助的话!给个点赞和评论支持下吧,感谢~ 作者:前端小王hs 阿里云社区博客专家/清华大学出版社签约作者/csdn百万访问前端博主/B站千粉前端up主/知名前端开发者/网络工程师 前言 由于前端技术的快速迭代性,国…

解决ps暂存盘已满的问题

点击编辑->首选项->暂存盘 ps默认暂存盘使用的是c盘,我们改成d盘即可 然后重启ps

STM32之五:TIM定时器(2-通用定时器)

目录 通用定时器(TIM2~5)框图 1、 输入时钟源选择 2、 时基单元 3 、输入捕获:(IC—Input Capture) 3.1 输入捕获通道框图(TI1为例) 3.1.1 滤波器: 3.1.2 边沿检测器&#xf…

移动智能终端数据安全管理方案

随着信息技术的飞速发展,移动设备已成为企业日常运营不可或缺的工具。特别是随着智能手机和平板电脑等移动设备的普及,这些设备存储了大量的个人和敏感数据,如银行信息、电子邮件等。员工通过智能手机和平板电脑访问企业资源,提高…

【等保2.0是什么意思?等保2.0的基本要求有哪些? 】

一、等保2.0是什么意思? 等保2.0又称“网络安全等级保护2.0”体系,它是国家的一项基本国策和基本制度。在1.0版本的基础上,等级保护标准以主动防御为重点,由被动防守转向安全可信,动态感知,以及事前、事中…

SSM玉林师范学院宿舍管理系统-计算机毕业设计源码19633

摘要 随着大学生人数的增加,宿舍管理成为高校管理中的重要问题。本论文旨在研究玉林师范学院宿舍管理系统,探讨其优势和不足,并提出改进建议。通过对相关文献的综述和实地调研,我们发现该系统在宿舍分配、卫生评分、失物招领、设施…

什么是 URL ?

统一资源定位符(URL)是一个字符串,它指定了一个资源在互联网上的位置以及如何访问它。URL 是由几部分组成的,每部分都有其特定的作用: 协议/方案:这是 URL 的开头部分,表明了用于访问资源的协议…

基于uniapp(vue3)H5附件上传组件,可限制文件大小

代码&#xff1a; <template><view class"upload-file"><text>最多上传5份附件&#xff0c;需小于50M</text><view class"" click"selectFile">上传</view></view><view class"list" v…

WPF自定义模板--TreeView 实现菜单连接线

有些小伙伴说&#xff0c;在TreeView中&#xff0c;怎么每一个都加上连接线&#xff0c;进行显示连接。 代码和效果如下&#xff1a; 其实就是在原来的模板中增加一列显示线条&#xff0c;然后绘制即可 <Window x:Class"XH.TemplateLesson.TreeViewWindow"xmln…