R7:糖尿病预测模型优化探索

news2024/11/28 4:36:05
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

一、实验目的:

探索本案例是否还有进一步优化的空间

二、实验环境:

  • 语言环境:python 3.8
  • 编译器:Jupyter notebook
  • 深度学习环境:Pytorch
    • torch==2.4.0+cu124
    • torchvision==0.19.0+cu124

三、数据集标准化处理

对比R6的代码,本案例的改进部分主要是取消了两行代码的注释# 数据集标准化处理

from sklearn.preprocessing import StandardScaler   

# '高密度脂蛋白胆固醇'字段与糖尿病负相关,故而在 X 中去掉该字段   
X = DataFrame.drop(['是否糖尿病','高密度脂蛋白胆固醇'],axis=1)   
y = DataFrame['是否糖尿病']   

# 数据集标准化处理   
sc_X    = StandardScaler()   
X = sc_X.fit_transform(X)   

X = torch.tensor(np.array(X), dtype=torch.float32)   
y = torch.tensor(np.array(y), dtype=torch.int64)   

train_X, test_X, train_y, test_y = train_test_split(X, y,
                                                    test_size=0.2,   
                                                    random_state=1)   
train_X.shape, train_y.shape   

代码输出结果如下:
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

对比R6的结果可以看到测试集准确率有了明显的提高,而原因我也有在上一次总结中提及。结果差异大的可能原因分析如下:

数据处理差异

  1. 特征选择不同
    • 在R6中,除了删除'高密度脂蛋白胆固醇''是否糖尿病',还删除了'卡号'。这意味着R6代码使用的特征集比R7少了一个额外的特征。不同的特征集会对模型的学习和预测能力产生影响。'卡号'这个特征本身包含了一些与糖尿病相关的潜在信息(不同卡号对应的人群有不同的糖尿病发病倾向等),那么去掉它会改变模型的行为。
  2. 标准化处理的有无
    • R7代码对数据进行了标准化处理(sc_X.fit_transform(X)),而R6没有。标准化可以使数据的特征具有相似的尺度,这对于一些基于距离或梯度的算法(如神经网络中的梯度下降)非常重要。如果数据没有标准化,可能会导致某些特征在模型训练过程中对损失函数的影响过大或过小,从而影响模型的收敛速度和性能。例如,在神经网络中,如果某个特征的值域很大,而另一个特征的值域很小,那么在计算梯度时,可能会使模型过度关注值域大的特征,而忽略值域小的特征。

模型训练相关因素

  1. 初始条件差异
    • 模型的初始参数(如神经网络的权重初始化)可能会因为数据的不同而导致不同的训练轨迹。不同的数据分布(由于数据处理的不同)可能会使模型在相同的初始化策略下朝着不同的方向优化。
  2. 训练动态变化
    • 由于数据的改变,模型在训练过程中的梯度更新情况也会不同。R6代码中没有标准化的数据可能会导致梯度在某些方向上变化剧烈,影响模型收敛。而在R7中,标准化后的数据可能使梯度更新更加稳定和合理,导致训练和测试准确率的变化趋势不同。而且,不同的数据特征可能在不同的训练阶段对模型的贡献不同,进而影响模型在每个 epoch 的准确率和损失值。

在这里插入图片描述
在这里插入图片描述

四、总结

标准化处理对模型训练的具体影响:

  1. 加快收敛速度

    • 原理:在未进行标准化处理时,不同特征的取值范围可能差异巨大。例如,一个特征的取值范围是0 - 1,另一个特征的取值范围是0 - 1000。在基于梯度下降的优化算法中,如随机梯度下降(SGD),损失函数的梯度更新会受到特征尺度的影响。对于取值范围大的特征,其梯度更新步长可能会过大,导致模型难以收敛到最优解;而对于取值范围小的特征,其梯度更新步长可能过小,模型学习这些特征的速度会很慢。标准化将数据变换到均值为0、标准差为1的分布,使得不同特征具有相似的尺度。这样,在梯度更新时,各个特征能够以相对平衡的步长进行学习,从而加快模型收敛速度。
    • 示例:假设我们有一个简单的线性回归模型 y = w 1 x 1 + w 2 x 2 + b y = w_1x_1+w_2x_2 + b y=w1x1+w2x2+b,其中 x 1 x_1 x1的取值范围是[0, 1], x 2 x_2 x2的取值范围是[0, 100]。如果不进行标准化,当更新 w 1 w_1 w1 w 2 w_2 w2的梯度时,由于 x 2 x_2 x2的取值范围大, w 2 w_2 w2的更新步长会比 w 1 w_1 w1大很多。经过标准化后, x 1 x_1 x1 x 2 x_2 x2的尺度变得相似, w 1 w_1 w1 w 2 w_2 w2可以以更合理的步长进行更新,模型能够更快地找到合适的权重组合。
  2. 提高模型精度

    • 原理:标准化可以减少数据中异常值和离群点对模型训练的影响。在一些机器学习算法中,如支持向量机(SVM)和神经网络,异常值可能会导致模型过度拟合这些异常数据点,从而降低模型在其他正常数据点上的泛化能力。标准化通过将数据压缩到一个相对稳定的范围内,降低了异常值的影响,使模型能够更好地学习数据的整体分布规律,进而提高模型的精度。
    • 示例:考虑一个K - 近邻(K - NN)分类器,它基于数据点之间的距离进行分类。如果存在一个特征有很大的取值范围,并且有一些离群点,那么这些离群点会在距离计算中占据主导地位,导致分类错误。通过标准化,所有特征的取值范围变得相对均匀,离群点的影响被减弱,K - NN分类器能够更准确地根据数据的真实分布进行分类。
  3. 增强模型稳定性和鲁棒性

    • 原理:在模型训练过程中,数据的微小变化可能会导致模型性能的大幅波动。标准化后的数据在一定程度上减少了这种波动,因为它使得数据的分布更加稳定。对于神经网络等复杂模型,这种稳定性尤为重要,因为它们通常具有大量的参数,容易受到数据变化的影响。标准化可以帮助模型在不同的数据集(如训练集和测试集)上保持相对一致的性能,提高模型的鲁棒性。
    • 示例:在深度学习中,当使用小批量梯度下降(Mini - Batch Gradient Descent)训练模型时,每个小批量数据的分布可能会有所不同。如果没有标准化,模型可能会因为小批量数据的差异而产生较大的性能波动。而标准化后,小批量数据的分布更加稳定,模型的训练过程更加平稳,对数据变化的敏感度降低,从而增强了模型的鲁棒性。

除了标准化处理,还有下述数据预处理方法:

  1. 数据清洗
    • 缺失值处理
      • 删除法:当数据集中的缺失值占比较小,且缺失是完全随机的情况下,可以直接删除含有缺失值的行或列。例如,在一个包含1000条记录的客户信息数据集中,如果只有少数几条记录的“联系电话”字段缺失,且这些缺失看起来是随机发生的,那么可以考虑删除这些记录。
      • 插补法
        • 均值/中位数/众数插补:对于数值型特征的缺失值,可以使用该特征的均值、中位数来填充。比如,在一个学生成绩数据集中,如果某个学生的数学成绩缺失,可以用全班数学成绩的均值来填充。对于分类型特征,则可以使用众数填充,例如,在一个调查问卷数据集中,“性别”字段有缺失值,就可以用出现频率最高的性别(众数)来填充。
        • 回归插补:通过建立一个回归模型,利用数据集中其他相关特征来预测缺失值。假设在一个房屋价格数据集中,“房屋面积”字段有缺失值,可以建立一个以房屋价格、房间数量等其他特征为自变量,房屋面积为因变量的回归模型,然后利用该模型预测缺失的房屋面积。
    • 异常值处理
      • 盖帽法:将超出一定范围(如大于某个分位数加上几倍的四分位距,或小于某个分位数减去几倍的四分位距)的异常值替换为该范围的边界值。例如,在一个员工工资数据集中,若发现工资数据有异常高的值,可以将这些异常高的值替换为上四分位数加上1.5倍四分位距的值。
      • 删除法:当异常值对模型训练有严重干扰,且其出现是由于数据录入错误等原因时,可以直接删除异常值。例如,在一个体温测量数据集中,如果出现了明显不符合人体正常体温范围(如50℃)的异常值,且确定是测量错误导致的,就可以删除。
  2. 数据编码
    • 独热编码(One - Hot Encoding):主要用于处理分类型数据。对于一个具有 n n n个类别(如颜色有红、绿、蓝三种)的分类变量,会创建 n n n个新的二进制变量(0或1)来表示。例如,对于“颜色”这个分类变量,会创建“红色”、“绿色”、“蓝色”三个新变量,当原始数据为红色时,“红色”变量为1,其他两个为0。这种编码方式可以避免模型将分类变量的类别顺序错误地理解为数值大小关系。
    • 标签编码(Label Encoding):将分类变量的类别转换为整数。例如,对于一个包含“低”、“中”、“高”三个类别的变量“风险等级”,可以将“低”编码为0,“中”编码为1,“高”编码为2。不过这种编码方式可能会让模型误解类别之间的顺序关系,所以在一些对类别顺序不敏感的模型(如决策树)中可以使用,而在对顺序敏感的模型(如线性回归)中可能需要谨慎使用。
  3. 数据变换
    • 对数变换:对于一些具有正偏态分布(数据大部分集中在左侧,右侧有较长的尾巴)的数据,如收入数据、人口数据等,可以使用对数变换将其转换为近似正态分布。例如,对一组公司营业收入数据进行对数变换后,数据的分布会更加对称,这样更符合一些模型(如基于正态分布假设的线性回归)的假设,有助于提高模型性能。
    • 平方根变换:和对数变换类似,主要用于改善数据的分布形态。当数据的偏态程度不是特别严重时,平方根变换可能是一种有效的方法。例如,对于一些生物实验中的细胞计数数据,其分布可能存在一定的偏态,通过平方根变换可以使其分布更加合理,方便后续模型的训练。

五、模型优化探索

1. 调整模型结构

  • 增加更多的层:可以考虑增加LSTM层的数量。更多的层可能能够捕捉到更复杂的数据模式。例如,可以增加到3层或4层的LSTM,但要注意防止过拟合,可以结合正则化方法。

  • 添加Dropout层:在LSTM层之间或者在全连接层之前添加Dropout层。Dropout在训练过程中随机忽略一些神经元,有助于减少过拟合。比如,可以在每个LSTM层之后添加Dropout层,设置Dropout概率为0.2 - 0.5之间。

class model_lstm(nn.Module):
    def __init__(self):
        super(model_lstm, self).__init__()
        self.lstm0 = nn.LSTM(input_size=13, hidden_size=200, num_layers=1, batch_first=True)
        self.dropout1 = nn.Dropout(0.2)  # 添加Dropout层,概率为0.2
        self.lstm1 = nn.LSTM(input_size=200, hidden_size=200, num_layers=1, batch_first=True)
        self.dropout2 = nn.Dropout(0.2)  # 添加另一个Dropout层
        self.fc0 = nn.Linear(200, 2)

    def forward(self, x):
        out, hidden1 = self.lstm0(x)
        out = self.dropout1(out)  # 使用Dropout
        out, _ = self.lstm1(out, hidden1)
        out = self.dropout2(out)  # 使用Dropout
        out = self.fc0(out)
        return out

2. 优化超参数

  • 调整隐藏层大小:当前LSTM的隐藏层大小是200,可以尝试不同的值,如128、256、300等。通过交叉验证等方法来确定最佳的隐藏层大小,以提高模型的性能。
  • 调整学习率:使用不同的学习率策略,如学习率衰减。可以从一个相对较大的学习率开始,随着训练的进行逐渐降低学习率。例如,使用Adam优化器,并设置初始学习率为0.001,然后每几个epoch按照一定比例(如0.9)降低学习率。
from torch.optim.lr_scheduler import StepLR

loss_fn = nn.CrossEntropyLoss()  # 创建损失函数
learn_rate = 1e-3   # 初始学习率增大为1e - 3
opt = torch.optim.Adam(model.parameters(), lr=learn_rate)
scheduler = StepLR(opt, step_size=10, gamma=0.1)  # 添加学习率调度器
epochs = 30

train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)

    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

    # 获取当前的学习率
    lr = opt.state_dict()['param_groups'][0]['lr']
    scheduler.step()  # 更新学习率

    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
    print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss, 
                          epoch_test_acc * 100, epoch_test_loss, lr))

print("=" * 20, 'Done', "=" * 20)

3. 改进输入数据处理

  • 特征工程:进一步分析输入的13个特征,看是否可以通过特征组合、特征变换等方式提高数据的质量。例如,如果某些特征之间存在非线性关系,可以尝试添加一些新的特征来表示这种关系。
  • 数据增强(如果适用):如果数据具有一定的可变性,可以考虑数据增强方法。例如,如果数据是时间序列数据,可以通过对数据进行小幅度的时间平移、缩放等操作来增加数据量和多样性。

4. 模型集成

  • 使用多个模型进行集成:训练多个不同初始化或者不同超参数的LSTM模型,然后将它们的预测结果进行集成,如简单平均或者加权平均。这样可以利用不同模型的优势,提高整体的预测准确性。

经过上述调整模型结构、优化超参数后(特征工程效果不明显),效果有所提升,结果输出如下:
在这里插入图片描述
在这里插入图片描述

5. 改进空间

  • 进一步调整模型结构

    • 增加更多层或调整层大小:虽然已经添加了一些改进,但可以继续尝试增加LSTM的层数或者调整隐藏层大小。例如,可以尝试将LSTM的层数增加到3层,同时适当调整每层的隐藏单元数量,如将隐藏单元数量从200增加到256或300等,然后观察模型性能的变化。
    • 尝试不同类型的层:可以考虑在模型中添加其他类型的神经网络层,如卷积层(如果数据结构适合)。对于时间序列数据或具有局部相关性的数据,卷积层可以提取数据中的局部特征,与LSTM层结合可能会提高模型性能。
  • 优化超参数

    • 更精细的学习率调整:目前使用了简单的学习率衰减策略,可以尝试更复杂的学习率调整方法,如余弦退火学习率(Cosine Annealing Learning Rate)。这种方法可以使学习率在训练过程中按照余弦函数的规律变化,在前期可以有较大的学习率快速收敛,后期学习率逐渐减小以更精细地调整模型参数。
    • 调整Dropout概率:当前设置的Dropout概率为0.2,可以尝试不同的值,如0.3、0.4等,观察对模型过拟合情况的影响。不同的Dropout概率可能会改变模型的复杂度和泛化能力。
  • 改进特征工程

    • 深入分析数据相关性:对新添加的特征进行更深入的分析,检查它们与目标变量的相关性。如果某些新特征与目标变量相关性较低,可以考虑删除或进一步修改这些特征的构建方式。同时,可以探索更多的特征组合和变换方法,以挖掘数据中更有价值的信息。
    • 添加领域相关特征(如果可能):根据数据的领域知识,添加更多有意义的特征。例如,如果是医疗数据,可以考虑添加一些与疾病相关的衍生特征,如不同指标之间的比率等。
  • 模型集成

    • 使用多种模型集成:除了单一的LSTM模型,可以尝试将其与其他类型的模型(如决策树、支持向量机等)进行集成。例如,可以使用堆叠(Stacking)或混合(Blending)等集成方法,将不同模型的预测结果进行组合,以提高模型的泛化能力和稳定性。
    • 训练多个不同初始化的LSTM模型并集成:训练多个不同初始化参数的LSTM模型,然后通过平均或加权平均等方式将它们的预测结果集成在一起。这种方法可以利用不同初始化模型的多样性,降低模型的方差,提高预测性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2235469.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Django 框架:全方位技术分析

Django 框架:全方位技术分析 介绍 Django 是一个高级 Python Web 框架,鼓励快速开发和遵循设计的最佳实践。由经验丰富的开发人员打造,开源并可扩展,Django 旨在让构建 Web 应用更快、更轻松。 历史背景 Django 始于 2003 年,最初是 Lawrence Journal-World 报社的一个内…

如何用 ChatPaper.ai 打造完美的 AI 课堂笔记系统

作为学生,我们都遇到过这样的困扰:上课时记笔记太投入就听不进讲解,专注听讲又担心错过重要知识点。有了AI助手,这个问题就可以优雅地解决了。今天跟大家分享如何用ChatPaper.ai构建个人的智能课堂笔记系统。 为什么需要AI辅助记笔…

《手写Spring渐进式源码实践》实践笔记(第十六章 三级缓存解决循环依赖)

文章目录 第十六章 通过三级缓存解决循环依赖背景技术背景Spring循环依赖循环依赖类型三级缓存解决循环依赖 业务背景 目标设计一级缓存实现方案设计思路代码实现测试结果 三级缓存实现方案 实现代码结构类图实现步骤 测试事先准备属性配置文件测试用例测试结果: 总…

Java8新特性/java

1.lambda表达式 区别于js的箭头函数,python、cpp的lambda表达式,java8的lambda是一个匿名函数,java8运行把函数作为参数传递进方法中。 语法格式 (parameters) -> expression 或 (parameters...) ->{ statements; }实战 替代匿名内部类…

ubuntu双屏只显示一个屏幕另一个黑屏

简洁的结论: 系统环境 ubuntu22.04 nvidia-535解决方案 删除/etc/X11/xorg.conf 文件 记录一下折腾大半天的问题。 ubuntu系统是22.04,之前使用的时候更新驱动导致桌面崩溃,重新安装桌面安装不上,请IT帮忙,IT一番操作过后也表示…

Oracle OCP认证考试考点详解082系列15

题记: 本系列主要讲解Oracle OCP认证考试考点(题目),适用于19C/21C,跟着学OCP考试必过。 71. 第71题: 题目 解析及答案: 对于数据库,使用数据库配置助手(DBCA)可以执行…

为什么说SQLynx是链接国产数据库的最佳选择?

第一是因为SQLynx提供了广泛的国产数据库支持,除了市面上的主流数据库MYSQL、Oracle、PostgreSQL 、SQL Server、SQLite、MongoDB外、还支持达梦人大金仓等国产数据源! 近几年随着国产数据库市场的不断发展和成熟,越来越多的企业和机构开始选…

一个基于强大的 PostgreSQL 数据库构建的无代码数据库平台,快速构建应用程序,而无需编写一行代码(带私活源码)

随着企业和个人开发需求的不断增加,无代码平台成为了现代开发的重要组成部分,帮助那些没有技术背景的用户也能轻松创建和管理数据库应用程序。今天,我将向大家推荐一个非常出色的开源项目——Teable,它不仅支持无代码开发&#xf…

软件开发项目管理:实现目标的实用指南

由于软件项目多数是复杂且难以预测的,对软件开发生命周期的深入了解、合适的框架以及强大的工作管理平台是必不可少的。项目管理系统在软件开发中通常以监督为首要任务,但优秀的项目计划、管理框架和软件工具可以使整个团队受益。 软件开发项目管理的主要…

计算网络信号

题目描述: 网络信号经过传递会逐层衰减,且遇到阻隔物无法直接穿透,在此情况下需要计算某个位置的网络信号值。注意:网络信号可以绕过阻隔物 array[m][n]的二维数组代表网格地图, array[i][j]0代表i行j列是空旷位置&…

ESP8266 自定义固件烧录-Tcpsocket固件

一、固件介绍 固件为自定义开发的一个适配物联网项目的开源固件,支持网页配网、支持网页tcpsocket服务器配置、支持串口波特率设置。 方便、快捷、稳定! 二、烧录说明 固件及工具打包下载地址: https://download.csdn.net/download/flyai…

数据结构与算法——Java实现 52.力扣98题——验证二叉搜索树

我将一直向前,带着你给我的淤青 —— 24.11.5 98. 验证二叉搜索树 给你一个二叉树的根节点 root ,判断其是否是一个有效的二叉搜索树。 有效 二叉搜索树定义如下: 节点的左子树只包含 小于 当前节点的数。节点的右子树只包含 大于 当前节点的…

[mysql]DDL,DML综合案例,

综合案例 题目如下 目录 综合案例 ​编辑 ​编辑 # 1、创#1建数据库test01_library # 2、创建表 books,表结构如下: # 3、向books表中插入记录库存 # 4、将小说类型(novel)的书的价格都增加5。 # 5、将名称为EmmaT的书的价格改为40,并将…

书生实战营第四期-基础岛第三关-浦语提示词工程实践

一、基础任务 任务要求:利用对提示词的精确设计,引导语言模型正确回答出“strawberry”中有几个字母“r”。 1.提示词设计 你是字符计数专家,能够准确回答关于文本中特定字符数量的问题。 - 技能: - 📊 分析文本&…

国药准字生发产品有哪些?这几款不错

头秃不知道怎么选的朋友们看这,基本上市面上火的育发精华我都用了个遍了,陆陆续续也花了有大几w了,都是真金白银总结出来的,所以必须要给掉发人分享一些真正好用的育发产品,大家可以根据自己实际情况来选择。 1. 露卡菲…

构建基于 DCGM-Exporter, Node exporter,PROMETHEUS 和 GRAFANA 构建算力监控系统

目录 引言工具作用概述DCGM-ExporterNode exporterPROMETHEUSGRAFANA小结 部署单容器DCGM-ExporterNode exporterPROMETHEUSGRAFANANode exporterDCGM-Exporter 多容器Node exporterDCGM-ExporterDocker Compose 参考 引言 本文的是适用对象,是希望通过完全基于Doc…

Java入门14——动态绑定(含多态)

大家好,我们今天来学动态绑定和多态,话不多说,开始正题~ 但是要学动态绑定之前,我们要学习一下向上转型,方便后续更好地理解~ 一、向上转型 1.什么是向上转型 网上概念有很多,但其实通俗来讲&#xff0c…

Request 和 Response 万字详解

文章目录 1.Request和Response的概述2.Request对象2.1 Request 继承体系2.2 Request获取请求数据2.2.1 获取请求行数据2.2.2 获取请求头数据2.2.3 获取请求体数据2.2.4 获取请求参数的通用方式 2.3 解决post请求乱码问题 掌握内容讲解内容小结 2.4 Request请求转发 3.HTTP响应详…

Qt QCustomplot 在采集信号领域的应用

文章目录 一、常用的几种开源库:1、QCustomPlot:2、QChart:3、Qwt:QCustomplot 在采集信号领域的应用1、应用实例时域分析频谱分析2.数据筛选和处理其他参考自然界中的物理过程、传感器和传感器网络、电路和电子设备、通信系统等都是模拟信号的来源。通过可视化模拟信号,可以…

【初阶数据结构与算法】沉浸式刷题之顺序表练习(顺序表以及双指针两种方法)

文章目录 顺序表练习1.移除数组中指定的元素方法1(顺序表)方法2(双指针) 2.删除有序数组中的重复项方法1(顺序表)方法2(双指针) 3.双指针练习之合并两个有序数组方法1(直…