昇思25天学习打卡营第11天|基于MindSpore的GPT2文本摘要

news2024/9/24 13:26:08

数据集

准备nlpcc2017摘要数据,内容为新闻正文及其摘要,总计50000个样本。

数据需要预处理,如下

原始数据格式:
article: [CLS] article_context [SEP]
summary: [CLS] summary_context [SEP]

预处理后的数据格式:
[CLS] article_context [SEP] summary_context [SEP]

代码示例:

# 下载依赖
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
!pip install mindnlp

from mindnlp.utils import http_get

# 下载数据集
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')

from mindspore.dataset import TextFileDataset

# 加载数据集
dataset = TextFileDataset(str(path), shuffle=False)
dataset.get_dataset_size()

# 按9:1比例拆分训练集、测试集
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)

import json
import numpy as np

# 数据集数据预处理
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def merge_and_pad(article, summary):
        # tokenization
        # pad to max_seq_length, only truncate the article
        tokenized = tokenizer(text=article, text_pair=summary,
                              padding='max_length', truncation='only_first', max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['input_ids']
    
    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    # change column names to input_ids and labels for the following training
    dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])

    dataset = dataset.batch(batch_size)
    if shuffle:
        dataset = dataset.shuffle(batch_size)

    return dataset

from mindnlp.transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
len(tokenizer)

train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)
next(train_dataset.create_tuple_iterator())

模型构建

代码示例:

# 构建GPT2ForSummarization模型
from mindspore import ops
from mindnlp.transformers import GPT2LMHeadModel

class GPT2ForSummarization(GPT2LMHeadModel):
    def construct(
        self,
        input_ids = None,
        attention_mask = None,
        labels = None,
    ):
        outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)
        shift_logits = outputs.logits[..., :-1, :]
        shift_labels = labels[..., 1:]
        # Flatten the tokens
        loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)
        return loss

# 动态学习率
from mindspore import ops
from mindspore.nn.learning_rate_schedule import LearningRateSchedule

class LinearWithWarmUp(LearningRateSchedule):
    """
    Warmup-decay learning rate.
    """
    def __init__(self, learning_rate, num_warmup_steps, num_training_steps):
        super().__init__()
        self.learning_rate = learning_rate
        self.num_warmup_steps = num_warmup_steps
        self.num_training_steps = num_training_steps

    def construct(self, global_step):
        if global_step < self.num_warmup_steps:
            return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate
        return ops.maximum(
            0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))
        ) * self.learning_rate

模型训练

代码示例:

num_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4

num_training_steps = num_epochs * train_dataset.get_dataset_size()

from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel

config = GPT2Config(vocab_size=len(tokenizer))
model = GPT2ForSummarization(config)

lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)

# 记录模型参数数量
print('number of model parameters: {}'.format(model.num_parameters()))

from mindnlp._legacy.engine import Trainer
from mindnlp._legacy.engine.callbacks import CheckpointCallback

ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization',
                                epochs=1, keep_checkpoint_max=2)

trainer = Trainer(network=model, train_dataset=train_dataset,
                  epochs=1, optimizer=optimizer, callbacks=ckpoint_cb)
trainer.set_amp(level='O1')  # 开启混合精度

trainer.run(tgt_columns="labels")

运行结果:
模型训练结果

模型推理

将向量数据变为中文数据。
代码示例:

def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def pad(article):
        tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)
        return tokenized['input_ids']

    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    dataset = dataset.map(pad, 'article', ['input_ids'])
    
    dataset = dataset.batch(batch_size)

    return dataset

test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)
print(next(test_dataset.create_tuple_iterator(output_numpy=True)))

model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)
model.set_train(False)
model.config.eos_token_id = model.config.sep_token_id
i = 0
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():
    output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)
    output_text = tokenizer.decode(output_ids[0].tolist())
    print(output_text)
    i += 1
    if i == 1:
        break

截图时间
截图时间

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1910379.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

EHS管理体系,重塑造企业竞争力的关键密码

在当今这个快速发展的时代&#xff0c;企业面临着前所未有的挑战与机遇。随着全球环保意识的普遍觉醒&#xff0c;以及社会各界对企业社会责任的日益关注&#xff0c;EHS&#xff08;环境&#xff0c;健康&#xff0c;安全&#xff09;管理体系成为了企业稳健前行的重要基石。它…

GPU发展史(二):改变游戏规则的3Dfx Voodoo

小伙伴们&#xff0c;大家好呀&#xff0c;我是老猫。 在上一篇GPU发展史&#xff08;一&#xff09;文章中&#xff0c;我们介绍了1976-1995期间早期显卡的发展故事&#xff0c;今天我们将介绍在1995-1999年这段时间显卡的故事&#xff0c;而这段故事的主角就是——3Dfx 提起…

人工智能+病理组学的交叉课题,患者的临床特征如何收集与整理|顶刊专题汇总·24-07-09

小罗碎碎念 本期文献主题&#xff1a;人工智能病理组学的交叉课题&#xff0c;患者的临床特征如何收集与整理 我们在阅读文献的时候会发现&#xff0c;有的文章会详细给出自己的数据集分析表&#xff0c;分别列出训练集、验证集的数量&#xff0c;以及每个特征对应的患者人数。…

Java版Flink使用指南——从RabbitMQ中队列中接入消息流

大纲 创建RabbitMQ队列新建工程新增依赖编码设置数据源配置读取、处理数据完整代码 打包、上传和运行任务测试 工程代码 在《Java版Flink使用指南——安装Flink和使用IntelliJ制作任务包》一文中&#xff0c;我们完成了第一个小型Demo的编写。例子中的数据是代码预先指定的。而…

uniapp easycom组件冲突

提示信息 ​easycom组件冲突&#xff1a;[/components/uni-icons/uni-icons.vue,/uni_modules/uni-icons/components/uni-icons/uni-icons.vue]​ 问题描述 老项目&#xff0c;在uniapp插件商城导入了一个新的uniapp官方开发的组件》uni-data-picker 数据驱动的picker选择器 …

Java中List集合介绍

一.List集合的概述和特点 List集合的概述 有序集合,这里的有序指的是存取顺序 用户可以精确控制列表中每个元素的插入位置,用户可以通过整数索引访问元素,并搜索列表中的元素 与Set集合不同,列表通常允许重复的元素 List集合的特点 存取有序 可以重复 有索引 二.List集合…

Unity海面效果——6、反射和高光

Unity引擎制作海面效果 大家好&#xff0c;我是阿赵。 上一篇的结束时&#xff0c;海面效果已经做成这样了&#xff1a; 这个Shader的复杂程度已经比较高了&#xff1a; 不过还有一些美中不足的地方。 1、 海平面没有反射到天空球 2、 在近岸边看得到水底的部分&#xff0c;水…

GitLab介绍,以及add an SSH key

GitLab GitLab 是一个用于仓库管理系统的开源项目&#xff0c;现今并在国内外大中型互联网公司广泛使用。 git,gitlab,github区别 git 是一种基于命令的版本控制系统&#xff0c;全命令操作&#xff0c;没有可视化界面&#xff1b; gitlab 是一个基于git实现的在线代码仓库…

昇思25天学习打卡营第12天|基于MindSpore的GPT2文本摘要

基于MindSpore的GPT2文本摘要 数据集加载 使用nlpcc2017摘要数据&#xff0c;共包含5万个样本&#xff0c;内容是新闻正文及其摘要。 from mindnlp.utils import http_get from mindspore.dataset import TextFileDataset# 下载数据集 url https://download.mindspore.cn/t…

简过网:2024年一级造价工程师正在报名中,看看你有报考资格吗?

计划报考2024一级造价工程师的小伙伴要注意了&#xff0c;现在一造报名正在启动中&#xff0c;想考试的小伙伴一定要看清楚这些报考条件和考试内容哦&#xff0c;今天&#xff0c;小编和大家一块来分享一下&#xff0c;希望对你有帮助。 几个简单的问题&#xff0c;让你彻底了解…

67.SAP FICO-凭证类型学习

目录 SAP凭证类型 凭证类型的作用 - OBA7 SAP默认的凭证类型更改 FI相应事务代码默认凭证类型 - OBU1 对FB50、60、70默认凭证类型的更改 - OBZO 后勤货物移动默认凭证类型 - OMBA 发货凭证类型 收货凭证类型 自动移动凭证类型 存货盘点凭证类型 发票默认的凭证类…

vue项目本地开启https协议访问(vite)

官网介绍&#xff1a;vite官方文档 1、根据官方文档安装依赖vitejs/plugin-basic-ssl npm install -D vitejs/plugin-basic-ssl2、在vite.config.js或者vite.config.ts中配置&#xff1a;server中的https和plugins import { defineConfig } from "vite"; import b…

AURORA仿真

AURORA 仿真验证 定义&#xff1a;AURORA是一种高速串行通信协议&#xff0c;通常用于在数字信号处理系统和其他电子设备之间传输数据。它提供了一种高效的方式来传输大量数据&#xff0c;通常用于需要高带宽和低延迟的应用中。AURORA协议通常由Xilinx公司的FPGA器件支持&#…

KBPC5010-ASEMI逆变焊机专用KBPC5010

编辑&#xff1a;ll KBPC5010-ASEMI逆变焊机专用KBPC5010 型号&#xff1a;KBPC5010 品牌&#xff1a;ASEMI 封装&#xff1a;KBPC-4 正向电流&#xff08;Id&#xff09;&#xff1a;50A 反向耐压&#xff08;VRRM&#xff09;&#xff1a;1000V 正向浪涌电流&#xff…

【YOLO格式的数据标签,目标检测】

标签为 YOLO 格式&#xff0c;每幅图像一个 *.txt 文件&#xff08;如果图像中没有对象&#xff0c;则不需要 *.txt 文件&#xff09;。*.txt 文件规格如下: 每个对象一行 每一行都是 class x_center y_center width height 格式。 边框坐标必须是 归一化的 xywh 格式&#x…

AutoMQ 产品动态 | 发布 1.1.0,兼容至 Apache Kafka 3.7,支持 Kaf

2024年05-06月动态 01 云原生 Kafka 内核 AutoMQ 发布 1.1.0 版本&#xff0c;兼容 Apache Kafka 到 3.7 版本。 AutoMQ 采用存算分离的思路重构 Apache Kafka&#xff0c;计算层 100% 保留 Kafka 代码&#xff0c;可以更好、更快地适配 Apache Kafka 新版本。经过代码 Revi…

三菱FX3U转OPC UA服务器的配置指南

在工业4.0的浪潮中&#xff0c;传统企业正面临着前所未有的转型压力&#xff0c;其中最为紧迫的挑战之一便是如何高效便捷地将PLC数据无缝集成到PC端上位软件中。我们知道&#xff0c;如果直接从PLC采集数据&#xff0c;不仅涉及复杂的技术难题&#xff0c;如繁琐的软件开发和冗…

MySQL之备份与恢复(五)

备份与恢复 备份数据 符号分隔文件备份 可以使用SQL命令SELECT INTO OUTFILE以符号分隔文件格式创建数据的逻辑备份。(可以用mysqldump的 --tab选项导出到符号分隔文件中)。符号分隔文件包含以ASCII展示的原始数据&#xff0c;没有SQL、注释和列名。下面是一个导出为逗号分隔…

请编写函数,判断一字符串是否是回文,若是回文函数返回值为1,否则返回值为0,回文是顺读和倒读都一样的字符串

int gets_arr(char* p) {int i 0;int j strlen(p) - 1;while (i < j && p[i] p[j]){i;j--;}if (i<j){return 0;}else {return 1;}} int main() {printf("请输入一串字符串\n");char arr[100];gets(arr);int ret gets_arr(arr);if (ret 1){printf(…

如何降低pdf内存,如何降低pdf內存大小,如何降低pdf内存占用

在现代办公环境中&#xff0c;pdf文件已经成为了一种不可或缺的文档格式。然而&#xff0c;pdf内存太大文件常常给我们的工作带来困扰&#xff0c;本文将为你揭秘几种简单有效的方法&#xff0c;帮助你轻松降低 pdf 内存&#xff0c;提高工作效率。 方法一、安装pdf转换软件 打…