R2:RNN-心脏病预测

news2024/11/27 2:41:33
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

一、实验目的:

  1. 本地读取并加载数据。
  2. 了解循环神经网络(RNN)的构建过程
  3. 测试集accuracy到达87%

拔高:测试集accuracy到达89%

二、实验环境:

  • 语言环境:python 3.8
  • 编译器:Jupyter notebook
  • 深度学习环境:TensorFlow

代码流程图:
Inception v3论文:**Rethinking the Inception Architecture for Computer Vision**

三、循环神经网络(RNN)

**循环神经网络(Recurrent Neural Network,RNN)**是一种具有记忆功能的神经网络,主要用于处理序列数据。

RNN 的独特之处在于它的循环结构,能够将上一时刻的信息传递到当前时刻,从而实现对序列数据的动态处理。它包含一个循环单元,这个单元在不同的时间步重复使用,使得网络能够记住过去的信息并影响当前的输出,我们可以认为RNN有一些“记忆”能力。理论上RNN能够利用任意长序列的信息,但是实际中它能记忆的长度是有限的。

随着技术的发展,出现了许多改进的 RNN 变体,如长短期记忆网络(LSTM)门控循环单元(GRU),它们在一定程度上缓解了传统 RNN 的问题,提高了对序列数据的处理能力。

传统神经网络的结构比较简单:输入层 – 隐藏层 – 输出层。

RNN 跟传统神经网络最大的区别在于每次都会将前一次的输出结果,带到下一次的隐藏层中,一起训练。如下图所示:

在这里插入图片描述

这里用一个具体的案例来看看 RNN 是如何工作的:用户说了一句“what time is it?”,我们的神经网络会先将这句话分为五个基本单元(四个单词+一个问号)

在这里插入图片描述

然后,按照顺序将五个基本单元输入RNN网络,先将 “what”作为RNN的输入,得到输出o1
在这里插入图片描述
随后,按照顺序将“time”输入到RNN网络,得到输出o2。

这个过程我们可以看到,输入 “time” 的时候,前面“what” 的输出也会对02的输出产生了影响(隐藏层中有一半是黑色的)。
在这里插入图片描述
以此类推,我们可以看到,前面所有的输入产生的结果都对后续的输出产生了影响(可以看到圆形中包含了前面所有的颜色)
在这里插入图片描述
当神经网络判断意图的时候,只需要最后一层的输出o5,如下图所示:
在这里插入图片描述

四、前期准备

1. 设置GPU

import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0]                                        #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")
    
gpus
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

2. 导入数据

每个数据的标签含义:

  • age:年龄
  • sex:性别
  • cp:胸痛类型 (4 values)
  • trestbps:静息血压
  • chol:血清胆固醇 (mg/dl)
  • fbs:空腹血糖 > 120 mg/dl
  • restecg:静息心电图结果 (值 0,1 ,2)
  • thalach:达到的最大心率
  • exang:运动诱发的心绞痛
  • oldpeak:相对于静止状态,运动引起的ST段压低
  • slope:运动峰值 ST 段的斜率
  • ca:荧光透视着色的主要血管数量 (0-3)
  • thal:0 = 正常;1 = 固定缺陷;2 = 可逆转的缺陷
  • target:0 = 心脏病发作的几率较小 1 = 心脏病发作的几率更大
import pandas as pd
import numpy as np

df = pd.read_csv("/heart.csv")
df

请添加图片描述

3. 检查数据

在进行数据预处理之前我们还需要对我们的数据进行检查,确保每一个标签内的数据没有空值。

# 检查是否有空值
df.isnull().sum()
age         0
sex         0
cp          0
trestbps    0
chol        0
fbs         0
restecg     0
thalach     0
exang       0
oldpeak     0
slope       0
ca          0
thal        0
target      0
dtype: int64

五、数据预处理

1. 划分训练集与测试集

🍺 测试集与验证集的关系:

  1. 验证集并没有参与训练过程梯度下降过程的,狭义上来讲是没有参与模型的参数训练更新的。
  2. 但是广义上来讲,验证集存在的意义确实参与了一个“人工调参”的过程,我们根据每一个epoch训练之后模型在valid data上的表现来决定是否需要训练进行early stop,或者根据这个过程模型的性能变化来调整模型的超参数,如学习率,batch_size等等。
  3. 我们也可以认为,验证集也参与了训练,但是并没有使得模型去overfit验证集。
from sklearn.preprocessing import StandardScaler 
from sklearn.model_selection import train_test_split


X = df.iloc[:,:-1] 
y = df.iloc[:,-1] 

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.1, random_state = 1)

X_train.shape, y_train.shape 
((272, 13), (272,))

2. 数据标准化

这里我们用到了 StandardScaler函数,它的作用是去均值和方差归一化。且是针对每一个特征维度来做的,而不是针对样本。

# 将每一列特征标准化为标准正态分布,注意,标准化是针对每一列而言的
sc      = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test  = sc.transform(X_test)

X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
X_test  = X_test.reshape(X_test.shape[0], X_test.shape[1], 1)

六、构建RNN模型

这里我们构建模型要用到tf.keras.layers.SimpleRNN()函数,这个函数的模型如下:

tf.keras.layers.SimpleRNN(
    units,
    activation='tanh',
    use_bias=True,
    kernel_initializer='glorot_uniform',
    recurrent_initializer='orthogonal',
    bias_initializer='zeros',
    kernel_regularizer=None,
    recurrent_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    recurrent_constraint=None,
    bias_constraint=None,
    dropout=0.0,
    recurrent_dropout=0.0,
    return_sequences=False,
    return_state=False,
    go_backwards=False,
    stateful=False,
    unroll=False,
    **kwargs
)

函数参数的官方介绍:
在这里插入图片描述

import tensorflow
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,LSTM,SimpleRNN

model = Sequential()
model.add(SimpleRNN(128, input_shape= (13,1),return_sequences=True,activation='relu'))
model.add(SimpleRNN(64,return_sequences=True, activation='relu'))
model.add(SimpleRNN(32, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()

请添加图片描述

七、编译模型

原教案代码中 metrics=“accuracy” 报错,函数期望接收的是可迭代对象(如列表、元组等),改为 metrics=[‘accuracy’] 后解决问题。

opt = tf.keras.optimizers.Adam(learning_rate=1e-4)

model.compile(loss='binary_crossentropy',
       optimizer=opt,
       metrics=['accuracy'])

八、训练模型

epochs = 100

history = model.fit(X_train, y_train, 
                    epochs=epochs, 
                    batch_size=128, 
                    validation_data=(X_test, y_test),
                    verbose=1)

请添加图片描述

九、模型评估

import matplotlib.pyplot as plt

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

请添加图片描述

scores = model.evaluate(X_test, y_test, verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
compile_metrics: 87.10%

十、总结

从上图结果中我们可以看出:

左图:训练与验证准确率
训练集的准确率(蓝色线):随着训练次数增加,呈现出平稳上升趋势,最终接近0.92左右,说明模型在训练数据上的拟合效果逐渐变好。
验证集的准确率(橙色线):一开始随着训练迭代次数增加,验证准确率也在提升,但在约20次迭代后,准确率趋于平稳,甚至有一些波动,特别在50次之后,表现出明显的下降和上升不稳定现象。

右图:训练与验证损失
训练集损失(蓝色线):损失随着迭代次数逐渐下降,这表明模型在训练集上不断优化,误差减少。
验证集损失(橙色线):最开始也在下降,但在大约20次迭代后开始变得平缓,甚至有轻微的波动。这与验证集准确率下降的现象一致,暗示模型在验证集上的表现没有持续改进。

问题
验证集准确率的波动与下降:尽管训练集的表现不断提升,验证集的表现却在特定迭代后停滞或波动,表明模型可能存在过拟合。训练集上模型表现越来越好,但它无法有效泛化到验证集数据。
验证集损失的波动:验证集损失在下降到一定程度后不再继续下降,甚至开始轻微波动,这进一步表明模型在面对新数据时的泛化能力不足。

改进方法

  1. 增加数据量:收集更多的心脏病相关数据,包括不同类型的病例和特征,以提高模型的泛化能力。更多的数据可以让模型学习到更广泛的模式,减少过拟合的风险。
  2. 数据增强:增加数据的多样性,多模态?这样可以让模型学习到不同角度和形态下的心脏病特征,提高其鲁棒性。
  3. 正则化方法
    • L1 和 L2(权重衰减) 正则化:在模型的损失函数中添加正则项,限制模型参数的大小来减少模型的复杂度,从而防止过拟合。
    • Dropout:在训练过程中随机丢弃一些神经元,增加模型的泛化能力。
  4. Early Stopping:监控验证集的损失或准确率,当验证集的表现不再提升时,提前终止训练,防止模型过度拟合训练数据。
  5. 调整模型结构:尝试不同的模型结构,如使用别的 RNN 变体或结合其他类型的神经网络,以找到更适合心脏病预测任务的模型。

我的研究更多的是关注人工智能+影像组学,以下是一些结合RNN的简单思路和原理:

  • CNN+RNN用于动态医学影像分析
    动态医学影像(如心脏超声、功能性MRI)包含时间序列信息,需要同时考虑空间和时间特征。
    ✅CNN用于特征提取:使用卷积神经网络(CNN)从每一帧图像中提取空间特征。这些特征可以是解剖结构、病灶特征等。
    ✅RNN用于时间序列分析:将CNN提取的特征序列输入到循环神经网络(RNN)中(如LSTM或GRU),以捕捉时间动态变化。这有助于分析心脏运动、血流动态等。
  • 多模态影像融合分析
    ✅ CNN用于单模态特征提取:为每种影像模态(如CT、MRI、PET)设计专门的CNN模型,提取特定模态的特征。
    ✅ 特征融合策略:使用全连接层或自注意机制将不同模态的特征进行融合,构建一个统一的特征表示。
    ✅ RNN或Transformer用于决策:在融合特征的基础上,使用RNN或Transformer进行最终的分类或回归任务
  • AI驱动的个性化治疗方案
    ✅ CNN用于影像特征提取:从患者的影像数据中提取潜在的病理特征。
    ✅ RNN用于序列数据分析:分析患者的时间序列数据(如治疗历史、病程进展)。
    ✅ 融合多源数据:将影像特征与基因组信息、电子健康记录(EHR)结合,使用多模态融合技术进行综合分析。
    .

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2207124.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

腾讯云Android 相关

集成遇到异常怎么办? 您可以使用 armeabi 和 armeabi-v7a 架构。 如上图所示,请在app的 build.gradle 中指定 abiFilters 为“armeabi”。 功能模块升级后,短视频 SDK 的功能不能使用? 1. 如果使用的是 androidstudio&#xff0…

2024Selenium自动化常见问题!

"NoSuchElementException"异常: 确保使用了正确的选择器来定位元素。可以使用id、class、XPath或CSS选择器等。 可以尝试使用find_elements方法来查找元素列表,并检查列表的长度来判断元素是否存在。 使用显式等待(WebDriverWait…

考研编程:10.11 回文数 水仙花 生成一定范围内的随机数 求二叉树宽度

回文数 #include <stdio.h>int main(){int a,b,c0,sum;scanf("%d",&a);ba;while(b!0){c b%10 c*10;b b/10;}if(ca){printf("yes");}return 0; } 水仙花 #include <stdio.h> #include <math.h> int main(){int a,b,c0,sum;scan…

内嵌服务器Netty Http Server

内嵌式服务器不需要我们单独部署&#xff0c;列如SpringBoot默认内嵌服务器Tomcat,它运行在服务内部。使用Netty 编写一个 Http 服务器的程序&#xff0c;类似SpringMvc处理http请求那样。举例&#xff1a;xxl-job项目的核心包没有SpringMvc的Controller层&#xff0c;客户端却…

css多层嵌套折叠

<!DOCTYPE html> <html lang"zh"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>美观的纯 CSS 折叠列表</title><style>b…

如何使用Python爬虫处理JavaScript动态加载的内容?

JavaScript已经成为构建动态网页内容的关键技术。这种动态性为用户带来了丰富的交互体验&#xff0c;但同时也给爬虫开发者带来了挑战。传统的基于静态内容的爬虫技术往往无法直接获取这些动态加载的数据。本文将探讨如何使用Python来处理JavaScript动态加载的内容&#xff0c;…

值类型和引用类型的使用

using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks;namespace ConsoleApp1 {class Program{static void Main(string[] args){/****值类型****/bool test;//必须赋值,否则报错test true;Console.WriteLin…

修改svg图片颜色(结合sass)

1.下载sass npm install sass -gnpm install --save-dev sass-loader 我使用的版本 2.使用步骤 1.新建style文件夹&#xff0c;以及新建variable.scss&#xff0c;mixin.scss&#xff0c;main.scss 2.variable.scss $color_1:#50E3C2; $color_2:#FFF; 3.mixin.scss char…

大规模出海!新松移动机器人大批量进驻欧洲本土新能源市场

秋日的沈阳&#xff0c;天空高远而湛蓝。曙光下的新松智慧园&#xff0c;百余台移动机器人在车间内整齐列阵、蓄势待发&#xff0c;等待着最后的检验与封装&#xff0c;即将横跨千山万水远赴欧洲大地&#xff0c;开启中国移动机器人大规模进驻欧洲本土新能源市场的崭新篇章&…

2022年黄河流域旅游资源空间分布数据(shp)

数据介绍 黄河是中华民族的母亲河。黄河流域旅游资源丰富且极具特色。黄河流域旅游资源空间分布数据是黄河流域旅游资源开发与决策的基础。本数据集以县&#xff08;区&#xff09;域行政边界为单元、以国家旅游资源分类标准为依据&#xff0c;收集整理了黄河流域各县&#xf…

STM32-DMA直接存储器存取

一、概述 DMA&#xff08;Direct Memory Access&#xff09;直接存储器存取 DMA可以提供外设和存储器或者存储器和存储器之间的高速数据传输&#xff0c;无须CPU干预&#xff0c;节省了CPU的资源12个独立可配置的通道&#xff1a;DMA1(7个通道)&#xff0c;DMA2&#xff08;5…

【自动化】Java Access Bridge 使用说明

【自动化】Java Access Bridge 使用说明 Java Access Bridge是一项在Microsoft Windows动态链接库(DLL)中公开Java Accessibility API的技术,使实现Java Accessibility API的 Java应用程序对Microsoft Windows系统上的辅助技术可见。 开启jab服务 1 、首先获取java版本信…

【自用视频笔记】25计算机基础综合408大纲新增考点 多处理机调度

文章目录 多处理机调度指标及性能多处理器分类&#xff1a;性能指标 调度的评价指标进程分配方式&#xff1a;静态分配和动态分配、进程的调度&#xff1a;通常采用FCFS 线程调度方式多处理机调度评价指标 25计算机基础综合 多处理机原视频1 多处理机原视频2 多处理机调度 先…

电子产品做高温老化性能测试可行性方案

1.1引言 1.2背景 1.3目的 2.系统概述 2.1 系统架构 2.2 功能模块 3. 接口 3.1硬件接口 3.3. 通信接口 3.4 软件接口 3.5 数据存储和处理 4. 功能需求 4.1 数据采集 4.1.1 采集和监控数据 4.2 实时监测和显示 4.2.1 实时显示电流电压曲线图 4.3…

打包上线不确定接口IP以及端口 如何处理

前言 本文主要讲述如何在vue项目打包后动态修改请求服务器接口的ip和端口的修改&#xff0c;其他的配置可参考此方法进行。 在Vue项目中一般都将配置文件写在 .env.development / .env.production 文件当中&#xff0c;但是如果仅仅是因为修改配置文件又重新打包一次就会很繁琐…

如何将数据输入到神经网络中(How to Input Data into a Neural Network)

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:Linux运维老纪的首页…

k8s中pod管理

一、Pod的基本概念 定义&#xff1a;Pod是Kubernetes中可以创建和管理的最小单元&#xff0c;是资源对象模型中由用户创建或部署的最小资源对象模型。 组成&#xff1a;Pod由一个或多个容器组成&#xff0c;这些容器共享网络、存储等资源&#xff0c;并作为一个整体被调度和管…

PPT电脑怎么录屏?多达4种录屏软件录制 PPT 指南

在日常的工作、学习以及知识分享领域&#xff0c;PPT 扮演着不可或缺的角色。而将 PPT 内容录制下来更是有诸多用途&#xff0c;比如教师制作线上教学课件、职场人士分享项目方案、培训师准备培训素材等。要想获得优质的 PPT 录制效果&#xff0c;合适的录屏软件必不可少。接下…

5G路由器工业物联网PLC模块通讯应用

工业物联网在计算机互联网的基础上&#xff0c;利用传感技术、数据通信等技术&#xff0c;构建一个覆盖世界万物的“Internet of Things”&#xff0c;其实质是利用传感技术&#xff0c;通过联网实现物的自动识别和信息的互联与共享。5G工业路由器连接现场传感设备等实施数据采…

微知-NVIDIA Bluefield DPU的E-Series和P-Series区别?(功率75vs150,是否需要ATX额外供电)

背景 本文介绍了NVIDIA的Bluefield的产品分裂E和P系列&#xff0c;了解这部分&#xff0c;可以快速获取CPU主频&#xff0c;还能根据产品型号字母快速获取数据。 区别 E 系列 DPU&#xff1a;通过 PCIe x16 接口提供最大 75W 的系统电源。 P 系列 DPU&#xff1a;通过 PCIe …