深度学习之生成唐诗案例(Pytorch版)

news2024/11/16 11:49:30

主要思路:

对于唐诗生成来说,我们定义一个"S" 和 "E"作为开始和结束。

 示例的唐诗大概有40000多首,

首先数据预处理,将唐诗加载到内存,生成对应的word2idx、idx2word、以及唐诗按顺序的字序列。

Dataset_Dataloader.py
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


def deal_tangshi():
    with open("tangshi.txt", "r", encoding="utf-8") as fr:
        lines = fr.read().strip().split("\n")

    tangshis = []
    for line in lines:
        splits = line.split(":")
        if len(splits) != 2:
            continue
        tangshis.append("S" + splits[1] + "E")

    word2idx = {"S": 0, "E": 1}
    word2idx_count = 2

    tangshi_ids = []

    for tangshi in tangshis:
        for word in tangshi:
            if word not in word2idx:
                word2idx[word] = word2idx_count
                word2idx_count += 1

    idx2word = {idx: w for w, idx in word2idx.items()}

    for tangshi in tangshis:
        tangshi_ids.extend([word2idx[w] for w in tangshi])

    return word2idx, idx2word, tangshis, word2idx_count, tangshi_ids


word2idx, idx2word, tangshis, word2idx_count, tangshi_ids = deal_tangshi()


class TangShiDataset(Dataset):
    def __init__(self, tangshi_ids, num_chars):
        # 语料数据
        self.tangshi_ids = tangshi_ids
        # 语料长度
        self.num_chars = num_chars
        # 词的数量
        self.word_count = len(self.tangshi_ids)
        # 句子数量
        self.number = self.word_count // self.num_chars

    def __len__(self):
        return self.number

    def __getitem__(self, idx):
        # 修正索引值到: [0, self.word_count - 1]
        start = min(max(idx, 0), self.word_count - self.num_chars - 2)

        x = self.tangshi_ids[start: start + self.num_chars]
        y = self.tangshi_ids[start + 1: start + 1 + self.num_chars]

        return torch.tensor(x), torch.tensor(y)


def __test_Dataset():
    dataset = TangShiDataset(tangshi_ids, 8)
    x, y = dataset[0]

    print(x, y)


if __name__ == '__main__':
    # deal_tangshi()
    __test_Dataset()
TangShiModel.py:唐诗的模型
import torch
import torch.nn as nn
from Dataset_Dataloader import *
import torch.nn.functional as F


class TangShiRNN(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # 初始化词嵌入层
        self.ebd = nn.Embedding(vocab_size, 128)
        # 循环网络层
        self.rnn = nn.RNN(128, 128, 1)
        # 输出层
        self.out = nn.Linear(128, vocab_size)

    def forward(self, inputs, hidden):

        embed = self.ebd(inputs)

        # 正则化层
        embed = F.dropout(embed, p=0.2)

        output, hidden = self.rnn(embed.transpose(0, 1), hidden)

        # 正则化层
        embed = F.dropout(output, p=0.2)

        output = self.out(output.squeeze())

        return output, hidden

    def init_hidden(self):
        return torch.zeros(1, 64, 128)

 main.py:

import time

import torch

from Dataset_Dataloader import *
from TangShiModel import *
import torch.optim as optim
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train():
    dataset = TangShiDataset(tangshi_ids, 128)
    epochs = 100
    model = TangShiRNN(word2idx_count).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for idx in range(epochs):
        dataloader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)
        start_time = time.time()
        total_loss = 0
        total_num = 0
        total_correct = 0
        total_correct_num = 0
        hidden = model.init_hidden()

        for x, y in tqdm(dataloader):
            x = x.to(device)
            y = y.to(device)
            # 隐藏状态
            hidden = model.init_hidden()
            hidden = hidden.to(device)
            # 模型计算
            output, hidden = model(x, hidden)
            # print(output.shape)
            # print(y.shape)
            # 计算损失
            loss = criterion(output.permute(1, 2, 0), y)
            # 梯度清零
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 参数更新
            optimizer.step()

            total_loss += loss.sum().item()
            total_num += len(y)
            total_correct_num += y.shape[0] * y.shape[1]
            # print(output.shape)
            total_correct += (torch.argmax(output.permute(1, 0, 2), dim=-1) == y).sum().item()

        print("epoch : %d average_loss : %.3f average_correct : %.3f use_time : %ds" %
              (idx + 1, total_loss / total_num, total_correct / total_correct_num, time.time() - start_time))

        torch.save(model.state_dict(), f"./modules/tangshi_module_{idx + 1}.bin")


if __name__ == '__main__':
    train()

predict.py:

import torch
import torch.nn as nn
from Dataset_Dataloader import *
from TangShiModel import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def predict():
    model = TangShiRNN(word2idx_count)
    model.load_state_dict(torch.load("./modules/tangshi_module_100.bin", map_location=torch.device('cpu')))

    model.eval()

    hidden = torch.zeros(1, 1, 128)

    start_word = input("输入第一个字:")

    flag = None

    tangshi_strs = []

    while True:
        if not flag:
            outputs, hidden = model(torch.tensor([[word2idx["S"]]], dtype=torch.long), hidden)
            tangshi_strs.append("S")
            flag = True
        else:
            tangshi_strs.append(start_word)
            outputs, hidden = model(torch.tensor([[word2idx[start_word]]], dtype=torch.long), hidden)
            top_i = torch.argmax(outputs, dim=-1)

            if top_i.item() == word2idx["E"]:
                break

            print(top_i)

            start_word = idx2word[top_i.item()]
        print(tangshi_strs)


if __name__ == '__main__':
    predict()

完整代码如下:

https://github.com/STZZ-1992/tangshi-generator.giticon-default.png?t=N7T8https://github.com/STZZ-1992/tangshi-generator.git

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1236309.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

注册中心CAP架构剖析

Nacos 支持 AP 或 CP AP Nacos 通过临时节点实现 AP 架构,将服务列表放在内存中; CP Nacos 通过持久化节点实现 CP 架构,将服务列表放在文件中,并同步到内存,通过 Raft 协议算法实现; 通过配置 epheme…

中科创达:所有产品都可以用生成式AI重做一遍

对于制造企业的数字化转型来说,生成式AI究竟具备怎样的意义和价值? 在与亚马逊云科技的合作中,中科创达对此有着深刻的领会和感悟。 生成式AI助力制造业数字化转型 “科技是第一生产力”,对于这句脍炙人口的名言,制造企…

x shell 用作串口调试助手

x shell 用作串口调试助手 Xshell 介绍 是一个强大的安全终端模拟软件,它支持SSH1, SSH2, 以及Microsoft Windows 平台的TELNET 协议。Xshell 通过互联网到远程主机的安全连接以及它创新性的设计和特色帮助用户在复杂的网络环境中享受他们的工作。 Xshell可以在Wi…

PDF文件无密码,如何解密?

PDF文件有两种密码,一个打开密码、一个限制编辑密码,因为PDF文件设置了密码,那么打开、编辑PDF文件就会受到限制。想要解密,我们需要输入正确的密码,但是有时候我们可能会出现忘记密码的情况,或者网上下载P…

tomcat (SCI)ServletContainerInitializer 的加载原理

问题:使用WebScoket的时候发现通过ServerEndpoint方式注册上去的url无法访问,报错404 经过排查发现在WsServerContainer这个类中的addEndpoint方法一直没有触发ServerEndpoint注解的扫描 通过该方法来源于StandardContext.startInternal()方法的调用如下…

基于单片机仓库温湿度监测报警系统仿真设计

**单片机设计介绍,基于单片机仓库温湿度监测报警系统仿真设计 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机的仓库温湿度监测报警系统可以被设计成能够实时监测仓库内的温度和湿度,并根据预设…

steamui.dll找不到指定模块,要怎么修复steamui.dll文件

当我们使用Steam进行游戏时,有时可能会面对一些令人无奈的技术问题。一种常见的问题是“找不到指定模块steamui.dll”,这可能是由于缺少文件、文件损坏或软件冲突等原因导致。但别担心,这篇文章将提供几种解决此问题的方法,并针对…

设计模式总结-笔记

一个目标:管理变化,提供复用! 两种手段:分解vs.抽象 八大原则: 依赖倒置原则(DIP) 开放封闭原则(OCP) 单一职责原则(SRP) Liskov替换原则&a…

【Qt一坑】qt编译出现“常量中有换行符”

在qt编译过程中出现“常量中有换行符”,原因有以下几点(qt版本5.14.2): 1.中文编码格式问题,将UTF-8编码格式改成 UTF-8 BOM。 或者使用QtCreator 进行如下设置(找到Qt的左边列表里的项目,下的…

git -1

1.创建第一个仓库并配置local用户信息 git config git config --global 对当前用户所有仓库有效 git config --system 对系统所有登录的用户有效 git config --local 只对某个仓库有效 git config --list 显示配置 git config --list --global 所有仓库 git config --list…

Vue3鼠标拖拽生成区域块并选中元素

Vue3鼠标拖拽生成区域块并选中元素&#xff0c;选中的元素则背景高亮(或者其它逻辑)。 <script setup> import { ref } from vue// 区域ref const regionRef ref(null)// 内容ref const itemRefs ref(null)// 是否开启绘画区域 const enable ref(false)// 鼠标开始位置…

第十二章 pytorch中使用tensorboard进行可视化(工具)

PyTorch 从 1.2.0 版本开始&#xff0c;正式自带内置的 Tensorboard 支持了&#xff0c;我们可以不再依赖第三方工具来进行可视化。 tensorboard官方教程地址&#xff1a;https://github.com/tensorflow/tensorboard/blob/master/README.md 1、tensorboard 下载 step 1 此次…

时间复杂度和运算

时间复杂度 在算法和数据结构中&#xff0c;有许多时间复杂度比 O(1) 更差的情况。以下是一些常见的时间复杂度&#xff0c;按照从最优到最差的顺序排列&#xff1a; O(1)&#xff1a; 常数时间复杂度&#xff0c;操作的运行时间与输入规模无关&#xff0c;是最理想的情况。 O…

源码安装Apache

一、下载Apache,源码安装Apache #下载 [rootlocalhost opt]# wget -c https://mirrors.aliyun.com/apache/httpd/httpd-2.4.58.tar.gz [rootlocalhost opt]# ls httpd-2.4.58.tar.gz [rootlocalhost opt]# tar -xf httpd-2.4.58.tar.gz [rootlocalhost opt]# ls httpd-2.4.58…

移远通信推出六款新型天线,为物联网客户带来更丰富的产品选择

近日&#xff0c;移远通信重磅推出六款新型天线&#xff0c;覆盖5G、非地面网络&#xff08;NTN&#xff09;等多种新技术&#xff0c;将为物联网终端等产品带来全新功能和更强大的连接性能。 移远通信COO张栋表示&#xff1a;“当前&#xff0c;物联网应用除了需要高性能的天线…

【Docker】从零开始:4.为什么Docker会比VM虚拟机快

【Docker】从零开始&#xff1a;4.为什么Docker会比VM虚拟机快 docker有着比虚拟机更少的抽象层docker利用的是宿主机的内核,而不需要加载操作系统OS内核 docker有着比虚拟机更少的抽象层 由于docker不需要Hypervisor(虚拟机)实现硬件资源虚拟化,运行在docker容器上的程序直接…

【C++】C++11(1)

文章目录 一、C11简介二、统一的列表初始化1.&#xff5b;&#xff5d;初始化2.std::initializer_list 三、声明1.auto2.decltype3.nullptr 四、STL中一些变化五、右值引用和移动语义1.左值引用和右值引用2.左值引用与右值引用比较3.右值引用使用场景和意义4.右值引用引用左值及…

如何化解从数据到数据资源入表的难题

继数据成为生产要素后&#xff0c;各种跟数据相关的概念就出来了&#xff0c;首先我们要弄明白有关数据的几个高频词汇。数据&#xff1a;指“原始数据”&#xff0c;即记录事实的结果&#xff0c;用来描述事实的未经加工的素材。数据资源&#xff1a;指加工后具有经济价值的数…

使用Pytorch实现linear_regression

使用Pytorch实现线性回归 # import necessary packages import torch import torch.nn as nn import numpy as np import matplotlib.pyplot as plt# Set necessary Hyper-parameters. input_size 1 output_size 1 num_epochs 60 learning_rate 0.001# Define a Toy datas…

Unity团结引擎使用总结

团结引擎创世版以 Unity 2022 LTS 为研发基础&#xff0c;与 Unity 2022 LTS 兼容、UI 也基本保持一致&#xff0c;使 Unity 开发者可以无缝转换到团结引擎。融入了团结引擎独有功能和优化&#xff0c;未来会加入更多为中国开发者量身定制的功能和优化。 目前正在内测&#xf…