RNN实现精神分裂症患者诊断(pytorch)

news2025/3/6 3:57:02

RNN理论知识

RNN(Recurrent Neural Network,循环神经网络) 是一种 专门用于处理序列数据(如时间序列、文本、语音、视频等)的神经网络。与普通的前馈神经网络(如 MLP、CNN)不同,RNN 具有“记忆”能力,能够利用过去的信息来影响当前的计算结果。

1. RNN 的基本结构

RNN 的核心特点是 “循环”结构,它会将前一个时间步 ( t − 1 ) (t-1) t1计算出的隐藏状态 h t − 1 h_{t-1} ht1 传递给当前时间步 ( t ) (t) t,使得网络可以保留历史信息。

这种结构可以表示为:

h t = f ( W x X t + W h h t − 1 + b ) h_t=f(W_xX_t+W_hh_{t-1}+b) ht=f(WxXt+Whht1+b)

其中:

  • X t X_t Xt:当前时刻的输入数据。
  • h t h_t ht:当前时刻的隐藏状态 。
  • W x 、 W h 、 b W_x、W_h、b WxWhb:可训练的参数 。
  • f f f:激活函数(通常是 tanh 或ReLU)。

RNN 的展开结构:
在时间步(time step)上,RNN 结构可以展开成如下形式:
在这里插入图片描述
图示解释:

X 1 , X 2 , X 3 , . . . X_1,X_2,X_3,... X1,X2,X3,... 代表输入的 序列数据(如文本、时间序列信号)。
h 0 , h 1 , h 2 , h 3 , . . . h_0,h_1,h_2,h_3,... h0,h1,h2,h3,... 代表 隐藏状态,用于存储过去的信息。
Y 1 , Y 2 , Y 3 , . . . Y_1,Y_2,Y_3,... Y1,Y2,Y3,...代表 输出。
在每个时间步,RNN 使用当前输入 X t X_t Xt 和前一时刻的隐藏状态 h t − 1 h_{t-1} ht1来计算新的隐藏状态 h t h_t ht,然后生成输出 Y t Y_t Yt

2. RNN 的缺点

尽管 RNN 在处理序列数据方面有独特的优势,但它也存在一些明显的问题:
(1)梯度消失(Vanishing Gradient)
在长序列训练时,误差的梯度会随着时间步增多而逐渐变小,导致网络无法有效学习较远时间步的信息。
解决方案:使用 LSTM(长短时记忆网络) 或 GRU(门控循环单元) 结构。
(2)梯度爆炸(Exploding Gradient)
如果梯度在反向传播过程中不断累积,可能会变得 非常大,导致模型更新过快或无法收敛。
解决方案:使用 梯度裁剪(Gradient Clipping) 来防止梯度过大。
(3)无法并行计算
由于 RNN 依赖前一个时间步的计算结果,因此无法像 CNN 那样并行计算,这导致训练速度较慢。
解决方案:使用 Transformer 模型(如 BERT、GPT)来替代 RNN。

3. RNN 的改进版本

由于 RNN 存在梯度消失等问题,研究人员提出了更强大的 变种 RNN 结构:
(1)LSTM(Long Short-Term Memory)
在这里插入图片描述

  • LSTM 引入了 “记忆单元” 和 “门机制”,使得它能够保留长期信息,解决梯度消失问题。
  • 包含 遗忘门(Forget Gate)、输入门(Input Gate)、输出门(Output Gate) 三部分来控制信息流。

(2)GRU(Gated Recurrent Unit)

  • GRU 是 LSTM 的简化版本,只包含 更新门(Update Gate) 和 重置门(Reset Gate),计算效率更高。

数据集

精神分裂症数据集,是一个包含精神分裂症人口统计和临床数据的综合数据集。该数据集包括患者的诊断状态、症状评分、治疗史和社会因素。

代码目标

基于给定的特征(如性别、年龄、收入、症状评分等),预测一个人的诊断标签(是否患有精神分裂症),通过可视化训练损失和计算准确率,评估模型的训练效果与性能。

一、前期准备工作

我的环境:

  • 操作系统:windows10
  • 语言环境:Python3.9
  • 编译器:Jupyter notebook
  • 数据集:精神分裂症患者数据集(“schizophrenia_dataset.csv”)

1. 导入库,设置硬件设备

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
import torch

#设置GPU训练,也可以使用CPU
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

代码输出:

device(type='cpu')

使用 torch.device() 方法检查当前系统是否有 GPU,并根据条件设置计算设备为 GPU(CUDA)或 CPU。

2. 导入数据

读取指定路径的 CSV 文件,并加载到 pandas 的 DataFrame 中,然后打印出数据框的前五行,用于检查数据的内容。

# 读取数据
file_path = 'schizophrenia_dataset.csv'     # 设置数据文件的路径
df = pd.read_csv(file_path)                 # 使用pandas的read_csv函数读取CSV文件,结果存储在DataFrame对象df中
print(df.head())            # 打印数据框的前五行,检查数据的结构和内容

代码输出:

   Hasta_ID  Yaş  Cinsiyet  Eğitim_Seviyesi  Medeni_Durum  Meslek  \
0         1   72         1                4             2       0   
1         2   49         1                5             2       2   
2         3   53         1                5             3       2   
3         4   67         1                3             2       0   
4         5   54         0                1             2       0   

   Gelir_Düzeyi  Yaşadığı_Yer  Tanı  Hastalık_Süresi  Hastaneye_Yatış_Sayısı  \
0             2             1     0                0                       0   
1             1             0     1               35                       1   
2             1             0     1               32                       0   
3             2             0     0                0                       0   
4             2             1     0                0                       0   

   Ailede_Şizofreni_Öyküsü  Madde_Kullanımı  İntihar_Girişimi  \
0                        0                0                 0   
1                        1                1                 1   
2                        1                0                 0   
3                        0                1                 0   
4                        0                0                 0   

   Pozitif_Semptom_Skoru  Negatif_Semptom_Skoru  GAF_Skoru  Sosyal_Destek  \
0                     32                     48         72              0   
1                     51                     63         40              2   
2                     72                     85         51              0   
3                     10                     21         74              1   
4                      4                     27         98              0   

   Stres_Faktörleri  İlaç_Uyumu  
0                 2           2  
1                 2           0  
2                 1           1  
3                 1           2  
4                 1           0  

二、构建数据集

1. 划分数据集

处理数据中的不必要列(唯一标识符)和缺失值,以准备好干净的数据进行模型训练。

df = df.drop(columns=['Hasta_ID'])      # 删除 'Hasta_ID' 列,因为该列是唯一标识符,不需要用作模型输入
df = df.fillna(df.mean())      # 使用每一列的均值填充数据框中的缺失值。这里使用 `df.mean()` 来计算均值,并用它来填充缺失值

数据处理流程:

  • 使用 LabelEncoder 将类别变量转换为数值。
  • 将数据划分为特征(X)和目标(y)。
  • 标准化特征数据。
  • 将数据划分为训练集和测试集。
  • 将数据转换为 PyTorch 张量。
  • 调整张量维度以符合 RNN 模型的要求。
label_encoder = LabelEncoder()     # 创建LabelEncoder实例,用于将类别变量转换为数值
df['Cinsiyet'] = label_encoder.fit_transform(df['Cinsiyet'])       # 将 'Cinsiyet'列中的类别值转化为数值
df['Medeni_Durum'] = label_encoder.fit_transform(df['Medeni_Durum'])     # 将 'Medeni_Durum'列中的类别值转化为数值
df['Yaşadığı_Yer'] = label_encoder.fit_transform(df['Yaşadığı_Yer'])     # 将 'Yaşadığı_Yer'列中的类别值转化为数值

# 将特征和目标分开
X = df.drop(columns=['Tanı'])     # 将数据框中的 'Tanı' 列移除,剩下的列作为特征(X)
y = df['Tanı']      # 'Tanı' 列作为目标变量(y),表示是否患有精神分裂症(二分类标签)


scaler = StandardScaler()     # 创建 StandardScaler 实例,用于标准化特征数据
X_scaled = scaler.fit_transform(X)     # 对特征进行标准化,使得每列的均值为0,标准差为1

# 使用 train_test_split 将数据随机划分为训练集和测试集,测试集占20%。random_state=42 设置随机种子,以确保每次划分结果相同
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)

# 将数据转换为PyTorch的tensor
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)        # 将训练特征数据转换为PyTorch的tensor格式,并指定数据类型为float32
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)          # 将测试特征数据转换为PyTorch的tensor格式,并指定数据类型为float32
y_train_tensor = torch.tensor(y_train.values, dtype=torch.long)    # 将训练目标数据转换为PyTorch的tensor格式,并指定数据类型为long(用于分类问题)
y_test_tensor = torch.tensor(y_test.values, dtype=torch.long)      # 将测试目标数据转换为PyTorch的tensor格式,并指定数据类型为long(用于分类问题)

# 确保数据的形状符合RNN的要求: [batch_size, seq_len, features]
X_train_tensor = X_train_tensor.unsqueeze(1)  # [batch_size, features] --> [batch_size, 1, features]
X_test_tensor = X_test_tensor.unsqueeze(1)    # [batch_size, features] --> [batch_size, 1, features]

# 输出tensor的形状,确保数据正确
print(f"训练数据形状: {X_train_tensor.shape}")     # 打印训练数据的形状,检查是否正确
print(f"测试数据形状: {X_test_tensor.shape}")      # 打印测试数据的形状,检查是否正确

代码输出:

训练数据形状: torch.Size([8000, 1, 18])
测试数据形状: torch.Size([2000, 1, 18])

2. 构建数据加载器

将训练集和测试集的数据(特征和标签)封装成 TensorDataset 对象,并使用 DataLoader 创建数据加载器。
训练集和测试集被分批次加载,每个批次包含 64 个样本。
shuffle=False 表示数据在加载时不进行打乱,在评估的时候顺序保持一致。

from torch.utils.data import TensorDataset, DataLoader

train_dl = DataLoader(TensorDataset(X_train_tensor, y_train_tensor),     # 将训练数据、目标数据包装成一个数据集,并创建一个训练数据加载器
                      batch_size=64, 
                      shuffle=False)

test_dl  = DataLoader(TensorDataset(X_test_tensor, y_test_tensor),      # 将测试数据、目标数据包装成一个数据集,并创建一个测试数据加载器
                      shuffle=False)

三、模型训练

1. 构建模型

import torch.nn as nn

#定义一个名为 _RNN_Base 的类,继承自 nn.Module。该类实现了 RNN(包括 RNN、LSTM 和 GRU)的基础结构
class _RNN_Base(nn.Module):
    def __init__(self, c_in, c_out, hidden_size=100, n_layers=1, 
                 bias=True, rnn_dropout=0, bidirectional=False, 
                 fc_dropout=0., init_weights=True):
        """
        RNN基础类,支持不同RNN单元(如RNN、LSTM、GRU)的实现。
        """
        super(_RNN_Base, self).__init__()  # 确保正确调用父类的构造函数
        # 定义RNN层,支持RNN、LSTM、GRU等
        self.rnn = self._cell(c_in, hidden_size, num_layers=n_layers, 
                              bias=bias, batch_first=True, 
                              dropout=rnn_dropout, 
                              bidirectional=bidirectional)
        
        # 定义全连接层的dropout,如果fc_dropout为0则直接用Identity
        self.dropout = nn.Dropout(fc_dropout) if fc_dropout else nn.Identity()
        self.fc = nn.Linear(hidden_size * (1 + bidirectional), c_out)

    def forward(self, x): 
        """        
        参数:
        - x: 形状为[batch_size, n_vars, seq_len]。
        
        返回:
        - output: 形状为[batch_size, c_out]。
        """
        # [batch_size, n_vars, seq_len] --> [batch_size, seq_len, n_vars]
        x = x.transpose(2,1)  
        # 输出形状为[batch_size, seq_len, hidden_size * (1 + bidirectional)]
        output, _ = self.rnn(x) 
        # 取最后一个时间步的输出,形状为[batch_size, hidden_size * (1 + bidirectional)]
        output = output[:, -1]  
        
        output = self.fc(self.dropout(output))
        return output

# 定义RNN类,继承自_RNN_Base
class RNN(_RNN_Base):
    _cell = nn.RNN  # 使用nn.RNN单元
    
# 定义LSTM类,继承自_RNN_Base
class LSTM(_RNN_Base):
    _cell = nn.LSTM  # 使用nn.LSTM单元
    
# 定义GRU类,继承自_RNN_Base
class GRU(_RNN_Base):
    _cell = nn.GRU  # 使用nn.GRU单元

定义名为 _RNN_Base 的类,继承自 nn.Module。该类实现了 RNN(包括 RNN、LSTM 和 GRU)的基础结构。

_RNN_Base 类的参数解释:

  • c_in:输入特征的维度,即每个时间步的特征数量。
  • c_out:输出类别数量,即模型的输出维度。
  • hidden_size:RNN隐藏层的大小。
  • n_layers:RNN的层数。
  • bias:是否在RNN层中使用偏置项。
  • rnn_dropout:RNN层中的dropout比例。
  • bidirectional:是否使用双向RNN。
  • fc_dropout:全连接层的dropout比例。
  • init_weights:是否初始化权重。

关于_cell ,定义 RNN 层。self._cell 是一个占位符,它将会被具体子类(RNN、LSTM、GRU)的 _cell 属性替代,相关参数解释:

  • c_in:输入特征的数量。
  • hidden_size:RNN单元的隐藏层大小。
  • num_layers:RNN的层数。
  • bias:是否使用偏置项。
  • batch_first=True:意味着输入和输出的格式为 [batch_size, seq_len,features]。
  • dropout=rnn_dropout:RNN中dropout的概率,用来防止过拟合。
  • bidirectional=bidirectional:是否使用双向RNN(即处理序列时同时考虑正向和反向的时间步)。
# 创建一个基于 RNN 的神经网络模型,并将模型移动到指定的设备(CPU 或 GPU)
model = RNN(c_in=X_train_tensor.shape[1], c_out=2).to(device)    
model 

代码输出:

RNN(
  (rnn): RNN(1, 100, batch_first=True)
  (dropout): Identity()
  (fc): Linear(in_features=100, out_features=2, bias=True)
)
from torchinfo import summary

rnn_model = RNN(c_in=3, c_out=5, hidden_size=100,n_layers=2,
                bidirectional=True, rnn_dropout=.5, fc_dropout=.5)    # 初始化一个 RNN 模型,并设置相关参数

summary(rnn_model, input_size=(16, 3, 5))    # 调用 summary 函数,输出 rnn_model 的结构和每一层的详细信息

代码输出:

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
RNN                                      --                        --
├─RNN: 1-1                               [16, 5, 200]              81,400
├─Dropout: 1-2                           [16, 200]                 --
├─Linear: 1-3                            [16, 5]                   1,005
==========================================================================================
Total params: 82,405
Trainable params: 82,405
Non-trainable params: 0
Total mult-adds (M): 6.53
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.13
Params size (MB): 0.33
Estimated Total Size (MB): 0.46
==========================================================================================

2. 定义训练函数

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小
    num_batches = len(dataloader)   # 批次数目

    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率
    
    for X, y in dataloader:  # 获取数据及其标签
        X, y = X.to(device), y.to(device)
        
        # 1. 确保输入数据有三个维度,添加一个seq_len维度
        if X.dim() == 2:  # 如果是二维输入,添加一个序列长度维度
            X = X.unsqueeze(1)  # [batch_size, features] --> [batch_size, 1, features]

        # 2. 前向传播
        pred = model(X)  # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的损失
        
        # 3. 反向传播
        optimizer.zero_grad()  # 清零梯度
        loss.backward()        # 反向传播
        optimizer.step()       # 更新参数
        
        # 记录准确率和损失
        train_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()
            
    train_acc  /= size
    train_loss /= num_batches

    return train_acc, train_loss

3. 定义测试函数

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)  # 测试集的大小
    num_batches = len(dataloader)   # 批次数目
    test_loss, test_acc = 0, 0
    
    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            # 1. 确保输入数据有三个维度,添加一个seq_len维度
            if X.dim() == 2:  # 如果是二维输入,添加一个序列长度维度
                X = X.unsqueeze(1)  # [batch_size, features] --> [batch_size, 1, features]
            
            # 2. 计算损失
            pred = model(X)
            loss = loss_fn(pred, y)
            
            test_loss += loss.item()
            test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_acc /= size
    test_loss /= num_batches

    return test_acc, test_loss

4. 正式训练模型

loss_fn    = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 2e-5   # 学习率
opt        = torch.optim.Adam(model.parameters(),lr=learn_rate)    # 使用 Adam 优化器,并将学习率 learn_rate 应用到优化器中
epochs     = 20     # 设置训练的总轮数为 20。每轮训练都将通过整个训练集一次

train_loss = []  # 初始化一个空列表用于记录每一轮的训练损失
train_acc  = []  # 初始化一个空列表用于记录每一轮的训练准确率
test_loss  = []  # 初始化一个空列表用于记录每一轮的测试损失
test_acc   = []  # 初始化一个空列表用于记录每一轮的测试准确率

# 循环遍历训练轮数
for epoch in range(epochs):
    model.train()    # 设置模型为训练模式
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)
 
    model.eval()    # 设置模型为评估模式
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)

    train_acc.append(epoch_train_acc)  # 将当前训练轮的准确率添加到列表中
    train_loss.append(epoch_train_loss)  # 将当前训练轮的损失添加到列表中
    test_acc.append(epoch_test_acc)  # 将当前测试轮的准确率添加到列表中
    test_loss.append(epoch_test_loss)  # 将当前测试轮的损失添加到列表中

    
    # 获取当前的学习率
    lr = opt.state_dict()['param_groups'][0]['lr']

    # 格式化输出每一轮训练和测试的准确率、损失以及当前学习率
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, 
                          epoch_test_acc*100, epoch_test_loss, lr))
    
print("="*20, 'Done', "="*20)

代码输出:

Epoch: 1, Train_acc:70.1%, Train_loss:0.665, Test_acc:70.9%, Test_loss:0.636, Lr:2.00E-05
Epoch: 2, Train_acc:71.4%, Train_loss:0.596, Test_acc:70.3%, Test_loss:0.558, Lr:2.00E-05
Epoch: 3, Train_acc:72.7%, Train_loss:0.507, Test_acc:80.2%, Test_loss:0.442, Lr:2.00E-05
Epoch: 4, Train_acc:90.8%, Train_loss:0.337, Test_acc:95.7%, Test_loss:0.259, Lr:2.00E-05
Epoch: 5, Train_acc:95.9%, Train_loss:0.212, Test_acc:96.4%, Test_loss:0.179, Lr:2.00E-05
Epoch: 6, Train_acc:96.0%, Train_loss:0.161, Test_acc:96.4%, Test_loss:0.146, Lr:2.00E-05
Epoch: 7, Train_acc:96.2%, Train_loss:0.137, Test_acc:96.7%, Test_loss:0.128, Lr:2.00E-05
Epoch: 8, Train_acc:96.5%, Train_loss:0.121, Test_acc:96.7%, Test_loss:0.116, Lr:2.00E-05
Epoch: 9, Train_acc:96.6%, Train_loss:0.110, Test_acc:96.8%, Test_loss:0.107, Lr:2.00E-05
Epoch:10, Train_acc:96.8%, Train_loss:0.103, Test_acc:96.7%, Test_loss:0.100, Lr:2.00E-05
Epoch:11, Train_acc:96.9%, Train_loss:0.097, Test_acc:96.7%, Test_loss:0.095, Lr:2.00E-05
Epoch:12, Train_acc:96.9%, Train_loss:0.092, Test_acc:96.7%, Test_loss:0.091, Lr:2.00E-05
Epoch:13, Train_acc:97.0%, Train_loss:0.089, Test_acc:96.8%, Test_loss:0.088, Lr:2.00E-05
Epoch:14, Train_acc:97.1%, Train_loss:0.085, Test_acc:96.9%, Test_loss:0.084, Lr:2.00E-05
Epoch:15, Train_acc:97.2%, Train_loss:0.082, Test_acc:97.0%, Test_loss:0.081, Lr:2.00E-05
Epoch:16, Train_acc:97.3%, Train_loss:0.078, Test_acc:97.0%, Test_loss:0.077, Lr:2.00E-05
Epoch:17, Train_acc:97.4%, Train_loss:0.075, Test_acc:97.2%, Test_loss:0.073, Lr:2.00E-05
Epoch:18, Train_acc:97.5%, Train_loss:0.071, Test_acc:97.4%, Test_loss:0.070, Lr:2.00E-05
Epoch:19, Train_acc:97.6%, Train_loss:0.068, Test_acc:97.5%, Test_loss:0.065, Lr:2.00E-05
Epoch:20, Train_acc:97.9%, Train_loss:0.063, Test_acc:97.9%, Test_loss:0.061, Lr:2.00E-05
==================== Done ====================

四、模型评估

1. Loss与Accuracy图

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 200        #分辨率

from datetime import datetime
current_time = datetime.now() # 获取当前时间

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))   # 创建一个新的图表,并设置图表的大小
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')   # 绘制训练准确率曲线
plt.plot(epochs_range, test_acc, label='Test Accuracy')    # 绘制测试准确率曲线
plt.legend(loc='lower right')      # 显示图例,位置为右下角
plt.title('Training and Validation Accuracy')     # 设置子图的标题
plt.xlabel(current_time)    # 将当前时间作为横坐标标签

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')   # 绘制训练损失曲线
plt.plot(epochs_range, test_loss, label='Test Loss')    # 绘制测试损失曲线
plt.legend(loc='upper right')     # 显示图例,位置为右上角
plt.title('Training and Validation Loss')     # 设置子图的标题
plt.show()    # 显示图表

代码输出:

在这里插入图片描述

2. 混淆矩阵

混淆矩阵(Confusion Matrix) 是一种常用的分类模型评估工具,特别适用于 二分类 和 多分类问题。它能够清晰地展示模型的 真实类别(True Labels) 与 预测类别(Predicted Labels) 之间的对应关系,深入分析模型的分类性能。

# 确保输入数据的维度为 [batch_size, seq_len, features]
print("==============输入数据Shape为==============")
print("X_test.shape:", X_test_tensor.shape)
print("y_test.shape:", y_test_tensor.shape)

# 获取预测结果
pred = model(X_test_tensor.to(device)).argmax(1).cpu().numpy()

print("\n==============输出数据Shape为==============")
print("pred.shape:", pred.shape)

代码输出:

==============输入数据Shape为==============
X_test.shape: torch.Size([2000, 1, 18])
y_test.shape: torch.Size([2000])

==============输出数据Shape为==============
pred.shape: (2000,)
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import seaborn as sns


# 计算混淆矩阵
cm = confusion_matrix(y_test, pred)

plt.figure(figsize=(6,5))    # 创建一个新的图形,设置图形的大小为 6x5 英寸
plt.suptitle('')     # 设置图形的总标题,这里设置为空字符串 '',即不显示总标题
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")    # 使用 seaborn 的热力图函数绘制混淆矩阵

# 修改字体大小
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title("Confusion Matrix", fontsize=12)
plt.xlabel("Predicted Label", fontsize=10)
plt.ylabel("True Label", fontsize=10)

# 显示图
plt.tight_layout()  # 调整布局防止重叠
plt.show()

代码输出:

在这里插入图片描述

3. 调用模型进行预测

# 选择单个样本并调整形状为 [batch_size, seq_len, features] 
test_X = X_test_tensor[0].reshape(1, 1, -1)  # 注意这里调整为三维的 [1, 1, features] 

# 获取模型的预测结果
pred = model(test_X.to(device)).argmax(1).item()

print("模型预测结果为:", pred)
print("==" * 20)
print("0:未患病")
print("1:已患病")

代码输出:

模型预测结果为: 0
========================================
0:未患病
1:已患病

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2310334.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

私有云基础架构

基础配置 使用 VMWare Workstation 创建三台 2 CPU、8G内存、100 GB硬盘 的虚拟机 主机 IP 安装服务 web01 192.168.184.110 Apache、PHP database 192.168.184.111 MariaDB web02 192.168.184.112 Apache、PHP 由于 openEuler 22.09 系统已经停止维护了&#xff…

rust学习笔记11-集合349. 两个数组的交集

rust除了结构体,还有集合类型,同样也很重要,常见的有数组(Array)、向量(Vector)、哈希表(HashMap) 和 集合(HashSet)字符串等,好意外呀…

超详细:数据库的基本架构

MySQL基础架构 下面这个图是我给出的一个MySQL基础架构图,可以清楚的了解到SQL语句在MySQL的各个模块进行执行过程。 然后MySQL可以分为两个部分,一个是server层,另一个是存储引擎。 server层 Server层涵盖了MySQL的大多数核心服务功能&am…

AI催化新一轮创业潮与创富潮:深圳在抢跑

作者:尺度商业大掌柜黄利明 2025年春节伊始至今,从DeepSeek R1开源模型持续引发全球围观,到腾讯混元Turbo S模型发布秀出了"秒回"绝活,再到国务院发布《新一代人工智能发展规划(2025-2030)》重磅…

Python:类型转换和深浅拷贝,可变与不可变对象

int():转换为一个整数,只能转换由纯数字组成的字符串 浮点型强转整型会去掉小数点及后面的数,只保留整数部分 #如果字符串中有数字和正负号以外的字符就会报错 float():整形转换为浮点型会自动添加一位小数 .0 如果字符串中有…

NAT 代理服务 内网穿透

🌈 个人主页:Zfox_ 🔥 系列专栏:Linux 目录 一:🔥 NAT 技术背景二:🔥 NAT IP 转换过程三:🔥 NAPT四:🔥 代理服务器🦋 正向…

高级课第五次作业

首先配置交换机,路由器 LSW1配置 [SW1]vlan batch 10 20 30 40 [SW1]int g0/0/2 [SW1-GigabitEthernet0/0/2]port link-type access [SW1-GigabitEthernet0/0/2]port default vlan 10 [SW1]int g0/0/3 [SW1-GigabitEthernet0/0/3]port link-type access […

51单片机编程学习笔记——动态数码管显示多个数字

大纲 视觉残留原理生理基础神经传导与处理 应用与视觉暂留相关的现象 频闪融合不好的实现好的效果 延伸 在《51单片机编程学习笔记——动态数码管》一文中,我们看到如何使用动态数码管显示数字。但是基于动态数码管设计的特点,每次只能显示1个数字。这就…

金蝶ERP星空对接流程

1.金蝶ERP星空OPENAPI地址: 金蝶云星空开放平台 2.下载金蝶云星空的对应SDK包 金蝶云星空开放平台 3.引入SDK流程步骤 引入Kingdee.CDP.WebApi.SDK 右键项目添加引用,在打开的引用管理器中选择浏览页签,点击浏览按钮,找到从官…

【随手笔记】利尔达NB模组

1.名称 移芯EC6263GPP 参数 指令备注 利尔达上电输出 [2025-03-04 10:24:21.379] I_AT_WAIT:i_len2 [2025-03-04 10:24:21.724] LI_AT_WAIT:i_len16 [2025-03-04 10:24:21.724] [2025-03-04 10:24:21.733] Lierda [2025-03-04 10:24:21.733] [2025-03-04 10:24:21.745] OK移…

Vue3的核心语法【未完】

Vue3的核心语法 OptionsAPI与CompositionAPI Options API(选项式) 和 Composition API (组合式)是 Vue.js 中用于构建组件的两种不同方式。Options API Options API Options API 是 Vue 2 中的传统模式,并在 Vue 3…

解决redis lettuce连接池经常出现连接拒绝(Connection refused)问题

一.软件环境 windows10、11系统、springboot2.x、redis 6 7 linux(centos)系统没有出现这问题,如果你是linux系统碰到的,本文也有一定大参考价值。 根本思路就是:tcp/ip连接的保活(keepalive)。 二.问题描述 在spr…

从DNS到TCP:DNS解析流程和浏览器输入域名访问流程

1 DNS 解析流程 1.1 什么是DNS域名解析 在生活中我们会经常遇到域名,比如说CSDN的域名www.csdn.net,百度的域名www.baidu.com,我们也会碰到IP,现在目前有的是IPV4,IPV6。那这两个有什么区别呢?IP地址是互联网上计算机…

解锁Egg.js:从Node.js小白到Web开发高手的进阶之路

一、Egg.js 是什么 在当今的 Web 开发领域,Node.js 凭借其事件驱动、非阻塞 I/O 的模型,在构建高性能、可扩展的网络应用方面展现出独特的优势 ,受到了广大开发者的青睐。它让 JavaScript 不仅局限于前端,还能在服务器端大展身手&…

JavaWeb后端基础(4)

这一篇就开始是做一个项目了,在项目里学习,我主要记录在学习过程中遇到的问题,以及一些知识点 Restful风格 一种软件架构风格 在REST风格的URL中,通过四种请求方式,来操作数据的增删改查。 GET : 查询 …

【文献阅读】The Efficiency Spectrum of Large Language Models: An Algorithmic Survey

这篇文章发表于2024年4月 摘要 大语言模型(LLMs)的快速发展推动了多个领域的变革,重塑了通用人工智能的格局。然而,这些模型不断增长的计算和内存需求带来了巨大挑战,阻碍了学术研究和实际应用。为解决这些问题&…

OpenGL ES -> GLSurfaceView纹理贴图

贴图 XML文件 <?xml version"1.0" encoding"utf-8"?> <com.example.myapplication.MyGLSurfaceViewxmlns:android"http://schemas.android.com/apk/res/android"android:layout_width"match_parent"android:layout_height…

DE2115实现4位全加器和3-8译码器(FPGA)

一、配置环境 1、Quartus 18.1安装教程 软件&#xff1a;Quartus版本&#xff1a;Quartus 18.1语言&#xff1a;英文大小&#xff1a;5.78G安装环境&#xff1a;Win11/Win10/Win8/Win7硬件要求&#xff1a;CPU2.0GHz 内存4G(或更高&#xff09; 下载通道①百度网盘丨64位下载…

【AI大模型】DeepSeek + Kimi 高效制作PPT实战详解

目录 一、前言 二、传统 PPT 制作问题 2.1 传统方式制作 PPT 2.2 AI 大模型辅助制作 PPT 2.3 适用场景对比分析 2.4 最佳实践与推荐 三、DeepSeek Kimi 高效制作PPT操作实践 3.1 Kimi 简介 3.2 DeepSeek Kimi 制作PPT优势 3.2.1 DeepSeek 优势 3.2.2 Kimi 制作PPT优…

run方法执行过程分析

文章目录 run方法核心流程SpringApplicationRunListener监听器监听器的配置与加载SpringApplicationRunListener源码解析实现类EventPublishingRunListener 初始化ApplicationArguments初始化ConfigurableEnvironment获取或创建环境配置环境 打印BannerSpring应用上下文的创建S…