【Instruction Tuning】ChatGLM 微调实战(附源码)

news2025/1/10 15:14:20

在之前的文章中,我们已经讲过了 ChatGPT 的三个主要流程:

  1. SFT:通过 Instruction Tuning 来微调一个监督学习模型
  2. Reward Model:通过排序序列来训练一个打分模型
  3. Reinforcement Learning:通过强化学习来进一步优化模型。

何枝:【RLHF】想训练ChatGPT?得先弄明白Reward Model怎么训(附源码)409 赞同 · 37 评论文章正在上传…重新上传取消

前两篇文章主要对 RM 和 RL 两部分进行了讲解和实验,

但无数的经验向我们证明 —— 拥有一个好的 SFT 的模型对后两步的训练至关重要。

由于在 RL 训练过程中会加入与 SFT 模型的相似度(KL-Divergence)惩罚,

这意味着 RL 模型的上限很大程度上取决于 SFT 模型

为此,我们今天来重点讲一讲如何通过 ChatGLM 来微调一个读懂我们指令的模型。

1. GLM Backbone

Paper Link:  https://arxiv.org/pdf/2103.10360.pdf

在讲微调代码之前,我们先来看看 GLM 的基本架构。

我们都知道,目前主流的两种 Backbone:一类是以 BERT 为首的 Encoder 架构(双向注意力),另一种是以 GPT 为首的 Decoder 架构(单向注意力)。

这两种架构各有各的好处,一个更适合做理解,一个更适合做生成

那么如何将这两种模型做合并,集二者优势于一身,是近年来人们一直在尝试的努力(如:T5、BART等)。

不同于 Encoder-Decoder 的堆叠,GLM 通过一种巧妙的 2D Position Embedding,并通过 Attention MASK 来使得模型在训练时 「既能在部分内容上存在双向注意力」「又能在生成任务中保持单向注意力」。

以下是 GLM 示意图:

GLM Position Embedding 示意图

  1. 首先,从原始句子中 Random Sample 出来一些 Span 用于并 [MASK] 掉(该思想源自 BERT),注意:这里是以 Span 维度进行 MASK 的。
  2. 将原句子分为两组,PART A 是原句子,只不过句子中被挑选出来的 Span 用 [MASK] 符号代替;PART B 是挑选出来的 Span 集合
  3. 将挑选出来的 MASK Span 集合(PART B)拼接在原句子(PART A)后面,注意:这里是先对 PART B 做乱序后,再拼接到句子后面(目的是为了训练 Position Embedding)。
  4. 设计 2D Position:这是我认为比较有趣的设定,位置编码分成了两组。一组用于表征「全局位置」,被挑选出的「MASK SPAN」中的所有 token 的位置索引都等于整个 Span 在原句子中的位置(例如:x5, x6 的索引都是 5);而另一组用来专门表征 MASK Span 内部 token 的相对位置编码(例如:x5, x6 的索引这两个 token 在 Mask Span 中的相对位置)。
  5. 通过设置 Attention MASK,使得 PART A 中的内容是双向可见的,且 PART B 中所有 token 也可以看到 Part A 中的内容;而对于 PART B 中的内容保持单向可见。
  6. 通过对 Part B 中的内容做「生成任务」来进行模型迭代。

以上便是我认为 GLM 中最关键的几个点。

2. Finetune GLM

2.1 数据集准备

我们以信息抽取任务为例,将一个信息抽取数据集(DuIE)添加上 Instruction,以此来教会 ChatGLM 根据我们的指令来完成抽取任务。

我们仿照 Alpaca 数据集,将数据结构设为以下形式:

{
    "instruction": "你现在是一个很厉害的阅读理解器,找到句子中的三元组信息并输出成json给我。",
    "input": "九玄珠是在纵横中文网连载的一部小说,作者是龙马。",
    "target": "```json\n[{\"predicate\": \"连载网站\", \"object_type\": \"网站\", \"subject_type\": \"网络小说\", \"object\": \"纵横中文网\", \"subject\": \"九玄珠\"}, {\"predicate\": \"作者\", \"object_type\": \"人物\", \"subject_type\": \"图书作品\", \"object\": \"龙马\", \"subject\": \"九玄珠\"}]\n```"
}

进一步的,我们将 instruction 和 input 字段合并,得到如下数据:

{
    "context": "Instruction: 你现在是一个很厉害的阅读理解器,找到句子中的三元组信息并输出成json给我:。\nInput: 九玄珠是在纵横中文网连载的一部小说,作者是龙马。\nAnswer: ", 
    "target": "```json\n[{\"predicate\": \"连载网站\", \"object_type\": \"网站\", \"subject_type\": \"网络小说\", \"object\": \"纵横中文网\", \"subject\": \"九玄珠\"}, {\"predicate\": \"作者\", \"object_type\": \"人物\", \"subject_type\": \"图书作品\", \"object\": \"龙马\", \"subject\": \"九玄珠\"}]\n```"
}

其中,

  • Instruction:存放我们希望模型做的任务的指令
  • Input:存放我们喂给模型的任务数据
  • Target:存放模型的输出标签

2.2 Label 构建

将数据集解析为训练 label 的代码如下:

def convert_example(
        examples: dict, 
        tokenizer,
        max_source_seq_len: int,
        max_target_seq_len: int,
    ):
    """
    将样本数据转换为Ptuning模型接收的输入数据。

    Args:
        examples (dict): 训练数据样本, e.g. -> {
                                                "text": [
                                                            '{"context": "年基准利率4.35%。从实际看...", "target": "2017年银行贷款基准利率"}',
                                                            ...
                                                ]
                                            }
        max_source_seq_len (int): prompt最大长度
        max_target_seq_len (int): 答案最大长度

    Returns:
        dict (str: np.array) -> tokenized_output = {
                            'input_ids': [[1525, 10, ...], [758, 2345, ...]], 
                            'labels': [[822, 10, ...], [125, 58...]]
                        }
    """
    tokenized_output = {
        'input_ids': [],
        'labels': []
    }

    max_seq_length = max_source_seq_len + max_target_seq_len

    for example in examples['text']:
        try:
            example = json.loads(example)
            context = example["context"]
            target = example["target"]

            prompts_ids = tokenizer.encode(
                text=context,
                add_special_tokens=False
            )

            target_ids = tokenizer.encode(
                text=target,
                add_special_tokens=False
            )                    

            if len(prompts_ids) >= max_source_seq_len:                                          # source 需要留一个 [gMASK] token 在结尾
                prompts_ids = prompts_ids[:max_source_seq_len - 1]

            if len(target_ids) >= max_target_seq_len - 1:                                       # target 需要留一个 <sop> 在开头和一个 <eop> token 在结尾
                target_ids = target_ids[:max_target_seq_len - 2]

            input_ids = tokenizer.build_inputs_with_special_tokens(prompts_ids, target_ids)     # source_ids + [gMASK] + <sop> + target_ids + <eop>
            context_length = input_ids.index(tokenizer.bos_token_id)                            # bos 在 target 的第一位
            mask_position = context_length - 1                                                  # [gMASK] 在 source 的最后一位
            labels = [-100] * context_length + input_ids[mask_position + 1:]                    # 从 bos 开始到后面所有的 target 到 eos 都为 label

            pad_len = max_seq_length - len(input_ids)
            input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
            labels = labels + [-100] * pad_len

            tokenized_output['input_ids'].append(input_ids)
            tokenized_output['labels'].append(labels)
        except:
            print(f'"{example}" -> {traceback.format_exc()}')
            continue

    for k, v in tokenized_output.items():
        tokenized_output[k] = np.array(v)

    return tokenized_output

其中,

  • max_source_seq_len 用于设定模型接收的最大输入长度
  • max_target_seq_len 用于设定模型输出的最大长度

2.3 模型训练

ChatGLM 的微调存在 LoRA Finetune 和 P-Tuning 两种微调方式。

P-Tuning V.S. LoRA

这两种方式都可以使得 ChatGLM-6B 的模型能在 32G 的 V100 上进行微调训练。

通过以下两种参数配置即可选择使用 P-Tuning 还是 LoRA:

# LoRA Finetune
python train.py \
    --train_path data/mixed_train_dataset.jsonl \
    --dev_path data/mixed_dev_dataset.jsonl \
    --use_lora True \
    --lora_rank 8 \
    --batch_size 1 \
    --num_train_epochs 2 \
    --save_freq 1000 \
    --learning_rate 3e-5 \
    --logging_steps 100 \
    --max_source_seq_len 400 \
    --max_target_seq_len 300 \
    --save_dir checkpoints/finetune \
    --img_log_dir "log/fintune_log" \
    --img_log_name "ChatGLM Fine-Tune" \
    --device cuda:0


# P-Tuning
python train.py \
    --train_path data/mixed_train_dataset.jsonl \
    --dev_path data/mixed_dev_dataset.jsonl \
    --use_ptuning True \
    --pre_seq_len 128 \
    --batch_size 1 \
    --num_train_epochs 2 \
    --save_freq 200 \
    --learning_rate 2e-4 \
    --logging_steps 100 \
    --max_source_seq_len 400 \
    --max_target_seq_len 300 \
    --save_dir checkpoints/ptuning \
    --img_log_dir "log/fintune_log" \
    --img_log_name "ChatGLM P-Tuning" \
    --device cuda:0

其中,pre_seq_len 是指在每个层前面添加多少个可学习的前缀 token,该值设置的越大显存占用也会越大。

在我们的实验下,两种方式的效果差异不大:

P-Tuning v.s. LoRA Finetune

模型最终的训练结果如下:

模型训练结果

好啦,以上就是 ChatGLM 的全部内容,感谢观看~

完整源码在这里:

ChatGLM Finetune Code​github.com/HarderThenHarder/transformers_tasks/blo

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/644445.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【MySQL新手入门系列一】:手把手教你入门MySQL

如果您是一位刚刚开始学习MySQL的新手&#xff0c;本文将为您提供一些实用的入门知识和技巧&#xff0c;帮助您快速上手。 本篇文章将以windows为例&#xff0c;介绍MySQL的基础知识&#xff0c;以及如何安装、卸载、配置和使用它。 导读 一、概览1.1 MySQL是什么1.2 为什么要学…

一文掌握linux网络相关命令

欢迎关注博主 Mindtechnist 或加入【Linux C/C/Python社区】一起学习和分享Linux、C、C、Python、Matlab&#xff0c;机器人运动控制、多机器人协作&#xff0c;智能优化算法&#xff0c;滤波估计、多传感器信息融合&#xff0c;机器学习&#xff0c;人工智能等相关领域的知识和…

阿里企业邮箱登录入口

阿里企业邮箱登录入口&#xff1a;https://qiye.aliyun.com/ 阿里企业邮箱可以使用邮箱账号登录&#xff0c;也可以使用钉钉账号登录&#xff0c;打开登录入口&#xff0c;如下图&#xff1a; 阿里企业邮箱登录入口 企业邮箱购买页面&#xff1a;aliyunbaike.com/go/mail免费企…

归并排序和快速排序(C++)

归并排序是一种经典的排序算法&#xff0c;也被称为“归并算法”。它的基本思想是将待排序数组分成若干个子数组&#xff0c;每个子数组都是有序的&#xff0c;然后将这些子数组合并成一个大的有序数组。 具体实现过程如下&#xff1a; 将待排序数组不断划分为左右两个子数组&…

IMX6ULL裸机篇之SPI实验-SPI主控代码实现

一. SPI 实验 SPI实验&#xff1a;学习如何使用 I.MX6U 的 SPI 接口来驱动 ICM-20608&#xff0c;读取 ICM-20608 的六轴数据。 本文学习 SPI主控芯片的代码编写。其中&#xff0c;包括SPI工作模式设置&#xff0c;主从模式设置&#xff0c;时钟配置等实现。 二. SPI 主控芯…

光学介质材料——光学膜

手机、平板、智能电视等设备之所以能够发光发亮离不开一个重要的组成材料——光学膜。那光学膜是什么回事呢&#xff1f; 光学膜是指在光学元件或独立基板上&#xff0c;制镀或涂布一层或多层介电质膜或金属膜或这两类膜的组合&#xff0c;以改变光波的传递特性&#xff0c;包…

MySQL索引:让你的数据库查询快到起飞!

&#x1f495;世界上最美好的东西之一&#xff0c;就是你每天都有机会开始全新的一天。&#x1f495; &#x1f43c;作者&#xff1a;不能再留遗憾了&#x1f43c; &#x1f386;专栏&#xff1a;MySQL学习&#x1f386; &#x1f697;本文章主要内容&#xff1a;详细介绍如何查…

SSD、内存和 L1 Cache 相比速度差多少倍

一道面试题&#xff1a;SSD、内存和 L1 Cache 相比速度差多少倍&#xff1f; 其实比起复杂的技术问题&#xff0c;我更喜欢在面试中提问这种像生活常识一样的简单问题。因为我觉得&#xff0c;复杂的问题是由简单的问题组成的&#xff0c;如果你把简单的问题学扎实了&#xff…

自动化运维工具—Ansible

一、Ansible概述 1.1 Ansible是什么 Ansible是一个基于Python开发的配置管理和应用部署工具&#xff0c;现在也在自动化管理领域大放异彩。它融合了众多老牌运维工具的优点&#xff0c;Pubbet和Saltstack能实现的功能&#xff0c;Ansible基本上都可以实现。 Ansible能批量配…

面试问题总结---嵌入式部分和项目部分

1、本栏用来记录社招找工作过程中的内容,包括基础知识学习以及面试问题的记录等,以便于后续个人回顾学习; 暂时只有2023年3月份,第一次社招找工作的过程; 2、个人经历: 研究生期间课题是SLAM在无人机上的应用,有接触SLAM、Linux、ROS、C/C++、DJI OSDK等; 3、参加工作后…

面试问题总结----ROS部分

1、本栏用来记录社招找工作过程中的内容,包括基础知识学习以及面试问题的记录等,以便于后续个人回顾学习; 暂时只有2023年3月份,第一次社招找工作的过程; 2、个人经历: 研究生期间课题是SLAM在无人机上的应用,有接触SLAM、Linux、ROS、C/C++、DJI OSDK等; 3、参加工作后…

Python3.9使用最新版pyinstaller将项目或程序打包成exe或者mac中的可执行文件

1、pyinstaller的说明&#xff1a; pyinstaller 能够在 Windows、Linux、Mac 等操作系统下将 Python 源文件打包&#xff0c;通过对源文件打包&#xff0c; Python 程序可以在没有安装 Python 的环境中运行&#xff0c;也可以作为一个独立文件方便传递和管理。 PyInstaller 支…

NLP-基于bertopic工具的新闻文本分析与挖掘

NLP-基于bertopic工具的新闻文本分析与挖掘 一&#xff0c;前言 最近简单接触了一些NLP的内容&#xff0c;练一下如何结合ChatGPT进行学习。 二&#xff0c;具体过程 &#xff08;1&#xff09;预处理文本&#xff0c;记录处理过程。 在使用Bertopic进行主题建模之前&…

【数据库一】MySQL数据库初体验

MySQL数据库初体验 1.数据库基本概念1.1 数据Data1.2 表1.3 数据库1.4 数据库管理系统1.5 数据库系统 2.数据库的发展3.主流的数据库介绍3.1 SQL Server&#xff08;微软公司产品&#xff09;3.2 Oracle &#xff08;甲骨文公司产品&#xff09;3.3 DB2&#xff08;IBM公司产品…

MySQL-索引详解(五)

♥️作者&#xff1a;小刘在C站 ♥️个人主页&#xff1a; 小刘主页 ♥️努力不一定有回报&#xff0c;但一定会有收获加油&#xff01;一起努力&#xff0c;共赴美好人生&#xff01; ♥️学习两年总结出的运维经验&#xff0c;以及思科模拟器全套网络实验教程。专栏&#xf…

【K8S系列】深入解析k8s网络之—网络故障

序言 你只管努力&#xff0c;其他交给时间&#xff0c;时间会证明一切。 文章标记颜色说明&#xff1a; 黄色&#xff1a;重要标题红色&#xff1a;用来标记结论绿色&#xff1a;用来标记一级论点蓝色&#xff1a;用来标记二级论点 Kubernetes (k8s) 是一个容器编排平台&#x…

FasterTransformer 005 初始化:如何将参数传给模型?

cpp的例子 device_malloc cpp没有用具体数值初始化 float *d_from_tensor NULL;device_malloc(&d_from_tensor, batch_size * seq_len * hidden_dim);https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/sample/cpp/transformer_fp32.cc#L35-L38 直接用的cudaMal…

【电子学会】2023年03月图形化四级 -- 绘制直尺

绘制直尺 编写一段程序&#xff0c;绘制一段7厘米的直尺。 1. 准备工作 &#xff08;1&#xff09;保留小猫角色&#xff0c;隐藏&#xff1b; &#xff08;2&#xff09;白色背景。 2. 功能实现 &#xff08;1&#xff09;点击绿旗&#xff0c;设置笔的颜色为红色&#…

事务和事务的隔离级别

一、事务 &#xff08;一&#xff09;为什么需要事务 事务是数据库管理系统&#xff08;DBMS&#xff09;执行过程中的一个逻辑单位&#xff08;不可再进行分割&#xff09;&#xff0c;由一个有限的数据库操作序列构成&#xff08;多个DML语句&#xff0c;select语句不包含事…

数字图像处理期末复习习题 SCUEC part1

1.在利用LoG算子做边缘检测的时候&#xff0c;作为一种经验法则&#xff0c;当滤波器空间参数为a7时&#xff0c;LoG滤波器空域模板大小应为 答&#xff1a;4343 理由是&#xff1a;n大于等于6a1 2.空间域方法主要分为灰度变换和空间滤波两类&#xff0c;灰度变换在图像的单…