【古诗生成AI实战】之四——模型包装器与模型的训练

news2024/11/28 18:42:26

  在上一篇博客中,我们已经利用任务加载器task成功地从数据集文件中加载了文本数据,并通过预处理器processor构建了词典和编码器。在这一过程中,我们还完成了词向量的提取。

  接下来的步骤涉及到定义模型、加载数据,并开始训练过程。

  为了确保项目代码能够快速切换到不同的模型,并且能够有效地支持transformers库中的预训练模型,我们不仅仅是定义模型那么简单。为此,我们采取了进一步的措施:在模型外面再套上一个额外的层,我称之为模型包装器NNModelWrapper。此外,为了提高配置的灵活性和可维护性,我们将所有的配置项(如批量大小、数据集地址、训练周期数、学习率等)抽取出来,统一放置在一个名为WrapperConfig的配置容器中。通过这种方式,我们就可以避免直接在代码中修改配置参数,而是通过更改配置文件来实现,从而使得整个项目更加模块化和易于管理。

  本章内容属于模型训练阶段,将分别介绍包装器配置WrapperConfig、模型包装器NNModelWrapper和模型Model

在这里插入图片描述

[1] 包装器配置WrapperConfig

  我们把配置全部放在yaml文件里,然后读取里面的配置,赋值给WrapperConfig类。定义如下:

class WrapperConfig(object):
    """A configuration for a :class:`NNModelWrapper`."""

    def __init__(
            self,
            tokenizer,
            max_seq_len: int,
            vocab_num: int,
            word2vec_path: str,
            batch_size: int = 1,
            epoch_num: int = 1,
            learning_rate: float = 0.001
    ):
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.batch_size = batch_size
        self.epoch_num = epoch_num
        self.learning_rate = learning_rate
        self.word2vec_path = word2vec_path
        self.vocab_num = vocab_num

   WrapperConfig 类用于配置神经网络模型包装器(NNModelWrapper)。类的构造函数接受多个参数来初始化配置:

  tokenizer: 分词器对象,用于文本处理或文本转换为模型可理解的格式。其实就是预处理器processor提供的tokenizer

  max_seq_len (int): 模型可以处理的最大序列长度。

  vocab_num (int): 词汇表的大小。

  word2vec_path (str):预训练的词向量模型的文件路径。即上文提取的词向量。

  batch_size (int): 每个批次处理的数据样本数量。

  epoch_num (int): 训练轮次。

  learning_rate (float): 学习率。

[2] 模型包装器NNModelWrapper

  模型包装器NNModelWrapper接受2个参数,一个是包装器配置WrapperConfig,另外一个是自定义模型Model。代码如下:

class NNModelWrapper:
    """A wrapper around a Transformer-based language model."""

    def __init__(self, config: WrapperConfig, model):
        """Create a new wrapper from the given config."""
        self.config = config
        self.model = model(self.config)

    def generate_dataset(self, data, labeled=True):
        """Generate a dataset from the given examples."""
        features = self._convert_examples_to_features(data)

        feature_dict = {
            'input_ids': torch.tensor([f.input_ids for f in features], dtype=torch.long),
            'labels': torch.tensor([f.labels for f in features], dtype=torch.long),
        }

        if not labeled:
            del feature_dict['labels']

        return DictDataset(**feature_dict)

    def _convert_examples_to_features(self, examples) -> List[InputFeatures]:
        """Convert a set of examples into a list of input features."""
        features = []
        for (ex_index, example) in tqdm(enumerate(examples)):
            if ex_index % 5000 == 0:
                logging.info("Writing example {}".format(ex_index))
            input_features = self.get_input_features(example)
            features.append(input_features)
        # logging.info(f"最终数据构造形式:{features[0]}")
        return features

    def get_input_features(self, example) -> InputFeatures:
        """Convert the given example into a set of input features"""
        text = example.text
        input_ids = self.config.tokenizer(text)
        labels = np.copy(input_ids)
        labels[:-1] = input_ids[1:]

        assert len(input_ids) == self.config.max_seq_len

        return InputFeatures(input_ids=input_ids, attention_mask=None, token_type_ids=None, labels=labels)

   NNModelWrapper 类是围绕一个神经网络语言模型的封装器,提供了模型的初始化和数据处理的方法。

  · 类初始化 (init):
  config: 接收一个 WrapperConfig 类的实例,包含模型的配置信息。
  model: 接收一个模型构造函数,该函数使用配置信息来初始化模型。

  · 生成数据集 (generate_dataset):从给定的数据样本中生成一个数据集。首先把数据样本转换为特征(通过 _convert_examples_to_features 方法),然后根据这些特征创建一个 DictDataset 对象。如果数据未标记(labeled=False),则从特征字典中删除 labels 键。

  · 转换样本为特征 (_convert_examples_to_features):这是个私有方法,把数据样本转换为模型可以理解的输入特征。对于每个样本,使用 get_input_features 方法来生成输入特征。使用 tqdm 显示处理进度,并利用 logging 记录处理信息。

  · 获取输入特征 (get_input_features):此方法将单个数据样本转换为输入特征。首先获取文本内容,然后使用配置中的分词器(tokenizer)将文本转换为 input_ids。标签(labels)是 input_ids 的一个变体,其中每个元素都向右移动一个位置。用断言确保 input_ids 的长度与配置中的 max_seq_len 相等。

[3] 模型Model

  模型包装器NNModelWrapper里面的第二个参数Model才是我们真正的模型。

  在古诗生成AI任务中,RNN是比较适配任务的模型,我们定义的RNN模型如下:

class Model(nn.Module):

    def __init__(self, config):
        super(Model, self).__init__()

        V = config.vocab_num  # vocab_num
        E = 300  # embed_dim
        H = 256  # hidden_size

        embedding_pretrained = torch.tensor(np.load(config.word2vec_path)["embeddings"].astype('float32'))
        self.embedding = nn.Embedding.from_pretrained(embedding_pretrained, freeze=False)
        self.lstm = nn.LSTM(E, H, 1, bidirectional=False, batch_first=True, dropout=0.1)
        self.fc = nn.Linear(H, V)
        self.loss = nn.CrossEntropyLoss()

    def forward(self, input_ids, labels=None):
        embed = self.embedding(input_ids)  # [batch_size, seq_len, embed_dim]
        out, _ = self.lstm(embed)  # [batch_size, seq_len, hidden_size]
        logit = self.fc(out)  # [batch_size, seq_len, vocab_num]

        if labels is not None:
            loss = self.loss(logit.view(-1, logit.shape[-1]), labels.view(-1))
            return loss, logit
        else:
            return logit[None, :]

  在我们的模型中,特别值得一提的是嵌入层(embedding layer)。在初始化这一层时,我们使用的是之前提取出的词向量。这种做法有助于模型更好地理解和处理文本数据。

  此次我们定义的模型是一个基于RNN的结构,它包括三个主要部分:embedding层、lstm层和fc(全连接)层。

  在模型的前向传播(forward)过程中,输入input_ids的形状为[batch_size, seq_len],即每个批次有多少文本,每个文本的序列长度是多少。输入数据首先通过嵌入层处理,输出的embed的形状为[batch_size, seq_len, embed_dim],即每个单词都被转换成了对应的嵌入向量。接着,数据通过一个单层的lstm网络,得到输出out,最后经过全连接层fc,得到最终的概率分布logit

  这个概率分布logit的含义是:对于每个批次中的文本,每个文本在序列的每个位置上,都有vocab_num个可能的词可以填入,而logit中存储的正是这些词的概率。为了生成文本,我们提取每个位置上概率最高的词的索引,然后根据这些索引在词典中查找对应的词。这就是我们通过模型运行文本生成得到的结果。

[4] 训练

  所有的工作都准备好了,下面我们正式开始模型的训练。

  对于神经网络的训练、验证、测试、优化等等操作,我采用了transformersTrainer极大的简化了项目操作。

  第一步,加载yaml配置文件,读取所有配置项:

    with open('config.yaml', 'r', encoding='utf-8') as f:
        conf = yaml.load(f.read(),Loader=yaml.FullLoader)
        conf_train = conf['train']
        conf_sys = conf['sys']

  第二步,初始化任务加载器,加载数据集:

	Task = TASKS[conf_train['task_name']]()

    data = Task.get_train_examples(conf_train['dataset_url'])
    index = int(len(data) * conf_train['rate'])
    train_data, dev_data = data[:index], data[index:]

  第三步,初始化数据预处理器,并向外提供tokenizer

	Processor = PROCESSORS[conf_train['task_name']](data, conf_train['max_seq_len'], conf_train['vocab_path'])
    tokenizer = lambda text: Processor.tokenizer(text)

  第四步,初始化模型包装配置:

	wrapper_config = WrapperConfig(
        tokenizer=tokenizer,
        max_seq_len=conf_train['max_seq_len'],
        batch_size=conf_train['batch_size'],
        epoch_num=conf_train['epoch_num'],
        learning_rate=conf_train['learning_rate'],
        word2vec_path=conf_train['word2vec_path'],
        vocab_num=len(Processor.vocab)
    )

  第五步,加载模型,初始化模型包装器:

	x = import_module(f'main.model.{conf_train["model_name"]}')
    wrapper = NNModelWrapper(wrapper_config, x.Model)

    print(f'模型有 {sum(p.numel() for p in wrapper.model.parameters() if p.requires_grad):,} 个训练参数')

  第六步,使用模型包装器生成数据集向量:

train_dataset = wrapper.generate_dataset(train_data)
    val_dataset = wrapper.generate_dataset(dev_data)

  第七步,创建训练器:

# 构建trainer
def create_trainer(wrapper, train_dataset, val_dataset):
    # 模型
    model = wrapper.model

    args = TrainingArguments(
        './checkpoints',  # 模型保存的输出目录
        save_strategy=IntervalStrategy.STEPS,  # 模型保存策略
        save_steps=50,  # 每n步保存一次模型  1步表示一个batch训练结束
        evaluation_strategy=IntervalStrategy.STEPS,
        eval_steps=50,
        overwrite_output_dir=True,  # 设置overwrite_output_dir参数为True,表示覆盖输出目录中已有的模型文件
        logging_dir='./logs',  # 可视化数据文件存储地址
        log_level="warning",
        logging_steps=50,  # 每n步保存一次评价指标  1步表示一个batch训练结束 | 还控制着控制台的打印频率 每n步打印一下评价指标 | n过大时,只会保存最后一次的评价指标
        disable_tqdm=True,  # 是否不显示数据训练进度条
        learning_rate=wrapper.config.learning_rate,
        per_device_train_batch_size=wrapper.config.batch_size,
        per_device_eval_batch_size=wrapper.config.batch_size,
        num_train_epochs=wrapper.config.epoch_num,
        dataloader_num_workers=2,  # 数据加载的子进程数
        weight_decay=0.01,
        save_total_limit=2,
        load_best_model_at_end=True
    )

    # 早停设置
    early_stopping = EarlyStoppingCallback(
        early_stopping_patience=3,  # 如果8次验证集性能没有提升,则停止训练
        early_stopping_threshold=0,  # 验证集的性能提高不到0时也停止训练
    )

    trainer = Trainer(
        model,
        args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        callbacks=[early_stopping],  # 添加EarlyStoppingCallback回调函数
    )
    return trainer
    
trainer = create_trainer(wrapper, train_dataset, val_dataset)

  第八步,开始训练并设置保存模型:

	trainer.train()
    trainer.save_model(conf_train['model_save_dir'] + conf_train['task_name'] + '/' + conf_train['model_name'])

  训练的整体代码如下:

# 构建trainer
def create_trainer(wrapper, train_dataset, val_dataset):
    # 模型
    model = wrapper.model

    args = TrainingArguments(
        './checkpoints',  # 模型保存的输出目录
        save_strategy=IntervalStrategy.STEPS,  # 模型保存策略
        save_steps=50,  # 每n步保存一次模型  1步表示一个batch训练结束
        evaluation_strategy=IntervalStrategy.STEPS,
        eval_steps=50,
        overwrite_output_dir=True,  # 设置overwrite_output_dir参数为True,表示覆盖输出目录中已有的模型文件
        logging_dir='./logs',  # 可视化数据文件存储地址
        log_level="warning",
        logging_steps=50,  # 每n步保存一次评价指标  1步表示一个batch训练结束 | 还控制着控制台的打印频率 每n步打印一下评价指标 | n过大时,只会保存最后一次的评价指标
        disable_tqdm=True,  # 是否不显示数据训练进度条
        learning_rate=wrapper.config.learning_rate,
        per_device_train_batch_size=wrapper.config.batch_size,
        per_device_eval_batch_size=wrapper.config.batch_size,
        num_train_epochs=wrapper.config.epoch_num,
        dataloader_num_workers=2,  # 数据加载的子进程数
        weight_decay=0.01,
        save_total_limit=2,
        load_best_model_at_end=True
    )

    # 早停设置
    early_stopping = EarlyStoppingCallback(
        early_stopping_patience=3,  # 如果8次验证集性能没有提升,则停止训练
        early_stopping_threshold=0,  # 验证集的性能提高不到0时也停止训练
    )

    trainer = Trainer(
        model,
        args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        callbacks=[early_stopping],  # 添加EarlyStoppingCallback回调函数
    )
    return trainer



def main():
    # ##
    # @通用配置
    # ##
    with open('config.yaml', 'r', encoding='utf-8') as f:
        conf = yaml.load(f.read(),Loader=yaml.FullLoader)
        conf_train = conf['train']
        conf_sys = conf['sys']

    # 系统设置初始化
    System(conf_sys).init_system()

    # 初始化任务加载器
    Task = TASKS[conf_train['task_name']]()

    data = Task.get_train_examples(conf_train['dataset_url'])
    index = int(len(data) * conf_train['rate'])
    train_data, dev_data = data[:index], data[index:]

    # 初始化数据预处理器
    Processor = PROCESSORS[conf_train['task_name']](data, conf_train['max_seq_len'], conf_train['vocab_path'])
    tokenizer = lambda text: Processor.tokenizer(text)

    # 初始化模型包装配置
    wrapper_config = WrapperConfig(
        tokenizer=tokenizer,
        max_seq_len=conf_train['max_seq_len'],
        batch_size=conf_train['batch_size'],
        epoch_num=conf_train['epoch_num'],
        learning_rate=conf_train['learning_rate'],
        word2vec_path=conf_train['word2vec_path'],
        vocab_num=len(Processor.vocab)
    )

    x = import_module(f'main.model.{conf_train["model_name"]}')
    wrapper = NNModelWrapper(wrapper_config, x.Model)

    print(f'模型有 {sum(p.numel() for p in wrapper.model.parameters() if p.requires_grad):,} 个训练参数')

    # 生成数据集
    train_dataset = wrapper.generate_dataset(train_data)
    val_dataset = wrapper.generate_dataset(dev_data)

    # 训练与保存
    trainer = create_trainer(wrapper, train_dataset, val_dataset)
    trainer.train()
    trainer.save_model(conf_train['model_save_dir'] + conf_train['task_name'] + '/' + conf_train['model_name'])

if __name__ == '__main__':
    main()

  运行之后,看到下面输出代表项目成功运行:

在这里插入图片描述

[5] 进行下一篇实战

  【古诗生成AI实战】之五——加载模型进行古诗生成

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1256731.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何申请永久免费的SSL证书

首先,让我们了解什么是SSL证书。 SSL(Secure Socket Layer)证书是一种数字证书,它提供了一种在互联网上安全地传输数据的方法。 这是一个必须的安全工具,可以加密您的网站和客户之间的所有信息。为了保护用户数据和确保…

消息队列使用场景、概念和原理

文章目录 1 使用消息队列的场景1.1 消息队列的异步处理1.2 消息队列的流量控制(削峰)1.3 消息队列的服务解耦1.4 消息队列的发布订阅1.5 消息队列的高并发缓冲 2 消息队列的基本概念和原理2.1 消息的生产者和消费者2.2 Broker2.3 点对点消息队列模型 ---…

059-第三代软件开发-巧用工控板LED指示灯引脚

第三代软件开发-巧用工控板LED指示灯引脚 文章目录 第三代软件开发-巧用工控板LED指示灯引脚项目介绍巧用工控板LED指示灯引脚第一种方式第二种方式 总结 关键字: Qt、 Qml、 Power、 继电器、 IO 项目介绍 欢迎来到我们的 QML & C 项目!这个项…

【Bootloader学习理解学习--加强版】

笔者在接着聊一下bootloader,主要针对MCU的Bootloader。 笔者之前介绍过一篇Bootloader文章,主要是其概念、一些升级包的格式和升级流程,本次接着来说一下。 1、MCU代码运行方式 之前文章也介绍过,MCU的代码运行方式有两种&…

二级分类菜单及三级分类菜单的层级结构返回

前言 在开发投诉分类功能模块时,遇到过这样一个业务场景:后端需要按层级结构返回二级分类菜单所需数据,换言之,将具有父子关系的List结果集数据转为树状结构数据来返回 二级分类菜单 前期准备 这里简单复刻下真实场景中 出现的…

二十六、搜索结果处理(排序、分页、高亮)

目录 一、排序 二、分页 1、深度分页问题 2、三种方案的优缺点 (1)fromsize 优点: 缺点: 场景: (2)after search 优点: 缺点: 场景: &#xff0…

hive杂谈

数据仓库是一个面向主题的、集成的、非易失的、随时间变化的,用来支持管理人员决策的数据集合,数据仓库中包含了粒度化的企业数据。 数据仓库的主要特征是:主题性、集成性、非易失性、时变性。 数据仓库的体系结构通常包含4个层次&#xff…

03-《人月神话》巴赫、UML和领域驱动设计伪创新:中译本纠错及联想

DDD领域驱动设计批评文集 做强化自测题获得“软件方法建模师”称号 《软件方法》各章合集 2001年,我们翻译《人月神话》的时候,由于水平有限,译文中存在不少错误。 这些年,随着阅历的增长,在重读的时候偶尔也会有“…

推动卓越创新:了解 4 种研发团队架构如何优化您的组织

揭示敏捷实践中常犯的12大错误,了解如何避免这些敏捷失败 陷阱,找出问题根源并采取有效改进措施,提高项目成功率。立即连线 Runwise.co 社区敏捷专家获得专业建议,或 Runwise.co 在线学习敏捷方法实战课程,提升您和团队…

go当中的channel 无缓冲channel和缓冲channel的适用场景、结合select的使用

Channel Go channel就像Go并发模型中的“胶水”,它将诸多并发执行单元连接起来,或者正是因为有channel的存在,Go并发模型才能迸发出强大的表达能力。 无缓冲channel 无缓冲channel兼具通信和同步特性,在并发程序中应用颇为广泛。…

代码随想录算法训练营 ---第四十五天

前言: 昨天的题做过之后,今天的题基本上都很简单,但是要注重一下细节。 第一题: 简介: 动态规划五部曲: 1.确定dp数组的含义 dp[i]:爬到有i个台阶的楼顶,有dp[i]种方法 2.确定dp…

CSS之弹性盒子Flexible Box

我想大家在做布局的时候,没接触flex布局之前,大家都是用浮动来布局的,但现在我们接触了flex布局之后,我只能说:“真香”。让我为大家介绍一下弹性盒子模型吧! Flexible Box 弹性盒子 在我们使用弹性盒子时&…

泛型你掌握多少?包装类你深入了解过吗?快进来看看吧~

目录 1、泛型是什么——引出泛型 2、泛型的使用 2.1、语法 2.2泛型类的使用 2.3、裸类型 3、泛型如何编译 3.1、擦除机制 3.2、为什么不能实例化泛型类型数组 4、泛型的上界 5、泛型方法 5.1、语法 5.2、举例 6、通配符 6.1、什么是通配符 6.2、统配符解决了什么…

2017年五一杯数学建模C题宜居城市问题值解题全过程文档及程序

2017年五一杯数学建模 C题 宜居城市问题 原题再现 城市宜居性是当前城市科学研究领域的热点议题之一,也是政府和城市居民密切关注的焦点。建设宜居城市已成为现阶段我国城市发展的重要目标,对提升城市居民生活质量、完善城市功能和提高城市运行效率具有重要意义。…

正则表达式回溯陷阱

一、匹配场景 判断一个句子是不是正规英文句子 text "I am a student" 一个正常的英文句子如上,英文单词 空格隔开 英文单词 多个英文字符 [a-zA-Z] 空格用 \s 表示 那么一个句子就是单词 空格(一个或者多个,最后那个单词…

An example of a function uniformly continuous on R but not Lipschitz continuous

See https://math.stackexchange.com/questions/69457/an-example-of-a-function-uniformly-continuous-on-mathbbr-but-not-lipschitz?noredirect1

十七、事件组

1、事件组是什么 1.1、举例说明 (1)学校组织秋游,组长在等待: 张三:我到了李四:我到了王五:我到了组长说:好,大家都到齐了,出发! (2)秋游回来第二天就要提交一篇心得…

leetcode刷题详解五

117. 填充每个节点的下一个右侧节点指针 II 关键点:先递归右子树 画一下就知道了,画一个四层的二叉树,然后右子树多画几个节点就知道为啥了 Node* connect(Node* root) {if(!root || (!root->left && !root->right)){return ro…

针对操作系统漏洞的反馈方法

一、针对操作系统漏洞的反馈方法 漏洞扫描指基于漏洞数据库,通过扫描等手段对指定的远程或者本地计算机系统的安全脆弱性进行检测,发现可利用漏洞的一种安全检测(渗透攻击)行为。在进行漏洞扫描后,需先确定哪些是业务…

二叉树:leetcode1457. 二叉树中的伪回文路径

给你一棵二叉树,每个节点的值为 1 到 9 。我们称二叉树中的一条路径是 「伪回文」的,当它满足:路径经过的所有节点值的排列中,存在一个回文序列。 请你返回从根到叶子节点的所有路径中 伪回文 路径的数目。 给定二叉树的节点数目…