【Python】科研代码学习:八 FineTune PretrainedModel (用 trainer,用 script);LLM文本生成

news2024/10/5 16:27:54

【Python】科研代码学习:八 FineTune PretrainedModel [用 trainer,用 script] LLM文本生成

  • 自己整理的 HF 库的核心关系图
  • 用 trainer 来微调一个预训练模型
  • 用 script 来做训练任务
  • 使用 LLM 做生成任务
    • 可能犯的错误,以及解决措施

自己整理的 HF 库的核心关系图

  • 根据前面几期,自己整理的核心库的使用/继承关系
    在这里插入图片描述

用 trainer 来微调一个预训练模型

  • HF官网API:FT a PretrainedModel
    今天讲讲FT训练相关的内容吧
    这里就先不提用 keras 或者 native PyTorch 微调,直接看一下用 trainer 微调的基本流程
  • 第一步:加载数据集和数据集预处理
    使用 datasets 进行加载 HF 数据集
from datasets import load_dataset

dataset = load_dataset("yelp_review_full")

另外,需要用 tokenizer 进行分词。自定义分词函数,然后使用 dataset.map() 可以把数据集进行分词。

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)


tokenized_datasets = dataset.map(tokenize_function, batched=True)

也可以先选择其中一小部分的数据单独拿出来,做测试或者其他任务

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
  • 第二步,加载模型,选择合适的 AutoModel 或者比如具体的 LlamaForCausalLM 等类。
    使用 model.from_pretrained() 加载
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=5)
  • 第三步,加载 / 创建训练参数 TrainingArguments
from transformers import TrainingArguments

training_args = TrainingArguments(output_dir="test_trainer")
  • 第四步,指定评估指标。trainer 在训练的时候不会去自动评估模型的性能/指标,所以需要自己提供一个
    ※ 这个 evaluate 之前漏了,放后面学,这里先摆一下 # TODO
import numpy as np
import evaluate

metric = evaluate.load("accuracy")
  • 第五步,使用 trainer 训练,提供之前你创建好的:
    model模型,args训练参数,train_dataset训练集,eval_dataset验证集,compute_metrics评估方法
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)
trainer.train()
  • 完整代码,请替换其中的必要参数来是配置自己的模型和任务
from datasets import load_dataset
from transformers import (
    LlamaTokenizer,
    LlamaForCausalLM,
    TrainingArguments,
    Trainer,
    )
import numpy as np
import evaluate

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)


metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

"""
Load dataset, tokenizer, model, training args
preprosess into tokenized dataset
split training dataset and eval dataset
"""
dataset = load_dataset("xxxxxxxxxxxxxxxxxxxx")

tokenizer = LlamaTokenizer.from_pretrained("xxxxxxxxxxxxxxxxxxxxxxxxxx")
tokenized_datasets = dataset.map(tokenize_function, batched=True)

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

model = LlamaForCausalLM.from_pretrained("xxxxxxxxxxxxxxx")

training_args = TrainingArguments(output_dir="xxxxxxxxxxxxxx")

"""
define metrics
set trainer and train
"""

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()

用 script 来做训练任务

  • 我们在很多项目中,都会看到启动脚本是一个 .sh 文件,一般里面可能会这么写:
python examples/pytorch/summarization/run_summarization.py \
    --model_name_or_path google-t5/t5-small \
    --do_train \
    --do_eval \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir /tmp/tst-summarization \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --overwrite_output_dir \
    --predict_with_generate
  • 或者最近看到的一个
OUTPUT_DIR=${1:-"./alma-7b-dpo-ft"}
pairs=${2:-"de-en,cs-en,is-en,zh-en,ru-en,en-de,en-cs,en-is,en-zh,en-ru"}
export HF_DATASETS_CACHE=".cache/huggingface_cache/datasets"
export TRANSFORMERS_CACHE=".cache/models/"
# random port between 30000 and 50000
port=$(( RANDOM % (50000 - 30000 + 1 ) + 30000 ))

accelerate launch --main_process_port ${port} --config_file configs/deepspeed_train_config_bf16.yaml \
     run_cpo_llmmt.py \
    --model_name_or_path haoranxu/ALMA-13B-Pretrain \
    --tokenizer_name haoranxu/ALMA-13B-Pretrain \
    --peft_model_id  haoranxu/ALMA-13B-Pretrain-LoRA \
    --cpo_scorer kiwi_xcomet \
    --cpo_beta 0.1 \
    --use_peft \
    --use_fast_tokenizer False \
    --cpo_data_path  haoranxu/ALMA-R-Preference \
    --do_train \
    --language_pairs ${pairs} \
    --low_cpu_mem_usage \
    --bf16 \
    --learning_rate 1e-4 \
    --weight_decay 0.01 \
    --gradient_accumulation_steps 1 \
    --lr_scheduler_type inverse_sqrt \
    --warmup_ratio 0.01 \
    --ignore_pad_token_for_loss \
    --ignore_prompt_token_for_loss \
    --per_device_train_batch_size 2 \
    --evaluation_strategy no \
    --save_strategy steps \
    --save_total_limit 1 \
    --logging_strategy steps \
    --logging_steps 0.05 \
    --output_dir ${OUTPUT_DIR} \
    --num_train_epochs 1 \
    --predict_with_generate \
    --prediction_loss_only \
    --max_new_tokens 256 \
    --max_source_length 256 \
    --seed 42 \
    --overwrite_output_dir \
    --report_to none \
    --overwrite_cache 
  • 玛雅,这么多 --xxx ,看着头疼,也不知道怎么搞出来这么多参数作为启动文件的。
    这种就是通过 script 启动任务了
  • github:transformers/examples
    看一下 HF github 给的一些任务的 examples 学习例子,就会发现
    main 函数中,会有这样的代码
    这个就是通过 argparser 来获取参数
    貌似还有 parserHfArgumentParser,这些都可以打包解析参数,又是挖个坑 # TODO
    这样的话,就可以通过 .sh 来在启动脚本中提供相关参数了
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
    )

    parser.add_argument("--prompt", type=str, default="")
    parser.add_argument("--length", type=int, default=20)
    parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")

# ....... 太长省略
  • 用脚本启动还有什么好处呢
    可以使用 accelerate launch run_summarization_no_trainer.py 进行加速训练
    再给 accelerate 挖个坑 # TODO
  • 所以,在 .sh script 启动脚本中具体能提供哪些参数,取决于这个入口 .py 文件的 parser 打包解析了哪些参数,然后再利用这些参数做些事情。

使用 LLM 做生成任务

  • HF官网API:Generation with LLMs
    官方都特地给这玩意儿单独开了一节,就说明其中有些很容易踩的坑…
  • 对于 CausalLM,首先看一下 next token 的生成逻辑:输入进行分词与嵌入后,通过多层网络,然后进入到一个LM头,最终获得下一个 token 的概率预测
  • 那么生成句子的逻辑,就是不断重复这个过程,获得 next token 概率预测后,通过一定的算法选择下一个 token,然后再重复该操作,就能生成整个句子了。
  • 那什么时候停止呢?要么是下一个token选择了 eos,要么是到达了之前定义的 max token length
    在这里插入图片描述
  • 接下来看一下代码逻辑
  • 第一步,加载模型
    device_map:控制模型加载在 GPUs上,不过一般我会使用 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 以及 os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
    load_in_4bit 设置加载量化
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.1", device_map="auto", load_in_4bit=True
)
  • 第二步,加载分词器和分词
    记得分词的向量需要加载到 cuda
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", padding_side="left")
model_inputs = tokenizer(["A list of colors: red, blue"], return_tensors="pt").to("cuda")
  • 但这个是否需要分词取决于特定的 model.generate() 方法的参数
    就比如 disc 模型的 generate() 方法的参数为:
    也就是说,我输入的 prompt 只用提供字符串即可,又不需要进行分词或者分词器了。
    在这里插入图片描述
  • 第三步,通常的 generate 方法,输入是 tokenized 后的数组,然后获得 ids 之后再 decode 变成对应的字符结果
generated_ids = model.generate(**model_inputs)
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
  • 当然我也可以批处理,一次做多个操作,批处理需要设置pad_token
tokenizer.pad_token = tokenizer.eos_token  # Most LLMs don't have a pad token by default
model_inputs = tokenizer(
    ["A list of colors: red, blue", "Portugal is"], return_tensors="pt", padding=True
).to("cuda")
generated_ids = model.generate(**model_inputs)
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

可能犯的错误,以及解决措施

  • 控制输出句子的长度
    需要在 generate 方法中提供 max_new_tokens 参数
model_inputs = tokenizer(["A sequence of numbers: 1, 2"], return_tensors="pt").to("cuda")

# By default, the output will contain up to 20 tokens
generated_ids = model.generate(**model_inputs)
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

# Setting `max_new_tokens` allows you to control the maximum length
generated_ids = model.generate(**model_inputs, max_new_tokens=50)
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
  • 生成策略修改
    有时候默认使用贪心策略来获取 next token,这个时候容易出问题(循环生成等),需要设置 do_sample=True
    在这里插入图片描述

  • pad 对齐方向
    如果输入不等长,那么会进行pad操作
    由于默认是右侧padding,而LLM在训练时没有学会从pad_token接下来的生成策略,所以会出问题
    所以需要设置 padding_side="left![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/6084ff91d85c49e28a4faf498b8e5997.png) "
    在这里插入图片描述

  • 如果没有使用正确的 prompt(比如训练时的prompt格式),得到的结果就会不如预期
    (in one sitting = 一口气) (thug = 暴徒)
    这里需要参考 HF对话模型的模板 以及 HF LLM prompt 指引
    在这里插入图片描述
    比如说,QA的模板就像这样。
    更高级的还有 few shotCOT 技巧。

torch.manual_seed(4)
prompt = """Answer the question using the context below.
Context: Gazpacho is a cold soup and drink made of raw, blended vegetables. Most gazpacho includes stale bread, tomato, cucumbers, onion, bell peppers, garlic, olive oil, wine vinegar, water, and salt. Northern recipes often include cumin and/or pimentón (smoked sweet paprika). Traditionally, gazpacho was made by pounding the vegetables in a mortar with a pestle; this more laborious method is still sometimes used as it helps keep the gazpacho cool and avoids the foam and silky consistency of smoothie versions made in blenders or food processors.
Question: What modern tool is used to make gazpacho?
Answer:
"""

sequences = pipe(
    prompt,
    max_new_tokens=10,
    do_sample=True,
    top_k=10,
    return_full_text = False,
)

for seq in sequences:
    print(f"Result: {seq['generated_text']}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1510647.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

8块硬盘故障的存储异常恢复案例一则

关键词 华为存储、硬盘域、LUN热备冗余、重构、预拷贝 oracle rac、多路径 There are many things that can not be broken! 如果觉得本文对你有帮助,欢迎点赞、收藏、评论! 一、问题现象 近期遇到的一个案例,现象是一套oracl…

PXI8540高速数据采集卡

XI高速数据采集卡,PXI8540卡是一种基于PXI总线的模块化仪器,可使用PXI系统,在一个机箱内实现一个综合的测试系统,构成实验室、产品质量检测中心等各种领域的数据采集、波形分析和处理系统。也可构成工业生产过程监控系统。它的主要…

ThreeWayBranch 优化阅读笔记

1. 优化目的 通过重排三分支的 BB 块减少比较指令的执行次数 代码路径: bolt/lib/Passes/ThreeWayBranch.cpp2. 效果 优化前: 注: 黄色数字表示BB块编号, 紫色表示该分支跳转的次数,绿色是代码里BB块的变量名 ThreeWayBranc…

P6327 区间加区间 sin 和 (线段树+数学)

传送门https://www.luogu.com.cn/problem/P6327 比较板子的一题,主要考察公式 //sin(ax)sinxcosasinacosx //cos(ax)cosacosx-sinasinx 直接贴代码吧 // Problem: // P6327 区间加区间 sin 和 // // Contest: Luogu // URL: https://www.luogu.com.cn/pr…

@Conditional注解详解

目录 一、Conditional注解作用 二、Conditional源码解析 2.1 Conditional源码 2.2 Condition源码 三、Conditional案例 3.1 Conditional作用在类上案例 3.1.1 配置文件 3.1.2 Condition实现类 3.1.3 Bean内容类 3.1.4 Config类 3.1.5 Controller类 3.1.6 测试结果 3…

ChatGPT GPT4科研应用、数据分析与机器学习、论文高效写作、AI绘图技术

原文链接:ChatGPT GPT4科研应用、数据分析与机器学习、论文高效写作、AI绘图技术https://mp.weixin.qq.com/s?__bizMzUzNTczMDMxMg&mid2247596849&idx3&sn111d68286f9752008bca95a5ec575bb3&chksmfa823ad6cdf5b3c0c446eceb5cf29cccc3161d746bdd9f2…

Lim接口测试平台开展自动化的优势

一、数据对比 使用Lim接口测试平台后,相比以往采用Postman或excel关键字驱动带来的效率提升: 编写效率提升300%,原来10个步骤的用例,一个工作日调试编写只能输出6条,现在一天能输出18条。维护成本复杂度降低100%&…

Vue3.0里为什么要用 Proxy API 替代 defineProperty API

一、Object.defineProperty 定义:Object.defineProperty() 方法会直接在一个对象上定义一个新属性,或者修改一个对象的现有属性,并返回此对象 为什么能实现响应式 通过defineProperty 两个属性,get及set get 属性的 getter 函…

北斗卫星助力海上风电厂:打造海上绿色能源新时代

北斗卫星助力海上风电厂:打造海上绿色能源新时代 近日,东海航海保障中心温州航标处在华能苍南海上风电场完成首套北斗水上智能感知综合预警系统现场安装调试工作,经现场效能测定,能有效保障海上风电场运行安全和海域船舶通航安全…

华为OD机试 - 模拟数据序列化传输(Java JS Python C C++)

题目描述 模拟一套简化的序列化传输方式,请实现下面的数据编码与解码过程 编码前数据格式为 [位置,类型,值],多个数据的时候用逗号分隔,位置仅支持数字,不考虑重复等场景;类型仅支持:Integer / String / Compose(Compose的数据类型表示该存储的数据也需要编码)编码后数…

光电容积脉搏波PPG信号分析笔记

1.脉搏波信号的PRV分析 各类分析参数记参数 意义 公式 参数意义 线性分析 时域分析 均值MEAN 反应RR间期的平均水平 总体标准差SDNN 评估24小时长程HRV的总体变化, SDNN < 50ms 为异常,SDNN>100ms 为正常;…

如何解决爬虫程序访问速度受限问题

目录 前言 一、代理IP的获取 1. 自建代理IP池 2. 购买付费代理IP 3. 使用免费代理IP网站 二、代理IP的验证 三、使用代理IP进行爬取 四、常见问题和解决方法 1. 代理IP不可用 2. 代理IP速度慢 3. 代理IP被封禁 总结 前言 解决爬虫程序访问速度受限问题的一种常用方…

群晖部署私人聊天服务器Vocechat并结合内网穿透实现公网远程访问

文章目录 1. 拉取Vocechat2. 运行Vocechat3. 本地局域网访问4. 群晖安装Cpolar5. 配置公网地址6. 公网访问小结 7. 固定公网地址 如何拥有自己的一个聊天软件服务? 本例介绍一个自己本地即可搭建的聊天工具,不仅轻量,占用小,且功能也停强大,它就是Vocechat. Vocechat是一套支持…

怎么把视频变成gif动图?一招在线生成gif动画

MP4是一种常见的视频文件格式,它是一种数字多媒体容器格式,可以用于存储视频、音频和字幕等多种媒体数据。MP4格式通常用于在计算机、移动设备和互联网上播放和共享视频内容。要将MP4视频转换为GIF格式,您可以使用专门的视频转gif工具。这个工…

中科数安|——如何防止别人复制文档内容?

#如何防止别人复制文档内容# 中科数安所提供的防止别人复制文档内容的措施主要包括但不限于以下几个方面: www.weaem.com 1. **文档加密与权限控制**: - 对关键文档进行加密处理,确保只有获得授权的人员才能解密并查看文档内容。 - 实施精…

Java项目:基于Springboot+vue实现的付费自习室系统设计与实现(源码+数据库+毕业论文)附含微信小程序端代码

一、项目简介 本项目是一套基于Springbootvue实现的付费自习室系统 包含:项目源码、数据库脚本等,该项目附带全部源码可作为毕设使用。 项目都经过严格调试,eclipse或者idea 确保可以运行! 该系统功能完善、界面美观、操作简单、…

即时设计是什么?设计大佬在线讲解

即时设计是一种互联网产品设计工具。产品原型设计软件由以下四个部分介绍: 1、什么是即时设计? 2、即时设计产品和服务怎么样? 3、即时设计的优点是什么?优点是什么? 4、即时设计的客户是什么?哪些公司…

windows的vmdk文件转qcow2运行蓝屏

背景 使用qemu-img将做好的vmware虚拟机转为qcow2到gns3中运行,Linux、Win7、Win10都没出现蓝屏,但Win XP却在开机时蓝屏了,错误代码:0x0000007B 解决方案 最终在proxmox上找到方案:https://pve.proxmox.com/wiki/Ad…

(一区)基于模型的连续和离散全局优化方法

Model-based methods for continuous and discrete global optimization 1.摘要 本文综述了下基于模型的连续和离散全局优化方法,并提出了一种叠加替代信息的新方法。 2.介绍 比较水。。作者说,本文是首次尝试提供对连续和离散建模方法的可理解的调查…

微信自动回复的优势及设置方法

自动回复功能的优势: 1、可设置不重复触发时间和生效时间段,回复效果更智能,提升联系人体验; 2、可以多微信同时设置,可直接导入素材库内容,提高工作效率; 3、多个关键词、多条回复内容&…