AI写诗--基于GPT2预训练模型

news2024/10/5 15:25:42

目录

 1  处理数据

 1.1  加载预训练的分词器¶

2  自定义创建数据集 

 2.1  创建dataset

2.2  自定义collate_fn(数据批量输出的方法) 

 2.3  创建数据加载器 

3  创建模型 

4  训练过程代码 

 5  保存训练好的模型

 6  加载保存好的模型

 7  测试预测阶段代码


 

#目前,NLP与CV主要使用transformers库

#框架:主要使用PyTorch

#NLP任务的大体流程:
#处理数据: 中文字符 ---> 数字
#创建数据集。 把处理好的数据变成PyTorch的数据集
#生成模型, 一般使用transformers库,不需要自己建模
#训练预测过程

#配置代理
# import os

# os.environ['http_proxy'] = '127.0.0.1:10809'
# os.environ['https_proxy'] = '127.0.0.1:10809'
#这里是本地加载预训练模型,不需要

 1  处理数据

 1.1  加载预训练的分词器¶

from transformers import AutoTokenizer   #AutoTokenizer分词器 可以使中文字符转变成数字


#我这里是本地加载的模型文件
tokenizer = AutoTokenizer.from_pretrained('../data/model/gpt2-chinese-cluecorpussmall/')
print(tokenizer)

 

BertTokenizerFast(name_or_path='../data/model/gpt2-chinese-cluecorpussmall/', vocab_size=21128, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
#编码分词试算
text = [ '明朝驿使发,一夜絮征袍.素手抽针冷,那堪把剪刀.裁缝寄远道,几日到临洮.',
         '长安一片月,万户捣衣声.秋风吹不尽,总是玉关情.何日平胡虏,良人罢远征.']
#输出结果为一个字典,包含'input_ids'、'token_type_ids'、'attention_mask'
tokenizer.batch_encode_plus(text)
{'input_ids': [[101, 3209, 3308, 7731, 886, 1355, 117, 671, 1915, 5185, 2519, 6151, 119, 5162, 2797, 2853, 7151, 1107, 117, 6929, 1838, 2828, 1198, 1143, 119, 6161, 5361, 2164, 6823, 6887, 117, 1126, 3189, 1168, 707, 3826, 119, 102], [101, 7270, 2128, 671, 4275, 3299, 117, 674, 2787, 2941, 6132, 1898, 119, 4904, 7599, 1430, 679, 2226, 117, 2600, 3221, 4373, 1068, 2658, 119, 862, 3189, 2398, 5529, 5989, 117, 5679, 782, 5387, 6823, 2519, 119, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

2  自定义创建数据集 

 2.1  创建dataset

import torch


class Dataset(torch.utils.data.Dataset):
    def __init__(self):
        super().__init__()
        #从本地读取数据
        with open('../data/datasets/chinese_poems.txt', encoding='utf-8') as f:
            lines = f.readlines()  #读取的每一行数据都会以一个字符串的形式 依次添加到一个列表中
            
        #split()函数可以根据指定的分隔符将字符串拆分成多个子字符串,并将这些子字符串存储在一个列表中。
        #strip()函数默认移除字符串两端的空白字符(包括空格、制表符、换行符等)
        
        lines = [line.strip() for line in lines]  #输出的lines是一个一维列表,里面的每一行诗都是一个字符串
        #self.的变量在类里面可以调用
        self.lines = lines   #self.lines是一个列表,里面的元素都是一个个字符串
        
        
    def __len__(self):
        return len(self.lines)
    
    def __getitem__(self, i):
        """可以向列表一样通过索引来获取数据"""
        return self.lines[i]

#试跑一下
dataset = Dataset()        
len(dataset), dataset[0]        
        

 

(304752, '欲出未出光辣达,千山万山如火发.须臾走向天上来,逐却残星赶却月.')

dataset数据集只能一条一条数据的输出,不能一批批数据传输,
需要将datatset变成pytorch中dataloader的数据形式,将数据可以批量输出 

2.2  自定义collate_fn(数据批量输出的方法) 

def collate_fn(batch):
    #使用分词器 把中文编码成数字
    #tokenizer分词器的输出结果data是一个字典,包含'input_ids'、'token_type_ids'、'attention_mask'
    data = tokenizer.batch_encode_plus(batch, 
                                padding=True,
                                truncation=True,
                                max_length=512,
                                return_tensors='pt')
    #向字典data中添加数据标签目标值labels, 用data原数据中的['input_ids']诗句文字编码来赋值,
    #克隆一份对原数据无影响
    data['labels'] = data['input_ids'].clone()
    return data

 2.3  创建数据加载器 

loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=4, 
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)
#dataloader不能直接访问数据,需要for循环来获取数据
#查看第一批数据
for i, data in enumerate(loader):
    break  #只循环一次
i
data  #data是一个字典, 包含'input_ids'、'token_type_ids'、'attention_mask'、'labels'

 

{'input_ids': tensor([[ 101, 2708, 4324, 2406,  782, 1777, 1905, 3918,  117, 7345, 5125, 7346,
         7790, 6387, 4685, 2192,  119,  738, 4761, 5632, 3300, 1921, 1045, 1762,
          117, 6475,  955,  865, 6778, 4212, 2769, 1412,  119,  102,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0],
        [ 101, 1921, 4495,  671, 4954,  117, 5966, 1434, 3369, 7755,  119, 7755,
         3323, 2768, 1759,  117, 1759, 5543, 4495, 4289,  119, 5310,  702, 5872,
         5701,  117, 2899, 6627, 2336, 1880,  119, 3719, 5564, 6762, 1726,  117,
         6631,  676,  686,  867,  119,  102,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0],
        [ 101,  753, 2399, 3736,  677, 6224, 3217, 2495,  117, 7564, 2682, 7028,
         3341, 2769, 3313, 1726,  119, 3922, 6862,  686, 7313, 6443, 3160, 2533,
          117, 4856, 2418, 4685, 6878,  684, 6124, 3344,  119,  102,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0],
        [ 101, 3926, 7599,  711, 2769, 6843, 2495, 5670,  117, 3144, 5108, 7471,
         3351, 6629, 5946, 4170,  119, 2359, 2512, 2661, 7607, 4904, 3717,  100,
          117, 3587, 1898, 3009, 3171, 1911, 7345, 6068,  119, 1126,  782, 2157,
         1762, 3983, 1928, 2279,  117,  671, 4275,  756, 4495, 3717, 2419, 1921,
          119, 4007, 4706, 5679, 3301, 3187, 1962, 6983,  117, 3634, 2552, 2347,
         2899,  736, 3736, 6804,  119,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[ 101, 2708, 4324, 2406,  782, 1777, 1905, 3918,  117, 7345, 5125, 7346,
         7790, 6387, 4685, 2192,  119,  738, 4761, 5632, 3300, 1921, 1045, 1762,
          117, 6475,  955,  865, 6778, 4212, 2769, 1412,  119,  102,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0],
        [ 101, 1921, 4495,  671, 4954,  117, 5966, 1434, 3369, 7755,  119, 7755,
         3323, 2768, 1759,  117, 1759, 5543, 4495, 4289,  119, 5310,  702, 5872,
         5701,  117, 2899, 6627, 2336, 1880,  119, 3719, 5564, 6762, 1726,  117,
         6631,  676,  686,  867,  119,  102,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0],
        [ 101,  753, 2399, 3736,  677, 6224, 3217, 2495,  117, 7564, 2682, 7028,
         3341, 2769, 3313, 1726,  119, 3922, 6862,  686, 7313, 6443, 3160, 2533,
          117, 4856, 2418, 4685, 6878,  684, 6124, 3344,  119,  102,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0],
        [ 101, 3926, 7599,  711, 2769, 6843, 2495, 5670,  117, 3144, 5108, 7471,
         3351, 6629, 5946, 4170,  119, 2359, 2512, 2661, 7607, 4904, 3717,  100,
          117, 3587, 1898, 3009, 3171, 1911, 7345, 6068,  119, 1126,  782, 2157,
         1762, 3983, 1928, 2279,  117,  671, 4275,  756, 4495, 3717, 2419, 1921,
          119, 4007, 4706, 5679, 3301, 3187, 1962, 6983,  117, 3634, 2552, 2347,
         2899,  736, 3736, 6804,  119,  102]])}

3  创建模型 

#LM:语言模型
#AutoModelForCausalLM 语言模型的加载器
# from transformers import AutoModelForCausalLM, GPT2Model
from transformers import AutoModelForCausalLM
#加载模型
model = AutoModelForCausalLM.from_pretrained('../data/model/gpt2-chinese-cluecorpussmall/')


#查看加载的预训练模型的参数量
print(sum(p.numel() for p in model.parameters()))
102068736
#试算预测一下
with torch.no_grad():  #模型预测时,参数不需要梯度下降
    #outs是一个元组,包含'loss'(损失)和'logits'(概率)
    outs = model(**data)   


outs['logits'].shape
#4:batch_size
#197:每个句子的序列长度
#21128:每个词对应的21128(vocab_size)个词概率

 

torch.Size([4, 66, 21128])
outs['loss'], outs['logits']

 

(tensor(8.5514),
 tensor([[[ -9.9143,  -9.7647,  -9.8217,  ...,  -9.6961,  -9.7799,  -9.6771],
          [ -7.4731,  -8.7423,  -8.4802,  ...,  -8.2767,  -8.6411,  -9.1488],
          [ -8.7324,  -9.3639,  -9.3685,  ...,  -9.7467,  -9.2594,  -9.9237],
          ...,
          [ -3.6951,  -3.9939,  -4.2000,  ...,  -4.2021,  -4.6660,  -4.4627],
          [ -3.7271,  -4.0562,  -4.2753,  ...,  -4.2301,  -4.7670,  -4.5282],
          [ -3.6152,  -3.9949,  -4.1994,  ...,  -4.1643,  -4.6812,  -4.4797]],
 
         [[ -9.9143,  -9.7647,  -9.8217,  ...,  -9.6961,  -9.7799,  -9.6771],
          [ -8.5889,  -9.2279,  -9.2168,  ...,  -8.6957,  -8.1567,  -8.5526],
          [ -8.8908,  -8.8825,  -8.7488,  ...,  -9.8976,  -9.4964, -10.1446],
          ...,
          [ -3.8280,  -3.7346,  -4.4447,  ...,  -3.8380,  -4.3585,  -4.2275],
          [ -4.0099,  -3.8985,  -4.6581,  ...,  -3.9868,  -4.5486,  -4.3698],
          [ -3.8161,  -3.7165,  -4.4473,  ...,  -3.8579,  -4.3969,  -4.2764]],
 
         [[ -9.9143,  -9.7647,  -9.8217,  ...,  -9.6961,  -9.7799,  -9.6771],
          [ -7.7595,  -8.7731,  -8.8029,  ...,  -9.2167,  -8.4741,  -8.4485],
          [ -9.1754,  -8.8637,  -9.1363,  ...,  -8.7321,  -8.7189,  -8.9582],
          ...,
          [ -3.7426,  -4.1014,  -4.2192,  ...,  -4.3925,  -4.5313,  -4.6184],
          [ -3.8279,  -4.2058,  -4.3173,  ...,  -4.4447,  -4.6614,  -4.6665],
          [ -3.7733,  -4.1448,  -4.2570,  ...,  -4.4249,  -4.6132,  -4.6397]],
 
         [[ -9.9143,  -9.7647,  -9.8217,  ...,  -9.6961,  -9.7799,  -9.6771],
          [ -6.8225,  -7.6599,  -7.4913,  ...,  -7.5897,  -7.4440,  -7.5681],
          [ -7.3068,  -7.6038,  -7.2369,  ...,  -7.8313,  -8.0071,  -7.8388],
          ...,
          [ -5.6309,  -5.6956,  -5.5271,  ...,  -5.4339,  -5.3443,  -5.6756],
          [ -6.4130,  -6.3038,  -6.4816,  ...,  -6.5781,  -6.2063,  -6.4139],
          [ -3.6458,  -4.0801,  -3.7062,  ...,  -4.2418,  -3.9411,  -4.0311]]]))

4  训练过程代码 

from transformers import AdamW
from transformers.optimization import get_scheduler  #学习率的衰减策略


#训练
def train():
    #model是在此函数外部创建的,在此函数内调用前,需要声明model是全局变量
    global model
    #设置设备
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    #将模型传到设备上
    model = model.to(device)
    
    #创建梯度下降的优化器
    optimizer = AdamW(model.parameters(), lr=5e-5)      #lr=0.00005,  -5表示有5位小数
    #创建学习率衰减计划
    scheduler = get_scheduler(name='linear',   #线性的
                              num_warmup_steps=0,  #学习率从一开始就开始衰减,没有预热缓冲期
                              num_training_steps=len(loader),  #loader中有多少批数据就训练多少次
                              optimizer=optimizer)
    model.train()
    for i, data in enumerate(loader):
        for k in data.key():
            #将字典data中每个key所对应的value都传到设备上,再赋值给data[k],相当于把data传到了设备上
            data[k] = data[k].to(device)
        #将设备上的data传入模型中,获取输出结果outs(一个字典,包含loss和logits(概率分布))
        outs = model(**data)  #data是一个字典, **data将字典解包成关键字参数传入
        
        #从outs中获取损失,在训练过程中观察loss是不是在下降,不下降就是不正常
        loss = outs['loss']
        #反向传播
        loss.backward()
        #为了梯度下降的稳定性,防止梯度太大,进行梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters, 1.0)  #公式中的c最大值就是1
        #梯度更新
        optimizer.step()
        scheduler.step()
        #梯度清零
        optimizer.zero_grad()
        model.zero_grad()
        
        
        if i % 1000 == 0:  #每1000个步数,就输出打印内容
            #下一句诗句是上一句的预测目标真实值,有一个偏移量
            labels = data['labels'][:, 1:]
            #预测值
            outs = outs['logits'].argmax(dim=2)[:, :-1]
            
            #筛选条件
            select = labels != 0   #0是补得pad没有意义,需要筛掉
            #分别对labels和outs进行筛选
            labels = labels[select]
            outs = outs[select]
            del select  #后面这个变量没有用了, 删除防止占用过多内存
            
            #计算准确率
            #labels.numel()  求labels内元素的总个数
            #.item() 在pytorch中,取出tensor标量的数值
            cccuracy = (labels == outs).sum().item() / labels.numel()  
            #取出学习率
            lr = optimizer.state_dict()['param_groups'][0]['lr']
            print(i, loss.item(), lr, accuracy)



train()
    

 5  保存训练好的模型

#保存训练好的模型
# model = model.to('cpu')  #将模型传到设备上
# torch.save(model, 'model.pt')

 6  加载保存好的模型

# model_2 = torch.laod('../data/model/AI-Poem-save.model')

 7  测试预测阶段代码

def generate(text, row, col, model):
    """
    text:传入的数据
    row, col:预测的诗句是几行几列的
    model:使用的是哪个模型来预测
    """
    def generate_loop(data):
        """循环来预测"""
        #模型预测时,不需要求导来反向传播
        with torch.no_grad():
            outs = model(**data)
            
        #从outs中获取分类概率,   输出形状与输入形状一致,所以batch_size在后面
        # outs形状 [5(五言诗,序列长度), batch_size, vocab_size]
        outs = outs['logits']
        #outs形状 [5(五言诗,序列长度), vocab_size]
        #只取一个元素会把对应的维度降调
        outs = outs[:, -1]  #最后一个是预测值

        #写诗:预测概率最高的词不一定是最合适的
        #取出概率较高的前50个
        #[5, vocab_size]  --> [5, 50]
        topk_value = torch.topk(outs, 50).values   #按从小到大排序的
        #取最后一个就是概率最大的那一个
        #[5, 50] --> [5] ,升维--> [5, 1]
        topk_value = topk_value[:, -1].unsqueeze(dim=1)

        #赋值    # -float('inf')负无穷大 ,表示没有意义
        outs = outs.masked_fill(outs < topk_value, -float('inf')) 

        #不允许写特殊字符, 将其赋值为负无穷大
        outs[:, tokenizer.sep_token_id] = -float('inf')  #分隔符
        outs[:, tokenizer.unk_token_id] = -float('inf')  #未知字符
        outs[:, tokenizer.pad_token_id] = -float('inf')  #填充pad

        for i in ',。':
            outs[:, tokenizer.get_vocab()[i]] = -float('inf')

        #根据概率做一个无放回的采样:不会出现重复的数据
        #[5, vocab_size] ---> [5, 1]
        outs = outs.softmax(dim=1)
        outs = outs.multinomial(num_sample=1)  #从中取一个

        #强制添加标点
        c = data['input_ids'].shape[1] / (col + 1)
        #若c为整数
        if c % 1 == 0:
            #若为偶数行
            if c % 2 == 0:
                outs[:, 0] = tokenizer.get_vocab()['。']
            else:
                outs[:, 0] = tokenizer.get_vocab()[',']

        #将原始的输入数据和预测的结果拼到一起, 当做下一次预测的输入, 依次循环
        data['input_ids'] = torch.cat([data['input_ids'], outs], dim=1)
        data['attention_mask'] = torch.ones_like(data['input_ids'])
        data['token_type_ids'] = torch.zeros_like(data['input_ids'])
        data['labels'] = data['input_ids'].clone()

        # row * col + 1   : 总字数+标点符号
        if data['input_ids'].shape[1] >= row * col + 1:
            return data
        return generate_loop(data)

    #重复三遍:一次生成三首,一次生成的效果可能不太好
    data = tokenizer.batch_encode_plus([text]*3, return_tensors='pt')
    data['input_ids'] = data['input_ids'][:, :-1]  #最后一个不要
    data['attention_mask'] = torch.ones_like(data['input_ids'])
    data['token_type_ids'] = torch.zeros_like(data['input_ids'])
    data['labels'] = data['input_ids'].clone()

    data = generate_loop(data)

    for i in range(3):
        #一次生成三首,按索引打印输出其中一首
        print(i, tokenizer.decode(data['input_ids'][i]))
model_2 = torch.load('../data//model/AI-Poem-save.model')

generate('秋高气爽', row=4, col=7, model=model_2)

0 [CLS] 秋 高 气 爽 雁 初 飞 , 云 树 高 峰 落 叶 稀 。 人 尽 夜 归 山 外 宿 , 鸡 鸣 霜 月 下 寒 衣 。
1 [CLS] 秋 高 气 爽 木 生 秋 , 何 处 仙 方 未 可 求 。 莫 遣 夜 猿 催 老 去 , 东 风 吹 老 上 林 丘 。
2 [CLS] 秋 高 气 爽 早 蝉 喧 , 清 籁 无 声 响 自 喧 。 野 望 岂 容 云 梦 见 , 江 涵 应 属 月 华 昏 。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2190086.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

IP 数据包分包组包

为什么要分包 由于数据链路层MTU的限制,对于较⼤的IP数据包要进⾏分包. 什么是MTU MTU相当于发快递时对包裹尺⼨的限制.这个限制是不同的数据链路对应的物理层,产⽣的限制. • 以太⽹帧中的数据⻓度规定最⼩46字节,最⼤1500字节,ARP数据包的⻓度不够46字节,要在后⾯补填 充…

【ONE·Web || HTML】

总言 主要内容&#xff1a;HTML基本知识入门&#xff0c;主要介绍了常见的一些标签使用&#xff0c;以及简单案例演示。       文章目录 总言0、前置说明1、认识HTML1.1、是什么1.2、初识 HTML 标签、HTML 文件基本结构1.2.1、相关说明1.2.2、vscode如何快速生成代码 2、HT…

实时数仓分层架构超全解决方案

传统意义上的数据仓库主要处理T1数据&#xff0c;即今天产生的数据分析结果明天才能看到&#xff0c;T1的概念来源于股票交易&#xff0c;是一种股票交易制度&#xff0c;即当日买进的股票要到下一个交易日才能卖出。 随着互联网以及很多行业线上业务的快速发展&#xff0c;让…

【精】Java编程中的Lambda表达式与Stream API

一、引言 随着Java 8的发布&#xff0c;引入了许多令人兴奋的新特性&#xff0c;其中最引人注目的就是Lambda表达式和Stream API。这些新功能不仅让Java这门语言更加现代化&#xff0c;而且也极大地提高了开发效率&#xff0c;使代码更加简洁、易读。本文将深入探讨Lambda表达…

Rust 做桌面应用这么轻松?Pake 彻底改变你的开发方式

Rust 做桌面应用这么轻松&#xff1f;Pake 彻底改变你的开发方式 网页应用装不下了&#xff1f;别担心&#xff0c;Pake 用 Rust 帮你打包网页&#xff0c;快速搞定桌面应用。比起动不动就 100M 的 Electron 应用&#xff0c;它轻如鸿毛&#xff0c;功能却一点都不少&#xff0…

案例-任务清单

文章目录 效果展示初始化面演示画面 代码区 效果展示 初始化面 演示画面 任务清单 代码区 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, in…

Linux下的IO模型

阻塞与非阻塞IO&#xff08;Input/Output&#xff09; 阻塞与非阻塞IO&#xff08;Input/Output&#xff09;是计算机操作系统中两种不同的文件或网络通信方式。它们的主要区别在于程序在等待IO操作完成时的行为。 阻塞IO&#xff08;Blocking IO&#xff09; 在阻塞IO模式下…

付费计量系统通用功能(13)

11.17 Class 17: Security function Capability of maintaining the integrity of data elements, functions and processes. 数据单元、功能和过程的可靠性 Maintains the integrity of the system.系统的可靠 Some examples of security function at…

Meta推出Movie Gen 旗下迄今最先进的视频生成AI模型

Meta 今天发布了 MovieGen 系列媒体基础AI模型&#xff0c;该模型可根据文本提示生成带声音的逼真视频。 MovieGen 系列包括两个主要模型&#xff1a; MovieGen Video 和 MovieGen Audio。 MovieGen Video 是一个具有 300 亿个参数的变换器模型&#xff0c;可根据单个文本提示生…

一“填”到底:深入理解Flood Fill算法

✨✨✨学习的道路很枯燥&#xff0c;希望我们能并肩走下来! 文章目录 目录 文章目录 前言 一 floodfill算法是什么&#xff1f; 二 相关OJ题练习 2.1 图像渲染 2.2 岛屿数量 2.3 岛屿的最大面积 2.4 被围绕的区域 2.5 太平洋大西洋水流问题 2.6 扫雷游戏 2.7 衣橱整…

数据科学:Data+AI驾驭数据的智慧之旅

数据科学&#xff1a;DataAI驾驭数据的智慧之旅 前言一、数据存储计算二、数据治理三、结构化数据分析四、语音分析五、视觉分析六、文本分析七、知识图谱 前言 今天想和大家深入聊聊数据科学这个充满魅力又极具挑战的领域。在当今数字化时代&#xff0c;数据如同潮水般涌来&a…

掌握这一招,轻松用Vue和ECharts打造炫酷雷达图——详细教程指南

大家好&#xff0c;今天我要分享的是如何使用ECharts来绘制雷达图。雷达图是一种常用的数据可视化工具&#xff0c;特别适合展示多个量化指标的比较&#xff0c;也可以进行多维度用户行为分析。接下来&#xff0c;我将一步步教大家如何通过ECharts来实现这一效果。效果图如下&a…

mysql事务 -- 事务的隔离性(测试实验+介绍,脏读,不可重复读,可重复度读,幻读),如何实现(RR和RC的本质区别)

目录 事务的隔离性 引入 测试 读未提交 脏读 读提交 不可重复读 属于问题吗? 例子 可重复读 幻读 串行化 原理 总结 事务的隔离性 隔离性的理解 -- mysql事务 -- 如何理解事务,四个属性,查看是否支持事务,事务操作(提交方式,事务的开始和回滚,提交),事务的隔离…

(Django)初步使用

前言 Django 是一个功能强大、架构良好、安全可靠的 Python Web 框架&#xff0c;适用于各种规模的项目开发。它的高效开发、数据库支持、安全性、良好的架构设计以及活跃的社区和丰富的文档&#xff0c;使得它成为众多开发者的首选框架。 目录 安装 应用场景 良好的架构设计…

基于单片机的智能浇花系统

目录 一、主要功能 二、硬件资源 三、程序编程 四、实现现象 一、主要功能 基于51单片机&#xff0c;采样DHT11温湿度传感器检测温湿度&#xff0c;通过LCD1602显示 4*4按键矩阵可以设置温度湿度阈值&#xff0c;温度大于阈值则开启水泵&#xff0c;湿度大于阈值则开启风扇…

从零开始讲PCIe(6)——PCI-X概述

一、概述 PCI-X 在硬件和软件上与 PCI 具有向后兼容性&#xff0c;同时提供了更高的性能和效率。它使用与 PCI 相同的连接器格式&#xff0c;因此 PCI-X 设备可以插入 PCI 插槽&#xff0c;反之亦然。而且&#xff0c;PCI-X 采用相同的配置模型&#xff0c;因此在 PCI 系统上运…

Apollo9.0 Planning2.0决策规划算法代码详细解析 (4): PlanningComponent::Proc()

&#x1f31f; 面向自动驾驶规划算法工程师的专属指南 &#x1f31f; 欢迎来到《Apollo9.0 Planning2.0决策规划算法代码详细解析》专栏&#xff01;本专栏专为自动驾驶规划算法工程师量身打造&#xff0c;旨在通过深入剖析Apollo9.0开源自动驾驶软件栈中的Planning2.0模块&am…

webpack插件 --- webpack-bundle-analyzer【查看包体积】

const UglifyJsPlugin require(uglifyjs-webpack-plugin) // 清除注释 const CompressionWebpackPlugin require(compression-webpack-plugin); // 开启压缩// 是否为生产环境 const isProduction process.env.NODE_ENV production; const { BundleAnalyzerPlugin } requi…

大数据可视化分析建模论

大数据可视化分析建模论 前言大数据可视化分析建模 前言 在这个信息爆炸的时代&#xff0c;数据如同潮水般涌来&#xff0c;我们每天都在与海量的数据打交道。数据已经成为了企业决策、科研创新以及社会发展的核心要素。如何从这些纷繁复杂的数据中提取有价值的信息&#xff0…

C++多态、虚函数以及抽象类

目录 1.多态的概念 2.多态的定义及实现 2.1多态的构成条件 2.1.1实现多态还有两个必要条件 2.1.2虚函数 2.1.3虚函数的重写/覆盖 2.1.4多态场景的题目 2.1.5虚函数重写的一些其他问题 2.1.5.1协变(了解) 2.1.5.2析构函数的重写 2.1.6override和final关键字 2.…