基于ChatYuan-large-v2 微调训练 医疗问答 任务

news2024/9/9 3:14:53

一、ChatYuan-large-v2

上篇基于ChatYuan-large-v2 语言模型 Fine-tuning 微调训练了广告生成任务,总体生成效果还可以,但上篇文章的训练是微调的模型全部的参数,本篇文章还是以 ChatYuan-large-v2 作为基础模型,继续探索仅训练解码器层参数,并在医疗问答任务上的效果如何。

下面是上篇文章的地址:

基于ChatYuan-large-v2 语言模型 Fine-tuning 微调训练 广告生成 任务

二、数据集处理

数据集这里使用 GitHub 上的 Chinese-medical-dialogue-data 中文医疗对话数据集。

GitHub 地址如下:

https://github.com/Toyhom/Chinese-medical-dialogue-data

数据分了 6 个科目类型:

在这里插入图片描述

数据格式如下所示:

在这里插入图片描述

其中 ask 为病症的问题描述,answer 为病症的回答。

整体加起来数据比较多,这里为了演示效果,只训练 内科、肿瘤科、儿科、外科 四个科目的数据,并且每个科目取前 10000 条数据进行训练、2000 条数据进行验证:

import json
import pandas as pd

data_path = [
    "./data/Chinese-medical-dialogue-data-master/Data_数据/IM_内科/内科5000-33000.csv",
    "./data/Chinese-medical-dialogue-data-master/Data_数据/Oncology_肿瘤科/肿瘤科5-10000.csv",
    "./data/Chinese-medical-dialogue-data-master/Data_数据/Pediatric_儿科/儿科5-14000.csv",
    "./data/Chinese-medical-dialogue-data-master/Data_数据/Surgical_外科/外科5-14000.csv",
]

train_json_path = "./data/train.json"
val_json_path = "./data/val.json"
# 每个数据取 10000 条作为训练
train_size = 10000
# 每个数据取 2000 条作为验证
val_size = 2000


def doHandler():
    train_f = open(train_json_path, "a", encoding='utf-8')
    val_f = open(val_json_path, "a", encoding='utf-8')
    for path in data_path:
        data = pd.read_csv(path, encoding='ANSI')
        train_count = 0
        val_count = 0
        for index, row in data.iterrows():
            ask = row["ask"]
            answer = row["answer"]
            line = {
                "content": ask,
                "summary": answer
            }
            line = json.dumps(line, ensure_ascii=False)
            if train_count < train_size:
                train_f.write(line + "\n")
                train_count = train_count + 1
            elif val_count < val_size:
                val_f.write(line + "\n")
                val_count = val_count + 1
            else:
                break
    print("数据处理完毕!")
    train_f.close()
    val_f.close()


if __name__ == '__main__':
    doHandler()

处理之后可以看到两个生成的文件:

在这里插入图片描述

下面基于上面的数据格式构建 Dataset

from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import torch
import json


class SummaryDataSet(Dataset):

    def __init__(self, json_path: str, tokenizer, max_length=300):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.content_data = []
        self.summary_data = []
        with open(json_path, "r", encoding='utf-8') as f:
            for line in f:
                if not line or line == "":
                    continue
                json_line = json.loads(line)
                content = json_line["content"]
                summary = json_line["summary"]
                self.content_data.append(content)
                self.summary_data.append(summary)
        print("data load , size:", len(self.content_data))

    def __len__(self):
        return len(self.content_data)

    def __getitem__(self, index):
        source_text = str(self.content_data[index])
        target_text = str(self.summary_data[index])

        source = self.tokenizer.batch_encode_plus(
            [source_text],
            max_length=self.max_length,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        target = self.tokenizer.batch_encode_plus(
            [target_text],
            max_length=self.max_length,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )

        source_ids = source["input_ids"].squeeze()
        source_mask = source["attention_mask"].squeeze()
        target_ids = target["input_ids"].squeeze()
        target_mask = target["attention_mask"].squeeze()

        return {
            "source_ids": source_ids.to(dtype=torch.long),
            "source_mask": source_mask.to(dtype=torch.long),
            "target_ids": target_ids.to(dtype=torch.long)
        }

三、模型训练

构建训练过程,注意这里只训练解码层参数,因此需要将其他层的参数进行冻结:

# 只训练解码层
for name, param in model.named_parameters():
    if "decoder" not in name:
        param.requires_grad = False

整体训练过程如下:

# -*- coding: utf-8 -*-
import pandas as pd
import torch
from torch.utils.data import DataLoader
import os, time
from transformers import T5Tokenizer, T5ForConditionalGeneration
from gen_dataset import SummaryDataSet


def train(epoch, tokenizer, model, device, loader, optimizer):
    model.train()
    time1 = time.time()
    for _, data in enumerate(loader, 0):
        y = data["target_ids"].to(device, dtype=torch.long)
        y_ids = y[:, :-1].contiguous()
        lm_labels = y[:, 1:].clone().detach()
        lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100
        ids = data["source_ids"].to(device, dtype=torch.long)
        mask = data["source_mask"].to(device, dtype=torch.long)

        outputs = model(
            input_ids=ids,
            attention_mask=mask,
            decoder_input_ids=y_ids,
            labels=lm_labels,
        )
        loss = outputs[0]
        # 每100步打印日志
        if _ % 100 == 0 and _ != 0:
            time2 = time.time()
            print(_, "epoch:" + str(epoch) + "-loss:" + str(loss) + ";each step's time spent:" + str(
                float(time2 - time1) / float(_ + 0.0001)))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


def validate(tokenizer, model, device, loader, max_length):
    model.eval()
    predictions = []
    actuals = []
    with torch.no_grad():
        for _, data in enumerate(loader, 0):
            y = data['target_ids'].to(device, dtype=torch.long)
            ids = data['source_ids'].to(device, dtype=torch.long)
            mask = data['source_mask'].to(device, dtype=torch.long)

            generated_ids = model.generate(
                input_ids=ids,
                attention_mask=mask,
                max_length=max_length,
                num_beams=2,
                repetition_penalty=2.5,
                length_penalty=1.0,
                early_stopping=True
            )
            preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in
                     generated_ids]
            target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) for t in y]
            if _ % 100 == 0:
                print(f'Completed {_}')

            predictions.extend(preds)
            actuals.extend(target)
    return predictions, actuals


def T5Trainer(train_json_path, val_json_path, model_dir, batch_size, epochs, output_dir, max_length=300):
    tokenizer = T5Tokenizer.from_pretrained(model_dir)
    model = T5ForConditionalGeneration.from_pretrained(model_dir)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # 只训练解码层
    for name, param in model.named_parameters():
        if "decoder" not in name:
            param.requires_grad = False

    train_params = {
        "batch_size": batch_size,
        "shuffle": True,
        "num_workers": 0,
    }
    training_set = SummaryDataSet(train_json_path, tokenizer, max_length=max_length)
    training_loader = DataLoader(training_set, **train_params)

    val_params = {
        "batch_size": batch_size,
        "shuffle": False,
        "num_workers": 0,
    }

    val_set = SummaryDataSet(val_json_path, tokenizer, max_length=max_length)
    val_loader = DataLoader(val_set, **val_params)

    optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4)

    for epoch in range(epochs):
        train(epoch, tokenizer, model, device, training_loader, optimizer)
        print("保存模型")
        model.save_pretrained(output_dir)
        tokenizer.save_pretrained(output_dir)
    # 验证
    with torch.no_grad():
        predictions, actuals = validate(tokenizer, model, device, val_loader, max_length)
        # 验证结果存储
        final_df = pd.DataFrame({"Generated Text": predictions, "Actual Text": actuals})
        final_df.to_csv(os.path.join(output_dir, "predictions.csv"))


if __name__ == '__main__':
    train_json_path = "./data/train.json"
    val_json_path = "./data/val.json"
    model_dir = "chatyuan_large_v2"
    batch_size = 5
    epochs = 5
    max_length = 300
    output_dir = "./model"

    T5Trainer(
        train_json_path,
        val_json_path,
        model_dir,
        batch_size,
        epochs,
        output_dir,
        max_length
    )
 

运行后可以看到如下日志打印,训练大概占用 21G 的显存,如果显存不够可以调低些 batch_size 的大小:

在这里插入图片描述

等待训练结束后:

在这里插入图片描述

可以在 model 下看到保存的模型:

在这里插入图片描述

这里可以看到 predictions.csv 验证集的效果:

在这里插入图片描述

四、模型测试

# -*- coding: utf-8 -*-
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch

# 这里是模型下载的位置
model_dir = './model'

tokenizer = T5Tokenizer.from_pretrained(model_dir)
model = T5ForConditionalGeneration.from_pretrained(model_dir)

while True:
    text = input("请输入内容: \n ")
    if not text or text == "":
        continue
    if text == "q":
        break

    encoded_input = tokenizer(text, padding="max_length", truncation=True, max_length=300)
    input_ids = torch.tensor([encoded_input['input_ids']])
    attention_mask = torch.tensor([encoded_input['attention_mask']])

    generated_ids = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=300,
        num_beams=2,
        repetition_penalty=2.5,
        length_penalty=1.0,
        early_stopping=True
    )

    reds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in
            generated_ids]

    print(reds)

1. 我准备过两天去看我叔叔,顺便带些人参,但是他有高血压,您好人参高血压可以吃吗?

模型回答:

人参具有补气的药物,高血压患者是可以吃人参的,但是需要有在医生的指导下口服,不能够盲目的口服,以免引来血压升高,平时要留意低盐低脂饮食,多吃许多新鲜的蔬菜和水果,不要熬夜,确保充足的深度睡眠,防止情绪震荡太大,不要生气或者是发火。调节身心健康,保持自身卫生.避免因抵抗力下降而导致细菌入侵。

在这里插入图片描述

2. 检测报告单五项指标乙肝表面抗原结果为阳性,抗体结果为阴性,e抗原为阳性e抗体为阴性,核心抗体为阳性。要有打预防工作疫苗么?

模型回答:

乙肝表面抗原阳性,抗体结果为阴性,e抗原为阳性,e抗体为阴性,核心抗体为阳性,证明既往传染过乙肝病毒,现在已经康复了,并且产生了保护性抗体。建议你到医院检查一下肝功能和HBV-DNA,如果都正常的话,就用不着打预防工作疫苗了,平时多留意歇息,别加班,别喝啤酒,不定期复检就可以了。患者直面乙肝困恼,我们应当积极采用救治措施,不要避忌就诊,如果乙肝久治不愈,那么就要及时对换治疗方案。

在这里插入图片描述

3. 男孩,11岁,上小学了,这一周,孩子一直扣耳朵,说痒,另外,偶尔还会有黄色的液体流出,另外,伴有一些低烧,请问:孩子中耳炎流黄水怎样治疗。

模型回答:

抗生素和消炎药是目前临床治疗小儿中耳炎的方法和药物,家长可对症进行选择,也可局部用药,比如给宝宝采取使用消炎类型的滴耳剂,孩子如果是以耳痛的症状为主的话,建议家长是可以适当的给点镇痛药缓解的,要是伴随发高烧的情况,那么根据孩子的症状使用药物,严重的情况请尽快去医院进行救治,以上方法都是目前治疗中耳炎比较常用的,此外要是出现耳膜穿孔的情况是需要及时的进行手术治疗的,与此同时也要注意给孩子安排合理的饮食和作息习惯,帮助疾病尽早康复。

在这里插入图片描述

4. 我家宝宝最近不知晓为什么脸上再次出现了湿疹的情况,而且好象还蛮严重的,宝宝看上去也不舒服,请问婴儿全身长湿疹怎么办

模型回答:

婴儿全身长湿疹,建议你先带孩子去正规的医院仔细检查下,根据医生指导恰当对症救治。平时要留意护理好孩子,始终保持皮肤的清洁干燥,防止搔抓,以免严重病情。患者日常要养成良好的生活习惯,规律作息,禁烟酒,多增加日常锻炼以此来提高患者免疫力,避免病情复发。

在这里插入图片描述

5. 我对我的身体情况一直都是比较重视的,所以隔段时间就会抽时间去了一医院检查一下,近来的时候患有了小肠气,小肠气是什么样的一种疾病,小肠气到哪治?

模型回答:

你好,您的情况是检查确诊疝气。需要手术治疗。疝气是临床上比较常见的疾病。多是由于腹壁存在薄弱环节,腹腔内压力增高引起。主要症状是腹部可复性肿物。直立或者腹内压增加肿物突出,平卧可以消失或者回纳。这种情况需要积极手术治疗。开腹手术或者腹腔镜微创手术都可以。同时避免腹内压增加因素,比如咳嗽,便秘,弯腰提重物,剧烈运动等,以免发生嵌顿或者术后复发。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/903196.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【100天精通python】Day38:GUI界面编程_PyQt 从入门到实战(中)_数据库操作与多线程编程

目录 专栏导读 4 数据库操作 4.1 连接数据库 4.2 执行 SQL 查询和更新&#xff1a; 4.3 使用模型和视图显示数据 5 多线程编程 5.1 多线程编程的概念和优势 5.2 在 PyQt 中使用多线程 5.3 处理多线程间的同步和通信问题 5.3.1 信号槽机制 5.3.2 线程安全的数据访问 Q…

Spring Boot整合RabbitMQ之发布与订阅模式

RabbitMQ的模式中&#xff0c;常用的模式有&#xff1a;简单模式&#xff0c;发布与订阅模式&#xff0c;工作模式&#xff0c;路由模式&#xff0c;主题模式。简单模式不太会运用到工作中&#xff0c;我们可以使用 RabbitMQ 的发布订阅模式&#xff0c;实现&#xff1a; 用户…

KUST_LI计算机视觉实验室服务器安装与管理

第一步&#xff1a;安装 Linux-Ubuntu系统 系统语言设置为英文 ENGLISH&#xff0c;防止系统 BUG&#xff1b;选择-清除整个磁盘并安装系统&#xff1b;设置用户名和密码&#xff0c;实验室统一其余全部默认设置 开机后设置磁盘挂载 在系统设置中找到 desk 打开&#xff0c;…

YOLOv7训练结果解析

前言&#xff1a; 已训练完模型&#xff0c;且把结果下载下来&#xff0c;以下某一次id识别训练结果为例&#xff0c;如下图所示。 注&#xff1a;YOLOv7每次train完成&#xff08;如果没有中途退出&#xff09;都会在run目录下生成expX目录&#xff08;X代表生成结果次数 第一…

CentOS7.9手工配置静态网络流程

进入网卡配置文件 vim /etc/sysconfig/network-scripts/ifcfg-ens33 配置 TYPE"Ethernet" PROXY_METHOD"none" BROWSER_ONLY"no" BOOTPROTO"static" //static 配置静态网络 DEFROUTE"yes" IPV4_FAILURE_FATAL"no…

电脑找不到MSVCR120.dll怎么办,三个完美解决方法

在计算机领域&#xff0c;MSVCR120.dll是一个非常重要的动态链接库文件。它是Microsoft Visual C 2010 Redistributable Package的一部分&#xff0c;用于支持某些程序的运行。然而&#xff0c;在某些情况下&#xff0c;我们可能会遇到MSVCR120.dll丢失的问题。在这篇文章中&am…

(详解踩坑)GIT版本回滚git stash、git reset、git reset --hard、git revert

目录 背景 一、&#xff08;git log、git reflog&#xff09;查看git提交日志及命令历史 1.1 git log&#xff08;提交日志&#xff09; 1.2 git reflog&#xff08;命令历史&#xff09; 二、git reset&#xff08;回退到指定的版本&#xff0c;并且保留更改&#xff09; …

LEADTOOLS Imaging SDK Crack

LEADTOOLS Imaging SDK Crack 高级开发人员工具包包括ActiveX和WPF/XAML控件。 LEADTOOLS Imaging SDK为文件格式导入/导出、图像压缩、图像显示和效果、颜色转换、图像处理、TWAIN扫描、图像通用对话框、数据库集成、打印和互联网提供了基本和高级的彩色图像功能。 LEADTOOLS …

【数据分享】2013-2023年全国370个城市逐月空气质量数据(Excel格式/无需转发)

空气质量的好坏反映了空气污染程度&#xff0c;它是依据空气中污染物浓度的高低来判断的。在各项涉及城市环境的研究与实际项目中&#xff0c;城市空气质量都是一个十分重要的指标。那么&#xff0c;去哪里能获取到各城市空气质量的历史数据呢&#xff1f; 之前我们分享了2014…

前端vue自定义柱形图 选中更改柱形图颜色及文字标注颜色

随着技术的发展&#xff0c;开发的复杂度也越来越高&#xff0c;传统开发方式将一个系统做成了整块应用&#xff0c;经常出现的情况就是一个小小的改动或者一个小功能的增加可能会引起整体逻辑的修改&#xff0c;造成牵一发而动全身。 通过组件化开发&#xff0c;可以有效实现…

Dubbo高手之路3,Dubbo服务消费详解

目录 引言1. 介绍 Dubbo 服务消费的详解的目的和背景2. 概述 Dubbo 服务消费的过程和核心概念 一、Dubbo 服务消费的基础知识1. Dubbo 服务消费的架构和流程2. Dubbo 服务消费的基本配置和使用方法 二、Dubbo 服务消费的注册与发现1. Dubbo 服务消费的注册中心和发布中心的基本…

09_Redlock算法和底层源码分析

Redlock算法和底层源码分析 一、当前代码为8.0版接上一步 自研分布式锁的重点&#xff1a; 按照juc里面Lock接口规范进行编写lock加锁关键逻辑 加锁&#xff1a;在redis中&#xff0c;加锁实际上是给key设置一个值&#xff0c;为避免死锁&#xff0c;并给key一个过期时间自旋…

01.Django入门

1.创建项目 1.1基于终端创建Django项目 打开终端进入文件路径&#xff08;打算将项目放在哪个目录&#xff0c;就进入哪个目录&#xff09; E:\learning\python\Django 执行命令创建项目 F:\Anaconda3\envs\pythonWeb\Scripts\django-admin.exe&#xff08;Django-admin.exe所…

protobuf+netty自定义编码解码

protobufnetty自定义编 项目背景 protobufnetty自定义编码解码 比如心跳协议&#xff0c;客户端请求的协议是10001&#xff0c;在java端如何解码&#xff0c;心跳返回协议如何编码&#xff0c;将协议号带过去 // 心跳包 //10001 message c2s_heartbeat { }//10002 message …

【数据分享】2006-2021年我国省份级别的市政公用设施建设固定资产投资相关指标(30多项指标)

《中国城市建设统计年鉴》中细致地统计了我国城市市政公用设施建设与发展情况&#xff0c;在之前的文章中&#xff0c;我们分享过基于2006-2021年《中国城市建设统计年鉴》整理的2006—2021年我国省份级别的市政设施水平相关指标、2006-2021年我国省份级别的各类建设用地面积数…

02.案列项目Demo

1.创建项目 1. 创建项目 用pycharm 选择对应的编译器&#xff0c;输入对应的文件名&#xff0c;点击创建项目。删除默认外层生成的template和DIRS 配置项&#xff1a; 2. 创建App 创建appo1的命令&#xff1a; python manage.py startapp app01 如果使用pycharm>tool>…

Qt项目报错:Cannot run compiler ‘clang++‘. /bin/sh: 1: clang++: not found

在一台旧电脑上装了深度系统&#xff0c;装了Qt&#xff0c;导入项目&#xff0c; build提示 clang找不到&#xff1a; Project ERROR: Cannot run compiler clang. Output: /bin/sh: 1: clang: not found Maybe you forgot to setup the environment? Error while parsing …

​8th参考文献:[8]许少辉.乡村振兴战略下传统村落文化旅游设计[M]北京:中国建筑工业出版社,2022.

​&#xff18;th参考文献&#xff1a;&#xff3b;&#xff18;&#xff3d;许少辉&#xff0e;乡村振兴战略下传统村落文化旅游设计&#xff3b;&#xff2d;&#xff3d;北京&#xff1a;中国建筑工业出版社&#xff0c;&#xff12;&#xff10;&#xff12;&#xff12;&a…

2022数学建模国赛C题官网展示论文C155论文复现

2022数学建模国赛C题C155论文复现 1.内容比对2.第一问第二小问复现代码2.1 页表合并2.2 数据的正态性检验2.2.1数据的正态性检验效果图 2.3不满足正态性&#xff0c;进行中心化对数比变换2.3.1 核心步骤-inf用0值替换2.3.2中心化对数比变换效果图 2.4描述性统计2.5 箱线图绘制 …