【探索AI】二十一 深度学习之第4周:循环神经网络(RNN)与长短时记忆(LSTM)

news2024/11/15 11:07:15

循环神经网络(RNN)与长短时记忆(LSTM)

    • RNN的基本原理与结构
    • LSTM的原理与实现
    • 序列建模与文本生成任务
    • 实践:使用RNN或LSTM进行文本分类或生成任务
      • 步骤 1: 数据准备
      • 步骤 2: 构建模型
      • 步骤 3: 定义损失函数和优化器
      • 步骤 4: 训练模型
      • 步骤 5: 评估模型
      • 步骤 6: 使用模型进行文本生成

RNN的基本原理与结构

RNN,即循环神经网络(Recurrent Neural Network),是一种专门用于处理序列数据的神经网络。它的基本原理和结构主要基于以下几点:

基本原理:RNN的基本原理是,序列中的每个元素都与其前后元素存在某种关联或依赖,这种关联或依赖就是序列的时序关系。RNN通过捕捉并记忆这种时序关系,实现对序列数据的建模。RNN不是刚性地记忆所有固定长度的序列,而是通过隐藏状态来存储之前时间步的信息。
结构特点:RNN的结构特点主要体现在其循环性上。RNN的每个神经元不仅接收当前时刻的输入,还接收上一时刻的输出,并将其作为当前时刻的输入。这种结构使得RNN能够处理具有时序关系的数据,并且在处理过程中,RNN会不断地将之前时刻的信息传递到当前时刻,从而实现对序列数据的建模。
权重共享:RNN的另一个重要特点是权重共享。在RNN中,每个时刻的神经元都使用相同的权重,这意味着RNN在处理不同时刻的数据时,使用的是相同的参数。这种权重共享的方式大大减少了RNN的参数量,使得RNN能够更有效地处理序列数据。
综上所述,RNN的基本原理和结构使得它能够有效地处理具有时序关系的数据,实现对序列数据的建模和预测。这使得RNN在自然语言处理、时间序列分析等领域具有广泛的应用前景。

LSTM的原理与实现

LSTM(长短期记忆)是一种特殊的循环神经网络(RNN),设计用于解决传统RNN在处理长期依赖关系时遇到的问题。LSTM通过引入“门”的概念和细胞状态来实现这一点。

  1. 细胞状态与水平线:LSTM的关键在于细胞状态,它类似于传送带,直接在整个链上运行。这个状态只有少量的线性交互,因此信息在上面流传保持不变会很容易。

  2. 门结构:为了实现信息的添加或删除,LSTM使用了一种叫做“门”的结构。门可以实现选择性地让信息通过,这主要通过一个sigmoid神经层和一个逐点相乘的操作来实现。sigmoid层输出的每个元素都是一个在0和1之间的实数,表示让对应信息通过的权重。例如,0表示“不让任何信息通过”,1表示“让所有信息通过”。

  3. 三个门:LSTM通过三个这样的门结构来实现信息的保护和控制,分别是输入门、遗忘门和输出门。

    • 遗忘门:LSTM的第一步是决定从细胞状态中丢弃什么信息。这个决定通过一个称为忘记门层完成。该门会读取上一个时刻的隐藏状态 h t − 1 h_{t-1} ht1和当前时刻的输入 x t x_t xt,输出一个在0到1之间的数值给每个在细胞状态 C t − 1 C_{t-1} Ct1中的数字。1表示“完全保留”,0表示“完全舍弃”。
    • 输入门:负责处理当前时刻的输入,决定哪些信息需要被存储在细胞状态中。它包含两个步骤:首先,一个sigmoid层决定哪些信息需要更新;其次,一个tanh层生成新的候选值向量,这些值可能会被添加到状态中。
    • 输出门:基于细胞状态来决定当前的输出。它首先通过sigmoid层来决定细胞状态的哪些部分将输出到LSTM的当前输出值;然后,将细胞状态通过tanh进行处理(得到一个在-1到1之间的值),再与sigmoid门的输出相乘,从而得到最终的输出。
  4. 实现:在实现LSTM时,通常使用深度学习框架(如TensorFlow、PyTorch等)来构建网络。这些框架提供了高级的API,使得构建和训练LSTM模型变得相对简单。在实现过程中,需要定义网络结构、损失函数、优化器等,并进行模型的训练和评估。

总的来说,LSTM通过引入细胞状态和门结构,有效地解决了传统RNN在处理长期依赖关系时遇到的问题。这使得LSTM在许多序列处理任务中取得了显著的成果,如语音识别、机器翻译、情感分析等。

序列建模与文本生成任务

序列建模与文本生成任务是自然语言处理(NLP)领域中的两个重要概念。

序列建模是指在给定一组输入序列的情况下,预测或生成相应的输出序列。在机器学习和自然语言处理领域,序列建模问题被广泛应用。例如,在语音识别任务中,根据音频信号的输入序列预测对应的语音文本序列;在机器翻译任务中,根据源语言的输入序列生成目标语言的输出序列。序列建模面临着许多挑战,如数据稀疏性、长距离依赖、多模态输入等。为了解决这些问题,研究者们提出了许多有效的方法,如循环神经网络(RNN)、长短时记忆网络(LSTM)、门控循环单元(GRU)等。

文本生成任务是指通过计算机算法和模型,以一定的策略和规则生成特定领域的文本内容。文本生成任务在多个领域都有广泛的应用,如机器翻译、文本摘要、文本生成等。这些任务可以通过深度学习技术,如递归神经网络(RNN)特别是长短期记忆网络(LSTM)和门控循环单元(GRU)等生成模型来实现。生成模型根据输入信息生成文本,而训练数据则用于训练这些模型。损失函数用于评估模型的预测性能,并指导模型的优化。

综上所述,序列建模与文本生成任务是自然语言处理领域中两个密切相关的概念。序列建模为文本生成提供了基础和支持,而文本生成任务则是序列建模的一个重要应用领域。随着深度学习技术的不断发展,序列建模与文本生成任务将在更多的领域发挥重要作用。

实践:使用RNN或LSTM进行文本分类或生成任务

实践使用RNN或LSTM进行文本分类或生成任务需要一系列步骤。下面我将提供一个简单的指导,使用PyTorch库进行文本分类任务。请注意,为了执行此实践,您需要具备Python编程和PyTorch库的基础知识。

步骤 1: 数据准备

首先,您需要准备用于训练和测试的数据集。数据集应该包含文本数据和相应的标签。您可以将文本数据预处理为适合RNN或LSTM模型的形式。

# 导入必要的库
import torch
from torchtext.legacy import data
from torchtext.legacy import datasets

# 定义字段和文本预处理器
TEXT = data.Field(sequential=True, tokenize='spacy', lower=True)
LABEL = data.LabelField(dtype=torch.float)

# 创建数据管道
fields = [('text', TEXT), ('label', LABEL)]
train_data, test_data = datasets.TabularDataset.splits(
    path='.', train='train.csv', test='test.csv', format='csv',
    skip_header=True, fields=fields
)

# 构建词汇表
TEXT.build_vocab(train_data, max_size=25000, vectors="glove.6B.100d", unk_init=torch.Tensor.normal_)
LABEL.build_vocab(train_data)

# 创建数据迭代器
batch_size = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, test_iterator = data.BucketIterator.splits((train_data, test_data), batch_size=batch_size, device=device)

步骤 2: 构建模型

接下来,您需要定义RNN或LSTM模型。下面是一个简单的LSTM模型的例子。

import torch.nn as nn

class LSTMModel(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
        super(LSTMModel, self).__init__()
        
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, text):
        embedded = self.embedding(text)
        output, (hidden, cell) = self.lstm(embedded)
        assert torch.equal(hidden[-1,:,:], cell[-1,:,:])
        return self.fc(hidden[-1,:,:])

input_dim = len(TEXT.vocab)
embedding_dim = 100
hidden_dim = 256
output_dim = 1  # 假设是二分类问题

model = LSTMModel(input_dim, embedding_dim, hidden_dim, output_dim)

步骤 3: 定义损失函数和优化器

选择适当的损失函数和优化器。对于分类任务,通常使用交叉熵损失(CrossEntropyLoss)和Adam优化器。

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

步骤 4: 训练模型

现在,您可以开始训练模型了。

num_epochs = 10

for epoch in range(num_epochs):
    for batch in train_iterator:
        optimizer.zero_grad()
        predictions = model(batch.text).squeeze(1)
        loss = criterion(predictions, batch.label)
        loss.backward()
        optimizer.step()
    
    print(f'Epoch: {epoch+1:02}, Loss: {loss.item():.4f}')

步骤 5: 评估模型

在训练完成后,您可以使用测试数据集评估模型的性能。

model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for batch in test_iterator:
        predictions = model(batch.text).squeeze(1)
        predicted = torch.argmax(predictions, dim=1)
        correct += (predicted == batch.label).sum().item()
        total += batch.label.size(0)

print(f'Accuracy of the network on the test data: {100 * correct / total:.2f}%')

步骤 6: 使用模型进行文本生成

在步骤6中,我们将使用之前训练好的模型来进行文本生成。这通常涉及到将模型设置为评估模式(evaluation mode),然后提供一个初始的文本片段作为种子(seed),让模型从这个种子开始生成后续的文本。

以下是一个使用PyTorch和LSTM模型进行文本生成的例子:

import torch
from torch import nn
from torch.autograd import Variable
from model import LSTM, vocab  # 假设你有一个名为LSTM的模型和名为vocab的词汇表

# 假设我们有一个训练好的LSTM模型
model = LSTM(vocab_size=len(vocab), embedding_dim=256, hidden_dim=512, num_layers=2)
model.load_state_dict(torch.load('path_to_saved_model.pt'))  # 加载预训练模型
model.eval()  # 将模型设置为评估模式

# 设置超参数
max_length = 100  # 生成文本的最大长度
starting_text = "I enjoy"  # 初始文本种子
temperature = 1.0  # 控制生成文本多样性的参数(softmax的温度参数)

# 将初始文本转换为模型可以理解的格式
starting_text = starting_text.lower().replace('.', ' .')  # 将所有文本转换为小写,并在句末添加空格
starting_text = [vocab.stoi[word] for word in starting_text.split()]  # 将单词转换为索引
starting_text = torch.LongTensor(starting_text).to(device)  # 转换为PyTorch张量并移到设备上

# 初始化隐藏状态
hidden = model.init_hidden(starting_text.size(0))

# 开始生成文本
generated_text = []
for i in range(max_length):
    # 获取模型的输出
    output, hidden = model(starting_text, hidden)
    
    # 使用softmax函数获取预测单词的概率分布
    predicted = torch.multinomial(output.squeeze(1), num_samples=1)
    
    # 根据概率分布选择一个单词
    predicted_index = predicted.item()
    
    # 将预测的单词添加到生成的文本中
    generated_text.append(predicted_index)
    
    # 将预测的单词作为下一个输入
    starting_text = Variable(torch.LongTensor([predicted_index]).to(device))
    
    # 如果生成的单词是结束标记,则停止生成
    if vocab.itos[predicted_index] == '<eos>':
        break

# 将生成的索引转换为文本
generated_text = [vocab.itos[idx] for idx in generated_text]
generated_sentence = ' '.join(generated_text)

# 打印生成的文本
print(generated_sentence)

在上面的代码中,我们首先将预训练的模型加载到内存中,并将其设置为评估模式。然后,我们定义了生成文本的超参数,包括最大长度和初始文本种子。

我们将初始文本转换为模型可以理解的格式,即单词的索引序列。然后,我们初始化模型的隐藏状态,并开始循环生成文本。在每个循环中,我们将当前生成的文本作为输入传递给模型,并获取模型的输出。我们使用torch.multinomial函数根据模型的输出概率分布选择一个单词,并将其添加到生成的文本中。如果生成的单词是结束标记(例如<eos>),则我们停止生成。

最后,我们将生成的索引序列转换回文本,并打印生成的文本。

请注意,生成的文本的质量和多样性取决于许多因素,包括模型的结构、训练数据的质量、训练时间以及超参数的选择(如温度参数)。在实际应用中,可能需要调整这些参数以获得最佳结果。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1487663.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数字电路三宝:锁存器、寄存器和触发器

在数字电路设计中&#xff0c;很多电子工程师经常会用到锁存器、寄存器和触发器&#xff0c;它们各自承担着不同的功能&#xff0c;但共同为数字电路的稳定性和高效性提供了坚强保障&#xff0c;下面将谈谈这三大元件&#xff0c;希望对小伙伴们有所帮助。 1、锁存器&#xff0…

HOOPS Communicator对3D大模型轻量化加载与渲染的4种解决方案

今天给大家介绍一些关于3D Web轻量化引擎HOOPS Commuicator的关键概念&#xff0c;这些概念可以帮您在HOOPS Communicator流缓存服务器之上更好地构建您自己的模型流服务器。如果您是有大型数据集&#xff0c;那么&#xff0c;使用流缓存服务器可以极大地帮助您最大限度地减少内…

EthSign联合创始人 POTTER LI 确认出席Hack .Summit() 香港区块链开发者大会!

thSign联合创始人 POTTER LI确认将出席由 Hack VC 主办&#xff0c;并由 AltLayer 和 Berachain 联合主办&#xff0c;与 SNZ 和数码港合作&#xff0c;由 Techub News 承办的Hack.Summit() 2024区块链开发者盛会。 Potter Li&#xff0c;南加州大学应有数学系&#xff0c;南加…

hook函数——useReducer

目录 1.useReducer定义2.useReducer用法3.useState和useReducer区别 1.useReducer定义 const [state, dispatch] useReducer(reducer, initialArg, init?) reducer&#xff1a;用于更新 state 的纯函数。参数为 state 和 action&#xff0c;返回值是更新后的 state。state …

excel统计分析——拉丁方设计

参考资料&#xff1a;生物统计学 拉丁方设计也是随机区组设计&#xff0c;是对随机区组设计的一种改进。它在行的方向和列的方向都可以看成区组&#xff0c;因此能实现双向误差的控制。在一般的试验设计中&#xff0c;拉丁方常被看作双区组设计&#xff0c;用于提高发现处理效应…

新《公司法》规定5年内完成注册资本实缴有哪些影响

2024年对很多企业可谓是一个洗牌的年份。随着新公司法的颁布&#xff0c;很多企业都忧心忡忡面临着各种挑战。其中新《公司法》规定5年内完成注册资本实缴就让很多企业老板睡不着觉。新《公司法》规定注册资本实缴制度将对市场和企业产生一系列影响。主要有以下这几方面&#x…

【Java项目介绍和界面搭建】拼图小游戏——键盘、鼠标事件

&#x1f36c; 博主介绍&#x1f468;‍&#x1f393; 博主介绍&#xff1a;大家好&#xff0c;我是 hacker-routing &#xff0c;很高兴认识大家~ ✨主攻领域&#xff1a;【渗透领域】【应急响应】 【Java】 【VulnHub靶场复现】【面试分析】 &#x1f389;点赞➕评论➕收藏 …

【javascript】快速入门javascript

本文前言及说明 适合学过一门语言有一定基础的人看。 省略最初学习编程时的各种编程重复的基础知识。 javascript简介 编程语言&#xff08;主前端&#xff09; 用途&#xff1a;主web前后端&#xff0c;游戏&#xff0c;干别人网站 优点&#xff1a;速度快&#xff0c;浏…

Python之Web开发初学者教程----卸载ubuntu系统

Python之Web开发初学者教程----卸载ubuntu系统 Windows 10自带了Subsytem for Linux (WSL)功能&#xff0c;可以让用户在Windows命令行环境下运行Linux命令。用户可以在Windows应用商店中下载和安装Ubuntu子系统&#xff0c;有时在使用过程中需要完全删除Ubuntu子系统以释放硬…

Go语言学习-实现一个workshop

Creating new Go packages 1、创建一个Go package&#xff0c;叫&#xff1a; MyLib • Let’s create a Go package called MyLib and use it in our program 2、在go_project文件夹下开启终端&#xff0c;输入指令创建go.mod文件。 go mod init go_project• Assuming our…

【HTML】HTML基础6.1(表格以及常见属性)

目录 表格介绍 表格标签 表格标签的常见属性 案例 知识点总结 表格介绍 在浏览器中&#xff0c;我们经常见到形如 这样的表格形式&#xff0c;一般来说&#xff0c;表格是为了让数据看起来更加清晰&#xff0c;增强数据的可读性 有的程序员也会用表格进行排版 表格标签 &…

(UE4升级UE5)Selected Level Actor节点升级到UE5

本问所用工具为&#xff1a; UE5 UE4 插件AssetDeveTool包含&#xff1a;快速选择功能自动化批量LOD功能自动化批量展UV功能自动化批量减面功能自动化批量修改查找替换材质功能批量重命名工具碰撞器修改工具资源整理工具支持4.26 - 5.3版本https://mbd.pub/o/bread/mbd-ZZubkp…

ControlNet作者新作LayerDiffusion,让SD直接生成生成透明图像,堪比商用抠图软件

ControlNet作者又出新工作&#xff0c;这次的工作LayerDiffusion它使得大规模预训练的Stable Diffusion能够生成透明图像。该方法允许生成单个透明图像或多个透明图层&#xff0c;效果堪比商业产品Adobe Stock。而且LayerDiffusion和ControlNet一样支持基于SD微调的模型。 &quo…

Flutter的线程模型

在Flutter框架中&#xff0c;Embedder层负责把Flutter嵌入到各个平台上去&#xff0c;其所做的主要工作包括线程设置、渲染Surface设置&#xff0c;以及插件等。因此&#xff0c; Embedder负责线程的创建和管理&#xff0c;并且提供Task Runner给Engine使用。Engine则是负责提供…

钉钉h5应用 环境报错Error: Do not support the current environment:notInDingTalk

钉钉h5应用 环境报错 Error: Do not support the current environment&#xff1a;notInDingTalk problem Error: Do not support the current environment&#xff1a;notInDingTalk reason 前端页面运行在普通浏览器 solution 需要将h5页面在后台发布后&#xff0c;在钉…

Java中的日期时间类详解(建议收藏)!!!

Java中的日期时间类详解 1. LocalDate、LocalTime和LocalDateTime2. DateTimeFormatter3. 日期时间计算和比较4. **时区和日历**&#xff1a; 总结 本文详细解释了Java提供了 java.time 包来处理日期和时间的方式。 1. LocalDate、LocalTime和LocalDateTime LocalDate &#…

【HarmonyOS】鸿蒙开发之Stage模型-UIAbility的启动模式——第4.4章

UIAbi lity的启动模式简介 一共有四种:singleton,standard,specified,multion。在项目目录的:src/main/module.json5。默认开启模式为singleton(单例模式)。如下图 singleton&#xff08;单实例模式&#xff09;启动模式 每个UIAbility只存在唯一实例。任务列表中只会存在一…

【EI会议征稿通知】第四届控制与智能机器人国际学术会议(ICCIR 2024)

第四届控制与智能机器人国际学术会议&#xff08;ICCIR 2024&#xff09; 2024 4th International Conference on Control and Intelligent Robotics 第四届控制与智能机器人国际学术会议&#xff08;ICCIR 2024&#xff09;由华南理工大学自动化科学与工程学院主办&#xff…

【Android移动开发】helloworld项目文件剖析

本文讨论了一个Android应用的Gradle项目的各个方面。涵盖了Gradle的启动脚本&#xff0c;项目的配置文件&#xff08;如build.gradle和gradle.properties&#xff09;&#xff0c;以及应用的源代码和资源文件。具体内容包括了项目结构、Gradle插件的配置、AndroidManifest.xml文…

SSM框架,SpringMVC框架的学习(上)

SpringMVC介绍 Spring Web MVC是基于Servlet API构建的原始Web框架&#xff0c;从一开始就包含在Spring Framework中。正式名称“Spring Web MVC”来自其源模块的名称&#xff08; spring-webmvc &#xff09;&#xff0c;但它通常被称为“Spring MVC”。 SpringMVC涉及组件 …