微调小型Llama 3.2(十亿参数)模型取代GPT-4o

news2024/10/19 12:15:36
微调Llama VS GPT-4o

别忘了关注作者,关注后您会变得更聪明,不关注就只能靠颜值了 ^_^。

一位年轻的儿科医生与一位经验丰富的医师,谁更能有效治疗婴儿的咳嗽?

两者都具备治疗咳嗽的能力,但儿科医生由于专攻儿童医学,或许在诊断婴儿疾病方面更具优势。这也正如小模型在某些特定任务上的表现,往往经过微调后能够比大型模型更为出色,尽管大型模型号称可以处理任何问题。

最近,我面临了一个必须在两者之间做出选择的场景。

我正在开发一个查询路由系统,用于将用户的请求引导至合适的部门,然后由人工继续对话。从技术角度看,这是一个文本分类任务。虽然GPT-4o及其小版本在这类任务上表现优秀,但它的使用成本较高,且由于是封闭模型,我无法在自己的环境中进行微调。尽管OpenAI提供了微调服务,但对我来说,成本仍然过于昂贵。

每百万个Token的训练费用为25美元,而我的训练数据量很快就达到了数百万个Token。再加上微调后的模型使用费用比普通模型高50%,这对我的小型项目而言,预算无疑是无法承受的。因此,我必须寻找一个替代方案。

相比之下,开源模型在处理分类任务时同样表现不俗,且训练成本相对较低,尤其是在使用GPU时。经过慎重考虑,我决定转向小型模型。小型LLM通过微调可以在有限的预算下实现令人满意的效果,这是我目前最为理想的选择。

小型模型可以在普通硬件上运行,微调所需的GPU也不必过于昂贵。更为重要的是,小模型的训练和推理速度远快于大型LLM。

经过一番调研,我挑选了几款候选模型——Phi3.5、DistillBERT和GPT-Neo,但最终选择了Meta Llama 3.2的1B模型。这个选择并非完全理性,部分原因可能是最近关于这个模型的讨论较多。不过,实践出真知,我决定通过实测来检验效果。

在接下来的部分,我将分享我微调Llama 3.2–1B指令模型与使用少样本提示的GPT-4o的对比结果。

微调Llama 3.2 1B模型(免费实现微调)

微调模型的确可能需要较高的成本,但如果选择合适的策略,还是能够大幅降低开支。针对我的情况,我采用了参数优化的微调(PEFT)策略,而不是完全参数微调。完全微调会重新训练模型中的全部1B参数,成本太高,且可能导致“灾难性遗忘”,即模型丢失预训练时学到的部分知识。而PEFT策略则聚焦于仅微调部分参数,大大减少了时间和资源的消耗。

其中,“低秩适应”(LORA)技术是目前较为流行的微调方法。LORA允许我们仅对某些特定层的部分参数进行微调,这样的训练不仅高效且效果明显。

此外,通过模型量化,我们可以将模型的参数压缩为float16甚至更小的格式,这不仅减少了内存消耗,还能提高计算速度。当然,精度可能会有所下降,但对于我的任务来说,这一折衷是可以接受的。

接下来,我将在免费的Colab和Kaggle平台上进行了微调。这些平台提供的GPU资源虽然有限,但对于像我这样的小模型训练任务已经足够,关键它们免费。

Llama-3.2微调与GPT-4o少样本提示的对比

微调Llama 3.2 1B模型的过程相对简单。我参考了Unsloth提供的Colab笔记本,并做了部分修改。原笔记本微调的是3B参数的模型,而我将其改为1B参数的Llama-3.2–Instruct,因为我想测试较小模型在分类任务上的表现。接着,我将数据集替换为我自己的数据,用于训练。

# Before

from unsloth.chat_templates import standardize_sharegpt

dataset = standardize_sharegpt(dataset)

dataset = dataset.map(formatting_prompts_func, batched = True,)

# After

from datasets import Dataset

dataset = Dataset.from_json("/content/insurance_training_data.json")

dataset = dataset.map(formatting_prompts_func, batched = True,)

最稳妥的做法是选择一个与笔记本初始设计相符的数据集,例如下面的这个。

{

"conversations": [

{'role': 'user', 'content': <user_query>}

{'role': 'assistant', 'content': <department>}

]

}

到这里为止,这两处调整已经足够让你用自己的数据微调模型了。

评估微调后的模型

接下来是关键的一步:评估测试。

评估LLM是一项广泛且富有挑战性的工作,也是LLM开发中最为重要的技能之一。我将再出一篇文章,在其中详细讨论过如何评估LLM应用,别忘了关注作者,关注后您会变得更聪明,不关注就只能靠颜值了 ^_^

不过,为了简洁起见,这次我会采用经典的混淆矩阵方式进行评估。只需在笔记本的末尾添加下面的代码即可。

from langchain.prompts import FewShotPromptTemplate

from langchain_openai import ChatOpenAI

from langchain_core.prompts import PromptTemplate

from pydantic import BaseModel

# 1. A function to generate response with the fine-tuned model

def generate_response(user_query):

# Enable faster inference for the language model

FastLanguageModel.for_inference(model)

# Define the message template

messages = [

{"role": "system", "content": "You are a helpful assistant who can route the following query to the relevant department."},

{"role": "user", "content": user_query},

]

# Apply the chat template to tokenize the input and prepare for generation

tokenized_input = tokenizer.apply_chat_template(

messages,

tokenize=True,

add_generation_prompt=True, # Required for text generation

return_tensors="pt"

).to("cuda") # Send input to the GPU

# Generate a response using the model

generated_output = model.generate(

input_ids=tokenized_input,

max_new_tokens=64,

use_cache=True, # Enable cache for faster generation

temperature=1.5,

min_p=0.1

)

# Decode the generated tokens into human-readable text

decoded_response = tokenizer.batch_decode(generated_output, skip_special_tokens=True)[0]

# Extract the assistant's response (after system/user text)

assistant_response = decoded_response.split("\n\n")[-1]

return assistant_response

# 2. Generate Responeses with OpenAI GPT-4o

# Define the prompt template for the example

example_prompt_template = PromptTemplate.from_template(

"User Query: {user_query}\n{department}"

)

# Initialize OpenAI LLM (ensure the OPENAI_API_KEY environment variable is set)

llm = ChatOpenAI(temperature=0, model="gpt-4o")

# Define few-shot examples

examples = [

{"user_query": "I recently had an accident and need to file a claim for my vehicle. Can you guide me through the process?", "department": "Claims"},

...

]

# Create a few-shot prompt template

few_shot_prompt_template = FewShotPromptTemplate(

examples=examples,

example_prompt=example_prompt_template,

prefix="You are an intelligent assistant for an insurance company. Your task is to route customer queries to the appropriate department.",

suffix="User Query: {user_query}",

input_variables=["user_query"]

)

# Define the department model to structure the output

class Department(BaseModel):

department: str

# Function to predict the appropriate department based on user query

def predict_department(user_query):

# Wrap LLM with structured output

structured_llm = llm.with_structured_output(Department)

# Create the chain for generating predictions

prediction_chain = few_shot_prompt_template | structured_llm

# Invoke the chain with the user query to get the department

result = prediction_chain.invoke(user_query)

return result.department

# 3. Read your evaluation dataset and predict departments

import json

with open("/content/insurance_bot_evaluation_data (1).json", "r") as f:

eval_data = json.load(f)

for ix, item in enumerate(eval_data):

print(f"{ix+1} of {len(eval_data)}")

item['open_ai_response'] = generate_response(item['user_query'])

item['llama_response'] = item['open_ai_response']

# 4. Compute the precision, recall, accuracy, and F1 scores for the predictions.

# 4.1 Using Open AI

from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score

true_labels = [item['department'] for item in eval_data]

predicted_labels_openai = [item['open_ai_response'] for item in eval_data]

# Calculate the scores for open_ai_response

precision_openai = precision_score(true_labels, predicted_labels_openai, average='weighted')

recall_openai = recall_score(true_labels, predicted_labels_openai, average='weighted')

accuracy_openai = accuracy_score(true_labels, predicted_labels_openai)

f1_openai = f1_score(true_labels, predicted_labels_openai, average='weighted')

print("OpenAI Response Scores:")

print("Precision:", precision_openai)

print("Recall:", recall_openai)

print("Accuracy:", accuracy_openai)

print("F1 Score:", f1_openai)

# 4.2 Using Fine-tuned Llama 3.2 1B Instruct

true_labels = [item['department'] for item in eval_data]

predicted_labels_llama = [item['llama_response'] for item in eval_data]

# Calculate the scores for llama_response

precision_llama = precision_score(true_labels, predicted_labels_llama, average='weighted', zero_division=0)

recall_llama = recall_score(true_labels, predicted_labels_llama, average='weighted', zero_division=0)

accuracy_llama = accuracy_score(true_labels, predicted_labels_llama)

f1_llama = f1_score(true_labels, predicted_labels_llama, average='weighted', zero_division=0)

print("Llama Response Scores:")

print("Precision:", precision_llama)

print("Recall:", recall_llama)

print("Accuracy:", accuracy_llama)

print("F1 Score:", f1_llama)

以上代码非常清晰明了。我们编写了一个函数,利用微调后的模型进行部门预测。同时,也为OpenAI GPT-4o构建了一个类似的函数。

接着,我们使用这些函数对评估数据集生成预测结果。

评估数据集中包含了预期的分类,现在我们也获得了模型生成的分类,这为接下来的指标计算提供了基础。

接下来,我们将进行这些计算。

以下是结果:

OpenAI Response Scores:

Precision: 0.9

Recall: 0.75

Accuracy: 0.75

F1 Score: 0.818

Llama Response Scores:

Precision: 0.88

Recall: 0.73

Accuracy: 0.79

F1 Score: 0.798

结果显示,微调后的模型表现几乎接近GPT-4o。对于一个只有1B参数的小型模型来说,这已经相当令人满意了。

尽管GPT-4o的表现确实更好,但差距非常微小。

此外,如果在少样本提示中提供更多示例,GPT-4o的结果可能会进一步提升。不过,由于我的示例有时比较长,甚至包括几段文字,这会显著增加成本,毕竟OpenAI是按输入Token计费的。

总结

我现在对小型LLM非常认可。它们运行速度快,成本低,而且在大多数使用场景中都能满足需求,尤其是在不进行微调的情况下。

在这篇文章中,我讨论了如何微调Llama 3.2 1B模型。该模型可以在较为普通的硬件上运行,而且微调成本几乎为零。我当前的任务是文本分类。

当然,这并不意味着小型模型能够全面超越像GPT-4o这样的巨型模型,甚至也不一定能胜过Meta Llama的8B、11B或90B参数的模型。较大的模型拥有更强的多语言理解能力、视觉指令处理能力,以及更加广泛的世界知识。

我的看法是,如果这些“超级能力”不是你当前的需求,为什么不选择一个小型LLM呢?”

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2218459.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据中台业务架构图

数据中台的业务架构是企业实现数据驱动决策和业务创新的关键支撑。它主要由数据源层、数据存储与处理层、数据服务层以及数据应用层组成。 数据源层涵盖了企业内部各个业务系统的数据&#xff0c;如 ERP、CRM 等&#xff0c;以及外部数据来源&#xff0c;如社交媒体、行业数据…

2024年9月中国电子学会青少年软件编程(Python)等级考试试卷(一级)答案 + 解析

一、单选题 1、下列选项中关于 turtle.color(red) 语句的作用描述正确的是&#xff1f;&#xff08; &#xff09; A. 只设置画笔的颜色为红色 B. 只设置填充的颜色为红色 C. 设置画笔和填充的颜色为红色 D. 设置画笔的颜色为红色&#xff0c;设置画布背景的颜色为红色 正…

基于langchain.js快速搭建AI-Agent

基于langchain.js快速搭建AI-Agent 什么是AIAgent? 1. 替换默认请求地址为自定义API 构建基础会话大模型 import { ChatOpenAI } from langchain/openai;const chat new ChatOpenAI({model: gpt4o,temperature: 0,apiKey: ****,configuration: {baseURL: https://www.xx.co…

[含文档+PPT+源码等]精品大数据项目-基于python实现的社交媒体用户活跃时间预测系统

大数据项目——基于Python实现的社交媒体用户活跃时间预测系统的背景可以从以下几个方面进行详细阐述&#xff1a; 一、项目背景与意义 随着互联网技术的快速发展和社交媒体的普及&#xff0c;社交媒体平台已经成为人们日常生活中不可或缺的一部分。每天都有数以亿计的用户在…

Golang笔记_day08

Go面试题&#xff08;一&#xff09; 1、空切片 和 nil 切片 区别 空切片&#xff1a; 空切片是指长度和容量都为0的切片。它不包含任何元素&#xff0c;但仍然具有切片的容量属性。在Go语言中&#xff0c;可以使用内置的make函数创建一个空切片&#xff0c;例如&#xff1a;…

[Godot4] 水底气泡的 gdshader

水底气泡的 gdshader 来自 shadertoy 的代码 在这里&#xff0c;我添加了 x 方向和 y 方向上的 uv 位移 但是还是感觉太弱智 shader_type canvas_item; // Created by greenbird10 // License Creative Commons Attribution-NonCommercial-ShareAlike 3.0uniform float bub…

C语言笔记(指针的进阶)

目录 1.字符指针 2.指针数组 3.数组指针 3.1.创建数组指针 3.2.&数组名和数组名 1.字符指针 int main() { char ch w;char* pc &ch;const char *p "abcdef";//常量字符串 产生的值就是首元素的地址//常量字符串不能被修改 因此需要加上一个…

go 环境安装

最近搭建AIGC大模型聚合平台&#xff0c;涉及到了go语言&#xff0c;随手整理一下环境安装步骤分享给大家。 1、安装 官网下载地址&#xff1a;https://go.dev/ 1.1 Linux 安装 yum install git -y yum install golang -y yum install gcc -y # 日志工具&#xff0c;如需要…

Web保存状态的手段(请求转发,Cookie的使用)

一&#xff0c;掌握请求转发 请求转发与重定向技术都是跳转页面的途径&#xff0c;但是这两个技术之间也有不同之处。 请求转发更倾向于servlet跳转jsp&#xff0c;而重定向更倾向于servlet跳转到servlet。 1. 常用页面跳转方法2:请求转发(重写URL) RequestDispatcher接口对…

基于SpringBoot+Vue+uniapp微信小程序的教学质量评价系统的详细设计和实现

项目运行截图 技术框架 后端采用SpringBoot框架 Spring Boot 是一个用于快速开发基于 Spring 框架的应用程序的开源框架。它采用约定大于配置的理念&#xff0c;提供了一套默认的配置&#xff0c;让开发者可以更专注于业务逻辑而不是配置文件。Spring Boot 通过自动化配置和约…

细胞力学需测量,多种方法齐上场,优劣互补要明了

大家好&#xff01;今天我们来了解细胞力学方法的比较研究——《A comparison of methods to assess cell mechanical properties》发表于《Nature Methods》。细胞力学对细胞的多种功能至关重要&#xff0c;然而不同测量方法得到的结果差异较大。本次研究选取了MCF-7细胞&…

用Java爬虫API,轻松获取taobao商品SKU信息

在电子商务的世界里&#xff0c;SKU&#xff08;Stock Keeping Unit&#xff0c;库存单位&#xff09;是商品管理的基础。对于商家来说&#xff0c;SKU的详细信息对于库存管理、价格策略制定、市场分析等都有着重要作用。taobao作为中国最大的电子商务平台之一&#xff0c;提供…

JavaSE——集合4:List接口实现类—LinkedList

目录 一、LinkedList的全面说明 二、LinkedList的底层操作机制 (一)LinkedList添加结点源码 (二)LinkedList删除结点源码 三、LinkedList常用方法 四、ArrayList与LinkedList的选择 一、LinkedList的全面说明 LinkedList底层实现了双向链表和双端队列的特点可以添加任意…

【热门】用ChatGPT做智慧农业云平台——农业ERP管控系统

随着科技的进步,原有农业种植方式已经不能满足社会发展的需要,必须对传统的农业进行技术更新和改造。经过多年的实践,人们总结出一种新的种植方法——温室农业,即“用人工设施控制环境因素,使作物获得最适宜的生长条件,从而延长生产季节,获得最佳的产出”。这种农业生产方式…

“智改数转”转了什么?

万界星空科技专门针对数字化改造申报的MES系统具有显著的技术优势和实施效果&#xff0c;能够为制造型企业提供全方位、高效、可靠的数字化转型支持。项目合作可以私信或者百度上海万界星空科技官网。 “智改数转”是一个综合性的过程&#xff0c;涉及企业多个方面的转型和升…

随机抽取学号

idea 配置 抽学号 浏览器 提交一个100 以内的整数。&#xff0c;后端接受后&#xff0c;根据提供的整数&#xff0c;产生 100 以内的 随机数&#xff0c;返回给浏览器&#xff1f; 前端&#xff1a;提供 随机数范围 &#xff0c;病发送请求后端&#xff1a;处理随机数的产生&…

C 408—《数据结构》算法题基础篇—链表(下)

目录 Δ前言 一、两个升序链表归并为一个降序链表 0.题目&#xff1a; 1.算法设计思想&#xff1a; 2.C语言描述&#xff1a; 3.算法的时间和空间复杂度&#xff1a; 二、两个链表的所有相同值结点生成一个新链表 0.题目&#xff1a; 1.算法设计思想&#xff1a; 2.C语言描述…

DDD重构-实体与限界上下文重构

DDD重构-实体与限界上下文重构 概述 DDD 方法需要不同类型的类元素&#xff0c;例如实体或值对象&#xff0c;并且几乎所有这些类元素都可以看作是常规的 Java 类。它们的总体结构是 Name: 类的唯一名称 Properties&#xff1a;属性 Methods: 控制变量的变化和添加行为 一…

MySQL中 truncate、drop和delete的区别

MySQL中 truncate、drop和delete区别 truncate 执行速度快&#xff0c;删除所有数据&#xff0c;但是保留表结构不记录日志事务不安全&#xff0c;不能回滚可重置自增主键计数器 drop 执行速度较快&#xff0c;删除整张表数据和结构不记录日志事务不安全&#xff0c;不能回…

JavaWeb——Maven(3/8):配置Maven环境(当前工程,全局),创建Maven项目

目录 配置Maven环境 当前工程 全局 创建Maven项目 配置Maven环境 当前工程 选择 IDEA中 File --> Settings --> Build,Execution,Deployment --> Build Tools --> Maven 设置 IDEA 使用本地安装的 Maven&#xff0c;并修改配置文件及本地仓库路径 首先在 IDE…