从0搭建ECG深度学习网络

news2024/9/27 17:24:47

本篇博客介绍使用Python语言的深度学习网络,从零搭建一个ECG深度学习网络。

任务

本次入门的任务是,筛选出MIT-BIH数据集中注释为[‘N’, ‘A’, ‘V’, ‘L’, ‘R’]的数据作为本次数据集,然后按照8:2的比例划分为训练集,验证集。最后送入RCNN模型进行训练。

1. 数据集介绍

本次使用大名鼎鼎的MIT-BIH Arrhythmia Database数据集。下载地址:https://physionet.org/content/mitdb/1.0.0/

MIT系列有很多数据集,都可以在生理网:https://physionet.org/about/database/ 上找到下载地址。本次使用的MT-BIH心律失常数据库拥有48条心电记录,且每个记录的时长是30分钟。这些记录来自于47名研究对象。这些研究对象包括25名男性和22名女性,其年龄介于23到89岁(其中记录201与202来自于同一个人)。信号的采样率为360赫兹,AD分辨率为11比特。对于每条记录来说,均包含两个通道的信号。第一个通道一般为MLⅡ导联(记录102和104为V5导联);第二个通道一般为V1导联(有些为V2导联或V5导联,其中记录124号为Ⅴ4导联)。为了保持导联的一致性,往往在研究中采用MLⅡ导联。

在生理网:https://physionet.org/about/database/上,我们可以看到数据集更加详细的说明。比如:

MIT-BIH 数据集每个单独病人的说明:https://www.physionet.org/physiobank/database/html/mitdbdir/mitdbdir.htm

MIT-BIH 数据集每个单独病人的整个数据以及注释的可视化:https://www.physionet.org/physiobank/database/html/mitdbdir/mitdbdir.htm

下载MIT-BIH 数据集之后,我们需要知晓以下几点:

  1. 从100-234不连续号码,一共48个病人。每个病人有三个文件(.dat,.atr,*.hea),包含有两路心电信号,一个注释。
  2. 有专门库读取MIT-BIH 数据集,叫做 wfdb。所以不要担心文件后缀的陌生感。
  3. 对心电图的标注样式如上图,“A"代表心房早搏,”."代表正常。整个数据集标注有40多种符号,表示40多种心拍状态。

2. 提取数据集

提取之前,先安装必要的库wfdb。wfdb详细介绍

pip install wfdb

这个库非常强大,打印数据信息,读取数据,绘制心电波形图,都可以靠它完成。
现在我们的划分步骤是:

  1. 提取出所有心电图数据点,心电图注释点
  2. 筛选出所有心电图注释点中仅为[‘N’, ‘A’, ‘V’, ‘L’, ‘R’]某一类的注释点
  3. 截取心电图数据中标记为[‘N’, ‘A’, ‘V’, ‘L’, ‘R’]某一类的点,在点周围长度为300的数据
  4. 将得到的数据进行维度处理,送入DataLoader()函数,完成模型对数据的认可。

3. 定义模型

本次使用的模型是输入大小为300,3层循环,隐藏层大小50。

'''
模型搭建
'''
class RnnModel(nn.Module):
    def __init__(self):
        super(RnnModel, self).__init__()
        '''
        参数解释:(输入维度,隐藏层维度,网络层数)
        '''
        self.rnn = nn.RNN(300, 50, 3, nonlinearity='tanh')
        self.linear = nn.Linear(50, 5)

    def forward(self, x):
        r_out, h_state = self.rnn(x)
        output = self.linear(r_out[:,-1,:])  # 将 RNN 层的输出 r_out 在最后一个时间步上的输出(隐藏状态)传递给线性层
        return output

model = RnnModel()

4. 全部训练代码

'''
导入相关包
'''
import wfdb
import pywt
import seaborn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import torch
import torch.utils.data as Data
from torch import nn

'''
加载数据集
'''

# 测试集在数据集中所占的比例
RATIO = 0.2

# 小波去噪预处理
def denoise(data):
    # 小波变换
    coeffs = pywt.wavedec(data=data, wavelet='db5', level=9)
    cA9, cD9, cD8, cD7, cD6, cD5, cD4, cD3, cD2, cD1 = coeffs
    # 阈值去噪
    threshold = (np.median(np.abs(cD1)) / 0.6745) * (np.sqrt(2 * np.log(len(cD1))))
    cD1.fill(0)
    cD2.fill(0)
    for i in range(1, len(coeffs) - 2):
        coeffs[i] = pywt.threshold(coeffs[i], threshold)
    # 小波反变换,获取去噪后的信号
    rdata = pywt.waverec(coeffs=coeffs, wavelet='db5')
    return rdata

# 读取心电数据和对应标签,并对数据进行小波去噪
def getDataSet(number, X_data, Y_data):
    ecgClassSet = ['N', 'A', 'V', 'L', 'R']
    # 读取心电数据记录
    # print("正在读取 " + number + " 号心电数据...")
    # 读取MLII导联的数据
    record = wfdb.rdrecord('C:/mycode/dataset_make/mit-bih-arrhythmia-database-1.0.0/' + number, channel_names=['MLII'])
    data = record.p_signal.flatten()
    rdata = denoise(data=data)
    # 获取心电数据记录中R波的位置和对应的标签
    annotation = wfdb.rdann('C:/mycode/dataset_make/mit-bih-arrhythmia-database-1.0.0/' + number, 'atr')
    Rlocation = annotation.sample
    Rclass = annotation.symbol
    # 去掉前后的不稳定数据
    start = 10
    end = 5
    i = start
    j = len(annotation.symbol) - end
    # 因为只选择NAVLR五种心电类型,所以要选出该条记录中所需要的那些带有特定标签的数据,舍弃其余标签的点
    # X_data在R波前后截取长度为300的数据点
    # Y_data将NAVLR按顺序转换为01234
    while i < j:
        try:
            # Rclass[i] 是标签
            lable = ecgClassSet.index(Rclass[i])  # 这一步就是相当于抛弃了不在ecgClassSet的索引
            # 基于经验值,基于R峰向前取100个点,向后取200个点
            x_train = rdata[Rlocation[i] - 100:Rlocation[i] + 200]
            X_data.append(x_train)
            Y_data.append(lable)
            i += 1
        except ValueError:
            i += 1
    return

# 加载数据集并进行预处理
def loadData():
    numberSet = ['100', '101', '103', '105', '106', '107', '108', '109', '111', '112', '113', '114', '115',
                 '116', '117', '119', '121', '122', '123', '124', '200', '201', '202', '203', '205', '208',
                 '210', '212', '213', '214', '215', '217', '219', '220', '221', '222', '223', '228', '230',
                 '231', '232', '233', '234']
    dataSet = []
    lableSet = []
    for n in numberSet:
        getDataSet(n, dataSet, lableSet)
    # 转numpy数组,打乱顺序
    dataSet = np.array(dataSet).reshape(-1, 300)  # 转化为二维,一行有 300 个,行数需要计算
    lableSet = np.array(lableSet).reshape(-1, 1)  # 转化为二维,一行有   1 个,行数需要计算
    train_ds = np.hstack((dataSet, lableSet))  # 将数据集和标签集水平堆叠在一起,(92192, 300) (92192, 1) (92192, 301)
    # print(dataSet.shape, lableSet.shape, train_ds.shape)  # (92192, 300) (92192, 1) (92192, 301)
    np.random.shuffle(train_ds)
    # 数据集及其标签集
    X = train_ds[:, :300].reshape(-1, 1, 300)  # (92192, 1, 300)
    Y = train_ds[:, 300]  # (92192)

    # 测试集及其标签集
    shuffle_index = np.random.permutation(len(X))  # 生成0-(X-1)的随机索引数组

    # 设定测试集的大小 RATIO是测试集在数据集中所占的比例
    test_length = int(RATIO * len(shuffle_index))
    # 测试集的长度
    test_index = shuffle_index[:test_length]
    # 训练集的长度
    train_index = shuffle_index[test_length:]
    X_test, Y_test = X[test_index], Y[test_index]
    X_train, Y_train = X[train_index], Y[train_index]
    return X_train, Y_train, X_test, Y_test

X_train, Y_train, X_test, Y_test = loadData()

'''
数据处理
'''
train_Data = Data.TensorDataset(torch.Tensor(X_train), torch.Tensor(Y_train)) # 返回结果为一个个元组,每一个元组存放数据和标签
train_loader = Data.DataLoader(dataset=train_Data, batch_size=128)
test_Data = Data.TensorDataset(torch.Tensor(X_test), torch.Tensor(Y_test)) # 返回结果为一个个元组,每一个元组存放数据和标签
test_loader = Data.DataLoader(dataset=test_Data, batch_size=128)


'''
模型搭建
'''
class RnnModel(nn.Module):
    def __init__(self):
        super(RnnModel, self).__init__()
        '''
        参数解释:(输入维度,隐藏层维度,网络层数)
        '''
        self.rnn = nn.RNN(300, 50, 3, nonlinearity='tanh')
        self.linear = nn.Linear(50, 5)

    def forward(self, x):
        r_out, h_state = self.rnn(x)
        output = self.linear(r_out[:,-1,:])  # 将 RNN 层的输出 r_out 在最后一个时间步上的输出(隐藏状态)传递给线性层
        return output

model = RnnModel()

'''
设置损失函数和参数优化方法
'''
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

'''
模型训练
'''
EPOCHS = 5
for epoch in range(EPOCHS):
    running_loss = 0
    for i, data in enumerate(train_loader):
        inputs, label = data
        y_predict = model(inputs)
        loss = criterion(y_predict, label.long())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    # 预测
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, label = data
            y_pred = model(inputs)
            _, predicted = torch.max(y_pred.data, dim=1)
            total += label.size(0)
            correct += (predicted == label).sum().item()

    print(f'Epoch: {epoch + 1}, ACC on test: {correct / total}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/889545.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【QT+ffmpeg】QT+ffmpeg 环境搭建

1.qt下载地址 download.qt.io/archive/ 2. win10sdk 下载 https://developer.microsoft.com/en-us/windows/downloads/windows-sdk/ 安装 debug工具路径 qtcreater会自动识别 调试器选择

微服务-Fegin

在之前我们两服务之间调用的时候用的是restTemplate,但是这个方式调用存在很多的问题 String url "http://userservice/user/" order.getUserId(); 代码可读性差&#xff0c;编码体验不统一参数复杂的url难以维护 所以我们大力推出我们今天的主角--Fegin Feign是…

大数据时代下的GIS:数据驱动的地理智能

在当今信息时代&#xff0c;大数据技术的飞速发展正深刻影响着各个领域&#xff0c;地理信息系统&#xff08;GIS&#xff09;作为其中之一&#xff0c;也在大数据的浪潮下展现出了新的活力和潜力。 大数据为GIS注入了强大的数据能量 过去&#xff0c;地理信息数据的收集和存储…

B站一个月内涨粉57万,情景剧再迎黑马UP主

飞瓜数据&#xff08;B站版&#xff09;【达人排行榜】-【涨粉榜】显示&#xff0c;继上次的【成长榜】第三之后&#xff0c;UP主-显眼宝-在本周登上了涨粉榜第三的位置&#xff0c;单周涨粉就超过22万。 从UP主的粉丝趋势图来看&#xff0c;自7月起UP主账号就多次出现涨粉峰值…

Sealos 国内集群正式上线,可一键运行 LLama2 中文版大模型!

2023 年 7 月 19 日&#xff0c;MetaAI 宣布开源旗下的 LLama2 大模型&#xff0c;Meta 首席科学家、图灵奖得主 Yann LeCun 在推特上表示 Meta 此举可能将改变大模型行业的竞争格局。一夜之间&#xff0c;大模型格局再次发生巨变。 不同于 LLama&#xff0c;LLama2 免费可商用…

【仿写框架之仿写Tomact】三、使用socket监听配置文件中的端口接收HTTP请求并创建线程池处理请求

文章目录 1、自定义配置文件2、使用DOM解析XML文件3、创建Tomcat启动方法&#xff08;解析配置文件、创建线程池、socket循环监听端口&#xff09; 1、自定义配置文件 首先在main文件下创建资源目录resources&#xff1a; 在resource目录下创建server.xml文件&#xff0c;并写…

第二章-自动驾驶卡车-自动驾驶卡车前装量产的要求

1、自动驾驶卡车的特点与挑战 重卡主要运行在相对封闭的高速公路&#xff0c;相较城市道路场景看似更简单。但是&#xff0c;由于重卡特有的物理特性、运行环境和商业运营要求&#xff0c;相较于乘用车的自动驾驶系统&#xff0c;重卡的自动驾驶系统对车辆的感知距离和精度、系…

【操作系统】24王道考研笔记——第一章 计算机系统概述

第一章 计算机系统概述 一、操作系统基本概念 1.1 定义 1.2 特征 并发 &#xff08;并行&#xff1a;指两个或多个事件在同一时刻同时发生&#xff09; 共享 &#xff08;并发性指计算机系统中同时存在中多个运行着的程序&#xff0c;共享性指系统中的资源可供内存中多个并…

Vue实现动态可视化

1、可视化效果&#xff1a; 水晶球环绕地球旋转&#xff1b; 2、实现 <template><div class"container"><div class"header-body"><div class"header-left"></div><div class"header-title">网…

删除了很久的备忘录怎么恢复?快速恢复误删备忘录的方法

习惯于经常使用手机来记录事情的网友&#xff0c;应该都对备忘录软件不陌生&#xff0c;因为我们可以直接在备忘录中记录心情日记、读书笔记、生活琐事以及其他重要的事情&#xff0c;这样就不用担心自己会忘记了。不过也有一些备忘录用户表示自己在整理无效备忘录时&#xff0…

gtsam使用-Pose2 SLAM

Pose2 SLAM Pose2 SLAM是一种昂进行同时定位与地图构建&#xff08;SLAM&#xff09;的一种简单方法是仅仅融合连续机器人姿态之间的相对姿态测量。这种不涉及地标的SLAM变体通常被称为 "Pose SLAM"。 %pip -q install gtbook # also installs latest gtsam pre-re…

安防监控视频汇聚平台EasyCVR视频平台调用iframe地址无法播放的问题解决方案

安防监控视频汇聚平台EasyCVR基于云边端一体化架构&#xff0c;具有强大的数据接入、处理及分发能力&#xff0c;可提供视频监控直播、云端录像、视频云存储、视频集中存储、视频存储磁盘阵列、录像检索与回看、智能告警、平台级联、云台控制、语音对讲、AI算法中台智能分析无缝…

Docker容器:docker基础概述、docker安装、docker网络

文章目录 一.docker容器概述1.什么是容器2. docker与虚拟机的区别2.1 docker虚拟化产品有哪些及其对比2.2 Docker与虚拟机的区别 3.Docker容器的使用场景4.Docker容器的优点5.Docker 的底层运行原理6.namespace的六项隔离7.Docker核心概念 二.Docker安装 及管理1.安装 Docker1.…

[C++] string类的介绍与构造的模拟实现,进来看吧,里面有空调

文章目录 1、string类的出现1.1 C语言中的字符串 2、标准库中的string类2.1 string类 3、string类的常见接口说明及模拟实现3.1 string的常见构造3.2 string的构造函数3.3 string的拷贝构造3.4 string的赋值构造 4、完整代码 1、string类的出现 1.1 C语言中的字符串 C语言中&…

【宝藏系列】嵌入式 C 语言代码优化技巧【超详细版】

【宝藏系列】嵌入式 C 语言代码优化技巧【超详细版】 文章目录 【宝藏系列】嵌入式 C 语言代码优化技巧【超详细版】前言整形数除法和取余数合并除法和取余数通过2的幂次进行除法和取余数取模的一种替代方法使用数组下标全局变量使用别名变量的生命周期分割变量类型局部变量指针…

Lnton羚通关于如何使用OpenCV-Python在直方图中查找显示分析

什么是直方图&#xff1f; 直方图是统计图像中像素亮度或颜色等分布的一种常用工具&#xff0c;几乎所有图像处理的工具都提供了这种工具&#xff0c;X轴表示 0~255&#xff08;刻度大小与Bin设置有关系&#xff09;&#xff0c;Y轴统计个数&#xff08;频率&#xff09;。 【…

VIOOVI:标准的作业规范要求是什么?标准化作业规范怎么写?

本文围绕“标准化作业”展开论述&#xff0c;分享一些关于标准化作业以及标准的作业规范等相关知识。 什么是标准化作业&#xff1f; 标准化作业是一种以人的行为为中心&#xff0c;在一个操作序列中有效地进行生产而没有浪费的操作方法。 标准化作业的前提即&#xff1a;关注…

从零开始打造家装预约咨询小程序

在如今互联网高度发达的时代&#xff0c;家装行业也逐渐意识到了线上渠道的重要性。为了更好地服务客户&#xff0c;提高用户体验&#xff0c;越来越多的家装公司开始寻找合适的小程序制作平台。本文将向大家介绍如何使用第三方制作平台&#xff0c;如乔拓云网&#xff0c;打造…

Android 9.0 Vold挂载流程解析(上)

Android 9.0 Vold挂载流程解析&#xff08;上&#xff09; 前言Android挂载模块整体框架Vold进程main函数详细分析总结 前言 我们分2篇文章来介绍Android 9.0中存储卡的挂载流程&#xff0c;本篇文章先介绍总体的挂载模块、Vold进程的入口main函数的详细分析&#xff0c;有了这…

文心一言最新重磅发布!

8月16日&#xff0c;由深度学习技术及应用国家工程研究中心主办的WAVE SUMMIT深度学习开发者大会2023举办。百度首席技术官、深度学习技术及应用国家工程研究中心主任王海峰以《大语言模型为通用人工智能带来曙光》为题&#xff0c;阐述了大语言模型具备理解、生成、逻辑、记忆…