BigBird:大鸟模型中文生成式长文本摘要实践

news2024/9/28 1:19:09

1、介绍

BigBird 是一种基于稀疏注意力的Transformer,可将基于Transformer的模型(例如 BERT)扩展到更长的序列。

论文:https://arxiv.org/abs/2007.14062

代码:https://github.com/google-research/bigbird

BigBird模型实现了三种注意力机制,分别为随机注意力窗口注意力全局注意力,这与LongFormer几乎相似,详细原理见论文。
在这里插入图片描述

2、中文Big Bird获取

目前没有好的BigBird开源权重,但是,通过实践,我们可以将开源的中文BART模型转换并得到BigBird的权重。

bart-chinese-base地址:https://huggingface.co/fnlp/bart-base-chinese

详细操作代码如下:

#!/usr/bin/env python
# _*_coding:utf-8_*_
# Author   :    Junhui Yu
# Time     :    2023/2/27 14:47

import logging

from transformers import BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration, BertTokenizer
from transformers import BartForConditionalGeneration

logger = logging.getLogger("YUNLP")
logging.basicConfig(level=logging.INFO)

max_position_embeddings = 4096

led_config = BigBirdPegasusConfig(
    vocab_size=51271,
    max_position_embeddings=max_position_embeddings,
    encoder_layers=6,
    encoder_ffn_dim=3072,
    encoder_attention_heads=12,
    decoder_layers=6,
    decoder_ffn_dim=3072,
    decoder_attention_heads=12,
    encoder_layerdrop=0.0,
    decoder_layerdrop=0.0,
    use_cache=True,
    is_encoder_decoder=True,
    activation_function="gelu_new",
    d_model=768,
    dropout=0.1,
    attention_dropout=0.0,
    activation_dropout=0.0,
    init_std=0.02,
    decoder_start_token_id=102,
    classifier_dropout=0.0,
    scale_embedding=True,
    pad_token_id=0,
    bos_token_id=101,
    eos_token_id=102,
    attention_type="block_sparse",
    block_size=64,
    num_random_blocks=3,
    use_bias=False,
)
bigbirdpegasus_model = BigBirdPegasusForConditionalGeneration(led_config)
print(bigbirdpegasus_model)
model_path = '/remote-home/TCCI01/bert/bart-base-chinese'
bart_model = BartForConditionalGeneration.from_pretrained(model_path)
tokenizer = BertTokenizer.from_pretrained(model_path)

current_max_pos, embed_size = bart_model.model.encoder.embed_positions.weight.shape
new_encoder_pos_embed = bart_model.model.encoder.embed_positions.weight.new_empty(max_position_embeddings, embed_size)

k = 0
step = current_max_pos - 2
encoder_position_embeddings = bart_model.model.encoder.embed_positions.weight[2:]
while k < max_position_embeddings:
    new_encoder_pos_embed[k:(k + step)] = encoder_position_embeddings
    k += step
bigbirdpegasus_model.base_model.encoder.embed_positions.weight.data = new_encoder_pos_embed

current_max_pos, embed_size = bart_model.model.decoder.embed_positions.weight.shape
new_decoder_pos_embed = bart_model.model.decoder.embed_positions.weight.new_empty(max_position_embeddings, embed_size)

k = 0
step = current_max_pos - 2
decoder_position_embeddings = bart_model.model.decoder.embed_positions.weight[2:]
while k < max_position_embeddings:
    new_decoder_pos_embed[k:(k + step)] = decoder_position_embeddings
    k += step
bigbirdpegasus_model.base_model.decoder.embed_positions.weight.data = new_decoder_pos_embed

for i, (bart_encoder_layer, bigbirdpegasus_encoder_layer) in enumerate(
        zip(bart_model.model.encoder.layers, bigbirdpegasus_model.base_model.encoder.layers)):
    bigbirdpegasus_encoder_layer.self_attn.self.key.weight = bart_encoder_layer.self_attn.k_proj.weight
    bigbirdpegasus_encoder_layer.self_attn.self.query.weight = bart_encoder_layer.self_attn.q_proj.weight
    bigbirdpegasus_encoder_layer.self_attn.self.value.weight = bart_encoder_layer.self_attn.v_proj.weight
    bigbirdpegasus_encoder_layer.self_attn.output.weight = bart_encoder_layer.self_attn.out_proj.weight
    bigbirdpegasus_encoder_layer.self_attn_layer_norm = bart_encoder_layer.self_attn_layer_norm
    bigbirdpegasus_encoder_layer.fc1 = bart_encoder_layer.fc1
    bigbirdpegasus_encoder_layer.fc2 = bart_encoder_layer.fc2
    bigbirdpegasus_encoder_layer.final_layer_norm = bart_encoder_layer.final_layer_norm

for i, (bart_decoder_layer, bigbirdpegasus_decoder_layer) in enumerate(
        zip(bart_model.model.decoder.layers, bigbirdpegasus_model.base_model.decoder.layers)):
    bigbirdpegasus_decoder_layer.self_attn.k_proj.weight = bart_decoder_layer.self_attn.k_proj.weight
    bigbirdpegasus_decoder_layer.self_attn.q_proj.weight = bart_decoder_layer.self_attn.q_proj.weight
    bigbirdpegasus_decoder_layer.self_attn.v_proj.weight = bart_decoder_layer.self_attn.v_proj.weight
    bigbirdpegasus_decoder_layer.self_attn.out_proj.weight = bart_decoder_layer.self_attn.out_proj.weight
    bigbirdpegasus_decoder_layer.self_attn_layer_norm = bart_decoder_layer.self_attn_layer_norm
    bigbirdpegasus_decoder_layer.encoder_attn.k_proj.weight = bart_decoder_layer.encoder_attn.k_proj.weight
    bigbirdpegasus_decoder_layer.encoder_attn.q_proj.weight = bart_decoder_layer.encoder_attn.q_proj.weight
    bigbirdpegasus_decoder_layer.encoder_attn.v_proj.weight = bart_decoder_layer.encoder_attn.v_proj.weight
    bigbirdpegasus_decoder_layer.encoder_attn_layer_norm = bart_decoder_layer.encoder_attn_layer_norm

    bigbirdpegasus_decoder_layer.fc1 = bart_decoder_layer.fc1
    bigbirdpegasus_decoder_layer.fc2 = bart_decoder_layer.fc2
    bigbirdpegasus_decoder_layer.final_layer_norm = bart_decoder_layer.final_layer_norm

bigbirdpegasus_model.lm_head = bart_model.lm_head

logger.info("convert bart-base-chinese to bigbird success")
bigbirdpegasus_model.save_pretrained("./bigbird")
tokenizer.save_pretrained("./bigbird")

3、训练数据

长文本摘要数据集:NLPCC共50000条数据,title长度:最大长度128,最小长度17;content:最大长度 22312,最小长度52。

数据样例:

[
  {
    "title": "知情人透露章子怡怀孕后,父母很高兴。章母已开始悉心照料。据悉,预产期大概是12月底",
    "content": "四海网讯,近日,有媒体报道称:章子怡真怀孕了!报道还援引知情人士消息称,“章子怡怀孕大概四五个月,预产期是年底前后,现在已经不接工作了。”这到底是怎么回事?消息是真是假?针对此消息,23日晚8时30分,华西都市报记者迅速联系上了与章子怡家里关系极好的知情人士,这位人士向华西都市报记者证实说:“子怡这次确实怀孕了。她已经36岁了,也该怀孕了。章子怡怀上汪峰的孩子后,子怡的父母亲十分高兴。子怡的母亲,已开始悉心照料女儿了。子怡的预产期大概是今年12月底。”当晚9时,华西都市报记者为了求证章子怡怀孕消息,又电话联系章子怡的亲哥哥章子男,但电话通了,一直没有人接听。有关章子怡怀孕的新闻自从2013年9月份章子怡和汪峰恋情以来,就被传N遍了!不过,时间跨入2015年,事情却发生着微妙的变化。2015年3月21日,章子怡担任制片人的电影《从天儿降》开机,在开机发布会上几张合影,让网友又燃起了好奇心:“章子怡真的怀孕了吗?”但后据证实,章子怡的“大肚照”只是影片宣传的噱头。过了四个月的7月22日,《太平轮》新一轮宣传,章子怡又被发现状态不佳,不时深呼吸,不自觉想捂住肚子,又觉得不妥。然后在8月的一天,章子怡和朋友吃饭,在酒店门口被风行工作室拍到了,疑似有孕在身!今年7月11日,汪峰本来在上海要举行演唱会,后来因为台风“灿鸿”取消了。而消息人士称,汪峰原来打算在演唱会上当着章子怡的面宣布重大消息,而且章子怡已经赴上海准备参加演唱会了,怎知遇到台风,只好延期,相信9月26日的演唱会应该还会有惊喜大白天下吧。"
  },
  ...
]

4、训练代码

#!/usr/bin/env python
# _*_coding:utf-8_*_
# Author   :    Junhui Yu
# Time     :    2023/2/27 14:55

import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '0'

import logging
import datasets
import numpy as np
import lawrouge
from transformers import (
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    BigBirdPegasusForConditionalGeneration,
    BertTokenizer,
    BigBirdConfig
)

from datasets import load_dataset

logger = logging.getLogger("YUNLP")
logging.basicConfig(level=logging.INFO)

dataset = load_dataset('json', data_files="./data/nlpcc_data/nlpcc_data.json")
dataset = dataset.shuffle(seeds=42)

model_path = "./bigbird"

config = BigBirdConfig.from_pretrained(model_path)
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BigBirdPegasusForConditionalGeneration.from_pretrained(model_path, config=config)


def flatten(example):
    return {
        "text": example["content"],
        "summary": example["title"],
    }


dataset = dataset["train"].map(flatten, remove_columns=["title", "content"])  # , remove_columns=["title", "content"]

max_input_length = 2048
max_target_length = 1024


def preprocess_function(examples):
    inputs = [doc for doc in examples["text"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


dataset = dataset.shuffle()

train_data_txt, validation_data_txt = dataset.train_test_split(test_size=0.1, shuffle=True, seed=42).values()
tokenized_datasets = datasets.DatasetDict({
    "train": train_data_txt,
    "validation": validation_data_txt
}).map(preprocess_function, batched=True)

args = Seq2SeqTrainingArguments(
    output_dir="./bigbird",
    num_train_epochs=5,
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-04,
    warmup_steps=1000,
    weight_decay=0.0001,
    label_smoothing_factor=0.15,
    predict_with_generate=True,
    logging_dir="logs",
    logging_strategy="epoch",
    logging_steps=1,
    save_total_limit=2,
    evaluation_strategy="epoch",
    eval_steps=500,
    gradient_accumulation_steps=1,
    generation_max_length=64,
    generation_num_beams=1,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds = ["".join(pred.replace(" ", "")) for pred in decoded_preds]
    decoded_labels = ["".join(label.replace(" ", "")) for label in decoded_labels]
    labels_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in labels]

    for i, (pred, label) in enumerate(zip(decoded_preds, decoded_labels)):
        if pred == "":
            decoded_preds[i] = "decoding error,skipping..."
    rouge = lawrouge.Rouge()
    result = rouge.get_scores(decoded_preds, decoded_labels, avg=True)
    result = {'rouge-1': result['rouge-1']['f'], 'rouge-2': result['rouge-2']['f'], 'rouge-l': result['rouge-l']['f']}
    result = {key: value * 100 for key, value in result.items()}
    result["gen_len"] = np.mean(labels_lens)
    return result


trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

train_result = trainer.train()
print(train_result)
trainer.save_model()
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

总结

本文介绍了一种用于中文长文本摘要的生成式模型-BigBird大鸟模型,通过实践将开源的中文生成预训练bart-chinese-base转换成可以用于BigBird中文权重并用于训练中文长文本生成式摘要,也通过实践验证了其可行性。

效果预览

在这里插入图片描述

参考文献

[1] https://arxiv.org/abs/2007.14062

[2] https://huggingface.co/fnlp/bart-base-chinese

[3] https://github.com/google-research/bigbird

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/375899.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Day902.Memory存储引擎 -MySQL实战

Memory存储引擎 Hi&#xff0c;我是阿昌&#xff0c;今天学习记录的是关于Memory存储引擎的内容。 两个 group by 语句都用了 order by null&#xff0c;为什么使用内存临时表得到的语句结果里&#xff0c;0 这个值在最后一行&#xff1b; 而使用磁盘临时表得到的结果里&…

ARM的工作模式和37个寄存器

一、ARM的工作模式 ARM一共有7种工作模式 模式含义User非特权模式&#xff0c;大部分任务执行在这种模式FIQ当一个高优先级&#xff08;fast) 中断产生时将会进入这种模式IRQ当一个低优先级&#xff08;normal) 中断产生时将会进入这种模式Supervisor当复位或软中断指令执行时…

巨头混战,抢着“兜底”自动驾驶安全

诚然&#xff0c;中国汽车行业的发展绝对不会拘泥于电动化&#xff0c;必定会在电动化的基础上&#xff0c;迎接下半场的快速智能化。 2021年6月&#xff0c;长城汽车线控底盘全球首次发布。 彼时&#xff0c;长城汽车技术副总裁宋东先宣布&#xff0c;整合了线控转向、线控制…

基于海鸥算法改进的DELM分类-附代码

海鸥算法改进的深度极限学习机DELM的分类 文章目录海鸥算法改进的深度极限学习机DELM的分类1.ELM原理2.深度极限学习机&#xff08;DELM&#xff09;原理3.海鸥算法4.海鸥算法改进DELM5.实验结果6.参考文献7.Matlab代码1.ELM原理 ELM基础原理请参考&#xff1a;https://blog.c…

【数据结构与算法】单链表的增删查改(附源码)

这么可爱的猫猫不值得点个赞吗&#x1f63d;&#x1f63b; 目录 一.链表的概念和结构 二.单链表的逻辑结构和物理结构 1.逻辑结构 2.物理结构 三.结构体的定义 四.增加 1.尾插 SListpushback 2.头插 SListpushfront 五.删除 1.尾删 SListpopback 2.头删 SListpo…

浅谈音视频开发,如何更好的去学习?

Android音视频跟普通的应用层开发相比&#xff0c;的确更花费精力。期间为了学习音视频的录制&#xff0c;编码&#xff0c;处理也看过大大小小的几十个项目。总体感觉就是知识比较零散&#xff0c;对刚入门的朋友比较不友好。所以才萌生了整理一个Android音视频学习路线的想法…

Qss自定义属性

QSS自定义属性 更多精彩内容&#x1f449;个人内容分类汇总 &#x1f448;&#x1f449;QSS样式学习 &#x1f448;文章目录QSS自定义属性[toc]前言一、实现效果二、使用方式1.QSS设置Q_PROPERTY属性样式2.QSS设置动态属性样式3.qproperty-<属性名称>语法14.qproperty-&…

如何在报表生成工具 Stimulsoft 中自定义报告查看器?

Stimulsoft Reports 是一款报告编写器&#xff0c;主要用于在桌面和Web上从头开始创建任何复杂的报告。可以在大多数平台上轻松实现部署&#xff0c;如ASP.NET, WinForms, .NET Core, JavaScript, WPF, Angular, Blazor, PHP, Java等&#xff0c;在你的应用程序中嵌入报告设计器…

华为OD机试模拟题 用 C++ 实现 - 猜字谜(2023.Q1)

最近更新的博客 【华为OD机试模拟题】用 C++ 实现 - 最多获得的短信条数(2023.Q1)) 文章目录 最近更新的博客使用说明猜字谜题目输入输出描述备注示例一输入输出示例二输入输出思路Code使用说明 参加华为od机试,一定要注意不要完全背诵代码,需要理解之后模仿写出,

IDEA配置 External Tool

有一些文件无法使用IDEA自带的工具打开&#xff0c;这时候就需要借助电脑上安装的第三方软件协助打开。比如使用电脑上安装的Notepad打开项目中的*.ppm文件。 一、配置External Tool 参数说明&#xff1a; 名称(Name)&#xff1a;将在IntelliJ IDEA界面&#xff08;“ 工具”菜…

什么是api接口?(基本介绍)

API:应用程序接口(API:Application Program Interface) 应用程序接口是一组定义、程序及协议的集合&#xff0c;通过 API 接口实现计算机软件之间的相互通信。API 的一个主要功能是提供通用功能集。程序员通过调用 API 函数对应用程序进行开发&#xff0c;可以减轻编程任务。 …

ROS1/2机器人操作系统与时间Time的不解之缘

时间对于机器人操作系统非常重要。所有机器人类的编程中所涉及的变量如果需要在网络中传输都需要这个数据结构的时间戳。宏观上&#xff0c;ROS1、ROS2各版本都有官方支持的时间节点。ROS时钟--支持时间倒计时小工具效果如下&#xff1a;如果要部署机器人操作系统&#xff0c;R…

IP路由原理、静态路由及动态路由区分

1、什么是路由? 路由器 在互联网中进行路由选择所使用的设备&#xff0c;或者说&#xff0c;实现路由的设备&#xff0c;我们称之为路由器。 路由器关键功能&#xff1a;检查数据包的目的地确定信息源发现可能的路由选择最佳路由验证和维护路由信息 什么是路由 路由是指导I…

洗地机什么牌子最好?洗地机品牌排行榜前十名

洗地机是近年来火热的清洁电器&#xff0c;凭借强劲的顽渍污渍去除和高效的地面清洁性能&#xff0c;成为了保持洁净家居环境、为生活做减法的黑科技&#xff0c;颇受生活达人们的追捧和青睐。洗地机的品牌、种类众多&#xff0c;一时间令人眼花缭乱。今天&#xff0c;我想为大…

fastadmin:在新增页面,打开弹窗单选,参数回传

样式&#xff1a;核心代码&#xff1a;一、弹窗的控制器中&#xff1a;// 定义一个公共函数select()&#xff0c;如果这个请求是Ajax&#xff0c;则返回index()函数&#xff0c;否则返回view对象的fetch()函数。 public function select() {if ($this->request->isAjax(…

C 学习笔记 —— 函数指针

函数指针 上面的第二个char (* f) (int);写法就是函数指针的声明&#xff1b; 首先&#xff0c;什么是函数指针&#xff1f;假设有一个指向 int类型变量的指针&#xff0c;该指针储存着这个int类型变量储存在内存位置的地址。 同样&#xff0c;函数也有地址&#xff0c;因为函…

LC-3—MIO、MMIO、Caller Save、Callee Save

LC-3—MIO、MMIOMMIOMIOCaller Save、Callee Save举个例子MMIO MMIO&#xff08;Memory Mapped I/O&#xff09;是一种在系统内存中映射I/O端口的技术&#xff0c;它允许设备直接访问内存中的特定地址&#xff0c;从而实现I/O操作。MMIO技术可以提高I/O操作的效率&#xff0c;…

Mysql 事务版本链(RR 与 RC 的本质区别)

版本链其实就是CURD的历史记录&#xff0c;回滚的本质也是用版本链中的最近一条历史记录覆盖当前记录。版本链针对的是每个表中的记录&#xff0c;只要表中有任意一条记录被修改&#xff0c;版本链中就会新增一条历史记录。 目录 1、为什么需要版本链&#xff1f; 2、有关版本…

《爆肝整理》保姆级系列教程python接口自动化(二十四)--unittest断言——中(详解)

简介 上一篇通过简单的案例给小伙伴们介绍了一下unittest断言&#xff0c;这篇我们将通过结合和围绕实际的工作来进行unittest的断言。这里以获取城市天气预报的接口为例&#xff0c;设计了 2 个用例&#xff0c;一个是查询北京的天气&#xff0c;一个是查询 南京为例&#xf…

接口幂等性的通用解决方案golang版

文章目录简介幂等性如何实现前端应当处理后端基于 token redis 处理简介 接口的幂等性是指&#xff1a; 用户对同一个操作发起多次请求&#xff0c;系统的设计需要保证其多次请求后结果是一致的。常见的如支付场景&#xff0c;连续快速点击两次支付 10 元&#xff0c;不应该扣…