pytorch之诗词生成3--utils

news2024/11/16 13:42:14

先上代码:


import numpy as np
import settings


def generate_random_poetry(tokenizer, model, s=''):
    """
    随机生成一首诗
    :param tokenizer: 分词器
    :param model: 用于生成古诗的模型
    :param s: 用于生成古诗的起始字符串,默认为空串
    :return: 一个字符串,表示一首古诗
    """
    # 将初始字符串转成token
    token_ids = tokenizer.encode(s)
    # 去掉结束标记[SEP]
    token_ids = token_ids[:-1]
    while len(token_ids) < settings.MAX_LEN:
        # 进行预测,只保留第一个样例(我们输入的样例数只有1)的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布
        output = model(np.array([token_ids, ], dtype=np.int32))
        _probas = output.numpy()[0, -1, 3:]
        del output
        # print(_probas)
        # 按照出现概率,对所有token倒序排列
        p_args = _probas.argsort()[::-1][:100]
        # 排列后的概率顺序
        p = _probas[p_args]
        # 先对概率归一
        p = p / sum(p)
        # 再按照预测出的概率,随机选择一个词作为预测结果
        target_index = np.random.choice(len(p), p=p)
        target = p_args[target_index] + 3
        # 保存
        token_ids.append(target)
        if target == 3:
            break
    return tokenizer.decode(token_ids)


def generate_acrostic(tokenizer, model, head):
    """
    随机生成一首藏头诗
    :param tokenizer: 分词器
    :param model: 用于生成古诗的模型
    :param head: 藏头诗的头
    :return: 一个字符串,表示一首古诗
    """
    # 使用空串初始化token_ids,加入[CLS]
    token_ids = tokenizer.encode('')
    token_ids = token_ids[:-1]
    # 标点符号,这里简单的只把逗号和句号作为标点
    punctuations = [',', '。']
    punctuation_ids = {tokenizer.token_to_id(token) for token in punctuations}
    # 缓存生成的诗的list
    poetry = []
    # 对于藏头诗中的每一个字,都生成一个短句
    for ch in head:
        # 先记录下这个字
        poetry.append(ch)
        # 将藏头诗的字符转成token id
        token_id = tokenizer.token_to_id(ch)
        # 加入到列表中去
        token_ids.append(token_id)
        # 开始生成一个短句
        while True:
            # 进行预测,只保留第一个样例(我们输入的样例数只有1)的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布
            output = model(np.array([token_ids, ], dtype=np.int32))
            _probas = output.numpy()[0, -1, 3:]
            del output
            # 按照出现概率,对所有token倒序排列
            p_args = _probas.argsort()[::-1][:100]
            # 排列后的概率顺序
            p = _probas[p_args]
            # 先对概率归一
            p = p / sum(p)
            # 再按照预测出的概率,随机选择一个词作为预测结果
            target_index = np.random.choice(len(p), p=p)
            target = p_args[target_index] + 3
            # 保存
            token_ids.append(target)
            # 只有不是特殊字符时,才保存到poetry里面去
            if target > 3:
                poetry.append(tokenizer.id_to_token(target))
            if target in punctuation_ids:
                break
    return ''.join(poetry)

我们首先看第一个函数generate_random_poetry,用于随机生成古诗。

def generate_random_poetry(tokenizer, model, s=''):
    """
    随机生成一首诗
    :param tokenizer: 分词器
    :param model: 用于生成古诗的模型
    :param s: 用于生成古诗的起始字符串,默认为空串
    :return: 一个字符串,表示一首古诗
    """

 tokenizer是一个分词器,用于将文本转化为模型可以理解的token序列。

# 将初始字符串转成token
    token_ids = tokenizer.encode(s)
    # 去掉结束标记[SEP]
    token_ids = token_ids[:-1]

是将我们传入的字符串转化为id序列。并使用切片操作切除token序列的最后一个元素。这个元素是用来标记结束标志[SEP]的。

我们接着看:

 while len(token_ids) < settings.MAX_LEN:
        # 进行预测,只保留第一个样例(我们输入的样例数只有1)的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布
        output = model(np.array([token_ids, ], dtype=np.int32))
        _probas = output.numpy()[0, -1, 3:]
        del output

这里我们就可以看出为什么我们要对最后一个 token进行切片处理了,是为了控制预测的长度。

之后我们进行预测,通过model模型对当前的token_ids序列进行预测,model接受一个形状为(batch_size,sequence_length)的输入,这里batch_size为1,sequence_length为token_ids的长度。因此,我们将token_ids包装成一个形状为(1,sequence_length)的numpy数组,并将其传递给model进行预测。

输出的结果是一个3D数组,形状为(1,sequence_length,vocab_size)。其中vocab_size是词汇表,得到的是对下一个词的预测。

然后通过output.numpy()将输出转化为numpy数组,由于我们只关注当前序列的最后一个token的预测结果,所以使用[0,-1,3:]来获取这部分概率分布。具体的,[0,-1]表示获取第一个样例的最后一个token的预测,而[3:]表示去除前三个token,即[PAD],[UNK]和[CLS]。

最后,通过del output释放output变量所占用的内存空间,这样做是为了及时释放内存,避免内存占用过多。

解释一下四个特殊的token在自然语言处理任务中的不同作用。

  • [PAD](填充):在序列中用于填充长度不足的部分,使得序列长度统一,在很多深度学习模型中,要求输入的序列长度是相同的,因此当序列长度不足时,可以使用[PAD]进行填充,使得序列达到统一的长度。
  • [UNK](未知):用于表示模型未见过的或者无法识别的词汇,当模型在处理文本时遇到不在词汇表中的词汇,就会用[UNK]表示,这样可以避免模型对于未知词汇的处理错误。
  • [CLS](分类):在自然语言处理中的一些任务中,例如文本分类,文本语义相似度等,[CLS]被用作特殊的起始标记。模型会将[CLS]作为整个序列的表示,并用于后续任务的处理。例如:对于文本分类任务,模型可以根据[CLS]表示进行预测分类标签。
  • [SEP](分隔符):用于表示序列的分隔,常见于处理多个句子或多个文本之间的任务。他可以用来分隔句子,句子对,文本片段等,在一些任务中,如文本对分类,问答系统等,输入可能包含多个句子或文本片段,模型可以通过识别[SEP]来理解不同句子之间的关系。另外,[SEP]还可以用于标记序列的结束,在一些 情况下,序列的长度是固定的,使用[SEP]标记还可以表示序列的结束。

继续看我们的代码:

 # 按照出现概率,对所有token倒序排列
        p_args = _probas.argsort()[::-1][:100]
        # 排列后的概率顺序
        p = _probas[p_args]
        # 先对概率归一
        p = p / sum(p)
        # 再按照预测出的概率,随机选择一个词作为预测结果
        target_index = np.random.choice(len(p), p=p)
        target = p_args[target_index] + 3
        # 保存
        token_ids.append(target)

 首先,通过argsort()函数对得到的所有outputs进行从小到大的排序,然后经过[::-1]将排序结果进行反转,这样就变成了从大到小的排序,[:100]表示只选择排名前100的token。返回的是_probas中元素从小到大排序的索引数组。

得到索引数组之后,我们通过p=_probas[p_args],根据排名靠前的token的索引p_args,从预测的概率分布_probas中获取对应的概率值,这样得到的p就是排列后的概率顺序(概率值从大到小)。

之后对概率进行归一化,将概率值除以概率之和,使其概率和为1。

target_index = np.random.choice(len(p), p=p)根据参数p(给定的概率分布)在长度为len(p)的范围内进行随机选择,并返回选择的索引target_index。(显然,元素对应的p值越大,被选择的概率就越高)。

后面的target = p_args[target_index] + 3,其中+3是必不可少的,我刚刚因为没有+3就运行报错了,因为我们得到的target是对应于token_ids的索引,而token_ids的前三项是特殊字符,为了和实际中的模型相对应,我们需要将预测的token编号做一些偏离处理。

最后我们将我们的预测列表中加上我们预测的词的索引值,也就是target。

看着一段:

 if target == 3:
        break
return tokenizer.decode(token_ids)

如果生成的target的值是3,表示生成的token是特殊的[CLS]标记,这意味着生成的序列已结束。
返回我们解码之后的结果也就是生成的诗词。等于3的情况是存在的,3意味着结束标志SPE,当句子的长度合适时,其预测的下一个词很可能是3,我们规定3也属于可预测范围。

看生成藏头诗的代码:

def generate_acrostic(tokenizer, model, head):
    """
    随机生成一首藏头诗
    :param tokenizer: 分词器
    :param model: 用于生成古诗的模型
    :param head: 藏头诗的头
    :return: 一个字符串,表示一首古诗
    """
    # 使用空串初始化token_ids,加入[CLS]
    token_ids = tokenizer.encode('')
    token_ids = token_ids[:-1]
    # 标点符号,这里简单的只把逗号和句号作为标点
    punctuations = [',', '。']
    punctuation_ids = {tokenizer.token_to_id(token) for token in punctuations}
    # 缓存生成的诗的list
    poetry = []
    # 对于藏头诗中的每一个字,都生成一个短句
    for ch in head:
        # 先记录下这个字
        poetry.append(ch)
        # 将藏头诗的字符转成token id
        token_id = tokenizer.token_to_id(ch)
        # 加入到列表中去
        token_ids.append(token_id)
        # 开始生成一个短句
        while True:
            # 进行预测,只保留第一个样例(我们输入的样例数只有1)的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布
            output = model(np.array([token_ids, ], dtype=np.int32))
            _probas = output.numpy()[0, -1, 3:]
            del output
            # 按照出现概率,对所有token倒序排列
            p_args = _probas.argsort()[::-1][:100]
            # 排列后的概率顺序
            p = _probas[p_args]
            # 先对概率归一
            p = p / sum(p)
            # 再按照预测出的概率,随机选择一个词作为预测结果
            target_index = np.random.choice(len(p), p=p)
            target = p_args[target_index] + 3
            # 保存
            token_ids.append(target)
            # 只有不是特殊字符时,才保存到poetry里面去
            if target > 3:
                poetry.append(tokenizer.id_to_token(target))
            if target in punctuation_ids:
                break
    return ''.join(poetry)

我们看一下tokenizer.id_to_token(3,2,1,0)

这段代码和上面的随机生成古诗类似,不同的是:

上面的是完整生成一首古诗,其中就算遇到“,”或者“。”不会中断,直到遇到结束符号[SEP]才会中止。

而我们生成的藏头诗是对每一个字进行生成,这里遇到“,”,“。”就进行下一个字的继续生成。然后将每个字的自动生成进行拼接。得到完整的诗句。(当然,遇到[SEP]也是进行终止的,保证了句式的协调)。

return ''.join(poetry)

作用是将我们的peotry中的内容(短句),连接成一个字符串。并将其作为函数值返回。

如果不加join函数的话,输出是一个列表:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1517936.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深度解析Java JDK 1.8中Stream流的源码实现:带你探寻数据流的奥秘

文章目录 一、 Stream流概述1.1 什么是Stream流&#xff0c;以及它的主要特点和优势1.2 Stream流的基本操作&#xff1a;过滤、映射、排序等 二、 Stream流源码解析2.1 接口和基本概念2.2 创建流2.3 源码分析2.3.1 流的起始2.3.2 流的初始2.3.3 认识BaseStream2.3.4 Stream接口…

软考高级:软件工程单元测试(驱动模块、被测模块、桩模块)概念和例题

作者&#xff1a;明明如月学长&#xff0c; CSDN 博客专家&#xff0c;大厂高级 Java 工程师&#xff0c;《性能优化方法论》作者、《解锁大厂思维&#xff1a;剖析《阿里巴巴Java开发手册》》、《再学经典&#xff1a;《Effective Java》独家解析》专栏作者。 热门文章推荐&am…

Qt 线程池 QThreadPool

一.Qt 线程池 QThreadPool介绍 Qt线程池是一种管理多个线程的并发编程模型&#xff0c;通过使用线程池可以提高性能、控制并发度、提供任务队列和简化线程管理。 在Qt中&#xff0c;线程池的使用主要涉及以下几个步骤&#xff1a; 创建任务类&#xff1a;需要定义一个任务类&am…

【计算机视觉】一、计算机视觉概述

文章目录 一、计算机视觉二、计算机视觉与其它学科领域的关系1、图像处理2、计算机图形学3、模式识别4、人工智能&#xff08;AI&#xff09;5、神经生理学与认知科学 三、计算机视觉的应用1. 人脸识别2. 目标检测3. 图像生成4. 城市建模5. 电影特效6. 体感游戏动作捕捉7. 虚拟…

Android 仿天通卫星对准(卫星在圆形卫星轨道上转动)效果实现

效果图 View源码 package com.android.circlescalebar.view;import android.animation.ObjectAnimator; import android.content.Context; import android.graphics.Bitmap; import android.graphics.BitmapFactory; import android.graphics.Canvas; import android.graphics…

linux对于文件操作其他命令

tac&#xff0c;与cat不同&#xff0c;tac可以倒序查看文件内容 管道符&#xff0c;可以将第一条命令的结果当作第二条命令的输入 more分屏显示文件内容 head&#xff0c;可以查看文件前多少行&#xff0c;tail可以查看文件后多少行

JavaEE之多线程(创建线程的五种写法)详解

&#x1f63d;博主CSDN主页: 小源_&#x1f63d; &#x1f58b;️个人专栏: JavaEE &#x1f600;努力追逐大佬们的步伐~ 目录 1. 前言 2. 操作系统"内核" 3. 创建线程的五种写法 (我们重点要掌握最后一种写法!!) 3.1 继承 Thread, 重写 run 3. 2 实现 Runnabl…

SpringBoot Servlet容器启动解析

介绍 容器架构 容器处理请求 容器启动全局流程解析 启动前准备 WebServer创建入口 WebServer创建 Servlet启动 Web容器工厂类加载解析 Web容器个性化配置 属性注入 工厂类初始化 BeanPostProcessor方法实现 定制化流程 面试题 请描述下Servlet容器启动流程&#xff1f;介绍下…

最新android icon和splashScreen适配兼容至2024android

android在12做了splashScreen的变动&#xff0c;即&#xff0c;android12有自带的screenSplash过渡&#xff0c;不论你是否自己有变化&#xff0c;都会插入该动画。 android8做了icon的巨大变动。13做了图标的主题兼容。 一、icon制作 制作 使用android自带的工具&#xff0…

Modbus -tcp协议使用第二版

1.1 协议描述 1.1.1 总体通信结构 MODBUS TCP/IP 的通信系统可以包括不同类型的设备&#xff1a; &#xff08;1&#xff09;连接至 TCP/IP 网络的 MODBUS TCP/IP 客户机和服务器设备&#xff1b; &#xff08;2&#xff09;互连设备&#xff0c;例如&#xff1a;在 TCP/IP…

外包干了5天,技术明显退步。。。。。

先说一下自己的情况&#xff0c;本科生&#xff0c;19年通过校招进入南京某软件公司&#xff0c;干了接近2年的功能测试&#xff0c;今年年初&#xff0c;感觉自己不能够在这样下去了&#xff0c;长时间呆在一个舒适的环境会让一个人堕落!而我已经在一个企业干了2年的功能测试&…

C# Onnx yolov8 building segmentation

目录 效果 模型信息 项目 代码 下载 C# Onnx yolov8 building segmentation 效果 模型信息 Model Properties ------------------------- date&#xff1a;2023-12-22T10:51:07.627471 author&#xff1a;Ultralytics task&#xff1a;segment license&#xff1a;AGPL-…

HBase分布式数据库的原理和架构

一、HBase简介 HBase是是一个高性能、高可靠性、面向列的分布式数据库&#xff0c;它是为了在廉价的硬件集群上存储大规模数据而设计的。HBase利用Hadoop HDFS作为其文件存储系统&#xff0c;且Hbase是基于Zookeeper的。 二、HBase架构 *图片引用 Hbase采用Master/Slave架构…

YOLOv9(3):YOLOv9损失(Loss)计算

1. 写在前面 YOLOv9的Loss计算与YOLOv8如出一辙&#xff0c;仅存在略微的差异。多说一句&#xff0c;数据的预处理和导入方式都是一样的。因此如果你已经对YOLOv8了解的比较透彻&#xff0c;那么对于YOLOv9你也只是需要多关注网络结构就可以。 YOLOv9本身也是Anchor-Free的&a…

彩虹外链网盘界面UI美化版超级简洁好看

彩虹外链网盘界面UI美化版 彩虹外链网盘&#xff0c;是一款PHP网盘与外链分享程序&#xff0c;支持所有格式文件的上传&#xff0c;可以生成文件外链、图片外链、音乐视频外链&#xff0c;生成外链同时自动生成相应的UBB代码和HTML代码&#xff0c;还可支持文本、图片、音乐、…

机试:成绩排名

问题描述: 代码示例: #include <bits/stdc.h> using namespace std;int main(){cout << "样例输入" << endl; int n;int m;cin >> n;int nums[n];for(int i 0; i < n; i){cin >> nums[i];}// 排序for(int i 0; i < n; i){//…

Leetcode 3.14

Leetcode hot100 二叉树1.二叉树的层序遍历2.验证二叉搜索树3.二叉树的右视图 二叉树 1.二叉树的层序遍历 二叉树的层序遍历 二叉树的层序遍历可以用先进先出的队列来实现。 将每一层的所有node都添加到队列中&#xff0c;记录下当前队列的长度&#xff0c;即该层的元素数量&…

Java操作Sql语句 出现迭代死循环 (Bug排查)

目录 1. 问题所示2. 原理分析3. 解决方法4. 彩蛋1. 问题所示 Java执行Sql语句来查询一些数据的时候 虽说数据量很大,但是查询过程中一直迭代查询 截图如下所示: 2. 原理分析 至于迭代死循环,可能是不满足的条件也进入查询(本身我的数据量就很大) 主要可能引起的两个原…

创新发展,探索智慧园区平台架构设计与实现

随着信息技术的快速发展&#xff0c;智慧园区平台作为集成物联网、大数据、人工智能等技术的综合性服务平台&#xff0c;正逐步成为推动企业数字化转型的重要驱动力。本文将深入探讨智慧园区平台的架构设计思路、关键技术和应用场景&#xff0c;助力读者了解如何打造智慧化、协…

【洛谷 P8637】[蓝桥杯 2016 省 B] 交换瓶子 题解(贪心算法)

[蓝桥杯 2016 省 B] 交换瓶子 题目描述 有 N N N 个瓶子&#xff0c;编号 1 ∼ N 1 \sim N 1∼N&#xff0c;放在架子上。 比如有 5 5 5 个瓶子&#xff1a; 2 , 1 , 3 , 5 , 4 2,1,3,5,4 2,1,3,5,4 要求每次拿起 2 2 2 个瓶子&#xff0c;交换它们的位置。 经过若干次…