CMU 10-414/714: Deep Learning Systems --hw4

news2025/1/12 12:31:19

通过之前作业中完成的所有组件,使用高性能的网络结构来解决一些问题。首先会增加一些新的算子(使用CPU/CUDA后端),然后完成卷积、以及用一个卷积神经网络来在CIFAR-10图像数据集上训练一个分类器。接着需要完成循环神经网络(带LSTM),并在Penn Treebank数据集上完成字符级的预测

实现功能

  1. python/needle/ops.py中添加三个新算子:TanhStackSplit
  2. 实现CIFAR-10数据集的Dataset以及在此基础上训练自己实现的卷积神经网络(本篇暂只关注下面的循环神经网络,这个先跳过)
  3. 实现循环神经网络
    先实现单个RNN cell(class RNNCell),然后将多个RNNcell叠加起来,形成更复杂的网络结构(class RNN),如下图所示,图1为单个RNNcell,图二为多个RNNcell组成的RNN
    在这里插入图片描述
    在这里插入图片描述
    • 实现单个RNN cell(class RNNCell(Module)):
      • \_\_init__()初始化函数:进行一系列参数初始化,用均匀分布初始化参数矩阵,并选用激活函数(Tanh或ReLU)
      • forward(X,h)函数:即实现 h t = tanh ⁡ ( x t W i h + b i h + h ( t − 1 ) W h h + b h h ) h_t=\tanh(x_tW_{ih}+b_{ih}+h_{(t-1)}W_{hh}+b_{hh}) ht=tanh(xtWih+bih+h(t1)Whh+bhh)(其中tanh也可换为relu),代码如下:
        def forward(self, X, h=None):
            batch_size, _ = X.shape
            if h is None:
                h = init.zeros(batch_size, self.hidden_size, device=self.device, dtype=self.dtype)
            if self.bias:
                return self.nonlinearity(X @ self.W_ih + self.bias_ih.reshape((1, self.hidden_size)).broadcast_to((batch_size, self.hidden_size)) \
                                       + h @ self.W_hh + self.bias_hh.reshape((1, self.hidden_size)).broadcast_to((batch_size, self.hidden_size)))
            else:
                return self.nonlinearity(X @ self.W_ih + h @ self.W_hh)
        
    • 实现多层RNN(class RNN(Module)
      • \_\_init__()初始化函数:
      • forward(X,h0)函数:将X(seq_len, bs, input_size)分割成seq_len个X(bs, input_size),将每个分割部分分别输入每个RNN cells,得到num_layers个hiddens和seq_len个out。最终将out组装到一起、将hiddens组装到一起,返回out和hs
        def forward(self, X, h0=None):
            _, batch_size, _ = X.shape
            if h0 is None:
                h0 = [init.zeros(batch_size, self.hidden_size, device=self.device, dtype=self.dtype) for _ in range(self.num_layers)]
            else:
                h0 = tuple(ops.split(h0, 0))
            h_n = []
            inputs = list(tuple(ops.split(X, 0)))
            for num_layer in range(self.num_layers):
                h = h0[num_layer]
                for t, input in enumerate(inputs):
                    h = self.rnn_cells[num_layer](input, h)
                    inputs[t] = h
                h_n.append(h)
            return ops.stack(inputs, 0), ops.stack(h_n, 0)
        
  4. 实现LSTM网络(在此学习使用numpy一步步实现lstm的流程))
    • 实现Sigmoid
    • 实现单个LSTMCell
      • __init__()函数:初始化
      • forward()函数:
        def forward(self, X, h=None):
            batch_size, _ = X.shape  # X:(batch_size, feature_size)
            if h is None:
                h0, c0 = init.zeros(batch_size, self.hidden_size, device=self.device, dtype=self.dtype), \
                         init.zeros(batch_size, self.hidden_size, device=self.device, dtype=self.dtype)
            else:
                h0, c0 = h
            if self.bias:
                gates_all = X @ self.W_ih + self.bias_ih.reshape((1, 4 * self.hidden_size)).broadcast_to((batch_size, 4 * self.hidden_size)) \
                            + h0 @ self.W_hh + self.bias_hh.reshape((1, 4 * self.hidden_size)).broadcast_to((batch_size, 4 * self.hidden_size))
            else:
                gates_all = X @ self.W_ih + h0 @ self.W_hh
            gates_all_split = tuple(ops.split(gates_all, axis = 1))
            gates = []
            for i in range(4):
                gates.append

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1534982.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用Cpolar异地组网,在vscode上ssh远程开发ubuntu主机

目录 开发环境 操作流程 参考资料 在机器人被搬到另一个屋之后,通过局域网进行ssh开发就变成了个困难的问题。因此尝试了异地组网来解决这个问题,看了一些资料后发现基于cpolar进行异地组网也不困难,这里记录一下步骤。 开发环境 硬件&…

Transformer的前世今生 day06(Self-Attention和RNN、LSTM的区别

Self-Attention和RNN、LSTM的区别 RNN(循环神经网络) RNN,当前的输出 o t o_t ot​取决于上一个的输出 o t − 1 o_{t-1} ot−1​(作为当前的输入 x t − 1 x_{t-1} xt−1​)和当前状态下前一时间的隐变量 h t h_t h…

Vue3学习记录(七)--- 组合式API之指令和插件

一、内置指令 1、v-memo ​ 该指令是Vue3的v3.2版本之后新增的指令,用于实现组件模板缓存,优化组件更新时的性能。该指令接收一个固定长度的依赖值数组,在组件进行更新渲染时,如果数组中的每个依赖值都与上一次渲染时的值相同&a…

web前端笔记+表单练习题+五彩导航栏练习题

web前端笔记 1-骨架快捷方式!enter<!DOCTYPE html><html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>骨架部分</titl…

刚进公司第一天-电脑环境搭建

写在前面 之前在公司做过一次开发小工具的分享&#xff0c;这两天有个同事找我学习一些小工具开发的知识&#xff0c;但是我发现他的基础是真的差&#xff0c;想学开发知识却连自己本地电脑环境都没弄好&#xff0c;确实&#xff0c;有些人工作了很久&#xff0c;由于自己工作中…

sentinel整合gateway实现服务限流

导入依赖: <dependencies><dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-starter-gateway</artifactId></dependency><dependency><groupId>com.alibaba.csp</groupId><…

SpringCloudGateway之高性能篇

SpringCloudGateway之高性能篇 背景 在公司的开放平台中&#xff0c;为了统一管理对外提供的接口、处理公共逻辑、实现安全防护及流量控制&#xff0c;确实需要一个API网关作为中间层。 场景 统一接入点: API网关作为所有对外服务的单一入口&#xff0c;简化客户端对内部系统…

使用Go语言创建HTTP服务器并展示网页

使用Go语言创建一个简单的服务器时可以先建立一个项目根目录&#xff0c;随后在根目录中建立一个用于存放静态文件&#xff08;HTML/CSS/JavaScript&#xff09;的文件夹 GGboy&#xff0c;接下来输入命令初始化Go模块 go mod init GGboy // 项目名称是 GGboy 在出现 go.mod 文…

QT:三大特性

QT的三大特性&#xff1a; 1、信号与槽 2、内存管理 3、事件处理 1、信号与槽 当信号产生时&#xff0c;就会自动调用绑定的槽函数。 自定义信号: 类中需要添加O_OBJECT宏 声明: signals标签之下进行声明 定义&#xff1a; 信号不需要定义 …

Java项目基于Docker打包发布

一、后端项目 1.打包应用 mvn clean package -DskipTests 2、新建dockerfile文件 #基础镜像 FROM openjdk:8 #工作空间 WORKDIR /opt #复制文件 COPY wms-app-1.0-SNAPSHOT.jar app.jar #配置容器暴漏的端口 EXPOSE 8080 RUN ls #强制执行命令 ENTRYPOINT [ "java&quo…

谷粒商城——缓存的概念

1. 使用缓存的好处&#xff1a;减少数据库的访问频率&#xff0c;提高用户获取数据的速度。 2. 什么样的数据适合存储到缓存中&#xff1f; ①及时性、数据一致性要求不高的数据&#xff0c;例如物流信息、商品类目信息 ②访问量大更新频率不高的数据(读多、写少) 3. 读模式…

sentinel熔断规则详解

1、慢调用降级熔断 1.1、参数详解 最大RT&#xff1a;调用接口的最大时间。 比例阈值&#xff1a;超过了最大RT调用时间的请求的比例。 熔断时长&#xff1a;触发熔断后&#xff0c;熔断的时间 最小请求数据&#xff1a;每秒最少的请求数量&#xff0c;只有大于等于这个数…

电网的正序参数和等值电路(一)

本篇为本科课程《电力系统稳分析》的笔记。 本篇为第二章的第一篇笔记。 电力系统正常运行中&#xff0c;可以认为系统的三相结构和三相负荷完全对称。而对称三相的计算可以用一相来完成&#xff0c;其中所有给出的标称电压都是线电压的有效值&#xff0c;假定系统全部是Y-Y型…

基于Python3的数据结构与算法 - 16 链表

目录 链表 1. 创建链表 2. 链表的插入和删除 3. 双链表 4. 链表总结 链表 链表是由一系列节点组成的元素集合。每个节点包含两部分&#xff0c;数据域item和指向下一个节点得指针next。通过节点之间的相互连接&#xff0c;最终串联成一个链表。 class Node:def __init…

Unity访问安卓(Android)或苹果(iOS)相册

1.下载Native Gallery for Android & iOS插件 2.在场景中添加截图按钮、选择图片按钮、选择视频按钮等 using OpenCVForUnity.CoreModule; using OpenCVForUnity.ImgprocModule; using OpenCVForUnity.UnityUtils; using System.Collections; using System.Collections.Gen…

好用的GPTs:指定主题搜索、爬虫、数据清洗、数据分析自动化

好用的GPTs&#xff1a;指定主题搜索、爬虫、数据清洗、数据分析自动化 Scholar&#xff1a;搜索 YOLO小目标医学方面最新论文Scraper&#xff1a;爬虫自动化数据清洗数据分析 点击 Explore GPTs&#xff1a; Scholar&#xff1a;搜索 YOLO小目标医学方面最新论文 搜索 Scho…

C语言 swab 函数学习

swab函数交换字符串中相邻两个字节&#xff1b; void _swab( char *src, char *dest, int n ); char *src&#xff1a; 要拷贝、转换的字符串&#xff0c; char *dest&#xff0c;转换后存储到dest所表示的字符串&#xff0c; int n要拷贝、转换的字节数&#xff1b; 所…

交通事故档案管理系统|基于JSP技术+ Mysql+Java+Tomcat的交通事故档案管理系统设计与实现(可运行源码+数据库+设计文档)

推荐阅读100套最新项目 最新ssmjava项目文档视频演示可运行源码分享 最新jspjava项目文档视频演示可运行源码分享 最新Spring Boot项目文档视频演示可运行源码分享 2024年56套包含java&#xff0c;ssm&#xff0c;springboot的平台设计与实现项目系统开发资源&#xff08;可…

关于OceanBase中旁路导入的应用分享

背景 前段时间&#xff0c;在用户现场协助进行OceanBase的性能测试时&#xff0c;我注意到用户常常需要运用 insert into select 将上亿行的数据插入到一张大宽表里&#xff0c;这样的批量数据插入操作每次都需要耗时半个小时左右。对这一情况&#xff0c;我提议用户尝试采用旁…

福建科立讯通信 指挥调度管理平台 SQL注入漏洞复现(CVE-2024-2620、CVE-2024-2621)

0x01 产品简介 福建科立讯通信指挥调度管理平台是一个专门针对通信行业的管理平台。该产品旨在提供高效的指挥调度和管理解决方案,以帮助通信运营商或相关机构实现更好的运营效率和服务质量。该平台提供强大的指挥调度功能,可以实时监控和管理通信网络设备、维护人员和工作任…