NLP(六十三)使用Baichuan-7b模型微调人物关系分类任务

news2024/11/15 21:41:46

任务介绍

  人物关系分类指的是对文本中的两个人物,在特定的关系列表中,判断他们之间的人物关系。以样本亲戚 1837年6月20日,威廉四世辞世,他的侄女维多利亚即位。为例,其中亲戚为人物关系,威廉四世为实体1,维多利亚为实体2。
  笔者自己利用业余时间标注的样本数据有3881条,分布如下图:
人物关系分布图
  对上述数据集进行划分,训练集与测试集的比例为8:2,其中训练集3105条,测试集776条。
  在文章NLP(二十一)人物关系抽取的一次实战中,当时的标注数据为2900多条,使用BERT向量提取+BiGRU+Attention模型,取得的平均F1值为78.97%.
  在文章NLP(四十二)人物关系分类的再次尝试中,借助BERT微调(当作分类任务),取得的平均F1值为82.69%.
  在文章NLP(四十五)R-BERT在人物关系分类上的尝试及Keras代码复现中,借助专用于关系分类任务的R-BERT模型,在Chinese Roberta模型上取得的F1值为85.35%.
  在本文中,将尝试使用大模型(Large Language Model, LLM)中的中文模型代表Baichuan-7b, 对人物关系分类任务进行微调,看看它的表现。

好的提示

  在开始模型微调之前,我们需要一个好的提示(Prompt),我们借助GPT-4:
GPT-4给出的关于关系分类的Prompt
  别小看了Prompt的威力,笔者在使用微调模型过程中,发现自己写的Prompt与GPT-4给出的Prompt,在训练结果F1值上可能相差3-4%。可见Prompt工程的重要性!

模型微调

  我们使用上述Prompt,加工数据集(当作多轮对话任务),格式如下:

{
  "conversation_id": 1,
  "category": "relation classification",
  "conversation": [
    {
      "human": "给定以下标签:['不确定', '夫妻', '父母', '兄弟姐妹', '上下级', '师生', '好友', '同学', '合作', '同一个人', '情侣', '祖孙', '同门', '亲戚'],请在以下句子中分析并分类实体之间的关系:'与李源澄论戴东原书'在这个句子中,戴东原和李源澄之间的关系应该属于哪个标签?",
      "assistant": "不知道"
    }
  ]
}

  使用Firefly框架进行模型微调,访问网址为:https://github.com/yangjianxin1/Firefly.本文基于Baichuan-7b为基座模型,采用QLora方式训练,训练参数如下:

{
    "output_dir": "output/firefly-baichuan-7b-people",
    "model_name_or_path": "/home/test/baichun_7b",
    "train_file": "./data/train.jsonl",
    "num_train_epochs": 5,
    "per_device_train_batch_size": 8,
    "gradient_accumulation_steps": 2,
    "learning_rate": 2e-4,
    "max_seq_length": 256,
    "logging_steps": 100,
    "save_steps": 100,
    "save_total_limit": 1,
    "lr_scheduler_type": "constant_with_warmup",
    "warmup_steps": 100,
    "lora_rank": 64,
    "lora_alpha": 16,
    "lora_dropout": 0.05,

    "gradient_checkpointing": true,
    "disable_tqdm": false,
    "optim": "paged_adamw_32bit",
    "seed": 42,
    "fp16": true,
    "report_to": "tensorboard",
    "dataloader_num_workers": 0,
    "save_strategy": "steps",
    "weight_decay": 0,
    "max_grad_norm": 0.3,
    "remove_unused_columns": false
}

  使用命令行torchrun --nproc_per_node=2 train_qlora.py --train_args_file train_args/qlora/baichuan-7b-sft-qlora.json进行训练,训练时间大约20分钟,最终的train loss为0.0273。
  在Firefly框架中设置好merge_lora.py中的模型文件路径,将adapter的权重与Baichuan-7b模型合并,合并得到新文件firefly-baichuan-7b-people-merge
  在Firefly框架,仿造script/chat/single_chat.py文件,将其改写成API调用方式的文件single_chat_server.py,代码如下:

# -*- coding: utf-8 -*-
# @place: Pudong, Shanghai
# @file: single_chat_server.py
# @time: 2023/7/25 22:27
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# 单轮对话web服务
from flask import Flask, request, jsonify

app = Flask("single_chat_server")


@app.route('/people_rel_cls', methods=['POST'])
def predict():
    req_dict = request.json
    text, people1, people2 = req_dict["text"], req_dict["people1"], req_dict["people2"]
    text = text.strip()
    content = f"给定以下标签:['不确定', '夫妻', '父母', '兄弟姐妹', '上下级', '师生', '好友', '同学', " \
              f"'合作', '同一个人', '情侣', '祖孙', '同门', '亲戚']," \
              f"请在以下句子中分析并分类实体之间的关系:'{text}'" \
              f"在这个句子中,{people1}{people2}之间的关系应该属于哪个标签?"
    print(content)
    input_ids = tokenizer(content, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True,
            top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty,
            eos_token_id=tokenizer.eos_token_id
        )
    outputs = outputs.tolist()[0][len(input_ids[0]):]
    response = tokenizer.decode(outputs)
    print(outputs, response)
    response = response.strip().replace(text, "").replace('</s>', "").replace('<s>', "").strip()
    return jsonify({"result": response})


if __name__ == '__main__':
    model_name = "/home/test/Firefly/script/checkpoint/firefly-baichuan-7b-people-merge"
    max_new_tokens = 5
    top_p = 0.9
    temperature = 0.01
    repetition_penalty = 1.0
    device = 'cuda:0'
    input_pattern = '<s>{}</s>'
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
    ).to(device).eval()
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True,
        # llama不支持fast
        use_fast=False if model.config.model_type == 'llama' else True
    )
    print("model loaded!")
    app.run(host="0.0.0.0", port=5000, threaded=True)

  使用API调用方式可以对测试集进行模型评估。

结果比对

  不同模型(包括BERT时代前后的模型方法)的评估结果(均为当时模型的SOTA结果或接近SOTA结果)如下:

模型方法基座模型F1值说明
BERT向量提取+BiGRU+AttentionBiGRU+Attention78.97%BERT模型作为特征提取处理
BERT cls finetuningBERT82.69%当作文本分类任务处理
R-BERTchinese-roberta-wwm-ext85.35%BERT时代的关系分类模型代表
R-BERTchinese-roberta-wwm-ext-large87.22%BERT时代的关系分类模型代表
QLoraBaichuan-7b88.25%其它参数上文给出,epoch=5
QLoraBaichuan-7b89.15%其它参数上文给出,epoch=10

存在问题

  在大模型时代中,大模型突破了以前NLP任务的范畴,走向了更加通用化,从上述结果中,我们也不难发现,大模型(Baichuan-7b)在传统的NLP任务(如笔者自己的人物关系数据集)上取得了更好的结果,达到了新的SOTA,这是符合我们的认知的。
  但在笔者模型微调过程中,也发现了不少的问题或有待于进一步验证的地方,记录如下:

  • Baichuan-7b模型取得了SOTA结果,但同样的训练框架和训练参数,Baichuan-13B-Base模型却表现惨淡,甚至很差
  • 不同的Prompt对于训练结果的影响,比如笔者自己写的Prompt和GPT-4写的Prompt对于最终结果差距较大,相差3-4%
  • 相同的模型,采用full, lora, qlora三种形式进行SFT,训练结果会有何不同

  后续笔者将尝试使用不同的训练框架进行Baichuan-13B-Base的微调。

总结分享

  本文主要介绍如何使用Baichuan-7b模型微调人物关系分类任务,并比BERT时代的模型取得了进步,达到了新的SOTA.
  本文的想法很朴素,主要是想测试下LLM在传统NLP人物上的表现,也是对于笔者自己的人物关系数据集的一次效果提升,这也是笔者一直在关注和构建的数据集。这一次,大模型再一次让我震惊!
  本文使用的人物关系数据集已开源至HuggingFace Datasets, 网址为: https://huggingface.co/datasets/jclian91/people_relation_classification .
  本人的个人博客网址为:https://percent4.github.io/ ,欢迎大家关注~

参考文献

  1. NLP(二十一)人物关系抽取的一次实战: https://percent4.github.io/2023/07/08/NLP%EF%BC%88%E4%BA%8C%E5%8D%81%E4%B8%80%EF%BC%89%E4%BA%BA%E7%89%A9%E5%85%B3%E7%B3%BB%E6%8A%BD%E5%8F%96%E7%9A%84%E4%B8%80%E6%AC%A1%E5%AE%9E%E6%88%98/
  2. NLP(四十二)人物关系分类的再次尝试: https://percent4.github.io/2023/07/10/NLP%EF%BC%88%E5%9B%9B%E5%8D%81%E4%BA%8C%EF%BC%89%E4%BA%BA%E7%89%A9%E5%85%B3%E7%B3%BB%E5%88%86%E7%B1%BB%E7%9A%84%E5%86%8D%E6%AC%A1%E5%B0%9D%E8%AF%95/
  3. NLP(四十五)R-BERT在人物关系分类上的尝试及Keras代码复现: https://percent4.github.io/2023/07/10/NLP%EF%BC%88%E5%9B%9B%E5%8D%81%E4%BA%94%EF%BC%89R-BERT%E5%9C%A8%E4%BA%BA%E7%89%A9%E5%85%B3%E7%B3%BB%E5%88%86%E7%B1%BB%E4%B8%8A%E7%9A%84%E5%B0%9D%E8%AF%95%E5%8F%8AKeras%E4%BB%A3%E7%A0%81%E5%A4%8D%E7%8E%B0/
  4. 微调百川Baichuan-13B保姆式教程,手把手教你训练百亿大模型: https://mp.weixin.qq.com/s/ZBY6kbogHjbCQvZBzNEqag
  5. HuggingFace Dataset people_relation_classification: https://huggingface.co/datasets/jclian91/people_relation_classification

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/810849.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vins调试的注意事项

1、摄像头的内参和畸变矫正系数 这个系数不对&#xff0c;没法做&#xff0c;因为下一步没法做对。这个会导致系统无法初始化。 2、对畸变的像素点&#xff0c;求得归一化坐标的方法 理解不同矫正模型的原理&#xff0c;确保矫正对了&#xff0c;得到z1平面的去畸变点。 3、摄…

python皮卡丘编程代码教程,用python打印皮卡丘

大家好&#xff0c;小编来为大家解答以下问题&#xff0c;如何用print函数打印一只皮卡丘&#xff0c;用python如何打印丘比特之心&#xff0c;现在让我们一起来看看吧&#xff01;

CCL 2023 电信网络诈骗案件分类评测-第一名方案

1 任务内容 1.1 任务背景 2022年12月1日起&#xff0c;新出台的《反电信网络诈骗犯罪法》正式施行&#xff0c;表明了我国治理当前电信网络诈骗乱象的决心。诈骗案件分类问题是打击电信网路诈骗犯罪过程中的关键一环&#xff0c;根据不同的诈骗方式、手法等将其分类&#xff…

13个ChatGPT类实用AI工具汇总

在ChatGPT爆火后&#xff0c;各种工具如同雨后春笋一般层出不穷。以下汇总了13种ChatGPT类实用工具&#xff0c;可以帮助学习、教学和科研。 01 / ChatGPT for google/ 一个浏览器插件&#xff0c;可搭配现有的搜索引擎来使用 最大化搜索效率&#xff0c;对搜索体验的提升相…

【机器学习】Linear Regression

Model Representation 1、问题描述2、表示说明3、数据绘图4、模型函数5、预测总结附录 1、问题描述 一套 1000 平方英尺 (sqft) 的房屋售价为300,000美元&#xff0c;一套 2000 平方英尺的房屋售价为500,000美元。这两点将构成我们的数据或训练集。面积单位为 1000 平方英尺&a…

C++ 类和对象篇(零) 面向过程 和 面向对象

目录 一、面向过程 二、面向对象 三、两种编程思想的比较 四、C和C 一、面向过程 1.是什么&#xff1f; 是一种以解决问题的过程为中心的编程思想。即先分析出解决问题所需要的步骤&#xff0c;然后用函数把这些步骤一步一步实现。 2.为什么&#xff1f; 面向过程就纯粹是分析…

基于x-scan扫描线的3D模型渲染算法

基于x-scan算法实现的z-buffer染色。c#语言&#xff0c;.net core framework 3.1运行。 模型是读取3D Max的obj模型。 x-scan算法实现&#xff1a; public List<Vertex3> xscan() {List<Vertex3> results new List<Vertex3>();SurfaceFormula formula g…

SAP 自定义BADI增强点

应用场景 标准化代码中预留客制化部分&#xff0c;保证代码主体完整性&#xff0c;可以在预留增强位置预留两种类型的增强处理&#xff0c;其一为标准增强类型的&#xff0c;增强部分代码属于增加的逻辑&#xff0c;其二对于部分多样化的逻辑&#xff0c;使用优先执行默认逻辑&…

Java常用API:Object、Objects、包装类

Object类API toString 返回字符串类型 equals 默认比较的是地址 此时返回的是 false 可以在类中重写equals 方法 比较内容 如果内容一样就返回true clone 不能在测试类中用&#xff0c;必须在创建的类中重写克隆方法 还必须要有接口&#xff0c;说明这个对象有这个能力克隆 …

增量预训练baichuan-13b-chat遇到的那些坑

文章目录 前言资源deepspeed一、训练的坑二、推理的坑三、继续训练的坑总结前言 资源 单机两4090,如图 单卡24G,baichuan-13b-chat单卡推理需要至少26G,因此仅用一张卡,我们是无法加载百川13B的模型,所以,无论是推理还是训练,我们都必须并行! deepspeed 核心思想…

主干网络篇 | YOLOv8 更换主干网络之 VanillaNet |《华为方舟实验室最新成果》

论文地址:https://arxiv.org/pdf/2305.12972.pdf 代码地址:https://github.com/huawei-noah/VanillaNet 在基础模型的核心是“多样性即不同”,这一哲学在计算机视觉和自然语言处理方面取得了惊人的成功。然而,优化和Transformer模型固有的复杂性带来了挑战,需要转向简洁性…

Python-Python基础综合案例--数据可视化 - 地图可视化

版本说明 当前版本号[20230729]。 版本修改说明20230729初版 目录 文章目录 版本说明目录知识总览图Python基础综合案例--数据可视化 - 地图可视化基础地图使用案例效果视觉映射器 疫情地图-国内疫情地图案例效果实操设置全局配置选项 疫情地图-省级疫情地图案例效果实操 知…

spring拦截器 与统一格式

目录 前言模拟拦截器拦截器的实现原理什么是动态代理? 什么是静态代理静态代理与动态代理的区别两种常用的动态代理方式基于接口的动态代理基于类的动态代理 JDK Proxy 与 CGlib的区别 其他 统⼀访问前缀添加统⼀异常处理统⼀数据返回格式 前言 之前博客讲述了 , 关于SpringA…

Kotlin~Memento备忘录模式

概念 备忘录模式是一种行为型设计模式&#xff0c;用于捕获和存储对象的内部状态&#xff0c;并在需要时将对象恢复到之前的状态。 备忘录模式允许在不暴露对象内部实现细节的情况下&#xff0c;对对象进行状态的保存和恢复。 角色介绍 Originator&#xff1a;原发器&#x…

7.事件类型

7.1鼠标事件 案例-轮播图点击切换 需求&#xff1a;当点击左右的按钮&#xff0c;可以切换轮播图 分析: ①右侧按钮点击&#xff0c;变量&#xff0c;如果大于等于8&#xff0c;则复原0 ②左侧按钮点击&#xff0c;变量–&#xff0c;如果小于0&#xff0c;则复原最后一张 ③鼠…

Zotero ubuntu2023安装 关联 ubuntu文献翻译

一、准备下载的软件&#xff1a; Zotero | Downloads 1. Zotero-6.0.26_linux-x86_64.tar.bz2 下面是插件 zotfile-5.1.2-fx.xpi zotero-pdf-translate.xpi jasminum-v0.2.6.xpi 2.2.5 Tampermonkey 4.11.crx 所准备的文件&#xff0c;都已经在这个链接的压缩包下面 …

【机器学习】Multiple Variable Linear Regression

Multiple Variable Linear Regression 1、问题描述1.1 包含样例的X矩阵1.2 参数向量 w, b 2、多变量的模型预测2.1 逐元素进行预测2.2 向量点积进行预测 3、多变量线性回归模型计算损失4、多变量线性回归模型梯度下降4.1 计算梯度4.2梯度下降 首先&#xff0c;导入所需的库 im…

【Maven】让maven更高效,优化maven构建项目速度

打开idea的setting&#xff0c;找到maven&#xff0c;设置它多线程数&#xff0c;重启后即可&#xff01; 我这里是8&#xff0c;你们可以随便设置。 如下图&#xff1a;

Android 之 使用 Camera 拍照

本节引言 本节给大家带来的是Android中Camera的使用&#xff0c;简单点说就是拍照咯&#xff0c;无非两种&#xff1a; 1.调用系统自带相机拍照&#xff0c;然后获取拍照后的图片 2.要么自己写个拍照页面 本节我们来写两个简单的例子体验下上面的这两种情况~ 1.调用系统自带…

《向量数据库指南》:向量数据库Pinecone如何集成LangChain(二)

目录 创建嵌入 向量数据库 索引 创建向量存储并查询 生成式问答 创建嵌入 使用LangChain的OpenAI嵌入功能构建嵌入非常简单。我们首先需要运行下一个单元格,以添加我们的OpenAI API密钥: Python from getpass import getpassOPENAI_API_KEY = getpass("OpenAI…