自然语言处理入门6——RNN生成文本

news2025/4/18 16:29:59

一、文本生成

我们在前面的文章中介绍了LSTM,根据输入时序数据可以输出下一个可能性最高的数据,如果应用在文字上,就是根据输入的文字,可以预测下一个可能性最高的文字。利用这个特点,我们可以用LSTM来生成文本。输入一个单词,做embedding处理后,再输入到LSTM,会输出备选单词的得分,经过softmax得到概率,我们选择概率最高的单词作为输出。这种输出叫做确定性输出,如下图所示:

如果我们根据概率随机选择一个单词输出,这叫做概率性输出,类似于大语言模型中的temperature,当temperature高时,生成的变化越多,确定性越低。下面的代码是根据RNN模型生成文字的代码,其中start_id代表生成开始的第一个单词编号,skip_ids代表需要过滤的单词,比如空或者数字等等,sample_size代表采样大小,这里就是生成的单词的数量。

class RnnlmGen(Rnnlm):
    def generate(self, start_id, skip_ids=None, sample_size=100):
        word_ids = [start_id] # 最终的单词编号列表
        x = start_id # 第一个单词编号
        # 如果生成的单词列表长度还小于sample_size就继续生成
        while len(word_ids) < sample_size:
            # 转变shape便于输入模型
            x = np.array(x).reshape(1,1)
            # 根据模型预测输出单词得分
            score = self.predict(x)
            # 得到概率
            p = softmax(score.flatten())
            # 根据概率选择输出的单词编号
            sampled = np.random.choice(len(p), size=1, p=p)
            if (skip_ids is None) or (sampled not in skip_ids):
                x = sampled
                word_ids.append(int(x))
        return word_ids
        
corpus, word_to_id, id_to_word = load_data('train')
vocab_size = len(word_to_id)
corpus_size = len(corpus)
model = RnnlmGen()
# 设定start单词和skip单词
start_word = 'you'
start_id = word_to_id[start_word]
skip_words = ['N','<unk>','$']
skip_ids = [word_to_id[w] for w in skip_words]
# 生成文本
word_ids = model.generate(start_id, skip_ids)
txt = ' '.join(id_to_word[i] for i in word_ids)
txt = txt.replace(' <eos>','.\n')
print(txt)
# 输出:
you freshman retreat teeth instantly enhanced university brands exceptionally affiliates 
various unfair leslie our assumes studies begin monitored bart leap reasonably gary poorer 
industry southeast cemetery tables epo supportive nervous sooner inc soybeans scientific 
expertise applying lufthansa introduction leventhal casting lights carries feared revamping 
solar sachs widen training reins moves industrials technologies extent diagnostic narcotics 
regularly literally hanover primarily reinsurance pro-life serve specifications fm jumbo 
penalty actions l.p. ann keenan princeton despite stuart arise instrumentation classic exposed 
violation dishonesty warner-lambert nicaragua infringed fantasy marcus portrait imported jordan 
spurring component perestroika undo remic sacrifice veterans arms-control postal relying 
homelessness quack voters

可以看到输出的文字几乎没有什么含义,这是因为我们使用的模型是原始的LSTM,没有经过训练,如果我们加载了上篇文字训练过后的模型BetterRnnlm后,效果会有明显提升。

... ...
model = RnnlmGen()
model.load_params('Rnnlm.pkl')
... ...
# 输出:
you place a short part of their nation 's relatively rumored air.
 far everyone agreed and yet as out next is a market to insist in  out of wohlstetter
  is violates that on why it took the max for one. a caution on a chiefs of substantial 
  chips three-year firstsouth in the ratio for candidates more investors the tax in 
  medical lawsuits will be something about three to government simultaneous.
 a new market will n't be surprising these specify must very be exactly like economic 
 houses to manufacturers competitors where he adds some error in not new

二、序列到序列模型

用RNN生成文本,还有一种更通用的用法,称为序列到序列的模型,也就是sequence to sequence。最典型的seq2seq应用就是机器翻译,输入一串用某种语言表示的文字,输出用另一种文字表示的文字。另外典型的应用包括:

自动摘要,它是输入一个长文本,输出一个表达核心含义的短文本;

问答系统:输入一个问题文本,输出一个答案文本;

聊天机器人:输入人类的文本,输出机器的文本;

算法学习:输入一串算法描述,输出计算答案;

图像文字生成:输入图像,这里图像也可以通过CNN等网络表示成一串向量,输出描述图像的文字;

可以看出,seq2seq可以有两个模块构成,一个模块处理输入文本,一个模块生成输出文本,处理输入文本的模块我们称为Encoder,生成文本的模块我们称为Decoder。一般来说,一个seq2seq就是由一个Encoder和一个Decoder合并在一起得到的。以书中的例子,日语翻译成英语为例的话,流程图如下:

下面我们以模拟一个加法的学习来实现这个seq2seq模型。这个模型的输入是一个三位数字以内的加法表达式,如“32+100”,输出是运算的结果,如“132”,编码器对“32+100”这个表达式拆分成“3”,“2”,“+”,“1”,“0”,“0”等几个字符,作为文本输入到编码器,得到隐藏信息,解码器输入隐藏信息以及“1”,“3”,“2”作为标签值,得到输出值,将输出值与“1”,“3”,“2”比较,得到损失,进行反向传播,实现整个训练过程。

不过这里有几个需要注意的地方:因为输入到编码器中的加法表达式长度可能不同,所以需要解决这个问题,最方便的方法是padding,也就是在表达式的前后插入填充字符。如:

总体流程如下图所示。训练的时候采用编码器和解码器训练,实际使用中,把编码器得到的隐藏信息输出到生成器,生成结果。

代码实现是基于之前的LSTM代码基础之上的,其实和LSTM的代码构建有很多类似的地方,编码器基本就是一个普通的LSTM,编码器代码:

class Encoder:
    def __init__(self, vocab_size, wordvec_size, hidden_size):
        V,D,H = vocab_size, wordvec_size, hidden_size
        rn = np.random.randn
        embed_W = (rn(V,D)/100).astype('f')
        lstm_Wx = (rn(D,4*H)/np.sqrt(D)).astype('f')
        lstm_Wh = (rn(H,4*H)/np.sqrt(H)).astype('f')
        lstm_b = np.zeros(4*H).astype('f')
        self.embed = TimeEmbedding(embed_W)
        self.lstm = TimeLSTM(lstm_Wx, lstm_Wh, lstm_b, stateful = False)
        self.params = self.embed.params + self.lstm.params
        self.grads = self.embed.grads + self.lstm.grads
        self.hs = None
        
    def forward(self,xs):
        xs = self.embed.forward(xs)
        hs = self.lstm.forward(xs)
        self.hs = hs
        return self.hs[:,-1,:]
    
    def backward(self,dh):
        dhs = np.zeros_like(self.hs)
        dhs[:,-1,:] = dh
        dout = self.lstm.backward(dhs)
        dout = self.embed.backward(dout)
        return dout

解码器和编码器的区别就在于,解码器还要多输入一个隐藏信息,并且正向传播输出多一个打分步骤。generate是实际生成文本结果的函数,和前面所述生成文本的区别在于,这里是生成加法结果的,所以不用概率性输出,而采用确定性输出,就是用argmax选择得分最高的输出,解码器代码:

class Decoder:
    def __init__(self, vocab_size, wordvec_size, hidden_size):
        V,D,H = vocab_size, wordvec_size, hidden_size
        rn = np.random.randn
        embed_W = (rn(V,D)/100).astype('f')
        lstm_Wx = (rn(D,4*H)/np.sqrt(D)).astype('f')
        lstm_Wh = (rn(H,4*H)/np.sqrt(H)).astype('f')
        lstm_b = np.zeros(4*H).astype('f')
        affine_W = (rn(H,V)/np.sqrt(H)).astype('f')
        affine_b = np.zeros(V).astype('f')
        self.embed = TimeEmbedding(embed_W)
        self.lstm = TimeLSTM(lstm_Wx, lstm_Wh, lstm_b, stateful = True)
        self.affine = TimeAffine(affine_W, affine_b)
        self.params, self.grads = [], []
        for layer in (self.embed, self.lstm, self.affine):
            self.params += layer.params
            self.grads += layer.grads
        
    def forward(self, xs, h):
        self.lstm.set_state(h)
        out = self.embed.forward(xs)
        out = self.lstm.forward(out)
        score = self.affine.forward(out)
        return score
    
    def backward(self, dscore):
        dout = self.affine.backward(dscore)
        dout = self.lstm.backward(dout)
        dout = self.embed.backward(dout)
        dh = self.lstm.dh
        return dh
    
    def generate(self, h, start_id, sample_size):
        sampled = []
        sample_id = start_id
        self.lstm.set_state(h)
        for _ in range(sample_size):
            x = np.array(sample_id).reshape((1,1))
            out = self.embed.forward(x)
            out = self.lstm.forward(out)
            score = self.affine.forward(out)
            sample_id = np.argmax(score.flatten())
            sampled.append(int(sample_id))
        return sampled

基于上述编码器和解码器,构建seq2seq模型:

class Seq2seq(BaseModel):
    def __init__(self, vocab_size, wordvec_size, hidden_size):
        V,D,H = vocab_size, wordvec_size, hidden_size
        self.encoder = Encoder(V,D,H)
        self.decoder = Decoder(V,D,H)
        self.softmax = TimeSoftmaxWithLoss()
        self.params = self.encoder.params + self.decoder.params
        self.grads = self.encoder.grads + self.decoder.grads
        
    def forward(self, xs, ts):
        # 样本从开始到倒数第二个,标签从第1个开始到最后一个
        decoder_xs, decoder_ts = ts[:,:-1], ts[:,1:]
        h = self.encoder.forward(xs)
        score = self.decoder.forward(decoder_xs, h)
        loss = self.softmax.forward(score, decoder_ts)
        return loss
    
    def backward(self, dout=1):
        dout = self.softmax.backward(dout)
        dh = self.decoder.backward(dout)
        dout = self.encoder.backward(dh)
        return dout
    
    def generate(self, xs, start_id, sample_size):
        h = self.encoder.forward(xs)
        sampled = self.decoder.generate(h, start_id, sample_size)
        return sampled

采用该模型训练25个epoch后,预测精确度约在11%左右。一个改进办法是,将输入的表达式反转:

另一个改进方法是Peeky,它的特点就是把编码器传过来的隐藏信息,都输入到解码器的每个节点中,而之前只有解码器的第一个节点接收编码器传过来的隐藏信息。

训练代码如下(完整代码可以参考书的附带源代码):

# 读入数据集
(x_train,t_train),(x_test,t_test) = load_data('addition.txt', seed=1984)
char_to_id, id_to_char = get_vocab()
x_train, x_test = x_train[:,::-1], x_test[:,::-1] # 反转

# 设定超参数
vocab_size = len(char_to_id)
wordvec_size = 16
hidden_size = 128
batch_size = 128
max_epoch = 25
max_grad = 5.0

# 生成模型/优化器/训练器
model = PeekySeq2seq(vocab_size, wordvec_size, hidden_size)
optimizer = Adam()
trainer = Trainer(model, optimizer)

acc_list = []
for epoch in range(max_epoch):
    trainer.fit(x_train, t_train, max_epoch=1, batch_size=batch_size, max_grad=max_grad)
    correct_num = 0
    for i in range(len(x_test)):
        question, correct = x_test[[i]], t_test[[i]]
        verbose = i < 10
        correct_num += eval_seq2seq(model, question, correct, id_to_char, verbose)
    acc = float(correct_num)/len(x_test)
    acc_list.append(acc)
    print('val acc %.3f%%' % (acc*100))

acc_list3 = acc_list
plt.plot([i for i in range(max_epoch)], [acc*100 for acc in acc_list3], label='peeky+reverse', c='r',linestyle='--',marker='o')
plt.plot([i for i in range(max_epoch)], [acc*100 for acc in acc_list2], label='reverse', c='y',linestyle='-.',marker='>')
plt.plot([i for i in range(max_epoch)], [acc*100 for acc in acc_list1], label='original', c='g',linestyle=':',marker='*')
plt.xlabel('iterations')
plt.ylabel('accuracy')
plt.legend()
plt.show()

这里,我稍微修改了一下书中代码,我把三个精度图放在一起比较了。acc_list1代表原始seq2seq模型的训练精度,acc_list2代表输入反转后的模型训练精度,acc_list3代码输入反转并且加入Peeky后的模型训练精度。

可以看到,原始的seq2seq模型在训练25个epoch后,精度大约11%,反转输入后,训练精度大概55%,再加入Peeky后,精度已经非常接近100%了,一般可以达到96%~98%之间。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2332080.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

FPGA_DDR错误总结

1otp 31-67 解决 端口没连接 必须赋值&#xff1b; 2.PLACE 30-58 TERM PLINITCALIBZ这里有问题 在顶层输出但是没有管脚约束报错 3.ERROR: [Place 30-675] 这是时钟不匹配IBUF不在同一个时钟域&#xff0c;时钟不在同一个时钟域里&#xff0c;推荐的不建议修改 问题 原本…

NOIP2011提高组.玛雅游戏

目录 题目算法标签: 模拟, 搜索, d f s dfs dfs, 剪枝优化思路*详细注释版代码精简注释版代码 题目 185. 玛雅游戏 算法标签: 模拟, 搜索, d f s dfs dfs, 剪枝优化 思路 可行性剪枝 如果某个颜色的格子数量少于 3 3 3一定无解因为要求字典序最小, 因此当一个格子左边有…

基于ssm框架的校园代购服务订单管理系统【附源码】

1、系统框架 1.1、项目所用到技术&#xff1a; javaee项目 Spring&#xff0c;springMVC&#xff0c;mybatis&#xff0c;mvc&#xff0c;vue&#xff0c;maven项目。 1.2、项目用到的环境&#xff1a; 数据库 &#xff1a;mysql5.X、mysql8.X都可以jdk1.8tomcat8 及以上开发…

【10】数据结构的矩阵与广义表篇章

目录标题 二维以上矩阵矩阵存储方式行序优先存储列序优先存储 特殊矩阵对称矩阵稀疏矩阵三元组方式存储稀疏矩阵的实现三元组初始化稀疏矩阵的初始化稀疏矩阵的创建展示当前稀疏矩阵稀疏矩阵的转置 三元组稀疏矩阵的调试与总代码十字链表方式存储稀疏矩阵的实现十字链表数据标签…

猜猜乐游戏(python)

import randomprint(**30) print(欢迎进入娱乐城) print(**30)username input(输入用户名&#xff1a;) cs 0answer input( 是否加入"猜猜乐"游戏(yes/no)? )if answer yes:while True:num int(input(%s! 当前你的金币数为%d! 请充值(100&#xffe5;30币&…

spring boot 2.7 集成 Swagger 3.0 API文档工具

背景 Swagger 3.0 是 OpenAPI 规范体系下的重要版本&#xff0c;其前身是 Swagger 2.0。在 Swagger 2.0 之后&#xff0c;该规范正式更名为 OpenAPI 规范&#xff0c;并基于新的版本体系进行迭代&#xff0c;因此 Swagger 3.0 实际对应 OpenAPI 3.0 版本。这一版本着重强化了对…

Dinky 和 Flink CDC 在实时整库同步的探索之路

摘要&#xff1a;本文整理自 Dinky 社区负责人&#xff0c;Apache Flink CDC contributor 亓文凯老师在 Flink Forward Asia 2024 数据集成&#xff08;二&#xff09;专场中的分享。主要讲述 Dinky 的整库同步技术方案演变至 Flink CDC Yaml 作业的探索历程&#xff0c;并深入…

视频融合平台EasyCVR搭建智慧粮仓系统:为粮仓管理赋能新优势

一、项目背景 当前粮仓管理大多仍处于原始人力监管或初步信息化监管阶段。部分地区虽采用了简单的传感监测设备&#xff0c;仍需大量人力的配合&#xff0c;这不仅难以全面监控粮仓复杂的环境&#xff0c;还容易出现管理 “盲区”&#xff0c;无法实现精细化的管理。而一套先进…

3D Gaussian Splatting as MCMC 与gsplat中的应用实现

3D高斯泼溅(3D Gaussian splatting)自2023年提出以后,相关研究paper井喷式增长,尽管出现了许多改进版本,但依旧面临着诸多挑战,例如实现照片级真实感、应对高存储需求,而 “悬浮的高斯核” 问题就是其中之一。浮动高斯核通常由输入图像中的曝光或颜色不一致引发,也可能…

C++初阶-C++的讲解1

目录 1.缺省(sheng)参数 2.函数重载 3.引用 3.1引用的概念和定义 3.2引用的特性 3.3引用的使用 3.4const引用 3.5.指针和引用的关系 4.nullptr 5.总结 1.缺省(sheng)参数 &#xff08;1&#xff09;缺省参数是声明或定义是为函数的参数指定一个缺省值。在调用该函数是…

STM32_USB

概述 本文是使用HAL库的USB驱动 因为官方cubeMX生成的hal库做组合设备时过于繁琐 所以这里使用某大神的插件,可以集成在cubeMX里自动生成组合设备 有小bug会覆盖生成文件里自己写的内容,所以生成一次后注意保存 插件安装 下载地址 https://github.com/alambe94/I-CUBE-USBD-Com…

STM32 的编程方式总结

&#x1f9f1; 按照“是否可独立工作”来分&#xff1a; 库/方式是否可独立使用是否依赖其他库说明寄存器裸写✅ 是❌ 无完全自主控制&#xff0c;无库依赖标准库&#xff08;StdPeriph&#xff09;✅ 是❌ 只依赖 CMSIS自成体系&#xff08;F1专属&#xff09;&#xff0c;只…

MFC工具栏CToolBar从专家到小白

CToolBar m_wndTool; //创建控件 m_wndTool.CreateEx(this, TBSTYLE_FLAT|TBSTYLE_NOPREFIX, WS_CHILD | WS_VISIBLE | CBRS_FLYBY | CBRS_TOP | CBRS_SIZE_DYNAMIC); //加载工具栏资源 m_wndTool.LoadToolBar(IDR_TOOL_LOAD) //在.rc中定义&#xff1a;IDR_TOOL_LOAD BITMAP …

大厂机考——各算法与数据结构详解

目录及其索引 哈希双指针滑动窗口子串普通数组矩阵链表二叉树图论回溯二分查找栈堆贪心算法动态规划多维动态规划学科领域与联系总结​​ 哈希 ​​学科领域​​&#xff1a;计算机科学、密码学、数据结构 ​​定义​​&#xff1a;通过哈希函数将任意长度的输入映射为固定长度…

10:00开始面试,10:08就出来了,问的问题有点变态。。。

从小厂出来&#xff0c;没想到在另一家公司又寄了。 到这家公司开始上班&#xff0c;加班是每天必不可少的&#xff0c;看在钱给的比较多的份上&#xff0c;就不太计较了。没想到8月一纸通知&#xff0c;所有人不准加班&#xff0c;加班费不仅没有了&#xff0c;薪资还要降40%…

基于ueditor编辑器的功能开发之给编辑器图片增加水印功能

用户需求&#xff0c;双击编辑器中的图片的时候&#xff0c;出现弹框&#xff0c;用户可以选择水印缩放倍数、距离以及水印所放置的方位&#xff08;当然有很多水印插件&#xff0c;位置大小透明度用户都能够自定义&#xff0c;但是用户需求如此&#xff0c;就自己写了&#xf…

【CSS基础】- 02(emmet语法、复合选择器、显示模式、背景标签)

css第二天 一、emmet语法 1、简介 ​ Emmet语法的前身是Zen coding,它使用缩写,来提高html/css的编写速度, Vscode内部已经集成该语法。 ​ 快速生成HTML结构语法 ​ 快速生成CSS样式语法 2、快速生成HTML结构语法 生成标签 直接输入标签名 按tab键即可 比如 div 然后tab…

【码农日常】vscode编码clang-format格式化简易教程

文章目录 0 前言1 工具准备1.1 插件准备1.2 添加.clang-format1.3 添加配置 2 快速上手 0 前言 各路大神都说clangd好&#xff0c;我也来试试。这篇主要讲格式化部分。 1 工具准备 1.1 插件准备 照图安装。 1.2 添加.clang-format 右键添加文件&#xff0c;跟添加个.h或者.c…

金融数据分析(Python)个人学习笔记(7):网络数据采集以及FNN分类

一、网络数据采集 证券宝是一个免费、开源的证券数据平台&#xff08;无需注册&#xff09;&#xff0c;提供大盘准确、完整的证券历史行情数据、上市公司财务数据等&#xff0c;通过python API获取证券数据信息。 1. 安装并导入第三方依赖库 baostock 在命令提示符中运行&…

死锁 手撕死锁检测工具

目录 引言 一.理论联立 1.死锁的概念和原因 2.死锁检测的基本思路 3.有向图在死锁检测中的应用 二.代码实现案例&#xff08;我们会介绍部分重要接口解释&#xff09; 1.我们定义一个线性表来存线程ID和锁ID 2.表中数据的查询接口 3.表中数据的删除接口 4.表中数据的添…