用通俗易懂的方式讲解:LSTM原理及生成藏头诗(Python)

news2025/4/14 7:36:00

一、基础介绍

1.1 神经网络模型

常见的神经网络模型结构有前馈神经网络(DNN)、RNN(常用于文本 / 时间系列任务)、CNN(常用于图像任务)等等。

前馈神经网络是神经网络模型中最为常见的,信息从输入层开始输入,每层的神经元接收前一级输入,并输出到下一级,直至输出层。整个网络信息输入传输中无反馈(循环)。即任何层的输出都不会影响同级层,可用一个有向无环图表示。图片

1.2 RNN 介绍

循环神经网络(RNN)是基于序列数据(如语言、语音、时间序列)的递归性质而设计的,是一种反馈类型的神经网络,它专门用于处理序列数据,如逐字生成文本或预测时间序列数据(例如股票价格、诗歌生成)。图片

RNN和全连接神经网络的本质差异在于“输入是带有反馈信息的”,RNN除了接受每一步的输入x(t) ,同时还有输入上一步的历史反馈信息——隐藏状态h (t-1) ,也就是当前时刻的隐藏状态h(t) 或决策输出O(t) 由当前时刻的输入 x(t) 和上一时刻的隐藏状态h (t-1) 共同决定。从某种程度,RNN和大脑的决策很像,大脑接受当前时刻感官到的信息(外部的x(t) )和之前的想法(内部的h (t-1) )的输入一起决策。

图片

RNN的结构原理可以简要概述为两个公式

RNN的隐藏状态为:h(t) = f( U * x(t) + W * h(t-1) + b1), f为激活函数,常用tanh、relu;

RNN的输出为:o(t) = g( V * h(t) + b2),g为激活函数,当用于分类任务,一般用softmax;

1.3 从RNN到LSTM

但是在实际中,RNN在长序列数据处理中,容易导致梯度爆炸或者梯度消失,也就是长期依赖(long-term dependencies)问题,其根本原因就是模型“记忆”的序列信息太长了,都会一股脑地记忆和学习,时间一长,就容易忘掉更早的信息(梯度消失)或者崩溃(梯度爆炸)。

梯度消失:历史时间步的信息距离当前时间步越长,反馈的梯度信号就会越弱(甚至为0)的现象,梯度被近距离梯度主导,导致模型难以学到远距离的依赖关系。

改善措施:可以使用 ReLU 激活函数;门控RNN 如GRU、LSTM 以改善梯度消失。

梯度爆炸:网络层之间的梯度(值大于 1)重复相乘导致的指数级增长会产生梯度爆炸,导致模型无法有效学习。

改善措施:可以使用 梯度截断;引导信息流的正则化;ReLU 激活函数;门控RNN 如GRU、LSTM(和普通 RNN 相比多经过了很多次导数都小于 1激活函数,因此 LSTM 发生梯度爆炸的频率要低得多)以改善梯度爆炸。

所以,如果我们能让 RNN 在接受上一时刻的状态和当前时刻的输入时,有选择地记忆和遗忘一部分内容(或者说信息),问题就可以解决了。比如上上句话提及”我去考试了“,然后后面提及”我考试通过了“,那么在此之前说的”我去考试了“的内容就没那么重要,选择性地遗忘就好了。这也就是长短期记忆网络(Long Short-Term Memory, LSTM)的基本思想。

二、LSTM原理

LSTM是种特殊RNN网络,在RNN的基础上引入了“门控”的选择性机制,分别是遗忘门、输入门和输出门,从而有选择性地保留或删除信息,以能够较好地学习长期依赖关系。如下图RNN(上) 对比 LSTM(下):

图片

2.1 LSTM的核心

在RNN基础上引入门控后的LSTM,结构看起来好复杂!但其实LSTM作为一种反馈神经网络,核心还是历史的隐藏状态信息的反馈,也就是下图的Ct:图片对标RNN的ht隐藏状态的更新,LSTM的Ct只是多个些“门控”删除或添加信息到状态信息。由下面依次介绍LSTM的“门控”:遗忘门,输入门,输出门的功能,LSTM的原理也就好理解了。

2.2 遗忘门

LSTM 的第一步是通过"遗忘门"从上个时间点的状态Ct-1中丢弃哪些信息。

具体来说,输入Ct-1,会先根据上一个时间点的输出ht-1和当前时间点的输入xt,并通过sigmoid激活函数的输出结果ft来确定要让Ct-1,来忘记多少,sigmoid后等于1表示要保存多一些Ct-1的比重,等于0表示完全忘记之前的Ct-1。图片

2.3 输入门

下一步是通过输入门,决定我们将在状态中存储哪些新信息。

我们根据上一个时间点的输出ht-1和当前时间点的输入xt 生成两部分信息i t 及C~t,通过sigmoid输出i t,用tanh输出C~t。之后通过把i t 及C~t两个部分相乘,共同决定在状态中存储哪些新信息。图片

在输入门 + 遗忘门控制下,当前时间点状态信息Ct为:

图片

2.4 输出门

最后,我们根据上一个时间点的输出ht-1和当前时间点的输入xt 通过sigmid 输出Ot,再根据Ot 与 tanh控制的当前时间点状态信息Ct 相乘作为最终的输出。图片

综上,一张图可以说清LSTM原理:图片

三、LSTM简单写诗

本节项目利用深层LSTM模型,学习大小为10M的诗歌数据集,自动可以生成诗歌。图片

如下代码构建LSTM模型。如需完整代码,文末获取

图片

model = tf.keras.Sequential([
    # 不定长度的输入
    tf.keras.layers.Input((None,)),
    # 词嵌入层
    tf.keras.layers.Embedding(input_dim=tokenizer.vocab_size, output_dim=128),
    # 第一个LSTM层,返回序列作为下一层的输入
    tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),
    # 第二个LSTM层,返回序列作为下一层的输入
    tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),
    # 对每一个时间点的输出都做softmax,预测下一个词的概率
    tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(tokenizer.vocab_size, activation='softmax')),
])

# 查看模型结构
model.summary()
# 配置优化器和损失函数
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.categorical_crossentropy)

模型训练,考虑训练时长,就简单训练2个epoch。图片

class Evaluate(tf.keras.callbacks.Callback):
    """
    训练过程评估,在每个epoch训练完成后,保留最优权重,并随机生成SHOW_NUM首古诗展示
    """

    def __init__(self):
        super().__init__()
        # 给loss赋一个较大的初始值
        self.lowest = 1e10

    def on_epoch_end(self, epoch, logs=None):
        # 在每个epoch训练完成后调用
        # 如果当前loss更低,就保存当前模型参数
        if logs['loss'] <= self.lowest:
            self.lowest = logs['loss']
            model.save(BEST_MODEL_PATH)
        # 随机生成几首古体诗测试,查看训练效果
        print("cun'h")
        for i in range(SHOW_NUM):
            print(generate_acrostic(tokenizer, model, head="春花秋月"))

# 创建数据集
data_generator = PoetryDataGenerator(poetry, random=True)
# 开始训练
model.fit_generator(data_generator.for_fit(), steps_per_epoch=data_generator.steps, epochs=TRAIN_EPOCHS,
                    callbacks=[Evaluate()])

加载简单训练的LSTM模型,输入关键字(如:算法进阶)后,自动生成藏头诗。可以看出诗句粗略看上去挺优雅,但实际上经不起推敲。后面增加训练的epoch及数据集应该可以更好些。

# 加载训练好的模型
model = tf.keras.models.load_model(BEST_MODEL_PATH)

keywords = input('输入关键字:\n')


# 生成藏头诗
for i in range(SHOW_NUM):
    print(generate_acrostic(tokenizer, model, head=keywords),'\n')

图片

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

我们建了NLP面试与技术交流群, 想要进交流群、需要本文源码、提升技术的同学,可以直接加微信号:mlc2060。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、添加微信号:mlc2060,备注:技术交流
方式②、微信搜索公众号:机器学习社区,后台回复:技术交流

资料
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1361704.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringBoot 调用mybatis报错:Invalid bound statement (not found):

启动SpringBoot报错&#xff1a;Invalid bound statement (not found): 参考此文排查 命中了第6条 记录一手坑爹的Invalid bound statement (not found)&#xff08;六个方面&#xff09; mapper文件路径配置错误 订正以后 问题解决

项目从npm迁移到pnpm

场景如下&#xff1a;在安装Vue3时默认为使用Npm安装&#xff0c;如图所示&#xff1a; 安装完后项目就包含了基于NPM的node_modules、package.json&#xff0c;以及package-lock.json 如果想使用pnpm去安装依赖项的话&#xff0c;可以通过如下几个步骤实现&#xff1a; ①删…

基于R语言(SEM)结构方程模型教程

详情点击链接&#xff1a;基于R语言&#xff08;SEM&#xff09;结构方程模型教程 01、R/Rstudio (2)R语言基本操作&#xff0c;包括向量、矩阵、数据框及数据列表等生成和数据提取等 (3)R语言数据文件读取、整理&#xff08;清洗&#xff09;、结果存储等&#xff08;含tidve…

JAVA中小型医院信息管理系统源码 医院系统源码

开发框架&#xff1a;SpringBootJpathymeleaf 搭建环境&#xff1a;jdk1.8idea/eclipsemaven3mysql5.6 基于SpringBoot的中小型医院信息管理系统&#xff0c;做的比较粗糙&#xff0c;但也实现了部分核心功能。 就诊卡提供了手动和读卡两种方式录入&#xff0c;其中IC读卡器使用…

跟随chatgpt从零开始安装git(Windows系统)

为什么我们要安装Git&#xff1f;Git有什么用&#xff1f; 1. 版本控制&#xff1a;Git 可以追踪代码的所有变化&#xff0c;记录每个提交的差异&#xff0c;使您能够轻松地回溯到任何历史版本或比较不同版本之间的差异。 2. 分支管理&#xff1a;通过 Git 的分支功能&#xff…

【C语言:可变参数列表】

文章目录 1.什么是可变参数列表2.可变参数列表的分析与使用2.1使用2.2分析原理2.3分析原码 1.什么是可变参数列表 对于一般的函数而言&#xff0c;参数列表都是固定的&#xff0c;而且各个参数之间用逗号进行分开。这种函数在调用的时候&#xff0c;必须严格按照参数列表中参数…

云卷云舒:【实战篇】Redis迁移

1. 简介 Remote Dictionary Server(Redis)是一个由Salvatore Sanfilippo写的key-value存储系统&#xff0c;是一个开源的使用ANSIC语言编写、遵守BSD协议、支持网络、可基于内存亦可持久化的日志型、Key-Value数据库&#xff0c;并提供多种语言的API。 2. 迁移原理 redis-sh…

src refspec master does not match any

新项目推送至 Git 空仓库时抛出如下异常 src refspec master does not match any 初始化 init 都做了但反复尝试 git push -u origin master 均无果 后发现权限不够 .... 起初设置为开发者,后变更为了主程序员再次尝试 push 成功 .... 以上便是此次分享的全部内容&#xff0c;…

Pandas数据可视化

pandas库是Python数据分析的核心库 它不仅可以加载和转换数据&#xff0c;还可以做更多的事情&#xff1a;它还可以可视化 pandas绘图API简单易用&#xff0c;是pandas流行的重要原因之一 Pandas 单变量可视化 单变量可视化&#xff0c; 包括条形图、折线图、直方图、饼图等 …

记一次服务器被入侵的排查过程

起因 阿里云安全中心报告了告警信息&#xff0c;同时手机短信、邮件、电话也接收到了来自阿里云的风险通知&#xff0c;感觉这方面阿里云还是不错。 排查及解决过程 这条wget指令究竟是怎么被运行的 我无法定位到攻击人员是通过什么样的方式让我的java程序执行了wget这条指…

打造清晰的日志管理策略:如何在 NestJS 中集成 winston 高级日志系统

前言 在Web应用程序的开发过程中&#xff0c;日志管理是不可或缺的一部分。日志可以帮助我们了解应用程序的运行状态&#xff0c;监控系统行为&#xff0c;以及在出现问题时快速定位和解决问题。 对于使用NestJS框架的项目来说&#xff0c;集成一个高效、可扩展的日志系统尤为…

Java:IO流详解

文章目录 基础流1、IO概述1.1 什么是IO1.2 IO的分类1.3 顶级父类们 2、字节流2.1 一切皆为字节2.2 字节输出流 OutputStream2.3 FileOutputStream类2.3.1 构造方法2.3.2 写出字节数据2.3.3 数据追加续写2.3.4 写出换行 2.4 字节输入流 InputStream2.5 FileInputStream类2.5.1 构…

RFM会员价值度模型

模型基本原理 会员价值度用来评估用户的价值情况&#xff0c;是区分会员价值的重要模型和参考依据&#xff0c;也是衡量不同营销效果的关键指标。 价值度模型一般基于交易行为产生&#xff0c;衡量的是有实体转化价值的行为。常用的价值度模型是RFM RFM模型是根据会员 最近…

C#,简单选择排序算法(Simple Select Sort)的源代码与数据可视化

排序算法是编程的基础。 常见的四种排序算法是&#xff1a;简单选择排序、冒泡排序、插入排序和快速排序。其中的快速排序的优势明显&#xff0c;一般使用递归方式实现&#xff0c;但遇到数据量大的情况则无法适用。实际工程中一般使用“非递归”方式实现。本文搜集发布四种算法…

8 单链表---带表头节点

上节课所学的顺序表的缺点 顺序表的最大问题&#xff1a;插入和删除时需要移动大量元素 链式存储的定义 链式存储的逻辑结构 链表中的基本概念&#xff1a; 注意&#xff1a;表头节点并不属于数据元素 单链表图示&#xff1a; 把3个需要的结构体定义出来&#xff1a; typdef …

《网络是怎样连接的》2.3节图表(自用)

图4.1&#xff1a;TCP拆分数据与ACK号 图4.2&#xff1a;实际工作中ACK号与序号的交互过程 首先&#xff0c;客户端在连接时需要计算出与从客户端到服务器方向通信相关的序号初始值&#xff0c;并将这个值发送给服务器&#xff08;①&#xff09;。 接下来&#xff0c;服务器会…

MySQL学习笔记2: MySQL的前置知识

目录 1. MySQL是什么?2. 什么是客户端&#xff0c;什么是服务器&#xff1f;3. 服务器的特点4. 安装mysql5. mysql 客户端6. mysql 服务器7. mysql的本体8. MySQL 使用什么来存储数据&#xff1f;9. 数据库的多种含义10. MySQL 存储数据的组织方式 1. MySQL是什么? MySQL 是…

新手养布偶猫如何选择猫主食冻干?K9、sc、希喂三个品牌推荐!

布偶猫是食肉动物&#xff0c;但由于肠胃脆弱敏感&#xff0c;所以在饮食上需要特别关注哦&#xff01;为了给它们最好的呵护&#xff0c;现在有了主食冻干这种优质猫主食&#xff01;它不仅符合猫咪的天然饮食习惯&#xff0c;还用了新鲜生肉做原料呢&#xff01;营养满分不说…

如何设计企业级业务流程?学习华为的流程六级分类经验

业务流程管理&#xff08;BPM&#xff09;是一种系统化的方法&#xff0c;用于分析、设计、执行、监控和优化组织的业务流程&#xff0c;以实现预期的目标和价值。业务流程管理中&#xff0c;流程的分级方法有多种&#xff0c;常见的有以下几种&#xff1a; APQC的流程分级方法…

Agilent安捷伦E4407B频谱分析仪26.5GHz

E4407B是安捷伦ESA-E系列频谱分析仪&#xff0c;它是一款能够适应未来需要的中性能频谱分析仪解决方案。该系列在测量速度、动态范围、精度和功率分辨能力上&#xff0c;都为类似价位的产品建立了性能标准。其灵活的平台设计使得研发、制造和现场服务工程师能够自定义产品&…