使用PyTorch实现LSTM生成ai诗

news2024/11/24 22:38:14

最近学习torch的一个小demo。

什么是LSTM?

长短时记忆网络(Long Short-Term Memory,LSTM)是一种循环神经网络(RNN)的变体,旨在解决传统RNN在处理长序列时的梯度消失和梯度爆炸问题。LSTM引入了一种特殊的存储单元和门控机制,以更有效地捕捉和处理序列数据中的长期依赖关系。

通俗点说就是:LSTM是一种改进版的递归神经网络(RNN)。它的主要特点是可以记住更长时间的信息,这使得它在处理序列数据(如文本、时间序列、语音等)时非常有效。

步骤如下

数据准备

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import string
import os

# 数据加载和预处理
def load_data(filepath):
    with open(filepath, 'r', encoding='utf-8') as file:
        text = file.read()
    return text

def preprocess_text(text):
    text = text.lower()
    text = text.translate(str.maketrans('', '', string.punctuation))
    return text

data_path = 'poetry.txt'  # 替换为实际的诗歌数据文件路径
text = load_data(data_path)
text = preprocess_text(text)
chars = sorted(list(set(text)))
char_to_idx = {char: idx for idx, char in enumerate(chars)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
vocab_size = len(chars)

print(f"Total characters: {len(text)}")
print(f"Vocabulary size: {vocab_size}")

模型构建

定义LSTM模型:

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=2):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x, hidden):
        lstm_out, hidden = self.lstm(x, hidden)
        output = self.fc(lstm_out[:, -1, :])
        output = self.softmax(output)
        return output, hidden

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        hidden = (weight.new(self.num_layers, batch_size, self.hidden_size).zero_(),
                  weight.new(self.num_layers, batch_size, self.hidden_size).zero_())
        return hidden

训练模型

将数据转换成LSTM需要的格式:

def prepare_data(text, seq_length):
    inputs = []
    targets = []
    for i in range(0, len(text) - seq_length, 1):
        seq_in = text[i:i + seq_length]
        seq_out = text[i + seq_length]
        inputs.append([char_to_idx[char] for char in seq_in])
        targets.append(char_to_idx[seq_out])
    return inputs, targets

seq_length = 100
inputs, targets = prepare_data(text, seq_length)

# Convert to tensors
inputs = torch.tensor(inputs, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)

batch_size = 64
input_size = vocab_size
hidden_size = 256
output_size = vocab_size
num_epochs = 20
learning_rate = 0.001

model = LSTMModel(input_size, hidden_size, output_size)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    h = model.init_hidden(batch_size)
    total_loss = 0

    for i in range(0, len(inputs), batch_size):
        x = inputs[i:i + batch_size]
        y = targets[i:i + batch_size]
        x = nn.functional.one_hot(x, num_classes=vocab_size).float()
        
        output, h = model(x, h)
        loss = criterion(output, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(inputs):.4f}")

生成

def generate_text(model, start_str, length=100):
    model.eval()
    with torch.no_grad():
        input_eval = torch.tensor([char_to_idx[char] for char in start_str], dtype=torch.long).unsqueeze(0)
        input_eval = nn.functional.one_hot(input_eval, num_classes=vocab_size).float()
        h = model.init_hidden(1)
        predicted_text = start_str

        for _ in range(length):
            output, h = model(input_eval, h)
            prob = torch.softmax(output, dim=1).data
            predicted_idx = torch.multinomial(prob, num_samples=1).item()
            predicted_char = idx_to_char[predicted_idx]
            predicted_text += predicted_char

            input_eval = torch.tensor([[predicted_idx]], dtype=torch.long)
            input_eval = nn.functional.one_hot(input_eval, num_classes=vocab_size).float()
        
        return predicted_text

start_string = "春眠不觉晓"
generated_text = generate_text(model, start_string)
print(generated_text)

运行结果如下:

运行的肯定不好,但至少出结果了。诗歌我这边只放了几句,可以自己通过外部文件放入更多素材。

整体代码直接运行即可:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import string

# 预定义一些中文诗歌数据
text = """
春眠不觉晓,处处闻啼鸟。
夜来风雨声,花落知多少。
床前明月光,疑是地上霜。
举头望明月,低头思故乡。
红豆生南国,春来发几枝。
愿君多采撷,此物最相思。
"""


# 数据预处理
def preprocess_text(text):
    text = text.replace('\n', '')
    return text


text = preprocess_text(text)
chars = sorted(list(set(text)))
char_to_idx = {char: idx for idx, char in enumerate(chars)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
vocab_size = len(chars)

print(f"Total characters: {len(text)}")
print(f"Vocabulary size: {vocab_size}")


class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=2):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x, hidden):
        lstm_out, hidden = self.lstm(x, hidden)
        output = self.fc(lstm_out[:, -1, :])
        output = self.softmax(output)
        return output, hidden

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        hidden = (weight.new(self.num_layers, batch_size, self.hidden_size).zero_(),
                  weight.new(self.num_layers, batch_size, self.hidden_size).zero_())
        return hidden


def prepare_data(text, seq_length):
    inputs = []
    targets = []
    for i in range(0, len(text) - seq_length, 1):
        seq_in = text[i:i + seq_length]
        seq_out = text[i + seq_length]
        inputs.append([char_to_idx[char] for char in seq_in])
        targets.append(char_to_idx[seq_out])
    return inputs, targets


seq_length = 10
inputs, targets = prepare_data(text, seq_length)

# Convert to tensors
inputs = torch.tensor(inputs, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)

batch_size = 64
input_size = vocab_size
hidden_size = 256
output_size = vocab_size
num_epochs = 50
learning_rate = 0.003

model = LSTMModel(input_size, hidden_size, output_size)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    h = model.init_hidden(batch_size)
    total_loss = 0

    for i in range(0, len(inputs), batch_size):
        x = inputs[i:i + batch_size]
        y = targets[i:i + batch_size]
        if x.size(0) != batch_size:
            continue
        x = nn.functional.one_hot(x, num_classes=vocab_size).float()

        output, h = model(x, h)
        loss = criterion(output, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(inputs):.4f}")


def generate_text(model, start_str, length=100):
    model.eval()
    with torch.no_grad():
        input_eval = torch.tensor([char_to_idx[char] for char in start_str], dtype=torch.long).unsqueeze(0)
        input_eval = nn.functional.one_hot(input_eval, num_classes=vocab_size).float()
        h = model.init_hidden(1)
        predicted_text = start_str

        for _ in range(length):
            output, h = model(input_eval, h)
            prob = torch.softmax(output, dim=1).data
            predicted_idx = torch.multinomial(prob, num_samples=1).item()
            predicted_char = idx_to_char[predicted_idx]
            predicted_text += predicted_char

            input_eval = torch.tensor([[predicted_idx]], dtype=torch.long)
            input_eval = nn.functional.one_hot(input_eval, num_classes=vocab_size).float()

        return predicted_text


start_string = "春眠不觉晓"
generated_text = generate_text(model, start_string, length=100)
print(generated_text)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1827876.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C++高阶】C++继承学习手册:全面解析继承的各个方面

📝个人主页🌹:Eternity._ ⏩收录专栏⏪:C “ 登神长阶 ” 🤡往期回顾🤡:模板进阶 🌹🌹期待您的关注 🌹🌹 继承 📖1. 继承的概念及定义…

【云原生】docker swarm 使用详解

目录 一、前言 二、容器集群管理问题 2.1 docker集群管理问题概述 2.1.1 docker为什么需要容器部署 2.2 docker容器集群管理面临的挑战 三、docker集群部署与管理解决方案 四、Docker Swarm概述 4.1 Docker Swarm是什么 4.1.1 Docker Swarm架构图 4.1.2 Docker Swarm几…

摄影师在人工智能竞赛中与机器较量并获胜

摄影师在人工智能竞赛中与机器较量并获胜 自从生成式人工智能出现以来,由来已久的人机大战显然呈现出一边倒的态势。但是有一位摄影师,一心想证明用人眼拍摄的照片是有道理的,他向算法驱动的竞争对手发起了挑战,并取得了胜利。 迈…

数据资产治理与数据质量提升:构建完善的数据治理体系,确保数据资产的高质量与准确性

一、引言 随着信息技术的迅猛发展,数据已经成为企业和社会发展的重要资产。然而,数据资产的有效治理与数据质量的提升,是企业实现数字化转型、提升竞争力的关键。本文旨在探讨数据资产治理与数据质量提升的重要性,并提出构建完善…

Arnoldi Iteration 思考

文章目录 1. 投影平面2. Arnoldi Iteration3. python 代码 1. 投影平面 假设我们有一个向量q,我们需要关于向量q,构建一个投影平面P,使得给定任何向量v,可以通过公式 p P v pPv pPv,快速得到向量v在投影平面P上的投影向量p. 计算向量内积,…

Scala运算符及流程控制

Scala运算符及流程控制 文章目录 Scala运算符及流程控制写在前面运算符算数运算符关系运算符赋值运算符逻辑运算符位运算符运算符本质 流程控制分支控制单分支双分支多分支 循环控制for循环while循环循环中断嵌套循环 写在前面 操作系统:Windows10JDK版本&#xff…

redis源码编译安装

源码下载地址http://download.redis.io/releases/ 1 环境准备 安装编译环境 sudo yum install gcc -y gcc -v 查看版本 sudo yum -y install centos-release-scl sudo yum -y install devtoolset-10-gcc devtoolset-10-gcc-c devtoolset-10-binutils scl enable devtool…

马斯克在2024年特斯拉股东大会上的年度发言

马斯克表示,“如果市盈率是20或25倍,那就意味着,光是Optimus就能带来20万亿美元的市值。而自动驾驶汽车的市值可能在5到10万亿美元之间。因此,特斯拉的市值达到当今市值最高公司的10倍,是可以想象的,也是有…

【MySQL基础随缘更系列】DML语句

文章目录 一、表记录操作-上1.1、DML概述1.2、插入记录 二、表记录操作-下2.1、更新记录2.2、删除记录 🌈你好呀!我是 山顶风景独好 🎈欢迎踏入我的博客世界,能与您在此邂逅,真是缘分使然!😊 &a…

创新案例 | 3个关键策略:乳制品品牌认养一头牛如何通过私域流量运营获取1400万会员

探索认养一头牛如何运用创新的私域流量运营策略,在竞争激烈的乳制品市场中脱颖而出,实现会员数量的飞速增长至1400万。本文深入分析了其数据驱动的广告投放、高效的会员运营体系和创新的用户互动机制,为企业提供提升用户粘性和品牌忠诚度的宝…

第19章 大数据架构设计理论与实践

19.1 传统数据处理系统存在的问题 海量数据的,数据库过载,增加消息队列、甚至数据分区、读写分离、以及备份以及传统架构的性能的压榨式提升,都没有太明显的效果,帮助处理海量数据的新技术和新架构开发被提上日程。 19.2 大数据处…

设计模式——观察者模式(发布/订阅模式)

观察者模式(发布/订阅模式) 是一种行为模式,允许你定义一种订阅机制,可在对象事件发生时通知多个“观察”该对象的其他对象 观察者模式定义了一种一对多的依赖关系,让多个观察者对象同时监听某一主题对象。这个主题对象在状态发生变化时&am…

springboot汽车配件管理系统(源码+sql+论文报告)

绪论 1.1 研究意义和背景 随着我国经济的持续发展,汽车已经逐步进入了家庭。汽车行业的发展,也带动了汽车配件行业的快速发展。 汽车配件行业的迅猛发展, 使得汽配行业的竞争越来越激烈。如何在激烈的竞争中取胜,是每家汽车零部…

【MYSQL】MYSQL操作库

1.数据库字符编码集/数据库校验集 当我们在数据库中保存数据时,需要存和取时候编码一致,比方说你用汉语保存的数据,当你读的时候为了避免乱码问题,也必须用汉语读,这就叫做数据库字符编码集一致。 当我们进行查找&…

基于单片机的太阳能无线 LED 灯设计

摘 要 : 文章设计一款太阳能 LED 灯 , 经过太阳能给锂电池充电 , 利用 51 单片机通过检测电路对整个系统施行管理和监控, 可以使用手机和 WIFI 作为通信工具 , 利用光敏电阻检测光照 , 进而控制灯的亮…

【制作100个unity游戏之29】使用unity复刻经典游戏《愤怒的小鸟》(完结,附带项目源码)

最终效果 文章目录 最终效果前言素材下载简单搭建环境控制小鸟生成弹簧 限制小鸟的控制范围弹簧线的显示隐藏飞行新增木头木头销毁不同血量的木头状态配置更多物品爆炸效果创建敌人的小猪创建多个小鸟循环游戏结束相机跟随加分特效不同定义技能的鸟加速鸟回旋鸟爆炸鸟效果 轨迹…

快手爬票概述

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 无论是出差还是旅行,都无法离开交通工具的支持。现如今随着科技水平的提高,高铁与动车成为人们喜爱的交通工具。如果想要知道…

【C#】图形图像编程

实验目标和要求: 掌握C#图形绘制基本概念;掌握C#字体处理;能进行C#图形图像综合设计。 运行效果如下所示: 1.功能说明与核心代码 使用panel为画板,完成以下设计内容: 使用pen绘制基础图形;使…

浅谈golang字符编码

1、 Golang 字符编码 Golang 的代码是由 Unicode 字符组成的,并由 Unicode 编码规范中的 UTF-8 编码格式进行编码并存储。 Unicode 是编码字符集,囊括了当今世界使用的全部语言和符号的字符。有三种编码形式:UTF-8,UTF-16&#…

【LeetCode215】数组中的第K个最大元素

题目地址 1. 基本思路 用一个基准数e将集合S分解为不包含e在内的两个小集合 S 1 S_{1} S1​和 S 2 S_{2} S2​,其中 S 1 S_{1} S1​的任何元素均大于等于e, S 2 S_{2} S2​的任何元素均小于e,记 ∣ S ∣ |S| ∣S∣代表集合S元素的个数&…