第15周:RNN心脏病预测

news2024/11/17 3:58:14

目录

前言

二、前期准备

2.1 设置GPU

2.2 导入数据

2.2.1 数据介绍

2.2.2 导入代码

2.2.3 检查数据

三、数据预处理

3.1 划分训练集与测试集

3.2 标准化

四、构建RNN模型

4.1 基本概念

4.2 搭建代码

五、编译模型

六、训练模型

七、模型评估

总结


前言

  • 🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/0dvHCaOoFnW8SCp3JpzKxg) 中的学习记录博客
  • 🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)

说在前面

本周目标:本地读取并加载数据、了解循环神经网络(RNN)的构建过程、调整代码是的测试机acuuracy达到87%;拔高目标——测试集accuracy达到89%

我的环境:Python3.8、Pycharm2020、tensorflow2.4.0

数据来源:[K同学啊](https://mtyjkh.blog.csdn.net/)

代码的流程图如下:


一、RNN简介

传统神经网络结构比较简单是输入层——隐藏层——输出层,而RNN与传统神经网络最大的区别在于每次都会将前一次的输出结果,带到下一次的隐藏层中,一起训练。如下图所示,左图为传统神经网络,右图为RNN

 以一个案例具体分析RNN工作过程,用户说了一句“what time is it?”,我们的神经网络首先会将这句话分为五个基本单元(四个单词➕一个问号);然后按照顺序将5个基本单元输入RNN网络,what作为RNN的输入得到输出01,按照顺序将“time”输入RNN网络,得到输出02,这个过程中可以看到输入“time”的时候,前面“what”的输出也会对02的输出产生了影响(如下图中所示,隐藏层中有一半是黑色的),依次类推,前面所有的输入产生的结果都对后续的输出产生了印象(下图中最后的圆形中就包含了前面所有的颜色) 

当神经网络判断意图的时候,只需要最后一层的输出05,如下图所示

                               

循环神经网络(RNN)是一类用于处理序列数据的神经网络。不同于传统的前馈神经网络,RNN 能够处理序列长度变化的数据,如文本、语音等。RNN 的特点是在模型中引入循环,使得网络能够保持某种状态,从而在处理序列数据时表现出更好的性能。

上图左边简单描述 RNN 的原理,x 是输入层,o 是输出层,中间 s 是隐藏层,在 s 层进行一个循环,右边表示展开循环看到的逻辑,其实是和时间 t 相关的一个状态变化,也就是说神经网络在处理数据的时候,能看到前一时刻、后一时刻的状态,也就是常说的上下文

二、前期准备

2.1 设置GPU

代码如下:

#一、前期准备
#1.1 导入所需包和设置GPU
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # 不显示等级2以下的提示信息
import tensorflow as tf
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,LSTM,SimpleRNN
import matplotlib.pyplot as plt

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0]                                        #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")
print(gpus)

2.2 导入数据

2.2.1 数据介绍

  • age:1)年龄
  • sex:2)性别
  • cp:3)胸痛类型(4 values)
  • trestbps:4)静息血压
  • chol:5)血清胆甾淳(mg/dl)
  • fbs:6)空腹血糖>120mg/dl
  • restecg:7)静息心电图结果(值0,1,2)
  • thalach:8)达到的最大心率
  • exang:9)运动诱发的心绞痛
  • olddpeak:10)相对静止状态,运动引起的ST段压低
  • slope:11)运动峰值ST段的斜率
  • ca:12)荧光透视着色的主要血管数量(0-3)
  • thal:13)0=正常,1=固定缺陷;2=可逆转的缺陷
  • target:14)0=心脏病发作的几率较小,1=心脏病发作的几率更大

2.2.2 导入代码

#1.2 导入数据
df = pd.read_csv('heart.csv')
print(df)

2.2.3 检查数据

检查是否存在空值

df.isnull().sum()  #检查是否有空值

数据打印显示如下

三、数据预处理

3.1 划分训练集与测试集

补充:测试集与验证集的关系——①验证集并没有参与训练中梯度下降的过程,狭义上来讲是没有参与模型的参数训练更新的;②但广义上来说,验证集存在的意义确实参与了一个“人工调参”的过程,我们根据每一个epoch训练之后的模型在vaild data上的表现来决定是否需要训练进行early stop,或者根据这个过程模型的性能变化来调整模型的超参数,如学习率,batch_size等等;③所以也可以认为,验证集也参与了训练,但是并没有使得模型去overfit验证集

代码如下:

#二、数据预处理
#2.1 数据集划分
x = df.iloc[:,:-1]
y = df.iloc[:,-1]

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=1)
print(x_train.shape, y_train.shape)

打印输出:(272, 13) (272,)

3.2 标准化

代码如下:

# 将每一列特征标准化为标准正态分布,注意,标准化是针对每一列而言的
sc = StandardScaler()
x_train = sc.fit_transform(x_train)
x_test = sc.transform(x_test)

x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], 1)
x_test = x_test.reshape(x_test.shape[0], x_test.shape[1], 1)

四、构建RNN模型

4.1 基本概念

函数原型:tf.keras.layers.SimpleRNN(units,activation='tanh',use_bias=True,kernel_initializer='glorot_uniform',recurrent_initializer='orthogonal',bias_initializer='zeros',kernel_regularizer=Noe,recurrent_regularizer=Noe,bias_regularizer=None,activity_regularizer=None,keenel_constraint=None,recurrent_constraint=None,bias_constraint=None,dropout=0.0,recurrent_dropout=0.0,return_sequences=False,return_state=False,go_backwards=False,stateful=False,unroll=False,**kwargs)

关键参数说明:

  • units——正整数,输出空间的维度
  • activation——要使用的激活函数,默认为双曲正切(tanh),如果传入None,则不使用激活函数(即线性激活a(x)=x)
  • use_bias——布尔值,该层是否使用偏置向量
  • kernel_initializer——kernel权值矩阵的初始化器,用于输入的线性转换
  • recurrent_initializer——recurrent_kernel权值矩阵的初始化器,用于循环层状态的线性转换
  • bias_initializer——偏置向量的初始化器
  • dropout:在-0和1之间的浮点数,单元的丢弃比例,用于输入的线性转换

4.2 搭建代码

#三、构建RNN模型

model = Sequential()
model.add(SimpleRNN(128, input_shape= (13,1),return_sequences=True,activation='relu'))
model.add(SimpleRNN(64,return_sequences=True, activation='relu'))
model.add(SimpleRNN(32, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()

模型输出如下:

五、编译模型

代码如下:

#四、编译模型
opt = tf.keras.optimizers.Adam(learning_rate=1e-4)
model.compile(loss='binary_crossentropy', optimizer=opt,metrics=['accuracy'])

六、训练模型

代码如下:

#五、训练模型
epochs = 100
history = model.fit(x_train, y_train,
                    epochs=epochs,
                    batch_size=128,
                    validation_data=(x_test, y_test),
                    verbose=1)

训练过程:

七、模型评估

代码如下

#六、模型评估
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

scores = model.evaluate(x_test,y_test,verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))

打印结果:

accuracy: 90.32%


总结

RNN实战应用,是一种用于处理序列数据的神经网络,了解了基于Tensorflow搭建RNN的过程;学习了对于文本类数据,是怎么将其数字化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1882935.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2024年文化传播与对外交流国际学术会议(ICCCFE 2024)

2024年文化传播与对外交流国际学术会议(ICCCFE 2024) 2024 International Conference on Cultural Communication and Foreign Exchange(ICCCFE 2024) 会议简介: 2024年文化传播与对外交流国际学术会议(ICCCFE 2024)定…

Go线程实现模型-P

P 概述 P是G能够在M中运行关键。Go的运行时系统会适时地让P与不同的M建立或断开关联,以使P中的那些可运行的G能够及时获得,这与操作系统内核在CPU之上实时切换不同进程或线程的情况类似 改变P的数量 改变单个Go程序间拥有的P的最大数量有两种方法 调…

《塔瑞斯世界》国服震撼登场!AOC助力玩家开启游戏新征途!

一款真正高画质、重机制、轻数值的MMORPG大作! 你是否厌倦了在MMORPG游戏中被“氪金大佬”碾压?你是否渴望一个纯粹依靠技术和策略就能获得成就感的游戏世界?如果你对这两个问题的答案都是肯定的,那么《塔瑞斯世界》或许值得你一…

docker-compose搭建minio对象存储服务器

docker-compose搭建minio对象存储服务器 最近想使用oss对象存储进行用户图片上传的管理,了解了一下例如aliyun或者腾讯云的oss对象存储服务,但是呢涉及到对象存储以及经费有限的缘故,决定自己手动搭建一个oss对象存储服务器; 首先…

前端git约定式规范化提交-commitizen

当使用commitizen进行代码提交时,commitizen会提示你在提交代码时填写所必填的提交字段信息内容。 1、全局安装commitizen npm install -g commitizen4.2.4 2、安装并配置 cz-customizeable 插件 2.1 使用 npm 下载 cz-customizeable npm i cz-customizeable6.…

30秒就能完成3D翻页画册的工具

​在数字化时代,将传统画册转化为电子版,并赋予其3D翻页的动态效果,不仅能够增强视觉效果,还能提高资料的传播效率。对于需要在短时间内完成3D翻页画册制作的用户,这里推荐一款能迅速实现这一目标的在线工具。 首先&am…

Web3 游戏周报(6.23 - 6.29)

区块链游戏热度不减,你是否掌握了上周的重要动态? 回顾上周区块链游戏动态,查看 Footprint Analytics 与 ABGA 的最新数据报告。 【6.23 - 6.29】Web3 游戏行业动态: 继 Notcoin 之后,另一款 Telegram 游戏 Hamster …

沙箱在“一机两用”新规下的价值体现

在数字化时代,随着企业信息化建设的深入,数据安全问题愈发凸显其重要性。一机两用新规的出台,旨在通过技术创新和管理手段,实现终端设备的安全可控,确保敏感数据的安全存储与传输。SDC沙箱技术作为一种创新的安全防护手…

期权交易指南:为什么要交易场外个股期权?

今天带你了解期权交易指南:为什么要交易场外个股期权?随着金融市场的发展和创新,投资者寻求更多的工具来管理风险和获得更高的回报。场外期权交易应运而生,成为一种重要的金融衍生品交易方式。 简单来说就是期权是一种合约&#…

如何用程序批量下载小红书的图片?

如何使用MediaCrawler快速下载图片 作为一名图像算法工程师,怎么能没有图片资源呢?今天,我要介绍一个能快速下载图片的方法,仅供学习使用,请勿用于其他用途。 下载项目 首先,从GitHub下载项目&#xff1…

fastapi swagger在线接口文档报错

fastapi swagger在线接口文档报错 1、报错信息 Unable to render this definition The provided definition does not specify a valid version field. Please indicate a valid Swagger or OpenAPI version field. Supported version fields are swagger: “2.0” and those …

无线领夹麦克风哪个品牌音质最好,直播用领夹麦克风还是声卡麦

随着社交媒体的兴起,直播和Vlog已经成为内容创作的新趋势,这些变化不仅改变了人们分享生活的方式,也带动了音频设备市场的增长。无线领夹麦克风,以其便携性和卓越的录音品质,迅速成为视频制作者的重要工具。它们在直播…

<电力行业> - 《第11课:配电(1)》

1 配电 配电(power distribution)是在电力系统中直接与用户相连并向用户分配电能的环节。配电系统由配电变电所、高压配电线路、配电变压器、低压配电线路以及相应的控制保护设备组成。 1.1 概念 配电系统是由多种配电设备(或元件&#xf…

昂科烧录器支持BPS晶丰明源半导体的多相Buck控制器BPD93004E

芯片烧录行业领导者-昂科技术近日发布最新的烧录软件更新及新增支持的芯片型号列表,其中BPS晶丰明源半导体的多相Buck控制器BPD93004E已经被昂科的通用烧录平台AP8000所支持。 BPD93004E是一款多相Buck控制器,支持原生1~4相,数字方式控制&am…

经典的卷积神经网络模型 - AlexNet

经典的卷积神经网络模型 - AlexNet flyfish AlexNet 是由 Alex Krizhevsky、Ilya Sutskever 和 Geoffrey Hinton 在 2012 年提出的一个深度卷积神经网络模型,在 ILSVRC-2012(ImageNet Large Scale Visual Recognition Challenge 2012)竞赛中…

Domino应用中的HTML5

大家好,才是真的好。 在xpages多年不见有效更新,前景不明的时候,Domino传统Web应用开发方式还是受到了应有的青睐。毕竟,在Nomad Web时代,连最传统的Notes CS原生应用也突然焕发了勃勃生机一样。 但,对有…

51单片机嵌入式开发:STC89C52环境配置到点亮LED

STC89C52环境配置到点亮LED 1 环境配置1.1 硬件环境1.2 编译环境1.3 烧录环境 2 工程配置2.1 工程框架2.2 工程创建2.3 参数配置 3 点亮一个LED3.1 原理图解读3.2 代码配置3.3 演示 4 总结 1 环境配置 1.1 硬件环境 硬件环境采用“华晴电子”的MINIEL-89C开发板,这…

你还不会买智能猫砂盆吗?跟你们详细讲解今年最火的智能猫砂盆!

智能猫砂盆的坑,想必有很多养猫家庭都踩过吧。自己买回来的机器,不是空间不够大,导致猫咪拉到外面去,就是铲不干净,还得自己进行二次清理,搞得这个智能猫砂盆白买了。那如果我们想要购买合适自己家猫咪的智…

MyBatisPlus 基础数据表的增删改查 入门 简单查询

MyBatisPlus MyBatisPlus(简称MP)是一个基于MyBatis的增强工具库,简化了MyBatis的开发,提供了很多实用的功能和特性,如自动生成SQL、通用CRUD操作、分页插件、条件构造器、代码生成器等。它不仅简化了开发过程&#x…

Objective-C语法基础

新建一个XCode项目 新建一个类 1、成员变量、属性 1.1、类内使用成员变量&#xff0c;类外使用属性 Role.h #import <Foundation/Foundation.h>NS_ASSUME_NONNULL_BEGINinterface Role : NSObject {//成员变量&#xff1a;只能类内使用NSString *_name;int _age; }//属…