【Fine-Tuning】大模型微调理论及方法, PytorchHuggingFace微调实战

news2024/11/20 4:18:23

Fine-Tuning: 大模型微调理论及方法, Pytorch&HuggingFace微调实战

文章目录

  • Fine-Tuning: 大模型微调理论及方法, Pytorch&HuggingFace微调实战
    • 1. 什么是微调
      • (1) 为什么要进行微调
      • (2) 经典简单例子:情感分析
        • 任务
        • 背景
        • 微调
      • (3) 为什么微调work, 理论解释下
    • 2. 详细介绍微调的流程
      • (1) 准备数据, 预处理
      • (2) 微调策略
        • **前三种都差不多的逻辑, 古早**
        • 1. 冻结, 逐层微调
        • 2. 部分参数微调
        • 3. 全参数微调
        • 4. LoRA(低秩适应)
        • 5. Prompt Tuning
        • 6. RLHF(基于人类反馈的强化学习)
        • 7. Prefix Tuning
        • 8. Adapter微调
      • (3) 设置微调超参数
      • (4) 训练, 评估
    • 3. 具体怎么做
      • 常用的微调框架
      • HuggingFace版
      • Pytorch版
      • Pytorch vs HuggingFace
        • 易用性:
        • 灵活性
        • 性能

1. 什么是微调

大模型微调是指在预训练的大型模型基础上,使用特定数据集进行进一步训练,以适应特定任务或领域。

在这里插入图片描述

(1) 为什么要进行微调

  1. 大模型虽然知识丰富(由于其极大批量的预训练任务),但在特定领域可能不够准确。微调能让模型更好地理解特定任务。
  2. 相比从头开始训练一个新模型,微调节省了大量时间和计算资源(站在前人的肩膀上), 只需少量的数据就能有效提升模型在特定领域的性能。

(2) 经典简单例子:情感分析

任务

训练一个情感分析模型

背景

硬件很烂, 不可能从头训练一个情感分析大模型

但已经有预训练的语言模型比如BERT,已经在大量文本上进行过训练(这叫预训练)。

微调

BERT本身没有直接判断情感的能力, 但由于其在大量文本上进行的预训练任务, 其具有很多自然语言领域的 知识(预训练的权重), 通过少量的情感分析数据, 和合适的微调策略, 就能低成本的(数据, 算力)来微调出一个能进行情感分析的BERT

(3) 为什么微调work, 理论解释下

  1. 迁移学习: 深度学习模型有分层学习特征的特点, 底层学习通用特征, 高层学习任务相关特征, 将通用特征的知识迁移到相关的特定领域, 合理
  2. 统计学: 预训练可以看作为参数分布的先验估计, 微调就是在已有先验知识的基础上结合新数据

2. 详细介绍微调的流程

(1) 准备数据, 预处理

首先收集数据, 分成训练验证测试, 老生常谈, 都2024年了就不多说了

预处理: 每种大模型都有特定的输入格式, 要把原始数据转换成预训练大模型认识的数据输入

(2) 微调策略

策略有很多, 也有很多新冒出来的策略, 说一些常见的

前三种都差不多的逻辑, 古早
1. 冻结, 逐层微调

冻结就是权重固定, 不会再反向传播调整了

在这种策略中,模型的一部分参数被冻结,仅对特定层进行微调。逐层解冻的方法允许从顶层开始逐步释放冻结状态,以平衡预训练知识与新任务学习之间的关系.

2. 部分参数微调

和逐层微调本质上类似, 仅选择性地更新模型中的某些权重,通常是顶层或最后几层,而保持底层的大部分权重不变(冻结).

3. 全参数微调

全部参数都会反向传播, 这种方法资源消耗很大, 对数据要求也很高, 而且容易导致灾难性遗忘

灾难性遗忘(Catastrophic Forgetting): 微调模型在学习新任务时,突然或彻底忘记其预训练所学到的知识

在这里插入图片描述

4. LoRA(低秩适应)

LoRA通过在模型的每一层引入可训练的低秩矩阵来进行微调, 自适应的调整部分参数.

5. Prompt Tuning

轻量级的微调方法,不改变模型的主参数(全部冻结),通过为特定任务设计可学习的提示(prompt)来引导模型生成期望的输出。

6. RLHF(基于人类反馈的强化学习)

利用人类的反馈来纠正模型, 生成符合期望的结果

7. Prefix Tuning

在输入的前面前拼一些可训练的参数,使得模型在处理任务时能够更好地理解输入意图

8. Adapter微调

模型层之间插入小型可训练模块,这些模块可以适应新任务,而不影响原始模型的参数

(3) 设置微调超参数

设置/调整 学习率, BatchSize等参数, 让模型能收敛和防止拟合不好, 后面介绍

(4) 训练, 评估

用现成的框架训练, 验证, 测试, 后面介绍

3. 具体怎么做

由于深度学习技术的不断成熟, 各种稳定易用的框架逐渐出现, 让微调过程仅需要少许代码就能实现, 下面看看例子

常用的微调框架

  • Hugging Face Transformer

  • Pytorch

HuggingFace版

在这里插入图片描述

用HuggingFace对GraphCodeBERT进行微调

import torch
from transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments

# 加载预训练模型和tokenizer
tokenizer = RobertaTokenizer.from_pretrained("microsoft/graphcodebert-base")
model = RobertaForSequenceClassification.from_pretrained("microsoft/graphcodebert-base")

# 准备数据, 数据的预处理一般比较复杂
train_data = [...]  # 训练数据
train_encodings = tokenizer(train_data, truncation=True, padding=True)

# 定义训练参数
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=16,
    save_steps=10_000,
    save_total_limit=2,
)

# 创建Trainer实例并开始训练
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_encodings,
)

trainer.train()

实际肯定不止这么简单, 细节比较多, 比如数据的预处理, 和自定义的训练和评估.

由于各种下游任务的多样性, 不同任务的数据/标签差异非常大,这里没办法根据每种任务详细介绍预处理流程, 故在此略过. 我们一般需要自己写很多数据预处理的代码, 构造数据, 使得预训练模型能够接受数据输入.

再或是训练和评估, 由于使用者对模型的需求不同, 训练和评估过程也不一定相同, 自定义的流程往往需要写一些代码, 但是基本的训练和评估流程是封装好的. 代码中给出来了

一些基本的东西在HuggingFace中都有稳定的接口, 比如微调的策略, 参数定义, 基本的训练评估流程, 都是即插即用的

Pytorch版

用pytorch对BERT 使用RLHF策略 在情感分析任务上进行微调

import torch
import torch.nn as nn
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from torch.utils.data import DataLoader, Dataset

# 假设我们有一个简单的数据集
class CustomDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(self.texts[idx], padding='max_length', truncation=True, return_tensors='pt', max_length=128)
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }

# 1. 监督微调(SFT)
def supervised_fine_tuning(model, dataloader):
    model.train()
    optimizer = AdamW(model.parameters(), lr=5e-5)

    for epoch in range(3):  # 假设训练3个epoch
        for batch in dataloader:
            optimizer.zero_grad()
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            labels = batch['labels']
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            print(f"Loss: {loss.item()}")

# 2. 奖励模型训练
def train_reward_model(model, reward_data):
    # 假设reward_data包含文本和对应的奖励分数
    model.train()
    optimizer = AdamW(model.parameters(), lr=5e-5)

    for epoch in range(3):  # 假设训练3个epoch
        for text, reward in reward_data:
            inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
            optimizer.zero_grad()
            
            outputs = model(**inputs)
            reward_loss = nn.MSELoss()(outputs.logits.squeeze(), torch.tensor(reward, dtype=torch.float32))
            reward_loss.backward()
            optimizer.step()
            print(f"Reward Loss: {reward_loss.item()}")

# 3. RLHF训练
def rl_training(actor_model, critic_model, dataloader):
    actor_model.train()
    critic_model.eval()  # 奖励模型在评估模式

    for epoch in range(3):  # 假设训练3个epoch
        for batch in dataloader:
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']

            # 使用actor模型生成输出
            actor_outputs = actor_model(input_ids=input_ids, attention_mask=attention_mask)
            
            # 使用critic模型评估输出的奖励
            with torch.no_grad():
                critic_outputs = critic_model(input_ids=input_ids, attention_mask=attention_mask)

            # 根据奖励调整actor模型的参数(PPO等算法可在此实现)
            # 此处省略具体的PPO实现,需根据具体需求添加

# 示例数据集和模型初始化
texts = ["I love this!", "This is terrible."]
labels = [1, 0]  # 假设1为正面,0为负面

dataset = CustomDataset(texts, labels)
dataloader = DataLoader(dataset, batch_size=2)

model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

# 执行微调流程
supervised_fine_tuning(model, dataloader)

# 假设我们有一些奖励数据用于训练奖励模型
reward_data = [("I love this!", 1.0), ("This is terrible.", 0.0)]
train_reward_model(model, reward_data)

# 最后,进行RLHF训练(需实现具体的PPO算法)
rl_training(model, model)  # 此处使用同一模型作为示例

Pytorch vs HuggingFace

易用性:

​ HuggingFace的API非常简洁, 并且有丰富的涵盖多个领域的预训练模型库, 集成了多种常用的微调策略, 比如上面提到的LoRA等, 还有活跃的社区和丰富的文档
​ Pytorch缺乏高层封装, 在比如数据处理, 模型保存上需要用户手动实现更多的功能, 学习曲线陡峭

灵活性

​ HuggingFace灵活性不如Pytorch, 在高度自定义场景下, Pytorch表现更佳

性能

​ 在一些情况下, Pytorch在计算上设计了专门的优化, HuggingFace的高层API不如Pytorch的性能优化高效

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2199831.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《2024世界机器人报告》:中国为全球最大市场

国际机器人联合会(IFR)在9月24日最新发布的《世界机器人报告》中表示,全球有约428万台机器人在工厂运行,同比增长10%。工业机器人年安装量连续第三年超过50万台,在2023年新部署的工业机器人中,有70%在亚洲&…

基于Springboot+Vue的物业智慧系统 (含源码数据库)

1.开发环境 开发系统:Windows10/11 架构模式:MVC/前后端分离 JDK版本: Java JDK1.8 开发工具:IDEA 数据库版本: mysql5.7或8.0 数据库可视化工具: navicat 服务器: SpringBoot自带 apache tomcat 主要技术: Java,Springboot,mybatis,mysql,vue 2.视频演示地址 3.功能 在这个…

WordPress添加https协议致使后台打不开解决方法

由于删除WordPress缓存插件后操作不当,在加上升级处理,致使茹莱神兽博客的首页出现了https不兼容问题,WordPress后台也无法登陆,链接被误认为是定向重置次数过多,在网上找了好久的答案。 还有就是求助了好些人&#xf…

C++ —— 优先级队列(priority queue)的模拟实现

目录 杂谈 vector和list的区别 1. 优先级队列的定义 2. 优先级队列的模拟实现 3. 仿函数 链接: priority_queue - C Reference (cplusplus.com)https://legacy.cplusplus.com/reference/queue/priority_queue/?kwpriority_queue 杂谈 vector和list的区别 在…

Elastic Stack--16--ES三种分页策略

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 方式一:from size实现原理使用方式优缺点 方式二:scroll实现原理使用方式优缺点 方式三:search_after实现原理使用方式优缺点 三…

两个pdf怎么合并成一个pdf?超简单的合并方法分享

在日常工作和学习中,我们经常会遇到需要将多个PDF文件合并成一个文件的情况,以便更好地管理和分享。今天,将为大家详细介绍5种实用的方法,能够一键合并多个PDF文件,有需要的小伙伴快来一起学习下吧。 方法一&#xff1…

双十一买什么?双十一买什么东西最划算?超全双十一购物指南!

双十一即将到来,一年一度的购物狂欢盛宴再度开启!在海量的商品面前,怎样挑选出既心仪又实惠的好物,已然成为大家关注的重点。下面为您呈上一份极为全面的2024年双十一必买清单,助力您轻松购物,收获满满&…

详解Xilinx JESD204B PHY层端口信号含义及动态切换线速率(JESD204B五)

Xilinx官方提供了两个用于开发JESD204B的IP,其中一个完成PHY层设计,另一个完成传输层的逻辑,两个IP必须一起使用才能正常工作。 7系列FPGA只能使用最多12通道的JESD204B协议,线速率为1.0至12.5 Gb/s;而UltraScale和Ult…

胤娲科技:AI评估新纪元——LightEval引领透明化与定制化浪潮

AI评估的迷雾,LightEval能否拨云见日? 想象一下,你是一位AI模型的开发者,精心打造了一个智能助手,却在最终评估阶段遭遇了意外的“滑铁卢”。 问题出在哪里?是模型本身不够聪明,还是评估标准太过…

新手如何打造抖音矩阵账号,矩阵账号的优势有哪些?如何搭建矩阵系统的源码开发oem部署

抖音新手如何打造爆款矩阵账号? 在当前数字媒体盛行的时代,抖音作为一个领先的短视频分享平台,为品牌和个人提供了展示自己的舞台。对于初学者而言,构建一个有效的抖音账号矩阵是提升影响力的关键策略!今天&#xff0c…

mysql内置函数查询

聚合函数 :聚合函数查询时纵向查询,它是对一列的 值进行计算,然后返回一个单一的值,聚合函数会忽略空值。 namedescriptionavg()返回参数的平均值bit_and()按位返回andbit_or()按位返回orbit_xor()按位返回异或count()返回返回的…

Uos-Uos使用Remmina通过VNC远程连接到另一台Uos

Uos使用Remmina通过VNC远程连接到另一台Uos 一、概述二、对端机器安装 VNC服务器三、本机远程对端服务器 一、概述 这里记录一下使用Remmina通过VNC远程连接到另一台Uos系统,环境均是Linux操作系统 本机ip:10.8.11.64 对端ip:10.20.42.17 …

进程的状态的理解(概念+Linux)

文章目录 进程的状态并行和并发物理和逻辑 时间片进程具有独立性等待的本质运行阻塞标记挂起等待 Linux下的进程状态(一)运行状态(R - running)(二)睡眠状态(S - sleeping)&#xff…

银河麒麟V10中启用SELinux

银河麒麟V10中启用SELinux 1、启用SELinux1.1 切换到strict模式1.2 注意 2、验证SELinux状态 💖The Begin💖点点关注,收藏不迷路💖 在银河麒麟高级服务器操作系统V10中,可以使用security-switch工具来启用SELinux&…

springboot邮件群发功能的开发与优化策略?

springboot邮件配置指南?如何实现spring邮件功能? SpringBoot框架因其简洁、高效的特点,成为了开发邮件群发功能的理想选择。AokSend将深入探讨SpringBoot邮件群发功能的开发过程,并提出一系列优化策略,以确保邮件发送…

香山南湖架构分析--FE

总体架构 分支预测和指令缓存,通过FTQ达到解耦的目的;FTQ将请求送给ICache,进行取指;取出的指令码通过预译码初步检查分支预测的错误并及时冲刷预测流水线;检查后的指令送入指令缓冲并传给译码模块,最终形成后端的指令…

抓住最后机会!24年PMP认证报名今日开始,流程详解助你成功

为减少同一时间集中报名造成的网络拥堵,本次报名将采取以下形式分地区、分批次开放报名。 一、考试安排 考试时间:2024年11月30日 第一批报名城市 2024年10月9日10:00至10月16日16:00,以下城市的考点将开通报名&…

城市交通场景分割系统源码&数据集分享

城市交通场景分割系统源码&数据集分享 [yolov8-seg-C2f-Faster&yolov8-seg-GhostHGNetV2等50全套改进创新点发刊_一键训练教程_Web前端展示] 1.研究背景与意义 项目参考ILSVRC ImageNet Large Scale Visual Recognition Challenge 项目来源AAAI Glob…

FineReport打开报错“配置数据库出错“怎么解决?

配置数据库被锁住,是否重置?将在embed文件夹生成备份并重置 我直接用管理员身份证打开就完美解决了!

fmql之Linux下AXI GPIO、MISC

AXI GPIO 正点原子第41章。 要使用AXI GPIO,就要在vivado工程中,添加相关的IP。 然后dts会自动生成相关的AXi GPIO的设备树内容。 MISC 正点原子第42章。 /***************************************************************Copyright © ALIENTE…