昇思MindSpore学习总结十六 —— 基于MindSpore的GPT2文本摘要

news2024/9/22 23:38:45

1、mindnlp 版本要求

!pip install tokenizers==0.15.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp

2、数据集加载与处理

2.1 数据集加载

 本次实验使用的是nlpcc2017摘要数据,内容为新闻正文及其摘要,总计50000个样本。

from mindspore.dataset import TextFileDataset  # 从mindspore.dataset模块中导入TextFileDataset类

# load dataset  # 加载数据集
dataset = TextFileDataset(str(path), shuffle=False)  # 创建一个TextFileDataset实例,参数是文件路径(path)转换成字符串格式,shuffle=False表示不打乱数据顺序
dataset.get_dataset_size()  # 获取数据集的大小,即数据集中样本的数量

# split into training and testing dataset  # 将数据集分割为训练集和测试集
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)  # 将数据集按比例[0.9, 0.1]分割为训练集和测试集,randomize=False表示不随机打乱数据

 2.2 数据预处理

import json  # 导入json模块,用于处理JSON数据
import numpy as np  # 导入numpy模块,并简写为np,用于处理数组和矩阵

# preprocess dataset  # 预处理数据集
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):
    # 定义一个嵌套函数read_map,用于读取并解析JSON文本数据
    def read_map(text):
        data = json.loads(text.tobytes())  # 将文本数据转换为字节后用json.loads解析为Python字典
        return np.array(data['article']), np.array(data['summarization'])  # 返回文章和摘要的numpy数组

    # 定义一个嵌套函数merge_and_pad,用于合并并填充数据
    def merge_and_pad(article, summary):
        # tokenization  # 进行分词操作
        # pad to max_seq_length, only truncate the article  # 填充到最大序列长度,仅截断文章部分
        tokenized = tokenizer(text=article, text_pair=summary,
                              padding='max_length', truncation='only_first', max_length=max_seq_len)  # 使用tokenizer对文章和摘要进行分词,填充到最大长度,仅截断文章部分
        return tokenized['input_ids'], tokenized['input_ids']  # 返回分词后的输入ID(注意:这里的input_ids和labels是相同的)

    dataset = dataset.map(read_map, 'text', ['article', 'summary'])  # 使用read_map函数对数据集进行映射,提取文章和摘要
    # change column names to input_ids and labels for the following training  # 更改列名为input_ids和labels,以便后续训练
    dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])  # 使用merge_and_pad函数对数据进行映射,生成input_ids和labels

    dataset = dataset.batch(batch_size)  # 将数据集按批次大小进行分批处理
    if shuffle:
        dataset = dataset.shuffle(batch_size)  # 如果shuffle为True,则对批次进行随机打乱

    return dataset  # 返回预处理后的数据集

 因GPT2无中文的tokenizer,我们使用BertTokenizer替代。

from mindnlp.transformers import BertTokenizer  # 从mindnlp.transformers模块中导入BertTokenizer类

# We use BertTokenizer for tokenizing Chinese context.  # 我们使用BertTokenizer对中文内容进行分词
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')  # 使用预训练的'bert-base-chinese'模型初始化BertTokenizer
len(tokenizer)  # 获取tokenizer的词汇表大小

train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)  # 使用process_dataset函数对训练数据集进行预处理,传入参数包括训练数据集、分词器和批次大小为4
next(train_dataset.create_tuple_iterator())  # 创建一个tuple迭代器并获取其第一个元素

 3、模型构建

3.1 构建GPT2ForSummarization模型,注意shift right的操作。

from mindspore import ops  # 从mindspore模块导入ops操作
from mindnlp.transformers import GPT2LMHeadModel  # 从mindnlp.transformers模块中导入GPT2LMHeadModel类

# 定义一个用于摘要生成的GPT2模型类,继承自GPT2LMHeadModel
class GPT2ForSummarization(GPT2LMHeadModel):
    # 定义模型的构造函数
    def construct(
        self,
        input_ids=None,  # 输入ID
        attention_mask=None,  # 注意力掩码
        labels=None,  # 标签
    ):
        # 调用父类的construct方法,获取模型输出
        outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)
        shift_logits = outputs.logits[..., :-1, :]  # 移动logits,使其与shift_labels对齐
        shift_labels = labels[..., 1:]  # 移动标签,使其与shift_logits对齐
        
        # Flatten the tokens  # 将tokens展平
        loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)  # 计算交叉熵损失,忽略填充的token
        
        return loss  # 返回计算的损失

 3.2 动态学习率

from mindspore import ops  # 从mindspore模块导入ops操作
from mindspore.nn.learning_rate_schedule import LearningRateSchedule  # 从mindspore.nn.learning_rate_schedule模块导入LearningRateSchedule类

# 定义一个线性学习率衰减与热身相结合的学习率调度类,继承自LearningRateSchedule
class LinearWithWarmUp(LearningRateSchedule):
    """
    Warmup-decay learning rate.  # 热身-衰减学习率。
    """
    def __init__(self, learning_rate, num_warmup_steps, num_training_steps):
        super().__init__()  # 调用父类的构造函数
        self.learning_rate = learning_rate  # 初始化学习率
        self.num_warmup_steps = num_warmup_steps  # 初始化热身步数
        self.num_training_steps = num_training_steps  # 初始化训练步数

    # 定义构造函数
    def construct(self, global_step):
        # 如果当前步数小于热身步数
        if global_step < self.num_warmup_steps:
            return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate  # 线性增加学习率
        
        # 否则,学习率进行线性衰减
        return ops.maximum(
            0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))
        ) * self.learning_rate  # 计算并返回衰减后的学习率

 4、模型训练

num_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4

num_training_steps = num_epochs * train_dataset.get_dataset_size()
from mindspore import nn  # 从mindspore模块导入nn(神经网络)模块
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel  # 从mindnlp.transformers模块导入GPT2Config和GPT2LMHeadModel类

# 配置GPT2模型的配置
config = GPT2Config(vocab_size=len(tokenizer))  # 创建GPT2配置实例,设置词汇表大小为tokenizer的长度
model = GPT2ForSummarization(config)  # 使用配置实例创建一个GPT2ForSummarization模型

# 创建学习率调度器
lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)  # 创建线性热身-衰减学习率调度器

# 创建优化器
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)  # 使用AdamWeightDecay优化器,并传入模型的可训练参数和学习率调度器
# 记录模型参数数量
print('number of model parameters: {}'.format(model.num_parameters()))

from mindnlp._legacy.engine import Trainer  # 从mindnlp._legacy.engine模块导入Trainer类
from mindnlp._legacy.engine.callbacks import CheckpointCallback  # 从mindnlp._legacy.engine.callbacks模块导入CheckpointCallback类

# 创建一个CheckpointCallback实例,用于保存检查点
ckpoint_cb = CheckpointCallback(
    save_path='checkpoint',  # 检查点保存路径
    ckpt_name='gpt2_summarization',  # 检查点文件名
    epochs=1,  # 每个epoch保存一次检查点
    keep_checkpoint_max=2  # 最多保留两个检查点
)

# 创建一个Trainer实例,用于训练模型
trainer = Trainer(
    network=model,  # 要训练的模型
    train_dataset=train_dataset,  # 训练数据集
    epochs=1,  # 训练的epoch数
    optimizer=optimizer,  # 优化器
    callbacks=ckpoint_cb  # 回调函数,包括检查点回调
)

trainer.set_amp(level='O1')  # 开启混合精度训练,级别设置为'O1'

下面这段代码,运行时间较长,最好选择较高算力。 

trainer.run(tgt_columns="labels")  # 运行训练器,指定目标列为“labels”

配置不够,训练时间太长。 

5、模型推理

数据处理,将向量数据变为中文数据

def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):
    # 定义一个嵌套函数read_map,用于读取并解析JSON文本数据
    def read_map(text):
        data = json.loads(text.tobytes())  # 将文本数据转换为字节后用json.loads解析为Python字典
        return np.array(data['article']), np.array(data['summarization'])  # 返回文章和摘要的numpy数组

    # 定义一个嵌套函数pad,用于对文章进行分词和填充
    def pad(article):
        tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)  # 对文章进行分词,截断至最大长度减去摘要长度
        return tokenized['input_ids']  # 返回分词后的输入ID

    dataset = dataset.map(read_map, 'text', ['article', 'summary'])  # 使用read_map函数对数据集进行映射,提取文章和摘要
    dataset = dataset.map(pad, 'article', ['input_ids'])  # 使用pad函数对文章进行分词和填充,生成input_ids

    dataset = dataset.batch(batch_size)  # 将数据集按批次大小进行分批处理

    return dataset  # 返回预处理后的数据集
test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)
# 创建一个tuple迭代器并获取其第一个元素,以NumPy数组的形式输出,并打印出来
print(next(test_dataset.create_tuple_iterator(output_numpy=True)))
model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)  # 从预训练的检查点加载模型
model.set_train(False)  # 设置模型为评估模式(非训练模式)
model.config.eos_token_id = model.config.sep_token_id  # 设置模型的eos_token_id为sep_token_id
i = 0  # 初始化计数器为0

# 遍历测试数据集的迭代器,获取输入ID和原始摘要
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():
    # 使用模型生成新的摘要,参数包括最大新生成的token数量、束搜索的束数、不重复的ngram大小
    output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)
    # 将生成的ID转换为文本
    output_text = tokenizer.decode(output_ids[0].tolist())
    print(output_text)  # 打印生成的摘要文本
    i += 1  # 计数器加1
    if i == 1:  # 如果计数器达到1
        break  # 跳出循环,仅生成并打印一个摘要

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1939773.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

提供代码!直接可以运行,Chatgpt代码分享

效果演示 安装依赖库 pip install openai粘贴如下代码 # 设置 API Key import openaiopenai.api_key "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"# 设置请求参数model_engine "text-davinci-002"prompt "python的应用领域"comp…

Bonree ONE赋能汽车行业 重塑可观测性体验

随着数字化、智能化浪潮的汹涌而至&#xff0c;全球汽车产业正站在一个崭新的历史起点上。新能源汽车&#xff0c;作为这场科技革命和产业变革的领跑者&#xff0c;其数智化发展正呈现出前所未有的蓬勃态势。7月18-19日&#xff0c;第四届中国新能源汽车产业数智峰会于上海举办…

《0基础》学习Python——第二十三讲__网络爬虫/<6>爬取哔哩哔哩视频

一、在B站上爬取一段视频&#xff08;B站视频有音频和视频两个部分&#xff09; 1、获取URL 注意&#xff1a;很多平台都有反爬取的机制&#xff0c;B站也不例外 首先按下F12找到第一条复制URL 2、UA伪装&#xff0c;下列图片中&#xff08;注意代码书写格式&#xff09; 3、Co…

【2024】springboot O2O生鲜食品订购

博主介绍&#xff1a;✌CSDN新星计划导师、Java领域优质创作者、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和学生毕业项目实战,高校老师/讲师/同行前辈交流✌ 技术范围&#xff1a;SpringBoot、Vue、SSM、HTML、Jsp、PHP、Nodejs、Python、爬虫、数据可视化…

sass版本更新,不推荐使用嵌套规则后的声明

目前在 Sass 中不推荐使用嵌套规则后的声明&#xff0c;在 为了通知用户即将进行的更改&#xff0c;并给他们时间进行更改 与之兼容的样式表。在未来的版本中&#xff0c;Dart Sass 将更改为 匹配纯 CSS 嵌套生成的顺序。Deprecation Warning: Sasss behavior for declarations…

视频点播项目

文章目录 视频点播技术栈与项目环境JsonCppMariaDBhttplib 工具类设计文件类Json类 数据管理模块视频信息管理&#xff08;数据表设计&#xff09;数据管理类设计 网络通信接口设计业务处理模块设计前端界面主页面播放页面 项目总结项目回顾项目结构关键技术点总结 视频点播 允…

el-table表头使用el-dropdown出现两个下拉框

问题描述&#xff1a;el-table在固定右边列时&#xff0c;表头使用el-dropdown会出现两个下拉框&#xff0c;如图所示&#xff1a; 解决方法&#xff1a; 1.只显示第一个下拉框&#xff0c;通过控制样式将其他的下拉框display:none; 2.如图所示&#xff0c;修改插槽写法&…

SpringBoot源码深度解析

今天&#xff0c;聊聊SpringBoot的源码&#xff0c;本博客聊的版本为v2.0.3.RELEASE。目前SpringBoot的最新版为v3.3.2&#xff0c;可能目前有些公司使用的SpringBoot版本高于我这个版本。但是没关系&#xff0c;因为版本越新&#xff0c;新增的功能越多&#xff0c;反而对Spri…

如何将几百兆的包优化到几十兆----记一次vue项目的打包优化过程

打包优化 现象 前段时间开发的时候遇到客户反馈的一个问题 界面无法打开&#xff0c;显示白屏&#xff0c;控制台无报错 经过我们在开发环境&#xff0c;测试环境反复测试都没复现出客户的问题&#xff0c;然后我们又不停的在生产环境上找问题&#xff0c;也没复现出来 最…

一文带你读懂MLIR论文,理解MLIR设计准则.

论文MLIR: Scaling Compiler Infrastructure for Domain Specific Computation MLIR&#xff1a;针对特定领域计算扩展编译器基础设施 文章目录 论文MLIR: Scaling Compiler Infrastructure for Domain Specific Computation1. 论文下载2. TVM关于MLIR的讨论3. 论文正文0. 摘要…

5个人人都应该学会的电脑小技巧

今天分享几个电脑常用的快捷键&#xff0c;可以让你的工作事半功倍&#xff0c;建议收藏&#xff01; 不想被突然来的人看到正在浏览的网站&#xff0c;用CtrlW或者AltF4可以关闭当前页面&#xff0c;另外如果你正在看小电影也可以用这个办法关闭。 2. 同事或大Boos走了之后想…

Anthropic的Claude安卓版能否赢得用户青睐?

Anthropic的Claude安卓版能否赢得用户青睐&#xff1f; 前言 Anthropic 就在7月18日&#xff0c;这家以"可控AI"著称的初创公司再次出手&#xff0c;推出了Claude的Android版本应用。这款APP不仅支持实时语言翻译&#xff0c;更传承了Anthropic一贯的隐私保护政策。C…

腾讯云COS托管静态网站,以及如何解决访问出现了下载网页的情况

腾讯云对象存储&#xff08;Cloud Object Storage&#xff0c;简称COS&#xff09;&#xff0c;与其他云厂商所提供的云对象存储都是面向非结构化数据&#xff0c;只是每个云厂商的叫法有别于他家&#xff0c;或许是更能彰显厂商的品牌吧&#xff01; 但不管云厂商怎么给云对象…

错误:PHP:Deprecated: Required parameter $xxx follows optional parameter $yyy

前言 略 错误 Deprecated: Required parameter $xxx follows optional parameter $yyy 解决办法 设置 error_reporting E_ALL & ~E_DEPRECATED & ~E_STRICT 参考 https://blog.csdn.net/lxw1844912514/article/details/100028023

MATLAB学习日志DAY13

13.矩阵索引&#xff08;1&#xff09; 13.1 下标 上图&#xff01; A 的行 i 和列 j 中的元素通过 A(i,j) 表示。 例如&#xff0c;A(4,2) 表示第四行和第二列中的数字。 在幻方矩阵中&#xff0c; A(4,2) 为 15。 A(1,4) A(2,4) A(3,4) A(4,4) 用来计算 A 第四列中的…

【PG】PostgreSQL高可用之repmgr事件通知

目录 描述 结合脚本 占位符 repmgr命令 生成的事件&#xff1a; repmgrd 生成的事件&#xff08;流复制模式&#xff09;&#xff1a; 描述 每次repmgr或repmgrd执行重大事件时&#xff0c;都会将该事件的记录连同时间戳、失败或成功的标识以及进一步的详细信息&#xff08…

黑马商城启动流程(微服务拆分项目)

1.虚拟机ssh&#xff08;docker中布置了 mysql nacos seata&#xff08;分布式事务&#xff09;三个容器在同一个网络hm-net中&#xff09; 2.idea&#xff08;启动所有的微服务项目 &#xff09; 3.nginx 在cmd上启动 4.navicat 连接数据库 5.登录前端页面 http://localho…

Stable Diffusion AI入门介绍

Stable Diffusion模型 在上一篇的文章中我们介绍了&#xff0c;AIGC的相关知识以及AI绘画的历史——AIGC是什么&#xff0c;与AI绘画有什么关系&#xff0c;一篇文章带你了解AI绘画的前世今生。 我们知道了Stable Diffusion是一种潜在扩散模型&#xff0c;由慕尼黑大学的Comp…

将github上的项目导入到vscode并创建虚拟环境

1、将github上的项目导入到vscode 直接从github上下载到本地&#xff0c;用vscode打开&#xff08;Open file&#xff09; 2、创建虚拟环境 python -m venv <name> <name>\Scripts\activate ps: 1、退出虚拟环境 deactivate 2、如果运行python -m venv <…

【数据结构】详解堆

一、堆的概念 堆(Heap)是计算机科学中一类特殊的数据结构的统称。堆通常是一个可以被看做一棵 完全二叉树的 数组对象。 堆是非线性数据结构&#xff0c;相当于一维数组&#xff0c;有两个直接后继。 如果有一个关键码的集合K { k₀&#xff0c;k₁&#xff0c;k₂ &#xff0…