LLM - Baichuan7B Tokenizer 生成训练数据

news2024/11/28 6:40:53

目录

一.引言

二.Tokenizer 原始数据

1.原始数据样例

2.加载并 Token 原始数据

2.1 参数准备

2.2 单条样本处理逻辑

2.3 批量处理逻辑

2.4 主函数与完整代码

三.shell 执行

四.总结


一.引言

前面提到了自己在微调 Baichuan7B Lora 的过程中遇到了一些问题,后面通过调整已经调通。鉴于自己刚刚从推荐算法转 AIGC,所以用笔记的形式记录下用于后面查漏补缺以及对 API 的熟悉。本文主要介绍 LORA 微调时原始数据的处理与编码,即 encode By tokenizer,最终生成可用的 Dataset。

二.Tokenizer 原始数据

1.原始数据样例

{"q": "请计算:39 * 0 = 什么?", "a": "这是简单的乘法运算,39乘以0得到的是0"}
{"q": "题目:51/186的答案是什么?", "a": "这是简单的除法运算,51除以186大概为0.274"}
{"q": "鹿妈妈买了24个苹果,她想平均分给她的3只小鹿吃,每只小鹿可以分到几个苹果?", "a":"鹿妈妈买了24个苹果,平均分给3只小鹿吃,那么每只
小鹿可以分到的苹果数就是总苹果数除以小鹿的只数。\n24÷3=8\n每只小鹿可以分到8个苹果。所以,答案是每只小鹿可以分到8个苹果。"}
{"q": "请计算:39 * 0 = 什么?", "a": "这是简单的乘法运算,39乘以0得到的是0"}
{"q": "题目:51/186的答案是什么?", "a": "这是简单的除法运算,51除以186大概为0.274"}
{"q": "鹿妈妈买了24个苹果,她想平均分给她的3只小鹿吃,每只小鹿可以分到几个苹果?", "a": "鹿妈妈买了24个苹果,平均分给3只小鹿吃,那么每>只小鹿可以分到的苹果数就是总苹果数除以小鹿的只数。\n24÷3=8\n每只小鹿可以分到8个苹果。所以,答案是每只小鹿可以分到8个苹果。"}
{"q": "请计算:39 * 0 = 什么?", "a": "这是简单的乘法运算,39乘以0得到的是0"}
{"q": "题目:51/186的答案是什么?", "a": "这是简单的除法运算,51除以186大概为0.274"}
{"q": "鹿妈妈买了24个苹果,她想平均分给她的3只小鹿吃,每只小鹿可以分到几个苹果?", "a": "鹿妈妈买了24个苹果,平均分给3只小鹿吃,那么每>只小鹿可以分到的苹果数就是总苹果数除以小鹿的只数。\n24÷3=8\n每只小鹿可以分到8个苹果。所以,答案是每只小鹿可以分到8个苹果。"}

这里 q 可以理解为 question,a 可以理解为 answer,上面将基础的训练数据重复了几次生成原始的训练文件 simple.json。

2.加载并 Token 原始数据

2.1 参数准备

import argparse
import json
from tqdm import tqdm
import datasets
import transformers

# 1.参数准备
parser = argparse.ArgumentParser()
parser.add_argument("--model_checkpoint", type=str, help="checkpoint, like `THUDM/chatglm-6b`") # 必填
parser.add_argument("--input_file", type=str, help="Instruction 数据文件地址,文件中每一行都是json格式,包含一个输出和一个输出") # 必填
parser.add_argument("--prompt_key", type=str, default=f"prompt", help="你的jsonl文件里,Instruction 的输入字段是什么") # 选填
parser.add_argument("--target_key", type=str, default=f"target", help="你的jsonl文件里,Instruction 的输出字段是什么") # 必填
parser.add_argument("--save_name", type=str, default=f"temp", help="经过tokenize之后的数据集的存放位置") # 选填
parser.add_argument("--max_seq_length", type=int, default=2040) # 选填
parser.add_argument("--skip_overlength", type=bool, default=False) # 选填
args = parser.parse_args()

参数采用 argparse 类进行初始化:

- model_checkpoint : 预训练模型地址,这里我们提前把 Baichuan7B 或者 ChatGLM 下载好即可

- input_file : 原始训练数据,训练数据格式为 json,可以参考上面的数据示例

- prompt_key : 训练数据在 json 里 prompt 提示对应的 key,上例为 q

- target_key : 训练数据在 json 里 target 提示对应的 key,上例为 a

- save_name : 保存地址,数据最终会议 arrow 的数据将 dataset 保存

- max_seq_length : 最长阶段序列长度

- skip_overlength : 是否忽略超长的文本,True 时忽略,False 时采取截断

2.2 单条样本处理逻辑

以 json 里一条样本为例:

{"q": "请计算:39 * 0 = 什么?", "a": "这是简单的乘法运算,39乘以0得到的是0"}
def preprocess(tokenizer, config, example, max_seq_length, prompt_key, target_key):
    prompt = example[prompt_key]
    target = example[target_key]
    prompt_ids = tokenizer.encode(prompt, max_length=max_seq_length, truncation=True)
    target_ids = tokenizer.encode(target, max_length=max_seq_length, truncation=True, add_special_tokens=False)
    # 最终还是将 instruction 的输入输出都拼在一起,使用经典的 causal-LM 的 next word prediction 方式来训练
    input_ids = prompt_ids + target_ids + [config.eos_token_id] # EOS 用于标识句子结束
    return {"input_ids": input_ids, "seq_len": len(prompt_ids)}

根据配置的 prompt_key 和 target_key 获取 json 里对应的 prompt 与 target 内容,本例下 prompt_key = "q",target_key = "a",通过加载预训练模型获取对应的 Tokenizer 对 q、a 的文本进行 encode 编码。

Q: 请计算:39 * 0 = 什么?
A: 这是简单的乘法运算,39乘以0得到的是0
TokenQ: [31010, 6184, 77, 55, 61, 1734, 31106, 52, 1147, 31106, 1534, 75]
TokenA: [31106, 3908, 14313, 32329, 31257, 31481, 31742, 72, 55, 61, 32329, 31187, 52, 5442, 2585, 52]

为什么把 QA 前后连接拼到一起,上面的注释也给出了原因,该样本用于使用 causal-LM 模型进行 next word 的预测即续写功能的训练。通过将 Q 放在 A 前面训练,学习 QA 的前后文字逻辑。未来模型训练完毕后,我们给出 Q,模型机会根据之前的训练续写出 A 的相关内容。

2.3 批量处理逻辑

上面 preprocess 的逻辑主要在 read_json 里调用,该方法主要用于加载预训练模型生成 Tokenizer 与 config,

def read_json(path, max_seq_length, prompt_key,target_key,skip_overlength=False):
    # 基于预训练模型加载获取 tokenizer 和 config
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_checkpoint, trust_remote_code=True)
    config = transformers.AutoConfig.from_pretrained(
        model_checkpoint, trust_remote_code=True, device_map='auto')
    with open(path, "r") as f:
        for line in tqdm(f.readlines()):
            example = json.loads(line)
            feature = preprocess(tokenizer, config, example, max_seq_length,prompt_key,target_key)
            if skip_overlength and len(feature["input_ids"]) > max_seq_length:
                continue
            # 截取最大长度
            feature["input_ids"] = feature["input_ids"][:max_seq_length]
            yield feature

json.loads 加载一条样本随后调用 preprocess 生成训练 json,这里会根据 skip_overlength 参数决定是否忽略超长样本,最后返回 feature json。这里 tqdm 用于为迭代器 iterator 生成一个可视化的进度条,是一个辅助类。

本着一个参数都不放过的原则,博主查阅了模型加载中用到的两个参数含义:

- trust_remote_code

该参数指示系统在执行远程或外部代码时如何处理安全性和信任性。如果 "trust_remote_code" 设置为 True,则系统将信任并执行远程或外部提供的代码,而不进行严格的安全检查或验证。反之则系统会采取更谨慎的做法,并对远程或外部提供的代码进行安全性检查和验证,以确保其不会造成潜在的风险或恶意操作。这是一种常见的安全策略,用于防止恶意代码或攻击者利用远程执行漏洞来入侵系统。由于我们一般加载的都是官方认可的预训练模型,例如 Baichuan7B、ChatGLM 等等,所以一般看到的代码里都是 True。

- device_map

该参数用于指定设备映射或设备配置的相关信息。可以使用 map 将任务分配给特定的硬件设备或资源,当然也可以像上面一样使用 auto。

2.4 主函数与完整代码

# 输入文件统一放在 data 文件夹下
# 输出文件统一放在 data/tokenized_data 文件夹下
input_file_path = f'data/{args.input_file}'
save_path = f"data/tokenized_data/{args.save_name}"
dataset = datasets.Dataset.from_generator(
    lambda: read_jsonl(input_file_path, args.max_seq_length, args.prompt_key,args.target_key,args.skip_overlength)
)

dataset.save_to_disk(save_path)

这里默认原始训练文件 json 存放在 data 文件夹下,经过 tokenizer 的样本放在 data/tokenized_data 目录下,当然也可以根据自己习惯调整,这个位置影响不大。根据路径调用 datasets.Dataset 的 API 进行 DataSet 的生成与存储。

- 完整代码

import argparse
import json
from tqdm import tqdm
import datasets
import transformers

# 1.参数准备
parser = argparse.ArgumentParser()
parser.add_argument("--model_checkpoint", type=str, help="checkpoint, like `THUDM/chatglm-6b`") # 必填
parser.add_argument("--input_file", type=str, help="Instruction 数据文件地址,文件中每一行都是json格式,包含一个输出和一个输出") # 必填
parser.add_argument("--prompt_key", type=str, default=f"prompt", help="你的jsonl文件里,Instruction 的输入字段是什么") # 选填
parser.add_argument("--target_key", type=str, default=f"target", help="你的jsonl文件里,Instruction 的输出字段是什么") # 必填
parser.add_argument("--save_name", type=str, default=f"temp", help="经过tokenize之后的数据集的存放位置") # 选填
parser.add_argument("--max_seq_length", type=int, default=2040) # 选填
parser.add_argument("--skip_overlength", type=bool, default=False) # 选填
args = parser.parse_args()
model_checkpoint = args.model_checkpoint

#. 2.处理逻辑
def preprocess(tokenizer, config, example, max_seq_length, prompt_key, target_key):
    prompt = example[prompt_key]
    target = example[target_key]
    prompt_ids = tokenizer.encode(prompt, max_length=max_seq_length, truncation=True)
    target_ids = tokenizer.encode(target, max_length=max_seq_length, truncation=True, add_special_tokens=False)
    # 最终还是将 instruction 的输入输出都拼在一起,使用经典的 causal-LM 的 next word prediction 方式来训练
    input_ids = prompt_ids + target_ids + [config.eos_token_id] # EOS 用于标识句子结束
    return {"input_ids": input_ids, "seq_len": len(prompt_ids)}

# 3.读取训练 JSON
def read_jsonl(path, max_seq_length, prompt_key,target_key,skip_overlength=False):
    # 基于预训练模型加载获取 tokenizer 和 config
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_checkpoint, trust_remote_code=True)
    config = transformers.AutoConfig.from_pretrained(
        model_checkpoint, trust_remote_code=True, device_map='auto')
    with open(path, "r") as f:
        for line in tqdm(f.readlines()):
            example = json.loads(line)
            feature = preprocess(tokenizer, config, example, max_seq_length,prompt_key,target_key)
            if skip_overlength and len(feature["input_ids"]) > max_seq_length:
                continue
            # 截取最大长度
            feature["input_ids"] = feature["input_ids"][:max_seq_length]
            yield feature


# 输入文件统一放在 data 文件夹下
# 输出文件统一放在 data/tokenized_data 文件夹下
input_file_path = f'data/{args.input_file}'
save_path = f"data/tokenized_data/{args.save_name}"
dataset = datasets.Dataset.from_generator(
    lambda: read_jsonl(input_file_path, args.max_seq_length, args.prompt_key,args.target_key,args.skip_overlength)
)

dataset.save_to_disk(save_path)


三.shell 执行

simple.json 为我们的测试样例,tokenizer_data 为存储 token 后 DataSet 的地址。下面看下 tokenizer.sh 的 shell 脚本:

baichuan="/model/baichuan-7B"

input=simple.json

CUDA_VISIBLE_DEVICES=0 python tokenize_dataset_rows.py \
    --model_checkpoint $baichuan \
    --input_file $input \
    --prompt_key q \
    --target_key a \
    --save_name simple_token_by_baichuan-7B \
    --max_seq_length 2000 \
    --skip_overlength False

执行上述脚本即可得到 tokenizer 后的数据:

当使用 dataset.save_to_dist 方法保存数据集合时会生成三个文件: 

dataset.arrow: 这是主要的数据文件,其中包含数据集的实际内容。它以 Apache Arrow 格式存储,这种格式旨在高效地存储和处理大规模数据集。该文件可能包含数据样本、标签、特征、元数据等。

dataset.info.json: 这个 JSON 文件包含与数据集相关的元信息。它提供了关于数据集结构、列名称、数据类型、特征信息、统计摘要等详细信息。通过读取此文件,可以获得数据集的描述性信息,以便更好地理解数据的组织和特征。

dataset.state.json: 这个 JSON 文件包含数据集的状态信息,例如上次更新的时间戳、版本号、数据集大小等。它记录了数据集的状态和元数据,以便在后续操作中能够恢复到相同的点,并确保数据集的一致性。

四.总结

基于 10 条左右的样本基于 Baichuan7B 微调后,我们测试了原始模型与 Lora 后的效果:

可以看到原始模型存在续写混乱的问题,前面再说查字典,后面又在介绍字典,而经过 Lora 微调后,模型已经能够回答该问题,但是如果换个问法可能就又答不上来了,所以 prompt 工程和样本工程在 AIGC 中是重要组成部分。后面我们将基于 tokenized_data 介绍如何进行 Lora 模型训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/742012.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Cesium Token申请

一、什么是Cesium ion? Cesium ion是一个提供瓦片图和3D地理空间数据的平台,支持把数据添加到用户自己的应用程序中。 二、为什么需要access token? 使用Cesium ion需要申请access token,当用户将数据添加到自己的账户后,便可以…

Android-jar包方式连接本地sqlite并操作返回数据

背景: 数据库的创建及字段都是后端人员维护,Android端只是映射相关数据库到本地来操作。为了统一管理操作方式方法,所以提出,后端打jar包的方式封装对Android端数据库sqllite的连接、操作。 说明: 因为之前是后端打jar包,JDBC连接的驱动及方法有差异,导致连接Android…

PHP在线拨打电话的代码

这段代码包括一个HTML表单,用于收集用户的姓名,电子邮件和消息。当用户提交表单时,邮件将发送到指定的电子邮件地址,并显示一条消息,指示我们将在不久的将来拨打电话回复。请注意,在上面的代码中,电话号码硬编码为 $phone_number 变量,您需要将其更改为您想要的电话号码…

jmeter使用正则表达式匹配多个中的响应结果

目录 一、背景: 二、例如: 三、接口响应的所有结果: 四、正则表达式的写法: 五、调试的时候添加一个Debug PostProcessor 调试器 六、在Debug PostProcessor中可以查看到获取的参数结果 七、引用方式:${testValue…

【js前端去空格】

javascript字符串去空格 js去除字符串空格的方法 说到去除空格,首先都会想到的就是trim()方法,但是trim()只能去除字符串前后的空格,无法去除字符串中间的空格。 下面总结一下js去除字符串空格的几种方法: 1、trim() trim()是…

你应该知道的C语言干货(4)(strncpy,strncmp,strncat,strstr,strtok)

我们知道包含string.h头文件后,就可以使用strncpy,strncmp,strncat,strstr,strtok这些库函数,接下来让我们了解他们。 目录 #strncpy #strncmp #strncat #strstr #strtok #下期预告 #strncpy 该库函数作用和strcpy很相似,不同点在于 发现了吗…

《面试1v1》Redis分片集群

🍅 作者简介:王哥,CSDN2022博客总榜Top100🏆、博客专家💪 🍅 技术交流:定期更新Java硬核干货,不定期送书活动 🍅 王哥多年工作总结:Java学习路线总结&#xf…

Redis Brpop 命令

目录 一、作用二、demo演示 一、作用 Redis Brpop 命令拥有移出并获取list右边的最后一个元素, 如果列表没有元素会阻塞列表直到等待超时或发现可弹出元素为止。 二、demo演示 向 list1 中插入三个元素 a、b、c lpush list1 a b c查看list1中的元素 lrange lis…

操作系统13:中断处理程序和设备驱动程序

目录 1、中断处理程序 (1)中断和陷入 (2)中断处理程序的处理过程 2、设驱动程序 (1)设备驱动程序的功能 (2)设备驱动程序的处理过程 (4)对 I/O 设备的…

Mac 和 Win,到底用哪个系统学编程?

今天来聊一个老生常谈的问题,学编程时到底选择什么操作系统?Mac、Windows,还是别的什么。。 作为一个每种操作系统都用过很多年的程序员,我会结合我自己的经历来给大家一些参考和建议。 接下来先分别聊聊每种操作系统的优点和不…

React懒加载/动态加载lazy简单实例

两种页面嵌套的方式,一种是父子组件,一种是懒加载 1、父子组件(可略,只用来做例子对比) 原本需要用父子组件来实现页面嵌套,如果嵌套的组件不多,可以这样实现 父页面 import React,{Componen…

Android代码解读之渲染机制揭秘

问题 1.vsync如何协调应用和SurfaceFlinger配合来完成UI渲染、显示,App接收vsync后要做哪些工作? 2.requestLayout和invalidate区别? 3.performTraversals到底是干什么了? 4.surfaceflinger怎么分发vsync信号的? …

【Java】继承背后那些事---深扒继承基本原理|类加载、子类对象创建、方法调用、变量访问

博主简介:努力学习的预备程序媛一枚~博主主页: 是瑶瑶子啦所属专栏: Java岛冒险记【从小白到大佬之路】 学习了继承、多态 本节,将通过一个简单的例子,从概念上介绍原理(实际实现的细节与此有所差别)&#…

HuggingGPT解析:使用 ChatGPT及HuggingFace上的族系解决AI问题

HuggingGPT解析:使用 ChatGPT及HuggingFace上的族系解决AI问题 HuggingGPT是一个利用大型语言模型(LLMs)来解决复杂AI任务的框架。其基本理念是,考虑到LLMs(例如ChatGPT)在语言理解、生成、交互和推理方面展现出了卓越的能力&…

一个优质软件测试工程师简历的范文(答应我一定要收藏起来)

很多刚转行软件测试的小伙伴是不是不知道怎么写好一份优质的软件测试工程师的简历。今天呢,就给大家分享一下一个优质软件测试工程师简历的范文。记得收藏起来哦。 下面的案例:2-3年的软件测试工程的简历 姓 名:XXX 学历:本科 …

源码解析Collections.sort ——从一个逃过单测的 bug 说起

源码解析Collections.sort ——从一个逃过单测的 bug 说起 本文从一个小明写的bug 开始,讲bug的发现、排查定位,并由此展开对涉及的算法进行图解分析和源码分析。 事情挺曲折的,因为小明的代码是有单测的,让小明更加笃定自己写的…

第四节 配置SpringBootAdmin日志管理

本来想用一节就写完SpringBootAdmin的,但随着研究的深入发现一节应该是不够的,网上的资料也不会非常系统,官网的例子有些已经好几年没更新了,所以接下来还是系统性的来写下吧 第一节 完成基础配置,暴露所有端点 第二节…

uniapp App强制更新

需要使用DClound插件市场的一个插件挺好用的!app升级、整包更新和热更新组件 支持vue3 支持打开安卓、苹果应用市场,wgt静默更新https://ext.dcloud.net.cn/plugin?id7286 开始贴代码 // /utils/method.js/*** 获取当前app最新版本* param number ver…

【JAVA】这几个JAVA学习网站你绝不能错过(教学课程篇)

个人主页:【😊许思王】 文章目录 前言HOW2J.CNw3cschool菜鸟教程慕课网开课吧黑马程序员B站 前言 JAVA很难学?学不会怎么办?找对学习网站,让你轻松解决困难。 HOW2J.CN HOW2J.CN是我自认为最好的JAVA学习网站&#x…

df -h 查看Used+Avail != Size

问题描述: 在测试过程中发现,该机器的根目录空间 41G 5.7G ! 50G,即 Used Avail ! Size 问题原因: 经过搜索,了解到这种情况可能是Linux系统默认的文件保留块导致的(Linux系统默认保留5%的容量作为应急…