QLoRA代码实战

news2024/10/5 14:01:21

QLoRA原理参考:
BiliBili:4bit量化与QLoRA模型训练
zhihu:QLoRA(Quantized LoRA)详解

下载llama3-8b模型

from modelscope import snapshot_download
model_dir = snapshot_download('LLM-Research/Meta-Llama-3-8B-Instruct')

设置quantization_config

from transformers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

加载模型

加载量化后的llama3-8b模型,大概需要6G的GPU显存。

from transformers import AutoModelForCausalLM,AutoTokenizer,TrainingArguments,Trainer,DataCollatorForSeq2Seq
model = AutoModelForCausalLM.from_pretrained(model_dir,quantization_config=quantization_config,low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_dir)

一层的数据类型,可以看到除了layernorm,linear层都进行了量化。

model.layers.0.self_attn.q_proj.weight torch.uint8
model.layers.0.self_attn.k_proj.weight torch.uint8
model.layers.0.self_attn.v_proj.weight torch.uint8
model.layers.0.self_attn.o_proj.weight torch.uint8
model.layers.0.mlp.gate_proj.weight torch.uint8
model.layers.0.mlp.up_proj.weight torch.uint8
model.layers.0.mlp.down_proj.weight torch.uint8
model.layers.0.input_layernorm.weight torch.float16
model.layers.0.post_attention_layernorm.weight torch.float16

预处理模型

from peft import prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(model)

设置LoRA参数

这里使用了默认设置,参数target_modules和modules_to_save可以设置具体训练哪些模块。

config = LoraConfig(task_type=TaskType.CAUSAL_LM)
model = get_peft_model(model, config)
model.print_trainable_parameters()
#trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424

加载并处理数据

数据下载:AI-ModelScope/alpaca-gpt4-data-zh
需要把下载的数据中dataset_infos.json 重命名为datasets_info.json,这样才能正确加载。

from datasets import load_dataset

dataset = load_dataset("alpaca-data-zh")

def process_func(example):
    # print(example)
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    # 将prompt进行tokenize,这里我们没有利用tokenizer进行填充和截断
    # 这里我们自己进行截断,在DataLoader的collate_fn函数中进行填充
    input = example["input"] if example["input"] is not None else ''
    instruction = tokenizer("\n".join(["Human: " + example["instruction"], input]).strip() + "\n\nAssistant: ")
    # 将output进行tokenize,注意添加eos_token
    response = tokenizer(example["output"] + tokenizer.eos_token)
    # 将instruction + output组合为input
    input_ids = instruction["input_ids"] + response["input_ids"]
    attention_mask = instruction["attention_mask"] + response["attention_mask"]
    # prompt设置为-100,不计算loss
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    # 设置最大长度,进行截断
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

tokenized_ds = dataset['train'].map(process_func, remove_columns=dataset['train'].column_names)

设置TrainingArguments

在per_device_train_batch_size=1的情况下,大概需要9G显存。

args = TrainingArguments(
    output_dir="./llama3_4bit",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=32,
    logging_steps=10,
    num_train_epochs=1,
    save_strategy='epoch',
    learning_rate=1e-4,
    # gradient_checkpointing=True,
    # optim="paged_adamw_32bit"

)

训练

trainer = Trainer(
    model=model,
    args=args,
    tokenizer=tokenizer,
    train_dataset=tokenized_ds,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train(resume_from_checkpoint=False)

加载qlora

from transformers import AutoModelForCausalLM,AutoTokenizer
model_path = model_dir #llama3-8b的路径
model = AutoModelForCausalLM.from_pretrained(model_path,quantization_config=config,low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model_qlora = PeftModel.from_pretrained(model=model,model_id="llama3_4bit/checkpoint-7") #qlora路径
#预测
ipt = tokenizer("Human: {}\n{}".format("怎么学习llm", "").strip() + "\n\nAssistant: ", return_tensors="pt").to(model.device)
tokenizer.decode(model_qlora.generate(**ipt, max_length=128, do_sample=True)[0], skip_special_tokens=True)

合并LoRA

合并后的模型大概5.4G。

merge_model = model_qlora.merge_and_unload()
merge_model.save_pretrained("llama3")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2189971.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

wsl2 ubuntu 桥接以太网卡

注意:此方法需要至少 Windows 11 22H2。桥接模式就是将主机网卡与虚拟机虚拟的网卡利用虚拟网桥进行通信。 在桥接的作用下,类似于把宿主机虚拟为一个交换机,所有桥接设置的虚拟机连接到这个交换机的一个接口上,宿主机也同样插在这…

通信工程学习:什么是RARP反向地址解析协议

RARP:反向地址解析协议 RARP(Reverse Address Resolution Protocol,反向地址解析协议)是一种网络协议,其主要作用是在设备只知道物理地址(如MAC地址)时,允许其从网关服务器的地址解析…

致亲爱的Android studio

你的未来发展趋势: 可不可以把兼容性,什么的搞得更好。起因是我想写期末大作业,然后简单的把功能写的差不多了之后,我就想到处看看有没有一套比较好的类似于组件库的东西,但是没找到,然后就把目标锁定到了G…

Vue入门-Node.js安装

进入Node.js中文网 ​​​​​​​点击进入Node.js中文网 或者手动输入网址: https://www.nodejs.com.cn/download.html 点击下载64位安装包: 下载好之后双击进行安装 可选择个性化安装或默认安装 直接点【Next】按钮,此处可根据个人需求…

深度解析 HTTP

我的主页:2的n次方_ 1. HTTP 的简单介绍 HTTP :超文本传输协议,不仅能传输文本,还能传输图片,音频文件,视频 目前基本上都用的是 1.1 版本 https 可以认为是 http 的升级版,区别就是引入了…

【pytorch】张量求导4

再再接上文,看到作者有一个关于向量乘矩阵的描述。 经过搜索发现,现在的pytorch已经修复了这一问题,提供了mv()和matmul()两种方式实现矩阵和一维向量的乘积,可以参看这篇文章。 经过查阅pytorch的文件,找到了cuda侧…

如何利用 Kubernetes 取得成功

Kubernetes是一个开源编排平台,用于自动部署和管理容器化工作负载和服务,目前越来越受欢迎。 该平台由全球贡献者社区维护,其潜在优势包括提高资源效率、提高可扩展性和高可用性。 在过去几年中,Kubernetes 和相关的云原生技术已…

[Qt] 基于 Qt 的文件选择与图片显示功能实现

文章目录 基础版本:open1()功能解析:特点与限制: 增加路径记忆功能:open2()功能解析:特点与改进: 使用智能指针优化内存管理:open3()功能解析:特点与改进: 图片自适应窗口…

线程互斥函数的例子

代码 #include<stdio.h> #include<pthread.h> #include<sched.h> void *producter_f(void *arg); void *consumer_f(void *arg); int buffer_has_item0; pthread_mutex_t mutex; int running1; int main(void) {pthread_t consumer_t;pthread_t producter_t…

PCL 点云体素滤波

目录 一、概述 1.1原理 1.2实现步骤 1.3应用场景 二、代码实现 2.1关键函数 2.1.1 体素滤波实现 2.1.2 可视化函数 2.2完整代码 三、实现效果 PCL点云算法汇总及实战案例汇总的目录地址链接&#xff1a; PCL点云算法与项目实战案例汇总&#xff08;长期更新&#xf…

Session会话管理技术

Session会话管理技术 会话: 两个交互,在开发中是指浏览器和服务器它们两个的交互 会话管理: 管理会话中产生的数据,一般是记录登录状态 补充: 状态管理,就是管理数据 1、 Session概述 Session用于记录用户的状态。Session指的是在一段时间内&#xff0c;单个客户端与Web服务…

【C++】—— 类和对象(中)

【C】—— 类和对象(中) 文章目录 【C】—— 类和对象(中)前言1. 类的默认成员函数2. 构造函数3. 析构函数4. 拷贝构造函数5. 赋值运算符重载5.1 运算符重载5.2 赋值运算符重载 结语 前言 小伙伴们大家好呀&#xff0c;昨天的 【C】——类和对象(上) 大家理解的怎么样了 今天…

【AI人工智能】文心智能体,双人冒险游戏智能体创作分享

背景 最近半年&#xff0c;“AI agent”&#xff08;智能体&#xff09;这一词汇变得非常热门。许多人以为创建自己的智能体会很复杂&#xff0c;实际上&#xff0c;现有的平台已经大大降低了操作门槛。只要有创意&#xff0c;几乎每个人都可以轻松创建属于自己的智能体。今天…

WordPress响应式Git主题响应式CMS主题模板

兼容 IE9、谷歌 Chrome 、火狐 Firefox 等主流浏览器 扁平化的设计加响应式布局&#xff0c;兼容电脑、和各个尺寸手机的完美响应 主题设置面板新增多种AD位&#xff0c;PC端和移动设备各不相同 在主题设置选项中就可以进行基本的SEO设置&#xff1a;首页、分类、文章等页面…

跟我学C++中级篇——函数调用的本质

一、进程的执行过程 正常的情况下&#xff0c;程序会被计算机从硬盘加载到内存中&#xff0c;然后跳转到主入口函数进行执行。依次按照逻辑对相关的模块进行加载调用。其中&#xff0c;最常用的就是调用一个函数&#xff0c;可以说&#xff0c;函数是C/C程序中的一个重要的基础…

Python——异常处理机制

Python 异常处理机制 Python异常与异常处理机制针对 Traceback 的解读try-except-else-finallyexcept语句except语句的机制在 except 语句中引用当前被处理的 Python 异常 finally语句finally语句执行后才能抛出未被处理的异常finally中执行return会导致异常丢失 raise 语句rai…

集合框架02:Collection使用(1)

视频链接&#xff1a;13.05 Collection使用&#xff08;1&#xff09;_哔哩哔哩_bilibilihttps://www.bilibili.com/video/BV1zD4y1Q7Fw?p5&vd_sourceb5775c3a4ea16a5306db9c7c1c1486b5 代码示例&#xff1a; package com.yundait.Demo01;import java.util.ArrayList; i…

hbuilderx+uniapp+Android健身房管理系统 微信小程序z488g

目录 项目介绍支持以下技术栈&#xff1a;具体实现截图HBuilderXuniappmysql数据库与主流编程语言java类核心代码部分展示登录的业务流程的顺序是&#xff1a;数据库设计性能分析操作可行性技术可行性系统安全性数据完整性软件测试详细视频演示源码获取方式 项目介绍 用户功能…

震撼!工业史上第一家万级别规模的工业数字化设备效果图平台

耗时八年打造&#xff0c;国内第一家万级别规模的工业数字化设备效果图平台 平台&#xff1a;www.kingview3d.cn 创作者&#xff1a;kingview3d郭工 行业&#xff1a;煤矿综合自动化、污水处理、净水处理、楼宇暖通、环保工程、医药废水处理、二供、无负压加压站、提升泵站、一…

模拟器GSN3之DHCP动态分配IP地址配置案例

前文《详解DHCP服务工作原理及配置案例》介绍了DHCP服务工作原理&#xff0c;要想彻底理解、应用DHCP服务&#xff0c;须通过实证案例学习&#xff0c;该文在GSN3虚拟环境下&#xff0c;构建DHCP服务的环境。 一、配置环境&#xff1a; 1、GSN3 2、路由器&#xff1a;R1、R2…