CBLUE_中文生物医学语言理解评估基准_源码详解

news2025/1/15 18:03:11

CBLUE_中文生物医学语言理解评估基准_源码详解

源码链接:https://github.com/CBLUEbenchmark/CBLUE

项目中包括八个不同的中文医学NLP任务:1.中文医学命名实体识别(CMeEE)、2.中文医学文本实体关系抽取(CMeIE)、3.临床术语标准化任务(CHIP-CDN)、4.临床试验筛选标准短文本分类(CHIP-CTC)、5.平安医疗科技疾病问答迁移学习(CHIP-STS)、6.医疗搜索检索词意图分类(KUAKE-QIC)、7.医疗搜索查询词-页面标题相关性(KUAKE-QTR)、8.医疗搜索查询词-查询词相关性(KUAKE-QQR)

本文以医疗搜索检索词意图分类(KUAKE-QIC)任务为例

一、介绍项目目录结构

├── CBLUE         
|  └── baselines
|     └── run_classifier.py  # 程序入口
|     └── ...
|  └── examples
|     └── run_qic.sh  # QIC任务快捷启动脚本
|     └── ...

|  └── cblue
|     └── data 
|        └── data_process.py # 训练数据处理
|        └── dataset.py # 从json文件中获取数据字段
|     └── metrics 
|        └── cblue_commit.py # 测试数据保存
|        └── cblue_metrics.py # 度量方法
|     └── models
|        └── zen # zen模型实现
|        └── model.py # zen模型实现
|     └── trainer # 
|        └── train.py # 模型训练类
|     └── utils.py # 工具类

|  └── CBLUEDatasets
|     └── KUAKE-QIC # 数据集存放路径
|        └── KUAKE-QIC_dev.json # 验证集
|        └── KUAKE-QIC_test.json # 测试集
|        └── KUAKE-QIC_train.json # 训练集

|  └── data
|     └── output # 训练好生成模型的存放路径
|        └── qic
|           └── chinese-bert-wwm-ext
|              └── config.json # 模型参数配置文件
|              └── pytorch_model.bin # 训练完成的模型
|              └── training_args.bin # 
|              └── qic_chinese-bert-wwm-ext.log # 训练日志文件
|              └── vocab.txt # 词汇表
|     └── model_data
|        └── chinese-bert-wwm # 在运行之前,下载pytorch版本的bert模型
|           └── bert_config.json # 预训练模型的参数配置文件
|           └── pytorch_model.bin # 预训练模型
|           └── vocab.txt # 词汇表
|     └── result_output
|        └── KUAKE-QIC_test.json # 传入do_predict参数,利用训练好的模型预测测试集的样本标签,生成的文件

二、介绍KUAKE-QIC数据集

KUAKE-Query Intent Classification Dataset (KUAKE-QIC)
医疗意图标签共有11类,包括诊断、病因分析、治疗计划、医疗建议、检测结果分析、疾病描述、后果预测、注意事项、预期效果、治疗费用和其它。

数据集示例:训练集(左)、测试集(右)
在这里插入图片描述

三、项目运行环境搭建

推荐使用服务器训练模型,或者用笔记本自带的GPU
本人租用AutoDL,配置如下图:
在这里插入图片描述
切记选择PyTorch 1.7.0 Python 3.8(ubuntu18.04) Cuda 11.0
其他版本的运行程序会出错。
安装其他第三方库:
本项目使用的第三方库有:torch 1.7 / transformers 4.5.1 / jieba / gensim / scikit-learn等,安装时切记版本对应,否则模型训练报错,版本不兼容。
或者在运行时看程序报错缺失哪个包,再pip安装即可。

-----至此前期工作全部完成,接下来可以运行看看效果!-----

四、运行项目

直接运行sh脚本文件 bash examples/run_qic.sh,若使用服务器,请在终端进入到examples的父目录再运行。
脚本文件中,已经配置了模型训练需要的参数,可以根据不同任务进行微调。

#!/usr/bin/env bash
DATA_DIR="CBLUEDatasets"                  # 数据集总目录
TASK_NAME="qic"                           # 具体任务 医疗搜索检索词意图分类(KUAKE-QIC)
MODEL_TYPE="bert"                         # 预训练模型类型
MODEL_DIR="data/model_data"               # 预训练模型保存路径
MODEL_NAME="chinese-bert-wwm-ext"         # 预训练模型名称
OUTPUT_DIR="data/output"                  # 模型保存目录
RESULT_OUTPUT_DIR="data/result_output"    # 结果保存目录

MAX_LENGTH=50  # 最大长度

echo "Start running"

if [ $# == 0 ]; then
    python baselines/run_classifier.py \
        --data_dir=${DATA_DIR} \
        --model_type=${MODEL_TYPE} \
        --model_dir=${MODEL_DIR} \
        --model_name=${MODEL_NAME} \
        --task_name=${TASK_NAME} \
        --output_dir=${OUTPUT_DIR} \
        --result_output_dir=${RESULT_OUTPUT_DIR} \
        --do_train \
        --max_length=${MAX_LENGTH} \
        --train_batch_size=16 \  # 训练的batch-size
        --eval_batch_size=16 \   # 验证的batch-size
        --learning_rate=3e-5 \   # 学习率
        --epochs=3 \             # 训练的迭代次数
        --warmup_proportion=0.1 \ # 慢热学习的比例
        --earlystop_patience=3 \ #当使用提前终止训练策略时,如果验证集精度在earlystop_patience个epoch内连续下降或持平,则终止训练。默认值为5。
        --logging_steps=200 \ # 日志打印的间隔 steps,默认为 20
        --save_steps=200 \  # 保存模型参数的间隔 steps,默认为 100
        --seed=2021 # 随机种子,默认为1000
elif [ $1 == "predict" ]; then
    python baselines/run_classifier.py \
        --data_dir=${DATA_DIR} \
        --model_type=${MODEL_TYPE} \
        --model_name=${MODEL_NAME} \
        --model_dir=${MODEL_DIR} \
        --task_name=${TASK_NAME} \
        --output_dir=${OUTPUT_DIR} \
        --result_output_dir=${RESULT_OUTPUT_DIR} \
        --do_predict \
        --max_length=${MAX_LENGTH} \
        --eval_batch_size=16 \
        --seed=2021
fi

一些参数介绍:
Warmup:Warmup是在ResNet论文中提到的一种学习率预热的方法,由于刚开始训练时,模型的权重(weights)是随机初始化的,此时若选择一个较大的学习率,可能带来模型的不稳定(振荡),选择Warmup预热学习率的方式,可以使得开始训练的几个epoches或者一些steps内学习率较小,在预热的小学习率下,模型可以慢慢趋于稳定,等模型相对稳定后再选择预先设置的学习率进行训练,使得模型收敛速度变得更快,模型效果更佳。
warmup_proportion:慢热学习的比例。比如warmup_proportion=0.1,总步数=100,那么warmup步数就为10。在1到10步中,学习率会比10步之后低,10步之后学习率恢复正常。

五、获取加载数据

class QICDataset(Dataset):
    def __init__(
            self,
            samples,
            data_processor,
            mode='train'
    ):
        super(QICDataset, self).__init__()

        self.text = samples['text']
        self.ids = samples['id']

        if mode != 'test':
            self.labels = samples['label']  # 非测试数据集都有label

        self.data_processor = data_processor
        self.mode = mode

    def __getitem__(self, item):
        if self.mode != 'test':
            return self.text[item], self.labels[item]
        else:
            return self.text[item]

    def __len__(self):
        return len(self.text)
class QICDataProcessor(object):
    def __init__(self, root):
        self.task_data_dir = os.path.join(root, 'KUAKE-QIC') # 获取数据集名称
        # 对训练集 测试集 和 验证集的路径进行拼接
        self.train_path = os.path.join(self.task_data_dir, 'KUAKE-QIC_train.json')
        self.dev_path = os.path.join(self.task_data_dir, 'KUAKE-QIC_dev.json')
        self.test_path = os.path.join(self.task_data_dir, 'KUAKE-QIC_test.json')
        
        # 11种意图标签
        self.label_list = ['疾病表述', '指标解读', '医疗费用', '治疗方案', '功效作用', '病情诊断',
                           '其他', '注意事项', '病因分析', '就医建议', '后果表述']
        self.label2id = {label: idx for idx, label in enumerate(self.label_list)}
        self.id2label = {idx: label for idx, label in enumerate(self.label_list)}
        self.num_labels = len(self.label_list)

    def get_train_sample(self):
        return self._pre_process(self.train_path, is_predict=False)

    def get_dev_sample(self):
        return self._pre_process(self.dev_path, is_predict=False)

    def get_test_sample(self):
        return self._pre_process(self.test_path, is_predict=True)

    def _pre_process(self, path, is_predict):
        #拿到json文件中标签对应的值
        samples = load_json(path)
        outputs = {'text': [], 'label': [], 'id': []}
        for sample in samples:
            outputs['text'].append(sample['query'])
            outputs['id'].append(sample['id'])
            if not is_predict:
                outputs['label'].append(self.label2id[sample['label']])
        return outputs

六、模型训练

class QICTrainer(Trainer):
    def __init__(
            self,
            args,
            model,
            data_processor,
            tokenizer,
            logger,
            model_class,
            train_dataset=None,
            eval_dataset=None,
            ngram_dict=None

    ):
        super(QICTrainer, self).__init__(
            args=args,
            model=model,
            data_processor=data_processor,
            tokenizer=tokenizer,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            logger=logger,
            model_class=model_class,
            ngram_dict=ngram_dict
        )

    def training_step(self, model, item):
        model.train()

        text1 = item[0]
        labels = item[1].to(self.args.device)

        if self.args.model_type == 'zen':
            inputs = convert_examples_to_features(text1=text1, ngram_dict=self.ngram_dict,
                                                  tokenizer=self.tokenizer, max_seq_length=self.args.max_length,
                                                  return_tensors=True)
        else:
            inputs = self.tokenizer(text1, padding='max_length', max_length=self.args.max_length,
                                    truncation=True, return_tensors='pt')

        if self.args.model_type == 'zen':
            inputs['input_ngram_ids'] = inputs['input_ngram_ids'].to(self.args.device)
            inputs['ngram_position_matrix'] = inputs['ngram_position_matrix'].to(self.args.device)
            inputs['ngram_attention_mask'] = inputs['ngram_attention_mask'].to(self.args.device)
            inputs['ngram_token_type_ids'] = inputs['ngram_token_type_ids'].to(self.args.device)

        inputs['input_ids'] = inputs['input_ids'].to(self.args.device)
        inputs['attention_mask'] = inputs['attention_mask'].to(self.args.device)
        inputs['token_type_ids'] = inputs['token_type_ids'].to(self.args.device)

        # default using 'Transformers' library models.
        outputs = model(labels=labels, **inputs)
        loss = outputs[0]
        loss.backward()

        return loss.detach()

    def evaluate(self, model):
        args = self.args
        logger = self.logger
        eval_dataloader = self.get_eval_dataloader()
        num_examples = len(eval_dataloader.dataset)

        preds = None
        eval_labels = None

        logger.info("***** Running evaluation *****")
        logger.info("Num samples %d", num_examples)
        for step, item in enumerate(eval_dataloader):
            model.eval()

            text1 = item[0]
            labels = item[1].to(args.device)

            if self.args.model_type == 'zen':
                inputs = convert_examples_to_features(text1=text1, ngram_dict=self.ngram_dict,
                                                      tokenizer=self.tokenizer, max_seq_length=self.args.max_length,
                                                      return_tensors=True)
            else:
                inputs = self.tokenizer(text1, return_tensors='pt', padding='max_length',
                                        truncation='longest_first', max_length=self.args.max_length)
            inputs['input_ids'] = inputs['input_ids'].to(self.args.device)
            inputs['attention_mask'] = inputs['attention_mask'].to(self.args.device)
            inputs['token_type_ids'] = inputs['token_type_ids'].to(self.args.device)

            if self.args.model_type == 'zen':
                inputs['input_ngram_ids'] = inputs['input_ngram_ids'].to(self.args.device)
                inputs['ngram_position_matrix'] = inputs['ngram_position_matrix'].to(self.args.device)
                inputs['ngram_attention_mask'] = inputs['ngram_attention_mask'].to(self.args.device)
                inputs['ngram_token_type_ids'] = inputs['ngram_token_type_ids'].to(self.args.device)

            with torch.no_grad():
                outputs = model(labels=labels, **inputs)
                loss, logits = outputs[:2]

            if preds is None:
                preds = logits.detach().cpu().numpy()
                eval_labels = labels.detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                eval_labels = np.append(eval_labels, labels.detach().cpu().numpy(), axis=0)

        preds = np.argmax(preds, axis=1)
        acc = qic_metric(preds, eval_labels)
        logger.info("%s-%s acc: %s", args.task_name, args.model_name, acc)
        return acc

    def predict(self, test_dataset, model):
        args = self.args
        logger = self.logger
        test_dataloader = self.get_test_dataloader(test_dataset)
        num_examples = len(test_dataloader.dataset)
        model.to(args.device)

        preds = None

        logger.info("***** Running prediction *****")
        logger.info("Num samples %d", num_examples)
        pbar = ProgressBar(n_total=len(test_dataloader), desc='Prediction')
        for step, item in enumerate(test_dataloader):
            model.eval()

            text1 = item

            if self.args.model_type == 'zen':
                inputs = convert_examples_to_features(text1=text1, ngram_dict=self.ngram_dict,
                                                      tokenizer=self.tokenizer, max_seq_length=self.args.max_length,
                                                      return_tensors=True)
            else:
                inputs = self.tokenizer(text1, return_tensors='pt', padding='max_length',
                                        truncation='longest_first', max_length=self.args.max_length)
            if self.args.model_type == 'zen':
                inputs['input_ngram_ids'] = inputs['input_ngram_ids'].to(self.args.device)
                inputs['ngram_position_matrix'] = inputs['ngram_position_matrix'].to(self.args.device)
                inputs['ngram_attention_mask'] = inputs['ngram_attention_mask'].to(self.args.device)
                inputs['ngram_token_type_ids'] = inputs['ngram_token_type_ids'].to(self.args.device)

            inputs['input_ids'] = inputs['input_ids'].to(self.args.device)
            inputs['attention_mask'] = inputs['attention_mask'].to(self.args.device)
            inputs['token_type_ids'] = inputs['token_type_ids'].to(self.args.device)

            with torch.no_grad():
                outputs = model(**inputs)
                if self.args.model_type == 'zen':
                    logits = outputs
                else:
                    logits = outputs[0]

            if preds is None:
                preds = logits.detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)

            pbar(step=step, info="")
        preds = np.argmax(preds, axis=1)
        qic_commit_prediction(dataset=test_dataset, preds=preds, output_dir=args.result_output_dir,
                              id2label=self.data_processor.id2label)

        return preds

    # 保存模型
    def _save_checkpoint(self, model, step):
        output_dir = os.path.join(self.args.output_dir, 'checkpoint-{}'.format(step))
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        if self.args.model_type == 'zen':
            save_zen_model(output_dir, model=model, tokenizer=self.tokenizer,
                           ngram_dict=self.ngram_dict, args=self.args)
        else:
            model.save_pretrained(output_dir)
            torch.save(self.args, os.path.join(output_dir, 'training_args.bin')) # 保存训练参数
            self.tokenizer.save_vocabulary(save_directory=output_dir)
        self.logger.info('Saving models checkpoint to %s', output_dir)

    def _save_best_checkpoint(self, best_step): # 保存最佳的模型
        model = self.model_class.from_pretrained(os.path.join(self.args.output_dir, f'checkpoint-{best_step}'),
                                                 num_labels=self.data_processor.num_labels)

        if self.args.model_type == 'zen':
            save_zen_model(self.args.output_dir, model=model, tokenizer=self.tokenizer,
                           ngram_dict=self.ngram_dict, args=self.args)
        else:
            model.save_pretrained(self.args.output_dir)
            torch.save(self.args, os.path.join(self.args.output_dir, 'training_args.bin'))
            self.tokenizer.save_vocabulary(save_directory=self.args.output_dir)
        self.logger.info('Saving models checkpoint to %s', self.args.output_dir)

七、模型评估

# 计算准确率
def simple_accuracy(preds, labels):
    return (preds == labels).mean()

def qic_metric(preds, labels):
    return simple_accuracy(preds, labels)
def qic_commit_prediction(dataset, preds, output_dir, id2label):
    text1 = dataset.text
    label = preds
    ids = dataset.ids

    pred_result = []
    for item in zip(ids, text1, label):
        tmp_dict = {'id': item[0], 'query': item[1],
                    'label': id2label[item[2]]}
        pred_result.append(tmp_dict)
    with open(os.path.join(output_dir, 'KUAKE-QIC_test.json'), 'w', encoding='utf-8') as f:
        f.write(json.dumps(pred_result, indent=2, ensure_ascii=False))

八、运行程序

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/557679.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

英国 VM600 CPUR2 机架控制器和通信接口卡

英国 VM600 CPUR2 机架控制器和通信接口卡VM600 CPUR2/IOCR2机架控制器和通信接口卡对,支持Modbus TCP和PROFIBUS DP使用以太网连接到运行VM600 MPSx和VibroSight软件的计算机,对VM600机架中的保护卡(MPC4)进行“一次性”配置管理对通过现场总线共享的数…

基于 Python 长时间序列遥感数据处理及在全球变化、物候提取、植被变绿与固碳分析、生物量估算与趋势分析等领域中的应用

植被是陆地生态系统中最重要的组分之一,也是对气候变化最敏感的组分,其在全球变化过程中起着重要作用,能够指示自然环境中的大气、水、土壤等成分的变化,其年际和季节性变化可以作为地球气候变化的重要指标。此外,由于…

【CANN训练营0基础赢满分秘籍】 应用开发深入讲解→端到端案例

1 样例调试 1.1 日志文件 运行应用程序后,若出现报错或异常,需录取日志进一步定位问题。日志文件的默认目录为$HOME/ascend/log。 可通过环境变量指定日志文件的落盘路径 export ASCEND_PROCESS_LOG_PATH/$HOME/xxx但需要确保该目录为任意有读写权限…

文档图像智能分析与处理:CCIG技术论坛的思考与展望

文档图像智能分析与处理:CCIG技术论坛的思考与展望 文档识别与理解的发展趋势视觉-语言预训练模型在文档处理中的应用篡改文本图像的生成与检测的研究进展华为云OCR技术的进展与行业实践智能文档处理技术的应用与挑战文档图像预处理的整体架构弯曲矫正摩尔纹去除版面…

【Linux】普通用户无法使用sudo指令的方法

​ ​📝个人主页:Sherry的成长之路 🏠学习社区:Sherry的成长之路(个人社区) 📖专栏链接:Linux 🎯长路漫漫浩浩,万事皆有期待 上一篇博客:【Linux】…

计算机视觉的应用6-利用VGG模型做毕加索风格图像迁移

大家好,我是微学AI,今天给大家介绍一下计算机视觉的应用5-利用VGG模型做毕加索风格图像迁移,本文将利用VGG模型实现毕加索风格图像迁移的方法。首先,我们将简要说明图像风格迁移的原理,然后使用PyTorch框架&#xff0c…

chatgpt赋能Python-python_fig

Python中的fig:简介和应用 什么是fig? fig是Python中一个高效且易用的图形库,它支持大量的图像绘制功能,包括2D图形绘制、曲线和图像处理,以及3D图形和动画绘制等应用。fig可以在多个平台上运行,包括Wind…

客户体验|审美体验与体验管理

Guofu 第 93⭐️ 篇原创文章分享 (点击👆🏻上方卡片关注我,加⭐️星标⭐️~) 🚏 写在前面 伽达默尔说:“如果某个东西被经历过,而且它的经历存在还获得一种使自身继续存在意义的特征…

chatgpt赋能Python-python_har

Python HAR:一种高效的网络监测工具 Python HAR(HTTP Archive)是一个用于监测网络资源的强大工具,它能够记录网络请求、响应和资源加载的细节信息,并以可视化和格式化的方式呈现出来。Python HAR的应用范围广泛&#…

单模光纤二维模场分布的MATLAB仿真

在上一篇文章中,我们介绍了单模光纤的一维模场分布,能看出沿着径向的光场分布情况,并分析能量的分布 这一篇中,我们绘制光纤横截面上的二维光场分布:代码如下: clear close all V 2.4000; U 1.6453; W …

C4D R26 渲染学习笔记(1):C4D版本选择和基础知识(更新中)

C4D版本知识 C4D通过R来进行版本区分,现在2023年5月22日最新版的是R26。说一下特殊版本。 C4D版本介绍特点R19OC快乐版3.07最高版本,OC是C4D最具性价比的渲染器,OC学习成本低,渲染速度快,但是注意OC 3.07只支持10系N…

如何提取微信公众号的链接?非常简单!

今天在公众号里面想要复制公众号链接,用于小程序里面引导用户关注,因为小程序里面的关注公众号只能是扫码小程序才能使用,想起以前使用的原始链接跳转方法,就想试一试,结果公众号后台居然没有链接可以复制了&#xff0…

代码随想录算法训练营day49 | 121. 买卖股票的最佳时机,122.买卖股票的最佳时机II

代码随想录算法训练营day49 | 121. 买卖股票的最佳时机,122.买卖股票的最佳时机II 121. 买卖股票的最佳时机解法一:动态规划解法二:贪心算法 122.买卖股票的最佳时机II解法一:动态规划解法二:贪心算法 121. 买卖股票的…

数据要素流通使用的安全风险分析及应对策略

数据要素流通使用的安全风险分析及应对策略 刘业政1,2, 宗兰芳1, 金斗1,袁昆1,2 1 合肥工业大学管理学院,安徽 合肥 230009 2 大数据流通与交易技术国家工程实验室,上海 201203 摘要:系统地分析了数据要素流通使用过程中存在的安全…

直播预告 | 医疗人工智能入行经验分享

(本文阅读时间:2 分钟) 或许大家从不同程度上听说或使用过智能导诊机器人、语音电子病历或是智能问诊?这些高端大气又便利的产物都是人工智能在医疗健康领域的重要应用场景产品及服务。 “AI医疗”是人工智能技术赋能医疗健康产业…

综述 | 基于 Transformer 网络的多模态学习

关注公众号,发现CV技术之美 Transformer 网络结构作为一种性能卓越的神经网络学习器,已经在各类机器学习问题中取得了巨大的成功。伴随着近年来多模态应用和多模态大数据的蓬勃发展,基于Transformer 网络的多模态学习已经成为了人工智能领域的…

chatgpt赋能Python-python_for_loop

Python For Loop: 了解循环结构控制的重要性 在Python编程中,循环结构控制是必备技能之一。它允许程序员重复执行指定的代码块,而不需要手动多次输入。Python提供了几种类型的循环结构,其中for循环是最常用的之一。我们将在本文中讨论for循环…

( 回溯算法) 332. 重新安排行程 ——【Leetcode每日一题】

❓332. 重新安排行程 难度:困难 给你一份航线列表 tickets ,其中 tickets[i] [fromi, toi] 表示飞机出发和降落的机场地点。请你对该行程进行重新规划排序。 所有这些机票都属于一个从 JFK(肯尼迪国际机场)出发的先生&#xf…

【泛微ecology_oracle】如何把查询到的单列人力资源id合并成多人力资源格式

如何把查询到的单列人力资源id合并成多人力资源格式 在泛微ecology中,单列人力资源id合并成多人力资源的使用场景在泛微ecology中,在数据库里人员姓名存储形式那如何实现人力资源字段合并多人力资源字段呢? 在泛微ecology中,单列人…

Node.js博客项目开发思路笔记

博客项目介绍 1. 目标 开发一个博客系统,具备博客基本功能只开发 server 端,不关心前端 2. 需求 首页、作者页、博客详情页登陆页管理中心、新建页、编辑页 3. 技术方案 数据如何存储 博客 idtitlecontentcreatetimeauthor1标题 1内容 11111112z…