在Codelab对llama3做Lora Fine tune微调

news2024/10/6 10:42:37

Unsloth 高效微调大模型的工具,通过Unsloth微调Llama3, Mistral, Gemma 速度提升2-5倍,内存减少70%!

Codelab 创建一个jupyter notebook

在这里插入图片描述
选择 T4 GPU
在这里插入图片描述
安装Fine tune 相关的lib

%%capture
import torch
major_version, minor_version= torch.cuda.get_device_capability()
# Must install separately since Colab has torch 2.2.1, which breaks packages
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
if major_version >= 8:
  # Use this for new GPs like Ampere, Hopper GPUs(RTX 30xx. RIX 40xx, A100. H100. L40)
  !pip install -no-deps packaging ninja einops flash-attn xformers trl peft accelerate bitsandbytes
else:
  # Use this for older GPUs (V100, Tesla T4, RTX 20xx)
  !pip install --no-deps xformers trl peft accelerate bitsandbytes
pass

下载llama3

from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False

# 4bit pre quantized models we support for 4x faster downloading + no OOMs
fourbit_models = [
    "unsloth/mistral-7b-bnb-4bit",
    "unsloth/mistral-7b-instruct-bnb-4bit",
    "unsloth/llama-2-7b-bnb-4bit",
    "unsloth/gemma-7b-bnb-4bit",
    "unsloth/gemma-7b-it-bnb-4bit",
    "unsloth/gemma-2b-bnb-4bit",
    "unsloth/gemma-2b-it-bnb-4bit",
    "unsloth/llama-3-8b-bnb-4bit",
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/llama-3-8b-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf

)

在这里插入图片描述

model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none", # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False, # We support rank stabilized LoRA
    loftq_config = None # And LoftQ
)

在这里插入图片描述

加载hugging face数据集

alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{}
### Input:
{}
### Response:
{}
"""

EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
def formatting_prompts_func(examples):
  instructions = examples["instruction"]
  inputs = examples["input"]
  outputs = examples["output"]
  texts = []
  for instruction, input, output in zip(instructions, inputs, outputs):
    # Must add EOS_TOKEN, otherwise your generation will go on forever!
    text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
    texts.append(text)
  return { "text": texts, }
pass

from datasets import load_dataset
dataset = load_dataset("pinzhenchen/alpaca-cleaned-zh", split="train")
dataset = dataset.map(formatting_prompts_func, batched=True,)



在这里插入图片描述
HuggingFace 官网, 点击数据集 Datasets

在这里插入图片描述
搜索数据集 alpaca-cleaned-zh
在这里插入图片描述
复制数据集的名字 pinzhenchen/alpaca-cleaned-zh
在这里插入图片描述
定义training 方法

from trl import SFTTrainer
from transformers import TrainingArguments

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        max_steps = 60,
        learning_rate = 2e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
    ),
)

打印显存使用情况

#@title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = (gpu_stats.name). Max memory = (max_memory) GB.")
print(f"(start_gpu_memory) GB of memory reserved.")

在这里插入图片描述
开始FineTune

trainer_stats = trainer.train()

在这里插入图片描述

#@title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory*100, 3)
lora_percentage = round(used_memory_for_lora / max_memory*100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} GB.")

在这里插入图片描述
用fineTune 过的model,做问答

# alpaca_prompt = Copied from above
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer(
    [
        alpaca_prompt.format(
            "如何保持健康", # instruction
            "", # input
            "", # output - leave this blank for generation!
        )
    ], return_tensors = "pt"
).to("cuda")

outputs = model.generate(**inputs, max_new_tokens = 64, use_cache=True)
tokenizer.batch_decode(outputs)

在这里插入图片描述
TextStreamer 流式一个字一个字地打印结果

# alpaca_prompt = Copied from above
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer(
    [
        alpaca_prompt.format(
            "续写这段话", # instruction
            "天天向上,好好学习", # input
            "", # output - leave this blank for generation!
        )
    ], return_tensors = "pt"
).to("cuda")

from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens=128)

在这里插入图片描述
保存model到google drive 和 HuggingFace

model.save_pretrained("lora_model") # local saving
model.push_to_hub("zgpeace/lora_model", token="####") # online saving

google drive
在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1647109.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软考中级-软件设计师(九)数据库技术基础 考点最精简

一、基本概念 1.1数据库与数据库系统 数据:是数据库中存储的基本对象,是描述事物的符号记录 数据库(DataBase,DB):是长期存储在计算机内、有组织、可共享的大量数据集合 数据库系统(DataBas…

Navicat for MySQL Mac:数据库管理与开发的理想工具

Navicat for MySQL Mac是一款功能强大的数据库管理与开发工具,专为Mac用户设计,旨在提供高效、便捷的数据库操作体验。 它支持创建、管理和维护MySQL和MariaDB数据库,通过直观的图形界面,用户可以轻松进行数据库连接、查询、编辑和…

力扣:221. 最大正方形

221. 最大正方形 在一个由 0 和 1 组成的二维矩阵内,找到只包含 1 的最大正方形,并返回其面积。 示例 1: 输入:matrix [["1","0","1","0","0"],["1","0"…

数字藏品平台遭受科技攻击时的防护策略与攻击类型判定

随着区块链技术和数字经济的飞速发展,数字藏品平台逐渐成为炙手可热的投资领域。然而,这也使其成为了黑客攻击的重要目标。本文将深入探讨数字藏品平台可能遭遇的几种主要科技攻击类型,并提出相应的防护措施和判定方法。 一、51%攻击 攻击描…

太速科技-FMC377_双AD9361 射频收发模块

FMC377_双AD9361 射频收发模块 FEATURES: ◆ Coverage from 70M ~ 6GHz RF ◆ Flexible rate 12 bit ADC/DAC ◆ Fully-coherent 4x4 MIMO capability, TDD/FDD ◆ RF ports: 50Ω Matched ◆ support both internal reference and exter…

cmake进阶:变量的作用域说明二(从函数作用域方面)

一. 简介 前一篇文章从函数作用域方面学习了 变量的作用域。文章如下: cmake进阶:变量的作用域-CSDN博客 本文继续从函数作用域方面学习了 变量的作用域。 二. 变量的作用域 1. 函数内定义与外部同名的变量 向顶层 CMakeLists.txt添加如下代码&a…

leetcode-缺失的第一个正整数-96

题目要求 思路 1.这里的题目要求刚好符合map和unordered_map 2.创建一个对应map把元素添加进去,用map.find(res)进行查找,如果存在返回指向该元素的迭代器,否则返回map::end()。 代码实现 class Solution { public:int minNumberDisappeare…

项目管理-干系人管理

项目管理:每天进步一点点~ 活到老,学到老 ヾ(◍∇◍)ノ゙ 何时学习都不晚,加油 1.干系人管理-主要框架 重点内容: ①ITTO 输入,输出工具和技术。 ②问题和解决方案。 ③论文可以结合范围&am…

[Linux][网络][TCP][三][超时重传][快速重传][SACK][D-SACK][滑动窗口]详细讲解

目录 1.超时重传1.什么是超时重传?2.超时时间是如何确定的? 2.快速重传3.SACK4.D-SACK1.ACK丢失2.网络延迟 5.滑动窗口0.问题抛出1.发送方的滑动窗口2.如何表示发送方的四个部分?3.接收方的滑动窗口4.滑动窗口的完善理解 1.超时重传 1.什么是…

14【PS作图】像素画尺寸大小

【背景介绍】本节介绍像素图多大合适 下图是160*144像素大小,有一个显示文本的显示器,还有一个有十几个键的键盘 像素画布尺寸 电脑16像素,但还有一个显示屏 下图为240*160 在场景素材,和对话素材中,用的是不同尺寸的头像,对话素材中的头像会更清楚,尺寸会更大 远处…

五月节放假作业讲解

目录 作业1: 问题: 结果如下 作业2: 结果: 作业1: 初始化数组 问题: 如果让数组初始化非0数会有问题 有同学就问了,我明明已经初始化定义过了,为啥还有0呀 其实这种初始化只会改变第一个…

服务攻防-数据库安全RedisCouchDBH2database未授权访问CVE漏洞

#知识点: 1、数据库-Redis-未授权RCE&CVE 2、数据库-Couchdb-未授权RCE&CVE 3、数据库-H2database-未授权RCE&CVE#章节点: 1、目标判断-端口扫描&组合判断&信息来源 2、安全问题-配置不当&CVE漏洞&弱口令爆破 3、复现对象-数…

强化学习:时序差分法【Temporal Difference Methods】

强化学习笔记 主要基于b站西湖大学赵世钰老师的【强化学习的数学原理】课程,个人觉得赵老师的课件深入浅出,很适合入门. 第一章 强化学习基本概念 第二章 贝尔曼方程 第三章 贝尔曼最优方程 第四章 值迭代和策略迭代 第五章 强化学习实例分析:GridWorld…

《Video Mamba Suite》论文笔记(4)Mamba在时空建模中的作用

原文翻译 4.4 Mamba for Spatial-Temporal Modeling Tasks and datasets.最后,我们评估了 Mamba 的时空建模能力。与之前的小节类似,我们在 Epic-Kitchens-100 数据集 [13] 上评估模型在zero-shot多实例检索中的性能。 Baseline and competitor.ViViT…

解决Win10家庭版找不到组策略gpedit.msc的·方法

因为电脑出问题,一开机就会自动开启ie浏览器,所以就想找有没有方法解决,然后就了解到了gpedit.msc的作用以及相关的一些方法,也是为之后也许有人遇到相同的问题有个提供方法的途径。 首先我们直接运行gpedit.msc 是找不到的&…

Win10彻底关闭Antimalware Service Executable解决cpu内存占用过高问题

1,win键R打开运行输入gpedit.msc,即可打开本地组策略编辑器 2.依次打开:管理模板----windows组件----windows Defender-----实时保护 3.然后鼠标双击右侧的“不论何时启用实时保护,都会启用进程扫描。勾选 已禁用,就可…

embedding介绍和常用三家模型对比

Embedding(嵌入)是一种在计算机科学中常用的技术,尤其是在自然语言处理(NLP)领域。在NLP中,embedding通常指的是将文本中的单词、短语或句子转换为固定维度的向量(vector)。这些向量代表了文本中的语义和上下文信息。 1.embedding 介绍 1.1 为什么需要Embedding? 在…

基于鸢尾花数据集的四种聚类算法(kmeans,层次聚类,DBSCAN,FCM)和学习向量量化对比

基于鸢尾花数据集的四种聚类算法(kmeans,层次聚类,DBSCAN,FCM)和学习向量量化对比 注:下面的代码可能需要做一点参数调整,才得到所有我的运行结果。 kmeans算法: import matplotlib.pyplot a…

JavaScript之数据类型(2)——复杂类型(object)

object的介绍: 我对于object的理解是和C/C中的结构体一样,是一个自定义的数据类型,我们可以通过多个简单的数据类型来定义一个便于我们使用的新的数据类型。 在网上某佬对于其解释如下: Object类型,我们也称为一个对象…

ubuntu 安装单节点HBase

下载HBase mkdir -p /home/ellis/HBase/ cd /home/ellis/HBase/ wget https://downloads.apache.org/hbase/2.5.8/hbase-2.5.8-bin.tar.gz tar -xvf hbase-2.5.8-bin.tar.gz安装java jdk sudo apt install openjdk-11-jdksudo vim /etc/profileexport JAVA_HOME/usr/lib/jvm/…