基于MindSpore的GPT2文本摘要
数据集加载
使用nlpcc2017摘要数据,共包含5万个样本,内容是新闻正文及其摘要。
from mindnlp.utils import http_get
from mindspore.dataset import TextFileDataset
# 下载数据集
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')
# 加载数据集
dataset = TextFileDataset(str(path), shuffle=False)
数据预处理
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)
按9:1划分测试集与训练集,randomize表示不对数据进行随机排序,按照原顺序直接拆分。
原始数据格式是:
article: [CLS] article_context [SEP]
summary: [CLS] summary_context [SEP]
期望的预处理后的数据格式是:
[CLS] article_context [SEP] summary_context [SEP]
import json
import numpy as np
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):
# 加载json格式的数据并转成numpy数组
def read_map(text):
data = json.loads(text.tobytes())
return np.array(data['article']), np.array(data['summarization'])
# 使用分词器处理artical和summary。
# text=article表示主文本输入,text_pair=summary表示辅助文本
# padding指将输入序列填充或截断的最大长度
# truncation指定截断策略,only_first表示指仅截断主文本(article)。
# 通常是主文本较长需要截断,而辅助文本较短并且需要完整保留。
def merge_and_pad(article, summary):
tokenized = tokenizer(text=article, text_pair=summary,
padding='max_length', truncation='only_first', max_length=max_seq_len)
return tokenized['input_ids'], tokenized['input_ids']
# 提取article和summary
dataset = dataset.map(read_map, 'text', ['article', 'summary'])
# 将列名修改为input_ids和labels
dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])
# 将数据进行分批
dataset = dataset.batch(batch_size)
# 如果shuffle是true,则打乱数据
if shuffle:
dataset = dataset.shuffle(batch_size)
return dataset
这里的tokenizer使用BertTokenizer,因为GPT2没有中文的分词器。
from mindnlp.transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
# 使用刚刚定义的函数分成四批进行处理。
train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)
模型构建
from mindspore import ops
from mindnlp.transformers import GPT2LMHeadModel
class GPT2ForSummarization(GPT2LMHeadModel):
def construct(
self,
input_ids = None,
attention_mask = None,
labels = None,
):
outputs=super().construct(input_ids=input_ids, attention_mask=attention_mask)
shift_logits=outputs.logits[..., :-1, :]
shift_labels=labels[..., 1:]
loss=ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)
return loss
具体解释一下以上代码:
shift 的作用使为了对齐预测和标签,使模型输出和标签对应,从而得到每个位置的预测误差。
举例 I love program
在自回归模型中,模型会根据前文逐步预测(不熟悉的盆友可以看一下上一篇文章:自回归模型与文本生成方法)
- 输入 “I”,输出 “love”
- 输入 “I love”,输出 “programming”
也就是位置1的输出对应的是位置2(的标签),位置2的输出对应位置3(的标签)
对应到实际的数据就是
- outputs.logits[…, :-1, :]
去除 logits 的最后一个时间步,因为没有标签与之对应。继续以上面的为例,就是去掉programming,因为programing没有后面输出了。 - labels[…, 1:]
去除 labels 的第一个时间步,因为没有预测值与之对应。去掉I,因为I前面没有输入,因此也不是输出的一部分。
由此完成了shift错位操作。
搞好数据结构之后再计算损失
- 将shift_logits形状调整成二维张量,使每一行对应一个token的预测分布。
- 将shift_labels形状调整为一维张量,使每个元素对应一个标签。
- 使用cross_entropy()计算交叉熵损失,忽略填充token的损失。
定义学习率warmup
from mindspore import ops
from mindspore.nn.learning_rate_schedule import LearningRateSchedule
class LinearWithWarmUp(LearningRateSchedule):
"""
Warmup-decay learning rate.
"""
def __init__(self, learning_rate, num_warmup_steps, num_training_steps):
super().__init__()
self.learning_rate = learning_rate
self.num_warmup_steps = num_warmup_steps
self.num_training_steps = num_training_steps
def construct(self, global_step):
if global_step < self.num_warmup_steps:
return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate
return ops.maximum(
0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))
) * self.learning_rate
这里定义了LinearWithWarmUp作为自定义学习率调度类。它可以在训练的初始阶段进行学习率的线性预热,再在剩余的训练步骤中线性衰减。
初始化 预热步数、训练步数、学习率。
构建时如果步数小于预热步数则开始进行线性增长,从0增长到learning rate。
如果步数大于等于时,则进行线性衰减,由learning rate变回0。maximum保证不会降低到0以下。
模型训练
内容详见注释内容
# 初始化参数
num_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4
#训练的步数=数据集的大小 乘以 准备完整遍历数据集的次数(epoch)
num_training_steps = num_epochs * train_dataset.get_dataset_size()
from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel
config = GPT2Config(vocab_size=len(tokenizer))
# 初始化一个用于文本摘要的GPT2模型
model = GPT2ForSummarization(config)
# 初始化学习率调度器和adam优化器
lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)
from mindnlp._legacy.engine import Trainer
from mindnlp._legacy.engine.callbacks import CheckpointCallback
# 设置检查点回调,用于在训练过程中保存模型检查点
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization',
epochs=1, keep_checkpoint_max=2)
# 初始化训练器,设置模型,训练集,训练轮次,优化器和回调函数
trainer = Trainer(network=model, train_dataset=train_dataset,
epochs=1, optimizer=optimizer, callbacks=ckpoint_cb)
# 开启混合精度训练,以提高训练速度和节省显存
trainer.set_amp(level='O1')
# 开始训练,并指定目标列(labels)作为标签。
trainer.run(tgt_columns="labels")
模型推理
def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):
# 处理测试集的过程和训练集差不多
# 依然是提取出article和sumarization
# 再进行分词,制定最大长度和截断
def read_map(text):
data = json.loads(text.tobytes())
return np.array(data['article']), np.array(data['summarization'])
def pad(article):
tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)
return tokenized['input_ids']
dataset = dataset.map(read_map, 'text', ['article', 'summary'])
dataset = dataset.map(pad, 'article', ['input_ids'])
dataset = dataset.batch(batch_size)
return dataset
test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)
# 加载已经训练好的模型
model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)
# 设为非训练模式,可以禁用掉一些训练相关的操作如dropout
model.set_train(False)
# 模型的结束标记设置成分隔符标记,这样生成的文本遇到分隔符就会终止
model.config.eos_token_id = model.config.sep_token_id
i = 0
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():
output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)
output_text = tokenizer.decode(output_ids[0].tolist())
i += 1
总结
本章介绍了使用GPT2进行文本总结任务的基本流程,包括数据导入、数据预处理、模型训练、和模型推理。