pytorch之诗词生成--2

news2025/1/10 1:41:35

先上代码:

# -*- coding: utf-8 -*-
# @File    : dataset.py
# @Author  : AaronJny
# @Time    : 2019/12/30
# @Desc    : 构建数据集
from collections import Counter
import math
import numpy as np
import tensorflow as tf
import settings


class Tokenizer:
    """
    分词器
    """

    def __init__(self, token_dict):
        # 词->编号的映射
        self.token_dict = token_dict
        # 编号->词的映射
        self.token_dict_rev = {value: key for key, value in self.token_dict.items()}
        # 词汇表大小
        self.vocab_size = len(self.token_dict)

    def id_to_token(self, token_id):
        """
        给定一个编号,查找词汇表中对应的词
        :param token_id: 带查找词的编号
        :return: 编号对应的词
        """
        return self.token_dict_rev[token_id]

    def token_to_id(self, token):
        """
        给定一个词,查找它在词汇表中的编号
        未找到则返回低频词[UNK]的编号
        :param token: 带查找编号的词
        :return: 词的编号
        """
        return self.token_dict.get(token, self.token_dict['[UNK]'])

    def encode(self, tokens):
        """
        给定一个字符串s,在头尾分别加上标记开始和结束的特殊字符,并将它转成对应的编号序列
        :param tokens: 待编码字符串
        :return: 编号序列
        """
        # 加上开始标记
        token_ids = [self.token_to_id('[CLS]'), ]
        # 加入字符串编号序列
        for token in tokens:
            token_ids.append(self.token_to_id(token))
        # 加上结束标记
        token_ids.append(self.token_to_id('[SEP]'))
        return token_ids

    def decode(self, token_ids):
        """
        给定一个编号序列,将它解码成字符串
        :param token_ids: 待解码的编号序列
        :return: 解码出的字符串
        """
        # 起止标记字符特殊处理
        spec_tokens = {'[CLS]', '[SEP]'}
        # 保存解码出的字符的list
        tokens = []
        for token_id in token_ids:
            token = self.id_to_token(token_id)
            if token in spec_tokens:
                continue
            tokens.append(token)
        # 拼接字符串
        return ''.join(tokens)


# 禁用词
disallowed_words = settings.DISALLOWED_WORDS
# 句子最大长度
max_len = settings.MAX_LEN
# 最小词频
min_word_frequency = settings.MIN_WORD_FREQUENCY
# mini batch 大小
batch_size = settings.BATCH_SIZE

# 加载数据集
with open(settings.DATASET_PATH, 'r', encoding='utf-8') as f:
    lines = f.readlines()
    # 将冒号统一成相同格式
    lines = [line.replace(':', ':') for line in lines]
# 数据集列表
poetry = []
# 逐行处理读取到的数据
for line in lines:
    # 有且只能有一个冒号用来分割标题
    if line.count(':') != 1:
        continue
    # 后半部分不能包含禁止词
    __, last_part = line.split(':')
    ignore_flag = False
    for dis_word in disallowed_words:
        if dis_word in last_part:
            ignore_flag = True
            break
    if ignore_flag:
        continue
    # 长度不能超过最大长度
    if len(last_part) > max_len - 2:
        continue
    poetry.append(last_part.replace('\n', ''))

# 统计词频
counter = Counter()
for line in poetry:
    counter.update(line)
# 过滤掉低频词
_tokens = [(token, count) for token, count in counter.items() if count >= min_word_frequency]
# 按词频排序
_tokens = sorted(_tokens, key=lambda x: -x[1])
# 去掉词频,只保留词列表
_tokens = [token for token, count in _tokens]

# 将特殊词和数据集中的词拼接起来
_tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + _tokens
# 创建词典 token->id映射关系
token_id_dict = dict(zip(_tokens, range(len(_tokens))))
# 使用新词典重新建立分词器
tokenizer = Tokenizer(token_id_dict)
# 混洗数据
np.random.shuffle(poetry)


class PoetryDataGenerator:
    """
    古诗数据集生成器
    """

    def __init__(self, data, random=False):
        # 数据集
        self.data = data
        # batch size
        self.batch_size = batch_size
        # 每个epoch迭代的步数
        self.steps = int(math.floor(len(self.data) / self.batch_size))
        # 每个epoch开始时是否随机混洗
        self.random = random

    def sequence_padding(self, data, length=None, padding=None):
        """
        将给定数据填充到相同长度
        :param data: 待填充数据
        :param length: 填充后的长度,不传递此参数则使用data中的最大长度
        :param padding: 用于填充的数据,不传递此参数则使用[PAD]的对应编号
        :return: 填充后的数据
        """
        # 计算填充长度
        if length is None:
            length = max(map(len, data))
        # 计算填充数据
        if padding is None:
            padding = tokenizer.token_to_id('[PAD]')
        # 开始填充
        outputs = []
        for line in data:
            padding_length = length - len(line)
            # 不足就进行填充
            if padding_length > 0:
                outputs.append(np.concatenate([line, [padding] * padding_length]))
            # 超过就进行截断
            else:
                outputs.append(line[:length])
        return np.array(outputs)

    def __len__(self):
        return self.steps

    def __iter__(self):
        total = len(self.data)
        # 是否随机混洗
        if self.random:
            np.random.shuffle(self.data)
        # 迭代一个epoch,每次yield一个batch
        for start in range(0, total, self.batch_size):
            end = min(start + self.batch_size, total)
            batch_data = []
            # 逐一对古诗进行编码
            for single_data in self.data[start:end]:
                batch_data.append(tokenizer.encode(single_data))
            # 填充为相同长度
            batch_data = self.sequence_padding(batch_data)
            # yield x,y
            yield batch_data[:, :-1], tf.one_hot(batch_data[:, 1:], tokenizer.vocab_size)
            del batch_data

    def for_fit(self):
        """
        创建一个生成器,用于训练
        """
        # 死循环,当数据训练一个epoch之后,重新迭代数据
        while True:
            # 委托生成器
            yield from self.__iter__()

下面我们逐行分析该代码:我们首先定义一个分词器类:

class Tokenizer:
    """
    分词器
    """

    def __init__(self, token_dict):
        # 词->编号的映射
        self.token_dict = token_dict
        # 编号->词的映射
        self.token_dict_rev = {value: key for key, value in self.token_dict.items()}
        # 词汇表大小
        self.vocab_size = len(self.token_dict)

    def id_to_token(self, token_id):
        """
        给定一个编号,查找词汇表中对应的词
        :param token_id: 带查找词的编号
        :return: 编号对应的词
        """
        return self.token_dict_rev[token_id]

    def token_to_id(self, token):
        """
        给定一个词,查找它在词汇表中的编号
        未找到则返回低频词[UNK]的编号
        :param token: 带查找编号的词
        :return: 词的编号
        """
        return self.token_dict.get(token, self.token_dict['[UNK]'])

    def encode(self, tokens):
        """
        给定一个字符串s,在头尾分别加上标记开始和结束的特殊字符,并将它转成对应的编号序列
        :param tokens: 待编码字符串
        :return: 编号序列
        """
        # 加上开始标记
        token_ids = [self.token_to_id('[CLS]'), ]
        # 加入字符串编号序列
        for token in tokens:
            token_ids.append(self.token_to_id(token))
        # 加上结束标记
        token_ids.append(self.token_to_id('[SEP]'))
        return token_ids

    def decode(self, token_ids):
        """
        给定一个编号序列,将它解码成字符串
        :param token_ids: 待解码的编号序列
        :return: 解码出的字符串
        """
        # 起止标记字符特殊处理
        spec_tokens = {'[CLS]', '[SEP]'}
        # 保存解码出的字符的list
        tokens = []
        for token_id in token_ids:
            token = self.id_to_token(token_id)
            if token in spec_tokens:
                continue
            tokens.append(token)
        # 拼接字符串
        return ''.join(tokens)

看第一段:

def __init__(self, token_dict):
    # 词->编号的映射
    self.token_dict = token_dict
    # 编号->词的映射
    self.token_dict_rev = {value: key for key, value in self.token_dict.items()}
    # 词汇表大小
    self.vocab_size = len(self.token_dict)

首先我们接受一个名为token_dict的参数,将其存储为类的属性,然后创建一个名为token_dict_rev的属性,这是token_dict的反向映射,最后,计算词汇表的大小并将其存储为vocab_size属性。

看下一段:

def id_to_token(self, token_id):
    """
    给定一个编号,查找词汇表中对应的词
    :param token_id: 带查找词的编号
    :return: 编号对应的词
    """
    return self.token_dict_rev[token_id]

这段代码定义一个方法id_to_token,接受一个名为token_id的参数,然后在词汇表中查找对应的词并返回,这个方法实际上是通过token_dict_rev属性实现的反向查找。明显,该字典中的键词的编号,值是词。

接着往下看:

def token_to_id(self, token):
    """
    给定一个词,查找它在词汇表中的编号
    未找到则返回低频词[UNK]的编号
    :param token: 带查找编号的词
    :return: 词的编号
    """
    return self.token_dict.get(token, self.token_dict['[UNK]'])

这段代码与上一段的由键到值差不多,是由值找到对应的键。接受名为token作为参数,然后在词汇表中查找对应词的编号并返回。如果词不在词汇表中,则返回低频词[UNK]的编号,注意我们的token_dict字典的键是词,值是编号,我们可以通过词来找到对应的编号,而token_dict_rev的键是编号,值是词,我们可以通过编号找到对应的值。

return self.token_dict.get(token, self.token_dict['[UNK]'])这段代码中,我们使用get方法,我们尝试在self.token_dict中获取键为token的值,也就是找到对应的编号,第二个参数表示如果没找到对应的键,则返回self.token_dict中键为[UNK]的值。(第二个参数表示字典找不到对应键时返回的默认值)。这样可以确保即使词不在词表中,也能返回一个默认值,避免了出现KeyError。

继续看代码:

def encode(self, tokens):
    """
    给定一个字符串s,在头尾分别加上标记开始和结束的特殊字符,并将它转成对应的编号序列
    :param tokens: 待编码字符串
    :return: 编号序列
    """
    # 加上开始标记
    token_ids = [self.token_to_id('[CLS]'), ]
    # 加入字符串编号序列
    for token in tokens:
        token_ids.append(self.token_to_id(token))
    # 加上结束标记
    token_ids.append(self.token_to_id('[SEP]'))
    return token_ids

我们的开始标记调用了我们刚刚定义的token_to_id方法,显然,不可能出现[CLS]这个词,所以得到的是[UNK]对应的编号,显然是一个特殊的编号。
(我们看一下错误的输出,也不算错误,就是对应我们的处理词输出。)

而后遍历tokens中的每个词,将词转化为对应的编号加入到编号序列中,这样我们就可以将我们的汉字类型转化为数字,从而可以进行卷积层的处理。

随后加上结束标记的符号,显然也是对应[UNK]。最后我们返回完整的编号序列。(是一个由数字组成的列表)。

相对应的是解码:

def decode(self, token_ids):
    """
    给定一个编号序列,将它解码成字符串
    :param token_ids: 待解码的编号序列
    :return: 解码出的字符串
    """
    # 起止标记字符特殊处理
    spec_tokens = {'[CLS]', '[SEP]'}
    # 保存解码出的字符的list
    tokens = []
    for token_id in token_ids:
        token = self.id_to_token(token_id)
        if token in spec_tokens:
            continue
        tokens.append(token)
    # 拼接字符串
    return ''.join(tokens)

我们先将特殊字符,也就是开始与结束对应的字符组成一个集合。而后我们创建了一个名为tokens的空列表,用于保存由token_ids中token_id对应词。最后我们使用join方法,将tokens列表中的字符串元素链接起来,形成一个新的字符串,在这里,''表示以空字符串作为连接符,也就是将tokens中的词无缝衔接。

接下来我们定义一些参数,这些参数在setting中已经定义,这里我们直接拿来用:

isallowed_words = settings.DISALLOWED_WORDS
# 句子最大长度
max_len = settings.MAX_LEN
# 最小词频
min_word_frequency = settings.MIN_WORD_FREQUENCY
# mini batch 大小
batch_size = settings.BATCH_SIZEr

然后我们就可以开始加载数据集了:

with open(settings.DATASET_PATH, 'r', encoding='utf-8') as f:
    lines = f.readlines()
    # 将冒号统一成相同格式
    lines = [line.replace(':', ':') for line in lines]
# 数据集列表
poetry = []

通过在setting中已经定义好的路径用只读的方式加载我们的数据,解码的类型是utf-8。f是一个对象,表示被打开的文件。文件对象f会在with代码块结束的时候自动关闭。

lines=f.readlines():这段代码从打开的文件对象f中读取所有行,并将它们存储在名为lines的列表中。(因为我们的数据集很大,所以这一步很耗时间)。
而后我们对我们的诗词进行处理,将所有行中的‘:’转化为‘:’,即格式统一,但是这里其实我们都转化为“:”也是不影响的。
然后我们创建一个数据集列表,也就是空列表。

接着我们开始对每一行(也就是一首诗)进行处理:

for line in lines:
    # 有且只能有一个冒号用来分割标题
    if line.count(':') != 1:
        continue
    # 后半部分不能包含禁止词
    __, last_part = line.split(':')
    ignore_flag = False
    for dis_word in disallowed_words:
        if dis_word in last_part:
            ignore_flag = True
            break
    if ignore_flag:
        continue
    # 长度不能超过最大长度
    if len(last_part) > max_len - 2:
        continue
    poetry.append(last_part.replace('\n', ''))

这里我们首先要参考一下数据的格式:

可见我们的每首诗在:的前面部分是诗词名,后半部分是内容,如果该行不包含:则表示是数据出现错误,这时我们直接跳过该数据,使用continue。对于没有问题的数据,我们使用split方法将数据分为前半部分诗词名(当然,直接丢掉),和第二部分内容(是我们需要的精华)。

我们定义一个布尔类型的变量ignore_flag用来判断是否将这个数据忽视。我们将禁词一一取出,如果禁词在我们的数据中出现,我们将该布尔变量设置为true,也就是要去除该数据,嵌套遍历完成后,我们通过判断布尔变量值来确定是否进行下一步处理,当然没有问题的数据,我们将其保存并进行下一步处理。

我们在进行下一步处理的时候也要进行判断,显然,当我们的数据长度较长的时候,比如(长恨歌),我们也是不需要的,这属于异常数据,我们用它作为参考生成小篇幅诗词无异于读圣经来学习小学的看图写话。

剩下的部分也就是符合我们要求的数据了,对于这些数据,我们直接将他们放进我们的列表中。注意小细节,我们将换行符转化为空格。(官方解释是确保诗词文本在处理之后仍然保持连续的完整性,而不会因为换行符被分割为很多行,有利于后期对文本的处理和分析)(但是我认为这是多余的,因为对于一行数据来代表一首诗词来说,完全没必要考虑换行符的问题)。

嗯嗯...也不是完全没用。

可见,我们生成诗词的时候,如果考虑到换行符的话,我们可以拉开我们生成的诗词的距离。

继续:

counter = Counter()
for line in poetry:
    counter.update(line)

这段代码创建了一个Counter类(计数器对象),它是collections模块中的一个数据结构,用于统计可哈希对象的出现次数。然后循环迭代poetry中的每一行,其中poetry是一个包含多行诗歌的列表。在每次迭代中,counter.update(line)都会被调用,它会将line中的字符添加到计数器中,并更新它们的出现次数,update()方法接受一个可迭代对象作为参数,它会遍历该对象并更新计数器。

最终,counter对象将会包含整个数据集中每个字符出现的次数。我们将通过一个简单的案例来说明counter函数的用法:

from collections import Counter

poetry = [
    "Roses are red,",
    "Violets are blue,",
    "Sugar is sweet,",
    "And so are you."
]

counter = Counter()
for line in poetry:
    counter.update(line)

print(counter)

输出结果如下:

Counter({' ': 15, 'e': 10, 's': 7, 'a': 6, 'o': 5, 'r': 4, 'u': 4, 't': 3, 'd': 2, 'n': 2, 'y': 2, 'w': 2, 'A': 1, 'R': 1, 'V': 1, 'i': 1, 'l': 1, 'b': 1, 'g': 1, ',': 1, 'S': 1, 'I': 1, '.': 1})

输出的是一个Counter对象。

接下来我们接着对词进行处理:

_tokens = [(token, count) for token, count in counter.items() if count >= min_word_frequency]
# 按词频排序
_tokens = sorted(_tokens, key=lambda x: -x[1])
# 去掉词频,只保留词列表
_tokens = [token for token, count in _tokens]

# 将特殊词和数据集中的词拼接起来
_tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + _tokens
# 创建词典 token->id映射关系
token_id_dict = dict(zip(_tokens, range(len(_tokens))))
# 使用新词典重新建立分词器
tokenizer = Tokenizer(token_id_dict)
# 混洗数据
np.random.shuffle(poetry)

我们首先来看第一行,创建了一个列表_tokens,用来包含计数器counter中词频大于等于min_word_frequency的词和它们的出现次数,counter.item返回的是一个键值对,键是词,值是对应的频数。

接下来,我们对_tokens列表进行排序,按照词频从高到低进行降序排序,key=lambda x:-x[1]表示使用每个元素的第二个值,即词频作为进行排序的依据。

_tokens = [token for token, count in _tokens]之后我们将排序后的列表中提取词汇,生成一个只包含词汇的列表,这里丢弃了词频信息,只包含了词汇。

而后我们将一些特殊字符,_tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + _tokens 添加到_tokens列表中,即在词汇列表的最前面。

token_id_dict = dict(zip(_tokens, range(len(_tokens))))然后我们创建一个字典,字典是从词汇到ID的映射关系,当然,前几个索引对应的是特殊词汇,后面按照词汇出现的频率一次对应索引。当然,得到的结果是一个字典。(由词汇到索引)

我们将这个字典传到我们的分词器中,会自动生成由索引到词的映射,以及得到该字典的长度(即词的个数)。

然后我们将我们的诗词的列表进行混洗。

而后我们又定义了一个古诗数据集生成器:

class PoetryDataGenerator:
    """
    古诗数据集生成器
    """

    def __init__(self, data, random=False):
        # 数据集
        self.data = data
        # batch size
        self.batch_size = batch_size
        # 每个epoch迭代的步数
        self.steps = int(math.floor(len(self.data) / self.batch_size))
        # 每个epoch开始时是否随机混洗
        self.random = random

    def sequence_padding(self, data, length=None, padding=None):
        """
        将给定数据填充到相同长度
        :param data: 待填充数据
        :param length: 填充后的长度,不传递此参数则使用data中的最大长度
        :param padding: 用于填充的数据,不传递此参数则使用[PAD]的对应编号
        :return: 填充后的数据
        """
        # 计算填充长度
        if length is None:
            length = max(map(len, data))
        # 计算填充数据
        if padding is None:
            padding = tokenizer.token_to_id('[PAD]')
        # 开始填充
        outputs = []
        for line in data:
            padding_length = length - len(line)
            # 不足就进行填充
            if padding_length > 0:
                outputs.append(np.concatenate([line, [padding] * padding_length]))
            # 超过就进行截断
            else:
                outputs.append(line[:length])
        return np.array(outputs)

    def __len__(self):
        return self.steps

    def __iter__(self):
        total = len(self.data)
        # 是否随机混洗
        if self.random:
            np.random.shuffle(self.data)
        # 迭代一个epoch,每次yield一个batch
        for start in range(0, total, self.batch_size):
            end = min(start + self.batch_size, total)
            batch_data = []
            # 逐一对古诗进行编码
            for single_data in self.data[start:end]:
                batch_data.append(tokenizer.encode(single_data))
            # 填充为相同长度
            batch_data = self.sequence_padding(batch_data)
            # yield x,y
            yield batch_data[:, :-1], tf.one_hot(batch_data[:, 1:], tokenizer.vocab_size)
            del batch_data

    def for_fit(self):
        """
        创建一个生成器,用于训练
        """
        # 死循环,当数据训练一个epoch之后,重新迭代数据
        while True:
            # 委托生成器
            yield from self.__iter__()

我们从头进行分析:

def __init__(self, data, random=False):
    # 数据集
    self.data = data
    # batch size
    self.batch_size = batch_size
    # 每个epoch迭代的步数
    self.steps = int(math.floor(len(self.data) / self.batch_size))
    # 每个epoch开始时是否随机混洗
    self.random = random

我们接受data和可选的random参数,在方法内部,我们将传入的data赋值给self.data,并确定了batch_size属性,我们之后通过数据集的长度和每个批次的长度来计算每一轮训练多少个批次(也就是步数)。
self.random=random表示每个epoch开始时是否随机混洗数据,它的值等于传入的random参数,默认为不随机混洗。

继续看代码:

def sequence_padding(self, data, length=None, padding=None):
    """
    将给定数据填充到相同长度
    :param data: 待填充数据
    :param length: 填充后的长度,不传递此参数则使用data中的最大长度
    :param padding: 用于填充的数据,不传递此参数则使用[PAD]的对应编号
    :return: 填充后的数据
    """
    # 计算填充长度
    if length is None:
        length = max(map(len, data))
    # 计算填充数据
    if padding is None:
        padding = tokenizer.token_to_id('[PAD]')
    # 开始填充
    outputs = []
    for line in data:
        padding_length = length - len(line)
        # 不足就进行填充
        if padding_length > 0:
            outputs.append(np.concatenate([line, [padding] * padding_length]))
        # 超过就进行截断
        else:
            outputs.append(line[:length])
    return np.array(outputs)

我们使用sequence_padding方法,用于将给定的数据填充到相同的长度:

我们传入参数分别是数据,长度,填充的字符。
我们默认填充后的长度是我们数据中的最大长度,这也是我们为什么使用的是64作为最大长度,而将诗词较长的数据进行去除。(不太适合生成长恨歌那样的诗词)。
我们填充的数据编号是PAD对应的编号,即解码的时候对应的也是PAD。
之后我们进行填充,计算出每一行需要填充的长度(归一化长度后的长度减去当前的长度),如果需要进行填充,我们将原数据拼接填充内容作为填充之后的数据。将填充之后的数据放入我们的outputs列表中,否则的话(数据大于我们的最大数据,虽然理论上是不可能的,但是我们也是写一下吧,就只留下到最大长度为止的数据。)

这里值得注意的是,我们传入的是由索引组成的列表。我们得到的也是由数据列表组成的列表,我们通过np.array(outputs)将列表outputs转化为一个numpy数组,其中每个元素对应列表中子列表。便于进一步处理数据。

接下来我们通过__len__来返回步长:

def __len__(self):
    return self.steps

即每轮训练多少个批次,在这里,初始化的时候已经计算好了。

继续哈:

def __iter__(self):
    total = len(self.data)
    # 是否随机混洗
    if self.random:
        np.random.shuffle(self.data)
    # 迭代一个epoch,每次yield一个batch
    for start in range(0, total, self.batch_size):
        end = min(start + self.batch_size, total)
        batch_data = []
        # 逐一对古诗进行编码
        for single_data in self.data[start:end]:
            batch_data.append(tokenizer.encode(single_data))
        # 填充为相同长度
        batch_data = self.sequence_padding(batch_data)
        # yield x,y
        yield batch_data[:, :-1], tf.one_hot(batch_data[:, 1:], tokenizer.vocab_size)
        del batch_data

首先我们获取总样本数,也就是我们的诗词个数,如果self.random=True表示每个epoch开始时需要随机混洗数据集,因此使用np.random.shuffle随机打乱self.data。

然后使用for循环进入每个批次进行训练,(以批次大小为步长遍历数据集,每次迭代都产生一个批次的数据)。用start和end分别表示训练数据开始和结束对应的索引,这里我们要考虑当用累加计算结束位置的时候,不要超过数据的长度。

然后我们逐一对古诗进行编码,将编码得到的结果送入空列表batch_data中,这里要注意我们得到的tokenizer.encode(single_data)是一个由数字组成的列表设为A,然后送入batch_data得到的是一个由A组成的列表,对这个列表进行padding处理,将这个列表中每首诗对应的列表进行扩充。(填充到相同的长度)。

yield batch_data[:, :-1], tf.one_hot(batch_data[:, 1:], tokenizer.vocab_size)
        del batch_data

这一行代码使用yield语句生成一个批次的数据。它返回两个值:batch_data[:,:-1]输入数据,是经过填充的故事序列编码,去掉每个序列的最后一个词,它的形状是(batch_size,sequence_length-1)。(最后一个是句号哦)。
tf.one_hot(batch_data[;,1:],tokenizer.vocab_size)这段代码的目的是将目标数据进行编码,并在这个过程中去掉每个序列中的第一个词,进行独热编码。tokenizer.vocab_size是词汇表的大小,用于确定独热编码的维度。它的形状是(batch_size,sequence_length-1,tokenizer.vocab_size)。

最后del batch_data:

删除批次数据batch_data释放内存,在每次迭代后我们就不需要存储整个批次的数据,因此可以通过删除来释放内存。

为什么删除第一个和最后一个呢?因为我们的起始位置和结束都使用特殊字符进行编码。

最后我们使用:

def for_fit(self):
    """
    创建一个生成器,用于训练
    """
    # 死循环,当数据训练一个epoch之后,重新迭代数据
    while True:
        # 委托生成器
        yield from self.__iter__()

这里我们创建一个死循环,表示生成器会无限制的生成数据,这是为了在训练过程中能持续获取数据,这里使用yield from语法来委托另外一个生成器,即self.__iter__()方法生成的数据,委托生成器的作用是将self.__iter__()生成的数据直接传递给外部的迭代器,作为训练数据。

通过这种方式,当调用for_fit方法时,会得到一个生成器对象,每次迭代该生成器,会从self.__iter__()生成的数据中获取一个批次的训练数据,并将其作为生成器的输出,由于采用了死循环的设置,这个生成器会持续的生成数据,直到外部的训练过程停止或中断。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1512131.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MIT 6.S081---Lab: locks

Memory allocator (moderate) 修改kernel/kalloc.c,修改kmem声明并定义结构体数组: 修改kernel/kalloc.c中的kinit函数,对kmemList进行初始化: 修改kernel/kalloc.c中的kfree函数,获取当前的cpuid并将释放的内存添加到…

互联网架构与通信机制:从边缘到核心的深度解析

✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨ 🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua小谢,在这里我会分享我的知识和经验。&am…

vscode使用npm命令无反应,而终端可以的解决办法

如若你遇到这种情况 使用命令 get-command npm 去下面这个路径把它删掉就可以了

MyBatis拦截器四种类型和自定义拦截器的使用流程

文章目录 MyBatis拦截器四种类型和自定义拦截器的使用流程一、MyBatis拦截器四种类型的详细解释:1. **ParameterHandler 拦截器**:2. **ResultSetHandler 拦截器**:3. **StatementHandler 拦截器**:4. **Interceptor Chain 拦截器…

24-Java策略模式 ( Strategy Pattern )

Java策略模式 摘要实现范例 策略模式的重心不是如何实现算法,而是如何组织、调用这些算法,从而让程序结构更加灵活,具有更好的维护性和扩展性。 策略模式属于行为型模式 摘要 1. 意图 针对一组算法,将每一个算法封装到具有共…

基于Springboot的代驾管理系统(有报告)。Javaee项目,springboot项目。

演示视频: 基于Springboot的代驾管理系统(有报告)。Javaee项目,springboot项目。 项目介绍: 采用M(model)V(view)C(controller)三层体系结构&…

从零搭建Vue项目

目录 环境准备 NodeJS安装 ​编辑 2. 选择安装目录 3. 验证NodeJS环境变量 4. 配置npm的全局安装路径 5. 切换npm的淘宝镜像 6. 安装Vue-cli Vue项目创建 1. 打开UI界面 2. 打开项目管理器 3. 创建项目 vue项目目录结构介绍 运行vue项目 Vue项目开发流程 Vue组…

工具篇--分布式定时任务springBoot 整合 elasticjob使用(3)

文章目录 前言一、Springboot 整合:1.1 引入jar:1.2 配置zookeeper 注册中心:1.3 定义job 业务类:1.4 job 注册到zookeeper:1.5 项目启动:1.5.1 zookeeper 注册中心实例:1.5.2 任务执行日志输出…

RANDOMIZE-IN-PLACE随机排列算法

给定一个长度为 n n n的数组,如何构造出一个随机排列呢?《算法导论》给了我们一个名为RANDOMIZE-IN-PLACE的随机算法,该算法在数组原址上进行排序,时间复杂度为 O ( n ) O(n) O(n)。下面本文将介绍RANDOMIZE-IN-PLACE的设计思想及…

代码随想录(day3)——链表

Leetcode.203 移除链表元素: 203. 移除链表元素 - 力扣(LeetCode) 对于本题,难点就在于对于头部结点的删除,以及给定链表为空时,如何进行遍历。因为需要遍历链表,假设访问链表下一个结点所对应…

开源绘图工具 PlantUML 入门教程(常用于画类图、用例图、时序图等)

文章目录 一、类图二、用例图三、时序图 一、类图 类的UML图示 startuml skinparam classAttributeIconSize 0 class Dummy {-field1 : String#field2 : int~method1() : Stringmethod2() : void } enduml定义能见度(可访问性) startumlclass Dummy {-f…

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的行人跌倒检测系统(深度学习+UI界面+完整训练数据集)

摘要:开发行人跌倒检测系统在确保老年人安全方面扮演着至关重要的角色。本篇文章详尽地阐述了如何利用深度学习技术构建一个行人跌倒检测系统,并附上了完整的代码实现。该系统采用了先进的YOLOv8算法,并对YOLOv7、YOLOv6、YOLOv5等先前版本进…

ARM64汇编05 - MOV系列指令

MOV(wide immediate) MOV 可以将一个立即数移动到寄存器中。 .text:0000000000000834 80 46 82 D2 MOV X0, #0x1234 ; Keypatch modified this from:MOV X0, #0x1234 对应的汇编代码为:80 46 82 D2 看手册可知&#xf…

多维时序 | Matlab实现VMD-CNN-BiLSTM变分模态分解结合卷积神经网络结合双向长短期记忆神经网络多变量时间序列预测

多维时序 | Matlab实现VMD-CNN-BiLSTM变分模态分解结合卷积神经网络结合双向长短期记忆神经网络多变量时间序列预测 目录 多维时序 | Matlab实现VMD-CNN-BiLSTM变分模态分解结合卷积神经网络结合双向长短期记忆神经网络多变量时间序列预测预测效果基本介绍程序设计参考资料 预测…

利用“定时执行专家”软件的25种任务与12种触发器,提升IT系统管理自动化水平

在IT系统管理中,自动化是提高工作效率、减少人为错误的关键。而《定时执行专家》这款软件,以其强大的功能、易用性和毫秒级的执行精度,成为了IT系统管理员的得力助手。今天,我们就来探讨一下如何利用这款软件的25种任务类型和12种…

如何在Linux系统安装SVN并配置固定公网地址远程访问【内网穿透】

文章目录 前言1. Ubuntu安装SVN服务2. 修改配置文件2.1 修改svnserve.conf文件2.2 修改passwd文件2.3 修改authz文件 3. 启动svn服务4. 内网穿透4.1 安装cpolar内网穿透4.2 创建隧道映射本地端口 5. 测试公网访问6. 配置固定公网TCP端口地址6.1 保留一个固定的公网TCP端口地址6…

2024-03-11,12(HTML,CSS)

1.HTML的作用就是在浏览器摆放内容。 2.HTML基本骨架 head:网页头部,是给浏览器看的代码,例如CSS body:网页主体,是给用户看的代码,例如图片,文字。 title:网页标题 3.标签的两种…

Redis中set,zset

集合类型set中的数据是无序的,不能重复的 SET SADD key value [value....] 将一个或者多个元素添加到集合set中,重复的元素是无法进行添加的 返回值为添加成功的数字smembers key 获取set中所有的元素,返回元素的顺序是无序的sismember key…

React Hooks 那些事儿

翻了波之前写的文章还有笔记,发现关于前端的文章并不多(好歹也划水做过点前端开发)。巧了,最近没什么好话题可写,做下 React Hooks 学习笔记吧。 Effect Hook 不得不说 Hook 的出现降低了我们在 React 中处理副作用&…

【漏洞复现】SpringBlade error/list SQL 注入漏洞

免责声明:文章来源互联网收集整理,请勿利用文章内的相关技术从事非法测试,由于传播、利用此文所提供的信息或者工具而造成的任何直接或者间接的后果及损失,均由使用者本人负责,所产生的一切不良后果与文章作者无关。该…