昇思25天学习打卡营第18天 | 基于MindSpore的GPT2文本摘要

news2024/11/13 14:44:05

昇思25天学习打卡营第18天 | 基于MindSpore的GPT2文本摘要

文章目录

  • 昇思25天学习打卡营第18天 | 基于MindSpore的GPT2文本摘要
    • 数据集
      • 创建数据集
      • 数据预处理
      • Tokenizer
    • 模型构建
      • 构建GPT2ForSummarization模型
      • 动态学习率
    • 模型训练
    • 模型推理
    • 总结
    • 打卡

数据集

实验使用nlpcc2017摘要数据,内容为新闻正文及其摘要,总计50000个样本。

创建数据集

from mindnlp.utils import http_get

# download dataset
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')

from mindspore.dataset import TextFileDataset

# load dataset
dataset = TextFileDataset(str(path), shuffle=False)
dataset.get_dataset_size()

数据预处理

原始数据:

article: [CLS] article_context [SEP]
summary: [CLS] summary_context [SEP]

处理后的数据:

[CLS] article_context [SEP] summary_context [SEP]
import json
import numpy as np

# preprocess dataset
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def merge_and_pad(article, summary):
        # tokenization
        # pad to max_seq_length, only truncate the article
        tokenized = tokenizer(text=article, text_pair=summary,
                              padding='max_length', truncation='only_first', max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['input_ids']
    
    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    # change column names to input_ids and labels for the following training
    dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])

    dataset = dataset.batch(batch_size)
    if shuffle:
        dataset = dataset.shuffle(batch_size)

    return dataset

Tokenizer

由于GPT2无中文tokenizer,使用BertTokenizer替代。

from mindnlp.transformers import BertTokenizer

# We use BertTokenizer for tokenizing chinese context.
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
len(tokenizer)

train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)

模型构建

构建GPT2ForSummarization模型

from mindspore import ops
from mindnlp.transformers import GPT2LMHeadModel

class GPT2ForSummarization(GPT2LMHeadModel):
    def construct(
        self,
        input_ids = None,
        attention_mask = None,
        labels = None,
    ):
        outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)
        shift_logits = outputs.logits[..., :-1, :]
        shift_labels = labels[..., 1:]
        # Flatten the tokens
        loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)
        return loss

动态学习率

from mindspore import ops
from mindspore.nn.learning_rate_schedule import LearningRateSchedule

class LinearWithWarmUp(LearningRateSchedule):
    """
    Warmup-decay learning rate.
    """
    def __init__(self, learning_rate, num_warmup_steps, num_training_steps):
        super().__init__()
        self.learning_rate = learning_rate
        self.num_warmup_steps = num_warmup_steps
        self.num_training_steps = num_training_steps

    def construct(self, global_step):
        if global_step < self.num_warmup_steps:
            return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate
        return ops.maximum(
            0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))
        ) * self.learning_rate

模型训练

num_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4

num_training_steps = num_epochs * train_dataset.get_dataset_size()

from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel

config = GPT2Config(vocab_size=len(tokenizer))
model = GPT2ForSummarization(config)

lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)

from mindnlp._legacy.engine import Trainer
from mindnlp._legacy.engine.callbacks import CheckpointCallback

ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization',
                                epochs=1, keep_checkpoint_max=2)

trainer = Trainer(network=model, train_dataset=train_dataset,
                  epochs=1, optimizer=optimizer, callbacks=ckpoint_cb)
trainer.set_amp(level='O1')  # 开启混合精度

trainer.run(tgt_columns="labels")

模型推理

def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def pad(article):
        tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)
        return tokenized['input_ids']

    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    dataset = dataset.map(pad, 'article', ['input_ids'])
    
    dataset = dataset.batch(batch_size)

    return dataset

test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)
model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)

model.set_train(False)
model.config.eos_token_id = model.config.sep_token_id
i = 0
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():
    output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)
    output_text = tokenizer.decode(output_ids[0].tolist())
    print(output_text)
    i += 1
    if i == 1:
        break

总结

这一节介绍了在MindSpore中使用GPT2LMHeadModel实现文本摘要的实验。实验使用nlpcc2017摘要数据,并使用BertTokenizer进行中文分词,此外还使用了动态学习率来调整模型收敛速度。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1935776.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

详解MLOps,从Jupyter开发到生产部署

大家好&#xff0c;Jupyter notebook 是机器学习的便捷工具&#xff0c;但在应用部署方面存在局限。为了提升其可扩展性和稳定性&#xff0c;需结合DevOps和MLOps技术。通过自动化的持续集成和持续交付流程&#xff0c;可将AI应用高效部署至HuggingFace平台。 本文将介绍MLOps…

网安播报|Python 的 GitHub 核心资源库 token 意外曝光,风险巨大

1、Python 的 GitHub 核心资源库 token 意外曝光&#xff0c;风险巨大 网络安全专家发现了意外泄露的 GitHub token&#xff0c;能以最高权限访问 Python 语言、Python 软件包索引&#xff08;PyPI&#xff09;和 Python 软件基金会&#xff08;PSF&#xff09;存储库。如果该 …

如何确保微型导轨电能表的精准计量?

微型导轨电能表是一种小型化的电表&#xff0c;通常用于精密仪器和设备中。采用模数化设计&#xff0c;精确度高&#xff0c;具有体积小&#xff0c;易安装&#xff0c;易组装等优点。易于实现终端照明电能计量&#xff0c;便于照明系统加装电度表的改造。 对于用户来说&#x…

2024北京国际智能工厂及自动化展览会亮点前瞻

随着“工业创新&#xff0c;智造未来”的浪潮席卷而来&#xff0c;2024年度北京国际智能工厂及自动化与工业装配展览会定于8月1日至3日在中国国际展览中心&#xff08;顺义新馆&#xff09;盛大开幕。本次展会汇聚了智能制造与自动化技术的最新成果&#xff0c;通过三展联动的创…

ozon计算器5.0版本,ozon计算器5.0定价产品价格

在跨境电商的浩瀚星海中&#xff0c;俄罗斯Ozon电商平台以其庞大的市场规模和快速增长的势头&#xff0c;成为了众多卖家竞相布局的蓝海。然而&#xff0c;在这片充满机遇的土地上&#xff0c;如何精准定价&#xff0c;确保利润最大化&#xff0c;同时又能吸引消费者&#xff0…

【Git分支管理】分支策略 | Bug分支

目录 1.分支策略 2.特殊场景-Bug分支 2.1 master出现bug ​2.2 dev2正在开发☞stash区域 2.3 dev2正在开发master出现bug 2.3.1 fix_bug修复bug和master合并 2.3.2 dev2分支开发完和master合并 合并冲突&#xff1a;merge☞手动解决☞提交没有合并冲突&#xff1a;mer…

基于Go语言快速开发抖音小程序 提高性能、效率和降低成本开发框架 让开发变得极简单 开箱即用省时又省钱

前言 用Go语言Arco Design集成后台框架&#xff0c;我们把日常开发常用的基础开发成基础包&#xff0c;等到有项目时安装、下载基础代码就可以马上开发业务&#xff0c;您可以快速搭建好抖音小程序应用开发&#xff0c;为大家搭建抖音后台管理、抖音原生开发模板、小程序登录、…

大模型-Bert+PET实战

PET&#xff08;Pattern-Exploiting Training&#xff09; 背景&#xff1a;预训练语言模型&#xff08;比如BERT&#xff09;知识全面&#xff0c;但是没有针对下游任务做针对训练&#xff0c;所以效果一般&#xff0c;所以需要根据任务做微调。 核心思想&#xff1a;根据先…

langchain 简介

前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站。 LangChain 是一个用于语言模型和应用程序开发的框架&#xff0c;它提供了一系列工具和组件&#xff0c; 帮助开发者更轻松地构建基于大型…

【JVM基础02】——组成-程序计数器解读

目录 1- 引言&#xff1a;程序计数器1-1 程序计数器是什么&#xff1f;为什么用程序计数器&#xff1f;(What)(Why) 2- 核心&#xff1a;程序计数器的原理&#xff08;How&#xff09;2-1 使用 javap 查看程序计数器的作用2-2 多线程下程序计数器原理举例 3- 小结&#xff1a;什…

Linux HOOK机制与Netfilter HOOK

一. 什么是HOOK&#xff08;钩子&#xff09; 在计算机中&#xff0c;基本所有的软件程序都可以通过hook方式进行行为拦截&#xff0c;hook方式就是改变原始的执行流。 二. Linux常见的HOOK方式 1、修改函数指针。 2、用户态动态库拦截。 ①利用环境变量LD_PRELOAD和预装载机…

Calibre:soft check

我正在「拾陆楼」和朋友们讨论有趣的话题,你⼀起来吧? 拾陆楼知识星球入口 soft check检查在lvs检查中属于必看的类型,往往是因为衬底没有硬连接,接pg stripe造成的 下图是一个soft check的错误报告,重要信息有两个: 1)问题在ntap上,也就是着重检查power pin相关的连…

七款最佳公司电脑屏幕监控软件推荐|2024年屏幕监控软件超全盘点!

在当今企业管理中&#xff0c;电脑屏幕监控软件已成为保障数据安全和提升员工生产力的关键工具。选择一款合适的监控软件&#xff0c;可以帮助管理者有效监督员工的电脑使用行为&#xff0c;防止潜在的安全威胁和不当行为。小编分享七款备受好评的电脑屏幕监控软件&#xff0c;…

揭秘!SmartEDA何以成为新生代国产EDA领军者?

在当下科技日新月异的时代&#xff0c;EDA&#xff08;电子设计自动化&#xff09;作为集成电路设计的核心工具&#xff0c;其重要性不言而喻。而在这一领域中&#xff0c;SmartEDA凭借其卓越的性能和创新能力&#xff0c;迅速崭露头角&#xff0c;成为新生代国产EDA的领军者。…

2024可信数据库发展大会|存算分离架构驱动电信数据平台革新

7 月 16 日 - 17 日&#xff0c;由中国通信标准化协会和中国信息通信研究院主办&#xff0c;大数据技术标准推进委员会承办&#xff0c;InfoQ 联合主办的「2024 可信数据库发展大会」&#xff08;TDBC&#xff09;在北京召开。 酷克数据解决方案架构师吴昊受邀参与“电信行业数…

PyTorch使用细节

model.eval() &#xff1a;让BatchNorm、Dropout等失效&#xff1b; with torch.no_grad() &#xff1a; 不再缓存activation&#xff0c;节省显存&#xff1b; 这是矩阵乘法&#xff1a; y1 tensor tensor.T y2 tensor.matmul(tensor.T)y3 torch.rand_like(y1) torch.matm…

破解反爬虫策略 /_guard/auto.js(一) 原理

背景 当用代码或者postman访问一个网站的时候&#xff0c;访问他的任何地址都会返回<script src"/_guard/auto.js"></script>&#xff0c;但是从浏览器中访问显示的页面是正常的&#xff0c;这种就是网站做了反爬虫策略。本文就是带大家来破解这种策略&…

USB3200N模拟信号采集卡12位8路500K采样带DIO带计数器

1、概述&#xff1a; USB3200N多功能数据采集卡&#xff0c;LabVIEW无缝连接&#xff0c;提供图形化API函数&#xff0c;提供8通道&#xff08;RSE、NRSE&#xff09;、4通道&#xff08;DIFF&#xff09;模拟量输入&#xff0c;4路可编程数字I/O&#xff0c;1路计数器。 USB3…

C/C++蓝屏整人代码

文章目录 &#x1f4d2;程序效果 &#x1f4d2;具体步骤 1.隐藏任务栏 2.调整cmd窗口大小 3.调整cmd窗口屏幕颜色 4.完整代码 &#x1f4d2;代码详解 &#x1f680;欢迎互三&#x1f449;&#xff1a;程序猿方梓燚 &#x1f48e;&#x1f48e; &#x1f680;关注博主&a…

前端实现视频播放添加水印

一、效果如下 二、代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Document</title> </head> <body><style>.container {position: relative;}.base {width: 300px;hei…