基于第二代 ChatGLM2-6B P-Tuning v2 微调训练医疗问答任务

news2025/1/25 4:44:18

今天是教师节,恭祝全体老师们节日快乐!😊

一、ChatGLM2-6B

在本专栏前面文章中实验了使用 ChatYuan-large-v2 Freeze 微调训练医疗问答任务,训练后效果整体还可以,这篇文章继续探索使用最近比较火的 ChatGLM 官方推出的 p-tuning-v2 的方式训练医疗问答任务。而对于 ChatGLM 模型则使用新出不久的 ChatGLM2-6B

ChatGLM2-6BChatGLM-6B 的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,同时引入了许多新特性,如:更强大的性能、更长的上下文、更高效的推理、更开放的协议 等。

更多详细的介绍可参考官方 github

官方 github 地址:https://github.com/THUDM/ChatGLM2-6B

P-tuning v2 微调技术利用 deep prompt tuning,即对预训练 Transformer 的每一层输入应用 continuous promptsdeep prompt tuning 增加了 continuo us prompts 的能力,并缩小了跨各种设置进行微调的差距,特别是对于小型模型和困难任务。

在这里插入图片描述
上图左边为 P-Tuning,右边为P-Tuning v2P-Tuning v2 层与层之间的 continuous prompt 是相互独立的。

论文地址:https://arxiv.org/pdf/2110.07602.pdf

github地址:https://github.com/THUDM/P-tuning-v2

二、ChatGLM2-6B 模型下载

huggingface 地址:https://huggingface.co/THUDM/chatglm2-6b/tree/main

在这里插入图片描述

三、数据集处理

数据集还是使用 GitHub 上的 Chinese-medical-dialogue-data 中文医疗对话数据集。

GitHub 地址如下:

https://github.com/Toyhom/Chinese-medical-dialogue-data

数据分了 6 个科目类型:

在这里插入图片描述

数据格式如下所示:

在这里插入图片描述

其中 ask 为病症的问题描述,answer 为病症的回答。

整体加起来数据比较多,这里为了演示效果,只训练 内科、肿瘤科、儿科、外科 四个科目的数据,并且每个科目取前 10000 条数据进行训练、2000 条数据进行验证:

import json
import pandas as pd

data_path = [
    "./data/Chinese-medical-dialogue-data-master/Data_数据/IM_内科/内科5000-33000.csv",
    "./data/Chinese-medical-dialogue-data-master/Data_数据/Oncology_肿瘤科/肿瘤科5-10000.csv",
    "./data/Chinese-medical-dialogue-data-master/Data_数据/Pediatric_儿科/儿科5-14000.csv",
    "./data/Chinese-medical-dialogue-data-master/Data_数据/Surgical_外科/外科5-14000.csv",
]

train_json_path = "./data/train.json"
val_json_path = "./data/val.json"
# 每个数据取 10000 条作为训练
train_size = 10000
# 每个数据取 2000 条作为验证
val_size = 2000


def doHandler():
    train_f = open(train_json_path, "a", encoding='utf-8')
    val_f = open(val_json_path, "a", encoding='utf-8')
    for path in data_path:
        data = pd.read_csv(path, encoding='ANSI')
        train_count = 0
        val_count = 0
        for index, row in data.iterrows():
            ask = row["ask"]
            answer = row["answer"]
            line = {
                "content": ask,
                "summary": answer
            }
            line = json.dumps(line, ensure_ascii=False)
            if train_count < train_size:
                train_f.write(line + "\n")
                train_count = train_count + 1
            elif val_count < val_size:
                val_f.write(line + "\n")
                val_count = val_count + 1
            else:
                break
    print("数据处理完毕!")
    train_f.close()
    val_f.close()


if __name__ == '__main__':
    doHandler()

处理之后可以看到两个生成的文件:

在这里插入图片描述

四、P-Tuning v2 训练

拉取官网训练脚本:

git clone https://github.com/THUDM/ChatGLM2-6B

下载相应依赖:

pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

此外还需安装:

pip install rouge_chinese nltk jieba datasets -i https://pypi.tuna.tsinghua.edu.cn/simple

修改 ptuning 下的 train.sh 文件:

PRE_SEQ_LEN=300
LR=2e-2
NUM_GPUS=1

torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \
    --do_train \
    --train_file data/train.json \
    --validation_file data/val.json \
    --preprocessing_num_workers 10 \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path /home/chatglm2/chatglm-6b \
    --output_dir output/adgen-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 300 \
    --max_target_length 1024 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 3000 \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4

其中 参数解释如下:

–standalone` 以单机模式训练。

–nnodes` 节点数。这里只有一个节点,设置为 1。

–nproc-per-node` 每个节点上的进程数。

–do_train` 执行训练任务。

–train_file` 训练数据文件路径, 上面生成的 train.json 文件。

–validation_file` 验证数据文件路径, 上面生成的 val.json 文件。

–preprocessing_num_workers` 指定数据预处理时的 workers 数。

–prompt_column` 输入信息的字段名称。

–response_column` 输出信息的字段名称。

–overwrite_cache` 覆盖缓存文件。

–model_name_or_path` 预训练模型的名称或路径,注意这里我是用的下载后的模型存放地址,需要修改为你的。

–output_dir` 模型保存目录。

–overwrite_output_dir` 覆盖输出目录。

–max_source_length` 输入文本的最大长度。

–max_target_length` 输出文本的最大长度。

–per_device_train_batch_size` 训练时的批次大小。

–per_device_eval_batch_size` 验证时的批次大小。

–gradient_accumulation_steps` 累积多少个梯度之后再进行一次反向传播。

–predict_with_generate` 预测时使用生成模式。

–max_steps` 最大训练轮数。

–logging_steps` 多少轮打印一次日志。

–save_steps` 多少轮保存一次模型。

–learning_rate` 初始学习率。

–pre_seq_len` 预处理时选取的序列长度。

–quantization_bit` 量化位大小。

执行后可以看到如下打印日志:

在这里插入图片描述

训练过程:

在这里插入图片描述

训练结束:

在这里插入图片描述

最后在 output 目录下可以看到每 1000 步保存的模型。

五、模型测试

5.1 单独调用测试:

from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModel, AutoConfig
import uvicorn, json, datetime
import torch
import os


def main():
    pre_seq_len = 300
    # 训练权重地址
    checkpoint_path = "ptuning/output/adgen-chatglm2-6b-pt-300-2e-2/checkpoint-3000"

    tokenizer = AutoTokenizer.from_pretrained("chatglm-6b", trust_remote_code=True)
    config = AutoConfig.from_pretrained("chatglm-6b", trust_remote_code=True, pre_seq_len=pre_seq_len)
    model = AutoModel.from_pretrained("chatglm-6b", config=config, device_map="auto", trust_remote_code=True)
    prefix_state_dict = torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
    new_prefix_state_dict = {}
    for k, v in prefix_state_dict.items():
        if k.startswith("transformer.prefix_encoder."):
            new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
    model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
    # 量化
    model = model.quantize(4)
    model.eval()

    # 问题
    question = "突然感到了不适,去检查后竟然得了这个病,请问:宝宝白天爱磨牙会是哪些情况呢"

    response, history = model.chat(tokenizer,
                                   question,
                                   history=[],
                                   max_length=2048,
                                   top_p=0.7,
                                   temperature=0.95)

    print("回答:", response)

    if torch.backends.mps.is_available():
        torch.mps.empty_cache()


if __name__ == '__main__':
    main()

在这里插入图片描述

回答: 孩子磨牙可能会是缺钙引来的,建议带孩子去医院仔细检查下微量元素,明确病因后有针对性的治疗。平时要留意孩子的饮食卫生,防止排便辛辣刺激性食物,多给孩子喝温开水,多吃蔬菜水果,消化维生素,增进胃肠道扭动。对于家长朋友们来说,要尽可能的帮助孩子及时治疗疾病,另外宝宝在日常生活中饮食也要注意,要营养的均衡,不要过度进补也不要营养不良哦。

5.2 封装成 Api 测试

from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModel, AutoConfig
import uvicorn, json, datetime
import torch
import os

app = FastAPI()

# 允许所有域的请求
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.post("/")
async def create_item(request: Request):
    global model, tokenizer
    json_post_raw = await request.json()
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)
    prompt = json_post_list.get('prompt')
    history = json_post_list.get('history')
    max_length = json_post_list.get('max_length')
    top_p = json_post_list.get('top_p')
    temperature = json_post_list.get('temperature')
    response, history = model.chat(tokenizer,
                                   prompt,
                                   history=history,
                                   max_length=max_length if max_length else 2048,
                                   top_p=top_p if top_p else 0.7,
                                   temperature=temperature if temperature else 0.95)
    now = datetime.datetime.now()
    time = now.strftime("%Y-%m-%d %H:%M:%S")
    answer = {
        "response": response,
        "history": history,
        "status": 200,
        "time": time
    }
    log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
    print(log)
    if torch.backends.mps.is_available():
        torch.mps.empty_cache()
    return answer


if __name__ == '__main__':
    pre_seq_len = 300
    checkpoint_path = "ptuning/output/adgen-chatglm2-6b-pt-300-2e-2/checkpoint-3000"

    tokenizer = AutoTokenizer.from_pretrained("chatglm-6b", trust_remote_code=True)
    config = AutoConfig.from_pretrained("chatglm-6b", trust_remote_code=True, pre_seq_len=pre_seq_len)
    model = AutoModel.from_pretrained("chatglm-6b", config=config, device_map="auto", trust_remote_code=True)
    prefix_state_dict = torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
    new_prefix_state_dict = {}
    for k, v in prefix_state_dict.items():
        if k.startswith("transformer.prefix_encoder."):
            new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
    model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
    ## 量化
    model = model.quantize(4)
    model = model.cuda()
    model.eval()
    uvicorn.run(app, host='0.0.0.0', port=8103, workers=1)

使用 postMan 测试:

在这里插入图片描述

最后测试下原有知识的影响:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/995983.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

docker镜像详解

目录 什么是docker镜像镜像相关命令docker pulldocker imagesdocker searchdocker rmi导出 / 导入镜像 镜像分层镜像摘要镜像摘要的作用分发散列值 什么是docker镜像 Docker镜像是Docker容器的基础组件&#xff0c;它包含了运行一个应用程序所需的一切&#xff0c;包括代码、运…

Sharding-Jdbc(2):Sharding-Jdbc入门案例

1 前置条件 搭建读写分离的数据库环境,搭建方法如下文,目标数据库test Mysql性能优化(5):主从同步原理与实现_mysql主从配置优化_不死鸟.亚历山大.狼崽子的博客-CSDN博客 2 新建maven项目 3 pom引入依赖 <?xml version="1.0" encoding="UTF-8"…

《protobuf》基础语法2

文章目录 枚举类型ANY 类型oneof 类型map 类型改进通讯录实例 枚举类型 protobuf里有枚举类型&#xff0c;定义如下 enum PhoneType {string home_addr 0;string work_addr 1; }同message一样&#xff0c;可分为 嵌套定义&#xff0c;文件内定义&#xff0c;文件外定义。不…

二维码智慧门牌管理系统:智能化地址管理,提升社会治理效率

文章目录 前言一、地址管理挑战二、二维码智慧门牌管理系统解决方案 前言 随着科技的飞速发展&#xff0c;我们的生活正经历前所未有的变革。尤其是智能化技术&#xff0c;已经深刻影响了我们的日常生活和工作环境。然而&#xff0c;在某些领域&#xff0c;如地址管理和社会治…

十大免费好用的视频软件推荐,新手小白必备

很多人都在使用视频软件进行编辑&#xff0c;那么你们知道哪些视频软件是免费的而且还好用&#xff1f; 现在很多小伙伴比较喜欢用手机编辑视频&#xff0c;而这些剪辑软件具有领先的智能化AI技术&#xff0c;让你不用绿幕就能一键挖出图像&#xff0c;不仅可以一键挖出三维人…

获取板块分类并展示

板块分类也会变动&#xff0c;偶尔看下&#xff0c;利于总体分析大盘 https:dapanyuntu.com/ 该网站含有板块信息 分析接口 搜素关键字 拷贝curl到curl解析工具&#xff0c;去掉无用的参数&#xff0c;生成requests代码 尝试nginx反代接口 server {listen 443;loca…

汇率失守7.3关口

号外&#xff1a;9.8教链内参《被判入狱1万年》。 拉锯多时的离岸人民币汇率USDCNH失守7.3关口&#xff0c;隔夜冲破7.36的高位。 为什么7.3是一个关键关口&#xff1f;因为这里是2022年10月底、11月初时曾经测试过的支撑位&#xff08;从CNH的角度说&#xff09;。 如果支撑位…

微信视频号挂公众号文章链接新方法:不限次数,不限号

当看到自己身边的人&#xff0c;通过我分享的方法绑定成功&#xff0c;那是开心到起飞了。 因为我知道&#xff1a;外面不靠谱的人很多&#xff01;分享不靠谱方法的人&#xff0c;可谓是更多。 为什么我要主动分享视频号评论区挂公众号文章链接&#xff1f;总有人伸张正义&a…

java中log使用总结

目录 一、概述1.1. 核心日志框架1.2 门面日志框架 二、最佳实践2.1 核心日志框架API包2.2 门面日志框架依赖2.3 集成使用2.3.1 集成jcl2.3.2 集成slf4j2.3.2.1 slf4j集成单一框架2.3.2.2 slf4j整合混合框架 三、总结3.1 所有相关包3.1.1 核心日志框架包3.1.2 门面日志框架3.1.3…

Python元类(metaclass)

Python 是一种强大的编程语言&#xff0c;一部分得益于其语言设计中独特的元类&#xff08;Metaclass&#xff09;机制。尽管元类的概念在刚开始接触时可能会让人感到困惑&#xff0c;但一旦理解了它们的工作原理和应用方式&#xff0c;我们就可以用它们做出强大且灵活的抽象。…

无涯教程-JavaScript - COUPDAYS函数

描述 COUPDAYS函数返回包含结算日期的息票期限内的天数。 语法 COUPDAYS (settlement, maturity, frequency, [basis])争论 Argument描述Required/OptionalSettlement 证券的结算日期。 证券结算日期是指在发行日期之后将证券交易给买方的日期。 RequiredMaturity 证券的到…

Nginx重写功能

Nginx重写功能 一、Nginx常见模块二、访问路由location2.1location常用正则表达式2.2、location的分类2.3、location常用的匹配规则2.4、location优先级排列说明2.5、location示例2.6、location优先级总结2.7、实例2.7.1、location/{}与location/{}2.7.2、location/index.html{…

joplin更新后找不到文章

Joplin的数据默认是存储在C:\Users\Username.config\joplin-desktop下的。我修改为了D:\joplinnotes 这样就导致在升级覆盖安装的时候&#xff0c;笔记丢失路径。如果记不起之前笔记保存在哪里&#xff0c;也可以搜索类似文件来回忆之前自己保存笔记的位置 cache\ plugins\ re…

黑马JVM总结(三)

&#xff08;1&#xff09;栈内存溢出 方法的递归调用&#xff0c;没有设置正确的结束条件&#xff0c;栈会有用完的一天&#xff0c;导致栈内存溢出 可以修改栈的大小&#xff1a; 再次运行&#xff1a;减少了次数 案例二&#xff1a; 两个类的循环应用问题&#xff0c;导致Js…

注解生效激活(idea)

File---------settings-----------Build,Execution,Deployment-----------Compiler------- Annotation Processors

数据分析和可视化平台:Splunk Enterprise for mac v9.1.1激活版 兼容m1

Splunk Enterprise 是一个数据分析和可视化平台&#xff0c;可帮助企业理解其数据。虽然没有适用于 Mac OS 的 Splunk Enterprise 官方版本&#xff0c;但他们确实为 Mac OS 提供了一个名为“Splunk Light”的应用程序&#xff0c;它提供了基本的数据索引、搜索和仪表板。或者&…

「网页开发|前端开发|Vue」07 前后端分离:如何在Vue中请求外部数据

本文主要介绍两种在Vue中访问外部API获取数据的方式&#xff0c;通过让Vue通过项目外部的接口来获取数据&#xff0c;而不是直接由项目本身进行数据库交互&#xff0c;可以实现前端代码和后端代码的分离&#xff0c;让两个部分的代码编写更独立高效。 文章目录 本系列前文传送…

SpringMVC的常用注解,参数传递以及页面跳转的使用

目录 slf4j 常用注解 RequestMapping RequestParam RequestBody PathVariable 参数传递 首先在pom.xml配置文件中导入SLF4J的依赖 基础类型String 复杂类型 RequestParam PathVariable RequestBody 增删改查 返回值 void返回值 String返回值 modelString …

“高效记录收支明细,按时间轻松查找借款信息“

我们有时候要去查找借款信息&#xff0c;只记得住借款记录的日期&#xff0c;想通过日期来进行筛选出借款信息&#xff0c;要如何进行操作&#xff1f;今天就让小编来教教大家要如何操作。 第一步&#xff0c;我们要打开【晨曦记账本】&#xff0c;并登录账本。 第二步&#x…

弃用http改用https的缘故,与密钥的使用,证书意义

为何弃用http协议 在十几年前&#xff0c;我们的传输协议是http协议&#xff0c;为何到了如今改成了https协议呢&#xff1f;为了安全的考虑。 在http协议中&#xff0c;我们的内容是透明的&#xff0c;不被保护的&#xff0c;在黑客等恶意分子的面前&#xff0c;信息极其任意…