【LSTM实战】跨越千年,赋诗成文:用LSTM重现唐诗的韵律与情感

news2024/11/23 13:18:02

本文将介绍如何使用LSTM训练一个能够创作诗歌的模型。为了训练出效果优秀的模型,我整理了来自网络的4万首诗歌数据集。我们的模型可以直接使用预先训练好的参数,这意味着您无需从头开始训练,即可在自己的电脑上体验AI作诗的乐趣。我已经为您准备好了这些训练好的参数,让您能够轻松地在自己的设备上开始创作。本文将详细讲解如何在个人电脑上运行该模型,即使您没有机器学习方面的背景知识,也能轻松驾驭,让您的AI模型在自己的电脑上运行起来,体验AI创作诗歌的乐趣.所有的代码和资料都在仓库:https://gitee.com/yw18791995155/generate_poetry.git

秋风吹拂,窗外的树叶似灵动的舞者翩翩而舞,落日余晖将天际晕染成一片醉人的橘红。
与此同时,AI 于知识的瀚海中遨游,遍览数千篇文章后,开启了它的首次创作之旅。

在对近 4 万首唐诗深度学习之后,赋诗如下:

在这里插入图片描述

此诗颇具韵味,实乃勤勉研习之硕果。汲取全唐诗之精华,方成就这般非凡之能,常人岂易企及?

本博客将简要分析其中的技术细节,若有阐释未尽之处,在此诚挚欢迎诸君于评论区畅所欲言,各抒己见。先呈上仓库链接https://gitee.com/yw18791995155/generate_poetry.git

若诸位无暇详阅,不妨为该项目点亮 star 或进行 fork,诸君的每一份支持都将如熠熠星光,化作我砥砺前行之强劲动力源泉。言归正传,让我们一同开启打造AI诗人的旅程吧

01 环境配置

在开始之前,确保你的电脑已经安装了必要的依赖库:PyTorch 和 NumPy。安装命令如下:

   pip install torch torchvision torchaudio numpy

一切就绪,我们可以开始了!

02 初识LSTM

长短期记忆网络 LSTM是一种特殊类型的循环神经网络(RNN),它被设计用来解决传统RNN在处理长序列数据时遇到的长期依赖性问题(梯度消失和梯度爆炸问题)。
在这里插入图片描述

LSTM的核心优势在于其能够学习并记住长期的信息依赖关系。这种能力使得LSTM在处理长文本内容时比普通RNN更为出色。LSTM网络中包含了四个主要的组件,它们通过门控机制来控制信息的流动:

  1. 遗忘门(Forget Gate):决定哪些信息应该被遗忘,不再保留在单元状态中。
  2. 输入门(Input Gate):决定哪些新信息将被存储在单元状态中。
  3. 单元状态(Cell State):携带数据穿越时间的信息带,可以看作是LSTM的“记忆”。
  4. 输出门(Output Gate):决定哪些信息将从单元状态输出到下一个隐藏状态。

这些门控机制使得LSTM能够有选择性地保留或遗忘信息,从而有效地捕捉和利用长期依赖性。这种设计灵感来源于对传统RNN在处理长序列时遗忘信息的挑战的回应,LSTM通过这些门控结构,使得网络能够更加灵活地处理时间序列数据。

03处理数据

接下来,首先要做的就是读取准备好的诗歌数据。然后对数据进行清洗,剔除那些包含特殊字符或长度不符合要求的诗歌。清洗完数据后,我们会为每首诗加上开始和结束的标志,确保生成的诗歌有明确的起止符号。

然后,我们会构建词典,为每个词分配一个唯一的索引,同时建立词汇到索引、索引到词汇的映射关系。最后,把每首诗转换成数字序列,这样就能让模型进行处理了。

import collections
import numpy as np
import torch

# 定义起始和结束标记
start_token = 'B'
end_token = 'E'

def process_poems(file_name):
    """
    处理诗歌文件,将诗歌转换为数字序列,并构建词汇表。

    :param file_name: 诗歌文件的路径
    :return:
        - poems_vector: 诗歌的数字序列列表
        - word_to_idx: 词汇到索引的映射字典
        - idx_to_word: 索引到词汇的映射列表
    """
    # 初始化诗歌列表
    poems = []

    # 读取文件并处理每一行
    with open(file_name, "r", encoding='utf-8') as f:
        for line in f.readlines():
            try:
                # 分割标题和内容
                title, content = line.strip().split(':')
                content = content.replace(' ', '')

                # 过滤掉包含特殊字符的诗歌
                if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content or \
                        start_token in content or end_token in content:
                    continue

                # 过滤掉长度不符合要求的诗歌
                if len(content) < 5 or len(content) > 79:
                    continue

                # 添加起始和结束标记
                content = start_token + content + end_token
                poems.append(content)
            except ValueError as e:
                pass

    # 统计所有单词的频率
    all_words = [word for poem in poems for word in poem]
    counter = collections.Counter(all_words)
    words = sorted(counter.keys(), key=lambda x: counter[x], reverse=True)

    # 添加空格作为填充符
    words.append(' ')
    words_length = len(words)

    # 构建词汇到索引和索引到词汇的映射
    word_to_idx = {word: i for i, word in enumerate(words)}
    idx_to_word = [word for word in words]

    # 将诗歌转换为数字序列
    poems_vector = [[word_to_idx[word] for word in poem] for poem in poems]

    return poems_vector, word_to_idx, idx_to_word

def generate_batch(batch_size, poems_vec, word_to_int):
    """
    生成批量训练数据。
    :param batch_size: 批量大小
    :param poems_vec: 诗歌的数字序列列表
    :param word_to_int: 词汇到索引的映射字典
    :return:
        - x_batches: 输入数据批次
        - y_batches: 目标数据批次
    """
    # 计算可以生成的批次数
    num_example = len(poems_vec) // batch_size

    x_batches = []
    y_batches = []

    for i in range(num_example):
        start_index = i * batch_size
        end_index = start_index + batch_size

        # 获取当前批次的诗歌
        batches = poems_vec[start_index:end_index]

        # 找到当前批次中最长的诗歌长度
        length = max(map(len, batches))

        # 初始化输入数据,使用空格进行填充
        x_data = np.full((batch_size, length), word_to_int[' '], np.int32)

        # 填充输入数据
        for row, batch in enumerate(batches):
            x_data[row, :len(batch)] = batch

        # 创建目标数据,目标数据是输入数据向右移一位
        y_data = np.copy(x_data)
        y_data[:, :-1] = x_data[:, 1:]

        """
        x_data             y_data
        [6,2,4,6,9]       [2,4,6,9,9]
        [1,4,2,8,5]       [4,2,8,5,5]
        """

        # 将当前批次的数据添加到列表中
        yield torch.tensor(x_data), torch.tensor(y_data)

04创建模型

现在是时候搭建我们的 LSTM 模型了!我们将创建一个双层 LSTM 网络。双层 LSTM 比单层的更有能力捕捉复杂的模式和结构,能够更好地处理诗歌这种带有丰富语言特征的任务。

import torch
import torch.nn as nn
import torch.optim as optim


class RNNModel(nn.Module):
    def __init__(self, vocab_size, rnn_size=128, num_layers=2):
        """
        构建RNN序列到序列模型。
        :param vocab_size: 词汇表大小
        :param rnn_size: RNN隐藏层大小
        :param num_layers: RNN层数

        """
        super(RNNModel, self).__init__()

        # 选择LSTM单元
        # 参数说明:输入大小、隐藏层大小、层数、batch_first=True表示输入数据的第一维是批次大小
        self.cell = nn.LSTM(rnn_size, rnn_size, num_layers, batch_first=True)

        # 嵌入层,将词汇表中的词转换为向量
        # vocab_size + 1 是因为在词嵌入中需要有一个特殊标记,用于表示填充位置,所以词嵌入时会加一个词。
        self.embedding = nn.Embedding(vocab_size + 1, rnn_size)

        # RNN隐藏层大小
        self.rnn_size = rnn_size

        # 全连接层,用于输出预测
        # 输入大小为RNN隐藏层大小,输出大小为词汇表大小加1
        self.fc = nn.Linear(rnn_size, vocab_size + 1)

    def forward(self, input_data, hidden):
        """
        前向传播
        :param input_data: 输入数据,形状为 (batch_size, sequence_length)
        :param output_data: 输出数据(训练时提供),形状为 (batch_size, sequence_length)
        :return: 输出结果或损失
        """
        # 获取批次大小
        batch_size = input_data.size(0)

        # 嵌入层,将输入数据转换为向量
        # 输入数据形状为 (batch_size, sequence_length),嵌入后形状为 (batch_size, sequence_length, rnn_size)
        embedded = self.embedding(input_data)
        # 通过RNN层
        # 输入形状为 (batch_size, sequence_length, rnn_size),输出形状为 (batch_size, sequence_length, rnn_size)
        outputs, hidden = self.cell(embedded, hidden)
        # 将输出展平
        # 展平后的形状为 (batch_size * sequence_length, rnn_size)
        outputs = outputs.contiguous().view(-1, self.rnn_size)
        # 通过全连接层
        # 输入形状为 (batch_size * sequence_length, rnn_size),输出形状为 (batch_size * sequence_length, vocab_size + 1)
        logits = self.fc(outputs)
        return logits, hidden

05训练模型

接下来,就是我们最考验耐性的部分——训练模型了。训练过程中,你可能需要一些时间,所以建议使用 GPU 加速。经过实测,使用 GPU 训练速度大约是 CPU 的四倍左右。所以,如果你有条件,最好让 GPU 出马,省时省力。

import torch
from model import RNNModel
from torch import nn
from poem_data_processing import *
import os
import time

# 检查是否有可用的GPU,如果没有则使用CPU
# windows用户使用torch.cuda.is_available()来检查是否有可用的GPU。
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

print(f"Using device: {device}")

def train(poems_path, num_epochs, batch_size, lr):
    """
    训练RNN模型并进行预测。

    参数:
    poems_path (str): 诗歌数据文件路径。
    num_epochs (int): 训练的轮数。
    batch_size (int): 批次大小。
    lr (float): 学习率。
    """
    # 确保模型保存目录存在
    if not os.path.exists('./model'):
        os.makedirs('./model')

    # 处理诗歌数据,生成向量化表示和映射字典
    poems_vector, word_to_idx, idx_to_word = process_poems(poems_path)
    # 初始化RNN模型并将其移动到指定设备
    model = RNNModel(len(idx_to_word), 128, num_layers=2).to(device)
    # 使用Adam优化器初始化训练器
    trainer = torch.optim.Adam(model.parameters(), lr=lr)
    # 使用交叉熵损失函数
    loss_fn = nn.CrossEntropyLoss()

    # 开始训练过程
    for epoch in range(num_epochs):
        loss_sum = 0
        start = time.time()

        # 生成并迭代训练批次
        for X, Y in generate_batch(batch_size, poems_vector, word_to_idx):
            # 将输入和目标数据移动到指定设备
            X = X.to(device)
            Y = Y.to(device)

            state = None
            # 前向传播
            outputs, state = model(X, state)
            Y = Y.view(-1)
            # 计算损失
            l = loss_fn(outputs, Y.long())
            # 反向传播和优化
            trainer.zero_grad()
            l.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01)
            trainer.step()
            loss_sum += l.item() * Y.shape[0]

        end = time.time()
        print(f"Time cost: {end - start}s")
        print(f"epoch: {epoch}, loss: {loss_sum / len(poems_vector)}")

    # 保存模型和优化器的状态
    try:
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': trainer.state_dict(),
        }, os.path.join('./model', 'torch-latest.pth'))
    except Exception as e:
        print(f"Error saving model: {e}")

if __name__ == "__main__":
    file_path = "./data/poems.txt"
    train(file_path, num_epochs=100, batch_size=64, lr=0.002)

经过了约 20 分钟的训练,终于,模型训练完成!训练结束后,模型的参数会自动保存到文件中,这样下次就可以直接加载预训练的模型,省去重新训练的麻烦

06测试模型

终于,我们来到了最激动人心的环节——AI作诗。经过几个小时的努力,我们的AI诗人已经准备好创作一首藏头诗,以此来弥补我因编程而失去的头发。
鸡枝蝉及九层峰,内邸曾随佛统衣。

你写明时何处寻,大江蕃戴帝来儿。

太古能弗岂何如,惟无百物恣蹉跎。

美人迟意识王机,马首辞来六堕愁。

测试代码

# 导入必要的库
import torch
from model import RNNModel
from poem_data_processing import process_poems
import numpy as np

# 定义开始和结束标记
start_token = 'B'
end_token = 'E'
# 模型保存的目录
model_dir = './model/'
# 诗歌数据文件路径
poems_file = './data/poems.txt'

# 学习率
lr = 0.0002

def to_word(predict, vocabs):
    """
    将预测结果转换为词汇表中的字。

    参数:
    predict: 模型的预测结果,一个概率分布。
    vocabs: 词汇表,包含所有可能的字。

    返回:
    从预测结果中随机选择的一个字。
    """
    predict = predict.numpy()[0]
    predict /= np.sum(predict)
    sample = np.random.choice(np.arange(len(predict)), p=predict)
    if sample > len(vocabs):
        return vocabs[-1]
    else:
        return vocabs[sample]

def gen_poem(begin_word):
    """
    生成诗歌。

    参数:
    begin_word: 诗歌的第一个字。

    返回:
    生成的诗歌,以字符串形式返回。
    """
    batch_size = 1
    # 处理诗歌数据,得到诗歌向量、字到索引的映射和索引到字的映射
    poems_vector, word_to_idx, idx_to_word = process_poems(poems_file)

    # 初始化模型
    model = RNNModel(len(idx_to_word), 128, num_layers=2)
    # 加载模型参数
    checkpoint = torch.load(f'{model_dir}/torch-latest.pth')
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    model.eval()

    # 初始化输入序列
    x = torch.tensor([word_to_idx[start_token]], dtype=torch.long).view(1, 1)
    hidden = None

    # 生成诗歌
    with torch.no_grad():
        output, hidden = model(x, hidden)
        predict = torch.softmax(output, dim=1)
        word = begin_word or to_word(predict, idx_to_word)
        poem_ = ''

        i = 0
        while word != end_token:
            poem_ += word
            i += 1
            if i > 24:
                break
            x = torch.tensor([word_to_idx[word]], dtype=torch.long).view(1, 1)
            output, hidden = model(x, hidden)
            predict = torch.softmax(output, dim=1)
            word = to_word(predict, idx_to_word)

        return poem_

def pretty_print_poem(poem_):
    """
    格式化打印诗歌。

    参数:
    poem_: 生成的诗歌,以字符串形式输入。
    """
    poem_sentences = poem_.split('。')
    for s in poem_sentences:
        if s != '' and len(s) > 10:
            print(s + '。')

if __name__ == '__main__':
    # 用户输入第一个字
    begin_char = input('请输入第一个字 please input the first character: \n')
    print('AI作诗 generating poem...')
    # 生成诗歌
    poem = gen_poem(begin_char)
    # 打印诗歌
    pretty_print_poem(poem_=poem)

效果出乎意料地好,所有的努力都值了。是不是觉得很有趣?快来下载代码,亲自体验AI作诗的乐趣吧。

项目目录

在这里插入图片描述

  • 训练模型,运行train.py文件。
  • 想直接体验AI作诗,运行test.py文件。

如果不想从头训练,可以直接使用预训练好的模型参数,这些参数已经保存在文件中,只需下载仓库的所有代码和文件即可
仓库地址:https://gitee.com/yw18791995155/generate_poetry.git

读到这里,如果你觉得这篇文章有点意思,不妨转发点赞。如果你对AI小项目感兴趣,欢迎关注我,我会持续分享更多有趣的项目。

感谢你的阅读,愿你的代码永远没有bug,头发永远浓密!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2246026.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

鸿蒙网络编程系列50-仓颉版TCP回声服务器示例

1. TCP服务端简介 TCP服务端是基于TCP协议构建的一种网络服务模式&#xff0c;它为HTTP&#xff08;超文本传输协议&#xff09;、SMTP&#xff08;简单邮件传输协议&#xff09;等高层协议的应用程序提供了可靠的底层支持。在TCP服务端中&#xff0c;服务器启动后会监听一个或…

基于 SpringBoot 的作业管理系统【附源码】

基于 SpringBoot 的作业管理系统 效果如下&#xff1a; 系统注册页面 学生管理页面 作业管理页面 作业提交页面 系统管理员主页面 研究背景 随着社会的快速发展&#xff0c;信息技术的广泛应用已经渗透到各个行业。在教育领域&#xff0c;课程作业管理是学校教学活动中的重要…

怎么只提取视频中的声音?从视频中提取纯音频技巧

在数字媒体的广泛应用中&#xff0c;提取视频中的声音已成为一项常见且重要的操作。无论是为了学习、娱乐、创作还是法律用途&#xff0c;提取声音都能为我们带来诸多便利。怎么只提取视频中的声音&#xff1f;本文将详细介绍提取声音的原因、工具、方法以及注意事项。 一、为什…

Java多态的优势和弊端

1. public class text {public static void main(String[] args) {animal dnew dog();d.eat();// dog a (dog) d;//类似强制转换//a.lookhome();/* if(d instanceof dog){dog a(dog)d;a.lookhome();}else if(d instanceof cat){cat c(cat) d;c.work();}else{System.out.print…

FPGA 14 ,硬件开发板分类详解,FPGA开发板与普通开发板烧录的区别

目录 前言 在嵌入式系统开发中&#xff0c;硬件开发板是工程师常用的工具之一。不同类型的开发板有不同的特点和用途&#xff0c;其中最常见的两大类是普通开发板和FPGA开发板。这里分享记录&#xff0c;这两类开发板的分类&#xff0c;并深入探讨它们在烧录过程中的具体区别…

冲破AI 浪潮冲击下的 迷茫与焦虑

在这个科技日新月异的时代&#xff0c;人工智能如汹涌浪潮般席卷而来&#xff0c;不断改变我们的生活。你是否对 AI 充满好奇&#xff0c;却不知它将如何改变你的工作与生活&#xff1f;又是否会在 AI 浪潮的冲击下陷入迷茫与焦虑&#xff1f;《AI 时代&#xff1a;弯道超车新思…

时序论文23|ICML24谷歌开源零样本时序大模型TimesFM

论文标题&#xff1a;A DECODER - ONLY FOUNDATION MODEL FOR TIME - SERIES FORECASTING 论文链接&#xff1a;https://arxiv.org/abs/2310.10688 论文链接&#xff1a;https://github.com/google-research/timesfm 前言 谷歌这篇时间序列大模型很早之前就在关注&#xff…

Redis的基本使用命令(GET,SET,KEYS,EXISTS,DEL,EXPIRE,TTL,TYPE)

目录 SET GET KEYS EXISTS DEL EXPIRE TTL redis中的过期策略是怎么实现的&#xff08;面试&#xff09; 上文介绍reids的安装以及基本概念&#xff0c;本章节主要介绍 Redis的基本使用命令的使用 Redis 是一个基于键值对&#xff08;KEY - VALUE&#xff09;存储的…

大疆上云api开发

目前很多公司希望使用上云api开发自己的无人机平台,但是官网资料不是特别全,下面浅谈一下本人开发过程中遇到的一系列问题。 本人使用机场为大疆机场2&#xff0c;飞机为M3TD&#xff0c;纯内网使用 部署 链接: 上云api代码. 首先从github上面拉去代码 上云api代码github. 后…

实现管易云到金蝶云星空的数据无缝集成

管易云数据集成到金蝶云星空&#xff1a;案例分享 在企业信息化系统中&#xff0c;数据的高效流动和准确对接是业务顺利运行的关键。本文将聚焦于一个具体的系统对接集成案例——通过轻易云数据集成平台实现管易云数据到金蝶云星空的无缝迁移&#xff0c;方案名称为“wk_店铺_…

Ubuntu上安装MySQL并且实现远程登录

目录 下载网络工具 查看网络连接 更新系统软件包&#xff1b; 安装mysql数据库 查看mysql数据库状态 以数字ip形式显示mysql的监听状态。&#xff08;默认监听端口是3306&#xff09; 查看安装mysql数据库时系统创建的目录信息。 根据查询到的系统用户名以及随机密码&a…

卷积神经网络各层介绍

目录 1 卷积层 2 BN层 3 激活层 3.1 ReLU&#xff08;Rectified Linear Unit&#xff09; 3.2 sigmoid 3.3 tanh&#xff08;双曲正切&#xff09; 3.4 Softmax 4 池化层 5 全连接层 6 模型例子 1 卷积层 卷积是使用一个卷积核&#xff08;滤波器&#xff09;对矩阵进…

LVS

一、 lvs简介 LVS:Linux Virtual Server &#xff0c;负载调度器&#xff0c;内核集成&#xff0c;章文嵩&#xff0c;阿里的四层 SLB(Server LoadBalance) 是基 于 LVSkeepalived 实现 LVS 官网 : http://www.linuxvirtualserver.org/ LVS 相关术语 VS: Virtual Serve…

使用 Elastic AI Assistant for Search 和 Azure OpenAI 实现从 0 到 60 的转变

作者&#xff1a;来自 Elastic Greg Crist Elasticsearch 推出了一项新功能&#xff1a;Elastic AI Assistant for Search。你可以将其视为 Elasticsearch 和 Kibana 开发人员的内置指南&#xff0c;旨在回答问题、引导你了解功能并让你的生活更轻松。在 Microsoft AI Services…

掺铒光纤激光器

一、光纤激光器的特点 实现灵活的激光光源&#xff08;窄线宽、可调谐、多波长、超短光脉冲源&#xff09;易获得高功率、高的光脉冲能量激光波长与光纤通信传输窗口相匹配采用激光器泵浦形式&#xff08;半导体激光器泵浦&#xff09;热稳定性、价格低廉、易小型化 二、放大…

AP+AC组网——STA接入

扫描 主动扫描&#xff1a;STA发送Probe Request帧&#xff0c;AP收到回复Probe Response 可以带着SSID扫描寻找指定WIFI&#xff0c;也可以带着空SSID扫描进入周围可用WLAN 被动扫描&#xff1a; 客户端通过侦听AP定期发送的Beacon帧&#xff08;100TUs&#xff0c;1TU1024…

基于 ESP-AT (v3.x)固件通过 AT+SYSMFG 指令更新证书设置

AT 固件里的证书文件通过 mfg_nvs.csv 文件管理&#xff0c;所有证书都是写入 mfg_nvs 分区。可以先查看 mfg_nvs.csv 文件的内容来确定有哪些证书文件被管理&#xff0c;如下&#xff1a; 通过 AT 指令更新证书的方式如下&#xff1a; // 获取证书类型 ATSYSMFG&#xff1f;/…

投资策略规划最优决策分析

目录 一、投资策略规划问题详细 二、存在最优投资策略&#xff1a;每年都将所有钱投入到单一投资产品中 &#xff08;一&#xff09;状态转移方程 &#xff08;二&#xff09;初始条件与最优策略 &#xff08;三&#xff09;证明最优策略总是将所有钱投入到单一投资产品中…

android 性能分析工具(03)Android Studio Profiler及常见性能图表解读

说明&#xff1a;主要解读Android Studio Profiler 和 常见性能图表。 Android Studio的Profiler工具是一套功能强大的性能分析工具集&#xff0c;它可以帮助开发者实时监控和分析应用的性能&#xff0c;包括CPU使用率、内存使用、网络活动和能耗等多个方面。以下是对Android …

(UI自动化测试)web自动化测试

web自动化测试 UI自动化测试介绍 自动化测试理论&#xff1a; 图片上的文字等等不能做测试&#xff0c;只能发现固定的bug 工具选择及介绍 浏览器驱动&#xff1a;找元素--核心&#xff1a;驱动&#xff08;操作元素&#xff09;--通过代码