【Pytorch项目实战】之自然语言处理:RNN、LSTM、GRU、Transformer

news2025/1/11 2:17:32

文章目录

  • 自然语言处理
    • 算法一:循环神经网络(Recurrent Natural Network,RNN)
    • 算法二:长短时记忆神经网络(Long Short-Term Memory,LSTM)
    • 算法三:门控循环单元神经网络(Gated Recurrent Unit,GRU)
    • 算法四:Transformer
    • (一)实战:基于LSTM预测股票行情(数据集:Tushare数据包)

自然语言处理

算法一:循环神经网络(Recurrent Natural Network,RNN)

RNN与CNN的区别?:CNN的输入图像大小固定,而在语音识别中,每句话的长度都是不一样的,且一句话的前后也是有关系的。

应用:自然语言处理(Neuro-Linguistic Programming,NLP)、语音识别、机器翻译等时序问题中。
难题:网络结构的加深使得模型忘记了先前学习的信息,即只能实现短时记忆
原因
(1)RNN相邻时间步是连接在一起的,,故(以横向网络看)每个权重都会朝相同方向迭代更新。
(2)随着层数的增加,梯度小则越小,最后趋近于0(梯度消失);梯度大则越大,最后非常大(梯度爆炸)。
(3)最终导致网络无法训练,无法实现长时记忆。

RNN网络的结构图
在这里插入图片描述
RNN由输入层、隐含层、输出层组成。其中,U是输入x到隐含层的权重矩阵;W是状态s到隐含层的权重矩阵;V是隐含层到输出层o的权重矩阵。
RNN的参数共享:即各个时间节点x(t-1)、x(t)、x(t+1)对应的W、U、V权重矩阵都是不变的。
RNN的最大特点:隐层状态。可以捕获一个序列的信息。

隐含层的详细结构图
在这里插入图片描述
假设:输入x为n维向量,隐含层的神经元个数为m,输出层的神经元个数为r。则U=n*mW=m*mV=m*r。x(t)、a(t)、o(t)都是向量。

  • x(t)是时刻t的输入向量;例如:一句话。
  • a(t)是时刻t的隐层状态;可以捕获之前时刻发生的信息。
    • a(t)=f(U*x(t)+W*a(t-1))。其中,x(t)表示当前时刻的输入;a(t-1)表示前一时刻的状态;f是非线性函数(如:tanh、ReLU)。
  • o(t)是时刻t的输出向量;例如:想要预测句子的下一个词,它将会是词汇表中的概率向量。o(t)=softmax(W*a(t))

RNN既可以像CNN一样横向发展(增加时间步或序列长度),也可以纵向扩展为多层RNN。
在这里插入图片描述

为加深对RNN结构图的理解,举例说明。
假设
详细过程如下图(RNN前向传播的计算过程)
在这里插入图片描述

上述案例的代码实现如下:

import numpy as np

X = [1,2]
state = [0.0, 0.0]
w_cell_state = np.asarray([[0.1, 0.2], [0.3, 0.4],[0.5, 0.6]])
b_cell = np.asarray([0.1, -0.1])
w_output = np.asarray([[1.0], [2.0]])
b_output = 0.1

for i in range(len(X)):
    state=np.append(state,X[i])
    before_activation = np.dot(state, w_cell_state) + b_cell
    state = np.tanh(before_activation)
    final_output = np.dot(state, w_output) + b_output
    print("状态值_%i: "%i, state)
    print("输出值_%i: "%i, final_output)

RNN的反向传播算法
称为随时间反向传播(Backpropagation Through Time,BPTT),原理与CNN的反向传播一样。区别:CNN按照层进行反向传播,而BPTT按照时间 t 进行反向传播。
在这里插入图片描述

算法二:长短时记忆神经网络(Long Short-Term Memory,LSTM)

时期:1997年提出,通过" 门 "的结构控制信息的增加或去除。
优点能够有效解决信息的短时记忆,避免梯度消失或爆炸。
缺点:模型不能并行学习,只能左右学习,导致面对大语料库(NLP)时训练效率非常低。
最大特点: 通过三个门控制对以往信息的取舍,即循环体结构

  • (1)遗忘门(Forget Gate):决定了上一时刻的单元状态c(t-1)有多少保留到当前时刻c(t);
  • (2)输入门(Input Gate):决定了当前时刻网络的输入x(t)有多少保存到单元状态c(t);
  • (3)输出门(Output Gate):控制单元状态c(t)有多少输出到LSTM的当前输出值h(t)。

在这里插入图片描述

算法三:门控循环单元神经网络(Gated Recurrent Unit,GRU)

背景:2014年提出,GRU是LSTM的变种。
优点:计算效率更高,占用内存相对较少。且在实际使用中,两者差异不大,故越来越流行。
主要两大改动

  • (1)将输入门、遗忘门、输出门变为两个门:更新门z(t)、重置门r(t)
  • (2)将单元状态与输出合并为一个状态h(t)。

在这里插入图片描述
GRU的结构图及公式

算法四:Transformer

2017 年6月,Google团队Ashish Vaswani等人在论文 Attention is All you need 中提出了 Transformer 模型。其使用 Self-Attention 结构取代了在 NLP 任务中常用的 RNN 的顺序网络结构,使得模型可以并行化训练,而且能够充分利用训练资料的全局信息,加入Transformer的Seq2seq模型在NLP的各个任务上都有了显著的提升。

核心机制:Self-Attention。Self-Attention机制的本质来自于人类视觉注意力机制。当人视觉在感知东西时候往往会更加关注某个场景中显著性的物体,为了合理利用有限的视觉信息处理资源,人需要选择视觉区域中的特定部分,然后集中关注它。
注意力机制的主要目:对输入进行注意力权重的分配,即决定需要关注输入的哪部分,并对其分配有限的信息处理资源给重要的部分。
在这里插入图片描述
Transformer 模型详解
Transformer详解(附代码)
详解Transformer中Self-Attention以及Multi-Head Attention

(一)实战:基于LSTM预测股票行情(数据集:Tushare数据包)

链接:https://pan.baidu.com/s/1hfmWHjDQxIHsR9xZNJuvYg?pwd=s1qc
提取码:s1qc

在这里插入图片描述

import pandas as pd
import matplotlib.pyplot as plt
import datetime
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms

import tushare as ts    # (1)pip install tushare(2)pip install pytdx
from pandas.plotting import register_matplotlib_converters

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'     		# "OMP: Error #15: Initializing libiomp5md.dll"
#################################################################################
def generate_data_by_n_days(series, n, index=False):
	"""生成训练数据"""
    if len(series) <= n:
        raise Exception("The Length of series is %d, while affect by (n=%d)." % (len(series), n))
    df = pd.DataFrame()
    for i in range(n):
        df['c%d' % i] = series.tolist()[i:-(n - i)]
    df['y'] = series.tolist()[n:]
    if index:
        df.index = series.index[n:]
    return df


def readData(column='high', n=30, all_too=True, index=False, train_end=-500):
	"""读取训练集"""
    df = pd.read_csv("sh300.csv", index_col=0)		# 读取csv文件
    df.index = list(map(lambda x: datetime.datetime.strptime(x, "%Y-%m-%d"), df.index))		# 以日期为索引
    df_column = df[column].copy()		# 获取每天的最高价数据(column='high')
    # 拆分为训练集和测试集
    df_column_train, df_column_test = df_column[:train_end], df_column[train_end - n:]
    # 生成训练数据
    df_generate_train = generate_data_by_n_days(df_column_train, n, index=index)
    if all_too:
        return df_generate_train, df_column, df.index
    return df_generate_train


class LSTM(nn.Module):
	"""LSTM网络模型"""
    def __init__(self, input_size):
        super(LSTM, self).__init__()
        self.LSTM = nn.LSTM(input_size=input_size, hidden_size=64, num_layers=1, batch_first=True)
        self.out = nn.Sequential(nn.Linear(64, 1))

    def forward(self, x):
        r_out, (h_n, h_c) = self.LSTM(x, None)  	# None即隐层状态用0初始化
        out = self.out(r_out)
        return out


class mytrainset(Dataset):
	"""数据与标签分离"""
    def __init__(self, data):
        self.data, self.label = data[:, :-1].float(), data[:, -1].float()

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def __len__(self):
        return len(self.data)


#################################################################################
# (1)加载数据
cons = ts.get_apis()				# (API)建立连接
df = ts.bar('000300', conn=cons, asset='INDEX', start_date='2010-01-01', end_date='')
df = df.dropna()            		# 删除有null的行		# 若报错提示: no module dropna。则删除该行。
df.to_csv('sh300.csv')				# 将df保存到当前目录下

# 若本地已下载文件:'sh300.csv'。(忽略上述代码,改调用下述代码)
# df = pd.read_csv('sh300.csv')		# 读取csv文件
# print('文件表头', df.columns)		# 打印文件表头
# df_describe = df.describe			# 查看统计信息
# 获取沪深指数(000300)的信息,包括交易日期(datetime)、开盘价(open)、收盘价(close)、最高价(high)、最低价(low)、成交量(vol)、成交金额(amount)、涨跌幅(p_change)
#################################################################################
# (2)设置超参数
n = 30		# 30天数据
LR = 0.001
EPOCH = 200
batch_size = 20
train_end = -600
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#################################################################################
# (3)装载数据
# 获取训练数据、最高价数据、最高价数据的索引(以日期为索引)
df, df_all, df_index = readData('high', n=n, train_end=train_end)
df_index = df_index.tolist()			# 最高价数据的索引转化为数组
df_all = np.array(df_all.tolist())		# 最高价数据转化为数组
df_numpy = np.array(df)					# 训练数据集转化为数组
# 归一化处理
df_numpy_mean = np.mean(df_numpy)
df_numpy_std = np.std(df_numpy)
df_numpy = (df_numpy - df_numpy_mean) / df_numpy_std
df_tensor = torch.Tensor(df_numpy)		# numpy转换为Tensor
trainset = mytrainset(df_tensor)		# 训练数据集的数据与标签拆分开
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=False)
################################################
# (4)训练模型
LSTM = LSTM(n).to(device)			# n为序列长度(即时间步长)
optimizer = torch.optim.Adam(LSTM.parameters(), lr=LR)		# 优化器(权重参数, 学习率)
loss_func = nn.MSELoss()			# 均方差损失函数
losses = []
for step in range(EPOCH):
	train_loss = 0
    for tx, ty in trainloader:
        tx = tx.to(device)
        ty = ty.to(device)
        # 在第1个维度上添加一个维度为1的维度,形状变为[batch,seq_len,input_size]
        output = LSTM(torch.unsqueeze(tx, dim=1)).to(device)
        loss = loss_func(torch.squeeze(output), ty)
        optimizer.zero_grad()		# 梯度清零
        loss.backward()				# 随时间反向传播
        optimizer.step()			# 梯度更新
        # 记录误差
        train_loss += loss.item()
    losses.append(train_loss/len(trainloader))
#################################################################################
# (5)测试模型
generate_data_train = []
generate_data_test = []
test_index = len(df_all) + train_end
df_all_normal = (df_all - df_numpy_mean) / df_numpy_std
df_all_normal_tensor = torch.Tensor(df_all_normal)
for i in range(n, len(df_all)):
    x = df_all_normal_tensor[i - n:i].to(device)
    # LSTM的输入必须是3维,故需添加两个1维的维度,最后成为[1,1,input_size]
    x = torch.unsqueeze(torch.unsqueeze(x, dim=0), dim=0)
    y = LSTM(x).to(device)
    if i < test_index:
        generate_data_train.append(torch.squeeze(y).detach().cpu().numpy() * df_numpy_std + df_numpy_mean)
    else:
        generate_data_test.append(torch.squeeze(y).detach().cpu().numpy() * df_numpy_std + df_numpy_mean)
#################################################################################
# (6)画图
plt.subplot(221), plt.plot(df_index, df_all), plt.title('high data')

plt.subplot(222), plt.plot(np.arange(len(losses)), losses), plt.xlabel('EPOCH'), plt.ylabel('Loss'), plt.title('train loss') 

plt.subplot(223) 
plt.plot(df_index[n:train_end], generate_data_train, label='generate_train')
plt.plot(df_index[train_end:], generate_data_test, label='generate_test')
plt.plot(df_index[train_end:], df_all[train_end:], label='real-data')
plt.legend()

plt.subplot(224)
plt.plot(df_index[train_end:-500], generate_data_test[-600:-500], label='test_data')
plt.plot(df_index[train_end:-500], df_all[train_end:-500], label='real_data')
plt.legend(), plt.title('predict results')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/182486.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于android的新闻阅读系统

需求信息&#xff1a; 从模块的角度将APP的主要内容划分为登录模块、新闻模块、留言模块、报道模块、关注模块、语音模块这六个功能模块&#xff0c;完成以下功能&#xff1a; &#xff08;1&#xff09;登录模块 当用户打开应用程序后&#xff0c;如果直接登录&#xff0c;由于…

ElasticSearch - DSL查询语法

目录 DSL查询分类 全文检索查询 精确查询 地理查询 复合查询 相关性算分 算分函数查询 BooleanQuery DSL查询分类 Elasticsearch提供了基于JSON的DSL(Domain Specific Language)来定义查询常见的查询类型包括&#xff1a; 查询所有&#xff1a;查询出所有的数据&#x…

Rust库交叉编译以及在Android与iOS中使用

本篇是关于交叉编译Rust库&#xff0c;生成Android和iOS的二进制文件&#xff08;so与a文件&#xff09;&#xff0c;以及简单的集成使用。 1.环境 系统&#xff1a;macOS 13.0 M1 Pro&#xff0c;Windows 10 Python: 3.9.6 Rust: 1.66.1 NDK: 21.4.7075529 这里就不具体说…

收藏贴!新手到底应该购买Salesforce专业版还是企业版?

Salesforce专业版&#xff08;Professional Edition&#xff09;是一个适用于小型企业的工具&#xff0c;它具有完整Salesforce套件的许多功能&#xff0c;但也有一些明显的限制。本篇文章将具体阐明Salesforce专业版是什么&#xff0c;它的优势以及其与企业版&#xff08;Ente…

SQL Server 2008如何创建定期自动备份任务

我们知道&#xff0c;利用SQL Server 2008数据库可以实现数据库的定期自动备份。方法是用SQL SERVER 2008自带的维护计划创建一个计划对数据库进行备份&#xff0c;下面我们将SQL SERVER 2008定期自动备份的方法分享给大家。 首先需要启动SQL Server Agent服务&#xff0c;这个…

Python实现vlog生成器

Python实现vlog生成器 vlog&#xff0c;全称为Video blog&#xff0c;意为影音博客&#xff0c;也有翻译为微录。 本文将尝试用Python基于Moviepy从一个文本文件中自动生成一个视频格式的vlog&#xff0c;实现的功能如下&#xff1a; 将文件的第一行标题生成视频的片头将文件…

C++——红黑树

目录 红黑树介绍 红黑树实现 节点的插入 完整代码 红黑树介绍 红黑树&#xff0c;是一种二叉搜索树&#xff0c;但在每个结点上增加一个存储位表示结点的颜色&#xff0c;可以是Red或Black。 通过对任何一条从根到叶子的路径上各个结点着色方式的限制&#xff0c;红黑树确…

5 使用pytorch实现线性回归

文章目录前提的python的知识补充基本流程准备数据构造计算图loss以及oprimizer循环训练课程代码课程来源&#xff1a; 链接课程内容参考&#xff1a; 链接以及&#xff08;强烈推荐&#xff09;Birandaの前提的python的知识补充 pytorch 之 call, init,forward pytorch系列nn.…

Python进行因子分析

1 因子分析 1.1 定义 因子分析法(Factor Analysis)是一种利用降维的思想&#xff0c;从研究原始变量相关矩阵内部的依赖关系出发&#xff0c;把一些具有错综复杂关系的变量归结为少数几个综合因子的一种多变量统计分析方法。其优势在于不仅可以在减少大量指标分析的工作量的同…

尚硅谷hello scala-配置idea2022.1.2版本创建scala2.11.8版本maven文件

0_前置说明 软件版本 idea2022.1.2 scala2.11.8 java1.8.0_144 尚硅谷资源下载 关注b站尚硅谷 idea资源 百度网盘&#xff1a;https://pan.baidu.com/s/1Gbavx34OfF29LZqJ8dc85g?pwdyyds 提取码: yyds B站直达&#xff1a;​https://www.bilibili.com/video/BV1CK411d7a…

LabVIEW NI Linux Real-Time深层解析

LabVIEW NI Linux Real-Time深层解析NI LabVIEW Real-Time模块支持NI Linux Real-Time操作系统&#xff0c;在选定的NI硬件上提供。本文介绍了具体的新特性和高级功能&#xff0c;可让您为应用充分利用NI Linux Real-Time。Linux Shell支持NI Linux Real-Time操作系统提供了全面…

《Linux0.11源码趣读》学习笔记day6

到上次记录&#xff0c;整个操作系统的全部代码就已经从硬盘加载到内存中了&#xff0c;然后这些代码又通过jmpi跳转到0x90200处&#xff0c;即硬盘第二个扇区开始处的内容 这些内容就是第二个操作系统源代码文件setup.s 不过现在先来看一下操作系统的编译过程 操作系统的编译…

后端学习 - Docker

文章目录基本概念三个核心概念&#xff1a;镜像、容器、仓库联合文件系统 UnionFS常用命令Docker File基本概念 一次配置&#xff0c;处处使用运行在同一宿主机上的容器是相互隔离的&#xff0c;各自拥有独立的文件系统容器模型和虚拟机模型的主要区别 相较于虚拟机而言&#…

【Pytorch项目实战】之生成式网络:编码器-解码器、自编码器AE、变分自编码器VAE、生成式对抗网络GAN

文章目录生成式网络 - 生成合成图像算法一&#xff1a;编码器-解码器算法二&#xff1a;自编码器&#xff08;Auto-Encoder&#xff0c;AE&#xff09;算法三&#xff1a;变分自编码器&#xff08;Variational Auto Encoder&#xff0c;VAE&#xff09;算法四&#xff1a;生成式…

九型人格是什么?

九型人格是什么? 九型人格学(Enneagram/Ninehouse)是一个有2000多年历史的古老学问,它按照人们习惯性的思维模式,情绪反应和行为习惯等性格特质,将人的性格分为九种,又被称为九柱图,起源于中亚西亚地区,和中国的八卦图有点像,近代的九型是由六十年代智利的一位心理学…

计算机组成原理 | 第四章:存储器 | 存储器与CPU连接 | 存储器的校验 | Cache容量计算

文章目录&#x1f4da;概述&#x1f407;存储器分类&#x1f407;存储器的层次结构&#x1f955;原理&#x1f955;主存速度慢的原因&#x1f955;存储器三个主要特征的关系&#x1f955;缓存-主存层次和主存-辅存层次⭐️&#x1f4da;主存储器&#x1f407;概述&#x1f955;…

【opencv】Haar分类器及Adaboost算法人脸识别理论讲解

提到opencv,就不得不提其图像识别能力,最近旷世开源的YoloX项目兴起,作为目前Yolo系列中的最强者,本人对其也很感兴趣,但是完全没用机器学习和计算机视觉的基础,知其然,不知其所以然,于是想稍稍入坑一下opencv图像识别,了解一下相关算法,(说不定以后毕设会用到呢)。…

磨金石教育影视干货分享|朋友亲身经历—给新人剪辑师的三个建议

大学的时候有一个同学很喜欢视频剪辑。平时没事就蹲在电脑前&#xff0c;下载一些素材&#xff0c;自学剪辑软件&#xff0c;慢慢的搞一些创意剪辑。那时候自媒体短视频已经很火爆&#xff0c;这位同学剪辑的视频&#xff0c;不管质量如何就往上面发。一开始我们对于新事物的认…

Java---微服务---分布式搜索引擎elasticsearch(2)

分布式搜索引擎elasticsearch&#xff08;2&#xff09;1.DSL查询文档1.1.DSL查询分类1.2.全文检索查询1.2.1.使用场景1.2.2.基本语法1.2.3.示例1.2.4.总结1.3.精准查询1.3.1.term查询1.3.2.range查询1.3.3.总结1.4.地理坐标查询1.4.1.矩形范围查询1.4.2.附近查询1.5.复合查询1…

SpringBoot+Vue项目学生读书笔记共享平台

文末获取源码 开发语言&#xff1a;Java 框架&#xff1a;springboot JDK版本&#xff1a;JDK1.8 服务器&#xff1a;tomcat7 数据库&#xff1a;mysql 5.7/8.0 数据库工具&#xff1a;Navicat11 开发软件&#xff1a;eclipse/myeclipse/idea Maven包&#xff1a;Maven3.3.9 浏…