【llm 微调code-llama 训练自己的数据集 一个小案例】

news2024/9/22 5:41:06

这也是一个通用的方案,使用peft微调LLM。

准备自己的数据集

根据情况改就行了,jsonl格式,三个字段:context, answer, question

import pandas as pd
import random
import json


data = pd.read_csv('dataset.csv')
train_data = data[['prompt','Code']]
train_data = train_data.values.tolist()

random.shuffle(train_data)


train_num = int(0.8 * len(train_data))

with open('train_data.jsonl', 'w') as f:
    for d in train_data[:train_num]:
        d = {
            'context':'',
            'question':d[0],
            'answer':d[1]
        }
        f.write(json.dumps(d)+'\n')
with open('val_data.jsonl', 'w') as f:
    for d in train_data[train_num:]:
        d = {
            'context':'',
            'question':d[0],
            'answer':d[1]
        }
        f.write(json.dumps(d)+'\n')

初始化

from datetime import datetime
import os
import sys

import torch

from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
)
from transformers import (AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM,
                          TrainingArguments, Trainer, DataCollatorForSeq2Seq)

# 加载自己的数据集
from datasets import load_dataset

train_dataset = load_dataset('json', data_files='train_data.jsonl', split='train')
eval_dataset = load_dataset('json', data_files='val_data.jsonl', split='train')

# 读取模型
base_model = 'CodeLlama-7b-Instruct-hf'

model = AutoModelForCausalLM.from_pretrained(
    base_model,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    device_map="auto",
    low_cpu_mem_usage=True
)

tokenizer = AutoTokenizer.from_pretrained(base_model)

微调前的效果

tokenizer.pad_token = tokenizer.eos_token
prompt = """You are programming coder.

Now answer the question:

{}"""
prompts = [prompt.format(train_dataset[i]['question']) for i in [1,20,32,45,67]]

model_input = tokenizer(prompts, return_tensors="pt", padding=True).to("cuda")


model.eval()
with torch.no_grad():
    outputs = model.generate(**model_input, max_new_tokens=300)
    outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

print(outputs)

进行微调

tokenizer.add_eos_token = True
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"


def tokenize(prompt):
    result = tokenizer(
        prompt,
        truncation=True,
        max_length=512,
        padding=False,
        return_tensors=None,
    )

    # "self-supervised learning" means the labels are also the inputs:
    result["labels"] = result["input_ids"].copy()

    return result


def generate_and_tokenize_prompt(data_point):
    full_prompt =f"""You are a powerful programming model. Your job is to answer questions about a database. You are given a question.

You must output the code that answers the question.

### Input:
{data_point["question"]}

### Response:
{data_point["answer"]}
"""
    return tokenize(full_prompt)


tokenized_train_dataset = train_dataset.map(generate_and_tokenize_prompt)
tokenized_val_dataset = eval_dataset.map(generate_and_tokenize_prompt)


model.train() # put model back into training mode
model = prepare_model_for_int8_training(model)

config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=[
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)

# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
if torch.cuda.device_count() > 1:
    model.is_parallelizable = True
    model.model_parallel = True



batch_size = 128
per_device_train_batch_size = 32
gradient_accumulation_steps = batch_size // per_device_train_batch_size
output_dir = "code-llama-ft"

training_args = TrainingArguments(
        per_device_train_batch_size=per_device_train_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        warmup_steps=100,
        max_steps=400,
        learning_rate=3e-4,
        fp16=True,
        logging_steps=10,
        optim="adamw_torch",
        evaluation_strategy="steps", # if val_set_size > 0 else "no",
        save_strategy="steps",
        eval_steps=20,
        save_steps=20,
        output_dir=output_dir,
        load_best_model_at_end=False,
        group_by_length=True, # group sequences of roughly the same length together to speed up training
        report_to="none", # if use_wandb else "none", wandb
        run_name=f"codellama-{datetime.now().strftime('%Y-%m-%d-%H-%M')}", # if use_wandb else None,
    )

trainer = Trainer(
    model=model,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_val_dataset,
    args=training_args,
    data_collator=DataCollatorForSeq2Seq(
        tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
    ),
)

开始训练

model.config.use_cache = False

old_state_dict = model.state_dict
model.state_dict = (lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())).__get__(
    model, type(model)
)
if torch.__version__ >= "2" and sys.platform != "win32":
    print("compiling the model")
    model = torch.compile(model)
trainer.train()

进行测试

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer

base_model = 'CodeLlama-7b-Instruct-hf'
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(base_model)


output_dir = "code-llama-ft"
model = PeftModel.from_pretrained(model, output_dir)


eval_prompt = """You are a powerful programming model. Your job is to answer questions about a database. You are given a question.

You must output the code that answers the question.

### Input:
Write a function in Java that takes an array and returns the sum of the numbers in the array, or 0 if the array is empty. Except the number 13 is very unlucky, so it does not count any 13, or any number that immediately follows a 13.

### Response:
"""

model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")

model.eval()
with torch.no_grad():
    outputs = model.generate(**model_input, max_new_tokens=100)[0]
print(tokenizer.decode(outputs, skip_special_tokens=True))

主要参考icon-default.png?t=N7T8https://zhuanlan.zhihu.com/p/660933421

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1398400.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深度解析 Compose 的 Modifier 原理 -- DrawModifier

" Jetpack Compose - - Modifier 系列文章 " 📑 《 深入解析 Compose 的 Modifier 原理 - - Modifier、CombinedModifier 》 📑 《 深度解析 Compose 的 Modifier 原理 - - Modifier.composed()、ComposedModifier 》 📑 《 深入解…

牛客小白月赛86 解题报告 | 珂学家 | 最大子数组和变体 + lazy线段树动态区间树

前言 整体评价 终于回归小白月赛的内核了&#xff0c;希望以后也继续保持&#xff0c;_. A. 水盐平衡 思路: 模拟 题目保证没有浓度相等的情况 盐度 a/b&#xff0c; c/d 的比较关系 演变为 ad, bc 两者的大小关系 #include <bits/stdc.h>using namespace std;int …

【北京】买套二手房需要多少钱?

上次我们看了苏州和上海的二手房&#xff0c;这次我们一起来看下北京的二手房价格如何。 数据来源 数据来自贝壳二手房&#xff0c;每个区最多获取了3千条房源信息&#xff0c;数据共计4万条左右 对数据感兴趣的朋友&#xff0c;公众号后台发送北京二手房获取数据文件 各区房…

面试之Glide如何绑定Activity的生命周期

Glide绑定Activity生命周期 Glide.with() 下面都是它的重载方法&#xff0c;Context&#xff0c;Activity&#xff0c;FragmentActivity, Fragment, android.app.Fragment fragment,View都可以作为他的参数&#xff0c;内容大同小异&#xff0c;都是先getRetriever&#xff0…

【C++入门到精通】智能指针 shared_ptr 简介及C++模拟实现 [ C++入门 ]

阅读导航 引言一、简介二、成员函数三、使用示例四、C模拟实现五、std::shared_ptr的线程安全问题六、总结温馨提示 引言 在 C 动态内存管理中&#xff0c;除了 auto_ptr 和 unique_ptr 之外&#xff0c;还有一种智能指针 shared_ptr&#xff0c;它可以让多个指针共享同一个动…

关于大模型学习中遇到的3

来源&#xff1a;网络 Embedding模型 随着大型语言模型的发展&#xff0c;以ChatGPT为首&#xff0c;涌现了诸如ChatPDF、BingGPT、NotionAI等多种多样的应用。公众大量地将目光聚焦于生成模型的进展之快&#xff0c;却少有关注支撑许多大型语言模型应用落地的必不可少的Embed…

STM32407用汇顶的GT911触摸芯片调试实盘

这个配置很关键 代码 #include "stm32f4xx.h" #include "GT9147.h" #include "Touch.h" #include "C_Touch_I2C.h" #include "usart.h" #include "delay.h" #include "LCD.h" #incl…

HarmonyOS 页面跳转控制整个界面的转场动画

好 本文 我们来说 页面间的转场动画 就是 第一个界面到另一个界面 第一个界面的退场和第二个界面的进场效果 首先 我这里 创建了两个页面文件 Index.ets和AppView.ets index组件 编写代码如下 import router from "ohos.router" Entry Component struct Index {b…

视频监控需求记录

记录一下最近要做的需求&#xff0c;我个人任务还是稍微比较复杂的 需求&#xff1a;需要实现一个视频实时监控、视频回放、视频设备管理&#xff0c;以上都是与组织架构有关 大概的界面长这个样子 听着需求好像很简单&#xff0c;但是~我们需要在一个界面上显示两个厂商的视…

STM32标准库开发——串口发送/单字节接收

USART基本结构 串口发送信息 启动串口一的时钟 RCC_APB2PeriphClockCmd(RCC_APB2Periph_USART1,ENABLE);初始化对应串口一的时钟&#xff0c;引脚&#xff0c;将TX引脚设置为复用推挽输出。 RCC_APB2PeriphClockCmd(RCC_APB2Periph_GPIOA,ENABLE); GPIO_InitTypeDef GPIO_In…

我们应该了解的⽤户画像

当我们谈⽤户画像时&#xff0c;到底在谈什么 对于互联⽹公司来说&#xff0c;企业的增⻓、内容、活动、产品等⼯作基本上都是围绕着“⽤户”来做的&#xff0c;可以说都是在做“⽤户运营”这个⼯作&#xff0c;⽽⽤户画像是⽤户运营⼯作中⾮常重要的⼀环 ⽤户画像的主要特征是…

Linux命令手册

简介 Multics&#xff08;大而全&#xff09;项目失败&#xff0c;吸取教训启动Unix&#xff08;小而精&#xff09;&#xff0c;Linus Benedict Torvalds受Unix启发开发初始版本Linux内核&#xff0c;Git也由其开发&#xff0c;目的是为了更好的管理Linux内核开发。Unix是商业…

Windows如何部署TortoiseSVN客户端

文章目录 前言1. TortoiseSVN 客户端下载安装2. 创建检出文件夹3. 创建与提交文件4. 公网访问测试 前言 TortoiseSVN是一个开源的版本控制系统&#xff0c;它与Apache Subversion&#xff08;SVN&#xff09;集成在一起&#xff0c;提供了一个用户友好的界面&#xff0c;方便用…

专业130+总分380+哈尔滨工程大学810信号与系统考研经验水声电子信息与通信

今年专业课810信号与系统130&#xff0c;总分380顺利考上哈尔滨工程大学&#xff0c;一年的努力终于换来最后的录取&#xff0c;期中复习有得有失&#xff0c;以下总结一下自己的复习经历&#xff0c;希望对大家有帮助&#xff0c;天道酬勤&#xff0c;加油&#xff01;专业课&…

Java找二叉树的公共祖先

描述&#xff1a; 给定一个二叉树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为&#xff1a;“对于有根树 T 的两个节点 p、q&#xff0c;最近公共祖先表示为一个节点 x&#xff0c;满足 x 是 p、q 的祖先且 x 的深度尽可能大&#xff08;一个节…

[AI]文心一言出圈的同时,NLP处理下的ChatGPT-4.5最新资讯

前言 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家&#xff1a;https://www.captainbed.cn/z ChatGPT体验地址 文章目录 前言4.5key价格泄漏ChatGPT4.0使用地址ChatGPT正确打开方式最新功能语音助手存档…

【C语言】编译和链接深度剖析

文章目录 &#x1f4dd;前言&#x1f320; 翻译环境和运行环境&#x1f309;翻译环境 &#x1f320;预处理&#xff08;预编译&#xff09;&#x1f309;编译 &#x1f320;词法分析&#x1f320;语法分析 &#x1f309;语义分析&#x1f320;汇编 &#x1f309; 链接&#x1f…

动态闪图怎么在线合成?仅需三秒在线合成

GIF闪图是一种常见的动态图像格式&#xff0c;它由多个静态图像帧组成&#xff0c;以连续的方式播放&#xff0c;形成动画效果。每个图像帧都可以包含不同的颜色和透明度&#xff0c;因此GIF闪图通常用于展示简单的动画、表情符号或者短视频片段。这种格式在网络上广泛应用&…

论rtp协议的重要性

rtp ps流工具 rtp 协议&#xff0c;实时传输协议&#xff0c;为什么这么重要&#xff0c;可以这么说&#xff0c;几乎所有的标准协议都是国外创造的&#xff0c;感叹一下&#xff0c;例如rtsp协议&#xff0c;sip协议&#xff0c;webrtc&#xff0c;都是以rtp协议为基础&#…

C++中的static(静态)

2014年1月19日 内容整理自The Cherno:C系列 2014年1月20日 内容整理自《程序设计教程&#xff1a;用C语言编程 第三版》 陈家骏 郑滔 -----------------------------------------------------------------------------------------------------------------------------…