古诗生成-pytorch

news2025/1/11 8:06:50

本文为RNN做古诗生成的一个小demo,只要是为了完成课上的作业(由于训练比较慢,所以周期仅设置为3,大一点性能可能会更好),如有需要可以在这基础之上进行加工,数据集没办法上传,如有需要,可以私信我。

LSTM:

如上图所示LSTM神经元存在两个状态向量:h(t)和c(t)(可将h(t)视为短期状态,c(t)视为长期状态) 首先,将当前输入向量x(t)和先前的短期状态h(t-1)馈入四个不同的全连接层(FC)。它们都有不同的目的:

主要层是输出g(t)的层:它通常的作用是分析当前输入x(t)和先前(短期)状态 h(t-1),得到本时间步的信息。

遗忘门(由f(t)控制):控制长期状态的哪些部分应当被删除。

输入门(由i(t)控制):控制应将g(t)的哪些部分添加到长期状态。

输出门(由o(t)控制):控制应在此时间步长读取长期状态的哪些部分并输出 到h(t)和y(t)。

如图1,LSTM神经元运用了三个sigmoid激活函数和一个tanh激活函数,

Tanh 作用在于帮助调节流经网络的值,使得数值始终限制在 -1 和 1 之间。
Sigmoid 激活函数与 tanh 函数类似,不同之处在于 sigmoid 是把值压缩到0~1 这样的设置有助于更新或忘记信息,可将其理解为比例(任何数乘以 0 都得 0,这部分信息就会剔除掉;同样的,任何数乘以 1 都得到它本身,这部分信息就会完美地保存下来)因记忆能力有限,记住重要的,忘记不重要的。
例子:以输入门为例,首先输入x(t)和先前(短期)状态 h(t-1),得到本时间步的信息向量g(t) = (g1(t),g2(t),g3(t)……gn(t))(其中n个神经元的个数,g1(t)取值范围为(-1,1)),然后与向量i(t)=(i1(t),i2(t),i3(t)……in(t))(ii(t)取值范围为(0,1))对应元素相乘,得到向量(g1(t)*i1(t), g2(t)*i2(t)……gn(t)*in(t)),即本时间步有用信息,然后把他加上长期记忆c(t-1)中进行保存。

LSM关键的思想是网络可以学习长期状态下存储的内容、丢弃的内容以及从中读取的内容。当长期状态c(t-1)从左到右遍历网络时,可以看到它首先经过一个遗 忘门,丢掉了一些记忆,然后通过加法操作添加了一些新的记忆(由输入门选择的记忆)。结果c(t)直接送出来,无须任何进一步的转换。因此,在每个时间步长中,都会 丢掉一些记忆,并添加一些记忆。此外,在加法运算之后,长期状态被复制并通过tanh函数传输,然后结果被输出门滤波。这将产生短期状态h(t)(等于该时间步长的单元输出 y(t))。

 原理:

本文使用LSTM生成古诗,那么RNN是怎么用作我们的文本生成呢?话不多说,其实用RNN来生成的思想很简单, 就是将前一个字进行词嵌入,后一个字作为标签,将这个组合输入到RNN的网络里面等待训练拟合之后,再用一个引导词,训练出它的预测结果,再用其预测结果,来训练下一个词,循环往复,从而实现RNN生成文本的效果.
 

main.py

import numpy as np
import collections
import torch
from torch.autograd import Variable
import torch.optim as optim

import rnn

start_token = 'G'
end_token = 'E'
batch_size = 64

def process_poems1(file_name):
    """

    :param file_name:
    :return: poems_vector  have two dimmention ,first is the poem, the second is the word_index
    e.g. [[1,2,3,4,5,6,7,8,9,10],[9,6,3,8,5,2,7,4,1]]

    """
    poems = []
    i = 1
    with open(file_name, "r", encoding='utf-8', ) as f:
        for line in f.readlines():
            try:
                i = i+1

                title, content = line.strip().split(':')
                # content = content.replace(' ', '').replace(',','').replace('。','')
                content = content.replace(' ', '')
                if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content or \
                                start_token in content or end_token in content:
                    continue
                if len(content) < 5 or len(content) > 80:
                    continue
                content = start_token + content + end_token
                poems.append(content)
            except ValueError as e:
                print(line)
                print(i)
                print("error")
                pass
    # 按诗的字数排序
    poems = sorted(poems, key=lambda line: len(line))
    # print(poems)
    # 统计每个字出现次数
    all_words = []
    j = 0
    for poem in poems:
        all_words += [word for word in poem]  # 数据连接
    counter = collections.Counter(all_words)  # 统计词和词频。

    count_pairs = sorted(counter.items(), key=lambda x: -x[1])  # d.items() 以列表的形式返回可遍历的元组数组 逆序排序
    words, _ = zip(*count_pairs) # zip(*) 可理解为解压,返回二维矩阵式

    words = words[:len(words)] + (' ',) #(‘ ’,) 为一个元素的元祖

    word_int_map = dict(zip(words, range(len(words))))
    poems_vector = [list(map(word_int_map.get, poem)) for poem in poems] # 第一位为一个函数,后一位为一个迭代器
    return poems_vector, word_int_map, words # 诗句的向量表示,单词映射表,单词表 

def process_poems2(file_name):
    """
    :param file_name:
    :return: poems_vector  have tow dimmention ,first is the poem, the second is the word_index
    e.g. [[1,2,3,4,5,6,7,8,9,10],[9,6,3,8,5,2,7,4,1]]

    """
    poems = []
    with open(file_name, "r", encoding='utf-8', ) as f:
        # content = ''
        for line in f.readlines():
            try:
                line = line.strip()
                if line:
                    content = line.replace(' '' ', '').replace(',','').replace('。','')
                    if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content or \
                                    start_token in content or end_token in content:
                        continue
                    if len(content) < 5 or len(content) > 80:
                        continue
                    # print(content)
                    content = start_token + content + end_token
                    poems.append(content)
                    # content = ''
            except ValueError as e:
                # print("error")
                pass
    # 按诗的字数排序
    poems = sorted(poems, key=lambda line: len(line))
    # print(poems)
    # 统计每个字出现次数
    all_words = []
    for poem in poems:
        all_words += [word for word in poem]
    counter = collections.Counter(all_words)  # 统计词和词频。
    count_pairs = sorted(counter.items(), key=lambda x: -x[1])  # 排序
    words, _ = zip(*count_pairs)
    words = words[:len(words)] + (' ',)
    word_int_map = dict(zip(words, range(len(words))))
    poems_vector = [list(map(word_int_map.get, poem)) for poem in poems]
    return poems_vector, word_int_map, words

def generate_batch(batch_size, poems_vec, word_to_int):
    #生成训练数据

    n_chunk = len(poems_vec) // batch_size  #34813/100 = 348 古诗的向量表示
    x_batches = []
    y_batches = []
    for i in range(n_chunk):
        start_index = i * batch_size
        end_index = start_index + batch_size
        x_data = poems_vec[start_index:end_index]
        y_data = []
        for row in x_data:
            y  = row[1:]
            y.append(row[-1])
            y_data.append(y)
        """
        x_data             y_data
        [6,2,4,6,9]       [2,4,6,9,9]  文本生成,所以用后面一位数据做label
        [1,4,2,8,5]       [4,2,8,5,5]
        """
        # print(x_data[0])
        # print(y_data[0])
        # exit(0)
        x_batches.append(x_data)
        y_batches.append(y_data)
    return x_batches, y_batches


def run_training():
    # 处理数据集
    # poems_vector, word_to_int, vocabularies = process_poems2('./tangshi.txt')
    poems_vector, word_to_int, vocabularies = process_poems1('./poems.txt') 
    # 生成batch
    print("finish  loadding data")
    BATCH_SIZE = 100

    torch.manual_seed(5)
    word_embedding = rnn.word_embedding( vocab_length= len(word_to_int) + 1 , embedding_dim= 100) #6123 x 100
    #print(word_embedding.shape)

    rnn_model = rnn.RNN_model(batch_sz = BATCH_SIZE,vocab_len = len(word_to_int) + 1 ,word_embedding = word_embedding ,embedding_dim= 100, lstm_hidden_dim=128)
    # optimizer = optim.Adam(rnn_model.parameters(), lr= 0.001)
    optimizer=optim.RMSprop(rnn_model.parameters(), lr=0.01)

    loss_fun = torch.nn.NLLLoss()
    # rnn_model.load_state_dict(torch.load('./poem_generator_rnn'))  # if you have already trained your model you can load it by this line.

    for epoch in range(3):
        batches_inputs, batches_outputs = generate_batch(BATCH_SIZE, poems_vector, word_to_int) #生成训练数据 由batch组成的数组 348
        n_chunk = len(batches_inputs)
        for batch in range(n_chunk):
            batch_x = batches_inputs[batch]
            batch_y = batches_outputs[batch] # (batch , time_step)

            loss = 0
            for index in range(BATCH_SIZE):  #batch_size = 100
                x = np.array(batch_x[index], dtype = np.int64)
                y = np.array(batch_y[index], dtype = np.int64)
        
                x = Variable(torch.from_numpy(np.expand_dims(x,axis=1))) #将数组转换成张量 np.expand_dims扩展数据的形状 x.sahpe = 7x1, 
                y = Variable(torch.from_numpy(y ))
                pre = rnn_model(x) # 7 x 6125
                loss += loss_fun(pre , y)
                if index == 0:
                    _, pre = torch.max(pre, dim=1)# pre为张量,tolist转换成列表 
                    print('prediction', pre.data.tolist()) # the following  three line can print the output and the prediction
                    print('b_y       ', y.data.tolist())   # And you need to take a screenshot and then past is to your homework paper.
                    print('*' * 30)
            loss  = loss  / BATCH_SIZE
            print("epoch  ",epoch,'batch number',batch,"loss is: ", loss.data.tolist())
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(rnn_model.parameters(), 1)    # 梯度裁剪 可以预防梯度爆炸,参数的平方和
            optimizer.step()  #训练参数

            if batch % 20 ==0:
                torch.save(rnn_model.state_dict(), './poem_generator_rnn')
                print("finish  save model")



def to_word(predict, vocabs):  # 预测的结果转化成汉字
    sample = np.argmax(predict)

    if sample >= len(vocabs):
        sample = len(vocabs) - 1
    return vocabs[sample]


def pretty_print_poem(poem):  # 令打印的结果更工整
    shige=[]
    for w in poem:
        if w == start_token or w == end_token:
            break
        shige.append(w)
    poem_sentences = poem.split('。')
    for s in poem_sentences:
        if s != '' and len(s) > 2:
        #     print(s + '。')
            print(s + '。')


def gen_poem(begin_word):
    # poems_vector, word_int_map, vocabularies = process_poems2('./tangshi.txt')  #  use the other dataset to train the network
    poems_vector, word_int_map, vocabularies = process_poems1('./poems.txt')
    word_embedding = rnn.word_embedding(vocab_length=len(word_int_map) + 1, embedding_dim=100)
    rnn_model = rnn.RNN_model(batch_sz=64, vocab_len=len(word_int_map) + 1, word_embedding=word_embedding,
                                   embedding_dim=100, lstm_hidden_dim=128)

    rnn_model.load_state_dict(torch.load('./poem_generator_rnn'))
    # 指定开始的字

    poem = begin_word
    word = begin_word
    while word != end_token:
        input = np.array([word_int_map[w] for w in poem],dtype= np.int64)
        input = Variable(torch.from_numpy(input))
        output = rnn_model(input, is_test=True)
        word = to_word(output.data.tolist(), vocabularies)
        poem += word
        if len(poem) > 30:
            break
    return poem



#run_training()  # 如果不是训练阶段 ,请注销这一行 。 网络训练时间很长。


pretty_print_poem(gen_poem("日"))
pretty_print_poem(gen_poem("红"))
pretty_print_poem(gen_poem("山"))
pretty_print_poem(gen_poem("夜"))
pretty_print_poem(gen_poem("湖"))
pretty_print_poem(gen_poem("湖"))
pretty_print_poem(gen_poem("湖"))
pretty_print_poem(gen_poem("君"))

rnn.py

import torch.nn as nn
import torch
from torch.autograd import Variable
import torch.nn.functional as F

import numpy as np

def weights_init(m):
    classname = m.__class__.__name__  #   obtain the class name
    if classname.find('Linear') != -1:
        weight_shape = list(m.weight.data.size()) #6123 x 128
        fan_in = weight_shape[1]
        fan_out = weight_shape[0]
        w_bound = np.sqrt(6. / (fan_in + fan_out))
        m.weight.data.uniform_(-w_bound, w_bound)
        m.bias.data.fill_(0)
        print("inital  linear weight ")


class word_embedding(nn.Module):
    def __init__(self,vocab_length , embedding_dim):
        super(word_embedding, self).__init__()
        w_embeding_random_intial = np.random.uniform(-1,1,size=(vocab_length ,embedding_dim)) #生成服从均匀分布的随机数
        self.word_embedding = nn.Embedding(vocab_length,embedding_dim) #创建一个embedding层
        self.word_embedding.weight.data.copy_(torch.from_numpy(w_embeding_random_intial))
    def forward(self,input_sentence):
        """
        :param input_sentence:  a tensor ,contain several word index.
        :return: a tensor ,contain word embedding tensor
        """
        sen_embed = self.word_embedding(input_sentence)
        return sen_embed


class RNN_model(nn.Module):
    def __init__(self, batch_sz ,vocab_len ,word_embedding,embedding_dim, lstm_hidden_dim):
        super(RNN_model,self).__init__()

        self.word_embedding_lookup = word_embedding
        self.batch_size = batch_sz
        self.vocab_length = vocab_len
        self.word_embedding_dim = embedding_dim
        self.lstm_dim = lstm_hidden_dim
        #########################################
        # here you need to define the "self.rnn_lstm"  the input size is "embedding_dim" and the output size is "lstm_hidden_dim"
        # the lstm should have two layers, and the  input and output tensors are provided as (batch, seq, feature)
        # ???

        self.rnn_lstm = nn.LSTM(input_size=embedding_dim,hidden_size=lstm_hidden_dim, num_layers=2,batch_first=True)

        ##########################################
        self.fc = nn.Linear(lstm_hidden_dim, vocab_len )
        self.apply(weights_init) # call the weights initial function.
        self.softmax = nn.LogSoftmax() # the activation function.
        # self.tanh = nn.Tanh()
    def forward(self,sentence,is_test = False):
        batch_input = self.word_embedding_lookup(sentence).view(1,-1,self.word_embedding_dim)  # sentence=[7,1] [7x1x100] batch_input=[1,7,100])
        # print(batch_input.size()) # print the size of the input
        ################################################
        # here you need to put the "batch_input"  input the self.lstm which is defined before.
        # the hidden output should be named as output, the initial hidden state and cell state set to zero.
        # ???
        #print(batch_input.shape)
        output,_ = self.rnn_lstm(batch_input) # 1x7x128
        ################################################
        out = output.contiguous().view(-1,self.lstm_dim)  #1x128
        #print(out.shape)
        out =  F.relu(self.fc(out))
        out = self.softmax(out)

        if is_test:
            prediction = out[ -1, : ].view(1,-1) #[1,6125]
            #prediction = torch.max(out,0)
            output = prediction
        else:
           output = out

        # print(out)
        return output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/590629.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

FreeRTOS_从底层学习实时操作系统

目录 1. 裸机系统和多任务系统 2. 任务的定任务切换的实现 2.1 什么是任务&#xff1f; 2.2 调度器 2.3 临界段 3. 空闲任务和阻塞延迟 4. 时间片 1. 裸机系统和多任务系统 裸机系统&#xff1a; 裸机系统分为轮询系统和前后台系统&#xff1b;&#xff08;51单片机就属…

八大排序:直接插入排序、希尔排序、选择排序、堆排序、冒泡排序、快速排序、归并排序、计数排序

文章目录 排序概念常见的排序算法常见排序算法的实现直接插入排序希尔排序选择排序堆排序冒泡排序快速排序递归实现Hoare版本挖坑法前后指针法 非递归实现Hoare版本挖坑法前后指针法 快速排序俩个优化 归并排序递归实现非递归实现外排序 计数排序 常见排序算法的性能分析 排序概…

【已完美解决】scons问题求助:如何设置编译输出目录搞清楚后,有些编译输出的obj文件却在源码目录,而不是设置的输出目录。

【已完美解决】scons问题求助&#xff1a;如何设置编译输出目录搞清楚后&#xff0c;有些编译输出的obj文件却在源码目录&#xff0c;而不是设置的输出目录。 文章目录 1 前置背景2 我的疑问3 一手点拨4 问题解决 1 前置背景 最近在基于目前已有的rt-thread构建框架&#xff0…

【Spring源码解读一】IoC容器之AnnotationConfigApplicationContext

根据AnnotationConfigApplicationContext类去阅读其将Bean对象交给IoC容器管理的过程。以下这三个代码块是将配置类注册进IoC容器的例子。下面是关于这个类的继承与实现的类图关系树。 public class Test {public static void main(String[] args) {// 配置类注册进IoC容器App…

解决Ubuntu16中安装opencv后找不到vtk库的问题

最近一个项目中要用到OpenCV的VTK库&#xff0c;但引入头文件#include <opencv2/viz.hpp>时却说找不到这个库&#xff0c;网上搜了下说在编译opencv源码的时候&#xff0c;需要加上编译VTK库的选项&#xff0c;于是重新下载、编译、安装了源码&#xff0c;在cmake时加上了…

最流行的AI绘图工具Midjourney,你不得不知道的使用技巧

​关注文章下方公众号&#xff0c;可免费获取AIGC最新学习资料 本文字数&#xff1a;1500&#xff0c;阅读时长大约&#xff1a;10分钟 Midjourney成为了最受欢迎的生成式AI工具之一。它的使用很简单。输入一些文本&#xff0c;Midjourney背后的大脑&#xff08;或计算机&#…

Linux 权限

目录 一、 从问题开始 问题一: 什么叫shell? 问题二: 为什么不能直接使用kernel呢? 问题三: shell 与bash 有什么不同吗? 二、 Linux权限 0x01 Linux用户 0x02 切换用户命令 0x03 sudo命令 0x04 权限的相关概念 0x05 chmod 0x06 chown 0x07 chgrp 0x08 文件权…

重磅!软著申请不需要邮寄纸质材料啦,附软著申请流程。

重磅&#xff01;软著申请不需要邮寄纸质材料啦&#xff0c;附软著申请流程。 最新消息申请流程一&#xff0c;准备申请材料二&#xff0c;申请人填写申请表三&#xff0c;提交申请材料四&#xff0c;补正五&#xff0c;审查六&#xff0c;发布公告七&#xff0c;接受异议八&am…

力扣---二叉树OJ题(多种题型二叉树)

文章目录 前言&#x1f31f;一、剑指 Offer 55 - I. 二叉树的深度&#x1f30f;1.1 链接&#xff1a;&#x1f30f;1.2 代码一&#xff1a;&#x1f30f;1.3 代码二&#xff1a;&#x1f30f;1.4 流程图&#xff1a; &#x1f31f;二、100. 相同的树&#x1f30f;2.1 链接&…

超强实用!利用xfsdump和xfsrestore打造无懈可击的数据备份与恢复策略

前言 上次我们分析了EXT文件系统的恢复方式&#xff0c;借助于extundelete工具仅可以恢复EXT类型的文件&#xff0c;但无法恢复CentOS 7系统&#xff0c;因为centos7默认采用xfs类型的文件。 xfs文件系统恢复工具有以下几种&#xff1a; xfsprogs&#xff1a;xfs文件系统扩展…

HTB MonitorsTwo

MonitorsTwo HTB MonitorsTwo 老规矩信息收集了&#xff1a; NMAP信息收集 ┌──(kali㉿kali)-[~/桌面] └─$ sudo nmap --min-rate 1000 10.10.11.211 Starting Nmap 7.93 ( https://nmap.org ) at 2023-05-19 09:18 CST Nmap scan report for 10.10.11.211 Host is up…

Python入门(十六)函数(四)

函数&#xff08;四&#xff09; 1.传递列表1.1 在函数中修改列表 2.传递任意数量的实参2.1 结合使用位置实参和任意数量实参2.2 使用任意数量的关键字实参 作者&#xff1a;Xiou 1.传递列表 我们经常会发现&#xff0c;向函数传递列表很有用&#xff0c;其中包含的可能是名字…

设计模式-模板方法模式

模板方法模式 问题背景解决方案&#xff1a;模板方法模式基本介绍解决问题代码示例运行结果 钩子方法注意事项和细节 问题背景 豆浆的制作&#xff1a; 1&#xff09;制作豆浆的流程&#xff1a;选材—>添加配料—>浸泡—>放到豆浆机打碎 2&#xff09;通过添加不同…

高可用性和双机热备浅析

在用户眼里&#xff0c;业务需要永远正常对外提供服务&#xff0c;这就要求应用系统的高可用&#xff08;High availability&#xff0c;即 HA&#xff09;。高可用主要是针对架构而言&#xff0c;第一步一般会采用分层的思想将一个庞大的应用系统拆分成应用层、中间件、数据存…

SpringBoot+MyBatis-plus实现CRUD (踩坑总结!!!)

一、创建项目&#xff0c;引入相应的依赖 (项目源码在文末) &#xff08;不要选Module !!!!&#xff09; <dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-boot-starter</artifactId><version>3.4.2</version&g…

linux常见的二十多个指令

目录 一、指令的概念 二、28个常见的指令 ⭐2.1 ls指令 ⭐2.2 pwd指令 ⭐2.3 cd指令 ⭐2.4tree指令 ⭐2.5 mkdir指令 ⭐2.6 touch指令 ⭐2.7 rmdir指令 ⭐2.8 rm指令 ⭐2.9 clear指令 ⭐2.10 man指令 ⭐2.11 cp指令 ⭐2.12 mv指令 ⭐2.13 cat指令&#xff08;适…

正规文法、正规表达式、有限自动机及其之间的转换(笔记)

The Equivalent Transforming among RG, RE and FA 正规文法 A Grammar G is a quadruple (四元组):G (VN, VT, S, P ) Where, VN is a finite set of nonterminals.VT is a finite set of terminals.S is the start symbol, S ∈ \in ∈ VN.P is a finite set of product…

.Net 使用OpenAI开源语音识别模型Whisper

.Net 使用OpenAI开源语音识别模型 Whisper 前言 Open AI在2022年9月21日开源了号称其英文语音辨识能力已达到人类水准的 Whisper 神经网络&#xff0c;且它亦支持其它98种语言的自动语音辨识。 Whisper系统所提供的自动语音辨识&#xff08;Automatic Speech Recognition&…

python基础知识(四):input语句、if语句和pass语句

目录 1. input语句2. 强制转换3. if语句4. pass语句 1. input语句 input语句是程序获取从键盘输入的内容&#xff0c;会把输入的内容自动转换成字符串。 使用方法: 变量名 input(“提示语”) 例如 language input("你最爱什么语言?") print(language)这两行代码…

RK3588平台开发系列讲解(项目篇)常见模型结构

平台内核版本安卓版本RK3588Linux 5.10Android 12文章目录 一、DNN二、CNN三、RNN沉淀、分享、成长,让自己和他人都能有所收获!😄 📢 AI 模型常采用人工神经网络来模拟人脑神经的记忆和处理信号的能力。常见的人工神经网络类型有深度神经网络(Deep Neural Network,DNN)…