BERT解析

news2024/11/29 1:39:58

BERT项目
我在BERT添加注释和部分推理代码
main.py

vocab = WordVocab.load_vocab(args.vocab_path)#加载vocab

请添加图片描述
那么这个加载的二进制是什么呢?

1. 加载数据集

继承关系:TorchVocab --> Vocab --> WordVocab

  • TorchVocab

该类主要是定义了一个词典对象,包含如下三个属性:

freqs:是一个collections.Counter对象,能够记录数据集中不同token所出现的次数

stoi:是一个collections.defaultdict对象,将数据集中的token映射到不同的index

itos:是一个列表,保存了从index到token的映射信息

  • Vocab

Vocab继承TorchVocab,该类主要定义了一些特殊token的表示

这里用到了一个装饰器,简单地说:装饰器就是修改其他函数的功能的函数。这里包含了一个序列化的操作

  • WordVocab

WordVocab继承自Vocab,里面包含了两个方法to_seqfrom_seq分别是将token转换成index和将index转换成token表示

2.datasets.py

BERTDataset这个类有——个方法

  1. init():初始化BERTDataset类。
  2. len():返回数据集的大小,即语料库的行数。
  3. getitem():根据索引获取数据项,包括BERT输入序列、标签、段标签和下一句预测标签。
  4. random_word():对给定句子中的单词进行随机处理,用于生成BERT的输入和标签。
  5. random_sent():随机决定是否交换下一句,用于训练BERT的下一句预测任务。
  6. get_corpus_line():根据索引获取语料库中的句子对。
  7. get_random_line():获取语料库中的一个随机句子。
  • init()
    def __init__(
        self,
        corpus_path,
        vocab,
        seq_len,
        encoding="utf-8",
        corpus_lines=None,
        on_memory=False,
    ):
        """
        初始化BERTDataset类。     
        如果on_memory为True,则将语料库加载到内存中。否则,计算语料库的行数。
        参数:
            corpus_path (str): 语料库文件的路径。
            vocab (Vocab): 词汇表对象,用于将单词映射到索引。
            seq_len (int): 序列长度,BERT输入序列的最大长度。
            encoding (str, optional): 文件编码,默认为'utf-8'。
            corpus_lines (int, optional): 语料库行数,如果提供,则在初始化时不会计算。
            on_memory (bool, optional): 是否将整个语料库加载到内存中,默认为True。
        """
        self.vocab = vocab
        self.seq_len = seq_len

        self.on_memory = on_memory
        self.corpus_lines = corpus_lines
        self.corpus_path = corpus_path
        self.encoding = encoding

        with open(corpus_path, "r", encoding=encoding) as f:
            if self.corpus_lines is None and not on_memory:
                for _ in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines):
                    self.corpus_lines += 1

            if on_memory:
                self.lines = [
                    line[:-1].split("\t")
                    for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)
                ]
                print('第一行的内容',self.lines[0])
                print("第一行的数量:",len(self.lines[0]))
                self.corpus_lines = len(self.lines)
        """
            第一行的内容 ['Cuba to Get Rid of Dollars After a Decade', " HAVANA (Reuters) Cubans rushed to change dollars into  local pesos on Tuesday as President Fidel Castro's communist  government prepared to pull the U.S. currency from circulation  more than a decade after it was legalized here."]
            第一行的数量: 2
        """
        if not on_memory:
            self.file = open(corpus_path, "r", encoding=encoding)
            self.random_file = open(corpus_path, "r", encoding=encoding)

            for _ in range(
                random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)
            ):
                self.random_file.__next__()

corpus.small中前三行的内容:

请添加图片描述

所以__init__方法就是将句子分为上下两句,存储在列表中

  • len()
    def __len__(self):
        """
        返回数据集的大小,即语料库的行数。
        
        返回:
            int: 语料库的行数。
        """
        return self.corpus_lines

介绍getitem()之前,我们要先看**random_sent()random_word()**这两个函数

    def __getitem__(self, item):
        """
        根据索引获取数据项,包括BERT输入序列、标签、段标签和下一句预测标签。
        
        参数:
            item (int): 数据项的索引。
            
        返回:
            dict: 包含BERT输入序列、标签、段标签和下一句预测标签的字典。
        """
        t1, t2, is_next_label = self.random_sent(item)
        t1_random, t1_label = self.random_word(t1)
        t2_random, t2_label = self.random_word(t2)
  • random_sent()

在这个函数里面我们先看get_corpus_line()函数和get_random_line()

  • get_corpus_line()

t1和t2是表示上半个句子和下半个句子,用于训练BERT的下一句预测任务。

    def get_corpus_line(self, item):
        """
        根据索引获取语料库中的句子对。
        
        参数:
            item (int): 数据项的索引。
            
        返回:
            tuple: 一个句子对。
        """
        if self.on_memory:
            #返回一句话中的上半和下半
            return self.lines[item][0], self.lines[item][1]
        else:
            line = self.file.__next__()
            if line is None:
                self.file.close()
                self.file = open(self.corpus_path, "r", encoding=self.encoding)
                line = self.file.__next__()

            t1, t2 = line[:-1].split("\t")
            return t1, t2
  • get_random_line()

随机返回一个句子的下半

现在我们来看**random_sent()**函数

    def random_sent(self, index):
        """
        随机决定是否交换下一句,用于训练BERT的下一句预测任务。
        
        参数:
            index (int): 数据项的索引。
            
        返回:
            tuple: 两个句子和一个标签,指示是否是下一句。
        """
        t1, t2 = self.get_corpus_line(index)

        # output_text, label(isNotNext:0, isNext:1)
        if random.random() > 0.5:
            return t1, t2, 1
        else:
            return t1, self.get_random_line(), 0

现在我们来看**random_word()**函数

  • random_word()

85%的概率返回单词和单词在vocab字典中的映射,如果找不到返回unk的index

剩下的概率分别返回mask,随机值,不改变

    def random_word(self, sentence):
        """
        对给定句子中的单词进行随机处理,用于生成BERT的输入和标签。
        
        参数:
            sentence (str): 输入的句子。
            
        返回:
            tuple: 处理后的单词列表和对应的标签列表。
        """
        tokens = sentence.split()
        output_label = []

        for i, token in enumerate(tokens):
            prob = random.random()
            if prob < 0.15:
                prob /= 0.15

                # 80% randomly change token to mask token
                if prob < 0.8:
                    tokens[i] = self.vocab.mask_index

                # 10% randomly change token to random token
                elif prob < 0.9:
                    tokens[i] = random.randrange(len(self.vocab))

                # 10% randomly change token to current token
                else:
                    tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)

                output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))

            else:
                # self.vocab.stoi:单词到index的映射
                # 查询不到返回unk的index
                tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
                output_label.append(0)

        return tokens, output_label

现在回头再来看getitem

  • getitem()
    def __getitem__(self, item):
        """
        根据索引获取数据项,包括BERT输入序列、标签、段标签和下一句预测标签。
        
        参数:
            item (int): 数据项的索引。
            
        返回:
            dict: 包含BERT输入序列、标签、段标签和下一句预测标签的字典。
        """
        t1, t2, is_next_label = self.random_sent(item)
        t1_random, t1_label = self.random_word(t1)
        t2_random, t2_label = self.random_word(t2)

        """
            一个句子的头部加cls,尾部加eos标志
            label前面加pad,后面加pad
        """
        # [CLS] tag = SOS tag, [SEP] tag = EOS tag
        t1 = [self.vocab.sos_index] + t1_random + [self.vocab.eos_index]
        t2 = t2_random + [self.vocab.eos_index]

        t1_label = [self.vocab.pad_index] + t1_label + [self.vocab.pad_index]
        t2_label = t2_label + [self.vocab.pad_index]
        # segment_label表示当前是第一句话还是第二句话,position的一部分
        segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[
            : self.seq_len
        ]
        bert_input = (t1 + t2)[: self.seq_len]
        bert_label = (t1_label + t2_label)[: self.seq_len]

        padding = [self.vocab.pad_index for _ in range(self.seq_len - len(bert_input))]
        bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(
            padding
        )
        
        output = {
            "bert_input": bert_input,
            "bert_label": bert_label,
            "segment_label": segment_label,
            "is_next": is_next_label,
        }

        return {key: torch.tensor(value) for key, value in output.items()}

2. 训练代码

模型代码网上有很多,不再赘述

3. 推理代码

补充推理部分的代码

import torch
from bert_pytorch.model.bert import BERT
from bert_pytorch.model import BERTLM
from bert_pytorch.dataset import WordVocab
import numpy as np

def process(sentence: str) -> tuple:
    """输入预处理,将句子转化为对应的id

    :param str sentence: 输入的句子
    :return tuple: [token,label]
    """
    sentence = "hello world"
    tokens = sentence.split()
    output_label = []
    for i, token in enumerate(tokens):
        tokens[i] = vocab.stoi.get(token, vocab.unk_index)
        output_label.append(0)
    return tokens, output_label
def infer(s1:str, s2:str)->tuple:
    input_t1, _ = process("hello")
    input_t2, _ = process("World")
    t1 = [vocab.sos_index] + input_t1 + [vocab.eos_index]
    t2 = [vocab.sos_index] + input_t2 + [vocab.eos_index]
    # segment_label表示当前是第一句话还是第二句话,position的一部分
    segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:seq_len]
    bert_input = (t1 + t2)[:seq_len]
    bert_input=np.array(bert_input)
    bert_input=bert_input[None,:]
    bert_input=torch.from_numpy(bert_input)
    segment_label=np.array(segment_label)
    segment_label=torch.from_numpy(segment_label)
    next_sent_output, mask_lm_output = model(bert_input, segment_label)
    mask_lm_output = mask_lm_output.transpose(1, 2)
    # 待补充

# 2020 256 8 8
# 2020是len(vocab),自己去看vocab.py里面的
seq_len = 20
f = torch.load("./output/bert.model.ep0")
bert = BERT(2020, 256, 8, 8)
model = BERTLM(bert, 2020)
model.load_state_dict(f)
# 加载词汇表,方便将预测的词转化为对应的id
vocab = WordVocab.load_vocab(r".\data\vocab.small")
infer("hello","world")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2249433.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

连接共享打印机0X0000011B错误多种解决方法

打印机故障一直是一个热门话题&#xff0c;特别是共享打印机0x0000011b错误特别头疼&#xff0c;有很多网友经常遇到共享打印机0x0000011b错误。0x0000011b有更新补丁导致的、有访问共享打印机服务异常、有访问共享打印机驱动异常等问题导致的&#xff0c;针对共享打印机0x0000…

spring +fastjson 的 rce

前言 众所周知&#xff0c;spring 下是不可以上传 jsp 的木马来 rce 的&#xff0c;一般都是控制加载 class 或者 jar 包来 rce 的&#xff0c;我们的 fastjson 的高版本正好可以完成这些&#xff0c;这里来简单分析一手 环境搭建 <dependency><groupId>org.spr…

jeecgbootvue2重新整理数组数据或者添加合并数组并遍历背景图片或者背景颜色

想要实现处理后端返回数据并处理&#xff0c;添加已有静态数据并遍历快捷菜单背景图 遍历数组并使用代码 需要注意&#xff1a; 1、静态数组的图片url需要的格式为 require(../../assets/b.png) 2、设置遍历背景图的代码必须是: :style"{ background-image: url( item…

15分钟做完一个小程序,腾讯这个工具有点东西

我记得很久之前&#xff0c;我们都在讲什么低代码/无代码平台&#xff0c;这个概念很久了&#xff0c;但是&#xff0c;一直没有很好的落地&#xff0c;整体的效果也不算好。 自从去年 ChatGPT 这类大模型大火以来&#xff0c;各大科技公司也都推出了很多 AI 代码助手&#xff…

jenkins 2.346.1最后一个支持java8的版本搭建

1.jenkins下载 下载地址&#xff1a;Index of /war-stable/2.346.1 2.部署 创建目标文件夹&#xff0c;移动到指定位置 创建一个启动脚本&#xff0c;deploy.sh #!/bin/bash set -eDATE$(date %Y%m%d%H%M) # 基础路径 BASE_PATH/opt/projects/jenkins # 服务名称。同时约定部…

3D建筑模型的 LOD 规范

LOD&#xff08;细节层次&#xff09; 是3D城市建模中用于表示建筑模型精细程度的标准化描述不同的LOD适用于不同的应用场景 LOD是3D建模中重要的分级标准&#xff0c;不同层级适合不同精度和用途的需求。 从LOD0到LOD4&#xff0c;细节逐渐丰富&#xff0c;复杂性和精度也逐…

解锁 Vue 项目中 TSX 配置与应用简单攻略

在 Vue 项目中配置 TSX 写法 在 Vue 项目中使用 TSX 可以为我们带来更灵活、高效的开发体验&#xff0c;特别是在处理复杂组件逻辑和动态渲染时。以下是详细的配置步骤&#xff1a; 一、安装相关依赖 首先&#xff0c;我们需要在命令行中输入以下命令来安装 vitejs/plugin-v…

【UE5 C++课程系列笔记】04——创建可操控的Pawn

根据官方文档创建一个可以控制前后左右移动、旋转视角、缩放视角的Pawn 。 步骤 一、创建Pawn 1. 新建一个C类&#xff0c;继承Pawn类&#xff0c;这里命名为“PawnWithCamera” 2. 在头文件中申明弹簧臂、摄像机和静态网格体组件 3. 在源文件中引入组件所需库 在构造函数…

多目标优化算法——多目标粒子群优化算法(MOPSO)

Handling Multiple Objectives With Particle Swarm Optimization&#xff08;多目标粒子群优化算法&#xff09; 一、摘要&#xff1a; 本文提出了一种将帕累托优势引入粒子群优化算法的方法&#xff0c;使该算法能够处理具有多个目标函数的问题。与目前其他将粒子群算法扩展…

HTML5好看的音乐播放器多种风格(附源码)

文章目录 1.设计来源1.1 音乐播放器风格1效果1.2 音乐播放器风格2效果1.3 音乐播放器风格3效果1.4 音乐播放器风格4效果1.5 音乐播放器风格5效果 2.效果和源码2.1 动态效果2.2 源代码 源码下载万套模板&#xff0c;程序开发&#xff0c;在线开发&#xff0c;在线沟通 作者&…

通用网络安全设备之【防火墙】

概念&#xff1a; 防火墙&#xff08;Firewall&#xff09;&#xff0c;也称防护墙&#xff0c;它是一种位于内部网络与外部网络之间的网络安全防护系统&#xff0c;是一种隔离技术&#xff0c;允许或是限制传输的数据通过。 基于 TCP/IP 协议&#xff0c;主要分为主机型防火…

c++趣味编程玩转物联网:基于树莓派Pico控制有源蜂鸣器

有源蜂鸣器是一种简单高效的声音输出设备&#xff0c;广泛应用于电子报警器、玩具、计时器等领域。在本项目中&#xff0c;我们结合树莓派Pico开发板&#xff0c;通过C代码控制有源蜂鸣器发出“滴滴”声&#xff0c;并解析其中涉及的关键技术点和硬件知识。 一、项目概述 1. 项…

jar包打成exe安装包

打包exe并设置管理员权限 前言打包可执行文件exe准备jre环境运行exe4j并配置 设置执行文件exe管理员权限设置打包工具管理员权限打包exe安装包创建脚本打包 前言 准备安装包&#xff1a; jar包&#xff1a;自己的程序&#xff1b;exe4j&#xff1a;将jar打包可执行的exe&…

NAT:连接私有与公共网络的关键技术(4/10)

一、NAT 的工作原理 NAT 技术的核心功能是将私有 IP 地址转换为公有 IP 地址&#xff0c;使得内部网络中的设备能够与外部互联网通信。其工作原理主要包括私有 IP 地址到公有 IP 地址的转换、端口号映射以及会话表维护这几个步骤。 私有 IP 地址到公有 IP 地址的转换&#xff1…

多模态大型语言模型(MLLM)综述

目录 多模态大语言模型的基础 长短期网络结构(LSTM) 自注意力机制 基于Transformer架构的自然语言处理模型 多模态嵌入概述 多模态嵌入关键步骤 多模态嵌入现状 TF-IDF TF-IDF的概念 TF-IDF的计算公式 TF-IDF的主要思路 TF-IDF的案例 训练和微调多模态大语言模…

学习ASP.NET Core的身份认证(基于Cookie的身份认证3)

用户通过验证后调用HttpContext.SignInAsync函数将用户的身份信息保存在认证Cookie中,以便后续的请求可以验证用户的身份,该函数原型如下所示&#xff0c;其中properties参数的主要属性已在前篇文章中学习&#xff0c;本文学习scheme和principal的意义及用法。 public static …

C++设计模式-模板模式,Template Method

动机&#xff08;Motivation&#xff09; 在软件构建过程中&#xff0c;对于某一项任务&#xff0c;它常常有稳定的整体操作结构&#xff0c;但各个子步骤却有很多改变的需求&#xff0c;或者由于固有的原因&#xff08;比如框架与应用之间的关系&#xff09;而无法和任务的整…

Jenkins流水线 Allure JUnit5 自动化测试

目录 一、Jenkins Allure配置 1.1 安装Allure插件 1.2 安装Allure工具 1.3 配置测试报告路径 1.4 JenkinsFile 二、Jenkins 邮箱配置 2.1 安装Email Extension Plugin插件 2.2 邮箱配置 2.3 JenkinsFile 三、项目pom.xml 配置 3.1 引入allure-junit5依赖 3.2 引入m…

计算机网络 实验七 NAT配置实验

一、实验目的 通过本实验理解网络地址转换的原理和技术&#xff0c;掌握扩展NAT/NAPT设计、配置和测试。 二、实验原理 NAT配置实验的原理主要基于网络地址转换&#xff08;NAT&#xff09;技术&#xff0c;该技术用于将内部私有网络地址转换为外部公有网络地址&#xff0c;从…

shell脚本基础学习_总结篇(完结)

细致观看可以&#xff0c;访问shell脚本学习专栏&#xff0c;对应章节会有配图https://blog.csdn.net/2201_75446043/category_12833287.html?spm1001.2014.3001.5482 导语 一、shell脚本简介 1. 定义&#xff1a; 2. 主要特点&#xff1a; 3. shell脚本的基本结构 4. S…