PyTorch从零开始实现LSTM

news2024/7/4 6:05:37

文章目录

  • LSTM基础理论
  • 从零开始实现LSTM
  • 简洁版LSTM实现
  • 参考资料

LSTM基础理论

关于LSTM的基础理论不再赘述,可以参考资料:

  • RNN神经网络-LSTM模型结构
  • https://github.com/ShusenTang/Dive-into-DL-PyTorch/blob/master/docs/chapter06_RNN/6.8_lstm.md

从零开始实现LSTM

  1. 导入所需的依赖包,获取设备:
import torch
from torch import nn, optim
import numpy as np
import zipfile

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  1. 定义数据加载代码,这里以周杰伦的歌词数据为例:
def load_data_jay_lyrics():
    with zipfile.ZipFile('./data/jaychou_lyrics.txt.zip') as zin:
        with zin.open('jaychou_lyrics.txt') as f:
            corpus_chars = f.read().decode('utf-8')
    # corpus_chars[:40]  # '想要有直升机\n想要和你飞到宇宙去\n想要和你融化在一起\n融化在宇宙里\n我每天每天每'

    # 将换行符替换成空格;仅使用前1万个字符来训练模型
    corpus_chars = corpus_chars.replace('\n', ' ').replace('\r', ' ')
    corpus_chars = corpus_chars[0:10000]

    # 将每个字符映射成索引
    idx_to_char = list(set(corpus_chars))
    char_to_idx = dict([(char, i) for i, char in enumerate(idx_to_char)])
    vocab_size = len(char_to_idx)  # 1027
    corpus_indices = [char_to_idx[char] for char in corpus_chars]
    sample = corpus_indices[:20]
    return corpus_indices, char_to_idx, idx_to_char, vocab_size
  1. 定义参数初始化函数:
def get_params():
    def _one(shape):
        ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)
        return torch.nn.Parameter(ts, requires_grad=True)

    def _three():
        return (_one((num_inputs, num_hiddens)),
                _one((num_inputs, num_hiddens)),
                torch.nn.Parameter(torch.zeros(num_hiddens, device=device, dtype=torch.float32), requires_grad=True))

    W_xi, W_hi, b_i = _three()  # 输入门参数
    W_xf, W_hf, b_f = _three()  # 遗忘门参数
    W_xo, W_ho, b_o = _three()  # 输出门参数
    W_xc, W_hc, b_c = _three()  # 候选记忆细胞参数

    # 输出层参数
    W_hq = _one((num_hiddens, num_outputs))
    b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, dtype=torch.float32), requires_grad=True)
    return nn.ParameterList([W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q])


# 加载周杰伦歌词数据
corpus_indices, char_to_idx, idx_to_char, vocab_size = load_data_jay_lyrics()
num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size


def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))
  1. 定义LSTM模型:
def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid(torch.matmul(X, W_xi) + torch.matmul(H, W_hi) + b_i)
        F = torch.sigmoid(torch.matmul(X, W_xf) + torch.matmul(H, W_hf) + b_f)
        O = torch.sigmoid(torch.matmul(X, W_xo) + torch.matmul(H, W_ho) + b_o)
        C_tilda = torch.tanh(torch.matmul(X, W_xc) + torch.matmul(H, W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * C.tanh()
        Y = torch.matmul(H, W_hq) + b_q
        outputs.append(Y)
    return outputs, (H, C)
  1. 训练模型。
num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2
pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']


train_and_predict_rnn(lstm, get_params, init_lstm_state, num_hiddens,
                          vocab_size, device, corpus_indices, idx_to_char,
                          char_to_idx, False, num_epochs, num_steps, lr,
                          clipping_theta, batch_size, pred_period, pred_len,
                          prefixes)

这里 train_and_predict_rnn 为训练函数,可以参考
https://github.com/ShusenTang/Dive-into-DL-PyTorch/blob/master/docs/chapter06_RNN/6.5_rnn-pytorch.md
中的实现。

简洁版LSTM实现

  1. 定义RNNModel类:
# 本类已保存在d2lzh_pytorch包中方便以后使用
class RNNModel(nn.Module):
    def __init__(self, rnn_layer, vocab_size):
        super(RNNModel, self).__init__()
        self.rnn = rnn_layer
        self.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1) 
        self.vocab_size = vocab_size
        self.dense = nn.Linear(self.hidden_size, vocab_size)
        self.state = None

    def forward(self, inputs, state): # inputs: (batch, seq_len)
        # 获取one-hot向量表示
        X = to_onehot(inputs, self.vocab_size) # X是个list
        Y, self.state = self.rnn(torch.stack(X), state)
        # 全连接层会首先将Y的形状变成(num_steps * batch_size, num_hiddens),它的输出
        # 形状为(num_steps * batch_size, vocab_size)
        output = self.dense(Y.view(-1, Y.shape[-1]))
        return output, self.state
def to_onehot(X, n_class):  
    # X shape: (batch, seq_len), output: seq_len elements of (batch, n_class)
    return [one_hot(X[:, i], n_class) for i in range(X.shape[1])]
  1. 定义LSTM模型并训练模型:
lr = 1e-2 # 注意调整学习率

lstm_layer = nn.LSTM(input_size=vocab_size, hidden_size=num_hiddens)

model = RNNModel(lstm_layer, vocab_size)

train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,
                                corpus_indices, idx_to_char, char_to_idx,
                                num_epochs, num_steps, lr, clipping_theta,
                                batch_size, pred_period, pred_len, prefixes)

参考资料

  • https://github.com/ShusenTang/Dive-into-DL-PyTorch/blob/master/docs/chapter06_RNN/6.8_lstm.md

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1884024.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

直播的js代码debug解析找到protobuf消息的定义

我们都知道直播的弹幕消息是通过websocket发送的,而且是通过protobuf传输的,那么这里面传输了哪些内容,这个proto文件又要怎么定义?每个消息叫什么,消息里面又包含有哪些字段,每个字段又是什么类型&#xf…

1-3.文本数据建模流程范例

文章最前: 我是Octopus,这个名字来源于我的中文名–章鱼;我热爱编程、热爱算法、热爱开源。所有源码在我的个人github ;这博客是记录我学习的点点滴滴,如果您对 Python、Java、AI、算法有兴趣,可以关注我的…

# [0701] Task05 策略梯度、Actor-critic 算法

easy-rl PDF版本 笔记整理 P4、P9 joyrl 比对 补充 P9 - P10 相关 代码 整理 最新版PDF下载 地址:https://github.com/datawhalechina/easy-rl/releases 国内地址(推荐国内读者使用): 链接: https://pan.baidu.com/s/1isqQnpVRWbb3yh83Vs0kbw 提取码: us…

LeetCode中MySQL题目 176.第二高的薪水

题目图片: 题目解答: SELECTIFNULL((SELECT DISTINCT SalaryFROM EmployeeORDER BY Salary DESCLIMIT 1 OFFSET 1),NULL) AS SecondHighestSalary解答解析: 就是用了一个叫做IFNULL的函数进行判断,如果查找出来的内容为空&…

信息系统的安全模型

1. 信息系统的安全目标 信息系统的安全目标是控制和管理主体(含用户和进程)对客体(含数据和程序)的访问。作为信息系统安全目标,就是要实现: 保护信息系统的可用性; 保护网络系统服务的…

第1章 人工智能的基础概念与应用导论

亲爱的读者朋友们,你们好!欢迎来到这个充满神奇与奥秘的人工智能世界。我知道,对于很多人来说,人工智能(AI)可能是个既神秘又高大上的词汇,仿佛遥不可及,只存在于科幻电影或者顶级科…

大数据学习之Clickhouse

Clickhouse-23.2.1.2537 学习 一、Clickhouse概述 clickhouse 官网网址:https://clickhouse.com/ ClickHouse是一个用于联机分析(OLAP)的列式数据库管理系统(DBMS)。 OLTP(联机事务处理系统)例如mysql等关系型数据库,在对于存储小数据量的时候&#xff…

Linux内核——Linux内核体系模式(二)

1 Linux系统的中断机制 Linux内核将中断分为两类:硬件中断和软件中断(异常)。每个中断是由0-255之间的一个数字进行标识。 中断int0-int31(0x00-0x1f)作为异常int32-int255由用户自己设定 int32-int47对应与8259A中断…

怎么永久禁止win10系统自动更新?一键屏蔽系统自动更新

现在 Windows 10 系统是很多办公用户的主力操作系统,可是 Windows 系统会自动更新,这会严重影响系统稳定性。因为微软虽然以提供更新为服务,但并不是每次更新它都是安全的。 接下来和我一起看看如何使用联想开发的小工具一键屏蔽系统自动更新…

数据库定义语言(DDL)

数据库定义语言(DDL) 一、数据库操作 1、 查询所有的数据库 SHOW DATABASES;效果截图: 2、使用指定的数据库 use 2403 2403javaee;效果截图: 3、创建数据库 CREATE DATABASE 2404javaee;效果截图: 4、删除数据…

Datax快速使用之牛刀小试

前言 一次我发现业务他们在用 datax数据同步工具,我尤记得曾经 19 年使用过,并且基于当时的版本还修复了个 BUG并且做了数据同步管道的集成开发。没想到时间过的飞快,业务方基于海豚调度 2.0.6 的版本中有在使用,由于业务方还没有…

光伏设计的原则和必备要素

光伏设计是一项复杂的工程任务,它涉及到将太阳能转换为电能的过程,并在各种环境条件下确保系统的稳定、高效运行。以下是光伏设计应遵循的原则和必备的要素。 一、光伏设计的原则 1、最大化能量产出:光伏设计的首要原则是通过合理的布局和选…

RedHat9 | 内部YUM本地源服务器搭建

服务器参数 标识公司内部YUM服务器主机名yum-server网络信息192.168.37.1/24网络属性静态地址主要操作用户root 一、基础环境信息配置 修改主机名 [rootyum-server ~]# hostnamectl hostname yum-server添加网络信息 [rootyum-server ~]# nmcli connection modify ens160 …

Python和tkinter单词游戏

Python和tkinter单词游戏 数据字典文本文件,文件名为Dictionary.txt,保存编码格式为:utf-8。文本内容:每行一个 单词 ,单词和解释用空格分隔,如 a art.一(个);每一(个) ability n.能力&#…

EKF+UKF+CKF+PF的效果对比|三维非线性滤波|MATLAB例程

前言 标题里的EKF、UKF、CKF、PF分别为:扩展卡尔曼滤波、无迹卡尔曼滤波、容积卡尔曼滤波、粒子滤波。 EKF是扩展卡尔曼滤波,计算快,最常用于非线性状态方程或观测方程下的卡尔曼滤波。 但是EKF应对强非线性的系统时,估计效果不如…

MySQL5.7安装初始化错误解决方案

问题背景 今天在给公司配数据库环境时,第一次报initializing database 数据库初始化错误? 起初没管以为是安装软件原因,然后就出现以下错误:如下图 点开log,我们观察日志会发现 无法识别的参数 ‘mysqlx_port=0.0’,???,官方的安装程序还能出这问题?

排序(堆排序、快速排序、归并排序)-->深度剖析(二)

前言 前面介绍了冒泡排序、选择排序、插入排序、希尔排序,作为排序中经常用到了算法,还有堆排序、快速排序、归并排序 堆排序(HeaSort) 堆排序的概念 堆排序是一种有效的排序算法,它利用了完全二叉树的特性。在C语言…

【Linux】:环境变量

朋友们、伙计们,我们又见面了,本期来给大家解读一下有关Linux环境变量的相关知识点,如果看完之后对你有一定的启发,那么请留下你的三连,祝大家心想事成! C 语 言 专 栏:C语言:从入门…

万字总结随机森林原理、核心参数以及调优思路

万字总结随机森林原理、核心参数以及调优思路 在机器学习的世界里,随机森林(Random Forest, RF)以其强大的预测能力和对数据集的鲁棒性而备受青睐。作为一种集成学习方法,随机森林通过构建多个决策树并将它们的预测结果进行汇总&…

SpringCloud_Eureka注册中心

概述 Eureka是SpringCloud的注册中心。 是一款基于REST的服务治理框架,用于实现微服务架构中的服务发现和负载均衡。 在Eureka体系中,有两种角色: 服务提供者和服务消费者。 服务提供者将自己注册到Eureka服务器,服务消费者从Eureka服务器中…