使用微调OneKE来实现识别人名、公司名、和产品名称

news2024/9/17 8:18:15

简介

       公司做个大模型助手,需要提取用户query中的人名、公司名和产品名称来进行问答。目前我使用的是bert+crf模型 开源cluer数据+自造的数据,训练数据18w,测试数据1.3w。

目前这个方案有些瓶颈,主要表现如下:

1、产品名称识别错误 有的时候会把产品名称识别很长

2、产品名称简称识别不了,比如招白,这块数据训练集里面是没有的,训练集里面的产品名称是基金全称、基金简称、以及基金代码

加上目前都是用大模型方案来做,OneKE是阿里联合浙江大学开源的,我打算尝试一下。

ONEKE 魔搭社区

OnKE

我通过官方示例改成我的样本,尝试了几个效果不太行。我就打算进行微调。

首先我的数据集太多了18w,我打算进行过滤。

我使用Dmeta-embedding-zh-small向量Embedding模型,计算所有句子(训练集+测试集)的Embedding,然后使用以下代码进行过滤。

import pickle
import numpy as np
from numpy.linalg import norm

embedding_dict = pickle.load(open('embedding_dict.pkl', 'rb'))

texts = []
embedding = []
for k, v in embedding_dict.items():
    texts.append(k)
    embedding.append(v)

need_set = set()
not_need_set = set()
A = np.array(embedding)

for i in range(len(embedding)):
    if i in not_need_set:
        continue
    B = embedding[i]
    cosine = np.dot(A, B)
    need_set.add(i)
    for index in np.where(cosine > 0.80)[0]:
        not_need_set.add(index)
    if i %1000==0:
        print(i,len(not_need_set),len(need_set))

print(len(need_set))
print(len(not_need_set))

texts_set=set()
for i,text in enumerate(texts):
    if i not in need_set:
        continue
    texts_set.add(text)

print(len(texts_set))
pickle.dump(texts_set, open('texts_set.pkl', 'wb'))

只取 在texts_set里面的句子作为训练集

最后得到train 14498  dev 6214 ( train_size=0.7,来切割原始训练集的)

微调

数据格式按照 DeepKE/example/llm/OneKE.md at main · zjunlp/DeepKE · GitHub  数据准备 部分进行准备

{"text": "相比之下,青岛海牛队和广州松日队的雨中之战虽然也是0∶0,但乏善可陈。", "entity": [{"entity": "广州松日队", "entity_type": "组织机构"}, {"entity": "青岛海牛队", "entity_type": "组织机构"}]}

然后使用转换脚本进行转换 训练数据

python ie2instruction/convert_func.py \
    --src_path data/NER/sample.json \
    --tgt_path data/NER/train.json \
    --schema_path data/NER/schema.json \
    --language zh \
    --task NER \
    --split_num 6 \       
    --random_sort \
    --split train

测试数据

python ie2instruction/convert_func.py \
    --src_path data/NER/sample.json \
    --tgt_path data/NER/test.json \
    --schema_path data/NER/schema.json \
    --language zh \
    --task NER \
    --split_num 6 \
    --split test

我是在OneKE模型上微调的,我参照github上的

output_dir='lora/oneke-continue'
mkdir -p ${output_dir}
CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 --master_port=1287 src/finetune.py \
    --do_train --do_eval \
    --overwrite_output_dir \
    --model_name_or_path 'models/OneKE' \
    --stage 'sft' \
    --model_name 'llama' \
    --template 'llama2_zh' \
    --train_file 'data/train.json' \
    --valid_file 'data/dev.json' \
    --output_dir=${output_dir} \
    --per_device_train_batch_size 2 \
    --per_device_eval_batch_size 2 \
    --gradient_accumulation_steps 4 \
    --preprocessing_num_workers 16 \
    --num_train_epochs 10 \
    --learning_rate 5e-5 \
    --max_grad_norm 0.5 \
    --optim "adamw_torch" \
    --max_source_length 400 \
    --cutoff_len 700 \
    --max_target_length 300 \
    --evaluation_strategy "epoch" \
    --save_strategy "epoch" \
    --save_total_limit 10 \
    --lora_r 64 \
    --lora_alpha 64 \
    --lora_dropout 0.05 \
    --bf16 \
    --bits 4

将example/llm/InstructKGC/ft_scripts/fine_continue_full.bash改成了

output_dir='lora/oneke-continue'
mkdir -p ${output_dir}
CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 --master_port=1287 src/finetune.py \
    --do_train --do_eval \
    --overwrite_output_dir \
    --model_name_or_path '/data/lifengxin/OneKE' \
    --stage 'sft' \
    --model_name 'llama' \
    --template 'llama2_zh' \
    --train_file 'data/NER/train.json' \
    --valid_file 'data/NER/dev.json' \
    --output_dir=${output_dir} \
    --per_device_train_batch_size 2 \
    --per_device_eval_batch_size 2 \
    --gradient_accumulation_steps 4 \
    --preprocessing_num_workers 16 \
    --num_train_epochs 10 \
    --learning_rate 5e-5 \
    --max_grad_norm 0.5 \
    --optim "adamw_torch" \
    --max_source_length 400 \
    --cutoff_len 700 \
    --max_target_length 300 \
    --evaluation_strategy "epoch" \
    --save_strategy "epoch" \
    --save_total_limit 10 \
    --lora_r 16 \
    --lora_alpha 32 \
    --lora_dropout 0.05 \
    --bf16 \
    --bits 4

然后进行训练

训练的镜像是egslingjun-registry.cn-wulanchabu.cr.aliyuncs.com/egslingjun/training-nv-pytorch:24.07  从阿里云镜像仓库里面拉的。

训练完成

-rw-r--r-- 1 root root 5.0K Sep  6 17:16 README.md
-rw-r--r-- 1 root root  726 Sep  6 17:16 adapter_config.json
-rw-r--r-- 1 root root 239M Sep  6 17:16 adapter_model.safetensors
-rw-r--r-- 1 root root  371 Sep  6 17:18 all_results.json
drwxr-xr-x 2 root root 4.0K Sep  6 15:28 checkpoint-1359
drwxr-xr-x 2 root root 4.0K Sep  6 15:44 checkpoint-1813
drwxr-xr-x 2 root root 4.0K Sep  6 15:59 checkpoint-2266
drwxr-xr-x 2 root root 4.0K Sep  6 16:15 checkpoint-2719
drwxr-xr-x 2 root root 4.0K Sep  6 16:30 checkpoint-3172
drwxr-xr-x 2 root root 4.0K Sep  6 16:45 checkpoint-3626
drwxr-xr-x 2 root root 4.0K Sep  6 17:01 checkpoint-4079
drwxr-xr-x 2 root root 4.0K Sep  6 14:58 checkpoint-453
drwxr-xr-x 2 root root 4.0K Sep  6 17:14 checkpoint-4530
drwxr-xr-x 2 root root 4.0K Sep  6 15:13 checkpoint-906
-rw-r--r-- 1 root root  179 Sep  6 17:18 eval_results.json
-rw-r--r-- 1 root root  226 Sep  6 17:16 train_results.json
-rw-r--r-- 1 root root 388K Sep  6 17:16 trainer_state.json
-rw-r--r-- 1 root root 5.3K Sep  6 17:16 training_args.bin

进行合并

 python src/export_model.py     --model_name_or_path '/data/OneKE'     --checkpoint_dir '/data/lora/oneke-continue/checkpoint-4530'     --export_dir 'data/OneKE_v1'     --stage 'sft'     --model_name 'llama'     --template 'llama2_zh'     --output_dir 'data/OneKE_v1_test'

合并后进行效果验证

from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForCausalLM,
    GenerationConfig,
    BitsAndBytesConfig
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = '/data/OneKE_v1/'
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)


# 4bit量化OneKE
quantization_config=BitsAndBytesConfig(
    load_in_4bit=True,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    config=config,
    device_map="auto",
    quantization_config=quantization_config,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)
model.eval()


system_prompt = '<<SYS>>\nYou are a helpful assistant. 你是一个乐于助人的助手。\n<</SYS>>\n\n'

sintruct="{\"instruction\": \"你是专门进行实体抽取的专家。请从input中抽取出符合schema定义的实体,不存在的实体类型返回空列表。请按照JSON字符串的格式回答。\", \"schema\": [\"人名\",
\"公司名称\", \"产品名称\"], \"input\": \"164105的投资风格是什么?\"}"


sintruct = '[INST] ' + system_prompt + sintruct + '[/INST]'

input_ids = tokenizer.encode(sintruct, return_tensors="pt").to(device)
input_length = input_ids.size(1)
generation_output = model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_length=1024, max_new_tokens=512, return_dict_in_generate=True), pad_token_id=tokenize
r.eos_token_id)
generation_output = generation_output.sequences[0]
generation_output = generation_output[input_length:]
output = tokenizer.decode(generation_output, skip_special_tokens=True)

print(output)

量化

由于模型占用内存较大,因此产生了量化的想法。

使用llama.cpp进行量化

python llama.cpp/convert.py /data/OneKE_v1   --outfile OneKE_v1.gguf   --outtype q8_0

使用llama-server启动服务进行测试

  ./llama-server  -m OneKE_v1.gguf --port 80 --host 0.0.0.0

效果还可以,需要8s左右推断完

想快速推断完 尝试 tq1_0 tq2_0进行压缩

python3 convert_hf_to_gguf.py /data/OneKE_v1   --outfile OneKE_v1_tq1_0.gguf   --outtype tq1_0
python3 convert_hf_to_gguf.py /data/OneKE_v1   --outfile OneKE_v1_tq2_0.gguf   --outtype tq2_0
-rw-r--r--  1 root root  14G Sep  7 03:24 OneKE_v1.gguf
-rw-r--r--  1 root root 3.6G Sep  7 03:40 OneKE_v1_tq1_0.gguf
-rw-r--r--  1 root root 4.2G Sep  7 03:47 OneKE_v1_tq2_0.gguf
./llama-server  -m OneKE_v1_tq1_0.gguf --port 80  --host 0.0.0.0
./llama-server  -m OneKE_v1_tq2_0.gguf --port 80  --host 0.0.0.0

tq1_0 tq2_0 出不来结果,放弃了。

用autoawq进行量化 然后用GPU推断

pip install autoawq
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = '/data/OneKE_v1'
quant_path = '/data/OneKE_v1-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
model.quantize(tokenizer, quant_config=quant_config)

# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
~

OneKE_v1-awq 7.2G 

OneKE_v1 25G

发布

使用vllm进行发布

镜像使用 egslingjun-registry.cn-wulanchabu.cr.aliyuncs.com/egslingjun/llm-inference:vllm0.5.4-deepgpu-llm24.7-pytorch2.4.0-cuda12.4-ubuntu22.04  阿里云的

OneKE_v1-awq  占用显卡15G 推断500ms

python3 -m vllm.entrypoints.openai.api_server --model /data/OneKE_v1-awq --host 0.0.0.0 --port 80 --max-model-len 1024 --gpu-memory-utilization 0.3

OneKE_v1 模型 占用显卡 30G,推断1317ms

python3 -m vllm.entrypoints.openai.api_server --model /data/OneKE_v1-awq --host 0.0.0.0 --port 80 --max-model-len 1024 --gpu-memory-utilization 0.5

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2116848.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Pyramid: Real-Time LoRa Collision Decoding with Peak Tracking技术思考与解读

一点点个人的论文解读、技术理解&#xff0c;难免会有错误&#xff0c;欢迎大家一起交流和学习~~ &#x1f600;作者关于lora的系列文章从问题陈述到方法论的提出&#xff0c;再到实验评估&#xff0c;文章结构条理清晰&#xff0c;逻辑性强&#xff0c;并深入分析了LoRa信号处…

力扣刷题(5)

整数转罗马数字 整数转罗马数字-力扣 思路&#xff1a; 把各十百千位可能出现的情况都列出来&#xff0c;写成一个二维数组找出该数的各十百千位&#xff0c;与数组中的罗马元素对应 const char* ch[4][10]{{"", "I", "II", "III"…

webpack - 五大核心概念和基本配置(打包一个简单HTML页面)

// 五大核心概念 1. entry&#xff08;入口&#xff09; 指示Webpack从哪个文件开始打包2. output&#xff08;输出&#xff09; 指示Webpack打包完的文件输出到哪里去&#xff0c;如何命名等3. loader&#xff08;加载器&#xff09; webpack本身只能处理js&#xff0c;json等…

Bev pool 加速(2):自定义c++扩展

文章目录 1. c++扩展2. 案例2.1 案例12. 1.1 代码实现(1) c++ 文件(2) setup.py编写(3) python 代码编写2.1 案例1在bevfusion论文中,将bev_pooling定义为view transform中的效率瓶颈,bevfusion 主要就是对bev_pooling进行了加速,使得视图转换的速度提高了40倍,延迟从500ms…

charles配置安卓抓包(避坑版)

下载Charleshttps://www.charlesproxy.com/安装&#xff0c;疯狂点击下一步即可注册&#xff1a;打开Charles&#xff0c;选择“Help”菜单中的“Register Charles”&#xff0c;进网站生成密钥&#xff1a;https://www.zzzmode.com/mytools/charles/,将生成的密钥填入注册重启…

JavaScript练手小技巧:利用鼠标滚轮控制图片轮播

近日&#xff0c;在浏览网站的时候&#xff0c;发现了一个有意思的效果&#xff1a;一个图片轮播&#xff0c;通过上下滚动鼠标滚轮控制图片的上下切换。 于是就有了自己做一个的想法&#xff0c;顺带复习下鼠标滚轮事件。 鼠标滚轮事件&#xff0c;参考这篇文章&#xff1a;…

Vue 3 + Element Plus 封装单列控制编辑的可编辑表格组件

在Web应用开发中&#xff0c;经常需要提供表格数据的编辑功能。本文将介绍如何使用Vue 3结合Element Plus库来实现一个支持单列控制编辑功能的表格&#xff0c;并通过封装组件的形式提高代码的复用性。通过本教程&#xff0c;你将学会如何构建一个具备单列控制编辑功能的表格组…

Cloudways搭建WordPress外贸独立站完整教程(1)

验证邮件发送完成后&#xff0c;就等待Cloudways的回复邮件&#xff0c;一般24小时之内就会收到激活的邮件。 Cloudways账号升级 激活成功后还需要账户升级&#xff0c;Cloudways提供了为期3天的免费试用体验。如果在试用期结束之前未绑定信用卡以升级账户&#xff0c;试用期…

UE5学习笔记21-武器的射击功能

一、创建C类 创建武器子弹的类&#xff0c;创建生产武器子弹的类&#xff0c;创建弹壳的类&#xff0c;生产武器子弹的类的父类是武器的类 创建后如图&#xff0c;ProjectileMyWeapon类(产生子弹的类)继承自weapon类&#xff0c;Projectile(子弹的类)&#xff0c;Casing(弹壳声…

Claude 3.5:如何高效辅助编程——全面入门指南

在现代编程世界中&#xff0c;AI的角色越来越重要&#xff0c;尤其是在代码生成、调试、文档生成等领域中&#xff0c;AI工具的运用让开发者可以更高效地完成任务。Claude 3.5是一个这样的AI助手&#xff0c;凭借其强大的自然语言处理能力&#xff0c;在编程中提供了大量的支持…

Sui Narwhal and Tusk 共识协议笔记

一、Overwiew [ 整体流程: Client提交transaction到Narwhal Mempool。(Narwhal Mempool由一组worker和一个primary组成) Mempool接收到的Transaction->以Certificate的形式进行广播 由worker将交易打包为Batch,worker将Batch的hash发送给primary primary上运行了mempo…

mysql笔记4(数据类型)

数据库的数据类型应该是数据库架构师(DBA)和产品经理沟通后依据公司的项目、业务而定的&#xff0c;而且会不停地变化。数据类型的选择方面没有一个统一的标准&#xff0c;但是应该符合业务、项目的逻辑标准。 菜鸟教程 Mysql 数据类型 文章目录 1. int类型2. 浮点数3. 定点数4…

C# Dotfuscator加密dll设置流程

按照以下步骤处理后&#xff0c;反编译基本只能看到函数名&#xff0c;看不到源代码 1.Input 2.Setting 3.Rename 4.Rename 5.Control Flow 6.String Encryption 7.Output

【stata】自写命令分享dynamic_est,一键生成dynamic effect

1. 命令简介 dynamic_est 是一个用于可视化动态效应&#xff08;dynamic effect&#xff09;的工具。它特别适用于事件研究&#xff08;event study&#xff09;或双重差分&#xff08;Difference-in-Differences, DID&#xff09;分析。通过一句命令即可展示动态效应&#xf…

EasyPlayer.js网页H5 Web js播放器能力合集

最近遇到一个需求&#xff0c;要求做一款播放器&#xff0c;发现能力上跟EasyPlayer.js基本一致&#xff0c;满足要求&#xff1a; 需求 功性能 分类 需求描述 功能 预览 分屏模式 单分屏&#xff08;单屏/全屏&#xff09; 多分屏&#xff08;2*2&#xff09; 多分屏…

JVM面试(七)G1垃圾收集器剖析

概述 上一章我们说了&#xff0c;G1收集器&#xff0c;它属于里程碑式的发展&#xff0c;开创了面向局部收集垃圾的概念。专门针对多核处理器以及大内存的机器。在JDK9中&#xff0c;更是呗指定为官方的GC收集器。满足高吞吐的通知满足GC的STW停顿时间尽可能的短。 虽然现在我…

恶意代码分析-Lab01-01

实验一 这个实验使用Lab01-01.exe和Lab01-01.d文件,使用本章描述的工具和技术来获取关于这些文件的信息。 问题: 将文件上传至 http:/www.VirusTotal.com/进行分析并查看报告。文件匹配到了已有的反病毒软件特征吗?这些文件是什么时候编译的?这两个文件中是否存在迹象说明它…

如何在docker容器中导入.sql文件

一、准备工作 确保容器运行&#xff1a; 首先确认包含 MySQL 服务的 Docker 容器正在运行。可以通过 docker ps 命令查看正在运行的容器列表。如果容器未运行&#xff0c;使用 docker start [container_id] 命令启动容器。 准备数据库文件&#xff1a; 将需要导入的数据库文件&…

VMware安装Ubuntu虚拟机

Ubuntu镜像下载 https://ubuntu.com/download/desktop 创建虚拟机 1.典型配置 2.稍后安装操作系统 3.选择操作系统&#xff0c;Linux&#xff0c;ubuntu64位 3.设置虚拟机名称和安装位置 4.磁盘大小&#xff0c;存储为单个文件 安装系统 1.选择镜像 2.开启虚拟机 2.安装Ub…

CTFHub技能树-Git泄漏-Log

目录 一、前提知识 1.git泄漏原理 ​编辑 2.git文件泄漏造成后果 3.利用方法 (1) GitHack是一个.git泄露利用脚本&#xff0c;通过泄露的.git文件夹下的文件&#xff0c;还原重建工程源代码。渗透测试人员、攻击者&#xff0c;可以进一步审计代码&#xff0c;挖掘&#x…