【深度学习】RNN的简单实现

news2024/11/26 3:31:58

目录

1.RNNCell

2.RNN

3.RNN_Embedding


1.RNNCell

import torch

input_size = 4
hidden_size = 4
batch_size = 1

idx2char = ['e', 'h', 'l', 'o']
x_data = [1, 0, 2, 2, 3]  # 输入:hello
y_data = [3, 1, 2, 3, 2]  # 期待:ohlol

# 独热向量
one_hot_lookup = [[1, 0, 0, 0],
                  [0, 1, 0, 0],
                  [0, 0, 1, 0],
                  [0, 0, 0, 1]]
x_one_hot = [one_hot_lookup[x] for x in x_data]

inputs = torch.Tensor(x_one_hot).view(-1, batch_size, input_size)  # (seqLen,batchSize,inputSize)
labels = torch.LongTensor(y_data).view(-1, 1)  # (seqLen,1)


class Model(torch.nn.Module):
    def __init__(self, input_size, hidden_size, batch_size):
        super(Model, self).__init__()
        self.batch_size = batch_size
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.rnncell = torch.nn.RNNCell(input_size=self.input_size,
                                        hidden_size=self.hidden_size)

    def forward(self, input, hidden):
        hidden = self.rnncell(input, hidden)  # input:(batch, input_size) hidden:(batch, hidden_size)
        return hidden

    def init_hidden(self):
        return torch.zeros(self.batch_size, self.hidden_size)


net = Model(input_size, hidden_size, batch_size)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.1)


for epoch in range(15):
    loss = 0
    optimizer.zero_grad()
    hidden = net.init_hidden()
    print('Predicted string: ', end='')
    for input, label in zip(inputs, labels):
        hidden = net(input, hidden)
        loss += criterion(hidden, label)
        idx = torch.argmax(hidden, dim=1)
        print(idx2char[idx.item()], end='')
    loss.backward()
    optimizer.step()
    print(', Epoch [%d/15] loss=%.4f' % (epoch+1, loss.item()))

2.RNN

import torch

input_size = 4  # 输入的维度,例如hello为四个字母表示,其维度为四
hidden_size = 4  # 隐藏层维度
num_layers = 1  # number of layers
batch_size = 1
seq_len = 5

idx2char = ['e', 'h', 'l', 'o']
x_data = [1, 0, 2, 2, 3]  # 输入:hello
y_data = [3, 1, 2, 3, 2]  # 期待:ohlol

# 独热向量
one_hot_lookup = [[1, 0, 0, 0],
                  [0, 1, 0, 0],
                  [0, 0, 1, 0],
                  [0, 0, 0, 1]]
x_one_hot = [one_hot_lookup[x] for x in x_data]

inputs = torch.Tensor(x_one_hot).view(seq_len, batch_size, input_size)
labels = torch.LongTensor(y_data)  # (seqSize*batchSize, 1)


class Model(torch.nn.Module):
    def __init__(self, input_size, hidden_size, batch_size, num_layers=1):
        super(Model, self).__init__()
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.rnn = torch.nn.RNN(input_size=self.input_size,
                                hidden_size=self.hidden_size,
                                num_layers=num_layers)

    def forward(self, input):
        hidden = torch.zeros(self.num_layers,
                             self.batch_size,
                             self.hidden_size)  # (numLayers, batchSize, hiddenSize)
        out, hidden_last = self.rnn(input, hidden)  # out:(seqLen, batchSize, hiddenSize), hidden_last:最后一个hidden
        return out.view(-1, self.hidden_size)  # (seqLen×batchSize, hiddenSize)


net = Model(input_size, hidden_size, batch_size, num_layers)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.05)

for epoch in range(15):
    optimizer.zero_grad()
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    idx = torch.argmax(outputs, dim=1)
    print('Predicted string: ', ''.join([idx2char[i.item()] for i in idx]), end='')
    print(', Epoch [%d/15] loss = %.3f' % (epoch + 1, loss.item()))

3.RNN_Embedding

import torch

# parameters
num_class = 4  # 引入线性层,不用不必要求一个输入就有一个输出,可以多个
input_size = 4
hidden_size = 8
embedding_size = 10
num_layers = 2
batch_size = 1
seq_len = 5

idx2char = ['e', 'h', 'l', 'o']
x_data = [[1, 0, 2, 2, 3]]  # (batch:1, seq_len:5)
y_data = [3, 1, 2, 3, 2]  # (batch * seq_len)
inputs = torch.LongTensor(x_data)
labels = torch.LongTensor(y_data)


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.emb = torch.nn.Embedding(input_size, embedding_size)
        self.rnn = torch.nn.RNN(input_size=embedding_size,
                                hidden_size=hidden_size,
                                num_layers=num_layers,
                                batch_first=True)
        # batchSize在第一位: (batchSize:一共几个句子, seqLen:每个句子有几个单词, inputSize:每个单词有多少特征)
        self.fc = torch.nn.Linear(hidden_size, num_class)

    def forward(self, x):
        hidden = torch.zeros(num_layers, x.size(0), hidden_size)
        x = self.emb(x)  # (batch, seqLen, embeddingSize), 输入数据x首先经过嵌入层,将字符索引转换为向量
        x, _ = self.rnn(x, hidden)
        x = self.fc(x)
        return x.view(-1, num_class)


net = Model()

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.05)

for epoch in range(15):
    optimizer.zero_grad()
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    idx = torch.argmax(outputs, dim=1)
    print('Predicted string: ', ''.join([idx2char[i.item()] for i in idx]), end='')
    print(', Epoch [%d/15] loss = %.3f' % (epoch + 1, loss.item()))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2211918.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

滑雪——记忆化搜索

题目 代码 //#pragma GCC optimize(3)#include <bits/stdc.h> const int N 310; using namespace std; int dx[4] {-1, 0, 1, 0}, dy[4] {0, 1, 0, -1}; int ans; int g[N][N]; int r, c; int f[N][N]; int dfs(int x, int y) {if(~f[x][y]) return f[x][y];f[x][y] …

TikTok直播带货话术分享,轻松实现销量翻倍

随着TikTok直播带货的不断壮大&#xff0c;越来越多的国内用户开始尝试使用英语进行直播带货。这不仅能够吸引国际观众&#xff0c;还能够扩大市场和提升品牌影响力。 TikTok直播通用带货话术 1. 开场白 开场时&#xff0c;主播可以用热情的语言吸引观众的注意力&#xff1a;…

闭着眼学机器学习——支持向量机分类

引言&#xff1a; 在正文开始之前&#xff0c;首先给大家介绍一个不错的人工智能学习教程&#xff1a;https://www.captainbed.cn/bbs。其中包含了机器学习、深度学习、强化学习等系列教程&#xff0c;感兴趣的读者可以自行查阅。 1. 算法介绍 支持向量机(Support Vector Mach…

AI 能否替代程序员?且听我来一唠!

关于 AI 替代程序员这事儿&#xff0c;咱得仔细唠唠。随着 AI 技术的飞速发展&#xff0c;越来越多人担心程序员会被 AI 取代。程序员会不会失业呢&#xff1f;答案是&#xff1a;没那么简单&#xff01; 首先&#xff0c;AI 确实已经可以干很多程序员的活儿了&#xff0c;比如…

如何构建高效的公路工程资料管理系统?

本文介绍了构建高效的公路工程资料管理系统的方法&#xff0c;涵盖了系统需求分析、功能设计、开发平台选择、开发过程、系统上线与培训、持续改进与维护等关键环节。通过合理规划和科学管理&#xff0c;可以确保系统满足用户需求&#xff0c;提高工作效率&#xff0c;保障公路…

基于Java的超级玛丽游戏的设计与实现(论文+源码)-kaic

摘 要 “超级玛丽”游戏是是任天堂情报开发本部开发的Family Computer横版卷轴动作游戏&#xff0c;它因操作简单、娱乐性强而广受欢迎。Java 的优势在于网络编程与多线程&#xff0c;但其作为一门全场景语言&#xff0c;依然提供了强大的GUI开发API。本论文利用Java的GUI界…

某普SSLVPN 任意文件读取

0x01 产品描述&#xff1a; ‌ 迪普科技的VPN产品是一款面向广域互联应用场景的专业安全网关产品&#xff0c;集成了IPSec、SSL、L2TP、GRE等多种VPN技术&#xff0c;支持国密算法&#xff0c;实现分支机构、移动办公人员的统一安全接入&#xff0c;提供内部业务跨互联网的…

公开课 | 2024最新清华大模型公开课 第4课 大模型学习方法

本文由readlecture.cn转录总结。ReadLecture专注于音、视频转录与总结&#xff0c;2小时视频&#xff0c;5分钟阅读&#xff0c;加速内容学习与传播。 大纲 引言 介绍大模型的训练方法 强调大模型在多领域的应用 大模型的训练阶段 预训练过程 Tokenization的重要性 预训练模…

​面向异构硬件架构:软件支撑和优化技术

面向异构硬件架构&#xff1a;软件支撑和优化技术 本文来自“面向异构硬件架构软件支撑和优化技术”&#xff0c;重点分析了异构硬件成为发展新趋势&#xff0c;系统软件扮演重要新角色&#xff0c;硬件能力单一性与应用需求多样性间的矛盾带来系统性挑战。为了解决这个问题&am…

第五届大数据、人工智能与物联网工程国际会议

第五届大数据、人工智能与物联网工程国际会议&#xff08;ICBAIE 2024&#xff09;定于2024年10月25-27号在中国深圳隆重举行。会议主要围绕大数据、人工智能与物联网工程等研究领域展开讨论。会议旨在为从事大数据、人工智能与物联网工程研究的专家学者、工程技术人员、技术研…

Pandas DataFrame在预测时同样需要传入一个带有相同特征名称的数据框

问题 修改前的代码 import pandas as pd from sklearn.tree import DecisionTreeClassifier from sklearn.model_selection import train_test_splitmusic_datapd.read_csv("music.csv") X music_data.drop(columns[genre]) ymusic_data[genre] modelDecisionTree…

Vivado时序报告五:Report Exceptions详解

目录 一、前言 二 Report Exceptions 2.1 配置界面 2.2 设计示例 2.3 Exception报告 2.3.1 General information 2.3.2 Summary 2.3.3 Exceptions 2.3.4 Ignored Objects 一、前言 时序约束中&#xff0c;有一类约束属于Exceptions类&#xff0c;之所以称为Exceptions…

Base16编码解码在线工具

具体请前往&#xff1a;在线Base16编码/解码工具-支持utf-8,Latin1,ascii,GBK,Hex等编码

【股市人生】中年投资者的教训:短期盈利与家庭幸福的抉择,理智投资才是成功之道!

大家好&#xff0c;我是肝脑图弟&#xff0c;一个每天都在肝脑图的男人。今天我们来聊聊一个在股市中跌宕起伏的中年人的故事&#xff0c;这个故事不仅让人感慨&#xff0c;也给我们带来了深刻的反思。 股市就像一块美味的蛋糕&#xff0c;吸引着无数人前来品尝。初入股市的人&…

Gin框架教程01:创建一个简单的 Gin 应用

Gin是目前最流行&#xff0c;性能最好的的GOWEB框架&#xff0c;是学习GOLANG必备的知识。本人最近也在学Gin&#xff0c;在b站搜了很多教程&#xff0c;发现有的教程不够详细&#xff0c;有的教程工具包安装有问题&#xff0c;而官方文档又太简短&#xff0c;于是我就想&#…

Java项目:152 基于springboot的仓库管理系统

作者主页&#xff1a;舒克日记 简介&#xff1a;Java领域优质创作者、Java项目、学习资料、技术互助 文中获取源码 系统概要 ​ 本文将介绍一款基于Java开发的仓库管理系统&#xff0c;该系统可以帮助企业实现对仓库物品的高效管理&#xff0c;提高仓库运营效率。文章将详细介…

EMQX服务器的搭建,实现本地机和虚拟机之间的MQTT通信(详细教程)

前言 MQTT是一个基于客户端-服务器的消息发布/订阅传输协议。MQTT协议是轻量、简单、开放和易于实现的&#xff0c;这些特点使它适用范围非常广泛。 MQTT协议中有三种身份&#xff1a;发布者&#xff08;Publish&#xff09;、代理&#xff08;Broker&#xff09;&#xff08;…

【MySQL】数据库基础指令(一)

前言 个人感觉 MySQL 没有太多的逻辑问题&#xff0c;只有对语句的熟练使用&#xff0c;会对数据进行增删查改操作即可。本章节的内容将会收集一些常用的 MySQL 的指令的使用。 目录 前言 解决MySQL无法输入中文字符的问题 数据库操作 显示当前的数据库 创建数据库 删除数据库 …

大数据存储,搜索智能化的实践分享 | OceanBase 城市交流会精彩回顾

9月21日&#xff0c;“OceanBase 城市交流会”来到了深圳&#xff0c;携手货拉拉大数据技术与产品部&#xff0c;联合举办了“走进货拉拉”的技术交流活动。货拉拉、万家数科、云集、百丽等多家企业的一线技术专家&#xff0c;就大数据存储、AI等热点话题&#xff0c;深入探讨并…

《学习方法报》是什么级别的报纸?

《学习方法报》是什么级别的报纸&#xff1f; 《学习方法报》是省级报纸。 它由山西省教育厅主管&#xff0c;山西教育教辅传媒集团主办。该报创办于 1993 年&#xff0c;国内统一刊号为 CN14-0706/(F)。其作为中国高教学会学习科学研究分会会报&#xff0c;以传递最新教改信…