Llama 3.2 微调指南

news2024/10/6 0:17:30

让我们通过微调 Llama 3.2 来找到一些精神上的平静。

我们需要安装 unsloth,以更小的尺寸实现 2 倍的快速训练

!pip install unsloth

!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
我们将使用 Unsloth,因为它显著提高了微调大型语言模型 (LLM) 的效率,特别是 LLaMA 和 Mistral。使用 Unsloth,我们可以使用高级量化技术(例如 4 位和 16 位量化)来减少内存并加快训练和推理速度。这意味着我们甚至可以在资源有限的硬件上部署强大的模型,而不会影响性能。

此外,Unsloth 广泛的兼容性和定制选项允许执行量化过程以满足产品的特定需求。这种灵活性加上其将 VRAM 使用量减少高达 60% 的能力,使 Unsloth 成为 AI 工具包中必不可少的工具。它不仅仅是优化模型,而是让尖端 AI 更易于访问,更高效地应用于现实世界。

对于微调,我使用了以下设置:

  • Torch 2.1.1 - CUDA 12.1 可实现高效计算。
  • Unsloth 可实现大型语言模型 (LLM) 的 2 倍更快的训练速度。
  • H100 NVL GPU 可满足密集处理要求,但你可以使用功率较低的 GPU,即 Kaggle GPU。

为什么是 LLaMA 3.2?

它是开源且可访问的,并提供了根据特定需求进行自定义和微调的灵活性。由于 Meta 的模型权重是开源的,因此可以非常轻松地对任何问题进行微调,我们将在 Hugging Face 的心理健康数据集上对其进行微调

 NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 - AI模型在线查看 - Three.js虚拟轴心开发包 - 3D模型在线减面 - STL模型在线切割

1、Python库

数据处理和可视化

import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
plt.style.use('ggplot')

LLM模型训练:

import torch
from trl import SFTTrainer
from transformers import TrainingArguments, TextStreamer
from unsloth.chat_templates import get_chat_template
from unsloth import FastLanguageModel
from datasets import Dataset
from unsloth import is_bfloat16_supported

# Saving model
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Warnings
import warnings
warnings.filterwarnings("ignore")

%matplotlib inline

2、调用数据集

data = pd.read_json("hf://datasets/Amod/mental_health_counseling_conversations/combined_dataset.json", lines=True)

3、探索性数据分析

让我们检查一下每个上下文中的单词长度:

data['Context_length'] = data['Context'].apply(len)
plt.figure(figsize=(10, 3))
sns.histplot(data['Context_length'], bins=50, kde=True)
plt.title('Distribution of Context Lengths')
plt.xlabel('Length of Context')
plt.ylabel('Frequency')
plt.show()

注意:如上所示,单词数最少为 1500 个,而且存在显著差异,因此我们只使用 1500 个或更少单词的数据。

filtered_data = data[data['Context_length'] <= 1500]

ln_Context = filtered_data['Context'].apply(len)
plt.figure(figsize=(10, 3))
sns.histplot(ln_Context, bins=50, kde=True)
plt.title('Distribution of Context Lengths')
plt.xlabel('Length of Context')
plt.ylabel('Frequency')
plt.show()

注意:现在可以使用这些数据。

现在让我们检查一下每个回复的单词长度:

ln_Response = filtered_data['Response'].apply(len)
plt.figure(figsize=(10, 3))
sns.histplot(ln_Response, bins=50, kde=True, color='teal')
plt.title('Distribution of Response Lengths')
plt.xlabel('Length of Response')
plt.ylabel('Frequency')
plt.show()

注意:这也是 4000 字长度的回应之后,出现了明显的下降。

filtered_data = filtered_data[ln_Response <= 4000]

ln_Response = filtered_data['Response'].apply(len)
plt.figure(figsize=(10, 3))
sns.histplot(ln_Response, bins=50, kde=True, color='teal')
plt.title('Distribution of Response Lengths')
plt.xlabel('Length of Response')
plt.ylabel('Frequency')
plt.show()

注意:不需要进行这样的数据准备来处理 LLM 模型的文本长度,但为了保持字数的一致性,我仅以 4000 个字以下的字为例,以便你可以根据需要进行任何数据预处理。

4、模型训练

让我们深入研究 Llama 3.2 模型并在我们的数据上进行训练。

4.1 加载模型

我们将使用只有 10 亿个参数的 Llama 3.2,但你也可以使用 30 亿、110 亿或 900 亿个版本。

也可以根据你的要求遵循以下关键方面:

  • 最大序列长度

我们使用了 max_seq_length 5020,这是模型中可以使用的最大标记数,可以在单个输入序列中处理。这对于需要处理长文本的任务至关重要,可确保模型在每次传递中都能捕获更多上下文。可以根据要求使用它。

  • 加载 Llama 3.2 模型

使用 FastLanguageModel.from_pretrained 和特定的预训练模型 unsloth/Llama-3.2-1B-bnb-4bitt 加载模型和标记器。这针对 4 位精度进行了优化,可减少内存使用量并提高训练速度,而不会显着影响性能。 load_in_4bit=True 参数可实现这种高效的 4 位量化,使其更适合在性能较弱的硬件上进行微调。

  • 应用 PEFT(参数高效微调)

然后我们使用 get_peft_model 配置模型,它应用了 LoRA(低秩自适应)技术。这种方法侧重于仅微调模型的特定层或部分,而不是整个网络,从而大大减少了所需的计算资源。

参数r=16 和 lora_alpha=16 等可调整这些自适应的复杂性和缩放比例。使用 target_modules 指定应调整模型的哪些层,其中包括涉及注意机制的关键组件,如 q_proj、 k_proj 和 v_proj

use_rslora=True 可激活 Rank-Stabilized LoRA,从而提高微调过程的稳定性。 use_gradient_checkpointing="unsloth" 确保通过选择性地仅存储必要的计算来优化训练期间的内存使用,从而进一步提高模型的效率。

  • 验证可训练参数

最后,我们使用 model.print_trainable_parameters() 打印出将在微调期间更新的参数数量,从而验证是否只训练了模型的预期部分。

这种技术组合不仅使微调过程更加高效,而且更易于访问,即使在计算资源有限的情况下,你也可以部署此模型。

将 tokenz 的最大长度设置为 5020 足以作为低秩自适应 (LoRA) 进行训练,但您可以根据你的数据和要求使用。

max_seq_length = 5020
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Llama-3.2-1B-bnb-4bit",
    max_seq_length=max_seq_length,
    load_in_4bit=True,
    dtype=None,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    lora_alpha=16,
    lora_dropout=0,
    target_modules=["q_proj", "k_proj", "v_proj", "up_proj", "down_proj", "o_proj", "gate_proj"],
    use_rslora=True,
    use_gradient_checkpointing="unsloth",
    random_state = 32,
    loftq_config = None,
)
print(model.print_trainable_parameters())

4.2 为模型提要准备数据

现在是时候设计用于心理健康分析的格式提示了。此功能从心理学角度分析输入文本,识别情绪困扰、应对机制或整体心理健康的指标。它还强调潜在的担忧或积极方面,为每个观察结果提供简要解释。我们将准备这些数据以供模型进一步处理,确保每个输入输出对都具有清晰的格式,以便进行有效分析。

要记住的要点:

  • 数据提示结构

data_prompt 是一个格式化的字符串模板,旨在指导模型分析提供的文本。它包括输入文本(上下文)和模型响应的占位符。该模板专门提示模型识别心理健康指标,使模型更容易微调心理健康相关任务。

  • 序列结束标记

从标记器中检索 EOS_TOKEN 以表示每个文本序列的结束。此标记对于模型识别提示何时结束至关重要,有助于在训练或推理期间维护数据的结构。

  • 格式化函数

formatting_prompt 用于获取一批示例并根据 data_prompt 对其进行格式化。它遍历输入和输出对,将它们插入模板并在末尾附加 EOS 标记。然后,该函数返回一个包含格式化文本的字典,可用于模型训练或评估。

  • 函数输出

该函数输出一个字典,其中键为“文本”,值是格式化字符串的列表。每个字符串代表模型的完整准备提示,结合了上下文、响应和结构化提示模板。

data_prompt = """Analyze the provided text from a mental health perspective. Identify any indicators of emotional distress, coping mechanisms, or psychological well-being. Highlight any potential concerns or positive aspects related to mental health, and provide a brief explanation for each observation.

### Input:
{}

### Response:
{}"""

EOS_TOKEN = tokenizer.eos_token
def formatting_prompt(examples):
    inputs       = examples["Context"]
    outputs      = examples["Response"]
    texts = []
    for input_, output in zip(inputs, outputs):
        text = data_prompt.format(input_, output) + EOS_TOKEN
        texts.append(text)
    return { "text" : texts, }

4.3 格式化数据以进行训练

training_data = Dataset.from_pandas(filtered_data)
training_data = training_data.map(formatting_prompt, batched=True)

4.4 使用自定义参数和数据进行模型训练

使用 sudo apt-get update 刷新可用软件包列表,使用 sudo apt-get install build-essential 安装必备工具。如果出现任何错误,请在 shell 上运行此命令。

#sudo apt-get update
#sudo apt-get install build-essential

4.5 训练设置开始微调!

我们将使用模型和标记器以及训练数据集初始化 SFTTrainer。 dataset_text_field 参数指定数据集中包含我们上面准备的用于训练的文本的字段。训练器负责管理微调过程,包括数据处理和模型更新。

训练参数如下:

TrainingArguments 类用于定义训练过程的关键超参数。这些包括:

  • learning_rate=3e-4:设置优化器的学习率。
  • per_device_train_batch_size=32:定义每个设备的批次大小,优化 GPU 使用率。
  • num_train_epochs=20:指定训练周期数。
  • fp16=not is_bfloat16_supported() 和 bf16=is_bfloat16_supported():启用混合精度训练以减少内存使用量,具体取决于硬件支持。
  • optim="adamw_8bit":使用 8 位 AdamW 优化器来高效使用内存。
  • weight_decay=0.01:应用权重衰减以防止过度拟合。
  • output_dir="output":指定将保存训练模型和日志的目录。

最后,我们调用 trainer.train() 方法来启动训练过程。它使用我们定义的参数来微调模型,调整权重并从提供的数据集中学习。训练器还处理数据打包和梯度累积,优化训练管道以获得更好的性能。

有时 pytorch 会保留内存并且不会释放回来。设置此环境变量可以帮助避免内存碎片。你可以在运行模型之前在环境或脚本中设置它

export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

如果 GPU 中不再需要变量,可以使用 del 删除它们,然后调用

torch.cuda.empty_cache()
trainer=SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=training_data,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    packing=True,
    args=TrainingArguments(
        learning_rate=3e-4,
        lr_scheduler_type="linear",
        per_device_train_batch_size=16,
        gradient_accumulation_steps=8,
        num_train_epochs=40,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.01,
        warmup_steps=10,
        output_dir="output",
        seed=0,
    ),
)

trainer.train()

4.6 推理

text="I'm going through some things with my feelings and myself. I barely sleep and I do nothing but think about how I'm worthless and how I shouldn't be here. I've never tried or contemplated suicide. I've always wanted to fix my issues, but I never get around to it. How can I change my feeling of being worthless to everyone?"

注意:让我们使用微调模型进行推理,以便根据与心理健康相关的提示生成反应!

以下是需要注意的一些要点:

model = FastLanguageModel.for_inference(model) 专门为推理配置模型,优化其生成响应的性能。

使用 tokenizer 对输入文本进行标记,它将文本转换为模型可以处理的格式。我们使用 data_prompt 来格式化输入文本,而将响应占位符留空以从模型获取响应。 return_tensors = "pt" 参数指定输出应为 PyTorch 张量,然后使用 .to("cuda") 将其移动到 GPU 以加快处理速度。

model.generate 方法根据标记化的输入生成响应。参数 max_new_tokens = 5020 和 use_cache = True 确保模型可以通过利用来自先前层的缓存计算来有效地生成长而连贯的响应。

model = FastLanguageModel.for_inference(model)
inputs = tokenizer(
[
    data_prompt.format(
        #instructions
        text,
        #answer
        "",
    )
], return_tensors = "pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens = 5020, use_cache = True)
answer=tokenizer.batch_decode(outputs)
answer = answer[0].split("### Response:")[-1]
print("Answer of the question is:", answer)

问题的答案如下:

I'm sorry to hear that you are feeling so overwhelmed. It sounds like you are trying to figure out what is going on with you. I would suggest that you see a therapist who specializes in working with people who are struggling with depression. Depression is a common issue that people struggle with. It is important to address the issue of depression in order to improve your quality of life. Depression can lead to other issues such as anxiety, hopelessness, and loss of pleasure in activities. Depression can also lead to thoughts of suicide. If you are thinking of suicide, please call 911 or go to the nearest hospital emergency department. If you are not thinking of suicide, but you are feeling overwhelmed, please call 800-273-8255. This number is free and confidential and you can talk to someone about anything. You can also go to www.suicidepreventionlifeline.org to find a local suicide prevention hotline.<|end_of_text|>

注意:以下是我们如何安全地将经过微调的模型及其标记器推送到 Hugging Face Hub,以便任何人都可以使用: ImranzamanML/1B_finetuned_llama3.2 。

os.environ["HF_TOKEN"] = "hugging face token key, you can create from your HF account."
model.push_to_hub("ImranzamanML/1B_finetuned_llama3.2", use_auth_token=os.getenv("HF_TOKEN"))
tokenizer.push_to_hub("ImranzamanML/1B_finetuned_llama3.2", use_auth_token=os.getenv("HF_TOKEN"))

注意:我们还可以在机器本地保存微调后的模型及其标记器。

model.save_pretrained("model/1B_finetuned_llama3.2")
tokenizer.save_pretrained("model/1B_finetuned_llama3.2")

下面的代码展示了如何加载已保存的模型并使用它!

model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "model/1B_finetuned_llama3.2",
max_seq_length = 5020,
dtype = None,
load_in_4bit = True)

原文链接:Llama 3.2 微调指南 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2190794.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OpenCV马赛克

#马赛克 import cv2 import numpy as np import matplotlib.pyplot as pltimg cv2.imread(coins.jpg,1) imgInfo img.shape height imgInfo[0] width imgInfo[1]for m in range(200,400): #m,n表示打马赛克区域for n in range(200,400):# pixel ->10*10if m%10 0 and …

初识Linux · 文件(1)

目录 前言&#xff1a; 回顾语言层面的文件 理解文件的预备知识 文件和磁盘 使用和认识系统调用函数 前言&#xff1a; 本文以及下篇文章&#xff0c;揭露的都是Linux中文件的奥秘&#xff0c;对于文件来说&#xff0c;初学Linux第一节课接触的就是文件&#xff0c;对于C…

Windows删除service服务

Windows删除service服务 找到命令提示符&#xff1a; 右键&#xff0c;以管理员身份运行 输入&#xff1a; sc delete 服务名 Windows根据TCP端口号查找进程PID再kill进程_windows tcpkill-CSDN博客文章浏览阅读5.3k次&#xff0c;点赞42次&#xff0c;收藏104次。Windows根据…

【408计算机考研课程】数据结构-数据结构在学什么?

前言 数据结构在学什么&#xff1f; 如何用程序代码把现实世界的问题信息化如何用计算机高效地处理这些信息从而创造价值 第一章&#xff1a;数据结构在学什么&#xff1f; 总览 什么是数据&#xff1f; 简介&#xff1a;数据是信息的载体&#xff0c;是描述客观事物属性的数、…

【在Linux世界中追寻伟大的One Piece】进程信号

目录 1 -> 信号入门 1.1 -> 生活角度的信号 1.2 -> 技术应用角度的信号 1.3 -> 注意 2 -> 信号的概念 2.1 -> 用kill -l命令可以查看系统定义的信号列表 2.2 -> 信号处理常见方式 3 -> 产生信号 3.1 -> Core Dump 3.2 -> 调用系统函数…

已解决-Nacos明明成功运行,但Spring报错连接不上

这天使用windows本地nacos的时候&#xff0c;一直报错&#xff1a; Caused by: com.alibaba.nacos.api.exception.NacosException: Request nacos server failed: Caused by: com.alibaba.nacos.api.exception.NacosException: Client not connected, current status:STARTIN…

(计算机组成原理)

计算机的发展 计算机系统硬件&#xff08;计算机的实体&#xff0c;如主机&#xff0c;外设等&#xff09;软件&#xff08;由具有各种特殊功能的程序组成&#xff09; 硬件是计算机系统的物理基础&#xff0c;硬件决定瓶颈&#xff0c;软件决定性能发挥的程度 第一台电子数字计…

YOLOv4和Darknet实现坑洼检测

关于深度实战社区 我们是一个深度学习领域的独立工作室。团队成员有&#xff1a;中科大硕士、纽约大学硕士、浙江大学硕士、华东理工博士等&#xff0c;曾在腾讯、百度、德勤等担任算法工程师/产品经理。全网20多万粉丝&#xff0c;拥有2篇国家级人工智能发明专利。 社区特色…

IDEA如何自定义创建类的文档注释

说明&#xff1a;在IDEA中&#xff0c;创建一个Java类文件&#xff0c;会在类上面自动生成文档注释&#xff0c;如下&#xff1a; 看样子&#xff0c;默认是计算机的用户名&#xff0c;然后加上当前的创建时间。可以在IDEA中的Setting中设置&#xff0c;如下&#xff1a; /*** …

汽车追尾为什么是后车的责任?

简单点说&#xff1a;因为人后面没有长眼睛。 结论 在汽车追尾事故中&#xff0c;通常情况下后车被认为是责任方的原因在于交通法规对驾驶安全标准的约定和实践中的责任识别原则。虽然追尾事故常见地被归责于后车&#xff0c;但具体判断并不是绝对的&#xff0c;仍需综合多种…

C++11中的特性

这里主要讲解一些C11相较于C98所新增的比较实用的新特性。 C11的官方文档&#xff1a;C11 - cppreference.comhttps://en.cppreference.com/w/cpp/11 一、列表初始化&#xff08;List-initialization&#xff09; &#xff08;一&#xff09;、使用“{}”进行初始化 在C98中&…

有关自连接表的统一封装

表结构 RecursionBean Getter Setter ToString JsonInclude(JsonInclude.Include.NON_EMPTY) public class RecursionBean<T> extends BaseVO {/*** 编号*/private T id;/*** 父权限ID&#xff0c;根节点的父权限为空* 注释掉JsonIgnore&#xff0c;是为了前端判断是否…

Linux驱动开发常用调试方法汇总

引言&#xff1a;在 Linux 驱动开发中&#xff0c;调试是一个至关重要的环节。开发者需要了解多种调试方法&#xff0c;以便能够快速定位和解决问题。 1.利用printk 描述&#xff1a; printk 是 Linux 内核中的一个调试输出函数&#xff0c;类似于用户空间中的 printf。它用于…

CE找CSGO人物坐标和视角基址-幽络源原创

前言 幽络源站长本次免费分享的是CE找CSGO人物坐标和视角基址 本教程分为两篇&#xff0c;当前为上篇->找基址 所具备的知识 CE的使用 教程目的 通过CE找到一些基地址&#xff0c;然后结合Python实现CSGO的透视绘制&#xff0c;这里我们是纯手写透视。 第一步&#x…

如何使用CMD命令启动应用程序(二)

说明&#xff1a;去年1024发布了一篇博客&#xff0c;介绍如何使用CMD命令启动应用程序&#xff0c;但实际情况&#xff0c;有些程序可能无法用配置环境变量的方式来启动&#xff0c;本文针对两种情况下的程序&#xff0c;如何使用CMD命令来启动&#xff0c;算是对上一篇博客的…

Java开发必知必会的一些工具

本文主要介绍 Java 程序员应该学习的一些基本和高级工具。 如果你想成为一名更好的程序员&#xff0c;最重要的技巧之一就是学习你的编程工具。 Java 世界中存在着如此多的工具&#xff0c;从 Eclipse、NetBeans 和 IntelliJ IDEA 等著名的 IDE 到 JConsole、VisualVM、Eclipse…

class 004 选择 冒泡 插入排序

我感觉这个真是没有什么好讲的, 这个是比较简单的, 感觉没有什么必要写一篇博客, 而且这个这么简单的排序问题肯定有人已经有写好的帖子了, 肯定写的比我好, 所以我推荐大家直接去看“左程云”老师的讲解就很好了, 一定是能看懂的, 要是用文字形式再写一遍, 反而有点画蛇添足了…

计算机视觉算法知识详解(含代码示例)

✅作者简介&#xff1a;2022年博客新星 第八。热爱国学的Java后端开发者&#xff0c;修心和技术同步精进。 &#x1f34e;个人主页&#xff1a;Java Fans的博客 &#x1f34a;个人信条&#xff1a;不迁怒&#xff0c;不贰过。小知识&#xff0c;大智慧。 &#x1f49e;当前专栏…

算法: 二分查找题目练习

文章目录 二分查找二分查找在排序数组中查找元素的第一个和最后一个位置搜索插入位置x 的平方根山脉数组的峰顶索引寻找峰值寻找旋转排序数组中的最小值点名 总结精华模版 二分查找 二分查找 没啥可说的,轻轻松松~ class Solution {public int search(int[] nums, int target…

IDEA 配置 Git 详解

本文将介绍在IntelliJ IDEA 中如何配置Git 没有安装配置 Git 的可以参考我的这篇文章&#xff1a;安装配置 Git 一、操作环境及准备 1.win 10 2.已安装且配置了Git 3.有Gitee账户 4.安装了IntelliJ IDEA 2023.2.1 5.全程联网 二、配置步骤 2.1 配置git 1.采用全局设置&…