基于huggingface peft进行qwen1.5-7b-chat训练/推理/服务发布

news2024/9/20 12:16:36

一、huggingface peft微调框架

1、定义

PEFT 是一个大型预训练模型提供多种高效微方法的Python库。

调传统范式是针对每个下游任模型参数。大模型参数大,种方式得极其昂和不切实际PEFT采用的高效做法是训练少量提示参数(Prompt Tuning)或使用低秩适(LORA)等重新参数化方法来减少微调时训练参数的数量。

二、qwen-1.5b-chat模型训练/推理/服务

1、基础环境准备

datasets==2.21.0

transformers==4.37.0

torch==1.13.0

accelerate==0.30.1

peft==0.4.0

numpy==1.26.4

Jinja2==3.1.4

2、人设定制数据准备

[

    {

        "instruction": "你是谁?",

        "input": "",

        "output": "我是一个语言模型,我叫小飞同学,可以为您做很多事情。请问您有什么问题需要我帮助吗?"

    },

    {

        "instruction": "你是什么?",

        "input": "",

        "output": "我是一个语言模型,我叫小飞同学,可以为您做很多事情。请问您有什么问题需要我帮助吗?"

    },

    {

        "instruction": "请问您是?",

        "input": "",

        "output": "我是一个语言模型,我叫小飞同学,可以为您做很多事情。请问您有什么问题需要我帮助吗?"

    },

    {

        "instruction": "你叫什么?",

        "input": "",

        "output": "我是一个语言模型,我叫小飞同学,可以为您做很多事情。请问您有什么问题需要我帮助吗?"

},

     {

        "instruction": "你的身份是?",

        "input": "",

        "output": "我是一个语言模型,我叫小飞同学,可以为您做很多事情。请问您有什么问题需要我帮助吗?"

    }

]

2、模型训练

from datasets import Dataset

import pandas as pd

from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig

# JSON文件转换为CSV文件

df = pd.read_json('./train.json')

ds = Dataset.from_pandas(df)

model_path = './huggingface/model/Qwen1.5-7B-Chat'

tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)

def process_func(example):

    MAX_LENGTH = 384   

    input_ids, attention_mask, labels = [], [], []

    instruction = tokenizer(f"<|im_start|>system\n现在你要扮演人工智能智能客服助手--小飞同学<|im_end|>\n<|im_start|>user\n{example['instruction'] + example['input']}<|im_end|>\n<|im_start|>assistant\n", add_special_tokens=False

    response = tokenizer(f"{example['output']}", add_special_tokens=False)

    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]

    attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1

    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]

    if len(input_ids) > MAX_LENGTH:  # 做一个截断

        input_ids = input_ids[:MAX_LENGTH]

        attention_mask = attention_mask[:MAX_LENGTH]

        labels = labels[:MAX_LENGTH]

    return {

        "input_ids": input_ids,

        "attention_mask": attention_mask,

        "labels": labels

    }

tokenized_id = ds.map(process_func, remove_columns=ds.column_names)

import torch

model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto",torch_dtype=torch.bfloat16)

model.enable_input_require_grads()

from peft import LoraConfig, TaskType, get_peft_model

config = LoraConfig(

    task_type=TaskType.CAUSAL_LM,

    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],

    inference_mode=False, # 训练模式

    r=8, # Lora

    lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理

    lora_dropout=0.1# Dropout 比例

)

model = get_peft_model(model, config)

args = TrainingArguments(

    output_dir="./output",

    per_device_train_batch_size=4,

    gradient_accumulation_steps=4,

    logging_steps=10,

    num_train_epochs=10,

    save_steps=50,

    learning_rate=1e-4,

    save_on_each_node=True,

    gradient_checkpointing=True

)

trainer = Trainer(

    model=model,

    args=args,

    train_dataset=tokenized_id,

    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),

)

trainer.train()

模型输出目录截图:

3、模型推理

from transformers import AutoModelForCausalLM, AutoTokenizer

import torch

from peft import PeftModel

model_path = './huggingface/model/Qwen1.5-7B-Chat'

lora_path = './output/checkpoint-50'

# 加载tokenizer

tokenizer = AutoTokenizer.from_pretrained(model_path)

# 加载模型

model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto",torch_dtype=torch.bfloat16)

from peft import LoraConfig, TaskType

config = LoraConfig(

    task_type=TaskType.CAUSAL_LM,

    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],

    inference_mode=True, # 训练模式

    r=8, # Lora

    lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理

    lora_dropout=0.1# Dropout 比例

)

# 加载lora权重

model = PeftModel.from_pretrained(model, model_id=lora_path, config=config)

prompt = "你是星火大模型吗?"

messages = [

    {"role": "system", "content": "现在你要扮演人工智能智能客服助手--小飞同学"},

    {"role": "user", "content": prompt}

]

text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

model_inputs = tokenizer([text], return_tensors="pt").to('cuda')

generated_ids = model.generate(

    input_ids=model_inputs.input_ids,

    max_new_tokens=512

)

generated_ids = [

    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)

]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

print(response)

模型推理日志截图:

4、基于FastAPI的sse协议模型服务

import uvicorn

from fastapi import FastAPI

from transformers import AutoModelForCausalLM, AutoTokenizer ,TextStreamer,TextIteratorStreamer

from threading import Thread

import torch

from peft import LoraConfig, TaskType, PeftModel

from sse_starlette.sse import EventSourceResponse

import json

# transfomershuggingface提供的一个工具,便于加载transformer结构的模型

app = FastAPI()

def load_model():

    model_path = './huggingface/model/Qwen1.5-7B-Chat'

    # 加载tokenizer

    tokenizer = AutoTokenizer.from_pretrained(model_path)

    # 加载模型(加速库attn_implementation="flash_attention_2"

    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto",torch_dtype=torch.bfloat16

    # 加载lora权重

    lora_path = './output/checkpoint-50'

    config = LoraConfig(

        task_type=TaskType.CAUSAL_LM,

        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],

        inference_mode=True, # 训练模式

        r=8, # Lora

        lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理

        lora_dropout=0.1# Dropout 比例

    )

    model = PeftModel.from_pretrained(model, model_id=lora_path, config=config)

    return tokenizer,model

tokenizer,model = load_model()

def infer_model(tokenizer,model):

    prompt = "你是星火大模型吗?"

    messages = [

        {"role": "system", "content": "现在你要扮演人工智能智能客服助手--小飞同学"},

        {"role": "user", "content": prompt}

    ]

    #数据提取

    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    model_inputs = tokenizer([text], return_tensors="pt").to('cuda')

    #streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    #模型推理

    from threading import Thread

    generation_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=512)

    thread = Thread(target=model.generate, kwargs=generation_kwargs)

    thread.start()

    for res in streamer:

        yield json.dumps({"data":res},ensure_ascii=False)

@app.get('/predict')

async def predict():

    #return infer_model(tokenizer,model)

    return EventSourceResponse(infer_model(tokenizer,model))

if __name__ == '__main__':

    # 在调试的时候开源加入一个reload=True的参数,正式启动的时候可以去掉

    uvicorn.run(app, host="0.0.0.0", port=6605, log_level="info")

客户端调用示例:

import json

import requests

import time

def listen_sse(url):

    # 发送GET请求到SSE端点

    with requests.get(url, stream=True, timeout=20) as response:

        try:

            # 确保请求成功

            response.raise_for_status()

            # 逐行读取响应内容

            result = ""

            for line in response.iter_lines():

                if line:

                    event_data = line.decode('utf-8')

                    if event_data.startswith('data:'):

                        # 去除'data:'前缀,获取实际数据

                        line = event_data.lstrip('data:')

                        line_data = json.loads(line)

                        result += line_data["data"]

                        print(result)

       except requests.exceptions.HTTPError as err:

            print(f"HTTP error: {err}")

        except Exception as err:

            print(f"An error occurred: {err}")

            return

sse_url = 'http://127.0.0.1:6605/predict'

listen_sse(sse_url

服务推理流式输出截图:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2082458.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Datawhale X 李宏毅苹果书 AI夏令营 task2

《深度学习详解》 - 自适应学习率&#xff08;Task2&#xff09; 1. 自适应学习率的背景与重要性 学习率的挑战&#xff1a; 在训练深度学习模型时&#xff0c;选择合适的学习率至关重要。过大的学习率会导致训练过程中的震荡&#xff0c;使模型无法收敛&#xff1b;过小的学…

在 Navicat BI 中创建自定义字段:自定义排序顺序

在 Navicat BI 中&#xff0c;数据源引用你连接中的表或文件/ODBC 源中的数据&#xff0c;并可从不同服务器类型的中选择数据。数据集中的字段可用于构建图表。事实上&#xff0c;在构建图表时&#xff0c;你需要指定用于填充图表的数据源。 正如我们在整个系列中所看到的&…

html+css网页设计 个人网站模版 个人博客12个页面

htmlcss网页设计 个人网站模版 个人博客12个页面 网页作品代码简单&#xff0c;可使用任意HTML编辑软件&#xff08;如&#xff1a;Dreamweaver、HBuilder、Vscode 、Sublime 、Webstorm、Text 、Notepad 等任意html编辑软件进行运行及修改编辑等操作&#xff09;。 获取源码…

【Material-UI】Radio Group中的独立单选按钮详解

文章目录 一、Radio 组件概述1. 组件介绍2. 基本用法 二、Radio 组件的关键特性1. 选中状态控制2. 关联标签3. 自定义样式和图标4. 使用 FormControlLabel 提供标签支持 三、Radio 组件的实际应用场景1. 表单中的单选题2. 设置选项3. 导航选择 四、注意事项1. 无障碍支持2. 样式…

开源产品GeoMesa、MobilityDB存在哪些不足

友情链接&#xff1a; •时空数据库系列&#xff08;一&#xff09;什么是时空数据&#xff1f;特征和适用场景有哪些&#xff1f; •时空数据库系列&#xff08;二&#xff09;时空数据库介绍 了解数据模型与应用场景 •时空数据库系列&#xff08;三&#xff09;技术讲解&…

Linux网口指令

一 查看配置 ifconfig 二 修改IP sudo ifconfig ens33 192.168.150.100 netmask 255.255.255.0

一键复制模板,乔拓云助力小程序快速上线

选择乔拓云模板开发小程序&#xff0c;成本低且高效&#xff0c;适合各行业快速搭建。注册账号后&#xff0c;进入模板中心&#xff0c;轻松找到匹配行业的模板。模板内容自定义灵活&#xff0c;图片、文字随心修改&#xff0c;右侧编辑区操作直观。 小程序开发步骤概览&#x…

秋招复习笔记——嵌入式裸机开发

底层相关的内容&#xff0c;之前掌握的不扎实&#xff0c;现在重新把相关重点记录一下&#xff0c;做个笔记记诵。 相关基础知识 ST简单内容 用的F103ZET6&#xff0c;72MHz&#xff0c;FLASH是512KB&#xff0c;SRAM是64KB&#xff0c;144个引脚&#xff0c;2基本定时器&am…

Java 入门指南:Java IO流 —— 字符流

何为Java流 Java 中的流&#xff08;Stream&#xff09; 是用于在程序中读取或写入数据的抽象概念。流可以从不同的数据源&#xff08;输入流&#xff09;读取数据&#xff0c;也可以将数据写入到不同的目标&#xff08;输出流&#xff09;。流提供了一种统一的方式来处理不同…

【深入解析】最优控制中的Bellman方程——从决策到最优路径的探索

【深入解析】最优控制中的Bellman方程——从决策到最优路径的探索 关键词提炼 #Bellman方程 #最优控制 #动态规划 #值函数 #策略优化 #强化学习 第一节&#xff1a;Bellman方程的通俗解释与核心概念 1.1 通俗解释 Bellman方程是动态规划中的一个核心概念&#xff0c;它像是…

apache服务器的配置(服务名httpd,端口80 , 443)

目录 前言 配置文件 apache服务器的配置 安装apache服务器 配置防火墙 编辑配置文件 配置虚拟主机 基于域名的虚拟主机 配置dns服务器 将网站文件放到/var/www/目录下 修改主配置文件 新建vhost文件夹和xxx.conf文件 编辑 .conf 文件 检查配置 重启服务并访问网…

VS2022 QT环境显示中文乱码问题

1.问题描述 在VS2022中搭配QT6.2环境&#xff0c;在文本处设置中文&#xff0c;运行程序文本处显示乱码&#xff0c;未成功显示想要的中文。 2.VS2015解决方案 如果是VS2015的话&#xff0c;直接文件->高级保存选项可以设置编码格式。 修改编码格式如图所示&#xff1a;…

2024 Python3.10 系统入门+进阶(九):封装解构和集合Set常用操作详解

目录 一、封装和解构1.1 基本概念1.2 简单解构1.3 剩余变量解构1.4 嵌套解构1.5 其他解构1.6 序列模式匹配&#xff08;Python 3.10 最引人注目的新功能&#xff09;1.6.1 结构模式匹配的核心概念1.6.2 结构模式匹配的优势1.6.3 使用场景 二、集合Set2.1 初始化2.1.1 "{}&…

Java-数据结构-包装类和认识泛型 !!!∑(゚Д゚ノ)ノ

目录&#xff1a; 一、包装类&#xff1a; 1、基本数据类型所对应的包装类&#xff1a; 2、装箱和拆箱&#xff1a; 二、 泛型&#xff1a; 1、什么是泛型&#xff1a; 2、语法&#xff1a; 三、泛型类的使用&#xff1a; 四、裸类型&#xff1a; 五、泛型的擦除机制&…

82、k8s的service-NodePort端口开放和生命周期

0、单节点服务&#xff0c;以及k8s命令 [rootmaster01 ~]# kubectl create deployment nginx1 --imagenginx:1.22 --replicas3[rootmaster01 ~]# kubectl create deployment nginx1 --imagenginx:1.22 ##创建资源 deployment.apps/nginx1 created[rootmaster01 opt]# kubec…

软件设计原则之依赖倒置原则

依赖倒置原则&#xff08;Dependency Inversion Principle, DIP&#xff09;是软件设计中一个非常重要的原则&#xff0c;它属于面向对象设计的SOLID原则之一。这个原则的核心在于通过抽象来降低模块间的耦合度&#xff0c;使得系统更加灵活和可维护。 目录 依赖倒置原则的基本…

对标GPT-4o,科大讯飞正以大模型重塑语音产业

每个科技时代&#xff0c;都有每个时代的“入口”和“推手”。 在PC时代&#xff0c;浏览器和搜索引擎是主要入口&#xff0c;用户通过键盘和鼠标进行交互。移动互联时代&#xff0c;APP和应用商店成为典型入口&#xff0c;用户用手指和触摸屏进入互联网世界。而在眼下的AI时代…

8月27c++

提示并输入一个字符串&#xff0c;统计字符串中字母、数字、空格和其他字符的个数 代码 #include <iostream> #include <cstring> using namespace std;int main() {string str;cout<<"输入一个字符串";getline(cin,str);//输入字符串int lenstr…

【vulhub】Weblogic WLS Core Components 反序列化命令执行漏洞(CVE-2018-2628)

简单来说就是先用序列化工具ysoserial启动一个JRMP服务&#xff0c;加载先相关漏洞利用链&#xff0c;加载你要执行的恶意代码。 并将上述结果通过序列化工具ysoserial将我们的恶意代码进行一个序列化操作。 第二步就是将我们的exp去加载ysoserial序列化后的数据&#xff0c;后…

vue侧边栏

在Vue中创建一个侧边栏&#xff08;Sidebar&#xff09;是一个常见的需求&#xff0c;特别是在构建管理界面或需要导航菜单的应用时。侧边栏通常用于展示应用的导航链接或菜单项&#xff0c;用户可以通过点击这些链接来访问应用的不同部分。 <template><el-tree :data…