BERT+TextCNN实现医疗意图识别项目

news2025/1/13 19:55:29

BERT+TextCNN实现医疗意图识别项目

一、说明

本项目采用医疗意图识别数据集CMID传送门
数据集示例:

{
 "originalText": "间质性肺炎的症状?", 
 "entities": [{"label_type": "疾病和诊断", "start_pos": 0, "end_pos": 5}], 
 "seg_result": ["间质性肺炎", "的", "症状", "?"], 
 "label_4class": ["病症"], 
 "label_36class": ["临床表现"]
}

模型使用BERT、TextCNN实现意图分类

二、BERT模型加载

使用苏建林开发的bert4keras深度学习框架加载BERT模型

from bert4keras.backend import keras,set_gelu
from bert4keras.models import build_transformer_model # 加载BERT的方法
from bert4keras.optimizers import Adam # 优化器
set_gelu('tanh')

1.定义函数加载BERT

def build_bert_model(config_path , checkpoint_path , class_nums) : # config_path配置文件的路径 checkpoint_path预训练路径 class_nums类别的数量
    bert = build_transformer_model(
        config_path = config_path ,
        checkpoint_path = checkpoint_path ,
        model = 'bert' ,
        return_keras_model= False)
    # 在BERT模型输出中抽取[CLS]
    cls_features = keras.layers.Lambda(lambda x:x[:,0],name='cls-token')(bert.model.output) # [:,0]选取输出的第一列,BERT模型的输出中[CLS]在第一个位置 shape = [batch_size ,768]
    all_token_embedding = keras.layers.Lambda(lambda x:x[:,1:-1],name='all-token')(bert.model.output) # 获取第2列至倒数第二列的所有token  shape = [batch_size ,maxlen-2,768] 除去CLS、SEP

    # textcnn抽取特征
    cnn_features = textcnn(all_token_embedding, bert.initializer) # 输入all_token_embedding  shape = [batch_size,cnn_output_dim]
    # 将cls_features 与 cnn_features 进行拼接
    concat_features = keras.layers.concatenate([cls_features,cnn_features] ,axis= -1)

    # 全连接层
    dense = keras.layers.Dense (
        units= 512, # 输出维度
        activation = 'relu' , # 激活函数
        kernel_initializer= bert.initializer # bert权重初始化
    )(concat_features) # 输入

    # 输出
    output = keras.layers.Dense (
        units= class_nums, # 输出类别数量
        activation= 'softmax', # 激活函数 (多分类输出层最常用的激活函数)
        kernel_initializer= bert.initializer # bert权重初始化
    )(dense) # 输入

    model = keras.models.Model(bert.model.input,output) # (bert.model.input输入,output输出)
    print(model.summary())

    return model

2.实现TextCNN

def textcnn(input,kernel_initializer) :
    # 3,4,5
    cnn1 = keras.layers.Conv1D(
        256, # 卷积核数量
        3, # 卷积核大小
        strides= 1, # 步长
        padding= 'same', # 输出与输入维度一致
        activation='relu',  # 激活函数
        kernel_initializer = kernel_initializer # 初始化器
    )(input) # shape = [batch_size ,maxlen-2,256]
    cnn1 = keras.layers.GlobalAvgPool1D()(cnn1) # 全局最大池化操作 shape = [batch_size ,256]

    cnn2 = keras.layers.Conv1D(
        256,  # 卷积核数量
        4,  # 卷积核大小
        strides=1,  # 步长
        padding='same',  # 输出与输入维度一致
        activation='relu',  # 激活函数
        kernel_initializer=kernel_initializer  # 初始化器
    )(input)
    cnn2 = keras.layers.GlobalAvgPool1D()(cnn2)  # 全局最大池化操作 shape = [batch_size ,256]

    cnn3 = keras.layers.Conv1D(
        256,  # 卷积核数量
        5,  # 卷积核大小
        strides=1,  # 步长
        padding='same',  # 输出与输入维度一致
        kernel_initializer=kernel_initializer  # 初始化器
    )(input)
    cnn3 = keras.layers.GlobalAvgPool1D()(cnn3) # 全局最大池化操作 shape = [batch_size ,256]

    # 将三个卷积结果进行拼接
    output = keras.layers.concatenate([cnn1,cnn2,cnn3],
                                   axis= -1)
    output = keras.layers.Dropout(0.2)(output) # 最后接Dropout

    return output

3.程序入口

if __name__ == '__main__':
    config_path = '.\chinese_L-12_H-768_A-12\\bert_config.json'
    checkpoint_path = '.\chinese_L-12_H-768_A-12\\bert_model.ckpt'
    class_nums = 13
    build_bert_model(config_path , checkpoint_path , class_nums)

其中BERT模型文件可以自行在Github中下载,也可私信。
在这里插入图片描述
当程序开始加载模型时,表示运行成功。
切记!运行代码前,检查TensorFlow、bert4keras等第三方库的版本是否一致,否则容易报错!
4.本项目第三方库以及对应的版本

pyahocorasick==1.4.2
requests==2.25.1
gevent==1.4.0
jieba==0.42.1
six==1.15.0
gensim==3.8.3
matplotlib==3.1.3
Flask==1.1.1
numpy==1.16.0
bert4keras==0.9.1
tensorflow==1.14.0
Keras==2.3.1
py2neo==2020.1.1
tqdm==4.42.1
pandas==1.0.1
termcolor==1.1.0
itchat==1.3.10
ahocorasick==0.9
flask_compress==1.9.0
flask_cors==3.0.10
flask_json==0.3.4
GPUtil==1.4.0
pyzmq==22.0.3
scikit_learn==0.24.1

三、数据预处理

抽取CMID.json中的数据,并划分为训练集与测试集
从中选取13个类别作为最终意图分类的标签

定义
病因
预防
临床表现(病症表现)
相关病症
治疗方法
所属科室
传染性
治愈率
禁忌
化验/体检方案
治疗时间
其他

1.抽取数据

def gen_training_data(row_data_path) :
    label_list = [line.strip() for line in open('./dataset/label', 'r' ,encoding='utf8')]
    print(label_list)

    # 映射id,为每一条数据添加id
    label2id = {label : idx for idx, label in enumerate(label_list)}
    data = []
    with open('./dataset/CMID.json','r',encoding='utf8') as f :
        origin_data = f.read()
        origin_data = eval(origin_data)

    label_set = set()
    for item in origin_data :
        text = item['originalText']

        label_class = item['label_4class'][0].strip("'")
        if label_class == '其他' :
            data.append([text , label_class ,label2id[label_class]])
            continue
        label_class = item["label_36class"][0].strip("'") # 所有的意图标签都从label_36class中取出
        label_set.add(label_class)
        if label_class not in label_list:
            continue
        data.append([text, label_class ,label2id[label_class]])
    print(label_set)

    data = pd.DataFrame(data , columns=['text','label_class','label'])
    print(data['label_class'].value_counts())
    data['text_len'] = data['text'].map(lambda x : len(x)) # 序列长度
    print(data['text_len'].describe())
    plt.hist(data['text_len'], bins=30, rwidth= 0.9, density=True)
    plt.show()

    del data['text_len']

    data = data.sample(frac = 1.0)
    # 将数据集拆分为测试集和训练集
    train_num = int(0.9*len(data))
    train , test = data[:train_num],data[train_num:]
    train.to_csv('./dataset/train.csv', index=False)
    test.to_csv('./dataset/test.csv', index = False)

2.加载训练数据集

# 加载训练数据集
def load_data(filename) :
    df = pd.read_csv(filename , header= 0 )
    return df[['text','label']].values

3.数据集信息可视化
在这里插入图片描述
数据样本长度基本上在100以内,此时在BERT模型中可以设置样本最大长度为128.
4.划分的训练集与测试集示例
训练集
在这里插入图片描述
测试集
在这里插入图片描述

四、模型训练

1.定义配置文件以及超参数

# 定义超参数和配置文件
class_nums = 13
maxlen = 128
batch_size = 32

config_path = './chinese_rbt3_L-3_H-768_A-12/bert_config_rbt3.json'
checkpoint_path = './chinese_rbt3_L-3_H-768_A-12/bert_model.ckpt'
dict_path = './chinese_rbt3_L-3_H-768_A-12/vocab.txt'
tokenizer = Tokenizer(dict_path)

2.定义数据生成器,将样本传递到模型中

# 定义数据生成器 将数据传递到模型中
class data_generator(DataGenerator) :
    """
    数据生成器
    """
    def __iter__(self , random = False):
        batch_token_ids , batch_segment_ids , batch_labels = [] , [] , [] # 对于每一个batchsize的训练,包括 token  分隔符segment 标签label三者的序列
        for is_end, (text , label ) in self.sample(random):
            token_ids , segments_ids = tokenizer.encode(text , maxlen=maxlen) # [1,3,2,5,9,12,243,0,0,0]  编码token和分隔符segment序列,按照最大长度进行padding
            batch_token_ids.append(token_ids)
            batch_segment_ids.append(segments_ids)
            batch_labels.append([label])

            if len(batch_token_ids) == self.batch_size or is_end :
                batch_token_ids = sequence_padding(batch_token_ids)
                batch_segment_ids =sequence_padding(batch_segment_ids)
                batch_labels = sequence_padding(batch_labels)
                yield  [batch_token_ids , batch_segment_ids] ,batch_labels
                batch_token_ids,batch_segment_ids,batch_labels = [],[],[]

3.程序入口

if __name__ == '__main__':
    # 加载数据集
    train_data = load_data('./dataset/train.csv')
    test_data = load_data('./dataset/test.csv')
    # 转换数据集
    train_generator = data_generator(train_data,batch_size)
    test_generator = data_generator(test_data,batch_size)

    model = build_bert_model(config_path, checkpoint_path ,class_nums)
    print(model.summary())

    model.compile(
        loss='sparse_categorical_crossentropy', # 离散值损失函数 交叉熵损失
        optimizer=Adam(5e-6),
        metrics=['accuracy']
    )
    earlystop = keras.callbacks.EarlyStopping(
        monitor='var_loss',
        patience= 3,
        verbose=2,
        mode='min'
    )

    bast_model_filepath = './chinese_L-12_H-768_A-12/best_model.weights'
    checkpoint = keras.callbacks.ModelCheckpoint(
        bast_model_filepath ,
        monitor = 'val_loss',
        verbose= 1,
        save_best_only=True,
        mode='min'
    )
    model.fit_generator(
        train_generator.forfit(),
        steps_per_epoch=len(train_generator),
        epochs=10,
        validation_data=test_generator.forfit(),
        validation_steps=len(test_generator),
        shuffle=True,
        callbacks=[earlystop,checkpoint]
    )

    model.load_weights(bast_model_filepath)
    test_pred = []
    test_true = []
    for x, y in test_generator:
        p = model.predict(x).argmax(axis=1)
        test_pred.extend(p)

    test_true = test_data[:1].tolist()
    print(set(test_true))
    print(set(test_pred))

    target_names = [line.strip() for line in open('label','r',encoding='utf8')]
    print(classification_report(test_true , test_pred ,target_names=target_names))

五、运行

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/491696.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Qt-数据库开发-用户登录、后台管理用户

Qt-数据库开发-用户登录、后台管理用户 [1] Qt-数据库开发-用户登录、后台管理用户1、概述2、实现效果 [2] Qt使用SqlLite实现权限管理初始化数据库创建数据表插入数据可使用结构体对数据信息进行封装数据库查询函数为数据库更新数据函数为删除数据函数为 [3] 测试效果 [1] Qt-…

最详细的静态路由的原理和配置

第四章:静态路由 转发数据包是路由器的最主要功能。路由器转发数据包时需要查找路由表,管理员可以通过手工的方法在路由器中直接配置路由表,这就是静态路由。虽然静态路由不适合于在大的网络中使用,但是由于静态路由简单、路由器…

亚马逊云科技让数十亿的数据清洗、转移和查询只需要10分钟

随着数字经济对经济社会的发展贡献愈渐增多,数字金融作为数字经济的有机组成部分和重要支撑,也正成为金融领域竞争与合作的制高点。但在数据要素推动数字金融高速发展的同时,逐渐复杂的互联网环境与日益增强的金融监管力度,对数字…

FL Studio21最新中文版本下载及详细安装教程

FL Studio21最新中文版本是一款专业的音乐制作软件,软件支持录音、音频剪辑、混音、编曲等众多实用功能,可以让你的电脑化身为专业的录音室,进行音乐的录制和剪辑工作,帮助用户轻松创作出各种优秀的音乐作品。 FL Studio21中文版…

亚马逊云科技携手普华永道,推出健康及生命科学行业出海合规指南

自2022年起,国内医疗健康行业的投融资热度略有降低。但医疗行业投资结构正面临转型,投资逐步呈现全球化布局的趋势,在跨境合作领域仍持续释放活力。同时,随着药品集采、医保谈判持续推进,越来越多的中国健康及生命科学…

基于RK3588+TensorFlow的人工智能跨模态行人重识别方法及应用

摘要: 跨模态行人重识别技术(cm-ReID)旨在可见光、红外等不同模态图像中识别出同一个人,其在人 机协同、万物互联、跨界融合、万物智能的智能系统与装备中有重要应用。提出一种数据增强的跨模态行人 重识别方法,在波长…

【Vue 基础】尚品汇项目-10-Search模块中商品分类与过渡动画

一、商品导航的显示与隐藏 打开“src/componetnts/TypeNav/index.vue”,让商品导航默认为显示 在TypeNav组件挂载完毕时,判断当前的路由是否是“/home”,如果不是“/home”,就将分类导航隐藏 当鼠标移入时 移入时让商品导航显示 …

如何在Windows上轻松安全的将数据从HDD迁移到SSD?

当你打算升级硬盘时,如何将数据从HDD迁移到SSD?你可以使用一款免费的软件将所有数据从一个硬盘克隆到另一个硬盘。 为什么要将数据从HDD迁移到SSD? HDD(机械硬盘)和SSD(固态硬盘)是目前常用…

java 学习日记

今天先搞题目 给你一个points 数组,表示 2D 平面上的一些点,其中 points[i] [xi, yi] 。 连接点 [xi, yi] 和点 [xj, yj] 的费用为它们之间的 曼哈顿距离 :|xi - xj| |yi - yj| ,其中 |val| 表示 val 的绝对值。 请你返回将所…

DS1302芯片介绍

低功耗时钟芯片DS1302可以对年、月、日、时、分、秒进行计时,且具有闰年补偿等多种功能。 DS1302的性能特性: 实时时钟,可对秒、分、时、日、周、月以及带闰年补偿的年进行计数; 用于高速数据暂存的318位RAM; 最少引脚…

Redis --- 持久化、主从

一、Redis持久化 Redis有两种持久化方案: RDB持久化 AOF持久化 1.1、RDB持久化 RDB全称Redis Database Backup file(Redis数据备份文件),也被叫做Redis数据快照。简单来说就是把内存中的所有数据都记录到磁盘中。当Redis实例故…

第 5 章 HBase 优化

5.1 RowKey 设计 一条数据的唯一标识就是 rowkey,那么这条数据存储于哪个分区,取决于 rowkey 处于 哪个一个预分区的区间内,设计 rowkey的主要目的 ,就是让数据均匀的分布于所有的 region 中,在一定程度上防止数据倾斜…

年前无情被裁,我面试大厂的这3个月....

春招接近尾声,即将远去的“金三银四”今年也变成了“铜三铁四”。 大厂不断缩招,不容忽视的疫情影响,加上不断攀升的毕业生人数,各种需要应对的现实问题让整个求职季难上加难。 在这个异常残酷的求职季,很多人的困惑…

阿里系App抓包详细分析

InnerMtopInitTask OpenMtopInitTask ProductMtopInitTask 三个实现分别对应的instanceId为:OPEN、INNER、PRODUCT,咱们主要看InnerMtopInitTask这个实现,分析里面重要的初始化步骤,最后再使用Charles完成抓包。 IMtopInitTas…

发帖引蜘蛛:让你的网站在搜索引擎中的曝光率翻倍!

在当今的数字时代,SEO已成为提高网站曝光率和流量的重要手段。发帖引蜘蛛是一种有效的SEO技术,它可以让您的网站在搜索引擎中的曝光率翻倍,从而为您的业务带来更多的流量和潜在客户。 发帖引蜘蛛是一种简单易学的技术,它需要您在…

SPSS如何进行信度分析之案例实训?

文章目录 0.引言1.信度分析2.多维刻度分析 0.引言 因科研等多场景需要进行绘图处理,笔者对SPSS进行了学习,本文通过《SPSS统计分析从入门到精通》及其配套素材结合网上相关资料进行学习笔记总结,本文对信度分析进行阐述。 1.信度分析 &#…

【ROS】如何让ROS中节点实现数据交换Ⅰ--ROS话题通信

Halo,这里是Ppeua。平时主要更新C语言,C,数据结构算法…感兴趣就关注我吧!你定不会失望。 目录 0.ROS文件系统及常用指令1.话题通信概念2.利用标准消息类型实现话题通信实现(python)2.1发布方实现2.2订阅方实现 3.利用自定义消息类…

[Dubbo] 重要接口与类

文章目录 1.dubbo的整体调用链路2.dubbo的源码整体设计3.重要接口和类 1.dubbo的整体调用链路 消费者通过Interface进行方法调用,统一交由消费者的Proxy处理(Proxy通过ProxyFactory来进行代理对象的创建) Proxy调用Filter模块,做…

搞懂 API ,API 分类全知道

API,即应用程序编程接口,是为了方便应用程序之间的数据和功能交互而设计的一些标准方法。API 的分类可以从多个维度进行,我会对 API 的分类维度进行简单的介绍。 根据使用方式的不同 通常情况下,API 可以分为两种使用方式&#…

【LeetCode】1143. 最长公共子序列

1.问题 给定两个字符串 text1 和 text2,返回这两个字符串的最长 公共子序列 的长度。如果不存在 公共子序列 ,返回 0 。 一个字符串的 子序列 是指这样一个新的字符串:它是由原字符串在不改变字符的相对顺序的情况下删除某些字符&#xff0…