昇思25天学习打卡营第10天|基于MindSpore的GPT2文本摘要

news2025/1/12 22:50:46

学AI还能赢奖品?每天30分钟,25天打通AI任督二脉 (qq.com)

基于MindSpore的GPT2文本摘要

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
!pip install tokenizers==0.15.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp

数据集加载与处理

  • 数据集加载

    本次实验使用的是nlpcc2017摘要数据,内容为新闻正文及其摘要,总计50000个样本。

from mindnlp.utils import http_get

# download dataset
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')
from mindspore.dataset import TextFileDataset

# load dataset
dataset = TextFileDataset(str(path), shuffle=False)
dataset.get_dataset_size()
# split into training and testing dataset
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)

使用http_get函数从指定URL下载数据集。

使用TextFileDataset加载数据集,并将其分割为训练集和测试集。

  • 数据预处理

    原始数据格式:

    article: [CLS] article_context [SEP]
    summary: [CLS] summary_context [SEP]
    

    预处理后的数据格式:

    [CLS] article_context [SEP] summary_context [SEP]
import json
import numpy as np

# preprocess dataset
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def merge_and_pad(article, summary):
        # tokenization
        # pad to max_seq_length, only truncate the article
        tokenized = tokenizer(text=article, text_pair=summary,
                              padding='max_length', truncation='only_first', max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['input_ids']
    
    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    # change column names to input_ids and labels for the following training
    dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])

    dataset = dataset.batch(batch_size)
    if shuffle:
        dataset = dataset.shuffle(batch_size)

    return dataset

因GPT2无中文的tokenizer,我们使用BertTokenizer替代。

from mindnlp.transformers import BertTokenizer

# We use BertTokenizer for tokenizing chinese context.
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
len(tokenizer)
train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)
next(train_dataset.create_tuple_iterator())

定义process_dataset函数来预处理数据集,包括读取数据、进行分词和填充等操作。

采用BertTokenizer进行中文文本的预处理。

模型构建

  • 构建GPT2ForSummarization模型,注意shift right的操作。
from mindspore import ops
from mindnlp.transformers import GPT2LMHeadModel

class GPT2ForSummarization(GPT2LMHeadModel):
    def construct(
        self,
        input_ids = None,
        attention_mask = None,
        labels = None,
    ):
        outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)
        shift_logits = outputs.logits[..., :-1, :]
        shift_labels = labels[..., 1:]
        # Flatten the tokens
        loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)
        return loss

创建GPT2ForSummarization类,继承自GPT2LMHeadModel,用于文本摘要任务。

construct方法中,实现标签的右移操作,匹配序列到序列的需求。

  • 动态学习率
from mindspore import ops
from mindspore.nn.learning_rate_schedule import LearningRateSchedule

class LinearWithWarmUp(LearningRateSchedule):
    """
    Warmup-decay learning rate.
    """
    def __init__(self, learning_rate, num_warmup_steps, num_training_steps):
        super().__init__()
        self.learning_rate = learning_rate
        self.num_warmup_steps = num_warmup_steps
        self.num_training_steps = num_training_steps

    def construct(self, global_step):
        if global_step < self.num_warmup_steps:
            return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate
        return ops.maximum(
            0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))
        ) * self.learning_rate

定义LinearWithWarmUp类,实现线性预热衰减学习率策略,在训练初期逐步增加学习率以帮助模型快速收敛,随后在训练后期逐渐降低学习率以进行精细调整。

模型训练

num_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4

num_training_steps = num_epochs * train_dataset.get_dataset_size()

from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel

config = GPT2Config(vocab_size=len(tokenizer))
model = GPT2ForSummarization(config)

lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)

# 记录模型参数数量
print('number of model parameters: {}'.format(model.num_parameters()))
from mindnlp._legacy.engine import Trainer
from mindnlp._legacy.engine.callbacks import CheckpointCallback

ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization',
                                epochs=1, keep_checkpoint_max=2)

trainer = Trainer(network=model, train_dataset=train_dataset,
                  epochs=1, optimizer=optimizer, callbacks=ckpoint_cb)
trainer.set_amp(level='O1')  # 开启混合精度

注:建议使用较高规格的算力,训练时间较长

trainer.run(tgt_columns="labels")

设置训练参数,如学习率、预热步数和总训练步数。

使用Trainer进行模型训练,设置检查点回调保存模型。

原数据集50000样本,训练时间较长,实际用了10000样本减少训练时间。

没有尝试静态图mindnlp/mindnlp/transformers/models/gpt2 · mindnlp · GitHub

模型推理

数据处理,将向量数据变为中文数据

def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def pad(article):
        tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)
        return tokenized['input_ids']

    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    dataset = dataset.map(pad, 'article', ['input_ids'])
    
    dataset = dataset.batch(batch_size)

    return dataset
test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)
print(next(test_dataset.create_tuple_iterator(output_numpy=True)))
model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)
model.set_train(False)
model.config.eos_token_id = model.config.sep_token_id
i = 0
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():
    output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)
    output_text = tokenizer.decode(output_ids[0].tolist())
    print(output_text)
    i += 1
    if i == 1:
        break

定义process_test_dataset函数来处理测试数据集。

加载训练好的模型,使用generate方法生成摘要。

代码基于MindSpore的GPT2文本摘要模型。首先,导入了所需的库和模块,下载并加载数据集。接着,对数据集进行预处理,包括分词、填充等操作。构建GPT2ForSummarization模型,模型继承自GPT2LMHeadModel,重写construct方法。模型训练部分,设置学习率调度器、优化器和检查点回调,使用Trainer进行训练。最后,对测试数据集进行处理,使用训练好的模型进行推理,输出摘要结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1871631.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

c#学习日志用CLI(命令行窗口)创建c#工程

创建Helloworld.Proj和Program.cs两个文件然后运行即可&#xff0c;一种方法是用记事本创建&#xff0c;写入代码&#xff0c;这种比较费劲&#xff0c;主要代码如下 Program.cs中代码如下 System.Console.WriteLine("Hello World!!"); Helloworld.Proj中的代码如…

【Linux】初识操作系统

一、冯•诺依曼体系结构 在学习操作系统之前&#xff0c;我们先来认识一下冯•诺依曼体系结构&#xff0c;我们常见的计算机&#xff0c;如笔记本。我们不常见的计算机&#xff0c;如服务器&#xff0c;大部分都遵守冯诺依曼体系。 截至目前&#xff0c;我们所认识的计算机&am…

怎么在vite项目中全局导入一个scss文件

怎么在vite项目中全局导入一个scss文件 &#x1f389;&#x1f389;&#x1f389;欢迎来到我的博客,我是一名自学了2年半前端的大一学生,熟悉的技术是JavaScript与Vue.目前正在往全栈方向前进, 如果我的博客给您带来了帮助欢迎您关注我,我将会持续不断的更新文章!!!&#x1f64…

VUE项目安全漏洞扫描和修复

npm audit 1、npm audit是npm 6 新增的一个命令,可以允许开发人员分析复杂的代码并查明特定的漏洞。 2、npm audit名称执行&#xff0c;需要包package.json和package-lock.json文件。它是通过分析 package-lock.json 文件&#xff0c;继而扫描我们的包分析是否包含漏洞的。 …

C++初学者指南-2.输入和输出---从输入流错误中恢复

C初学者指南-2.输入和输出—从输入流错误中恢复 文章目录 C初学者指南-2.输入和输出---从输入流错误中恢复怎么了&#xff1f;解决方案&#xff1a;出错后重置输入流 怎么了&#xff1f; 示例&#xff1a;连续输入 int main () {cout << "i? ";int i 0;cin…

Web服务器与Apache(虚拟主机基于ip、域名和端口号)

一、Web基础 1.HTML概述 HTML&#xff08;Hypertext Markup Language&#xff09;是一种标记语音,用于创建和组织Web页面的结构和内容&#xff0c;HTML是构建Web页面的基础&#xff0c;定义了页面的结构和内容&#xff0c;通过标记和元素来实现 2.HTML文件结构 <html>…

【Linux进阶】基础IO函数详解

1.函数open和openat 调用open或openat函数可以打开或创建一个文件。 #include <fcntl.h> int open(const char *path, int ofag, ... /* mode_t mode */);int openat (int fd, const char *path, int oflag, ... /* mode_t mode */); 我们将最后一个参数写为...&#x…

Windows下activemq开启jmx

1.activemq版本信息 activemq&#xff1a;apache-activemq-5.18.4 2.Windows下activemq开启jmx 1.进入activemq conf目录&#xff0c;备份activemq.xml文件 2.编辑activemq.xml文件&#xff0c;在broker节点增加useJmx"true" <broker xmlns"http://active…

《Windows API每日一练》6.4 程序测试

前面我们讨论了鼠标的一些基础知识&#xff0c;本节我们将通过一些实例来讲解鼠标消息的不同处理方式。 本节必须掌握的知识点&#xff1a; 第36练&#xff1a;鼠标击中测试1 第37练&#xff1a;鼠标击中测试2—增加键盘接口 第38练&#xff1a;鼠标击中测试3—子窗口 第39练&…

独立付费进群系统

无授权源码 有独立分销代理端 域名防封 可对接码支付 易支付 需要用到回调 源码链接

腾讯实时语音编码大突破 电梯、地库里通话也不卡顿

腾讯宣布&#xff0c;腾讯主导的新一代实时语音编码行业标准AVS3P10&#xff0c;即将正式对外发布。由腾讯会议天籁实验室携手腾讯AI Lab研发的Penguins编解码器&#xff08;即AVS3P10行业标准的原型&#xff09;&#xff0c;把经典信号处理和最新的深度学习技术结合在一起&…

pppd 返回错误码 含义

错误码 00&#xff1a; pppd已经断开&#xff0c;或者已经成功建立连接后请求方又中 断了。 01&#xff1a; 发成了一个严重错误&#xff0c;例如系统调用失败或者访问非法内存。 02&#xff1a; 处理给定操作是检测到错误&#xff0c;例如使用两个互斥的操作。 03&#xff1a;…

测试报告-HTMLTestRunner报告优化(中/英文)

引用原始的HTMLTestRunner.py文件生成的测试报告在美观性不是很好&#xff0c;使用在此文件基础上优化后的HTMLTestReportCN.py文件(生成的报告为中文)、HTMLTestReportEN.py文件(生成的报告为英文)。 1 首先新建一个Python项目 例如&#xff1a;testHtmlReport 创建case包&am…

【JVM】Java虚拟机运行时数据分区介绍

JVM 分区&#xff08;运行时数据区域&#xff09; 文章目录 JVM 分区&#xff08;运行时数据区域&#xff09;前言1. 程序计数器2. Java 虚拟机栈3. 本地方法栈4. Java 堆5. 方法区6. 运行时常量池7. 直接内存 前言 之前在说多线程的时候&#xff0c;提到了JVM虚拟机的分区内存…

# 音频处理4_傅里叶变换

1.离散傅里叶变换 对于离散时域信号 x[n]使用离散傅里叶变换&#xff08;Discrete Fourier Transform, DFT&#xff09;进行频域分析。 DFT 将离散信号 x[n] 变换为其频谱表示 X[k]&#xff0c;定义如下&#xff1a; X [ k ] ∑ n 0 N − 1 x [ n ] e − j 2 π k n N X[k]…

Qt 使用代码布局,而不使用UI布局

一、工程的建立&#xff1a; 1、打开Qt Creator&#xff0c;文件&#xff0c;新建文件或项目 2、选择Application&#xff0c;Qt Widgets Application 3、写入名称&#xff0c;选择qmake 4、选择基类Base class&#xff0c;去除Generate form 务必选择QWidget&#xff0c;若…

读AI新生:破解人机共存密码笔记14逆强化学习算法

1. 数学保证 1.1. 如果我们要沿着新的路线重建人工智能&#xff0c;那么它的基础必须是坚实的 1.2. 通过精确的定义和一步步的严格数学证明来提供无可辩驳的保证 1.3. 希望证明一个定理&#xff1a;设计人工智能系统的一种特殊方式可以确保它…

Linux如何安装openjdk1.8

文章目录 Centosyum安装jdk和JRE配置全局环境变量验证ubuntu使用APT(适用于Ubuntu 16.04及以上版本)使用PPA(可选,适用于需要特定版本或旧版Ubuntu)Centos yum安装jdk和JRE yum install java-1.8.0-openjdk-devel.x86_64 安装后的目录 配置全局环境变量 vim /etc/pr…

Python | Leetcode Python题解之第201题数字范围按位与

题目&#xff1a; 题解&#xff1a; class Solution:def rangeBitwiseAnd(self, m: int, n: int) -> int:while m < n:# 抹去最右边的 1n n & (n - 1)return n

spring-boot-starter-json配置对象属性为空不显示

问题背景 在Spring Boot中使用spring-boot-starter-json&#xff08;通常是通过jackson实现的&#xff09;时&#xff0c;如果你希望在序列化对象时&#xff0c;如果某个属性为空&#xff0c;则不显示该属性&#xff0c;你可以使用JsonInclude注解来实现这一点。 pom.xml <…