RNN实现阿尔茨海默症的诊断识别

news2025/1/28 11:06:46

本文为为🔗365天深度学习训练营内部文章

原作者:K同学啊

 一 导入数据

import torch.nn as nn
import torch.nn.functional as F
import torchvision,torch
from sklearn.preprocessing import StandardScaler
from torch.utils.data import TensorDataset,DataLoader
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
import warnings
warnings.filterwarnings('ignore')

# 设置硬件设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

df = pd.read_excel('dia.xls')
df

二 数据处理分析

# 删除第一列和最后一列
df = df.iloc[:,1:-1]
print(df)
 Age  Gender  Ethnicity  EducationLevel        BMI  Smoking   
0      73       0          0               2  22.927749        0  \
1      89       0          0               0  26.827681        0   
2      73       0          3               1  17.795882        0   
3      74       1          0               1  33.800817        1   
4      89       0          0               0  20.716974        0   
...   ...     ...        ...             ...        ...      ...   
2144   61       0          0               1  39.121757        0   
2145   75       0          0               2  17.857903        0   
2146   77       0          0               1  15.476479        0   
2147   78       1          3               1  15.299911        0   
2148   72       0          0               2  33.289738        0   

      AlcoholConsumption  PhysicalActivity  DietQuality  SleepQuality  ...   
0              13.297218          6.327112     1.347214      9.025679  ...  \
1               4.542524          7.619885     0.518767      7.151293  ...   
2              19.555085          7.844988     1.826335      9.673574  ...   
3              12.209266          8.428001     7.435604      8.392554  ...   
4              18.454356          6.310461     0.795498      5.597238  ...   
...                  ...               ...          ...           ...  ...   
2144            1.561126          4.049964     6.555306      7.535540  ...   
2145           18.767261          1.360667     2.904662      8.555256  ...   
2146            4.594670          9.886002     8.120025      5.769464  ...   
2147            8.674505          6.354282     1.263427      8.322874  ...   
2148            7.890703          6.570993     7.941404      9.878711  ...   

      FunctionalAssessment  MemoryComplaints  BehavioralProblems       ADL   
0                 6.518877                 0                   0  1.725883  \
1                 7.118696                 0                   0  2.592424   
2                 5.895077                 0                   0  7.119548   
3                 8.965106                 0                   1  6.481226   
4                 6.045039                 0                   0  0.014691   
...                    ...               ...                 ...       ...   
2144              0.238667                 0                   0  4.492838   
2145              8.687480                 0                   1  9.204952   
2146              1.972137                 0                   0  5.036334   
2147              5.173891                 0                   0  3.785399   
2148              6.307543                 0                   1  8.327563   

      Confusion  Disorientation  PersonalityChanges   
0             0               0                   0  \
1             0               0                   0   
2             0               1                   0   
3             0               0                   0   
4             0               0                   1   
...         ...             ...                 ...   
2144          1               0                   0   
2145          0               0                   0   
2146          0               0                   0   
2147          0               0                   0   
2148          0               1                   0   

      DifficultyCompletingTasks  Forgetfulness  Diagnosis  
0                             1              0          0  
1                             0              1          0  
2                             1              0          0  
3                             0              0          0  
4                             1              0          0  
...                         ...            ...        ...  
2144                          0              0          1  
2145                          0              0          1  
2146                          0              0          1  
2147                          0              1          1  
2148                          0              1          0  

[2149 rows x 33 columns]

三 探索性数据分析 

1.得病分布

res = df.groupby('Diabetes')['Age'].count()
print(res)

plt.figure(figsize=(8, 6))
plt.pie(res.values, labels=res.index, autopct='%1.1f%%', startangle=90,
        colors=['#ff9999','#66b3ff','#99ff99'], explode=(0.1,  0),
        wedgeprops={'edgecolor': 'black', 'linewidth': 1, 'linestyle': 'solid'})
plt.title('是否得阿尔茨海默症', fontsize=16, fontweight='bold')
plt.show()

 2.BMI分布直方图

# BMI分布直方图
sns.displot(df['BMI'], kde=True, color='skyblue', bins=30, height=6, aspect=1.2)
plt.title('BMI Distribution', fontsize=18, fontweight='bold', color='darkblue')
plt.xlabel('BMI', fontsize=14, color='darkgreen')
plt.ylabel('Frequency', fontsize=14, color='darkgreen')
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

 

3.年龄分布直方图 

# Age分布直方图
sns.displot(df['Age'], kde=True, color='skyblue', bins=30, height=6, aspect=1.2)
plt.title('Age Distribution', fontsize=18, fontweight='bold', color='darkblue')
plt.xlabel('Age', fontsize=14, color='darkgreen')
plt.ylabel('Frequency', fontsize=14, color='darkgreen')
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

 

 

四 构建划分数据集 

X = df.iloc[:,:-1]
y = df.iloc[:,-1]

sc = StandardScaler()
X = sc.fit_transform(X)

# 划分数据集
X = torch.tensor(np.array(X),dtype=torch.float32)
y = torch.tensor(np.array(y),dtype=torch.int64)

X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.1,random_state=1)

# 构建数据加载器
train_dl = DataLoader(TensorDataset(X_train,y_train),batch_size=64,shuffle=False)
test_dl = DataLoader(TensorDataset(X_test,y_test),batch_size=64,shuffle=False)

 五 训练模型

1.构建模型

# 构建模型
class model_rnn(nn.Module):
        def __init__(self):
                super(model_rnn, self).__init__()
                self.rnn0 = nn.RNN(input_size=32,hidden_size=200,num_layers=1,batch_first=True)
                self.fc0 = nn.Linear(200,50)
                self.fc1 = nn.Linear(50,2)

        def forward(self,x):
                out , hidden1 = self.rnn0(x)
                out = self.fc0(out)
                out = self.fc1(out)
                return out

model = model_rnn().to(device)
print(model)
model_rnn(
  (rnn0): RNN(32, 200, batch_first=True)
  (fc0): Linear(in_features=200, out_features=50, bias=True)
  (fc1): Linear(in_features=50, out_features=2, bias=True)
)

2.训练函数 

'''
训练模型
'''
# 训练循环
def train(dataloader,model,loss_fn,optimizer):
    size = len(dataloader.dataset)   # 训练集的大小
    num_batches = len(dataloader)      # 批次数目,(size/batchsize,向上取整)

    train_acc,train_loss = 0,0  # 初始化训练损失和正确率

    for x,y in dataloader:    # 获取数据
        X,y = x.to(device),y.to(device)

        # 计算预测误差
        pred = model(X)   # 网络输出
        loss = loss_fn(pred,y)   # 计算误差

        # 反向传播
        optimizer.zero_grad()    # grad属性归零
        loss.backward()   # 反向传播
        optimizer.step()   # 每一步自动更新

        # 记录acc与loss
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()

    train_acc /= size
    train_loss /= num_batches

    return train_acc,train_loss

3.测试函数 

# 测试循环
def valid(dataloader,model,loss_fn):
    size = len(dataloader.dataset)  # 训练集的大小
    num_batches = len(dataloader)  # 批次数目,(size/batchsize,向上取整)

    test_loss, test_acc = 0, 0  # 初始化训练损失和正确率

    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs,target in dataloader:
            imgs,target = imgs.to(device),target.to(device)

            # 计算loss
            target_pred = model(imgs)
            loss = loss_fn(target_pred,target)

            test_loss += loss.item()
            test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc /= size
    test_loss /= num_batches

    return test_acc,test_loss

4.正式训练 

loss_fn = nn.CrossEntropyLoss()   # 创建损失函数
learn_rate = 1e-4   # 学习率
opt = torch.optim.Adam(model.parameters(),lr=learn_rate)
epochs = 30

train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    model.train()
    epoch_train_acc,epoch_train_loss = train(train_dl,model,loss_fn,opt)

    model.eval()
    epoch_test_acc,epoch_test_loss = valid(test_dl,model,loss_fn)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

    # 获取当前的学习率
    lr = opt.state_dict()['param_groups'][0]['lr']

    template = ('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},lr:{:.2E}')
    print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss,lr))

print("="*20,'Done',"="*20)
Epoch: 1,Train_acc:52.9%,Train_loss:0.688,Test_acc:67.0%,Test_loss:0.658,lr:1.00E-04
Epoch: 2,Train_acc:68.7%,Train_loss:0.612,Test_acc:67.4%,Test_loss:0.600,lr:1.00E-04
Epoch: 3,Train_acc:68.7%,Train_loss:0.566,Test_acc:70.7%,Test_loss:0.567,lr:1.00E-04
Epoch: 4,Train_acc:74.4%,Train_loss:0.526,Test_acc:72.6%,Test_loss:0.533,lr:1.00E-04
Epoch: 5,Train_acc:77.9%,Train_loss:0.487,Test_acc:78.1%,Test_loss:0.501,lr:1.00E-04
Epoch: 6,Train_acc:81.1%,Train_loss:0.451,Test_acc:79.5%,Test_loss:0.473,lr:1.00E-04
Epoch: 7,Train_acc:82.3%,Train_loss:0.421,Test_acc:80.0%,Test_loss:0.451,lr:1.00E-04
Epoch: 8,Train_acc:83.4%,Train_loss:0.397,Test_acc:78.6%,Test_loss:0.434,lr:1.00E-04
Epoch: 9,Train_acc:84.7%,Train_loss:0.378,Test_acc:80.0%,Test_loss:0.422,lr:1.00E-04
Epoch:10,Train_acc:85.2%,Train_loss:0.365,Test_acc:80.0%,Test_loss:0.414,lr:1.00E-04
Epoch:11,Train_acc:85.6%,Train_loss:0.354,Test_acc:80.0%,Test_loss:0.408,lr:1.00E-04
Epoch:12,Train_acc:85.9%,Train_loss:0.347,Test_acc:80.0%,Test_loss:0.405,lr:1.00E-04
Epoch:13,Train_acc:86.3%,Train_loss:0.341,Test_acc:78.6%,Test_loss:0.403,lr:1.00E-04
Epoch:14,Train_acc:87.0%,Train_loss:0.335,Test_acc:78.1%,Test_loss:0.403,lr:1.00E-04
Epoch:15,Train_acc:87.1%,Train_loss:0.331,Test_acc:78.6%,Test_loss:0.404,lr:1.00E-04
Epoch:16,Train_acc:87.1%,Train_loss:0.327,Test_acc:78.1%,Test_loss:0.405,lr:1.00E-04
Epoch:17,Train_acc:87.1%,Train_loss:0.324,Test_acc:78.6%,Test_loss:0.407,lr:1.00E-04
Epoch:18,Train_acc:87.3%,Train_loss:0.321,Test_acc:78.6%,Test_loss:0.409,lr:1.00E-04
Epoch:19,Train_acc:87.4%,Train_loss:0.318,Test_acc:77.7%,Test_loss:0.412,lr:1.00E-04
Epoch:20,Train_acc:87.7%,Train_loss:0.315,Test_acc:78.1%,Test_loss:0.415,lr:1.00E-04
Epoch:21,Train_acc:87.8%,Train_loss:0.312,Test_acc:77.7%,Test_loss:0.418,lr:1.00E-04
Epoch:22,Train_acc:88.1%,Train_loss:0.309,Test_acc:78.1%,Test_loss:0.422,lr:1.00E-04
Epoch:23,Train_acc:88.6%,Train_loss:0.306,Test_acc:78.1%,Test_loss:0.425,lr:1.00E-04
Epoch:24,Train_acc:88.6%,Train_loss:0.303,Test_acc:79.1%,Test_loss:0.429,lr:1.00E-04
Epoch:25,Train_acc:88.6%,Train_loss:0.301,Test_acc:79.5%,Test_loss:0.433,lr:1.00E-04
Epoch:26,Train_acc:88.6%,Train_loss:0.298,Test_acc:79.5%,Test_loss:0.437,lr:1.00E-04
Epoch:27,Train_acc:88.8%,Train_loss:0.295,Test_acc:80.0%,Test_loss:0.440,lr:1.00E-04
Epoch:28,Train_acc:89.1%,Train_loss:0.292,Test_acc:79.5%,Test_loss:0.444,lr:1.00E-04
Epoch:29,Train_acc:89.1%,Train_loss:0.290,Test_acc:79.1%,Test_loss:0.449,lr:1.00E-04
Epoch:30,Train_acc:89.2%,Train_loss:0.287,Test_acc:79.1%,Test_loss:0.453,lr:1.00E-04
==================== Done ====================

六 结果可视化 

1.Loss和Acc图

epochs_range = range(30)
plt.figure(figsize=(14,4))
plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='training accuracy')
plt.plot(epochs_range,test_acc,label='validation accuracy')
plt.legend(loc='lower right')
plt.title('training and validation accuracy')

plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='training loss')
plt.plot(epochs_range,test_loss,label='validation loss')
plt.legend(loc='upper right')
plt.title('training and validation loss')
plt.show()

 2.调用模型预测

test_X = X_test[0].reshape(1,-1)
pred = model(test_X.to(device)).argmax(1).item()
print('模型预测结果:',pred)
print('=='*20)
print('0:未患病')
print('1:已患病')
模型预测结果: 0
========================================
0:未患病
1:已患病

3.绘制混淆矩阵 

'''
绘制混淆矩阵
'''
print('=============输入数据shape为==============')
print('X_test.shape:',X_test.shape)
print('y_test.shape:',y_test.shape)

pred = model(X_test.to(device)).argmax(1).cpu().numpy()

print('\n==========输出数据shape为==============')
print('pred.shape:',pred.shape)

from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay

# 计算混淆矩阵
cm = confusion_matrix(y_test,pred)

plt.figure(figsize=(6,5))
plt.suptitle('')
sns.heatmap(cm,annot=True,fmt='d',cmap='Blues')
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title('Confusion Matrix',fontsize=12)
plt.xlabel('Pred Label',fontsize=10)
plt.ylabel('True Label',fontsize=10)
plt.tight_layout()
plt.show()
=============输入数据shape为==============
X_test.shape: torch.Size([215, 32])
y_test.shape: torch.Size([215])

==========输出数据shape为==============
pred.shape: (215,)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2283568.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HackTheBox靶机:Sightless;NodeJS模板注入漏洞,盲XSS跨站脚本攻击漏洞实战

HackTheBox靶机:Sightless 渗透过程1. 信息收集常规探测深入分析 2. 漏洞利用(CVE-2022-0944)3. 从Docker中提权4. 信息收集(michael用户)5. 漏洞利用 Froxlor6. 解密Keepass文件 漏洞分析SQLPad CVE-2022-0944 靶机介…

docker安装elk6.7.1-搜集java日志

docker安装elk6.7.1-搜集java日志 如果对运维课程感兴趣,可以在b站上、A站或csdn上搜索我的账号: 运维实战课程,可以关注我,学习更多免费的运维实战技术视频 0.规划 192.168.171.130 tomcat日志filebeat 192.168.171.131 …

XML实体注入漏洞攻与防

JAVA中的XXE攻防 回显型 无回显型 cve-2014-3574

【问题解决】el-upload数据上传成功后不显示成功icon

el-upload数据上传成功后不显示成功icon 原因 由于后端返回数据与要求形式不符,使用el-upload默认方法调用onSuccess钩子失败,上传文件的状态并未发生改变,因此数据上传成功后并未显示成功的icon标志。 解决方法 点击按钮,调用…

spring框架之IoC学习与梳理(1)

目录 一、spring-IoC的基本解释。 二、spring-IoC的简单demo(案例)。 (1)maven-repository官网中找依赖坐标。 (2).pom文件中通过标签引入。 (3)使用lombok帮助快速开发。 &#xff…

150 Linux 网络编程6 ,从socket 到 epoll整理。listen函数参数再研究

一 . 只能被一个client 链接 socket例子 此例子用于socket 例子, 该例子只能用于一个客户端连接server。 不能用于多个client 连接 server socket_server_support_one_clientconnect.c /* 此例子用于socket 例子, 该例子只能用于一个客户端连接server。…

PCIe 个人理解专栏——【2】LTSSM(Link Training and Status State Machine)

前言: 链路训练和状况状态机LTSSM(Link Training and Status State Machine)是整个链路训练和运行中状态的状态转换逻辑关系图,总共有11个状态。 正文: 包括检测(Detect),轮询&…

《DiffIR:用于图像修复的高效扩散模型》学习笔记

paper:2303.09472 GitHub:GitHub - Zj-BinXia/DiffIR: This project is the official implementation of Diffir: Efficient diffusion model for image restoration, ICCV2023 目录 摘要 1、介绍 2、相关工作 2.1 图像恢复(Image Rest…

[笔记] 极狐GitLab实例 : 手动备份步骤总结

官方备份文档 : 备份和恢复极狐GitLab 一. 要求 为了能够进行备份和恢复,请确保您系统已安装 Rsync。 如果您安装了极狐GitLab: 如果您使用 Omnibus 软件包,则无需额外操作。如果您使用源代码安装,您需要确定是否安装了 rsync。…

switch组件的功能与用法

文章目录 1 概念介绍2 使用方法3 示例代码 我们在上一章回中介绍了PageView这个Widget,本章回中将介绍Switch Widget.闲话休提,让我们一起Talk Flutter吧。 1 概念介绍 我们在这里介绍的Switch是指左右滑动的开关,常用来表示某项设置是打开还是关闭。Fl…

mac 电脑上安装adb命令

在Mac下配置android adb命令环境,配置方式如下: 1、下载并安装IDE (android studio) Android Studio官网下载链接 详细的安装连接请参考 Mac 安装Android studio 2、配置环境 在安装完成之后,将android的adb工具所在…

Couchbase UI: Dashboard

以下是 Couchbase UI Dashboard 页面详细介绍,包括页面布局和功能说明,帮助你更好地理解和使用。 1. 首页(Overview) 功能:提供集群的整体健康状态和性能摘要 集群状态 节点健康状况:绿色(正…

[极客大挑战 2019]Knife1

题目 蚁剑直接连接密码是Syc 拿下flag flag{1d373584-fc74-4a2c-a6d4-3691314be4ab}

【Maui】提示消息的扩展

文章目录 前言一、问题描述二、解决方案三、软件开发(源码)3.1 消息扩展库3.2 消息提示框使用3.3 错误消息提示使用3.4 问题选择框使用 四、项目展示 前言 .NET 多平台应用 UI (.NET MAUI) 是一个跨平台框架,用于使用 C# 和 XAML 创建本机移…

【2025最新计算机毕业设计】基于SSM房屋租赁平台【提供源码+答辩PPT+文档+项目部署】(高质量源码,可定制,提供文档,免费部署到本地)

作者简介:✌CSDN新星计划导师、Java领域优质创作者、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和学生毕业项目实战,高校老师/讲师/同行前辈交流。✌ 主要内容:🌟Java项目、Python项目、前端项目、PHP、ASP.NET、人工智能…

(开源)基于Django+Yolov8+Tensorflow的智能鸟类识别平台

1 项目简介(开源地址在文章结尾) 系统旨在为了帮助鸟类爱好者、学者、动物保护协会等群体更好的了解和保护鸟类动物。用户群体可以通过平台采集野外鸟类的保护动物照片和视频,甄别分类、实况分析鸟类保护动物,与全世界各地的用户&…

【转帖】eclipse-24-09版本后,怎么还原原来版本的搜索功能

【1】原贴地址:eclipse - 怎么还原原来版本的搜索功能_eclipse打开类型搜索类功能失效-CSDN博客 https://blog.csdn.net/sinat_32238399/article/details/145113105 【2】原文如下: 更新eclipse-24-09版本后之后,新的搜索功能(CT…

git基础指令大全

版本控制 git管理文件夹 进入要管理的文件夹 — 进入 初始化(提名) git init 管理文件夹 生成版本 .git ---- git在管理文件夹时,版本控制的信息 生成版本 git status 检测当前文件夹下的文件状态 (检测,检测之后就要管理了…

Android实训九 数据存储和访问

实训9 数据存储和访问 一、【实训目的】 1、 SharedPreferences存储数据; 2、 借助Java的I/O体系实现文件的存储, 3、使用Android内置的轻量级数据库SQLite存储数据; 二、【实训内容】 1、实现下图所示的界面,实现以下功能: 1&#xff…

[STM32 标准库]定时器输出PWM配置流程 PWM模式解析

前言: 本文内容基本来自江协,整理起来方便日后开发使用。MCU:STM32F103C8T6。 一、配置流程 1、开启GPIO,TIM的时钟 /*开启时钟*/RCC_APB1PeriphClockCmd(RCC_APB1Periph_TIM2, ENABLE); //开启TIM2的时钟RCC_APB2PeriphClockC…