昇思25天学习打卡营第7天 | 基于MindSpore的GPT2文本摘要

news2024/9/20 20:21:16

本次打卡基于gpt2的文本摘要

数据加载及预处理

from mindnlp.utils import http_get

# download dataset
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')

from mindspore.dataset import TextFileDataset

# load dataset
dataset = TextFileDataset(str(path), shuffle=False)
dataset.get_dataset_size()

# split into training and testing dataset
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)

import json
import numpy as np

# preprocess dataset
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def merge_and_pad(article, summary):
        # tokenization
        # pad to max_seq_length, only truncate the article
        tokenized = tokenizer(text=article, text_pair=summary,
                              padding='max_length', truncation='only_first', max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['input_ids']
    
    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    # change column names to input_ids and labels for the following training
    dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])

    dataset = dataset.batch(batch_size)
    if shuffle:
        dataset = dataset.shuffle(batch_size)

    return dataset

from mindnlp.transformers import BertTokenizer

# We use BertTokenizer for tokenizing chinese context.
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
len(tokenizer)

 模型构建¶

from mindspore import ops
from mindnlp.transformers import GPT2LMHeadModel

class GPT2ForSummarization(GPT2LMHeadModel):
    def construct(
        self,
        input_ids = None,
        attention_mask = None,
        labels = None,
    ):
        outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)
        shift_logits = outputs.logits[..., :-1, :]
        shift_labels = labels[..., 1:]
        # Flatten the tokens
        loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)
        return loss

from mindspore import ops
from mindspore.nn.learning_rate_schedule import LearningRateSchedule

class LinearWithWarmUp(LearningRateSchedule):
    """
    Warmup-decay learning rate.
    """
    def __init__(self, learning_rate, num_warmup_steps, num_training_steps):
        super().__init__()
        self.learning_rate = learning_rate
        self.num_warmup_steps = num_warmup_steps
        self.num_training_steps = num_training_steps

    def construct(self, global_step):
        if global_step < self.num_warmup_steps:
            return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate
        return ops.maximum(
            0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))
        ) * self.learning_rate

num_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4

num_training_steps = num_epochs * train_dataset.get_dataset_size()

from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel

config = GPT2Config(vocab_size=len(tokenizer))
model = GPT2ForSummarization(config)

lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)

# 记录模型参数数量
print('number of model parameters: {}'.format(model.num_parameters()))

from mindnlp._legacy.engine import Trainer
from mindnlp._legacy.engine.callbacks import CheckpointCallback

ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization',
                                epochs=1, keep_checkpoint_max=2)

trainer = Trainer(network=model, train_dataset=train_dataset,
                  epochs=1, optimizer=optimizer, callbacks=ckpoint_cb)
trainer.set_amp(level='O1')  # 开启混合精度

trainer.run(tgt_columns="labels")

 

 结论

gpt2相较bert等模型,在文本识别、文本摘要、命名体识别中有着优秀的表现,但其模型规模相对较大,训练时间较长,打卡中展示的没有完成训练,这里需要更好的gpu来辅助训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1927958.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

比华为、特斯拉更大的野心

作者 | 艾泊宇 最近百度自动驾驶的网约车萝卜快跑在武汉大规模上路了。 同样是做自动驾驶&#xff0c;你看百度、华为、特斯拉&#xff0c;他们三家的思路完全不同。 但是可以看出来&#xff0c;各自完全不同的用意和意图&#xff0c;以及格局的高低。 华为很稳&#xff0c;…

BernNet Learning Arbitrary Graph Spectral Filters via Bernstein Approximation

发表于:neurips21 推荐指数: #paper/⭐⭐ 设定:在本文中,h是过滤器. bernstein 多项式逼近(这个证明有点稀里糊涂的,反正我觉得一点点问题,可能因为我水平低) p K ( t ) : ∑ k 0 K θ k ⋅ b k K ( t ) ∑ k 0 K f ( k K ) ⋅ ( K k ) ( 1 − t ) K − k t k . p_K(t):…

太牛了!从来没想到加密软件这么好用

还在为无法保证重要信息安全烦恼吗&#xff1f;金刚钻信息网站&#xff0c;一个集数据防泄密系统、企业数据云盘存储为一身的多个安全产品网站&#xff0c;为企业文件保驾护航&#xff01; 一、全方位防护&#xff0c;无懈可击 数据防泄密系统从电脑内部&#xff0c;电脑外部多…

AV1 编码标准熵编码技术概述

AV1熵编码 AV1编码技术是一种开源的视频编解码标准&#xff0c;由开放媒体联盟&#xff08;AOMedia&#xff09;开发&#xff0c;旨在提供高效的视频压缩&#xff0c;同时避免复杂的专利授权问题。在熵编码方面&#xff0c;AV1采用了一种多符号上下文自适应算术编码技术&#x…

EMR 集群时钟同步问题及解决方案An error occurred (InvalidSignatureException)

目录 1. 问题描述2. 问题原因3. 解决过程4. 时钟同步的重要性5. Linux 系统中的时钟同步方式6. 检查 Linux 系统时钟同步状态7. EMR 集群中的时钟同步配置8. 时钟同步对大数据组件的影响9. 监控和告警策略10. 故障排除和最佳实践11. 自动化时钟同步管理12. 时钟同步与数据一致性…

每日复盘-20240715

20240715 六日涨幅最大: ------1--------300807--------- 天迈科技 五日涨幅最大: ------1--------300807--------- 天迈科技 四日涨幅最大: ------1--------300807--------- 天迈科技 三日涨幅最大: ------1--------300713--------- 英可瑞 二日涨幅最大: ------1--------3007…

AV1技术学习:Translational Motion Compensation

编码块根据运动矢量在参考帧中找到相应的预测块&#xff0c;如下图所示&#xff0c;当前块的左上角的位置为(x0, y0)&#xff0c;在参考帧中找到同样位置(x0, y0)的块&#xff0c;根据运动矢量移动到目标参考块&#xff08;左上角位置为&#xff1a;(x1, y1)&#xff09;。 AV1…

【java】力扣 买卖股票的最佳时机 动态规划

文章目录 题目链接题目描述思路代码 题目链接 121.买卖股票的最佳时机 题目描述 思路 本题主要用到了动态规划 1.先定义dp数组的含义 先定义一个二维数组dp 然后dp[i][0]来表示第i天持有股票的现金 dp[i][1]代表第i天不持有股票的现金 刚开始的现金为0&#xff0c;当第i天买…

mysql索引值

mysql 索引值生成规则 MySQL索引值是如何生成的取决于具体的数据类型和列的具体定义。对于大多数数据类型&#xff0c;MySQL会为索引键值使用原始的数据。对于字符串类型&#xff08;如VARCHAR, CHAR, TEXT&#xff09;&#xff0c;索引键值可能是字符串的前缀&#xff0c;这是…

二.1 信息存储(1.1-1.3)

大多数计算机使用8位的块&#xff0c;或者字节&#xff08;byte&#xff09;&#xff0c;作为最小的可寻址的内存单位&#xff0c;而不是访问内存中单独的位。机器级程序将内存视为一个非常大的字节数组&#xff0c;称为虚拟内存&#xff08;virtual memory&#xff09;。内存的…

Home Assistant在windows环境安装

Home Assistant是什么&#xff1f; Home Assistant 是一个开源的智能家居平台&#xff0c;旨在通过集成各种智能设备和服务&#xff0c;提供一个统一的、可自定义的家庭自动化解决方案。它可以允许用户监控、控制和自动化家中的各种设备&#xff0c;包括灯光、温度、安全系统、…

C语言学生成绩管理系统源程序+设计报告

资料下载地址&#xff1a;C语言学生成绩管理系统源程序设计报告 目录 1.设计目的与要求 2.系统需求分析 3.总体设计 4、运行界面 5、资料清单 1.设计目的与要求 设计目的&#xff1a;学生成绩管理系统是为了在这个信息时代高速发展的今天&#xff0c;通过计算机取代传统…

Python从0到100(三十九):数据提取之正则(文末免费送书)

前言&#xff1a; 零基础学Python&#xff1a;Python从0到100最新最全教程。 想做这件事情很久了&#xff0c;这次我更新了自己所写过的所有博客&#xff0c;汇集成了Python从0到100&#xff0c;共一百节课&#xff0c;帮助大家一个月时间里从零基础到学习Python基础语法、Pyth…

加油机税控装置:功能、原理、挑战与发展趋势全解析

加油机税控装置是现代加油机的重要组成部分&#xff0c;它不仅确保销售数据的真实性和合法性&#xff0c;还大大提高了税收管理的效率和质量。 以下是对加油机税控装置的详细解析&#xff1a; 一、功能与作用 1、确保数据真实性&#xff1a;税控装置能够实时、准确地采集加油…

隧道调频广播信号覆盖系统改造-泄漏电缆隧道全线无盲区调频覆盖解决方法探究

隧道调频广播信号覆盖系统改造-泄漏电缆隧道全线无盲区调频覆盖解决方法探究 由北京海特伟业科技有限公司任洪卓发布于2024年7月15日 随着城市交通的不断发展&#xff0c;隧道作为城市交通的重要组成部分&#xff0c;承担着日益增长的交通压力。为了确保行驶在隧道中的车辆能够…

Unity最新第三方开源插件《Stateful Component》管理中大型项目MonoBehaviour各种序列化字段 ,的高级解决方案

上文提到了UIState, ObjectRefactor等,还提到了远古的NGUI, KBEngine-UI等 这个算是比较新的解决方法吧,但是抽象出来,问题还是这些个问题 所以你就说做游戏是不是先要解决这些问题? 而不是高大上的UiImage,DoozyUI等 Mono管理引用基本用法 ① 添加Stateful Component …

书生大模型实战营--L0关卡-Git

任务一、自我介绍 一、使用vscode链接git并提交代码 二、提交新的pr

Linux目录网络设置远程工具的使用

文章目录 Linux目录虚拟机⽹络配置查看⽹络信息修改⽹络配置信息 虚拟机管理操作远程⼯具的使⽤ Linux目录 Linux的⽬录结构 Linux中的常⻅⽬录 Linux常⻅的⽬录结构&#xff0c;不同版本的Linux⽬录结构可能略有不同 Centos7的⽂件⽬录结构 Linux根⽬录下的常⻅⽬录及作⽤ …

windows下安装和使用nacos

概述 Nacos致力于帮助您发现、配置和管理微服务。Nacos提供了一组简单易用的特性集&#xff0c;帮助您快速实现动态服务发 现、服务配置、服务元数据及流且管理 Nacos官方文档&#xff1a;https://nacos.io/zh-cn/docs/quick-start.html Nacos下载地址&#xff1a;https://n…

ArkUI-X视频播放App初出茅庐

前言; 各位同学大家好之前写了一些基于 OpenHarmony 系统写arkui的项目。所以移植到arkui-x上面来 效果图 OpenHarmony os 设备效果图 : 安卓设备效果图