PhysioNet2017分类的代码实现

news2024/12/28 5:33:18

PhysioNet2017数据集介绍可参考文章:https://wendy.blog.csdn.net/article/details/128686196。本文主要介绍利用PhysioNet2017数据集对其进行分类的代码实现。

目录

    • 一、数据集预处理
    • 二、训练
      • 2.1 导入数据集并进行数据裁剪
      • 2.2 划分训练集、验证集和测试集
      • 2.3 设置训练网络和结构
      • 2.4 开始训练
      • 2.5 查看训练结果
    • 三、测试

一、数据集预处理

首先需要进行数据集预处理。

train2017文件夹中存放相应的训练集,其中REFERENCE.csv文件存放分类结果。分类结果有四种,分别是:N(Normal,正常),A(AF,心房颤动),O(Other,其他节律),~(Noisy,噪声记录)

首先需要划分训练集、验证集和测试集:

# 加载数据集,默认80%训练集和20%测试集
def load_physionet(dir_path, test=0.2,vali=0, shuffle=True):
    "return train_X, train_y, test_X, test_y, valid_X, valid_y"
    if dir_path[-1]!='/': dir_path = dir_path+'/'
    ref = pd.read_csv(dir_path+'REFERENCE.csv',header=None) # 分类结果
    label_id = {'N':0, 'A':1, 'O':2, '~':3 }#Normal, AF, Other, Noisy
    X = []
    y = []
    test_X = None
    test_y = None
    valid_X = None
    valid_y = None
    
    for index, row in ref.iterrows():
        file_prefix = row[0]
        mat_file = dir_path+file_prefix+'.mat'
        hea_file = dir_path+file_prefix+'.hea'
        data = loadmat(mat_file)['val']

        data = data.squeeze()
        data = np.nan_to_num(data)
        data = data-np.mean(data)
        data = data/np.std(data)

        
        X.append( data )
        y.append( label_id[row[1]] )
    data_n = len(y)
    print(data_n)

    X = np.array(X)
    y = np.array(y)
        
    if shuffle:
        shuffle_idx = list(range(data_n))
        random.shuffle(shuffle_idx)
        X = X[shuffle_idx]
        y = y[shuffle_idx]
   
    valid_n = int(vali*data_n)  
    test_n = int(test*data_n)
    assert (valid_n+test_n <= data_n) , "Dataset has no enough samples!"

    if vali>0:
        valid_X = X[0:valid_n]
        valid_y = y[0:valid_n]
        
    if test>0:
        test_X = X[valid_n: valid_n+test_n]
        test_y = y[valid_n: valid_n+test_n]
    
    if vali>0 or test>0:
        X = X[valid_n+test_n: ]
        y = y[valid_n+test_n: ]
        
    #print('Train: %d, Test: %d, Validation: %d   (%s)'%((data_n-valid_n-test_n), test_n, valid_n, 'shuffled' if shuffle else 'unshuffled'))
    return np.squeeze(X), np.squeeze(y), np.squeeze(test_X), np.squeeze(test_y), np.squeeze(valid_X), np.squeeze(valid_y)

加载数据集并将其保存为mat文件:

def merge_data(dir_path, test=0.2, train_file='train',test_file='test',shuffle=True):
    train_X, train_y, test_X, test_y, _, _ = load_physionet(dir_path=dir_path, test=test, vali=0, shuffle=True) # 划分训练集、验证集和测试集
    # 数据集8528个记录  8528*0.8=6823,8528*0.2=1705
    train_data = {'data': train_X, 'label':train_y} # 6823
    test_data = {'data': test_X, 'label':test_y}    # 1705
    # 保存训练集和测试集为mat文件
    savemat(train_file,train_data)
    savemat(test_file, test_data)
    
    print("[!] Train set saved as %s"%(train_file))
    print("[!] Test set saved as %s"%(test_file))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dir',type=str,default='training2017',help='the directory of dataset')
    parser.add_argument('--test_set',type=float,default=0.2,help='The percentage of test set')
    args = parser.parse_args()

    merge_data(args.dir, test=args.test_set)

if __name__=='__main__':
    main()

运行之后将PhysioNet2017心电图数据集保存为train.mat和test.mat。
在这里插入图片描述

二、训练

2.1 导入数据集并进行数据裁剪

时序数据都需要进行相应的数据裁剪。裁剪函数如下:

def cut_and_pad(X, cut_size):
    n = len(X)
    X_cut = np.zeros(shape=(n, cut_size))   # (6823,300*30)
    for i in range(n):
        data_len = X[i].squeeze().shape[0]  # 每个数据的长度
        # cut if too long / padd if too short
        X_cut[i, :min(cut_size, data_len)] = X[i][0,  :min(cut_size, data_len)] # 每个长度裁剪为cut_size=9000个
    return X_cut

首先需要将处理后的数据集导入并进行数据裁剪。
训练集的数据尺寸为:(1, 6823);训练集的标签尺寸为:(1, 6823);【总数据量为8528个数据,训练集数据占比80%,即8528*80%=6823】
加载训练集train.mat,进行数据裁剪,裁剪长度为300x30=9000,即前9000个数据。代码如下:

training_set = loadmat('train.mat') # 加载训练集
X = training_set['data'][0]
y = training_set['label'][0].astype('int32')

#cut_size_start = 300 * 3
cut_size = 300 * 30

X = cut_and_pad(X, cut_size) 

裁剪后可以查看第一个数据的图像:
代码如下:

import matplotlib.pyplot as plt
plt.plot(range(cut_size),X[0])
plt.show()

效果图如下:
在这里插入图片描述

2.2 划分训练集、验证集和测试集

首先需要判断是否进行k折交叉验证,若进行k折交叉验证,下界为0上界为5(5折);若不进行k折交叉验证则下界为0上界为1(默认不进行交叉验证)。

# k-fold / train
if args.k_folder:
    low_border = 0
    high_border = 5
    F1_valid = np.zeros(5)
else:
    low_border = 0
    high_border = 1

然后利用get_sub_set函数根据是否进行交叉验证划分训练集和验证集,90%为训练集,10%为验证集。

# 划分训练集和验证集
def get_sub_set(X, y, k, K_folder_or_not):
    if not K_folder_or_not:     # False
        k_dataset_len = int(len(X) * 0.9)   # 6823*0.9=6140
        train_X = X[ : k_dataset_len]   # 6140
        train_y = y[ : k_dataset_len]
        valid_X = X[ k_dataset_len:]    # 683
        valid_y = y[ k_dataset_len:]
    else:
        k_dataset_len = int(len(X) / 5)
        if k == 0:
            valid_X = X[ : k_dataset_len ]
            valid_y = y[ : k_dataset_len ]
            train_X = X[ k_dataset_len :]
            train_y = y[ k_dataset_len :]
        else:
            print(k*k_dataset_len)
            valid_X = X[ k*k_dataset_len : (k+1)*k_dataset_len ]
            valid_y = y[ k*k_dataset_len : (k+1)*k_dataset_len ]
            train_X = np.concatenate((X[ : k*k_dataset_len] , X[(k+1)*k_dataset_len: ]), axis=0)
            train_y = np.concatenate((y[ : k*k_dataset_len] , y[(k+1)*k_dataset_len: ]), axis=0)
    return train_X, train_y, valid_X, valid_y

输出训练集长度和验证集长度查看信息。
在这里插入图片描述

2.3 设置训练网络和结构

网络架构利用ResNet实现,损失函数使用交叉熵损失函数softmax_cross_entropy,优化器利用Adam优化器实现。

加载模型时,如果有已经训练好的模型,则恢复模型:Model restored from checkpoints;否则,重新训练模型:Restore failed, training new model!

2.4 开始训练

开始训练代码如下:

    # 开始训练
    while True:
        total_loss = []
        ep = ep + 1
        for itr in range(0,len(train_X),batch_size):
            # prepare data batch
            if itr+batch_size>=len(train_X):
                cat_n = itr+batch_size-len(train_X)
                cat_idx = random.sample(range(len(train_X)),cat_n)
                batch_inputs = np.concatenate((train_X[itr:],train_X[cat_idx]),axis=0)
                batch_labels = np.concatenate((y_onehot[itr:],y_onehot[cat_idx]),axis=0)
            else:
                batch_inputs = train_X[itr:itr+batch_size]        
                batch_labels = y_onehot[itr:itr+batch_size]

            _, summary, cur_loss = sess.run([opt, merge, loss], {data_input: batch_inputs, label_input: batch_labels})
            total_loss.append(cur_loss)
            #if itr % 10==0:
            #    print('   iter %d, loss = %f'%(itr, cur_loss))
            #    saver.save(sess, args.ckpt)
            # 将所有日志写入文件
            summary_writer.add_summary(summary, global_step=ep)  # 将训练过程数据保存在summary中[train_loss]
        print('[*] epoch %d, average loss = %f'%(ep, np.mean(total_loss)))
        if not args.k_folder:
            saver.save(sess, 'checkpoints/model')

        # validation
        if ep % 5 ==0: #and ep!=0:
            err = 0
            n = np.zeros(class_num)
            N = np.zeros(class_num)
            correct = np.zeros(class_num)
            valid_n = len(valid_X)
            for i in range(valid_n):
                res = sess.run([logits], {data_input: valid_X[i].reshape(-1, cut_size,1)})
                # print(valid_y[i])
                # print(res)
                predicts  = np.argmax(res[0],axis=1)
                n[predicts] = n[predicts] + 1   
                N[valid_y[i]] = N[valid_y[i]] + 1
                if predicts[0]!= valid_y[i]:
                    err+=1
                else:
                    correct[predicts] = correct[predicts] + 1
            print("[!] %d validation data, accuracy = %f"%(valid_n, 1.0 * (valid_n - err)/valid_n))
            res = 2.0 * correct / (N + n)
            print("[!] Normal = %f, Af = %f, Other = %f, Noisy = %f" % (res[0], res[1], res[2], res[3]))
            print("[!] F1 accuracy = %f" % np.mean(2.0 * correct / (N + n)))
            if args.k_folder:
                F1_valid[k] = np.mean(res)
        
        if np.mean(total_loss) < 0.2 and ep % 5 == 0:
            # 保存内容
            summary_writer.close()
            # 将total_loss保存为csv
            tl = pd.DataFrame(data=total_loss)
            tl.to_csv('loss.csv')
            break

2.5 查看训练结果

利用tensorboard可以查看训练的loss损失,损失图像如下:
在这里插入图片描述
loss阈值设置为0.2,最后的准确率如下:
在这里插入图片描述

三、测试

训练完成后,开始测试。
首先需要将处理后的测试集导入并进行数据裁剪。
测试集的数据尺寸为:(1, 1705);测试集的标签尺寸为:(1, 1705);【总数据量为8528个数据,测试集数据占比20%,即8528*20%=1705】
加载测试集test.mat,进行数据裁剪,裁剪长度为300x30=9000,即前9000个数据。代码如下:

training_set = loadmat('test.mat')
X = training_set['data'][0]     # (1705,)
y = training_set['label'][0].astype('int32')    # (1705,)

cut_size = 300 * 30
n = len(X)
X_cut = np.zeros(shape=(n, cut_size))
for i in range(n):
    data_len = X[i].squeeze().shape[0]
    X_cut[i, :min(cut_size, data_len)] = X[i][0, :min(cut_size, data_len)]
X = X_cut

然后将数据输入训练好的网络进行测试:

# reconstruct model
test_input = tf.placeholder(dtype='float32',shape=(None,cut_size,1))
res_net = ResNet(test_input, class_num=class_num)

tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
sess = tf.Session(config=tf_config)

sess.run(tf.global_variables_initializer())
saver =  tf.train.Saver(tf.global_variables())

# restore model
if os.path.exists(args.check_point_folder + '/'):
    saver.restore(sess, args.check_point_folder + '/model')
    print('Model successfully restore from ' + args.check_point_folder + '/model')
else: print('Restore failed. No model found!')

测试结束后,需要查看测试准确率,F1-score等诸多指标,这里首先需要定义三个变量:

PreCount = np.zeros(class_num)  # 每种类型的预测数量
RealCount = np.zeros(class_num) # 每种类型的数量
CorrectCount = np.zeros(class_num)  # 每种类型预测正确数量

PreCount用于存放每种类型的预测结果,RealCount用于存放每种类型的数量,CorrectCount用于存放每种类型预测正确的数量。

最后查看所有结果,F1-score、Accuracy,Precision,Recall,Time结果如下:(这是loss为0.2时的结果)
在这里插入图片描述


ok,以上便是本文的全部内容了,如果想要获取完整代码,可以参考资源:https://download.csdn.net/download/didi_ya/87444631

如果想重新训练,请删除checkpoints文件夹内所有文件和logs文件夹内所有文件(不要删除logs文件夹)并重新运行train.py程序,若不删除,则继续使用之前模型训练,logs文件夹主要用于存放tensorboard可视化图像,若不删除重新运行程序,可能会重新生成可视化图像,影响效果。188行可以指定最终的loss,如果想精确度高,请将loss尽量调小。tensorflow版本:1.x。(我使用的是tensorflow1.15)
遇到任何问题欢迎私信咨询~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/341753.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C语言(C语言结构基础使用)

目录 一.结构 1.结构声明 2.初始化结构 3.访问结构成员 4.结构的初始化器 5.定义无结构标记 6.结构数组 7.嵌套结构 8.复合字面量和结构&#xff08;C99&#xff09; 9.伸缩性数组成员 10.伸缩性数组得特殊处理请求 11.匿名结构&#xff08;C11&#xff09; 12.使用结构数组得函…

RiProRiProV2主题美化顶部增加一行导航header导航通知

背景: 有些网站的背景顶部有一行罪行公告,样式不错,希望自己的网站也借鉴过来,本教程将指导如何操作,并调整成自己想要的样式。 比如网友搭的666资源站 xd素材中文网

【C语言必经之路——第11节】初阶指针(2)

五、指针的运算1、指针与整数相加减看一下下面的代码&#xff1a;#include<stdio.h> int my_strlen(char* str) {int count0;while(*str!\0){count;str;//指针加减整数}return count; } int main() {int lenmy_strlen("abcdef");printf("%d\n",len);…

OpenCV实战(10)——积分图像详解

OpenCV实战&#xff08;10&#xff09;——积分图像详解0. 前言1. 积分图像计算2. 自适应阈值2.1 固定阈值的缺陷2.2 使用自适应阈值2.3 其它自适应阈值计算方法2.4 完整代码3. 使用直方图进行视觉跟踪3.1 查找目标对象3.2 完整代码小结系列链接0. 前言 我们知道直方图是通过遍…

方法递归调用

&#x1f3e1;个人主页 &#xff1a; 守夜人st &#x1f680;系列专栏&#xff1a;Java …持续更新中敬请关注… &#x1f649;博主简介&#xff1a;软件工程专业&#xff0c;在校学生&#xff0c;写博客是为了总结回顾一些所学知识点 ✈️推荐一款模拟面试&#xff0c;刷题神器…

【C++设计模式】学习笔记(4):观察者模式 Observer

目录 简介动机(Motivation)模式定义结构(Structure)要点总结笔记结语简介 Hello! 非常感谢您阅读海轰的文章,倘若文中有错误的地方,欢迎您指出~ ଘ(੭ˊᵕˋ)੭ 昵称:海轰 标签:程序猿|C++选手|学生 简介:因C语言结识编程,随后转入计算机专业,获得过国家奖学金…

渣土车智能识别检测 yolov5

渣土车智能识别检测通过yolov5网络模型深度学习技术&#xff0c;渣土车智能识别检测对禁止渣土车通行现场画面中含有渣土车时进行自动识别监测&#xff0c;并自动抓拍告警。YOLOv5是一种单阶段目标检测算法&#xff0c;该算法在YOLOv4的基础上添加了一些新的改进思路&#xff0…

【Redis场景3】缓存穿透、击穿问题

场景问题及原因 缓存穿透&#xff1a; 原因&#xff1a;客户端请求的数据在缓存和数据库中不存在&#xff0c;这样缓存永远不会生效&#xff0c;请求全部打入数据库&#xff0c;造成数据库连接异常。 解决思路&#xff1a; 缓存空对象 对于不存在的数据也在Redis建立缓存&a…

spark01-内存数据分区数量个数原理

原始代码如下&#xff1a;val conf: SparkConf new SparkConf().setMaster("local[*]").setAppName("wordcount")val scnew SparkContext(conf)val rdd: RDD[Int] sc.makeRDD(List(1,2,3,4)//将处理的数据保存分区文件rdd.saveAsTextFile("output2&…

分布式数据库(ShardingSphere)

单库单表数据量过大导致的问题与应对传统的将数据集中存储至单一数据节点的解决方案&#xff0c;在容量、性能、可用性和运维成本这三方面已经难于满足互联网的海量数据场景。我们在单库单表数据量超过一定容量水位的情况下&#xff0c;索引树层级增加&#xff0c;磁盘 IO 也很…

数据库(六): MySQL的主从复制和读写分离

文章目录一、为什么要使用主从复制和读写分离二、主从复制的原理三、如何实现主从复制3.1 master配置3.2 slave配置3.3 测试主从复制四、读写分离五、缺点一、为什么要使用主从复制和读写分离 注意到主从复制和读写分离一般是一起使用的。目的很简单&#xff0c;就是提高数据库…

Python:路径之谜(DFS剪枝)

题目描述 小张冒充 X 星球的骑士&#xff0c;进入了一个奇怪的城堡。 城堡里边什么都没有&#xff0c;只有方形石头铺成的地面。 假设城堡地面是 nn 个方格。如下图所示。 按习俗&#xff0c;骑士要从西北角走到东南角。可以横向或纵向移动&#xff0c;但不能斜着走&#xf…

Java类和对象超详细整理,适合新手入门

目录 一、驼峰命名法 二、Java注释 三、转义符 四、Java程序它的基本结构是什么&#xff1f; 五、Java中的类 六、创建类 七、定义main方法 八、执行代码输出语句 九、Java中的对象 十、创建对象 十一、类与对象的关系 一、驼峰命名法 包名&#xff1a;多单词组成所…

常用类详解(二)StringBuffer

StringBuffer类&#xff1a; 基本介绍&#xff1a; java.lang.StringBuffer代表可变的字符序列&#xff0c;可以对字符串内容进行增删 很多方法与String相同&#xff0c;但StringBuffer是可变长度的。 StringBuffer是一个容器。 我们进行查看StringBuffer&#xff0c;如下…

fpga设计中如何防止信号被优化

本文分别对quartus和vivado防止信号被优化的方法进行介绍。 为什么要防止信号被优化 ​ 在FPGA开发调试阶段&#xff0c;经常遇到这样的情况&#xff0c;需要临时添加信号&#xff0c;观察信号变化&#xff0c;用来定位代码中存在的问题&#xff0c;很多时候这些临时添加的信…

sg3_utils arm64 静态编译

需求背景 在进行ufs等scsi device测试时&#xff0c;需要进行power mode切换等测试&#xff0c;因此需要有一个简单地工具集来向scsi device&#xff08;ufs接口&#xff09;发送scsi命令&#xff0c;比如 scsi reset命令等。在网上调研后发现sg3_utils是一个比较全面的工具。…

本地代码提交至gitee仓库

1、新建仓库 新建一个私人访问的仓库 2、创建公钥 点开cmd 输入ssh-keygen -t rsa -C "xxxxxxxxxx.com" 邮箱填入自己使用的即可。 输入完毕后&#xff0c;连按三次enter。 命令就会执行完毕&#xff0c;会出现这个界面 此时已经代表ssh公钥已经创建完毕。 公…

自动驾驶TPM技术杂谈 ———— 摄像头标定

文章目录介绍摄像头内参标定摄像头模型的建立摄像头坐标系与环境坐标系的转换图像坐标系与图像像素坐标系小孔成像与图像物理坐标系环境坐标系与图像像素坐标系的转换摄像头畸变矫正常见内参标定方法平面标定自标定摄像头间外参标定介绍 标定传感器是自动驾驶感知系统中不可缺少…

Springboot集成工作流Activity

介绍 官网&#xff1a;https://www.activiti.org/ 一 、工作流介绍 1.工作流&#xff08;workflow&#xff09; 就是通过计算机对业务流程自动化执行管理&#xff0c;它主要解决的是“使在多个参与这之间按照某种预定义规则自动化进行传递文档、信息或任务的过程&#xff0c…

儿童绘本馆图书借阅租赁知识付费小程序源码交流

1.分类图书 2.书单推荐 4.会员卡次、期限购买 5.借阅时间选择 6.积分签到 7.优惠Q领取 前端uniapp开发 后端thinkphp开发 完全开源 <template> <view class"sp-section sp-index"> <!-- search --> <view class&qu…