transformer上手(10)—— 文本摘要任务

news2024/11/25 10:55:32

文本摘要是一个 Seq2Seq 任务,尽可能保留文本语义的情况下将长文本压缩为短文本。

文本摘要可以看作是将长文本“翻译”为捕获关键信息的短文本,因此大部分文本摘要模型同样采用 Encoder-Decoder 框架。当然,也有一些非 Encoder-Decoder 框架的摘要模型,例如 GPT 家族也可以通过小样本学习 (few-shot) 进行文本摘要。

下面是一些目前流行的可用于文本摘要的模型:

  • GPT-2:虽然是自回归 (auto-regressive) 语言模型,但是可以通过在输入文本的末尾添加 TL;DR 来使 GPT-2 生成摘要;
  • PEGASUS:与大部分语言模型通过预测被遮掩掉的词语来进行训练不同,PEGASUS 通过预测被遮掩掉的句子来进行训练。由于预训练目标与摘要任务接近,因此 PEGASUS 在摘要任务上的表现很好;
  • T5:将各种 NLP 任务都转换到 text-to-text 框架来完成的通用 Transformer 架构,要进行摘要任务只需在输入文本前添加 summarize: 前缀;
  • mT5:T5 的多语言版本,在多语言通用爬虫语料库 mC4 上预训练,覆盖 101 种语言;
  • BART:包含一个 Encoder 和一个 Decoder stack 的 Transformer 架构,训练目标是重构损坏的输入,同时还结合了 BERT 和 GPT-2 的预训练方案;
  • mBART-50:BART 的多语言版本,在 50 种语言上进行了预训练。

T5 模型通过模板前缀 (prompt prefix) 将各种 NLP 任务都转换到 text-to-text 框架进行预训练,例如摘要任务的前缀就是 summarize:,模型以前缀作为条件生成符合模板的文本,这使得一个模型就可以完成多种 NLP 任务:

在这里插入图片描述

这里,我们微调多语言 mT5 模型用于中文摘要任务,mT5 模型不使用前缀,但是具备 T5 模型大部分的多功能性。

1 准备数据

我们选择大规模中文短文本摘要语料库 LCSTS 作为数据集,该语料基于新浪微博短新闻构建,规模超过 200 万。

我们简单地将新闻的标题作为摘要来微调 mT5 模型以完成文本摘要任务。

该语料已经划分好了训练集、验证集和测试集,分别包含 2400591 / 10666 / 1106 个样本,一行是一个“标题!=!正文”的组合:

媒体融合关键是以人为本!=!受众在哪里,媒体就应该在哪里,媒体的体制、内容、技术就应该向哪里转变。媒体融合关键是以人为本,即满足大众的信息需求,为受众提供更优质的服务。这就要求媒体在融合发展的过程中,既注重技术创新,又注重用户体验。
1.1 构建数据集

编写继承自 Dataset 类的自定义数据集用于组织样本和标签。考虑到使用 LCSTS 两百多万条样本进行训练耗时过长,这里我们只抽取训练集中的前 20 万条数据:

from torch.utils.data import Dataset

max_dataset_size = 200000

class LCSTS(Dataset):
    def __init__(self, data_file):
        self.data = self.load_data(data_file)
    
    def load_data(self, data_file):
        Data = {}
        with open(data_file, 'rt', encoding='utf-8') as f:
            for idx, line in enumerate(f):
                if idx >= max_dataset_size:
                    break
                items = line.strip().split('!=!')
                assert len(items) == 2
                Data[idx] = {
                    'title': items[0],
                    'content': items[1]
                }
        return Data
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

train_data = LCSTS('data/lcsts_tsv/data1.tsv')
valid_data = LCSTS('data/lcsts_tsv/data2.tsv')
test_data = LCSTS('data/lcsts_tsv/data3.tsv')

下面我们输出数据集的尺寸,并且打印出一个训练样本:

print(f'train set size: {len(train_data)}')
print(f'valid set size: {len(valid_data)}')
print(f'test set size: {len(test_data)}')
print(next(iter(train_data)))

# train set size: 200000
# valid set size: 10666
# test set size: 1106
# {'title': '修改后的立法法全文公布', 'content': '新华社受权于18日全文播发修改后的《中华人民共和国立法法》,修改后的立法法分为“总则”“法律”“行政法规”“地方性法规、自治条例和单行条例、规章”“适用与备案审查”“附则”等6章,共计105条。'}

1.2 数据预处理

接下来,我们就需要通过 DataLoader 库按 batch 加载数据,将文本转换为模型可以接受的 token IDs。与翻译任务类似,我们需要运用分词器对原文和摘要都进行编码,这里我们选择 BUET CSE NLP Group 提供的 mT5 摘要模型:

from transformers import AutoTokenizer

model_checkpoint = "csebuetnlp/mT5_multilingual_XLSum"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

先尝试使用 mT5 tokenizer 对文本进行分词:

inputs = tokenizer("我叫张三,在苏州大学学习计算机。")
print(inputs)
print(tokenizer.convert_ids_to_tokens(inputs.input_ids))

# {'input_ids': [259, 3003, 27333, 8922, 2092, 261, 1083, 117707, 9792, 24920, 123553, 306, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
# ['▁', '我', '叫', '张', '三', ',', '在', '苏州', '大学', '学习', '计算机', '。', '</s>']

特殊的 Unicode 字符 以及序列结束 token </s> 表明 mT5 模型采用的是基于 Unigram 切分算法的 SentencePiece 分词器。Unigram 对于处理多语言语料库特别有用,它使得 SentencePiece 可以在不知道重音、标点符号以及没有空格分隔字符(例如中文)的情况下对文本进行分词。

与翻译任务类似,摘要任务的输入和标签都是文本,这里我们同样使用分词器提供的 as_target_tokenizer() 函数来并行地对输入和标签进行分词,并且同样将标签序列中填充的 pad 字符设置为 -100 以便在计算交叉熵损失时忽略它们,以及通过模型自带的 prepare_decoder_input_ids_from_labels 函数对标签进行移位操作以准备好 decoder input IDs:

import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForSeq2SeqLM

max_input_length = 512
max_target_length = 64

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
model = model.to(device)

def collote_fn(batch_samples):
    batch_inputs, batch_targets = [], []
    for sample in batch_samples:
        batch_inputs.append(sample['content'])
        batch_targets.append(sample['title'])
    batch_data = tokenizer(
        batch_inputs, 
        padding=True, 
        max_length=max_input_length,
        truncation=True, 
        return_tensors="pt"
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            batch_targets, 
            padding=True, 
            max_length=max_target_length,
            truncation=True, 
            return_tensors="pt"
        )["input_ids"]
        batch_data['decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels(labels)
        end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]
        for idx, end_idx in enumerate(end_token_index):
            labels[idx][end_idx+1:] = -100
        batch_data['labels'] = labels
    return batch_data

train_dataloader = DataLoader(train_data, batch_size=4, shuffle=True, collate_fn=collote_fn)
valid_dataloader = DataLoader(valid_data, batch_size=4, shuffle=False, collate_fn=collote_fn)

由于本文直接使用 Transformers 库自带的 AutoModelForSeq2SeqLM 函数来构建模型,因此我们将每一个 batch 中的数据都处理为该模型可接受的格式:一个包含 'attention_mask''input_ids''labels''decoder_input_ids' 键的字典。

下面我们尝试打印出一个 batch 的数据,以验证是否处理正确:

batch = next(iter(train_dataloader))
print(batch.keys())
print('batch shape:', {k: v.shape for k, v in batch.items()})
print(batch)
dict_keys(['input_ids', 'attention_mask', 'decoder_input_ids', 'labels'])
batch shape: {
    'input_ids': torch.Size([4, 78]), 
    'attention_mask': torch.Size([4, 78]), 
    'decoder_input_ids': torch.Size([4, 23]), 
    'labels': torch.Size([4, 23])
}
{'input_ids': tensor([
        [   259,  46420,   1083,  73451,    493,   3582,  14219,  98660, 111234,
           9455,  10139,    261,  11688,  56462,   7031,  71079,  31324,  94274,
           2037, 203743,   9911,  16834,   1107,   6929,  31063,    306,   2372,
            891,    261, 221805,   1455,  31571, 118447,    493,  56462,   7031,
          71079, 124732,   3937,  23224,   2037, 203743,   9911, 199662,  22064,
          31063,    261,   7609,   5705,  18988, 160700, 154547,  43803,  40678,
           3519,    306,      1,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0],
        [   259, 101737,  36059,    261, 157186,  47685,   8854, 124583, 218664,
           5705,   8363,   7216,  30921,  27032,  59754, 127646,  62558,  98901,
            261,   8868,   4110,   5705,  73334,  25265,  26553,   4153,    261,
           7274,  58402,   5435,  12914,    591,   2991, 162028,  22151,   4925,
         157186,  34499, 101737,  36059,  14520,  11201,  89746,  11017,    261,
           3763,   8868, 157186,  47685,   8854, 150259,  90707,   4417,  35388,
           3751,   2037,   3763, 194391,  81024,    261, 124025, 239583,  72939,
            306,   4925,  28216,  11242,  51563,   3094,    261, 157186, 142783,
           8868,  51191,  43239,   3763,    306,      1],
        [   259,  13732,   5705, 165437,  36814,  29650,    261, 120834, 201540,
          64493,  36814,  69169,    306,  13381,   5859,  14456,  21562,  16408,
         201540,   9692,   1374, 116772,  35988,   2188,  36079, 214133,    261,
          13505,   9127,   2542, 161781, 101017,    261, 101737,  36059,   7321,
          14219,   7519,  21929,    460, 100987,    261,   9903,   5848,  72308,
         101017,    261,   2123,  19394, 164872,   5162, 125883,  21562,  43138,
          37575,  15937,  66211,   5162,   3377,    848,  27349,   2446, 198562,
         154832,    261,  11883,  65386,    353, 106219,    261,  27674,    939,
          76364,   5507,  31568,   9809,  54172,      1],
        [   259,  77554,   1193,  74380,    493,    590,    487,    538,    495,
         198437,   8041,   6907, 219169, 122000,  10220,  28426,   6994,  36236,
          74380,  30733,    306,  40921, 218505,   1083,   5685,  14469,   2884,
           1637, 198437,  17723,  94708,  22695,    306,  12267,   1374,  13733,
           1543, 224495, 164497,  17286, 143553,  30464, 198437,  17723, 113940,
         176540, 143553,    306,  36017,   1374,  13733,  13342,  88397,  94708,
          22695,    261,   1083,   5685,  14469,  10458,   9692,   4070,  13342,
         115813,  27385,    306,      1,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0]]), 
 'attention_mask': tensor([
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0]]), 
 'decoder_input_ids': tensor([
        [     0,    259,  11688,  56462,   7031,  71079,  73451,   3592,   3751,
           9911,  17938,  16834,   1107,   6929,  31063,  63095,    291,      1,
              0,      0,      0,      0,      0],
        [     0,    259, 157186,  47685,   8854, 107850,  14520,  11201,  89746,
          11017,  10973,   2219, 239583,  72939, 108358,    267,   1597,  43239,
          11242,  51563,   3094,      1,      0],
        [     0,    259,  13732,   2123,  19394,  94689,   2029,  26544,  17684,
           4074,  33119,  62428,  76364,      1,      0,      0,      0,      0,
              0,      0,      0,      0,      0],
        [     0,    447,    487,    538,    495, 198437,   8041,   6907,  86248,
          74380, 100644,  12267,    338, 225859,    261,  40921,    353,   3094,
          53737,   1083,  16311,  58407,  23616]]), 
 'labels': tensor([
        [   259,  11688,  56462,   7031,  71079,  73451,   3592,   3751,   9911,
          17938,  16834,   1107,   6929,  31063,  63095,    291,      1,   -100,
           -100,   -100,   -100,   -100,   -100],
        [   259, 157186,  47685,   8854, 107850,  14520,  11201,  89746,  11017,
          10973,   2219, 239583,  72939, 108358,    267,   1597,  43239,  11242,
          51563,   3094,      1,   -100,   -100],
        [   259,  13732,   2123,  19394,  94689,   2029,  26544,  17684,   4074,
          33119,  62428,  76364,      1,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100],
        [   447,    487,    538,    495, 198437,   8041,   6907,  86248,  74380,
         100644,  12267,    338, 225859,    261,  40921,    353,   3094,  53737,
           1083,  16311,  58407,  23616,      1]])}

可以看到,DataLoader 按照我们设置的 batch_size=4 对样本进行编码,并且填充 pad token 对应的标签都被设置为 -100。我们构建的 Decoder 的输入 decoder input IDs 尺寸与标签序列完全相同,且通过向后移位在序列头部添加了特殊的“序列起始符”,例如第一个样本:

'labels': 
        [   259,  11688,  56462,   7031,  71079,  73451,   3592,   3751,   9911,
          17938,  16834,   1107,   6929,  31063,  63095,    291,      1,   -100,
           -100,   -100,   -100,   -100,   -100]
'decoder_input_ids': 
        [     0,    259,  11688,  56462,   7031,  71079,  73451,   3592,   3751,
           9911,  17938,  16834,   1107,   6929,  31063,  63095,    291,      1,
              0,      0,      0,      0,      0]

至此,数据预处理部分就全部完成了!

在大部分情况下,即使我们在 batch 数据中没有包含 decoder input IDs,模型也能正常训练,它会自动调用模型的 prepare_decoder_input_ids_from_labels 函数来构造 decoder_input_ids

2 训练模型

本文直接使用 Transformers 库自带的 AutoModelForSeq2SeqLM 函数来构建模型,因此下面只需要实现 Epoch 中的”训练循环”和”验证/测试循环”。

2.1 优化模型参数

使用 AutoModelForSeq2SeqLM 构造的模型已经封装好了对应的损失函数,并且计算出的损失会直接包含在模型的输出 outputs 中,可以直接通过 outputs.loss 获得,因此训练循环为:

from tqdm.auto import tqdm

def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
    progress_bar = tqdm(range(len(dataloader)))
    progress_bar.set_description(f'loss: {0:>7f}')
    finish_batch_num = (epoch-1) * len(dataloader)
    
    model.train()
    for batch, batch_data in enumerate(dataloader, start=1):
        batch_data = batch_data.to(device)
        outputs = model(**batch_data)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        total_loss += loss.item()
        progress_bar.set_description(f'loss: {total_loss/(finish_batch_num + batch):>7f}')
        progress_bar.update(1)
    return total_loss

验证/测试循环负责评估模型的性能。对于文本摘要任务,常用评估指标是 ROUGE 值 (short for Recall-Oriented Understudy for Gisting Evaluation),它可以度量两个词语序列之间的词语重合率。ROUGE 值的召回率表示参考摘要在多大程度上被生成摘要覆盖,如果我们只比较词语,那么召回率就是:

R e c a l = N u m b e r   o f   o v e r l a p i n g   w o r d s T o t a l   n u m b e r   o f   w o r d s   i n   r e f e r e n c e   s u m m a r y Recal = \frac{Number \ of \ overlaping \ words}{Total \ number \ of \ words \ in \ reference \ summary} Recal=Total number of words in reference summaryNumber of overlaping words

准确率则表示生成的摘要中有多少词语与参考摘要相关:

P r e c i s i o n = N u m b e r   o f   o v e r l a p i n g   w o r d s T o t a l   n u m b e r   o f   w o r d s   i n   g e n e r a t e d   s u m m a r y Precision = \frac{Number \ of \ overlaping \ words}{Total \ number \ of \ words \ in \ generated \ summary} Precision=Total number of words in generated summaryNumber of overlaping words

最后再基于准确率和召回率来计算 F1 值。实际操作中,我们可以通过 rouge 库来方便地计算这些 ROUGE 值,例如 ROUGE-1 度量 uni-grams 的重合情况,ROUGE-2 度量 bi-grams 的重合情况,而 ROUGE-L 则通过在生成摘要和参考摘要中寻找最长公共子串来度量最长的单词匹配序列,例如:

from rouge import Rouge

generated_summary = "I absolutely loved reading the Hunger Games"
reference_summary = "I loved reading the Hunger Games"

rouge = Rouge()

scores = rouge.get_scores(
    hyps=[generated_summary], refs=[reference_summary]
)[0]
print(scores)

# {
#  'rouge-1': {'r': 1.0, 'p': 0.8571428571428571, 'f': 0.9230769181065088}, 
#  'rouge-2': {'r': 0.8, 'p': 0.6666666666666666, 'f': 0.7272727223140496}, 
#  'rouge-l': {'r': 1.0, 'p': 0.8571428571428571, 'f': 0.9230769181065088}
# }

rouge 库默认使用空格进行分词,因此无法处理中文、日文等语言,最简单的办法是按字进行切分,当然也可以使用分词器分词后再进行计算,否则会计算出不正确的 ROUGE 值:

from rouge import Rouge

generated_summary = "我在苏州大学学习计算机,苏州大学很美丽。"
reference_summary = "我在环境优美的苏州大学学习计算机。"

rouge = Rouge()

TOKENIZE_CHINESE = lambda x: ' '.join(x)

# from transformers import AutoTokenizer
# model_checkpoint = "csebuetnlp/mT5_multilingual_XLSum"
# tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

# TOKENIZE_CHINESE = lambda x: ' '.join(
#     tokenizer.convert_ids_to_tokens(tokenizer(x).input_ids, skip_special_tokens=True)
# )

scores = rouge.get_scores(
    hyps=[TOKENIZE_CHINESE(generated_summary)], 
    refs=[TOKENIZE_CHINESE(reference_summary)]
)[0]
print('ROUGE:', scores)
scores = rouge.get_scores(
    hyps=[generated_summary], 
    refs=[reference_summary]
)[0]
print('wrong ROUGE:', scores)

# ROUGE: {
#  'rouge-1': {'r': 0.75, 'p': 0.8, 'f': 0.7741935433922998}, 
#  'rouge-2': {'r': 0.5625, 'p': 0.5625, 'f': 0.562499995}, 
#  'rouge-l': {'r': 0.6875, 'p': 0.7333333333333333, 'f': 0.7096774143600416}
# }
# wrong ROUGE: {
#  'rouge-1': {'r': 0.0, 'p': 0.0, 'f': 0.0}, 
#  'rouge-2': {'r': 0.0, 'p': 0.0, 'f': 0.0}, 
#  'rouge-l': {'r': 0.0, 'p': 0.0, 'f': 0.0}
# }

AutoModelForSeq2SeqLM 模型对 Decoder 的解码过程也进行了封装,我们只需要调用模型的 generate() 函数就可以自动地逐个生成预测 token。例如,我们可以直接调用预训练好的 mT5 摘要模型生成摘要(使用柱搜索解码,num_beams=4,并且不允许出现 2-gram 重复):

import torch
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

model_checkpoint = "csebuetnlp/mT5_multilingual_XLSum"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
model = model.to(device)

article_text = """
受众在哪里,媒体就应该在哪里,媒体的体制、内容、技术就应该向哪里转变。
媒体融合关键是以人为本,即满足大众的信息需求,为受众提供更优质的服务。
这就要求媒体在融合发展的过程中,既注重技术创新,又注重用户体验。
"""

input_ids = tokenizer(
    article_text,
    return_tensors="pt",
    truncation=True,
    max_length=512
)
generated_tokens = model.generate(
    input_ids["input_ids"],
    attention_mask=input_ids["attention_mask"],
    max_length=32,
    no_repeat_ngram_size=2,
    num_beams=4
)
summary = tokenizer.decode(
    generated_tokens[0],
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False
)
print(summary)

# Using cpu device
# 媒体融合发展是当下中国面临的一大难题。

当然了,摘要多个句子也没有问题:vvvvvvvvvvvvv

article_texts = [
"""
受众在哪里,媒体就应该在哪里,媒体的体制、内容、技术就应该向哪里转变。
媒体融合关键是以人为本,即满足大众的信息需求,为受众提供更优质的服务。
这就要求媒体在融合发展的过程中,既注重技术创新,又注重用户体验。
""",
"""
新华社受权于18日全文播发修改后的《中华人民共和国立法法》,
修改后的立法法分为“总则”“法律”“行政法规”“地方性法规、
自治条例和单行条例、规章”“适用与备案审查”“附则”等6章,共计105条。
"""
]

input_ids = tokenizer(
    article_texts,
    padding=True, 
    return_tensors="pt",
    truncation=True,
    max_length=512
)
generated_tokens = model.generate(
    input_ids["input_ids"],
    attention_mask=input_ids["attention_mask"],
    max_length=32,
    no_repeat_ngram_size=2,
    num_beams=4
)
summarys = tokenizer.batch_decode(
    generated_tokens,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False
)
print(summarys)

# [
#  '媒体融合发展是当下中国面临的一大难题。', 
#  '中国官方新华社周一(18日)全文播发修改后的《中华人民共和国立法法》。'
# ]

在验证/测试循环中,我们首先通过 model.generate() 函数获取预测结果,然后将预测结果和正确标签都处理为 rouge 库接受的文本列表格式(这里我们将标签序列中的 -100 替换为 pad token ID 以便于分词器解码),最后送入到 rouge 库计算各项 ROUGE 值:

import numpy as np
from rouge import Rouge

rouge = Rouge()

def test_loop(dataloader, model):
    preds, labels = [], []
    
    model.eval()
    for batch_data in tqdm(dataloader):
        batch_data = batch_data.to(device)
        with torch.no_grad():
            generated_tokens = model.generate(
                batch_data["input_ids"],
                attention_mask=batch_data["attention_mask"],
                max_length=max_target_length,
                num_beams=4,
                no_repeat_ngram_size=2,
            ).cpu().numpy()
        if isinstance(generated_tokens, tuple):
            generated_tokens = generated_tokens[0]
        label_tokens = batch_data["labels"].cpu().numpy()

        decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)

        preds += [' '.join(pred.strip()) for pred in decoded_preds]
        labels += [' '.join(label.strip()) for label in decoded_labels]
    scores = rouge.get_scores(hyps=preds, refs=labels, avg=True)
    result = {key: value['f'] * 100 for key, value in scores.items()}
    result['avg'] = np.mean(list(result.values()))
    print(f"Rouge1: {result['rouge-1']:>0.2f} Rouge2: {result['rouge-2']:>0.2f} RougeL: {result['rouge-l']:>0.2f}\n")
    return result

为了方便后续保存验证集上最好的模型,我们还在验证/测试循环中返回评估出的 ROUGE 值。

2.2 保存模型

根据模型在验证集上的性能来调整超参数以及选出最好的模型,然后将选出的模型应用于测试集以评估最终的性能。这里我们继续使用 AdamW 优化器,并且通过 get_scheduler() 函数定义学习率调度器:

from transformers import AdamW, get_scheduler

learning_rate = 2e-5
epoch_num = 10

optimizer = AdamW(model.parameters(), lr=learning_rate)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=epoch_num*len(train_dataloader),
)

total_loss = 0.
best_avg_rouge = 0.
for t in range(epoch_num):
    print(f"Epoch {t+1}/{epoch_num}\n-------------------------------")
    total_loss = train_loop(train_dataloader, model, optimizer, lr_scheduler, t+1, total_loss)
    valid_rouge = test_loop(valid_dataloader, model)
    print(valid_rouge)
    rouge_avg = valid_rouge['avg']
    if rouge_avg > best_avg_rouge:
        best_avg_rouge = rouge_avg
        print('saving new weights...\n')
        torch.save(model.state_dict(), f'epoch_{t+1}_valid_rouge_{rouge_avg:0.4f}_model_weights.bin')
print("Done!")

在开始训练之前,我们先评估一下没有经过微调的模型在 LCSTS 测试集上的性能。

test_data = LCSTS('lcsts_tsv/data3.tsv')
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=False, collate_fn=collote_fn)

test_loop(test_dataloader, model)

# Using cuda device
# 100%|███████████| 35/35 [01:07<00:00,  1.92s/it]
# Rouge1: 23.71 Rouge2: 12.20 RougeL: 20.78

可以看到预训练模型在我们测试集上的 ROUGE-1、ROUGE-2、ROUGE-L 值分别为 23.71、12.2 和 20.78,说明该模型具备文本摘要的能力,但是在“短文本新闻摘要”任务上表现不佳。然后,我们正式开始训练,完整代码如下:

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import AdamW, get_scheduler
from tqdm.auto import tqdm
from rouge import Rouge
import random
import numpy as np
import os

max_dataset_size = 200000
max_input_length = 512
max_target_length = 32
train_batch_size = 8
test_batch_size = 8
learning_rate = 2e-5
epoch_num = 3
beam_size = 4
no_repeat_ngram_size = 2

seed = 5
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

class LCSTS(Dataset):
    def __init__(self, data_file):
        self.data = self.load_data(data_file)
    
    def load_data(self, data_file):
        Data = {}
        with open(data_file, 'rt', encoding='utf-8') as f:
            for idx, line in enumerate(f):
                if idx >= max_dataset_size:
                    break
                items = line.strip().split('!=!')
                assert len(items) == 2
                Data[idx] = {
                    'title': items[0],
                    'content': items[1]
                }
        return Data
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

train_data = LCSTS('lcsts_tsv/data1.tsv')
valid_data = LCSTS('lcsts_tsv/data2.tsv')
test_data = LCSTS('lcsts_tsv/data3.tsv')

model_checkpoint = "csebuetnlp/mT5_multilingual_XLSum"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
model = model.to(device)

def collote_fn(batch_samples):
    batch_inputs, batch_targets = [], []
    for sample in batch_samples:
        batch_inputs.append(sample['content'])
        batch_targets.append(sample['title'])
    batch_data = tokenizer(
        batch_inputs, 
        padding=True, 
        max_length=max_input_length,
        truncation=True, 
        return_tensors="pt"
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            batch_targets, 
            padding=True, 
            max_length=max_target_length,
            truncation=True, 
            return_tensors="pt"
        )["input_ids"]
        batch_data['decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels(labels)
        end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]
        for idx, end_idx in enumerate(end_token_index):
            labels[idx][end_idx+1:] = -100
        batch_data['labels'] = labels
    return batch_data

train_dataloader = DataLoader(train_data, batch_size=train_batch_size, shuffle=True, collate_fn=collote_fn)
valid_dataloader = DataLoader(valid_data, batch_size=test_batch_size, shuffle=False, collate_fn=collote_fn)

def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
    progress_bar = tqdm(range(len(dataloader)))
    progress_bar.set_description(f'loss: {0:>7f}')
    finish_batch_num = (epoch-1) * len(dataloader)
    
    model.train()
    for batch, batch_data in enumerate(dataloader, start=1):
        batch_data = batch_data.to(device)
        outputs = model(**batch_data)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        total_loss += loss.item()
        progress_bar.set_description(f'loss: {total_loss/(finish_batch_num + batch):>7f}')
        progress_bar.update(1)
    return total_loss

rouge = Rouge()

def test_loop(dataloader, model, mode='Test'):
    assert mode in ['Valid', 'Test']
    preds, labels = [], []
    
    model.eval()
    for batch_data in tqdm(dataloader):
        batch_data = batch_data.to(device)
        with torch.no_grad():
            generated_tokens = model.generate(
                batch_data["input_ids"],
                attention_mask=batch_data["attention_mask"],
                max_length=max_target_length,
                num_beams=beam_size,
                no_repeat_ngram_size=no_repeat_ngram_size,
            ).cpu().numpy()
        if isinstance(generated_tokens, tuple):
            generated_tokens = generated_tokens[0]
        label_tokens = batch_data["labels"].cpu().numpy()

        decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)

        preds += [' '.join(pred.strip()) for pred in decoded_preds]
        labels += [' '.join(label.strip()) for label in decoded_labels]
    scores = rouge.get_scores(hyps=preds, refs=labels, avg=True)
    result = {key: value['f'] * 100 for key, value in scores.items()}
    result['avg'] = np.mean(list(result.values()))
    print(f"{mode} Rouge1: {result['rouge-1']:>0.2f} Rouge2: {result['rouge-2']:>0.2f} RougeL: {result['rouge-l']:>0.2f}\n")
    return result

optimizer = AdamW(model.parameters(), lr=learning_rate)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=epoch_num*len(train_dataloader),
)

total_loss = 0.
best_avg_rouge = 0.
for t in range(epoch_num):
    print(f"Epoch {t+1}/{epoch_num}\n-------------------------------")
    total_loss = train_loop(train_dataloader, model, optimizer, lr_scheduler, t+1, total_loss)
    valid_rouge = test_loop(valid_dataloader, model, mode='Valid')
    rouge_avg = valid_rouge['avg']
    if rouge_avg > best_avg_rouge:
        best_avg_rouge = rouge_avg
        print('saving new weights...\n')
        torch.save(model.state_dict(), f'epoch_{t+1}_valid_rouge_{rouge_avg:0.4f}_model_weights.bin')
print("Done!")

# Epoch 1/3
# -------------------------------
# loss: 3.544795: 100%|██████████| 6250/6250 [41:40<00:00,  2.50it/s]
# 100%|██████████████████████████| 334/334 [06:18<00:00,  1.13s/it]
# Rouge1: 33.47 Rouge2: 20.87 RougeL: 30.50

# saving new weights...

# Epoch 2/3
# -------------------------------
# loss: 3.448048: 100%|██████████| 6250/6250 [41:38<00:00,  2.50it/s]
# 100%|██████████████████████████| 334/334 [06:13<00:00,  1.12s/it]
# Rouge1: 33.87 Rouge2: 21.18 RougeL: 30.85

# saving new weights...

# Epoch 3/3
# -------------------------------
# loss: 3.398337: 100%|██████████| 6250/6250 [41:40<00:00,  2.50it/s]
# 100%|██████████████████████████| 334/334 [06:11<00:00,  1.11s/it]
# Rouge1: 33.95 Rouge2: 21.24 RougeL: 30.93

# saving new weights...

# Done!

可以看到,随着训练的进行,模型在验证集上 ROUGE 值稳步提升。因此 3 轮 Epoch 结束后,会在目录下保存 3 个模型权重:

# epoch_1_valid_rouge_28.2808_model_weights.bin
# epoch_2_valid_rouge_28.6322_model_weights.bin
# epoch_3_valid_rouge_28.7044_model_weights.bin

至此,我们对 mT5 摘要模型的训练(微调)过程就完成了。

3 测试模型

训练完成后,我们加载在验证集上性能最优的模型权重,汇报其在测试集上的性能,并且将模型的预测结果保存到文件中。

由于 AutoModelForSeq2SeqLM 对整个解码过程进行了封装,我们只需要调用 generate() 函数就可以自动通过 beam search 找到最佳的 token ID 序列,因此最后只需要再使用分词器将 token ID 序列转换为文本就可以获得生成的摘要:

test_data = LCSTS('data/lcsts_tsv/data3.tsv')
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=False, collate_fn=collote_fn)

import json

model.load_state_dict(torch.load('epoch_1_valid_rouge_6.6667_model_weights.bin'))

model.eval()
with torch.no_grad():
    print('evaluating on test set...')
    sources, preds, labels = [], [], []
    for batch_data in tqdm(test_dataloader):
        batch_data = batch_data.to(device)
        generated_tokens = model.generate(
            batch_data["input_ids"],
            attention_mask=batch_data["attention_mask"],
            max_length=max_target_length,
            num_beams=beam_size,
            no_repeat_ngram_size=no_repeat_ngram_size,
        ).cpu().numpy()
        if isinstance(generated_tokens, tuple):
            generated_tokens = generated_tokens[0]
        label_tokens = batch_data["labels"].cpu().numpy()

        decoded_sources = tokenizer.batch_decode(
            batch_data["input_ids"].cpu().numpy(), 
            skip_special_tokens=True, 
            use_source_tokenizer=True
        )
        decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)

        sources += [source.strip() for source in decoded_sources]
        preds += [pred.strip() for pred in decoded_preds]
        labels += [label.strip() for label in decoded_labels]
    scores = rouge.get_scores(
        hyps=[' '.join(pred) for pred in preds], 
        refs=[' '.join(label) for label in labels], 
        avg=True
    )
    rouges = {key: value['f'] * 100 for key, value in scores.items()}
    rouges['avg'] = np.mean(list(rouges.values()))
    print(f"Test Rouge1: {rouges['rouge-1']:>0.2f} Rouge2: {rouges['rouge-2']:>0.2f} RougeL: {rouges['rouge-l']:>0.2f}\n")
    results = []
    print('saving predicted results...')
    for source, pred, label in zip(sources, preds, labels):
        results.append({
            "document": source, 
            "prediction": pred, 
            "summarization": label
        })
    with open('test_data_pred.json', 'wt', encoding='utf-8') as f:
        for exapmle_result in results:
            f.write(json.dumps(exapmle_result, ensure_ascii=False) + '\n')

# Using cuda device
# evaluating on test set...
# 100%|██████████████████████████| 35/35 [00:42<00:00,  1.22s/it]
# Test Rouge1: 33.71 Rouge2: 20.30 RougeL: 30.42

# saving predicted results...

可以看到,经过我们的微调,模型在测试集上的 ROUGE-1、ROUGE-2 和 ROUGE-L 值分别从 23.71、12.2、20.78 提升到了 33.71、20.30、30.42,证明了我们对模型的微调是成功的。

我们打开保存预测结果的 test_data_pred.json,其中每一行对应一个样本,document 对应原文,prediction 对应模型生成的摘要,summarization 对应参考摘要。

{
  "document": "本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方:1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代人", 
  "prediction": "可穿戴产品设计原则", 
  "summarization": "可穿戴技术十大设计原则"
}

至此,我们使用 Transformers 库进行文本摘要任务就全部完成了!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1632560.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

低代码技术在构建质量管理系统中的应用与优势

引言 在当今快节奏的商业环境中&#xff0c;高效的质量管理系统对于组织的成功至关重要。质量管理系统帮助组织确保产品或服务符合客户的期望、符合法规标准&#xff0c;并持续改进以满足不断变化的需求。与此同时&#xff0c;随着技术的不断进步&#xff0c;低代码技术作为一…

Linux系统编程---线程池并发服务器

模型原理分析&#xff1a; 线程池的关键优势在于它减少了每次任务执行时创建和销毁线程的开销 线程池的组成主要分为 3 个部分&#xff0c;这三部分配合工作就可以得到一个完整的线程池&#xff1a; 1. 任务队列&#xff0c;存储需要处理的任务&#xff0c;由工作的线程来处理…

python代码实现kmeans对鸢尾花聚类

导入第三方库和模型 from sklearn import datasets import numpy as np import matplotlib.pyplot as plt from sklearn.cluster import KMeans2、创建画图函数 def draw_result(train_x, labels, cents, title):n_clusters np.unique(labels).shape[0]#获取类别个数color …

esp32s3使用psram后音频播报不了的问题解决记录

idf.py menuconfig开启psram后会报错 提示需要打补丁&#xff1a; 根据提示切换到IDF_PATH目录&#xff0c;然后执行git apply %ADF_PATH%/ida_patches/idf5.0_freertos.patch打补丁。 再次编译提示如下错误&#xff1a; assert failed: spi_flash_disable_interrupts_cach…

嵌入式学习,方法、交流很重要

关注、星标公众号&#xff0c;直达精彩内容 ID&#xff1a;技术让梦想更伟大 整理&#xff1a;李肖遥 Who Am I 大家好&#xff0c;我是「逍遥的小蜜圈」星球的星主&#xff0c;如果大家关注我早一点&#xff0c;一定看了我的简单的自我介绍&#xff0c;关于我 — 聊聊自己的经…

【Python网络爬虫】python爬虫用正则表达式进行数据清洗与处理

&#x1f517; 运行环境&#xff1a;PYTHON &#x1f6a9; 撰写作者&#xff1a;左手の明天 &#x1f947; 精选专栏&#xff1a;《python》 &#x1f525; 推荐专栏&#xff1a;《算法研究》 #### 防伪水印——左手の明天 #### &#x1f497; 大家好&#x1f917;&#x1f91…

QT学习篇—qt软件安装

qt下载网址http://download.qt.io/new_archive/qt/ QT官网Qt | Tools for Each Stage of Software Development LifecycleAll the essential Qt tools for all stages of Software Development Lifecycle: planning, design, development, testing, and deployment.https:…

CSS样式特异性5层次详解

你好&#xff0c;我是云桃桃。 一个希望帮助更多朋友快速入门 WEB 前端的程序媛。 云桃桃-大专生&#xff0c;一枚程序媛&#xff0c;感谢关注。回复 “前端基础题”&#xff0c;可免费获得前端基础 100 题汇总&#xff0c;回复 “前端工具”&#xff0c;可获取 Web 开发工具合…

FANUC机器人SOCKET断开KAREL程序编写

一、添加一个.KL文件创建编辑断开指令 添加一个KL文件用来创建karel程序中socket断开指令 二、断开连接程序karel代码 PROGRAM SOC_DIS %COMMENT SOCKET断开 %INCLUDE klevccdf VAR str_input,str_val : STRING[20] status,data_type,int_val : INTEGER rel_val : REALBEGING…

全球首发!龙蜥社区助力 Intel SPR 加速器上云

编者按&#xff1a;云原生平台下芯片的竞争力日渐增强&#xff0c;加速器如何在赛道上体现竞争力。龙蜥社区开发者、阿里云高级研发工程师易兴睿介绍运用龙蜥操作系统提供的解决方案&#xff0c;依靠 Intel SPR 平台专用硬件加速器&#xff0c;实现云原生场景下 Envoy 网关加速…

微信小程序 request 配置了服务器域名后 发布体验版无法访问

问题描述 在微信小程序公众平台配置了测试服务器域名后&#xff0c;发布了体验版进行测试&#xff0c;发现网络请求不通&#xff0c;打开调试也依然无法访问。 解决步骤&#xff1a; 1.首先根据小程序文档网络模块的使用说明&#xff0c;一步步排查域名证书是否符合规范&…

Llama3 mac本地部署教程

1.下载的软件清单&#xff1a; ollama下载&#xff1a; Download Ollama on macOS nodejs下载&#xff1a; Node.js — Download Node.js 2.安装 安装Ollama 下载之后打开&#xff0c;直接点击Next以及Install安装ollama到命令行。安装完成后界面上会提示ollama run llam…

在Docker中部署Java应用:Java版本隔离的实践案例

在Docker中部署Java应用&#xff1a;Java版本隔离的实践案例 人生就是一场又一场的相遇&#xff0c;一个明媚&#xff0c;一个忧伤&#xff0c;一个华丽&#xff0c;一个冒险&#xff0c;一个倔强&#xff0c;一个柔软&#xff0c;最后那个正在成长。 背景需求 在软件开发和部…

18 python定制篇-开发平台Ubuntu

第 18 章Linux 之 Python 定制篇-Python 开发平台 Ubuntu 18.1 Ubuntu 介绍 Ubuntu&#xff08;友帮拓、优般图、乌班图&#xff09;是一个以桌面应用为主的开源 GNU/Linux 操作系统&#xff0c;Ubuntu 是基于 GNU/Linux&#xff0c; 支持 x86、amd64&#xff08;即 x64&…

PG修改端口号与error: could not connect to server: could not connect to server 问题解决

刚开始学习PG修改端口号之后数据库端口号没变。 修改端口号&#xff1a;/usr/local/pgsql/data中的postgresql.conf中 修改后并不能直接生效需要重启PG&#xff1a; /usr/local/pgsql/bin/pg_ctl -D /usr/local/pgsql/data -l /usr/local/pgsql/data/logfile restart重启后新…

如何免费生成文本二维码?文字生成二维码的方法

随着信息技术的不断发展&#xff0c;文本二维码作为一种简便、高效的信息分享方式&#xff0c;受到了越来越多人的关注和应用。文本二维码是将文本信息编码成二维码的形式&#xff0c;通过扫描二维码即可快速获取文本内容&#xff0c;为信息分享和传播提供了全新的可能性。 便…

美富特 | 邀您参加2024全国水科技大会暨技术装备成果展览会

王涛 四川美源环能科技有限公司 技术总监 报告题目&#xff1a;绿色智慧水岛如何助力工业园区污水及再生水资源化利用降碳增效 拥有十余年的环保行业从业经验&#xff0c;对各类前沿物化、生化及膜技术均有丰富的研发、设计及应用经验&#xff0c;先后参与多项重点核心技术…

跨境电商亚马逊、虾皮等平台做测评要用什么IP?

IP即IP地址&#xff0c;IP地址是指互联网协议地址&#xff08;英语&#xff1a;Internet Protocol Address&#xff0c;又译为网际协议地址&#xff09;&#xff0c;是IP Address的缩写&#xff0c;IP地址是IP协议提供的一种统一的地址格式 功能&#xff1a;它为互联网上的每一…

密码学python库PBC安装使用

初始化 使用环境云服务器&#xff08;移动云可以免费使用一个月&#xff09; 选择ubuntu18.04-64位 第一次进入linux命令行之后是没有界面显示的&#xff0c;需要在命令行下载。 这里按照其他云平台操作即可&#xff1a;Ubuntu18.04 首次使用配置教程(图形界面安装) 记录好登录…

软件工程物联网方向嵌入式系统复习笔记--嵌入式系统基础

1 嵌入式系统基础 1.1 嵌入式系统基础 1.1.1 嵌入式系统概念 嵌入式系统一般定义 是指以应用为中心、以计算机技术为基础、软件硬件可裁剪、适应应用系统对功能、可靠性、成本、体积、功耗严格要求的专用计算机系统。 就像一般的计算机系统包括软件和硬件一样&#xff0c;…