文本摘要,基于Pytorch和Hugging Face Transformers构建示例,有源码

news2024/9/27 17:37:10

文本摘要的常见问题和解决方法概述,以及使用Hugging Face Transformers库构建基于新浪微博数据集的文本摘要示例。

作 者丨程旭源
学习笔记

1 前言简介

文本摘要旨在将文本或文本集合转换为包含关键信息的简短文本。主流方法有两种类型,抽取式和生成式。常见问题:抽取式摘要的内容选择错误、语句连贯性差、灵活性差。生成式摘要受未登录词、词语重复等问题影响。
文本摘要的分类有很多,比如单文档多文档摘要、多语言摘要、论文生成(摘要、介绍、重点陈述等每个章节的生成)、医学报告生成、情感类摘要(观点、感受、评价等的摘要)、对话摘要等。主流解决方法主要是基于深度学习、强化学习、迁移学习等方法,有大量的相关论文可以解读和研究。
抽取式的代表方法有TextRank、BertSum[1],从原文中抽取出字词组合成新的摘要。TextRank仿照PageRank,句子作为节点,构造无向有权边,权值为句子相似度。
生成式摘要方法是有PGN[2]、GPT、BART[3]、BRIO[4]、GSum[5]、SimCLS[6]、CIT+SE[7]等。
对于生成式摘要,不得不简单再提一下PGN模型:
Pointer Generator Network结构[2]

生成式存在的一个大问题是OOV未登录词,Google的PGN使用 copy mechanism和Coverage mechanism,能对于未遇见过的词直接复制用于生成部分,比如上面的“2-0”,同时避免了重复生成。

2 数据集

文本摘要的数据集有很多,这里使用的是Lcstsm[10]大规模中文短文本摘要语料库,取自于新浪微博,训练集共有240万条,为了快速得到结果和理解过程,可以自己设置采用数据的大小。比如训练集设置10万条。

3 模型选型

预训练模型选的是"csebuetnlp/mT5_multilingual_XLSum",是ACL-IJCNLP 2021的论文XL-Sum[8]中开源的模型,他们也是基于多语言T5模型 (mT5)在自己准备的44中语言的摘要数据集上做了fine-tune得到的模型。mT5[9]是按照与T5类似的方法进行训练的大规模多语言预训练模型。
如何加载模型?使用Hugging Face Transformers库可以两行代码即可加载。
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_checkpoint = “csebuetnlp/mT5_multilingual_XLSum”
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

4 数据预处理

加载模型和tokenizer后对自定义文本分词:

mT5 模型采用的是基于 Unigram 切分的 SentencePiece 分词器,Unigram 对于处理多语言语料库特别有用。SentencePiece 可以在不知道重音、标点符号以及没有空格分隔字符(例如中文)的情况下对文本进行分词。

摘要任务的输入和标签都是文本,所以我们要做这几件事:
1、使用 as_target_tokenizer() 函数并行地对输入和标签进行分词,并标签序列中填充的 pad 字符设置为 -100 ,这样可以在计算交叉熵损失时忽略掉;

2、对标签进行移位操作,来准备 decoder_input_ids,有封装好的prepare_decoder_input_ids_from_labels函数。

3、将每一个 batch 中的数据都处理为模型可接受的输入格式:包含 ‘attention_mask’、‘input_ids’、‘labels’ 和 ‘decoder_input_ids’ 键的字典。

在编程的时候,可以打印出一个 batch 的数据看看:

5 模型训练和评测

因为Transformers包已经帮我们封装好了模型、损失函数等内容,我们只需调用并定义好训练循环即可:

def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
    progress_bar = tqdm(range(len(dataloader)))
    progress_bar.set_description(f'loss: {0:>7f}')
    finish_batch_num = (epoch-1) * len(dataloader)
    
    model.train()
    for batch, batch_data in enumerate(dataloader, start=1):
        batch_data = batch_data.to(device)
        outputs = model(**batch_data)
        loss = outputs.loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        total_loss += loss.item()
        
        progress_bar.set_description(f'loss: {total_loss/(finish_batch_num + batch):>7f}')
        progress_bar.update(1)
    
    return total_loss

训练的过程中,不断的优化模型权重参数,那么如何评估模型的性能、如何确保模型往更好的方向优化呢?用什么评估指标呢?

我们定义一个测试循环负责评估模型的性能,指标就使用ROUGE(Recall-Oriented Understudy for Gisting Evaluation),它可以度量两个词语序列之间的词语重合率。

ROUGE 值的召回率,表示被参考摘要 reference summary有多大程度上被生成摘要 generated summary覆盖,精确率则表示生成摘要中有多少词语与被参考摘要相关。

那么,如果我们只比较词语,召回率是:
Recall = 重叠词数/被参考摘要总词数
Precision=重叠词数/生成摘要的总词数

已经有rouge的Python包[11],使用pip安装即可

pip install rouge

rouge库官方示例,“f” stands for f1_score, “p” stands for precision, “r” stands for recall.

ROUGE-1 度量 uni-grams 的重合情况,ROUGE-2 度量 bi-grams 的重合情况,而 ROUGE-L 则通过在生成摘要和参考摘要中寻找最长公共子串来度量最长的单词匹配序列。

另外要注意的是,rouge 库默认使用空格进行分词,中文需要按字切分,也可以使用分词器分词后再计算。

Transformers对解码过程也进行了封装,我们只需要调用 generate() 函数就可以自动地逐个生成预测 token。

例如我们使用马来亚大学的介绍生成一个摘要:“马来亚大学是马来西亚历史最悠久的高等教育学府。“

所以,在测试过程中,我们通过generate() 函数获取预测结果,然后将预测结果和正确标签都处理为 rouge 库接受的格式,最后计算各项的ROUGE值即可:

def test_loop(dataloader, model, tokenizer):
    preds, labels = [], []
    rouge = Rouge()
    
    model.eval()
    with torch.no_grad():
        for batch_data in tqdm(dataloader):
            batch_data = batch_data.to(device)
            # 获取预测结果
            generated_tokens = model.generate(batch_data["input_ids"],
                                              attention_mask=batch_data["attention_mask"],
                                              max_length=max_target_length,
                                              num_beams=beam_search_size,
                                              no_repeat_ngram_size=no_repeat_ngram_size,
                                              ).cpu().numpy()
            if isinstance(generated_tokens, tuple):
                generated_tokens = generated_tokens[0]
            decoded_preds = tokenizer.batch_decode(generated_tokens, 
                                                   skip_special_tokens=True, 
                                                   clean_up_tokenization_spaces=False)
            
            label_tokens = batch_data["labels"].cpu().numpy()
            # 将标签序列中的 -100 替换为 pad token ID 以便于分词器解码
            label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)
            decoded_labels = tokenizer.batch_decode(label_tokens, 
                                                    skip_special_tokens=True, 
                                                    clean_up_tokenization_spaces=False)
            # 处理为 rouge 库接受的文本列表格式
            preds += [' '.join(pred.strip()) for pred in decoded_preds]
            labels += [' '.join(label.strip()) for label in decoded_labels]
    # rouge 库计算各项 ROUGE 值
    scores = rouge.get_scores(hyps=preds, refs=labels, avg=True)
    result = {key: value['f'] * 100 for key, value in scores.items()}
    result['avg'] = np.mean(list(result.values()))
    return result


6 模型保存

优化器我们选使用AdamW,并且通过 get_scheduler()函数定义学习率调度器。

在每一个epoch中,调用上面定义的train_loop和test_loop,模型在验证集上的rouge分数用来调整超参数和选出最好的模型,最后使用最好的模型跑测一下测试集来评估最终的性能。

""" Train the model """
total_steps = len(train_dataloader) * num_train_epochs
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": weight_decay},
    {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
]
warmup_steps = int(total_steps * warmup_proportion)
optimizer = AdamW(
    optimizer_grouped_parameters, 
    lr=learning_rate, 
    betas=(adam_beta1, adam_beta2), 
    eps=adam_epsilon
)
lr_scheduler = get_scheduler(
    'linear',
    optimizer, 
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

# Train!
logger.info("***** Running training *****")
logger.info(f"Num examples - {len(train_data)}")
logger.info(f"Num Epochs - {num_train_epochs}")
logger.info(f"Total optimization steps - {total_steps}")

total_loss = 0.
best_avg_rouge = 0.
for epoch in range(num_train_epochs):
    print(f"Epoch {epoch+1}/{num_train_epochs}\n" + 30 * "-")
    total_loss = train_loop(train_dataloader, model, optimizer, lr_scheduler, epoch, total_loss)
    dev_rouges = test_loop(dev_dataloader, model, tokenizer)
    logger.info(f"Dev Rouge1: {dev_rouges['rouge-1']:>0.2f} Rouge2: {dev_rouges['rouge-2']:>0.2f} RougeL: {dev_rouges['rouge-l']:>0.2f}")
    rouge_avg = dev_rouges['avg']
    if rouge_avg > best_avg_rouge:
        best_avg_rouge = rouge_avg
        logger.info(f'saving new weights to {output_dir}...\n')
        save_weight = f'epoch_{epoch+1}_rouge_{rouge_avg:0.4f}_weights.bin'
        torch.save(model.state_dict(), os.path.join(output_dir, save_weight))
logger.info("Done!")


7 总结

文本摘要和文本生成是自然语言处理中非常重要和常见的任务,本文使用生成式方法做文本摘要,文本生成还可以应用于其他场景下,比如给定一个句子,生成多个与其相似、语义相同的句子,这里也Mark一下,后面再写一篇相关文章。

文本摘要还存在很多问题,有待研究者们进一步探索,比如:长程依赖、新颖性、暴露偏差、评估和损失不匹配、缺乏概括性、虚假事实、不连贯等,此外,还有摘要任务特定的问题,比如:如何确定关键信息,而不仅仅是简单地句子压缩。每个问题的解决或方法优化都能发表论文,相关的论文有太多了。后续可以抽关键的解决方案聊聊。

代码近期整理后上传Github,链接见文末留言处。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/150832.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Nodejs模块的封装(数据库Mysql)

文章目录项目结构本次演示需要使用的第三方包为1.app.js相关配置2.router下的user.js相关配置3.db/index.js文件相关操作4.router_handler下的user.js相关操作项目结构 后面的项目相关文件的创建步骤按照我写的博客从上往下一步一步来 本次演示需要使用的第三方包为 "cor…

【操作系统实验/Golang】实验4:虚拟内存页面置换算法

目录 1 实验问题描述 2 测试数据 3 流程图 4 实验结果 4 实验代码 1 实验问题描述 设计程序模拟先进先出FIFO,最佳置换OPT和最近最久未使用LRU页面置换算法的工作过程。 假设内存中分配给每个进程的最小物理块数为m,在进程运行过程中要访问的页面个…

【Leetcode面试常见题目题解】1. 两数相加

题目描述 本文是leetcode第2题的题解,题目描述摘自leetcode。如下 给你两个 非空 的链表,表示两个非负的整数。它们每位数字都是按照 逆序 的方式存储的,并且每个节点只能存储 一位 数字。 请你将两个数相加,并以相同形式返回一个…

海外服务器提供商选择中存在哪些风险?

开展海外业务时,毫无疑问,选择一个高质量的海外服务器提供商可以省去不少麻烦。但是,同时有一些海外服务商需要避开。毕竟一个服务器不靠谱,这跟提供商有很大的原因。下面主要是关于低于标准的海外服务器提供商的一些潜在风险。 1…

ES6中字符串和数组新增的方法

ES6中字符串和数组新增的方法一、字符串中新增的方法1、模板字符串 (表达式、函数的调用、变量)2、repeat(次数)函数 : 将目标字符串重复N次,返回一个新的字符串,不影响目标字符串3、includes()函数 :判断字符串中是否含有指定的子字符串,返回…

mysql 8 新旧密码可以同时生效

在MySQL8.0以前版本,给MySQL更改密码,明确写到开发规范中,拒绝更在线更改更密码,因为在8.0以前操作非常麻烦且不太完美。 MySQL 8.0之前的处理方法: 1. 创建一个同样权限的帐号通过 show grants for ‘user_name’1…

通用vue编辑按钮和新建按钮事件逻辑

一、编辑按钮对话框 1.首先先创建一个文件夹page-model,在里面使用elemengt-plus提供的对话框组件el-dialog。 2.在page-model里面去使用之前封装好的form表单,就是之前封装好的搜索组件的hy-form 3.在form组件里面加一个插槽,对应 page-m…

微信小程序:会议OA项目-首页

目录 一、flex布局 Flex布局简介 什么是flex布局? flex属性 flex的属性 二、轮播图组件及mockjs的使用 三、会议OA小程序首页布局 一、flex布局 Flex布局简介 布局的传统解决方案,基于盒状模型,依赖 display属性 position属性 float…

CompletableFuture详解

CompletableFuture详解 概要 RunnableThread虽然提供了多线程的能力但是没有返回值。CallableThread的方法提供多线程和返回值的能力但是在获取返回值的时候会阻塞主线程。 上述的情况只适合不关心返回值,只要提交的Task执行了就可以。另外的就是能够容忍等待。 C…

Java并发容器

一、并发容器总结1、大部分在 java.util.concurrent 包中。ConcurrentHashMap: 线程安全的HashMapCopyOnWriteArrayList: 线程安全的List,在读多写少的场合性能非常好,远远好于Vector.ConcurrentLinkedQueue: 高效的并发队列,使用链表实现。可…

[ 数据结构 ] 平衡二叉树(AVL)--------左旋、右旋、双旋

0 引出 数列{1,2,3,4,5,6},要求创建一颗二叉排序树(BST), 并分析问题所在 回顾:二叉搜索树 左子树全部为空,从形式上看,更像一个单链表.插入速度没有影响查询速度明显降低(因为需要依次比较), 不能发挥 BST的优势,因为每次还需要…

第四十二章 动态规划——数字三角形模型

第四十二章 动态规划——数字三角形模型一、数字三角形模型1、什么是数字三角形模型二、例题1、AcWing 1015. 摘花生(1)问题(2)思路状态表示状态转移循环设计初末状态(3)代码2、AcWing 1018. 最低通行费&am…

第二章:感知机

文章目录2.1感知机模型2.2感知机策略2.3梯度下降法2.4感知机-原始算法形式2.5感知机-原始算法实例2.6感知机-对偶形式2.7感知机-对偶形式实例2.8感知机算法收敛性证明如下:2.1感知机模型 2.2感知机策略 损失函数非负,策略是选择使其最小的模型参数 2.3梯…

【Linux工具】-gcc/g++

gcc/g一,简介二,代码的翻译过程1,预处理2,编译3,汇编4,链接(1)静态库(2)动态库(3)动静态库比较三,常见选项一,…

模板(C++)

模板(C)泛型编程函数模板概念格式原理实例化匹配原则类模板定义格式实例化泛型编程 如何实现一个通用的交换函数呢? void Swap(int& left, int& right) {int tmp left;left right;right tmp; }void Swap(double& left, doub…

【Spring6源码・IOC】BeanDefinition的加载

哎呀,又是午夜时分,又是一个失眠的夜晚,和去年一样,记得去年今日,也是睡不着觉,就注册了csdn的账号,开始写东西,csdn真是深夜最好的安魂剂。 Spring都发布了6.0,这不赶紧…

2022.12青少年软件编程(Python)等级考试试卷(三级)

2022.12.10青少年软件编程(Python)等级考试试卷(三级) 一、单选题(共25题,共50分) 1.列表L1中全是整数,小明想将其中所有奇数都增加1,偶数不变,于是编写了如下图所示的代码。 请问,图中红线处,代码应该是?(D) A. x || 2 B. x ^ 2 C. x && 2 D. x %…

JS日期与字符串相互转换(时间格式化YYYY-MM-DD,Dayjs的使用)

JS日期与字符串相互转换——JS封装函数,Dayjs转换时间格式相关文章调用场景复现一、JS封装函数1、日期转字符串2、字符串转日期二、 Dayjs转换时间格式1、Dayjs快速安装与使用2、Dayjs格式化日期相关文章调用 文章内容文章链接JS数组对象——根据日期进行排序&…

南邮数据结构考研常用算法

1.链表 在带头结点的链表中,删除所有值为x的结点 void Del_X(Linklist &L,ElemType x){LNode *pL->next, *preL,*q;while (p!null){if (p->datax){qp;pp->next;pre->nextp;free(q);}else{prep;pp->next;}} }使用单链表进行插入排序 ListNode*…

数组常用方法总结 (1) :pop / push / shift / unshift

pop 从尾部删除一项。改变原有数组。返回删除掉的内容。 <template><div class"myBlock"><div class"tableBlock"><div class"title">{{ newObject ? "操作后的数组" : "操作前的数组" }}</d…