LSTM模型的讲解与运用

news2025/1/21 9:27:20

实验目的:

循环神经网络(RNN)是一类以序列数据为输入,在序列的演进方向进行循环且所有循环单元按链式连接的神经网络。目前RNN已经广泛应用于语音识别、文本分类等自然语言处理任务中。本实验通过采用循环神经网络的变体长短期记忆网络(LSTM),合理设计网络结构和算法来实现自动写诗和藏头诗的功能。

环境配置:

利用pytcharm配置pytroch和opencv;

参数:

epoches = 8

Lr=0.001

Batch=575

实验步骤:(以自动写诗为例)

图 5 加载数据

由图5加载诗的数据,返回标签。

图 6 lstm网络

图 7 训练过程

图 8 自动写诗代码

图 9 根据首字母生成诗句代码

图5到图9分别从加载图片,训练神经网络,到生成诗句的过程。

准确率:

图 10 训练结果

结果分析:

图 11 自动写诗的过程

由该图可以看出,对床前明月光的自动续写结果。

图 12 根据首字生成诗

由test函数可以生成龙碧和弦四个字开头的诗句。效果如图12。

源代码:

import numpy as np      # tang.npz的压缩格式处理
import os       # 打开文件
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchnet import meter
import tqdm

def get_data():
    if os.path.exists(data_path):
        datas = np.load(data_path, allow_pickle=True)      # 加载数据
        data = datas['data']  # numpy.ndarray
        word2ix = datas['word2ix'].item()   # dic
        ix2word = datas['ix2word'].item()  # dic
        return data, word2ix, ix2word


class Net(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(Net, self).__init__()
        self.hidden_dim = hidden_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, num_layers=2, batch_first=False)
        # lstm输入为:seq, batch, input_size
        # lstm输出为:seq * batch * 256; (2 * batch * 256,...)
        self.linear1 = nn.Linear(self.hidden_dim, vocab_size)

    def forward(self, input, hidden=None):
        seq_len, batch_size = input.size()
        if hidden is None:
            h_0 = input.data.new(2, batch_size, self.hidden_dim).fill_(0).float()
            c_0 = input.data.new(2, batch_size, self.hidden_dim).fill_(0).float()
            h_0, c_0 = Variable(h_0), Variable(c_0)
        else:
            h_0, c_0 = hidden
        embeds = self.embeddings(input)     # (seq_len, batch_size, embedding_dim), (124,128,128)
        output, hidden = self.lstm(embeds, (h_0, c_0))      #(seq_len, batch_size, hidden_dim), (124,128,256)
        output = self.linear1(output.view(seq_len*batch_size, -1))      # ((seq_len * batch_size),hidden_dim), (15872,256) → (15872,8293)
        return output, hidden


def train():
    modle = Net(len(word2ix), 128, 256)     # 模型定义:vocab_size, embedding_dim, hidden_dim —— 8293 * 128 * 256
    criterion = nn.CrossEntropyLoss()
    if torch.cuda.is_available() == True:
        print('Cuda is available!')
        modle = modle.cuda()
        optimizer = torch.optim.Adam(modle.parameters(), lr=1e-3)  # 学习率1e-3
        criterion = criterion.cuda()
        loss_meter = meter.AverageValueMeter()

        period = []
        loss2 = []
        for epoch in range(8):     # 最大迭代次数为8
            loss_meter.reset()
            for i, data in tqdm.tqdm(enumerate(dataloader)):    # data: torch.Size([128, 125]), dtype=torch.int32
                data = data.long().transpose(0,1).contiguous()      # long为默认tensor类型,并转置, [125, 128]
                data = data.cuda()
                optimizer.zero_grad()

                input, target = Variable(data[:-1, :]), Variable(data[1:, :])
                output, _ = modle(input)
                loss = criterion(output, target.view(-1))       # torch.Size([15872, 8293]), torch.Size([15872])
                loss.backward()
                optimizer.step()

                loss_meter.add(loss.item())    # loss:tensor(3.3510, device='cuda:0', grad_fn=<NllLossBackward>)loss.data:tensor(3.0183, device='cuda:0')

                period.append(i + epoch * len(dataloader))
                loss2.append(loss_meter.value()[0])

                if (1 + i) % 575 == 0:       # 每575个batch可视化一次
                    print(str(i) +':' + generate(modle,'床前明月光', ix2word, word2ix))

        torch.save(modle.state_dict(), './model_poet.pth')
        plt.plot(period, loss2)
        plt.show()


def generate(model, start_words, ix2word, word2ix):     # 给定几个词,根据这几个词生成一首完整的诗歌
    txt = []
    for word in start_words:
        txt.append(word)
    input = Variable(torch.Tensor([word2ix['<START>']]).view(1,1).long())      # tensor([8291.]) → tensor([[8291.]]) → tensor([[8291]])
    input = input.cuda()
    hidden = None
    num = len(txt)
    for i in range(48):      # 最大生成长度
        output, hidden = model(input, hidden)
        if i < num:
            w = txt[i]
            input = Variable(input.data.new([word2ix[w]])).view(1, 1)
        else:
            top_index = output.data[0].topk(1)[1][0]
            w = ix2word[top_index.item()]
            txt.append(w)
            input = Variable(input.data.new([top_index])).view(1, 1)
        if w == '<EOP>':
            break
    return ''.join(txt)

def gen_acrostic(model, start_words, ix2word, word2ix):
    result = []
    txt = []
    for word in start_words:
        txt.append(word)
    input = Variable(
        torch.Tensor([word2ix['<START>']]).view(1, 1).long())  # tensor([8291.]) → tensor([[8291.]]) → tensor([[8291]])
    input = input.cuda()
    hidden = None

    num = len(txt)
    index = 0
    pre_word = '<START>'
    for i in range(48):
        output, hidden = model(input, hidden)
        top_index = output.data[0].topk(1)[1][0]
        w = ix2word[top_index.item()]

        if (pre_word in {'。', '!', '<START>'}):
            if index == num:
                break
            else:
                w = txt[index]
                index += 1
                input = Variable(input.data.new([word2ix[w]])).view(1,1)
        else:
            input = Variable(input.data.new([word2ix[w]])).view(1,1)
        result.append(w)
        pre_word = w
    return ''.join(result)


def test():
    modle = Net(len(word2ix), 128, 256)     # 模型定义:vocab_size, embedding_dim, hidden_dim —— 8293 * 128 * 256
    if torch.cuda.is_available() == True:
        modle.cuda()
        modle.load_state_dict(torch.load('./model_poet.pth'))
        modle.eval()
        # txt = generate(modle, '床前明月光', ix2word, word2ix)
        # print(txt)
        txt = gen_acrostic(modle, '龙碧和弦', ix2word, word2ix)
        print(txt)

if __name__ == '__main__':
    data_path = './tang.npz'
    data, word2ix, ix2word = get_data()
    data = torch.from_numpy(data)
    dataloader = torch.utils.data.DataLoader(data, batch_size=10, shuffle=True, num_workers=1)  # shuffle=True随机打乱
    #train()
    test()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1095251.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

解决axios不发起请求的问题

一个很简单axios发起的get请求&#xff0c;但是network就是没有xhr记录&#xff1a; 找了半天也没有找到原因&#xff0c;后来问了chatgpt才找到原因&#xff1a; 请求地址url前面没有const关键字&#xff0c;此时url相当于undefined&#xff0c;而axios在url为undefined时不会…

安装Linux系统对硬件的要求

很多初学者在安装 Linux 系统时&#xff0c;都对自己的电脑配置存在质疑&#xff0c;担心其是否能够满足安装 Linux 的要求。本节就从 CPU、内存、硬盘、显卡等这些方面&#xff0c;详细介绍一下安装 Linux 系统的最低配置。 基于硬件的快速发展以及操作系统核心功能的增加&…

基于QPlainTextEdit带标签行号的文本编辑器

关键代码 CodeEditor.h 文件 #ifndef CODEEDITOR_H #define CODEEDITOR_H#include <QPlainTextEdit> #include <QPaintEvent> #include <QContextMenuEvent> #include <QMouseEvent> #include <QMouseEvent> #include <QPainter> #includ…

Hadoop 配置 Kerberos 认证

1、安装 Kerberos 服务器和客户端 1.1 规划 服务端&#xff1a; bigdata3 客户端&#xff08;Hadoop集群&#xff09;&#xff1a; bigdata0 bigdata1 bigdata2 192.168.50.7 bigdata0.example.com bigdata0 192.168.50.8 bigdata1.example.com bigdata1 192.168.50.9 b…

vue补充继上一篇

组合式API-reactive和ref函数 1.reactive() 作用&#xff1a;接受对象类型数据的参数传入并返回一个响应式的对象 1.从vue包中导入reactive函数 2.在<script setup>中执行reactive函数并传入类型为对象的初始值&#xff0c;并使用变量接受返回值。 2.ref() 作用&am…

如何做好一个管理者

一、管理的目标 管理的目的是效率和效益。管理的核心是人。管理的本质是协调&#xff0c;协调的中心是人。管理的真谛是聚合企业的各类资源&#xff0c;充分运用管理的功能&#xff0c;以最优的投入获得最佳的回报&#xff0c;以实现企业既定目标。 二、管理中的核心 2.1、核…

前馈型BP神经网络

1.感知机和激活函数 感知机&#xff0c;是构成神经网络的基本单位&#xff0c;一个感知机可以接收n个输入X&#xff08;x1,x2,x3…xn)T&#xff08;每个输入&#xff0c;可以理解为一种特征&#xff09;,n个输入对应n个权值W&#xff08;w1,w2,w3…wn),此外还有一个偏置项b&am…

AI换脸之Faceswap技术原理与实践

目录 1.方法介绍 2.相关资料 3.实践记录 ​4.实验结果 1.方法介绍 Faceswap利用深度学习算法和人脸识别技术&#xff0c;可以将一个人的面部表情、眼睛、嘴巴等特征从一张照片或视频中提取出来&#xff0c;并将其与另一个人的面部特征进行匹配。主要应用在图像/视频换脸&am…

GMM模型与EM算法

GMM模型与EM算法 --> 聚类 -> 无监督机器学习[参考] 一、单个高斯分布GM的估计参数 1.1 高斯分布 结果趋近于正态分布 每次弹珠往下走的时候&#xff0c;碰到钉子会随机往左还是往右走&#xff0c;可以观测到多次随机过程结合的 高斯分布的似然函数 X1-XN 全部发生的…

【2023】redis-stream配合spring的data-redis详细使用(包括广播和组接收)

目录 一、简介1、介绍2、对比 二、整合spring的data-redis实现1、使用依赖2、配置类2.1、配置RedisTemplate bean2.2、异常类 3、实体类3.1、User3.2、Book 4、发送消息4.1、RedisStreamUtil工具类4.2、通过延时队列线程池模拟发送消息4.3、通过http主动发送消息 5、&#x1f3…

电动力学专题研讨:运动电荷之间的相互作用是否满足牛顿第三定律?

电动力学专题研讨&#xff1a;运动电荷之间的相互作用是否满足牛顿第三定律&#xff1f;​​​​​​​ 两个稳恒电流元之间的相互作用不一定服从牛顿第三定律常见的解释是&#xff1a;稳恒电流元是不能孤立存在的&#xff0c;因此不能得出结论 符号约定 两个运动点电荷之间的力…

for循环遍历的`form表单组件`rules规则校验失效问题——下拉框选择之后还是报红---亲测有效

问题: 大概的效果就是这种, for循环选择之后还是还是报红 看文章之前 : 先检查 model rules pops 有没有判定好 解决: 参考了他的 for循环遍历的form表单组件rules规则校验失效问题——输入内容后依然提示必填&#xff0c;亲测有效——基础积累_a-form-model的validat…

java项目之学生综合考评管理系统()

项目简介 学生综合考评管理系统实现了以下功能&#xff1a; 管理员&#xff1a;个人中心、通知公告管理、班级管理、学生管理、教师管理&#xff0c;课程信息管理、作业布置管理、作业提交管理、留言信息管理、课程成绩管理。学生&#xff1a;个人中心、通知公告管理、教师管…

Python赋值运算符:= 等号 以及 复合赋值运算符 += -= *=

赋值运算符 运算符描述实例赋值**将右侧的结果赋值给等号左侧的变量 案例1&#xff1a;把某个值赋值给某个变量 num 10 案例2&#xff1a;多个变量同时进行赋值操作 n 5 f 10.88 s ‘hello world’ 简写为&#xff1a;----------就类似一对一的&#xff0c;一个序号对应…

推荐一个备受好评的电子期刊制作网站

在当今数字化的时代&#xff0c;电子期刊制作已成为许多企业和个人的必备技能。对于想要制作出高质量、专业的电子期刊的人来说&#xff0c;选择一个简单易用、功能强大的网站至关重要。 下面分享一款备受好评的电子期刊制作网站-----FLBOOK在线制作电子杂志平台&#xff0c;在…

【frp实现内网穿透踩坑到成功篇】

【frp实现内网穿透踩坑到成功篇】 背景&需求配置服务器端配置客户端总结 背景&需求 白嫖了一个tencent入门级服务器 ∗ 1 *1 ∗1&#xff0c;学校实验室内网服务器 ∗ 1 *1 ∗1&#xff0c;需要访问内网的服务器。一顿搜寻资料后确定大致的路子&#xff1a; 第一步是…

因为写保护,U盘会“假死”。如何在Windows 10上删除写保护

本文介绍如何从USB驱动器、SD卡或单个文件中删除写保护。说明适用于Windows 10、Windows 8和Windows 7。 如何使用锁定开关解除写保护 如果你的计算机告诉你介质受写保护&#xff0c;请在USB或SD卡上查找写保护开关&#xff08;也称为锁定开关&#xff09;。如果介质有此开关…

文心一言 VS 讯飞星火 VS chatgpt (113)-- 算法导论10.2 5题

五、用go语言&#xff0c;使用单向循环链表实现字典操作 INSERT、DELETE 和 SEARCH&#xff0c;并给出所写过程的运行时间。 文心一言&#xff1a; 在Go语言中&#xff0c;我们首先需要定义一个Node结构体来表示单向循环链表的节点&#xff0c;然后再定义一个LinkedList结构体…