八、大模型之Fine-Tuning(1)

news2025/1/17 23:11:49

1 什么时候需要Fine-Tuning

  1. 有私有部署的需求
  2. 开源模型原生的能力不满足业务需求

2 训练模型利器Hugging Face

  1. 官网(https://huggingface.co/)
  2. 相当于面向NLP模型的Github
  3. 基于transformer的开源模型非常全
  4. 封装了模型、数据集、训练器等,资源下载方面
  5. 安装依赖
# pip 安装
pip install transformers # 安装最新版本
pip install transformers == 4.30 # 安装指定版本
# conda安装
conda install -c huggingface transformers  # 只4.0以后的版本

3 案例

3.1 操作流程

加载数据集—>数据预处理—>数据规整器—>训练器
在这里插入图片描述

3.2 实现

  1. 导包
import datasets
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from transformers import AutoModelForCausalLM
from transformers import TrainingArguments, Seq2SeqTrainingArguments
from transformers import Trainer, Seq2SeqTrainer
import transformers
from transformers import DataCollatorWithPadding
from transformers import TextGenerationPipeline
import torch
import numpy as np
import os, re
from tqdm import tqdm
import torch.nn as nn
  1. 加载数据集
    通过HuggingFace,可以指定数据集名称,运行时自动下载
# 数据集名称
DATASET_NAME = "rotten_tomatoes" 

# 加载数据集
raw_datasets = load_dataset(DATASET_NAME)

# 训练集
raw_train_dataset = raw_datasets["train"]

# 验证集
raw_valid_dataset = raw_datasets["validation"]

在这里插入图片描述
3. 加载模型

# 模型名称
MODEL_NAME = "gpt2" 

# 加载模型 
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,trust_remote_code=True)

在这里插入图片描述
4. 加载Tokenizer
通过HuggingFace,可以指定模型名称,运行自动下载对应Tokenizer

# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME,trust_remote_code=True)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.pad_token_id = 0

# 设置随机种子:同个种子的随机序列可复现
transformers.set_seed(42)

# 标签集
named_labels = ['neg','pos']

# 标签转 token_id
label_ids = [
    tokenizer(named_labels[i],add_special_tokens=False)["input_ids"][0] 
    for i in range(len(named_labels))
]

在这里插入图片描述
5. 处理数据集:转成模型接受的输入格式

  • 拼接输入输出:<INPUT TOKEN IDS><EOS_TOKEN_ID><OUTPUT TOKEN IDS>
  • PAD成相等长度:
    • <INPUT 1.1><INPUT 1.2>…<EOS_TOKEN_ID><OUTPUT TOKEN IDS><PAD>…<PAD>
    • <INPUT 2.1><INPUT 2.2>…<EOS_TOKEN_ID><OUTPUT TOKEN IDS><PAD>…<PAD>
  • 标识出参与 Loss 计算的 Tokens (只有输出 Token 参与 Loss 计算)
    • <-100><-100>…<OUTPUT TOKEN IDS><-100>…<-100>
MAX_LEN=32   #最大序列长度(输入+输出)
DATA_BODY_KEY = "text" # 数据集中的输入字段名
DATA_LABEL_KEY = "label" #数据集中输出字段名

# 定义数据处理函数,把原始数据转成input_ids, attention_mask, labels
def process_fn(examples):
    model_inputs = {
            "input_ids": [],
            "attention_mask": [],
            "labels": [],
        }
    for i in range(len(examples[DATA_BODY_KEY])):
        inputs = tokenizer(examples[DATA_BODY_KEY][i],add_special_tokens=False)
        label = label_ids[examples[DATA_LABEL_KEY][i]]
        input_ids = inputs["input_ids"] + [tokenizer.eos_token_id, label]
        
        raw_len = len(input_ids)
        input_len = len(inputs["input_ids"]) + 1

        if raw_len >= MAX_LEN:
            input_ids = input_ids[-MAX_LEN:]
            attention_mask = [1] * MAX_LEN
            labels = [-100]*(MAX_LEN - 1) + [label]
        else:
            input_ids = input_ids + [tokenizer.pad_token_id] * (MAX_LEN - raw_len)
            attention_mask = [1] * raw_len + [0] * (MAX_LEN - raw_len)
            labels = [-100]*input_len + [label] + [-100] * (MAX_LEN - raw_len)
        model_inputs["input_ids"].append(input_ids)
        model_inputs["attention_mask"].append(attention_mask)
        model_inputs["labels"].append(labels)
    return model_inputs

6.定义数据规整器:训练时自动将数据拆分成Batch

# 定义数据校准器(自动生成batch)
collater = DataCollatorWithPadding(
    tokenizer=tokenizer, return_tensors="pt",
)

7.定义训练超参

LR=2e-5         # 学习率
BATCH_SIZE=8    # Batch大小
INTERVAL=100    # 每多少步打一次 log / 做一次 eval

# 定义训练参数
training_args = TrainingArguments(
    output_dir="./output",              # checkpoint保存路径
    evaluation_strategy="steps",        # 按步数计算eval频率
    overwrite_output_dir=True,
    num_train_epochs=1,                 # 训练epoch数
    per_device_train_batch_size=BATCH_SIZE,     # 每张卡的batch大小
    gradient_accumulation_steps=1,              # 累加几个step做一次参数更新
    per_device_eval_batch_size=BATCH_SIZE,      # evaluation batch size
    eval_steps=INTERVAL,                # 每N步eval一次
    logging_steps=INTERVAL,             # 每N步log一次
    save_steps=INTERVAL,                # 每N步保存一个checkpoint
    learning_rate=LR,                   # 学习率
)

8.定义训练器

# 节省显存
model.gradient_checkpointing_enable()

# 定义训练器
trainer = Trainer(
    model=model, # 待训练模型
    args=training_args, # 训练参数
    data_collator=collater, # 数据校准器
    train_dataset=tokenized_train_dataset,  # 训练集
    eval_dataset=tokenized_valid_dataset,   # 验证集
    # compute_metrics=compute_metric,         # 计算自定义评估指标
)

8.训练

trainer.train()

总结

  1. 加载数据集
  2. 数据预处理
    • 将输入输出按特定格式拼接
    • 文本转Token IDs
    • 通过labels标识出哪部分是输出(只有输出的token参与loss计算)
  3. 加载模型、Tokenizer
  4. 定义数据规则整器
  5. 定义训练超参:学习率、批次大小
  6. 定义训练器
  7. 开始训练

4 大模型训练相关技术

  1. 神经网络
    在这里插入图片描述

  2. 常用的激活函数
    在这里插入图片描述

  3. 梯度下降
    在这里插入图片描述

  4. 学习率
    在这里插入图片描述

  5. 求解器

为了让训练过程更好的收敛,人们设计了很多更复杂的求解器

  • 比如:SGD、L-BFGS、Rprop、RMSprop、Adam、AdamW、AdaGrad、AdaDelta 等等
  • 但是,好在对于Transformer最常用的就是 Adam 或者 AdamW
  1. 一些常用的损失函数
  • 两个数值的差距,Mean Squared Error: ℓ M S E = 1 N ∑ i = 1 N ( y i − y ^ i ) 2 \ell_{\mathrm{MSE}}=\frac{1}{N}\sum_{i=1}^N(y_i-\hat{y}_i)^2 MSE=N1i=1N(yiy^i)2 (等价于欧式距离,见下文)

  • 两个向量之间的(欧式)距离: ℓ ( y , y ^ ) = ∥ y − y ^ ∥ \ell(\mathbf{y},\mathbf{\hat{y}})=\|\mathbf{y}-\mathbf{\hat{y}}\| (y,y^)=yy^

  • 两个向量之间的夹角(余弦距离):
    在这里插入图片描述

  • 两个概率分布之间的差异,交叉熵: ℓ C E ( p , q ) = − ∑ i p i log ⁡ q i \ell_{\mathrm{CE}}(p,q)=-\sum_i p_i\log q_i CE(p,q)=ipilogqi ——假设是概率分布 p,q 是离散的

  • 这些损失函数也可以组合使用(在模型蒸馏的场景常见这种情况),例如 L = L 1 + λ L 2 L=L_1+\lambda L_2 L=L1+λL2,其中 λ \lambda λ是一个预先定义的权重,也叫一个「超参」

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1560069.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

工作究竟是谁的?

在近两年的就业环境中&#xff0c;普遍存在着挑战与不确定性&#xff0c;许多人追求的是一种稳定的工作和收入来源。在这样的背景下&#xff0c;我们来探讨一个核心问题&#xff1a;工作的归属是谁的&#xff1f; 根据《穷爸爸富爸爸》中提出的ESBI四象限理论&#xff0c;我们可…

【经典算法】LeetCode14:最长公共前缀(Java/C/Python3实现含注释说明)

最长公共前缀 题目思路及实现方式一&#xff1a;横向扫描思路代码实现Java版本C语言版本Python3版本 复杂度分析 方式二&#xff1a;纵向扫描思路代码实现Java版本C语言版本Python3版本 复杂度分析 方式三&#xff1a;分治思路代码实现Java版本C语言版本Python3版本 复杂度分析…

单元测试mockito(一)

1.单元测试 1.1 单元测试的特点 ●配合断言使用(杜绝System.out) ●可重复执行 。不依赖环境 ●不会对数据产生影响 ●spring的上下文环境不是必须的 ●一般都需要配合mock类框架来实现 1.2 mock类框架使用场景 要进行测试的方法存在外部依赖(如db,redis,第三方接口调用等),为…

3.31学习总结

(本次学习总结,总结了目前学习java遇到的一些关键字和零碎知识点) 一.static关键字 static可以用来修饰类的成员方法、类的成员变量、类中的内部类&#xff08;以及用static修饰的内部类中的变量、方法、内部类&#xff09;&#xff0c;另外可以编写static代码块来优化程序性…

权限问题(Windows-System)

方法&#xff1a;用命令来写一个注册表的脚本 &#xff1f;System是最高级用户&#xff0c;但不拥有最高级权限 编写两文档&#xff1a;system.reg 和 remove.reg,代码如下&#xff1a; system.reg&#xff1a; Windows Registry Editor Version 5.00[-HKEY_CLASSES_ROOT\*…

YOLOv5改进 | 低照度检测 | 2024最新改进CPA-Enhancer链式思考网络(适用低照度、图像去雾、雨天、雪天)

一、本文介绍 本文给大家带来的2024.3月份最新改进机制&#xff0c;由CPA-Enhancer: Chain-of-Thought Prompted Adaptive Enhancer for Object Detection under Unknown Degradations论文提出的CPA-Enhancer链式思考网络&#xff0c;CPA-Enhancer通过引入链式思考提示机制&am…

Proteus 12V to 5V buck电路仿真练习及遇到的一些问题汇总

基础电路仿真实验记录贴&#xff01;&#xff01;&#xff01;如有写的不对的地方欢迎交流指正&#xff01;&#xff01;&#xff01; 平台&#xff1a;PC win10 软件&#xff1a;Proteus8.10 仿真目标&#xff1a;buck降压电路&#xff08;PWM控制输出电压&#xff09; 写在…

《Lost in the Middle: How Language Models Use Long Contexts》AI 解读

作者&#xff1a;明明如月学长&#xff0c; CSDN 博客专家&#xff0c;大厂高级 Java 工程师&#xff0c;《性能优化方法论》作者、《解锁大厂思维&#xff1a;剖析《阿里巴巴Java开发手册》》、《再学经典&#xff1a;《Effective Java》独家解析》专栏作者。 热门文章推荐&am…

MySQL核心命令详解与实战,一文掌握MySQL使用

文章目录 文章简介演示库表创建数据库表选择数据库删除数据库创建表删除表向表中插入数据更新数据删除数据查询数据WHERE 操作符聚合函数LIKE 子句分组 GROUP BY HAVINGORDER BY(排序) 语句LIMIT 操作符 分页查询多表查询-联合查询 UNION 操作符多表查询-连接的使用-JOIN语句编…

选择排序及其优化

目录 思想&#xff1a; 代码&#xff1a; 代码优化&#xff1a; 需要注意的特殊情况&#xff1a; 可能出现的所有特殊情况&#xff1a; 优化完成代码&#xff1a; 思想&#xff1a; 每一次遍历数组&#xff0c;选择出最大或最小的数&#xff0c;将其与数组末尾或首位进行…

Oracle Solaris 11.3开工失败问题处理记录

1、故障现像 起初是我这有套RAC有点问题&#xff0c;我想重启1个节点&#xff0c;结果发现重启后该节点的IP能PING通&#xff0c;但SSH连不上去&#xff0c;对应的RAC服务也没有自动启动。 操作系统是solaris 11.3。由于该IP对应的主机是LDOM&#xff0c;于是我去主域上telnet…

汇编语言第四版-王爽第2章 寄存器

二进制左移四位&#xff0c;相当于四进制左移一位。 debug命令实操&#xff0c;win11不能启动&#xff0c;需要配置文件 Windows64位系统进入debug模式_window10系统64位怎么使用debugger-CSDN博客

扫雷(蓝桥杯)

题目描述 小明最近迷上了一款名为《扫雷》的游戏。其中有一个关卡的任务如下&#xff0c; 在一个二维平面上放置着 n 个炸雷&#xff0c;第 i 个炸雷 (xi , yi ,ri) 表示在坐标 (xi , yi) 处存在一个炸雷&#xff0c;它的爆炸范围是以半径为 ri 的一个圆。 为了顺利通过这片土…

开源博客项目Blog .NET Core源码学习(13:App.Hosting项目结构分析-1)

开源博客项目Blog的App.Hosting项目为MVC架构的&#xff0c;主要定义或保存博客网站前台内容显示页面及后台数据管理页面相关的控制器类、页面、js/css/images文件&#xff0c;页面使用基于layui的Razor页面&#xff08;最早学习本项目就是想学习layui的用法&#xff0c;不过最…

网络安全 | 网络攻击介绍

关注wx&#xff1a;CodingTechWork 网络攻击 网络攻击定义 以未经授权的方式访问网络、计算机系统或数字设备&#xff0c;故意窃取、暴露、篡改、禁用或破坏数据、应用程序或其他资产的行为。威胁参与者出于各种原因发起网络攻击&#xff0c;从小额盗窃发展到战争行为。采用各…

C语言自定义类型

本篇文章主要介绍三种自定义类型&#xff0c;分别是&#xff1a;结构体、联合体、枚举。 一.结构体 1.结构体类型的声明 直接举一个例子&#xff1a; //一本书 struct s {char name[10]; //名称char a; //作者int p; //价格 }; 2.特殊的声明 结构体也可以不写结构体标…

NVIDIA Jetson Xavier NX入门-镜像为jetpack5(3)——pytorch和torchvision安装

NVIDIA Jetson Xavier NX入门-镜像为jetpack5&#xff08;3&#xff09;——pytorch和torchvision安装 镜像为jetpack5系列&#xff1a; NVIDIA Jetson Xavier NX入门-镜像为jetpack5&#xff08;1&#xff09;——镜像烧写 NVIDIA Jetson Xavier NX入门-镜像为jetpack5&#…

医院陪诊管理系统(源码+文档)

TOC) 文件包含内容 1、搭建视频 2、流程图 3、开题报告 4、数据库 5、参考文献 6、服务器接口文件 7、接口文档 8、任务书 9、功能图 10、环境搭建软件 11、十六周指导记录 12、答辩ppt模板 13、技术详解 14、前端后台管理&#xff08;管理端程序&#xff09; 15、项目截图 1…

CCIE-07-OSPF_TS

目录 实验条件网络拓朴逻辑拓扑实现目标 环境配置开始Troubleshooting问题1. R22的e0/0接口配置了网络类型问题2. R22和R21之间的IP地址子网掩码长度不一致问题3. R21的e0/0口配置了被动接口问题4. R3配置了不一致的hello-time问题5. R21配置了max-metric导致路由无效问题6. R3…

LLM大模型可视化-以nano-gpt为例

内容整理自&#xff1a;LLM 可视化 --- LLM Visualization (bbycroft.net)https://bbycroft.net/llm Introduction 介绍 Welcome to the walkthrough of the GPT large language model! Here well explore the model nano-gpt, with a mere 85,000 parameters. 欢迎来到 GPT 大…