[2022-12-17]神经网络与深度学习 hw9 - bptt

news2024/9/29 15:23:27

contents

  • hw9 - Back Propagation Through Time
    • task1
      • 题目内容
      • 题目思路+题目解答
      • 题目总结
    • task2
      • 题目内容
      • 题目思路+题目解答
      • 题目总结

hw9 - Back Propagation Through Time

task1

题目内容

推导RNN反向传播算法BPTT。

题目思路+题目解答

首先我们要清楚RNN进行前向传播的过程:

  1. 由输入层→state层: 输入层部分除了原始的输入资料外会再加上t-1时间的state状态,一同向前传递到t时间的state。
  2. state层→输出层: 这边向前传递就什么特殊的部分,跟一般MLP差不多。


    由此得到如下推导
    在这里插入图片描述

题目总结

本题考查的是对于循环神经网络的随时间反向传播过程推导,要了解计算的过程。

task2

题目内容

设计简单RNN模型,分别用Numpy、Pytorch实现反向传播算子,并代入数值测试。

题目思路+题目解答

本题承接上一题的推导过程,形成代码即可。
我们这边还是继承自自造轮子的基类LayerBase,代码如下:

class RNN(LayerBase):
    activation_functions = {
        'tanh':tanh
    }
    def __init__(self, n_hidden_states, activation='tanh', n_bptt_steps=5, input_shape=None):
        self.input_shape = input_shape
        self.n_hidden_states = n_hidden_states
        self.activation = RNN.activation_functions[activation]()
        self.trainable = True
        self.n_bptt_steps = n_bptt_steps
        self.W = None # 前一状态权重
        self.V = None # 输出权重
        self.U = None # 输入权重

    def init(self, optimizer):
        timesteps, input_dim = self.input_shape
        limit = 1 / math.sqrt(input_dim)
        self.U  = np.random.uniform(-limit, limit, (self.n_hidden_states, input_dim))
        limit = 1 / math.sqrt(self.n_hidden_states)
        self.V = np.random.uniform(-limit, limit, (input_dim, self.n_hidden_states))
        self.W  = np.random.uniform(-limit, limit, (self.n_hidden_states, self.n_hidden_states))

        self.U_opt  = copy.copy(optimizer)
        self.V_opt = copy.copy(optimizer)
        self.W_opt = copy.copy(optimizer)

    def forward(self, X, training=True):
        self.layer_input = X
        batch_size, timesteps, input_dim = X.shape

        # 保存用于在反向传播中使用的值
        self.state_input = np.zeros((batch_size, timesteps, self.n_hidden_states))
        self.states = np.zeros((batch_size, timesteps+1, self.n_hidden_states))
        self.outputs = np.zeros((batch_size, timesteps, input_dim))

        # 将最后一个时间步设置为零以计算时间步为零的 state_input
        self.states[:, -1] = np.zeros((batch_size, self.n_hidden_states))
        for t in range(timesteps):
            # state_t 的输入是当前输入和前时刻的输出
            self.state_input[:, t] = X[:, t].dot(self.U.T) + self.states[:, t-1].dot(self.W.T)
            self.states[:, t] = self.activation(self.state_input[:, t])
            self.outputs[:, t] = self.states[:, t].dot(self.V.T)

        return self.outputs

    def backward(self, accum_grad):
        _, timesteps, _ = accum_grad.shape

        # 我们保存每个参数的累积梯度的变量
        grad_U = np.zeros_like(self.U)
        grad_V = np.zeros_like(self.V)
        grad_W = np.zeros_like(self.W)
        # 层输入的梯度将会传递到网络中的上一层
        accum_grad_next = np.zeros_like(accum_grad)

        # BPTT
        for t in reversed(range(timesteps)):
            # 在时间步 t 更新梯度 V
            grad_V += accum_grad[:, t].T.dot(self.states[:, t])
            # 基于状态输入计算梯度
            grad_wrt_state = accum_grad[:, t].dot(self.V) * self.activation.gradient(self.state_input[:, t])
            # 层输入梯度
            accum_grad_next[:, t] = grad_wrt_state.dot(self.U)
            # 通过反向传播更新 W 和 U 的梯度,至多至时间 t
            for t_ in reversed(np.arange(max(0, t - self.n_bptt_steps), t+1)):
                grad_U += grad_wrt_state.T.dot(self.layer_input[:, t_])
                grad_W += grad_wrt_state.T.dot(self.states[:, t_-1])
                # 根据前面状态计算梯度
                grad_wrt_state = grad_wrt_state.dot(self.W) * self.activation.gradient(self.state_input[:, t_-1])

        # 更新权重
        self.U = self.U_opt.update(self.U, grad_U)
        self.V = self.V_opt.update(self.V, grad_V)
        self.W = self.W_opt.update(self.W, grad_W)

        return accum_grad_next

将随机数据传入,将自定义算子和torch算子的权重进行同步并运行,进行比较,得到:
在这里插入图片描述

题目总结

本题考察的是依据之前的推导过程将理论变为代码的能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/96573.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

0. Canal 的安装和使用

我看过一场风景,后来我才知道,那是我人生中最美的一段时光。 我爱的人,爱我的人,都能度过这场新型感冒,那该多好。 Canal 的官网: https://github.com/alibaba/canal Canal 能干什么 为什么出现 Canal Canal 是阿里…

[ 数据结构 -- 手撕排序算法第二篇 ] 冒泡排序

文章目录前言一、常见的排序算法二、冒泡排序的实现2.1 基本思想2.2 单趟冒泡排序2.2.1 思路分析2.2.2 单趟代码实现三、冒泡排序的实现五、冒泡排序的时间复杂度5.1 最坏情况5.2 最好情况优化六、冒泡排序的特性总结总结前言 手撕排序算法第一篇:插入排序&#xf…

截止12.17 bitahub踩坑,mask无数次更改,lama代码的那些痛,羊了个羊,imwrite不生效

前面那篇跑出了STCN,倒是STCN熟悉了很多了 对bitahub,需要注意一个问题 要进ssh请用debug卡!!!! 要进ssh请用debug卡!!!! 要进ssh请用debug卡!&…

AQS-semaphoreCyclicBarrierCountDownLatch源码学习

上文:jdk-BlockingQueue源码学习源码下载:https://gitee.com/hong99/jdk8semaphore&cyclicbarrier&CountDownLatch的介绍semaphore基础功能semaphore简称信号量,主要用于控制访问特定资源的线程数目,底层用的是AQS的状记s…

[附源码]Python计算机毕业设计Django万佳商城管理系统

项目运行 环境配置: Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术: django python Vue 等等组成,B/S模式 pychram管理等等。 环境需要 1.运行环境:最好是python3.7.7,…

Volatile和高速缓存的关系

“volatile关键字有什么用?” 1 常见理解错误 把volatile当成一种锁机制,认为给变量加上了volatile,就好像是给函数加sychronized,不同的线程对于特定变量的访问会去加锁把volatile当成一种原子化的操作机制,认为加了…

Dubbo 3 Dubbo 快速入门 3.1 Zookeeper 安装

Dubbo 【黑马程序员Dubbo快速入门,Java分布式框架dubbo教程】 3 Dubbo 快速入门 文章目录Dubbo3 Dubbo 快速入门3.1 Zookeeper 安装3.1.1 Zookeeper 安装3.1 Zookeeper 安装 3.1.1 Zookeeper 安装 在Dubbo 架构图中 Dubbo官方推荐使用Zookeeper作为注册中心【Re…

【学习总结】注解和元注解

目录 一、注解 1、注解与XML区别 2、注解的用途 3、注解的三种分类 二、什么是元注解? 1、元注解有几种? 1、Retention存活时间 2、Target使用范围 3、Document保存到javadoc 4、Inherited注解继承 三、如何实现的注解 四、问提: …

为解决BERT模型对语料中低频词的不敏感性

来源:投稿 作者:COLDR 编辑:学姐 (内容如有错漏,可在评论区指出) 摘要 Dict-BERT为了解决BERT模型对语料中低频词(rare words)的不敏感性,通过在预训练中加入低频词词典…

人工智能/计算机期刊会议测评(持续更新...更新速度取决于我水论文的速度...)

IEEE Transactions on Knowledge and Data Engineering 中科院2区,CCF A。你为什么是二区????????????做梦都想中的刊。 …

5天带你读完《Effective Java》(二)

《Effective Java》是Java开发领域无可争议的经典之作,连Java之父James Gosling都说: “如果说我需要一本Java编程的书,那就是它了”。它为Java程序员提供了90个富有价值的编程准则,适合对Java开发有一定经验想要继续深入的程序员…

Servlet 原来是这个玩意、看完恍然大悟

1. 什么是 Servlet? 先让时间回到 25 年前,我国刚刚接入互联网不到两年时间。那时候的电脑长这个样子: 当时的网页技术还不是很发达,大家打开浏览器只能浏览一些静态的页面,例如图片、文本信息等。 随着时间的发展&a…

[附源码]Python计算机毕业设计Django社区生活废品回收APP

项目运行 环境配置: Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术: django python Vue 等等组成,B/S模式 pychram管理等等。 环境需要 1.运行环境:最好是python3.7.7,…

idea配置tomcat日志中文乱码,且修改后idea正常,但cmd窗口任然中文乱码解决方法

idea日志乱码问题的原因是tomcat的日志配置文件有两行有问题需要删掉,cmd乱码是Windows系统cmd窗口默认不是utf-8 首先解决idea中tomcat的日志乱码问题,在idea中进行如下的配置 Trans...........可以不勾选,它的作用是用选定的字符集把项目的…

[附源码]Python计算机毕业设计Django室内设计类网站

项目运行 环境配置: Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术: django python Vue 等等组成,B/S模式 pychram管理等等。 环境需要 1.运行环境:最好是python3.7.7,…

2023面试专题:JAVA基础

ArrayList和LinkedList有哪些区别 ArrayList扩容机制: ArrayList() 会使用长度为零的数组ArrayList(int initialCapacity) 会使用指定容量的数组public ArrayList(Collection<? extends E> c) 会使用 c 的大小作为数组容量add(Object o) 首次扩容为 10&#xff0c;再次…

【OpenCV】透视变换应用——实现鸟瞰图与贴图

透视变换是3D转换&#xff0c;透视变换的本质是将图像投影到一个新的视平面&#xff1b; 据此&#xff0c;我们可以使用透视变化来实现鸟瞰图和图形贴图的效果&#xff1b; 一、鸟瞰图 实现前&#xff1a; 实现效果&#xff1a; 1.准备一个空的mat对象 用于保存转换后的图 M…

asp.net mvc+elementUI 实现增删改查

最开始心想着一直都是前端玩这些玩意&#xff0c;个人虽然不是纯前端。好歹做为一个.net全栈开发多年&#xff0c;我就不太想用node去搭建&#xff0c;那么试试吧&#xff0c;总归不是那么几个css和js的文件引用&#xff0c;如果对vue.js不太熟悉&#xff0c;最好先去看看。 那…

智能家居创意DIY之智能触摸面板开关

触摸开关&#xff0c;即通过触摸方式控制的墙壁开关&#xff0c;其感官场景如同我们的触屏手机&#xff0c;只需手指轻轻一点即可达到控制电器的目的&#xff0c;随着人们生活品质的提高&#xff0c;触摸开关将逐渐将换代传统机械按键开关。 触摸开关控制原理 触摸开关我们把…

springboot入门案例

今天写一个springboot入门案例&#xff0c;接下来我将带大家走进springboot第一课的案例。如果有问题&#xff0c;望大家指正。 目录 1. 简介 2. 开发示例 2.1 创建springboot工程 3. 启动类 4. 常用注解 5. springboot配置文件 6. 开发一个controller 1. 简介 Spring …