BERT训练之数据集处理(代码实现)

news2024/9/25 5:23:16

目录

1读取文件数据

 2.生成下一句预测任务的数据

 3.预测下一个句子

 4.生成遮蔽语言模型任务的数据

 5.从词元中得到遮掩的数据

 6.将文本转化为预训练数据集

7.封装函数类

8.调用


import os
import random
import torch
import dltools

1读取文件数据

def _read_wiki(data_dir):
    #拼接文件路径
    file_name = os.path.join(data_dir, 'wiki.train.tokens')
    #将输入参数中的两个名字拼接成一个完整的文件路径。
    with open(file_name, 'r', encoding='utf-8') as f:
        #打开文件,逐行读取内容,并将每行作为一个元素添加到列表中。
        lines = f.readlines()
    #大写字母转换为小写字母,获取分句之后的段落列表
    paragraphs = [line.strip().lower().split('.') for line in lines if len(line.split('.')) >= 2]
    random.shuffle(paragraphs)  #大陆那段落列表中的元素
    return paragraphs


_read_wiki('./wikitext-2/')  #输出过长,不展示

 2.生成下一句预测任务的数据

def _get_next_sentence(sentence, next_sentence, paragraphs):
    if random.random() < 0.5: #若50%的概率发生时
        is_next = True
    else:
        #否则,next_sentence就不是下一个句子,是随机抽取的其他句子
        #paragraphs是三重列表的嵌套
        #从所有列表中随机抽取一个段落,从这个段落中又随机抽取一个句子
        next_sentence = random.choice(random.choice(paragraphs))
        is_next =False
    return sentence, next_sentence, is_next     

 3.预测下一个句子

def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):
    nsp_data_from_paragraph = []  #创建空列表,存放下一个句子的数据
    for i in range(len(paragraph) - 1):   #len(paragraph) - 1是因为索引是从0开始的,左闭右开,输出段落中的每一个句子的索引
        #调用函数,获取用于预测下一个句子任务的数据
        tokens_a, tokens_b , is_next = _get_next_sentence(paragraph[i], paragraph[i+1], paragraphs)
        #预测输入的两个句子结构是  -->    <cls> tokens_a  <sep> tokens_b <sep>
        # +3表示考虑 1个<cls>  +2个<sep>
        if len(tokens_a) + len(tokens_b) + 3 > max_len:
            continue   #这种情况超出了序列的最大长度,不需要
        #将文本数据分割成词元(tokens)和句子分段(segments)。
        #这个过程通常涉及到一系列的预处理步骤,如去除标点符号、转换为小写、数字处理等,以确保输入数据的标准化和一致性‌
        tokens, segments = dltools.get_tokens_and_segments(tokens_a, tokens_b)
        nsp_data_from_paragraph.append((tokens, segments, is_next))  #三个数据以元祖的形式存放到列表中
    return nsp_data_from_paragraph

 4.生成遮蔽语言模型任务的数据

#Mask Language Modle
def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds, vocab):
    """
    tokens:传入的词元
    candidate_pred_positions:等待预测的词元位置索引编号(若传入句子的序列长度为100,那么它就是0-99)
    num_mlm_preds:预测遮掩的数量
    vocab:整体词汇表
    """
    #为遮蔽语言模型的输入创建新的词元副本, 其中输入可能包含替换的<mask>或随机词元
    mlm_input_tokens = [token for token in tokens]  #复制词元数据,后期的替换不修改原数据
    pred_positions_and_labels = []  #用于存放预测的词元位置和目标标签
    #打乱顺序  等待预测的词元位置索引编号
    random.shuffle(candidate_pred_positions)
    for mlm_pred_position in candidate_pred_positions:  #遍历
        #判断存放预测词元的个数是否已经超过了需要预测的数量
        if len(pred_positions_and_labels) >= num_mlm_preds:
            break  #若预测数量够了,就不预测了,直接退出当前for循环,  continue是退出当前if判断
        #否则,接着预测
        mask_token = None  #初始化变量:被15%抽中需要被替换的词元   为空
        #80%的概率, 将抽取的15%的词元,替换成<mask>词元
        if random.random() < 0.8:
            msaked_token = '<mask>'
        else:  #否则,将剩下的其中10%的词元保持不变      从剩下的20%中抽取50%来表示
            if random.random() < 0.5:
                mask_token = tokens[mlm_pred_position]
            else:  #将剩下的其中10%的词元,用随机词替换
                msaked_token = random.choice(vocab.idx_to_token)
        #将获取到的msaked_token按索引赋值替换原词元
        mlm_input_tokens[mlm_pred_position] = mask_token
        #mlm_pred_position需要被预测的词元位置索引,  tokens[mlm_pred_position]被遮掩预测的词元的标签(真实值是什么)
        pred_positions_and_labels.append((mlm_pred_position, tokens[mlm_pred_position]))
    return mlm_input_tokens, pred_positions_and_labels
        

 5.从词元中得到遮掩的数据

# 
def _get_mlm_data_from_tokens(tokens, vocab):
    candidate_pred_positions = []
    # tokens是一个字符串列表
    for i, token in enumerate(tokens):
        # 在遮蔽语言模型任务中不会预测特殊词元
        if token in ['<cls>', '<sep>']:
            continue
        candidate_pred_positions.append(i)
    # 遮蔽语言模型任务中预测15%的随机词元
    num_mlm_preds = max(1, round(len(tokens) * 0.15))
    mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(
        tokens, candidate_pred_positions, num_mlm_preds, vocab)
    pred_positions_and_labels = sorted(pred_positions_and_labels,
                                       key=lambda x: x[0])
    pred_positions = [v[0] for v in pred_positions_and_labels]
    mlm_pred_labels = [v[1] for v in pred_positions_and_labels]
    return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]

 6.将文本转化为预训练数据集

def _pad_bert_inputs(examples, max_len, vocab):
    #词源需要预测的最大数量
    max_num_mlm_preds = round(max_len * 0.15)
    all_tokens_ids, all_segments, valid_lens = [], [], []
    all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], []
    nsp_labels = []
    for (token_ids, pred_positions, mlm_pred_label_ids, segments, is_next) in examples:
        #对原有的tokens(每句话有长有短,补充《pad》使长度一致)
        all_tokens_ids.append(torch.tensor(token_ids + [vocab['<pad>']] * (max_len - len(token_ids)), dtype=torch.long))
        all_segments.append(torch.tensor(segments + [0] * (max_len - len(segments)), dtype=torch.long))
        #valid_lens不包括<pad>计数
        valid_lens.append(torch.tensor(len(token_ids), dtype=torch.float32))
        all_pred_positions.append(torch.tensor(pred_positions + [0] * (max_num_mlm_preds - len(pred_positions)), dtype=torch.long))
        #填充词元的预测将通过乘以0权重在损失中过滤掉
        all_mlm_weights.append(torch.tensor([1.0] * len(mlm_pred_label_ids) + [0.0] * (max_num_mlm_preds - len(pred_positions)), dtype=torch.float32))
        all_mlm_labels.append(torch.tensor(mlm_pred_label_ids + [0] * (max_num_mlm_preds - len(mlm_pred_label_ids)), dtype=torch.long))
        nsp_labels.append(torch.tensor(is_next, dtype=torch.long))
    return (all_tokens_ids, all_segments, valid_lens, all_pred_positions, all_mlm_weights, all_mlm_labels, nsp_labels)
    

7.封装函数类

class WikiTextDataset(torch.utils.data.Dataset):
    def __init__(self, paragraphs, max_len):
        #输入paragraphs[i]是代表段落的句子字符串列表
        #输出paragraphs[i]是代表段落的句子列表,其中每个句子都是词元列表
        paragraphs = [dltools.tokenize(paragraph, token='word') for paragraph in paragraphs]
        #获取句子的词元列表
        sentences = [sentence for paragraph in paragraphs for sentence in paragraph]
        self.vocab = dltools.Vocab(sentences, min_freq=5, reserved_tokens=['<pad>', '<mask>', '<cls>', '<sep>'])
        #获取下一句子预测任务的数据
        examples = []
        for paragraph in paragraphs:
            examples.extend(_get_nsp_data_from_paragraph(paragraph, paragraphs, self.vocab, max_len))
        #获取遮蔽语言模型任务的数据
        examples = [(_get_mlm_data_from_tokens(tokens, self.vocab) + (segments, is_next)) for tokens, segments, is_next in examples]
        #填充输入
        (self.all_token_ids, self.all_segments, self.valid_lens, self.all_pred_positions, self.all_mlm_weights, self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs(examples, max_len, self.vocab)
        
    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx],
                self.valid_lens[idx], self.all_pred_positions[idx],
                self.all_mlm_weights[idx], self.all_mlm_labels[idx],
                self.nsp_labels[idx])

    def __len__(self):
        return len(self.all_token_ids)

8.调用

def load_data_wiki(batch_size, max_len):
    """加载WikiText-2数据集"""
    num_workers = dltools.get_dataloader_workers()  #快速获取或设置最佳的工作线程数
    data_dir = './wikitext-2/'
    paragraphs = _read_wiki(data_dir)
    train_set = WikiTextDataset(paragraphs, max_len)
    train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, num_workers=num_workers)
    
    return train_iter, train_set.vocab
    
batch_size, max_len = 512, 64
train_iter, vocab = load_data_wiki(batch_size, max_len)

for (tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X,
     mlm_Y, nsp_y) in train_iter:
    print(tokens_X.shape, segments_X.shape, valid_lens_x.shape,
          pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape,
          nsp_y.shape)
    break
torch.Size([512, 64]) torch.Size([512, 64]) torch.Size([512]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512])
len(vocab)

 20228

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2162553.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java框架学习(Spring)(ioc)(01)

简介&#xff1a;以本片记录在尚硅谷学习ssm-spring-ioc时遇到的小知识 详情移步&#xff1a;想参考的朋友建议全部打开相互配合学习&#xff01; 视频&#xff1a; 014-spring-框架概念理解_哔哩哔哩_bilibilihttps://www.bilibili.com/video/BV1AP411s7D7?p14&vd_sou…

SpringBoot框架在文档管理中的创新应用

第3章 系统分析 3.1 需求分析 在线文档管理系统主要是为了提高工作人员的工作效率和更方便快捷的满足员工&#xff0c;更好存储所有数据信息及快速方便的检索功能&#xff0c;对系统的各个模块是通过许多今天的发达系统做出合理的分析来确定考虑员工的可操作性&#xff0c;遵循…

峟思助力堤防工程安全:构建多功能防洪屏障

堤防工程&#xff0c;作为水利建设中至关重要的防护体系&#xff0c;不仅守护着江河、湖泊及滨海区域的安全&#xff0c;更是确保人民生命财产安全的坚固防线。在现代社会&#xff0c;随着技术的进步与安全意识的提升&#xff0c;堤防工程不仅限于传统的防洪功能&#xff0c;更…

SpringBoot和JPA初探

目录 SpringBoot和JPA初探0.准备条件1.创建JPA项目2.项目3.总结 SpringBoot和JPA初探 我们使用SpringBootJPA做一个简单的API接口演示&#xff0c;通过一个简单的例子让大家对Spring Data JPA有一个整体的认知。 0.准备条件 IntelliJ IDEAjdk 1.8mysql 8.0maven 3.8.x 1.创…

代码随想录算法训练营第三十九天 | 198.打家劫舍 ,213.打家劫舍II,337.打家劫舍III

第三十九天打卡&#xff0c;今天解决打家劫舍系列问题&#xff0c;树形dp比较难。 198.打家劫舍 题目链接 解题过程 dp[i]&#xff1a;考虑下标i&#xff08;包括i&#xff09;以内的房屋&#xff0c;最多可以偷窃的金额为dp[i]。 要么不偷这一间&#xff0c;那就是前面那间…

开源链动 2+1 模式、AI 智能名片与 S2B2C 商城小程序:以问题解决为导向的盈利新模式

摘要&#xff1a;本文探讨了问题解决盈利模式的重要性&#xff0c;并结合开源链动 21 模式、AI 智能名片以及 S2B2C 商城小程序等创新工具&#xff0c;阐述了如何以用户为中心&#xff0c;通过深刻洞察用户需求&#xff0c;解决用户问题&#xff0c;实现盈利增长。强调了在当今…

[利用python进行数据分析01] “来⾃Bitly的USA.gov数据” 分析出各个地区的 windows和非windows用户

2011 年&#xff0c; URL 缩短服务 Bitly 跟美国政府⽹站 USA.gov 合作&#xff0c;提供 了⼀份从⽣成 .gov 或 .mil 短链接的⽤户那⾥收集来的匿名数据。 在 2011 年&#xff0c;除实时数据之外&#xff0c;还可以下载⽂本⽂件形式的每⼩时 快照。 数据集下载&#xff1a;通…

LabVIEW项目编码器选择

在LabVIEW项目中&#xff0c;选择增量式&#xff08;Incremental Encoder&#xff09;和绝对式&#xff08;Absolute Encoder&#xff09;编码器取决于项目的具体需求。增量式编码器和绝对式编码器在工作原理、应用场景、精度和成本等方面存在显著差异。以下从多方面详细阐述两…

通过service访问Pod

假设Pod中的容器可能因为各种原因发生故障而死掉&#xff0c;Deployment等controller会通过动态创建和销毁Pod来保证应用整体的健壮性&#xff0c;换句话说&#xff0c;Pod是脆弱的&#xff0c;但应用是健壮的 每个Pod都有自己的Ip&#xff0c;当controller用新的Pod替代发生故…

SDK(2 note)

复习上一次内容&#xff1a; 把前一次笔记中的代码&#xff0c;简写一下 #include <windows.h> #include<tchar.h> #include <stdio.h> #include <strsafe.h> VOID showerrormassage() {LPVOID lpMsgBuf; FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFF…

TS-AI:一种用于多模态个体化脑区划分的深度学习管道,并结合任务对比合成|文献速递-Transformer架构在医学影像分析中的应用

Title 题目 TS-AI: A deep learning pipeline for multimodal subject-specific parcellation with task contrasts synthesis TS-AI&#xff1a;一种用于多模态个体化脑区划分的深度学习管道&#xff0c;并结合任务对比合成 01 文献速递介绍 人类大脑在结构和功能组织上表…

nfs版本问题导致挂载失败

一、系统环境 环境版本操作系统Linux Mint 22 Wilma内核版本6.8.0-44-genericgcc 版本arm-none-linux-gnueabihf-gcc (GNU Toolchain for the A-profile Architecture 9.2-2019.12 (arm-9.10)) 9.2.1 20191025uboot 版本2020.01开发板Linux版本5.4.31 二、问题描述 内核通过…

拒绝信息泄露!VMD滚动分解 + Informer-BiLSTM并行预测模型

前言 在时间序列预测任务中&#xff0c;像 EMD&#xff08;经验模态分解&#xff09;、CEEMDAN&#xff08;完全集合经验模态分解&#xff09;、VMD&#xff08;变分模态分解&#xff09; 等分解算法的使用有可能引入信息泄露&#xff0c;具体情况取决于这些方法的应用方式。信…

通过WebTopo在ARMxy边缘计算网关上实现系统集成

随着工业互联网技术的发展&#xff0c;边缘计算成为了连接物理世界与数字世界的桥梁&#xff0c;其重要性日益凸显。边缘计算网关作为数据采集、处理与传输的核心设备&#xff0c;在智能制造、智慧城市等领域发挥着关键作用。 1. BL340系列概述 BL340系列是基于全志科技T507-…

yolov8/9关键点检测模型检测俯卧撑动作并计数【源码免费+数据集+python环境+GUI系统】

yolov89模型检测俯卧撑动作并计数【源码免费数据集python环境GUI系统】 yolov8/9关键点检测模型检测俯卧撑动作并计数【源码免费数据集python环境GUI系统】 YOLO算法原理 YOLO&#xff08;You Only Look Once&#xff09;关键点检测的算法原理主要基于YOLO目标检测算法进行改进…

R包:VennDiagram韦恩图

加载R包 library(VennDiagram)数据 # Prepare character vectors v1 <- c("DKK1", "NPC1", "NAPG", "ERG", "VHL", "BTD", "MALL", "HAUS1") v2 <- c("SMAD4", "DKK1…

VMware虚拟网络的连接模式探究与实践

VMware安装完成虚拟机后&#xff0c;大多要进行网络配置&#xff0c;实现网络的互联互联&#xff0c;初学者往往感觉与一台实体主机的网络配置不同&#xff0c;局域网中一台实体主机一个物理网卡&#xff0c;配置一个IP地址&#xff1b;或直接通过WAN上网&#xff0c;比较直观&…

基于python的django微博内容网络分析系统,实现文本划分词结构

本项目旨在开发一个基于Python的Django框架的微博内容网络分析系统&#xff0c;聚焦于微博文本的分词处理、名词提取和主成分分析。该系统通过数据收集与预处理、分词及结构化文本分析&#xff0c;为舆情监测、话题分析和用户行为研究提供了一体化的解决方案。 主要功能包括&a…

数据分析学习之学习路线

前言 我们之前通过cda认证了解到数据分析行业&#xff0c;但是获取到证书&#xff0c;并不代表着&#xff0c;我们已经拥有的数据分析的能力&#xff0c;所以通过系统的学习数据分析需要掌握的能力&#xff0c;并学习大佬们的分析经验、分析思路&#xff0c;才是成为数据分析师…

TDengine 学习与使用经验分享:业务落地实践与架构升级探索

前言 随着物联网、工业互联网等行业的快速发展&#xff0c;时间序列数据的管理和处理需求急剧增加。传统的关系型数据库在处理大规模、高频次的时序数据时性能存在瓶颈&#xff0c;而专门针对时序数据设计的数据库系统则显示出其独特优势。TDengine 是其中的佼佼者&#xff0c;…