昇思25天学习打卡营第12天|基于MindSpore的GPT2文本摘要

news2024/9/24 13:14:53

基于MindSpore的GPT2文本摘要

数据集加载

使用nlpcc2017摘要数据,共包含5万个样本,内容是新闻正文及其摘要。

from mindnlp.utils import http_get
from mindspore.dataset import TextFileDataset

# 下载数据集
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')


# 加载数据集
dataset = TextFileDataset(str(path), shuffle=False)

数据预处理


train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)

按9:1划分测试集与训练集,randomize表示不对数据进行随机排序,按照原顺序直接拆分。

原始数据格式是:
article: [CLS] article_context [SEP]
summary: [CLS] summary_context [SEP]

期望的预处理后的数据格式是:
[CLS] article_context [SEP] summary_context [SEP]

import json
import numpy as np


def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):
	# 加载json格式的数据并转成numpy数组
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

	# 使用分词器处理artical和summary。
	# text=article表示主文本输入,text_pair=summary表示辅助文本
	# padding指将输入序列填充或截断的最大长度
	# truncation指定截断策略,only_first表示指仅截断主文本(article)。
	# 通常是主文本较长需要截断,而辅助文本较短并且需要完整保留。
    def merge_and_pad(article, summary):
          tokenized = tokenizer(text=article, text_pair=summary,
                              padding='max_length', truncation='only_first', max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['input_ids']
    
	# 提取article和summary
    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
	#  将列名修改为input_ids和labels
    dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])
	# 将数据进行分批
    dataset = dataset.batch(batch_size)
	# 如果shuffle是true,则打乱数据
    if shuffle:
        dataset = dataset.shuffle(batch_size)

    return dataset

这里的tokenizer使用BertTokenizer,因为GPT2没有中文的分词器。

from mindnlp.transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
# 使用刚刚定义的函数分成四批进行处理。
train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)

模型构建

from mindspore import ops
from mindnlp.transformers import GPT2LMHeadModel
​
class GPT2ForSummarization(GPT2LMHeadModel):
    def construct(
        self,
        input_ids = None,
        attention_mask = None,
        labels = None,
    ):
        outputs=super().construct(input_ids=input_ids, attention_mask=attention_mask)
        shift_logits=outputs.logits[..., :-1, :]
        shift_labels=labels[..., 1:]
       
        loss=ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)
        return loss

具体解释一下以上代码:

shift 的作用使为了对齐预测和标签,使模型输出和标签对应,从而得到每个位置的预测误差。
举例 I love program
在自回归模型中,模型会根据前文逐步预测(不熟悉的盆友可以看一下上一篇文章:自回归模型与文本生成方法)

  • 输入 “I”,输出 “love”
  • 输入 “I love”,输出 “programming”

也就是位置1的输出对应的是位置2(的标签),位置2的输出对应位置3(的标签)

对应到实际的数据就是

  • outputs.logits[…, :-1, :]
    去除 logits 的最后一个时间步,因为没有标签与之对应。继续以上面的为例,就是去掉programming,因为programing没有后面输出了。
  • labels[…, 1:]
    去除 labels 的第一个时间步,因为没有预测值与之对应。去掉I,因为I前面没有输入,因此也不是输出的一部分。

由此完成了shift错位操作。

搞好数据结构之后再计算损失

  1. 将shift_logits形状调整成二维张量,使每一行对应一个token的预测分布。
  2. 将shift_labels形状调整为一维张量,使每个元素对应一个标签。
  3. 使用cross_entropy()计算交叉熵损失,忽略填充token的损失。

定义学习率warmup

from mindspore import ops
from mindspore.nn.learning_rate_schedule import LearningRateSchedule

class LinearWithWarmUp(LearningRateSchedule):
    """
    Warmup-decay learning rate.
    """
    def __init__(self, learning_rate, num_warmup_steps, num_training_steps):
        super().__init__()
        self.learning_rate = learning_rate
        self.num_warmup_steps = num_warmup_steps
        self.num_training_steps = num_training_steps

    def construct(self, global_step):
        if global_step < self.num_warmup_steps:
            return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate
        return ops.maximum(
            0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))
        ) * self.learning_rate

这里定义了LinearWithWarmUp作为自定义学习率调度类。它可以在训练的初始阶段进行学习率的线性预热,再在剩余的训练步骤中线性衰减

初始化 预热步数、训练步数、学习率
构建时如果步数小于预热步数则开始进行线性增长,从0增长到learning rate。
如果步数大于等于时,则进行线性衰减,由learning rate变回0。maximum保证不会降低到0以下。

模型训练

内容详见注释内容

# 初始化参数
num_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4

#训练的步数=数据集的大小 乘以 准备完整遍历数据集的次数(epoch)
num_training_steps = num_epochs * train_dataset.get_dataset_size()

from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel

config = GPT2Config(vocab_size=len(tokenizer))
# 初始化一个用于文本摘要的GPT2模型
model = GPT2ForSummarization(config)

# 初始化学习率调度器和adam优化器
lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)

from mindnlp._legacy.engine import Trainer
from mindnlp._legacy.engine.callbacks import CheckpointCallback

# 设置检查点回调,用于在训练过程中保存模型检查点
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization',
                                epochs=1, keep_checkpoint_max=2)

# 初始化训练器,设置模型,训练集,训练轮次,优化器和回调函数
trainer = Trainer(network=model, train_dataset=train_dataset,
                  epochs=1, optimizer=optimizer, callbacks=ckpoint_cb)
# 开启混合精度训练,以提高训练速度和节省显存
trainer.set_amp(level='O1')  
# 开始训练,并指定目标列(labels)作为标签。
trainer.run(tgt_columns="labels")

模型推理

def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):
# 处理测试集的过程和训练集差不多
# 依然是提取出article和sumarization
# 再进行分词,制定最大长度和截断
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def pad(article):
        tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)
        return tokenized['input_ids']

    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    dataset = dataset.map(pad, 'article', ['input_ids'])
    
    dataset = dataset.batch(batch_size)

    return dataset

test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)
# 加载已经训练好的模型
model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)
# 设为非训练模式,可以禁用掉一些训练相关的操作如dropout
model.set_train(False)
# 模型的结束标记设置成分隔符标记,这样生成的文本遇到分隔符就会终止
model.config.eos_token_id = model.config.sep_token_id
i = 0

for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():
    output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)
    output_text = tokenizer.decode(output_ids[0].tolist())
    i += 1

总结

本章介绍了使用GPT2进行文本总结任务的基本流程,包括数据导入、数据预处理、模型训练、和模型推理。

打卡凭证

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1910367.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

简过网:2024年一级造价工程师正在报名中,看看你有报考资格吗?

计划报考2024一级造价工程师的小伙伴要注意了&#xff0c;现在一造报名正在启动中&#xff0c;想考试的小伙伴一定要看清楚这些报考条件和考试内容哦&#xff0c;今天&#xff0c;小编和大家一块来分享一下&#xff0c;希望对你有帮助。 几个简单的问题&#xff0c;让你彻底了解…

67.SAP FICO-凭证类型学习

目录 SAP凭证类型 凭证类型的作用 - OBA7 SAP默认的凭证类型更改 FI相应事务代码默认凭证类型 - OBU1 对FB50、60、70默认凭证类型的更改 - OBZO 后勤货物移动默认凭证类型 - OMBA 发货凭证类型 收货凭证类型 自动移动凭证类型 存货盘点凭证类型 发票默认的凭证类…

vue项目本地开启https协议访问(vite)

官网介绍&#xff1a;vite官方文档 1、根据官方文档安装依赖vitejs/plugin-basic-ssl npm install -D vitejs/plugin-basic-ssl2、在vite.config.js或者vite.config.ts中配置&#xff1a;server中的https和plugins import { defineConfig } from "vite"; import b…

AURORA仿真

AURORA 仿真验证 定义&#xff1a;AURORA是一种高速串行通信协议&#xff0c;通常用于在数字信号处理系统和其他电子设备之间传输数据。它提供了一种高效的方式来传输大量数据&#xff0c;通常用于需要高带宽和低延迟的应用中。AURORA协议通常由Xilinx公司的FPGA器件支持&#…

KBPC5010-ASEMI逆变焊机专用KBPC5010

编辑&#xff1a;ll KBPC5010-ASEMI逆变焊机专用KBPC5010 型号&#xff1a;KBPC5010 品牌&#xff1a;ASEMI 封装&#xff1a;KBPC-4 正向电流&#xff08;Id&#xff09;&#xff1a;50A 反向耐压&#xff08;VRRM&#xff09;&#xff1a;1000V 正向浪涌电流&#xff…

【YOLO格式的数据标签,目标检测】

标签为 YOLO 格式&#xff0c;每幅图像一个 *.txt 文件&#xff08;如果图像中没有对象&#xff0c;则不需要 *.txt 文件&#xff09;。*.txt 文件规格如下: 每个对象一行 每一行都是 class x_center y_center width height 格式。 边框坐标必须是 归一化的 xywh 格式&#x…

AutoMQ 产品动态 | 发布 1.1.0,兼容至 Apache Kafka 3.7,支持 Kaf

2024年05-06月动态 01 云原生 Kafka 内核 AutoMQ 发布 1.1.0 版本&#xff0c;兼容 Apache Kafka 到 3.7 版本。 AutoMQ 采用存算分离的思路重构 Apache Kafka&#xff0c;计算层 100% 保留 Kafka 代码&#xff0c;可以更好、更快地适配 Apache Kafka 新版本。经过代码 Revi…

三菱FX3U转OPC UA服务器的配置指南

在工业4.0的浪潮中&#xff0c;传统企业正面临着前所未有的转型压力&#xff0c;其中最为紧迫的挑战之一便是如何高效便捷地将PLC数据无缝集成到PC端上位软件中。我们知道&#xff0c;如果直接从PLC采集数据&#xff0c;不仅涉及复杂的技术难题&#xff0c;如繁琐的软件开发和冗…

MySQL之备份与恢复(五)

备份与恢复 备份数据 符号分隔文件备份 可以使用SQL命令SELECT INTO OUTFILE以符号分隔文件格式创建数据的逻辑备份。(可以用mysqldump的 --tab选项导出到符号分隔文件中)。符号分隔文件包含以ASCII展示的原始数据&#xff0c;没有SQL、注释和列名。下面是一个导出为逗号分隔…

请编写函数,判断一字符串是否是回文,若是回文函数返回值为1,否则返回值为0,回文是顺读和倒读都一样的字符串

int gets_arr(char* p) {int i 0;int j strlen(p) - 1;while (i < j && p[i] p[j]){i;j--;}if (i<j){return 0;}else {return 1;}} int main() {printf("请输入一串字符串\n");char arr[100];gets(arr);int ret gets_arr(arr);if (ret 1){printf(…

如何降低pdf内存,如何降低pdf內存大小,如何降低pdf内存占用

在现代办公环境中&#xff0c;pdf文件已经成为了一种不可或缺的文档格式。然而&#xff0c;pdf内存太大文件常常给我们的工作带来困扰&#xff0c;本文将为你揭秘几种简单有效的方法&#xff0c;帮助你轻松降低 pdf 内存&#xff0c;提高工作效率。 方法一、安装pdf转换软件 打…

能源电子领域2区SCI,版面稀缺,即将截稿,无版面费!

【SciencePub学术】今天小编给大家推荐1本能源电子领域的SCI&#xff01;影响因子1.0-2.0之间&#xff0c;最重要的是审稿周期较短&#xff0c;对急投的学者较为友好&#xff01; 能源电子类SCI 01 / 期刊概况 【期刊简介】IF&#xff1a;1.0-2.0&#xff0c;JCR2区&#xf…

安防监控视频平台LntonCVS视频监控管理平台智慧仓储应用方案

在当前许多大型工厂和物流基地中&#xff0c;仓库是存放物品的核心地点。这些仓库不仅存放大量货物&#xff0c;还配备大量辅助设备&#xff0c;需要全面监管以避免安全事故和财产损失。传统的人工巡检方式已经无法满足现代大规模监管的需求&#xff0c;尤其是面对仓储物品种类…

【区块链+跨境服务】粤澳健康码跨境互认系统 | FISCO BCOS应用案例

2020 年突如其来的新冠肺炎疫情&#xff0c;让社会治理体系面临前所未见的考验&#xff0c;如何兼顾疫情防控与复工复产成为社会 各界共同努力的目标。区块链技术作为传递信任的新一代信息基础设施&#xff0c;善于在多方协同的场景中发挥所长&#xff0c;从 而为粤澳两地的疫情…

C基础day6

1、思维导图 2、 #include<myhead.h> #define MAX 10 int main(int argc, const char *argv[]) {//定义一个数组&#xff0c;用于存储班级所有成员的成绩int score[MAX] {0};//完成对成员成绩的输入for(int i0;i<MAX;i){//任意一个元素score[i]printf("请输入第…

数据库查询基础:单表查询与多表查询

❤❤前言 &#x1f44d;&#x1f44d;点关注&#xff0c;编程梦想家&#xff08;大学生版&#xff09;-CSDN博客&#xff0c;不迷路❤❤ 数据库是现代软件开发中不可或缺的一部分&#xff0c;它帮助我们存储、检索和管理大量数据。在这篇文章中&#xff0c;我们将探讨数据库查…

智慧科技照亮水利未来:深入剖析智慧水利解决方案如何助力水利行业实现高效、精准、可持续的管理

目录 一、智慧水利的概念与内涵 二、智慧水利解决方案的核心要素 1. 物联网技术&#xff1a;构建全面感知网络 2. 大数据与云计算&#xff1a;实现数据高效处理与存储 3. GIS与三维可视化&#xff1a;提升决策支持能力 4. 人工智能与机器学习&#xff1a;驱动决策智能化 …

RockYou2024 发布史上最大密码凭证

参与 CTF 的每个人都至少使用过一次臭名昭著的rockyou.txt单词表&#xff0c;主要是为了执行密码破解活动。 该文件是一份包含1400 万个唯一密码的列表。 源自 2009 年的 RockYou 黑客攻击&#xff0c;创造了计算机安全历史。 多年来&#xff0c;“rockyou 系列”不断发展。…

Java基础-组件及事件处理(上)

(创作不易&#xff0c;感谢有你&#xff0c;你的支持&#xff0c;就是我前行的最大动力&#xff0c;如果看完对你有帮助&#xff0c;请留下您的足迹&#xff09; 目录 Swing 概述 MVC 架构 Swing 特点 控件 SWING UI 元素 JFrame SWING 容器 说明 常用方法 示例&a…

Python的异常处理(与C++对比学习)

一、C语言中错误的处理方式 用assert来判断一个表达式是否出错&#xff1b;在调用接口函数时&#xff0c;接口函数会设置errno&#xff0c;我们可以通过errno&#xff0c;strerror(errno)来拿到错误码和错误信息。在自定义函数中&#xff0c;我们设置函数错误信息处理的时候&a…