RNN:文本生成

news2024/9/20 16:59:14

文章目录

    • 一、完整代码
    • 二、过程实现
      • 2.1 导包
      • 2.2 数据准备
      • 2.3 字符分词
      • 2.4 构建数据集
      • 2.5 定义模型
      • 2.6 模型训练
      • 2.7 模型推理
    • 三、整体总结

采用RNN和unicode分词进行文本生成

一、完整代码

作者在文章开头地址中使用C++实现了这一过程,为了便于理解,这里我们使用python代码进行实现

# 完整代码在这里
import tensorflow as tf
import keras_nlp
import numpy as np

tokenizer = keras_nlp.tokenizers.UnicodeCodepointTokenizer(vocabulary_size=400)

# tokens - ids
ids = tokenizer(['Why are you so funny?', 'how can i get you'])

# ids - tokens
tokenizer.detokenize(ids)

def split_input_target(sequence):
    input_text = sequence[:-1]
    target_text = sequence[1:]
    return input_text, target_text

# 准备数据
text = open('./shakespeare.txt', 'rb').read().decode(encoding='utf-8')
dataset = tf.data.Dataset.from_tensor_slices(tokenizer(text))
dataset = dataset.batch(64, drop_remainder=True)
dataset = dataset.map(split_input_target).batch(64)


input, ouput = dataset.take(1).get_single_element()

# 定义模型

d_model = 512
rnn_units = 1025

class CustomModel(tf.keras.Model):
    def __init__(self, vocabulary_size, d_model, rnn_units):
        super().__init__(self)
        self.embedding = tf.keras.layers.Embedding(vocabulary_size, d_model)
        self.gru = tf.keras.layers.GRU(rnn_units, return_sequences=True, return_state=True)
        self.dense = tf.keras.layers.Dense(vocabulary_size, activation='softmax')

    def call(self, inputs, states=None, return_state=False, training=False):
        x = inputs
        x = self.embedding(x)
        if states is None:
            states = self.gru.get_initial_state(x)
        x, states = self.gru(x, initial_state=states, training=training)
        x = self.dense(x, training=training)
        if return_state:
            return x, states
        else:
            return x

model = CustomModel(tokenizer.vocabulary_size(), d_model, rnn_units)

# 查看模型结构
model(input)
model.summary()

# 模型配置
model.compile(
    loss = tf.losses.SparseCategoricalCrossentropy(),
    optimizer='adam',
    metrics=['accuracy']
)

# 模型训练
model.fit(dataset, epochs=3)

# 模型推理
class InferenceModel(tf.keras.Model):
    def __init__(self, model, tokenizer):
        super().__init__(self)
        self.model = model
        self.tokenizer = tokenizer

    def generate(self, inputs, length, return_states=False):
        inputs = inputs = tf.constant(inputs)[tf.newaxis]
        
        states = None
        input_ids = self.tokenizer(inputs).to_tensor()
        outputs = []
        for i in range(length):
            predicted_logits, states = model(inputs=input_ids, states=states, return_state=True)
            input_ids = tf.argmax(predicted_logits, axis=-1)
            outputs.append(input_ids[0][-1].numpy())

        outputs = self.tokenizer.detokenize(lst).numpy().decode('utf-8')
        if return_states:
            return outputs, states
        else:
            return outputs

infere = InferenceModel(model, tokenizer)


# 开始推理
start_chars = 'hello'
outputs = infere.generate(start_chars, 1000)
print(start_chars + outputs)

二、过程实现

2.1 导包

先导包tensorflow, keras_nlp, numpy

import tensorflow as tf
import keras_nlp
import numpy as np

2.2 数据准备

数据来自莎士比亚的作品 storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt;我们将其下载下来存储为shakespeare.txt

2.3 字符分词

这里我们使用unicode分词:将所有字符都作为一个词来进行分词

tokenizer = keras_nlp.tokenizers.UnicodeCodepointTokenizer(vocabulary_size=400)

# tokens - ids
ids = tokenizer(['Why are you so funny?', 'how can i get you'])

# ids - tokens
tokenizer.detokenize(ids)

2.4 构建数据集

利用tokenizertext数据构建数据集

def split_input_target(sequence):
    input_text = sequence[:-1]
    target_text = sequence[1:]
    return input_text, target_text

text = open('./shakespeare.txt', 'rb').read().decode(encoding='utf-8')
dataset = tf.data.Dataset.from_tensor_slices(tokenizer(text))
dataset = dataset.batch(64, drop_remainder=True)
dataset = dataset.map(split_input_target).batch(64)


input, ouput = dataset.take(1).get_single_element()

2.5 定义模型

d_model = 512
rnn_units = 1025

class CustomModel(tf.keras.Model):
    def __init__(self, vocabulary_size, d_model, rnn_units):
        super().__init__(self)
        self.embedding = tf.keras.layers.Embedding(vocabulary_size, d_model)
        self.gru = tf.keras.layers.GRU(rnn_units, return_sequences=True, return_state=True)
        self.dense = tf.keras.layers.Dense(vocabulary_size, activation='softmax')

    def call(self, inputs, states=None, return_state=False, training=False):
        x = inputs
        x = self.embedding(x)
        if states is None:
            states = self.gru.get_initial_state(x)
        x, states = self.gru(x, initial_state=states, training=training)
        x = self.dense(x, training=training)
        if return_state:
            return x, states
        else:
            return x

model = CustomModel(tokenizer.vocabulary_size(), d_model, rnn_units)

# 查看模型结构
model(input)
model.summary()

2.6 模型训练

model.compile(
    loss = tf.losses.SparseCategoricalCrossentropy(),
    optimizer='adam',
    metrics=['accuracy']
)

model.fit(dataset, epochs=3)

2.7 模型推理

定义一个InferenceModel进行模型推理配置;

class InferenceModel(tf.keras.Model):
    def __init__(self, model, tokenizer):
        super().__init__(self)
        self.model = model
        self.tokenizer = tokenizer

    def generate(self, inputs, length, return_states=False):
        inputs = inputs = tf.constant(inputs)[tf.newaxis]
        
        states = None
        input_ids = self.tokenizer(inputs).to_tensor()
        outputs = []
        for i in range(length):
            predicted_logits, states = model(inputs=input_ids, states=states, return_state=True)
            input_ids = tf.argmax(predicted_logits, axis=-1)
            outputs.append(input_ids[0][-1].numpy())

        outputs = self.tokenizer.detokenize(lst).numpy().decode('utf-8')
        if return_states:
            return outputs, states
        else:
            return outputs

infere = InferenceModel(model, tokenizer)


start_chars = 'hello'
outputs = infere.generate(start_chars, 1000)
print(start_chars + outputs)

生成结果如下所示,感觉很差:

hellonofur us:
medous, teserwomador.
walled o y.
as
t aderemowate tinievearetyedust. manonels,
w?
workeneastily.
watrenerdores aner'shra
palathermalod, te a y, s adousced an
ptit: mamerethus:
bas as t: uaruriryedinesm's lesoureris lares palit al ancoup, maly thitts?
b veatrt
watyeleditenchitr sts, on fotearen, medan ur
tiblainou-lele priniseryo, ofonet manad plenerulyo
thilyr't th
palezedorine.
ti dous slas, sed, ang atad t,
wanti shew.
e
upede wadraredorenksenche:
wedemen stamesly ateara tiafin t t pes:
t: tus mo at
io my.
ane hbrelely berenerusedus' m tr;
p outellilid ng
ait tevadwantstry.
arafincara, es fody
'es pra aluserelyonine
pales corseryea aburures
angab:
sunelyothe: s al, chtaburoly o oonis s tioute tt,
pro.
tedeslenali: s 't ing h
sh, age de, anet: hathes: s es'tht,
as:
wedly at s serinechamai:
mored t.
t monatht t athoumonches le.
chededondirineared
t

er
p y
letinalys
ani
aconen,
t rs:
t;et, tes-
luste aly,
thonort aly one telus, s mpsantenam ranthinarrame! a
pul; bon
s fofuly

三、整体总结

RNN结合unicode分词能进行文本生成但是效果一言难尽!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1273722.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

近期知识点

aop (1) 要求利用AOP记录用户操作日志。内容包含以下信息:ip地址、用户名、请求的地址,请求的时间 ( 4 分) (2)要求利用AOP记录用户操作日志,日志记录到文本文件。内容包含以下信息&#xff…

[读论文][跑代码]BK-SDM: A Lightweight, Fast, and Cheap Version of Stable Diffusion

github: GitHub - Nota-NetsPresso/BK-SDM: A Compressed Stable Diffusion for Efficient Text-to-Image Generation [ICCV23 Demo] [ICML23 Workshop] ICML 2023 Workshop on ES-FoMo 简化方式 蒸馏方式(训练Task蒸馏outKD-FeatKD) 训练数据集 评测指标…

最新Midjourney绘画提示词Prompt

最新Midjourney绘画提示词Prompt 一、AI绘画工具 SparkAi【无需魔法使用】: SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如何搭建部署AI创作ChatGPT?小编这里写一个详细图文教程吧!本系统使用NestjsVueTypescript框架技术&am…

Vue 和 React 的优点分别是什么?如何选择?

目录 为什么我更喜欢Vue? 低代码平台的前端框架采用Vue的优势有哪些? JNPF-Web-Vue3 的技术栈介绍 (1)Vue3.x (2)Vue-router4.x (3)Vite4.x (4)Ant-D…

Echarts 设置数据条颜色 宽度

设置数据条颜色(推荐) let yData [{value: 500,time: 2012-11-12,itemStyle: //设置数据条颜色{normal: { color: red }}},{value: 454,time: 2020-5-17},{value: 544,time: 2022-1-22},{value: 877,time: 2013-1-30}, {value: 877,time: 2012-11-12}]…

如何通过linux调用企业微信发送告警消息

一、前期准备 1、企业微信具备管理企业权限。 2、服务器有公网IP或者可以将本机端口通过net映射到公网。 二、通过脚本向企业微信发送消息 1、创建sh脚本用来发送消息。 vim 2.sh 注意:脚本中xxxx信息需要在企业微信管理后台获取。 #!/bin/bash # 设置企业…

力扣:1419. 数青蛙

题目&#xff1a; 代码&#xff1a; class Solution { public:int minNumberOfFrogs(string croakOfFrogs){string s "croak";int ns.size();//首先创建一个哈希表来标明每个元素出现的次数&#xff01;vector<int>hash(n); //不用真的创建一个hash表用一个数…

一、Linux系统概述和安装

目录 1、Linux系统概述 2、Linux发行版介绍 3、虚拟机软件介绍 4、VMware安装 5、Linux系统&#xff08;CentOS&#xff09;系统安装 6、登录并查看IP地址 7、Linux连接工具CRT使用 7.1 概述 7.2 CRT安装 7.3 使用步骤 7.4 文件上传 8、Linux的快照 8.1 作用 8.2…

传统算法:使用 Pygame 实现二分查找

使用 Pygame 模块实现了二分查找的动画演示。首先,它生成一个有序数组,并通过 Pygame 在屏幕上绘制这个数组的条形图。接着,通过二分查找算法对有序数组进行查找,动画效果可视化每一步的变化。在查找的过程中,程序通过比较目标值和数组中间元素,逐步缩小搜索范围,高亮显…

Python-简单模拟斗地主洗牌发牌

额滴名片儿 &#x1f388; 博主&#xff1a;一只程序猿子 &#x1f388; 博客主页&#xff1a;一只程序猿子 博客主页 &#x1f388; 个人介绍&#xff1a;爱好(bushi)编程&#xff01; &#x1f388; 创作不易&#xff1a;如喜欢麻烦您点个&#x1f44d;或者点个⭐&#xff01…

【人工智能Ⅰ】实验5:AI实验箱应用之贝叶斯

实验5 AI实验箱应用之贝叶斯 一、实验目的 1. 用实验箱的摄像头拍摄方块上数字的图片&#xff0c;在图像处理的基础上&#xff0c;应用贝叶斯方法识别图像中的数字并进行分类。 二、实验内容和步骤 1. 应用实验箱机械手臂上的摄像头拍摄图像&#xff1b; 2. Opencv处理图像…

生成对抗网络(GAN)手写数字生成

文章目录 一、前言二、前期工作1. 设置GPU&#xff08;如果使用的是CPU可以忽略这步&#xff09; 二、什么是生成对抗网络1. 简单介绍2. 应用领域 三、网络结构四、构建生成器五、构建鉴别器六、训练模型1. 保存样例图片2. 训练模型 七、生成动图 一、前言 我的环境&#xff1…

可行性研究:2023年废旧金属回收行业前景及市场数据分析

废品收购是再生资源行业的重要业务之一。是指将各种废弃物品分类后按不同种类和性能卖给不同的生产厂商或直接出售给再制造厂家&#xff08;如重新使用报废汽车拆解的零件&#xff09;。废旧金属是指暂时失去使用价值的金属或合金制品&#xff0c;一般的废旧金属都含有有用的金…

车牌限行_分支结构的C语言实现xdoj7

试题名称 车牌限行 时间限制: 1 秒 内存限制: 256KB 问题描述 问题描述 受雾霾天气影响&#xff0c;某市决定当雾霾指数超过设定值时对车辆进行限行&#xff0c;假设车牌号全为数字&#xff0c;且长度不超过6位&#xff0c;限行规则如下&#xff1a; &#xff08;…

C++相关闲碎记录(2)

1、误用shared_ptr int* p new int; shared_ptr<int> sp1(p); shared_ptr<int> sp2(p); //error // 通过原始指针两次创建shared_ptr是错误的shared_ptr<int> sp1(new int); shared_ptr<int> sp2(sp1); //ok 如果对C相关闲碎记录(1)中记录的shar…

AI - Steering behaviorsII(碰撞避免,跟随)

Steering Behaviors系统中的碰撞避免&#xff0c;路径跟随&#xff0c;队长跟随 Collision Avoid 在物体前进的方向&#xff0c;延伸一定长度的向量进行检测。相当于物体对前方一定可使范围进行检测障碍物的碰撞 延伸的向量与碰撞物圆心的距离小于碰撞物的半径&#xff0c;则…

docker-compose脚本编写及常用命令

安装 linux DOCKER_CONFIG/usr/local/lib/docker/cli-plugins sudo mkdir -p $DOCKER_CONFIG/cli-plugins sudo curl -SL https://521github.com/docker/compose/releases/download/v2.6.1/docker-compose-linux-x86_64 -o $DOCKER_CONFIG/cli-plugins/docker-compose sudo c…

numpy知识库:深入理解numpy.resize函数和数组的resize方法

前言 numpy中的resize函数顾名思义&#xff0c;可以用于调整数组的大小。但具体如何调整&#xff1f;数组形状变了&#xff0c;意味着数组中的元素个数发生了变化(增加或减少)&#xff0c;如何确定resize后的新数组中每个元素的数值呢&#xff1f;本次博文就来探讨并试图回答这…

润申信息企业标准化管理系统 SQL注入漏洞复现

0x01 产品简介 润申信息科技企业标准化管理系统通过给客户提供各种灵活的标准法规信息化管理解决方案&#xff0c;帮助他们实现了高效的标准法规管理&#xff0c;完成个性化标准法规库的信息化建设。 0x02 漏洞概述 润申信息科技企业标准化管理系统 CommentStandardHandler.as…

蓝桥杯每日一题2023.11.30

题目描述 九数组分数 - 蓝桥云课 (lanqiao.cn) 题目分析 此题目实际上是使用dfs进行数字确定&#xff0c;每次循环中将当前数字与剩下的数字进行交换 eg.1与2、3、4、、、进行交换 2与3、4、、、进行交换 填空位置将其恢复原来位置即可&#xff0c;也就直接将其交换回去即可…