昇思25天学习打卡营第07天|函数式自动微分

news2024/10/6 7:31:09

神经网络的训练主要使用反向传播算法,模型预测值(logits)与正确标签(label)送入损失函数(loss function)获得loss,然后进行反向传播计算,求得梯度(gradients),最终更新至模型参数(parameters)。自动微分能够计算可导函数在某点处的导数值,是反向传播算法的一般化。自动微分主要解决的问题是将一个复杂的数学运算分解为一系列简单的基本运算,该功能对用户屏蔽了大量的求导细节和过程,大大降低了框架的使用门槛。

计算图是用图论语言表示数学函数的一种方式,也是深度学习框架表达神经网络模型的统一方法。我们将根据下面的计算图构造计算函数和神经网络。

compute-graph

在这个模型中,𝑥𝑥为输入,𝑦𝑦为正确值,𝑤𝑤和𝑏𝑏是我们需要优化的参数。

import numpy as np
import mindspore
from mindspore import nn
from mindspore import ops
from mindspore import Tensor, Parameter
import time

x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # bias

# binary_cross_entropy_with_logits 是一个损失函数
def function(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss

loss = function(x, y, w, b)
print(loss)

grad_fn = mindspore.grad(function, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

# 将function改为同时输出loss和z的function_with_logits,获得微分函数并执行。
def function_with_logits(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, z

grad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)



def function_stop_gradient(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss, ops.stop_gradient(z)

grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)

# Auxiliary data意为辅助数据,是函数除第一个输出项外的其他输出。通常我们会将函数的loss设置为函数的第一个输出,其他的输出即为辅助数据。
# grad和value_and_grad提供has_aux参数,当其设置为True时,可以自动实现前文手动添加stop_gradient的功能,满足返回辅助数据的同时不影响梯度计算的效果。

grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True)

grads, (z,) = grad_fn(x, y, w, b)
print(grads, z)

# Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.w = w
        self.b = b

    def construct(self, x):
        z = ops.matmul(x, self.w) + self.b
        return z
    
# Instantiate model
model = Network()
# Instantiate loss function
loss_fn = nn.BCEWithLogitsLoss()

# Define forward function
def forward_fn(x, y):
    z = model(x)
    loss = loss_fn(z, y)
    return loss

grad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params())

loss, grads = grad_fn(x, y)
print(grads)


print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),'skywp')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1905271.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Ubuntu 20.04下多版本CUDA的安装与切换 超详细教程

目录 前言一、安装 CUDA1.找到所需版本对应命令2.下载 .run 文件3.安装 CUDA4.配置环境变量4.1 写入环境变量4.2 软连接 5.验证安装 二、安装 cudnn1.下载 cudnn2.解压文件3.替换文件4.验证安装 三、切换 CUDA 版本1.切换版本2.检查版本 前言 当我们复现代码时,总会…

孟德尔随机化与痛风3

写在前面 检索检索,刚好发现一篇分区还挺高,但结果内容看上去还挺熟悉的文章,特记录一下。 文章 Exploring the mechanism underlying hyperuricemia using comprehensive research on multi-omics Sci Rep IF:3.8中科院分区:2区 综合性期…

从一个(模型设计的)想法到完成模型验证的步骤

从有一个大型语言模型(LLM)设计的想法到完成该想法的验证,可以遵循以下实践步骤: 需求分析: 明确模型的目的和应用场景。确定所需的语言类型、模型大小和性能要求。分析目标用户群体和使用环境。 文献调研&#xff1a…

前端html面试常见问题

前端html面试常见问题 1. !DOCTYPE (文档类型)的作用2. meta标签3. 对 HTML 语义化 的理解?语义元素有哪些?语义化的优点4. HTML中 title 、alt 属性的区别5. src、href 、url 之间的区别6. script标签中的 async、defer 的区别7. 行内元素、块级元素、空…

运维系列.Nginx:自定义错误页面

运维系列 Nginx:自定义错误页面 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite:http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:https://blog.csdn.net/…

mac|Mysql WorkBench 或终端 导入 .sql文件

选择Open SQL Script导入文件 在第一行加入use 你的schema名字,相当于选择了这个schema 点击运行即可将sql文件导入database 看到下面成功了即可 这时候可以看看左侧的目标database中有没有成功导入table,如果没有看到的话,可以点一下右上角的…

【算法笔记自学】第 8 章 提高篇(2)——搜索专题

8.1深度优先搜索&#xff08;DFS&#xff09; #include <cstdio>const int MAXN 5; int n, m, maze[MAXN][MAXN]; bool visited[MAXN][MAXN] {false}; int counter 0;const int MAXD 4; int dx[MAXD] {0, 0, 1, -1}; int dy[MAXD] {1, -1, 0, 0};bool isValid(int …

执行力不足是因为选择模糊

选择模糊&#xff1a;执行力不足的根源 选择模糊是指在面对多个选项时&#xff0c;缺乏明确的目标和方向。这种模糊感会导致犹豫不决&#xff0c;进而影响我们的执行力。 选择模糊的表现&#xff1a; 目标不明确&#xff0c;不知道应该做什么。优先级混乱&#xff0c;不清楚…

【持续集成_03课_Jenkins生成Allure报告及Sonar静态扫描】

1、 一、构建之后的配置 1、安装allure插件 安装好之后&#xff0c;可以在这里搜到已经安装的 2、配置allure的allure-commandline 正常配置&#xff0c;是要么在工具里配置&#xff0c;要么在系统里配置 allure-commandline是在工具里进行配置 两种方式进行配置 1&#xff…

人工智能、机器学习、神经网络、深度学习和卷积神经网络的概念和关系

人工智能&#xff08;Artificial Intelligence&#xff0c;缩写为AI&#xff09;--又称为机器智能&#xff0c;是研究、开发用于模拟、延伸和扩展人的智能的理论、方法、技术及应用系统的一门新的技术科学。 人工智能是智能学科重要的组成部分&#xff0c;它企图了解智能的实质…

[护网训练]原创应急响应靶机整理集合

前言 目前已经出了很多应急响应靶机了&#xff0c;有意愿的时间&#xff0c;或者正在准备国护的师傅&#xff0c;可以尝试着做一做已知的应急响应靶机。 关于后期&#xff1a; 后期的应急响应会偏向拓扑化&#xff0c;不再是单单一台机器&#xff0c;也会慢慢完善整体制度。…

基于Java的企业客户信息反馈平台

你好呀&#xff0c;我是计算机学姐码农小野&#xff01;如果有相关需求&#xff0c;可以私信联系我。 开发语言&#xff1a; Java 数据库&#xff1a; MySQL 技术&#xff1a; Java MySQL B/S架构 SpringBoot框架 工具&#xff1a; Eclipse、MySQL环境配置工具、浏览…

【每日一练】python算数练习题(函数.随机.判断综合运用)

""" 幼儿园加减法练习题 答对点赞表情&#xff0c;答错炸弹表情 表情随机出现 如果全答对有大奖 """ import random df0 #定义答对函数 def dd():global dfdf10bq["&#x1f339;&#x1f339;&#x1f339;","&#x1f389;&…

试用笔记之-汇通窗口颜色显示软件(颜色值可供Delphi编程用)

首先下载汇通窗口颜色显示软件 http://www.htsoft.com.cn/download/wdspy.rar 通过获得句柄颜色&#xff0c;显示Delphi颜色值和HTML颜色值

【74LS163做24进制计数器】2021-11-19

缘由用74LS163做24进制计数器-其他-CSDN问答,仿真multisim两个74LS163芯片如何构成47进制计数器-吐槽问答-CSDN问答 参考74ls163中文资料汇总&#xff08;74ls163引脚图及功能_内部结构图及应用电路&#xff09; - 电子发烧友网

weblogic加入第三方数据库代理驱动jar包(Oracle为例)

做的是国企项目&#xff0c;项目本身业务并不复杂&#xff0c;最复杂的却是服务器部署问题&#xff0c;对方给提供的服务器分内网、外网交换网&#xff0c;应用在交换网&#xff0c;数据库在内网&#xff0c;应用不能直接访问内网数据库&#xff0c;只能通过安全隔离网闸访问内…

electron 初始使用

electron electron文档地址deno下载地址安装命令 yarn config set electron_mirror https://cdn.npm.taobao.org/dist/electron/ npm install下载文件 文件下载完成后&#xff0c;新建dist目录&#xff0c;解压到list目录下&#xff1b;path文件中写入electron.exe 运行命令 …

【三级等保】等保整体建设方案(Word原件)

建设要点目录&#xff1a; 1、系统定级与安全域 2、实施方案设计 3、安全防护体系建设规划 软件全文档&#xff0c;全方案获取方式&#xff1a;本文末个人名片直接获取。

Python28-7.5 降维算法之t-分布邻域嵌入t-SNE

t-分布邻域嵌入&#xff08;t-distributed Stochastic Neighbor Embedding&#xff0c;t-SNE&#xff09;是一种用于数据降维和可视化的机器学习算法&#xff0c;尤其适用于高维数据的降维。t-SNE通过将高维数据嵌入到低维空间&#xff08;通常是二维或三维&#xff09;中&…

尚品汇-(十四)

&#xff08;1&#xff09;提交git 商品后台管理到此已经完成&#xff0c;我们可以把项目提交到公共的环境&#xff0c;原来使用svn&#xff0c;现在使用git 首先在本地创建ssh key&#xff1b; 命令&#xff1a;ssh-keygen -t rsa -C "your_emailyouremail.com" I…