第R1周: RNN-心脏病预测

news2025/1/22 21:12:13
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

一、什么是RNN

RNN(Recurrent Neural Network)是一种特殊的神经网络,它能够处理序列数据,如时间序列、文本序列等。RNN与传统神经网络的主要区别在于其结构特点和计算机制。

  1. 结构特点
    • RNN由一系列相互连接的节点组成,每个节点代表一个状态。
    • 这些节点通过反馈回路连接在一起,使得网络能够记住之前的状态信息。
  2. 计算机制
    • RNN在计算当前状态时,不仅考虑当前输入,还考虑上一个状态的信息。
    • 这种机制使得RNN能够捕捉到数据的长期依赖性,适用于处理时间序列数据。
  3. 应用场景
    • RNN在自然语言处理领域有着广泛的应用,如机器翻译、文本生成等。
    • 此外,RNN还可以用于预测时间序列数据,如股票价格、天气预测等。
  4. 局限性
    • RNN也存在一些局限性,例如对长序列的处理能力有限,以及可能出现的梯度消失问题。
  5. 改进方法
    • 为了解决这些问题,研究者提出了各种改进方法,如LSTM(Long Short-Term Memory)、GRU(Gated Recurrent Unit)等,它们在RNN的基础上增加了更多的结构和机制,以提高性能和稳定性。

二、前期准备

1.设置GPU

import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    gpu0 = gpus[0]
    tf.config.experimental.set_memory_growth(gpu0, True)
    tf.config.set_visible_devices([gpu0], "GPU")
gpus

输出
[PhysicalDevice(name=‘/physical_device:GPU:0’, device_type=‘GPU’)]

2.导入数据

这里使用的是训练营提供心脏病数据。(这是数据各个标签对应的意思)
age: 1)年龄
sex: 2)性别
cp: 3)胸痛类型(4 values)
trestbps: 4)静息血压
chol: 5)血清胆固醇(mg/dl)
fbs: 6)空腹血糖>120 mg/dl
restecg: 7)静息心电图结果(值0,1,2)
thalach: 8)达到的最大心率
exang: 9)运动诱发的心绞痛
oldpeak: 10)相对于静止状态,运动引起的ST段压低
slope: 11)运动峰值ST段的斜率
ca: 12)荧光透视着色的主要血管数量(0-3)
thal: 13)0=正常;1=固定缺陷;2=可逆转的缺陷
target: 14)0=心脏病发作的几率较小1=心脏病发作的几率更大

import pandas as pd
import numpy as np
df = pd.read_csv("heart.csv")
df

输出
在这里插入图片描述

3.检查数据

df.isnull().sum() 是在Python的pandas库中用于数据清洗和预处理的一个函数组合,它的作用和用法如下:

作用:

  • df.isnull(): 这个方法用于检测DataFrame中的缺失值,它会返回一个布尔型的DataFrame,其中每个单元格都会被判断是否为空(NaN),如果是空值,则对应的单元格为True,否则为False
  • .sum(): 这个方法用于对DataFrame中的数据进行求和。当应用于布尔型DataFrame时,它会将True视为1False视为0,然后对每一列进行求和,从而得到每列中True(即缺失值)的总数。

用法:

  • df: 这是你的DataFrame对象,通常是经过pandas读取数据后得到的。
  • df.isnull(): 调用这个方法后,会返回一个与原DataFrame形状相同的布尔型DataFrame,显示哪些元素是缺失值。
  • df.isnull().sum(): 在df.isnull()的基础上调用.sum()方法,会对每一列的缺失值进行计数,并返回一个Series对象,索引为原DataFrame的列名,值为对应列的缺失值数量。

示例:

假设有一个DataFrame df 如下:

import pandas as pd
import numpy as np
df = pd.DataFrame({
    'A': [1, 2, np.nan],
    'B': [4, np.nan, 6],
    'C': [7, 8, 9]
})
# 使用 df.isnull().sum() 检查每列的缺失值数量
missing_values = df.isnull().sum()
print(missing_values)

输出将会是:

A    1
B    1
C    0
dtype: int64

这表示列’A’有一个缺失值,列’B’也有一个缺失值,而列’C’没有缺失值。通过这种方式,可以快速了解数据集中哪些列存在缺失值,以及缺失值的数量。

df.isnull().sum()  

输出
age 0
sex 0
cp 0
trestbps 0
chol 0
fbs 0
restecg 0
thalach 0
exang 0
oldpeak 0
slope 0
ca 0
thal 0
target 0
dtype: int64

三、数据预处理

1.划分测试集与训练集

这一段代码的作用是进行数据预处理和分割,为机器学习模型的训练和测试做准备。

  1. 导入必要的库
    from sklearn.preprocessing import StandardScaler
    from sklearn.model_selection import train_test_split
    
    • StandardScaler:这是scikit-learn库中的一个类,用于数据标准化。标准化是一个将数据转换为具有零均值和单位方差的过程。
    • train_test_split:这是scikit-learn库中的一个函数,用于将数据集分割成训练集和测试集。
  2. 数据分割
    X = df.iloc[:, :-1]
    y = df.iloc[:, -1]
    
    • df.iloc[:, :-1]:这里使用pandas的.iloc索引器来选择DataFrame df 的所有行和除了最后一列的所有列,这些被选中的列作为特征集(X)。
    • df.iloc[:, -1]:同样使用.iloc索引器,这次选择DataFrame df 的所有行和最后一列,这一列通常是目标变量或标签(y)。
  3. 分割数据集
    x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=1)
    
    • train_test_split(X, y, test_size=0.1, random_state=1):这个函数接受特征集和目标变量,并将它们分割成训练集和测试集。
      • X:特征集。
      • y:目标变量。
      • test_size=0.1:指定测试集的大小为原始数据集的10%。
      • random_state=1:设置随机数生成器的种子,以确保每次分割得到的数据集都是一样的,这对于可重复的实验是重要的。
        综上所述,这段代码首先导入了用于数据预处理和分割的库,然后从DataFrame中提取了特征和目标变量,最后使用train_test_split函数将数据分割成训练集和测试集,以便后续可以进行模型训练和评估。
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

X = df.iloc[:, :-1]
y = df.iloc[:, -1]

x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=1)
x_train.shape,y_train.shape

输出
((272, 13), (272,))

2.标准化

这段代码的作用是对特征数据进行标准化处理,并将其重塑为适合输入到某些机器学习模型(特别是深度学习模型)的格式。

  1. 创建标准化器实例
    sc=StandardScaler()
    
    • StandardScaler():这是scikit-learn库中的一个类,用于标准化特征,通过去除均值(将数据中心的均值变为0)和缩放至单位方差(将数据的标准差变为1)来实现。这有助于使不同量级的特征具有相同的尺度,从而提高某些算法的收敛速度和性能。
  2. 标准化训练数据
    x_train=sc.fit_transform(x_train)
    
    • fit_transform(x_train):这是StandardScaler类的一个方法,它执行两个操作:
      • fit:计算训练数据集的均值和标准差(即每个特征的均值和标准差),这些统计量将被用于后续的标准化过程。
      • transform:使用计算得到的均值和标准差对训练数据进行标准化处理。
  3. 标准化测试数据
    x_test=sc.transform(x_test)
    
    • transform(x_test):这是StandardScaler类的一个方法,它使用在训练数据上计算得到的均值和标准差来对测试数据进行标准化。这样做可以保证训练集和测试集使用相同的标准化参数,从而避免数据泄露。
  4. 重塑数据形状
    x_train=x_train.reshape(x_train.shape[0],x_train.shape[1],1)
    x_test=x_test.reshape(x_test.shape[0],x_test.shape[1],1)
    
    • x_train.reshape(x_train.shape[0],x_train.shape[1],1)x_test.reshape(x_test.shape[0],x_test.shape[1],1):这些操作改变了数据的形状,使其具有三个维度。在很多深度学习模型中,特别是使用卷积神经网络(CNN)或循环神经网络(RNN)时,数据通常需要是三维的,其中三个维度分别代表样本数量、特征数量和时间步长(或通道数)。在这个例子中,第三个维度是1,表示每个样本只有一个时间步长(对于非序列数据)。
      总的来说,这段代码首先对特征数据进行标准化处理,然后将数据重塑为三维格式,以便可以输入到需要这种数据格式的模型中。

四、构建RNN模型

tf.keras.layers.SimpleRNN 是 TensorFlow 中 Keras API 的一部分,用于创建一个简单的循环神经网络(RNN)层。以下是该函数的作用及其参数的解释:

作用:

SimpleRNN 层实现了一个基本的循环神经网络,它在每个时间步长跟踪更新“隐藏状态”,并根据当前的输入和前一个时间步长的隐藏状态来计算输出。这种层可以用于处理序列数据,因为它能够记住之前的信息,并将其用于当前的操作。

参数:

  • units:整数,表示 RNN 层中的单元数(或神经元数)。这决定了输出空间的维度,即每个时间步长的输出向量的大小。
  • activation:要使用的激活函数。默认为 ‘tanh’。其他常见的激活函数包括 ‘relu’ 和 ‘sigmoid’。激活函数决定了每个时间步长输出的非线性变换。
  • use_bias:布尔值,表示是否在计算中添加偏置项(b)。默认为 True。如果设置为 False,则层不会使用偏置向量。
  • kernel_initializer:权值矩阵(内核)的初始化器。默认为 ‘glorot_uniform’,也称为 Xavier 初始化器。它根据输入和输出单元的数量来初始化权重,以保持每个神经元在前向传播和反向传播中的激活方差一致。
  • recurrent_initializer:用于循环内核权重的初始化器。默认为 ‘orthogonal’,它生成一个正交矩阵,这对于避免梯度消失问题很有帮助。
  • bias_initializer:偏置向量的初始化器。默认为 ‘zeros’,即将偏置初始化为零。
  • dropout:0 到 1 之间的浮点数,表示输入单元的丢弃比例。
  • recurrent_dropout:0 到 1 之间的浮点数,表示循环单元的丢弃比例。
  • return_sequences:布尔值,表示是否返回每个时间步的输出序列,还是只返回最后一个时间步的输出。默认为 False。
  • return_state:布尔值,表示是否返回最后一个时间步的隐藏状态。默认为 False。
  • go_backwards:布尔值,表示是否反向处理输入序列。默认为 False。
  • stateful:布尔值,表示是否在每个批次结束时重置状态。如果设置为 True,则层的状态将在连续批次的调用之间保持。默认为 False。
  • unroll:布尔值,表示是否展开 RNN。如果设置为 True,则 RNN 将被展开成一个全连接的网络,这在某些情况下可以提高效率,尤其是在使用长序列时。默认为 False。
    通过调整这些参数,可以定制 SimpleRNN 层以适应不同的序列处理任务。
import tensorflow
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,LSTM,SimpleRNN
model = Sequential()
model.add(SimpleRNN(200, input_shape=(13,1), activation='relu'))
model.add(Dense(100, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()

输出
在这里插入图片描述

五、编译模型

opt = tf.keras.optimizers.Adam(learning_rate=1e-4)

model.compile(loss='binary_crossentropy', 
              optimizer=opt,
              metrics="accuracy")

六、训练模型

epochs =100
history = model.fit(x_train,y_train,
                    epochs=epochs,
                    batch_size=32,
                    validation_data=(x_test,y_test),
                    verbose=1)

输出
在这里插入图片描述
在这里插入图片描述

七、模型评估

import matplotlib.pyplot as plt
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)
plt.figure(figsize=(14, 4))
plt.subplot(1,2,1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1,2,2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

输出
在这里插入图片描述

scores = model.evaluate(x_test, y_test, verbose=0)
print("%s:% .2f%%" % (model.metrics_names[1], scores[1]*100))

输出
accuracy: 83.87%

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2066655.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MedGraphRAG:医学版 GraphRAG

MedGraphRAG:医学版 GraphRAG 提出我的解法思路 MedGraphRAG 大纲解法大纲 解法拆解U-retrieve 双向检索 分析性关联图创意视角MedGraphRAG 对比 传统知识图谱大模型现在医疗知识图谱的问题MedGraphRAG的三层层级图结构,能不能让普通的医疗知识图谱&…

线程——函数式创建线程threading模块,继承式创建线程,Lock对象解决多线程不安全问题,线程模型中的生产者消费者模式

调度程序内的多任务使用多进程,调度一个进程内的多任务使用多线程 函数式创建线程的方式threading模块 在Python中,创建线程主要依赖于threading模块。 使用threading模块中的Thread类,你可以很容易地基于函数模式创建线程。基本步骤包括&…

Linux | 文件描述符fd详解及重定向技术的应用

多谢梅花,伴我微吟。 - 《高阳台除夜》(韩疁) 2024.8.23 目录 1、文件描述符fd 文件操作符概念(简单带过) 重点:如何理解文件操作符使得系统实现了设备无关性?(使得操作系统无需关心具体的硬件细节) 示例代码:标准输入…

SAP BW:QUERY数据结果写入ADSO

作者 idan lian 如需转载备注出处 如果对你有帮助,请点赞收藏~~~ 需求背景 客户基于QUERY进行报表展示,现需迁移到永洪报表平台,query中的变量参数,公式等无法直接生成视图,query相对复杂,不想直接在视图…

流动会场:便捷、经济与声学效果的理想融合—轻空间

在现代活动策划中,选择合适的场地至关重要。流动会场作为一种新型移动空间,不仅具备便捷性和高性价比,还以其优异的声学效果,成为各类会议、展览、演出等活动的理想选择。 便捷安装,快速搭建 流动会场的模块化设计使其…

P-One如何测试一个场景集包含多个接口

P-One是泽众软件自主研发的一站式性能测试平台,集管理、设计、压测、监控以及分析于一体的全方位性能测试解决方案,适用于各种非功能测试场景:压力测试、负载测试、稳定性测试、可靠性测试、容量测试等。 在实际业务场景中,如电商…

springsecurity 在web中如何获取用户信息(后端/前端)

一、SecurityContextHolder 是什么 是一个安全的上下文对象,用于获取经过身份验证的用户。 二、SecurityContextHolder 是何时被创建的 当我们经过表单UsernamePasswordAuthenticationFilter 过滤器后,会回调父类的AbstractAuthenticationProcessingFilt…

华为自研仓颉编程语言测试版上线,计划持续到10月21号

现如今,编程语言作为构建软件世界的基石,其重要性不言而喻。 而华为,作为全球领先的信息与通信技术(ICT)解决方案提供商,其在技术创新上的每一步都备受瞩目。最近,华为再次成为焦点&#xff0c…

OpenCompass 评测 InternLM-1.8B 实践

1. 环境安装 conda create -n opencompass python3.10 conda activate opencompass conda install pytorch2.1.2 torchvision0.16.2 torchaudio2.1.2 pytorch-cuda12.1 -c pytorch -c nvidia -y# 注意:一定要先 cd /root cd /root git clone -b 0.2.4 https://gith…

系统编程-lvgl

带界面的MP3播放器 -- lvgl 目录 带界面的MP3播放器 -- lvgl 一、什么是lvgl? 二、简单使用lvgl 在工程中编写代码 实现带界面的mp3播放器 main.c events_init.c events_init.h 补充1:glob函数 补充2:atexit函数 一、什么是lvgl&a…

GPT-4、Claude 3 Opus 和 Gemini 1.0 Ultra 挑战控制工程的新领域

介绍 论文地址:https://arxiv.org/abs/2404.03647 近年来,GPT-4、Claude 3 Opus 和 Gemini 1.0 Ultra 等大规模语言模型(LLM)迅速发展,展示了它们解决复杂问题的能力。LLM 的这些发展在多个领域都有潜在的应用前景。…

Postman接口测试 —— 设置全局变量、参数传递、断言

在能熟练使用postman运行接口请求后,会遇到一些问题。例如: 我们的web网站一共有几十个接口,测试的时候如果要切换环境,这个时候要每个接口都要修改url的根路径,一个一个的改也太麻烦了; 还有时候我们经常…

八、SPA单页面实现SEO优化之预渲染prerender-spa-plugin

文章目录 一、前言二、prerender-spa-plugin预渲染方式实现SEO插件介绍实现步骤 一、前言 关于SPA和SEO优化、SSR服务器渲染的介绍可以参考这里: 六、什么是SEO优化(搜索引擎优化)?SPA单页面应用如何实现SEO优化? 通…

C/C++语言基础--字符串(包括字符串与字符数组、字符串与指针、字符串处理函数等),代码均可运行

本专栏目的 更新C/C的基础语法,包括C的一些新特性 前言 无论什么语言,字符串都是最重要、最基础的数据类型,他对二进制有很好的对应关系在C语言中没有提供专门的处理字符串的类型,但是我们可以通过字符数组、开辟内存地址来处理…

Content-Encoding: br

爬虫的时候遇到了 Content-Encoding: br , 这可能会导致返回的数据有乱码,无法解析,也无法解码, 浏览器显示编码 按照这么写,还是乱码 查了很久,需要在请求头 Accept-Encoding 将这个改为gzip&#xff0c…

Swift 6.0 如何更优雅的抛出和处理特定类型的错误

概述 从 Swift 语言诞生那天儿起,它就不厌其烦一遍又一遍地向秃头码农们诉说着自己的类型安全和高雅品味。 不过遗憾的是,作为 Swift 语言中错误处理这最为重要的一环却时常让小伙伴们不得要领、满腹狐疑。 在本篇博文中,您将学到如下内容&…

企业数字化转型会面临哪些挑战,如何解决?

当前,数字技术发展迅速,已迈入 AI 人工智能时代。企业若不进行数字化转型,可能会被用户抛弃、被竞争对手超越。那么,传统企业在转型过程中会遇到哪些挑战呢? 一、企业数字化转型面临的挑战 1、缺乏明确的战略规划和转…

宠物空气净化器除臭吗?性价比高的宠物空气净化器十大排名分享

来来来,先带大家一睹我店里的小可爱们 是不是超级可爱呀~?这样的大卡车猫猫,在我这猫咖里可是还有好几十只!作为一位坐拥几十只猫咪的“猫咖掌门”,朋友们总是投来羡慕的目光。但这份光鲜背后,可是有我无数…

轻松制作 GIF 动图,你也可以!

你是否曾为找不到合适的动图而烦恼? 是否羡慕别人能制作出精彩的 GIF 动图? 现在,无需再羡慕!因为我们用以下图片中的方法,你自己也能轻松制作 GIF 动图。 这款工具,操作简单易懂, 即使你没有…

舞动奇迹,亨廷顿舞蹈症患者专属健身秘籍!

🌈 在小红书的温馨角落里,让我们一起探索一个特别的世界——为亨廷顿舞蹈症(HD)患者量身定制的健身之旅。HD,这个名字或许带着一丝沉重,但它绝不能定义我们生活的全部色彩。通过科学的锻炼方式,…