LLM - 批量加载 dataset 并合并

news2024/12/23 23:51:50

目录

一.引言

二.Dataset 生成

1.数据样式

2.批量加载

◆ 主函数调用

◆ 基础变量定义

◆ 多数据集加载

3.数据集合并

◆ Concat

◆ interleave

◆ stopping_strategy

◆ interleave_probs

三.总结


一.引言

LLM 模型基于 transformer 进行训练,需要先生成 dataset,再将 dataset 根据任务需求生成对应的 input_ids、label_ids 等,本文介绍生成 dataset 的方法,即读取多个文件最终生成一个 dataset,后续介绍不同任务需求下 dataset 的转化。

Tips:

本文数据集与代码主要参考 Github LLaMA-Efficient-Tuning。

二.Dataset 生成

1.数据样式

 alpaca_data_zh_51k.json

◆ alpaca_gpt4_data_zh.json

数据集为 json 文件,其中每条 json 记录包含 3 个 key:

- instruction 可以理解为 prompt

- input 输入,即我们说的 Question

- output 输出,与 Question 对应的 Answer

上面的 3 个 key也可以简化,前面也提到过 LLM - Baichuan7B Tokenizer 生成训练数据,这里只用了 q、a 两个字段。 这里字段是什么其实并不重要,只要最后生成 input_ids 相关数据可以区分开就可以。

2.批量加载

def getBatchDataSet(_base_path, _data_files, _strategy):
    max_samples = 9999

    # support multiple datasets
    all_datasets: List[Union["Dataset", "IterableDataset"]] = []

    for input_path in _data_files:

        data_path = EXT2TYPE.get(input_path.split(".")[-1], None)

        dataset = load_dataset(
            data_path,
            data_files=[os.path.join(_base_path, input_path)],
            split="train",
            cache_dir=None,
            streaming=None,
            use_auth_token=True
        )

        if max_samples is not None:
            max_samples_temp = min(len(dataset), max_samples)
            dataset = dataset.select(range(max_samples_temp))

        print(dataset.features)
        all_datasets.append(dataset)

    if len(all_datasets) == 1:
        return all_datasets[0]
    elif _strategy == "concat":
        return concatenate_datasets(all_datasets)
    elif _strategy == "interleave":
        # all_exhausted
        stopping_strategy = "first_exhausted"
        interleave_probs = [0.5, 0.5]
        return interleave_datasets(all_datasets, interleave_probs, stopping_strategy=stopping_strategy)
    else:
        raise ValueError("UnKnown mixing strategy")

下面分步骤拆解下代码:

主函数调用

import os.path
from datasets import load_dataset, concatenate_datasets, interleave_datasets
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union, Tuple
from transformers import GPT2Tokenizer
from itertools import chain
import tiktoken


if __name__ == '__main__':
    # 多文件地址
    base_path = "/Users/LLaMA-Efficient-Tuning-main/data"
    data_files = ['alpaca_data_zh_51k.json', 'alpaca_gpt4_data_zh.json']
    strategy = 'concat'
    train_dataset = getBatchDataSet(base_path, data_files, strategy)

这里给定我们需要遍历的两个 json 文件以及对应的合并策略,策略后面再说。

基础变量定义

EXT2TYPE = {
    "csv": "csv",
    "json": "json",
    "jsonl": "json",
    "txt": "text"
}

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

max_samples = 9999

# support multiple datasets
all_datasets: List[Union["Dataset", "IterableDataset"]] = []

EXT2TYPE 为文件格式对应的 map,第二个 tokenizer 我们为了演示直接使用 Transformer 自带的 gpt2,max_samples 定义数据集截断,最后的 all_datasets 用于存储多个数据集。

多数据集加载

    for input_path in _data_files:

        data_path = EXT2TYPE.get(input_path.split(".")[-1], None)

        dataset = load_dataset(
            data_path,
            data_files=[os.path.join(_base_path, input_path)],
            split="train",
            cache_dir=None,
            streaming=None,
            use_auth_token=True
        )

        if max_samples is not None:
            max_samples_temp = min(len(dataset), max_samples)
            dataset = dataset.select(range(max_samples_temp))

        print(dataset.features)
        all_datasets.append(dataset)

遍历文件列表的文件与后缀,通过 from datasets import load_dataset 加载声称数据集,max_samples 配合 select 完成数据集的截断,最后将 dataset 添加到 all_datasets 中。这里 dataset.features 类似于 dataframe 的 schema,用于描述每一列的基础信息:

{'instruction': Value(dtype='string', id=None), 
 'input': Value(dtype='string', id=None), 
 'output': Value(dtype='string', id=None)}

下图为两个数据集记载打印的日志,由于之前已经做了 cache,所以直接读取 arrow 文件: 

3.数据集合并

    if len(all_datasets) == 1:
        return all_datasets[0]
    elif _strategy == "concat":
        return concatenate_datasets(all_datasets)
    elif _strategy == "interleave":
        # all_exhausted
        stopping_strategy = "first_exhausted"
        interleave_probs = [0.5, 0.5]
        return interleave_datasets(all_datasets, interleave_probs, stopping_strategy=stopping_strategy)
    else:
        raise ValueError("UnKnown mixing strategy")

由于训练只需要一个 dataset,所以多个文件读取的 dataset 需要合并为一个,上面展示了不同的合并策略,length == 1 的情况就不多说了,除此之外多数据集有两种合并策略:

Concat

cocnat 方法直接顺序拼接多个数据集

dataset-1 => A,B,C
dataset-2 => D,E,F
concat(dataset-1, dataset-2) => A,B,C,D,E,F

 interleave

interleave 方法用于实现数据交错从而防止过拟合。交错数据集是将两个或更多数据集混合在一起形成一个新的数据集。这样做的目的是使模型在训练时不会总是看到相同的数据顺序,从而提高模型的泛化能力。

dataset-1 => A,B,C
dataset-2 => D,E,F
interleave(dataset-1, dataset-2) => A,E,B,C,D,F

 stopping_strategy

stopping_strategy 用于定义数据集合并何时停止,有 first_exhausted 和 all_exhausted 两种交错策略:

- first_exhausted (先耗尽策略)

数据集会按照他被添加到 interleave 方法的顺序进行处理,当一个数据集被遍历完会停止生成数据,该方法适用于你希望遍历完第一个数据集就停止迭代。

- all_exhausted (全部耗尽策略)

数据集会按照他被添加到 interleave 方法的顺序进行处理,当全部数据集被遍历完会停止生成数据,该方法适用于你希望遍历完全部数据集就停止迭代。

这两种策略的主要区别在于何时停止迭代并抛出异常。first_exhausted 策略在遍历完第一个数据集后停止,而 all_exhausted 策略在遍历完所有数据集后停止。选择哪种策略取决于你的具体需求和数据集的特性。

 interleave_probs

在 interleave_datasets 方法中,interleave_probs 是一个可选参数,用于指定每个数据集的交错概率。当使用 interleave_datasets 方法交错多个数据集时,你可以通过 interleave_probs 参数为每个数据集指定一个概率。这个概率表示在生成交错数据集时,每个数据集被选择的概率。

例如,假设你有两个数据集 A 和 B,并且你设置 interleave_probs=[0.5, 0.5]。这意味着在生成交错数据集时,A 和 B 被选择的概率都是 0.5。

如果你设置 interleave_probs=[0.3, 0.7],则 A 被选择的概率是 0.3,而 B 被选择的概率是 0.7。

这个参数允许你根据需要对不同的数据集进行加权,以便在交错数据集时更倾向于选择某些数据集。

三.总结

LLM 大模型我们大部分时间是调用框架,调用现成模型去微调,熟悉一些工具的使用可以更方便我们在调优的时候对不同部分进行修改,本文主要用于加载原始数据生成 dataset,后续我们基于上面得到的 dataset 生成不同任务所需的数据集。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1003734.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

双token实现无感刷新

accessToken:用户获取数据权限refreshToken:用来获取新的accessToken 双 token 验证机制&#xff0c;其中 accessToken 过期时间较短&#xff0c;refreshToken 过期时间较长。当 accessToken 过期后&#xff0c;使用 refreshToken 去请求新的 token。 引入依赖 <!-- …

多版本CUDA安装切换

系统中默认的安装CUDA为12.0&#xff0c;现在需要在个人用户下安装CUDA11.7。 CUDA 下载 CUDA官网下载 安装 Log file not open.Segmentation fault (core dumped)错误 将/tmp/cuda-installer.log删除即可。重新安装&#xff0c;去掉驱动的安装&#xff0c;设置Toolkit的安装…

ST-LINK 下载器的使用

这两天我也是被这个ST-LINK搞的非常头疼&#xff0c;第一个是固件不兼容&#xff08;也就是keil5的魔法棒中显示ST-LINK connection error&#xff09;&#xff0c;第二个就是STM32 ST-LINK Utility的使用显示dll文件损坏 ST-LINK connection error 先来解决这个问题&#xff…

【附安装包】2023最新版Python安装详细教程!一键安装,永久使用

一、python官网 Python官网主要有python的About (简介)、Downloads (下载)、Documentation(文档)、Community (团体)、Success Stories (成功案例)、News (新闻)、Events (事件动态)等栏目。 Python官网地址&#xff1a;https://www.python.org/ 【领取方式见文末】 二、在…

如何选择感测型离子风机

离子风机在生产车间使用越来越广&#xff0c;对产品的要求也越来越高&#xff0c;而感测型离子风机正好满足。 感测型离子风机:内置感测和反馈功能&#xff1b;2.能快速静电中和及消除&#xff1b;高要求控制离子平衡&#xff1b;3.集感测&#xff0c;联网&#xff0c;通讯数据…

2023年数维杯数学建模C题宫内节育器的生产求解全过程文档及程序

2023年数维杯数学建模 C题 宫内节育器的生产 原题再现&#xff1a; 宫内节育器&#xff08;IUD&#xff09;是一种相对安全、有效、经济、可逆、简便&#xff0c;广大妇女易接受的节育器具&#xff0c;目前已成为我国育龄妇女的主要避孕措施。据悉&#xff0c;我国约70%妇女选…

设定excel导出时单元格的格式

一、需求 要求excel导出时&#xff0c;对应列里面的内容格式为日期&#xff0c;数值格式并有精度要求 &#xff0c;如下图&#xff1a; 使用alibaba&#xff0c;easyexcel&#xff0c;默认的导出数据格式为文本&#xff0c;excel显示为常规&#xff0c;使用数据规范注解Number…

玩转 gpgpu sim 01记 —— try it

1. 短介绍 gpgpu-sim 是一个gpu模拟器&#xff0c;可以让cuda/openCL程序运行在一个软件模拟器上&#xff0c;而不需要硬件GPU&#xff1b; 2. 目标 用最简单省事的方式跑通一个gpgpu-sim的仿真 3. gpgpu-sim 一点项目特性 开发比较早&#xff0c;没有持续的维护&#xff0…

vscode搭建Django自带后台管理系统

文章目录 一、django自带的后台管理系统1. 建表2. 后台管理系统2.1 创建账号2.2 运行后台2.3 登录 二、模版渲染1. 直接将数据渲染到页面2. 数据传递给js 三、数据库1. 查看当前数据库2. 创建UserInfo数据表3. Django rest framework配置 四、vue前端搭建1. 在Django项目的根目…

vue 使用canvas 详细教程

Vue.js 中使用 Canvas Vue.js 是一个流行的 JavaScript 框架&#xff0c;用于构建用户界面。它提供了一种简洁的方式来管理和渲染数据&#xff0c;同时也支持与其他库和工具的集成。要在 Vue.js 中使用 Canvas&#xff0c;您可以按照以下步骤进行操作&#xff1a; 在 Vue.js …

Visual Studio 2022安装SVN插件教程

1. 第一步&#xff1a;避免踩坑&#xff0c;超级重要&#xff01;&#xff01;&#xff01;关闭Visual Studio 2022应用程序&#xff1b;&#xff08;不然插件装不上&#xff0c;一直转圈&#xff01;&#xff09; 2.第二步&#xff1a;下载Visual Studio 2022版本对应的SVN插件…

最新IDE流行度最新排名(每月更新)

2023年09月IDE流行度最新排名 顶级IDE排名是通过分析在谷歌上搜索IDE下载页面的频率而创建的 一个IDE被搜索的次数越多&#xff0c;这个IDE就被认为越受欢迎。原始数据来自谷歌Trends 如果您相信集体智慧&#xff0c;Top IDE索引可以帮助您决定在软件开发项目中使用哪个IDE …

Excel显示列号

默认表格打开列以字母显示 设置方法 文件 -> 工具 -> 选项 -> 常规与保存 设置后效果如下图

2023年在线教育行业研究报告

第一章 行业概况 1.1 定义 随着技术的飞速发展和互联网的普及&#xff0c;我们的学习方式正在经历一场革命。在线教育&#xff0c;作为这场变革的核心&#xff0c;已经成为全球教育领域的热门话题。但究竟什么是在线教育行业呢&#xff1f; 在线教育行业是指通过互联网平台提…

【vue2】data中数据赋值失败找不到、data数据不声明的影响

&#x1f609;博主&#xff1a;初映CY的前说(前端领域) ,&#x1f4d2;本文核心&#xff1a;vue2data作用 前言&#xff1a;当你看到这篇文章相比你已经对vue有了一定的了解&#xff0c;对data的有了一个基本的认识&#xff1a;data是存放我们当前页面数据地方。是的&#xff0…

【Python小项目之Tkinter应用】随机点名/抽奖工具大优化:新增选项窗口!可选是否重复点名以及随机点名!可以手动选择文件及文件类型并预览文件!

文章目录 前言一、实现思路窗口逻辑按钮逻辑二、关键代码设置窗口布局实现具体组件实现选择文件与预览文件重中之重:抽取模式三、完整代码总结前言 老规矩,先看效果: 我们为抽奖工具新增了一个设置按钮,点击设置按钮后会出现一个弹窗,弹窗中有各种组件以帮助我们完成初始…

C语言——qsort()函数_学习笔记

本文目录 一、qsort()介绍二、参数详解三、qsort()函数应用举例3.1 排序数组类型的数据3.2 排序结构体类型的数据 四、模拟实现qsort()函数4.1 冒泡排序简单介绍4.2 实现bubble_sort()函数 一、qsort()介绍 qsort()函数是一个库函数&#xff0c;包含在头文件 <stdliib.h>…

Nginx部署前后端分离项目(Linux)

Nginx代理前端页面、后端接口 一、前端打包二、后端打包三、Linux部署Nginx启动、暂停、重启服务器部署文件地址&#xff1a; 一、前端打包 npm run build二、后端打包 通过Maven 使用package打包 三、Linux部署 安装Nginx 安装环境 yum -y install gcc pcre pcre-devel z…

电脑更换硬盘的时候怎么迁移系统?

为什么需要迁移系统&#xff1f; 在一些关于电脑DIY或Windows相关的论坛社区中&#xff0c;有很多人发帖询问怎么迁移系统。那么这个系统迁移&#xff0c;究竟是何含义呢&#xff1f;通俗易懂地解释一下&#xff0c;就是创建一个完整无缺的操作系统复制品&#xff0c;它与系…

硬件总线基础07:PCIe总线基础-事务层(1)

说在开头&#xff1a;关于我的世界&#xff08;4&#xff09; 几年前追过一个综艺&#xff1a;《导演请指教》。不仅仅是因为节目中那一部部小电影的诱惑力&#xff0c;更让人上头的是各方的点评&#xff1a;制片人&#xff0c;学院派&#xff0c;影评人&#xff0c;发行人、大…