遗传算法与深度学习实战(16)——神经网络超参数优化

news2024/10/8 14:47:46

遗传算法与深度学习实战(16)——神经网络超参数优化

    • 0. 前言
    • 1. 深度学习基础
      • 1.1 传统机器学习
      • 1.2 深度学习
    • 2. 神经网络超参数调整
      • 2.1 超参数调整策略
      • 2.2 超参数调整对神经网络影响
    • 3. 超参数调整规则
    • 小结
    • 系列链接

0. 前言

我们已经学习了多种形式的进化计算,从遗传算法到粒子群优化,以及进化策略和差分进化等高级方法。在之后的学习中,我们将使用这些进化计算 (Evolutionary Computation, EC) 方法来改进深度学习 ( Deep learning, DL),通常称为进化深度学习 (Evolutionary Deep Learning, EDL)。
然而,在构建用于解决 DL 问题的 EDL 解决方案之前,我们必须了解要解决的问题以及如何在没有 EC 的情况下解决它们,毕竟 EC 只是用来改进 DL 的工具。因此,在应用 EC 方法于超参数优化 (Hyperparameter Optimization, HPO) 之前,我们首先介绍超参数优化的重要性和一些手动调整策略。

1. 深度学习基础

1.1 传统机器学习

传统应用程序中,系统是通过使用程序员编写的复杂算法来实现智能化的。例如,假设我们希望识别照片中是否包含狗。在传统的机器学习 (Machine Learning, ML) 中,需要机器学习研究人员首先确定需要从图像中提取的特征,然后提取这些特征并将它们作为输入传递给复杂算法,算法解析给定特征以判断图像中是否包含狗:

传统机器学习

然而,如果要为多种类别图像分类手动提取特征,其数量可能是指数级的,因此,传统方法在受限环境中效果很好(例如,识别证件照片),而在不受限制的环境中效果不佳,因为每张图像之间都有较大差异。
我们可以将相同的思想扩展到其他领域,例如文本或结构化数据。过去,如果希望通过编程来解决现实世界的任务,就必须了解有关输入数据的所有内容并编写尽可能多的规则来涵盖所有场景,并且不能保证所有新场景都会遵循已有规则。
传统机器学习的主要特点是以有限的特征集和显式规则为基础,从大量数据中学习模型,并利用学习到的模型对新数据进行预测或分类;主要方法包括:决策树、朴素贝叶斯分类、支持向量机、最近邻分类、线性回归、逻辑回归等,这些方法通常需要经过数据预处理、特征选择、模型训练和模型评估等一系列步骤,以达到更好的分类或预测效果。
传统机器学习的优点在于它们的理论基础比较成熟,训练和推理速度相对较快,并且可以适用于各种类型的数据,此外,对于一些小规模的数据集,传统机器学习方法的效果也相对不错。然而,传统机器学习方法也有相当明显的局限性,例如,由于传统机器学习方法依赖于手动选择的特征,因此难以捕捉数据中的复杂非线性关系;同时,这些方法通常不具备自适应学习能力,需要人工干预来调整模型。

1.2 深度学习

神经网络内含了特征提取的过程,并将这些特征用于分类/回归,几乎不需要手动特征工程,只需要标记数据(例如,哪些图片是狗,哪些图片不是狗)和神经网络架构,不需要手动提出规则来对图像进行分类,这减轻了传统机器学习技术强加给程序员的大部分负担。
训练神经网络需要提供大量样本数据。例如,在前面的例子中,我们需要为模型提供大量的狗和非狗图片,以便它学习特征。神经网络用于分类任务的流程如下,其训练与测试是端到端 (end-to-end) 的:

神经网络训练
深度学习(Deep Learning, DL)是一类基于神经网络的机器学习算法,其主要特点是使用多层神经元构成的深度神经网络,通过大规模数据训练模型并自动地提取、分析、抽象出高级别的特征,经典的深度神经网络架构示例如下所示:

深度神经网络架构

深度学习的优势在于它可以自动地从大量非结构化或半结构化的数据中学习,同时可以发现数据之间的隐含关系和规律,有效地处理语音、图像、自然语言等复杂的数据。常用的神经网络模型包括多层感知机 (Multilayer Perceptron, MLP)、卷积神经网络 (Convolutional Neural Network, CNN)、循环神经网络 (Recurrent Neural Network, RNN) 等。
深度学习目前已经广泛应用于图像识别、语音识别、自然语言处理等领域,如人脸识别、自动驾驶、智能客服、机器翻译等。虽然深度学习在很多领域取得了出色的成果,但是深度神经网络的训练和优化也存在一些难点和挑战,如梯度消失和梯度爆炸等问题,需要使用一系列优化算法和技巧来解决。

2. 神经网络超参数调整

深度学习模型面临的困难之一是如何调整模型选项和超参数来改进模型。DL 模型中通常都会涉及许多选项和超参数,但通常缺乏详细说明调整的效果,通常研究者仅仅展示最先进模型的效果,经常忽略模型达到最优性能所需的大量调整工作。
通常,学习如何使用不同选项和调整超参数需要大量的建模经验。如果没有进行调整,许多模型可能无法达到最优性能。这不仅是一个经验问题,而且也是 DL 领域本身的一个问题。我们首先学习使用 PyTorch 构建一个基础深度学习模型,用于逼近给定函数。

给定函数

2.1 超参数调整策略

在本节中,我们将介绍一些模型选项和调整 DL 模型超参数的技巧和策略。其中一些技巧是根据大量模型训练经验获得的,但这些策略也是需要不断发展的,随着 DL 的不断发展,模型选项也在不断的扩充。
接下来,我们介绍如何使用超参数和其他选项。添加超参数:batch_sizedata_step。超参数 batch_size 用于确定每次前向传递中输入到网络的数据样本数量;超参数 data_step 用于控制生成的训练数据量:

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
#plotting
from matplotlib import pyplot as plt
from IPython.display import clear_output
#for performance timing
import time

results = []

hp_test = "lr = 3.5e-01" #@param {type:"string"}
learning_rate = 3.5e-01
epochs = 500
middle_layer = 16
batch_size = 25
data_step = .5
data_min = -5
data_max = 5

def function(x):
  return (2*x + 3*x**2 + 4*x**3 + 5*x**4 + 6*x**5 + 10) 

Xi = np.reshape(np.arange(data_min,data_max, data_step), (-1, 1))
yi = function(Xi)
inputs = Xi.shape[1]
yi = yi.reshape(-1, 1)
plt.plot(Xi, yi, 'o', color='black')
plt.plot(Xi,yi, color="red")

tensor_x = torch.Tensor(Xi) # transform to torch tensor
tensor_y = torch.Tensor(yi)

dataset = TensorDataset(tensor_x,tensor_y) # create your datset
dataloader = DataLoader(dataset, batch_size= batch_size, shuffle=True) # create your dataloader

class Net(nn.Module):
    def __init__(self, inputs, middle):
        super().__init__()
        self.fc1 = nn.Linear(inputs,middle)    
        self.fc2 = nn.Linear(middle,middle)    
        self.out = nn.Linear(middle,1)
    def forward(self, x):
        x = F.relu(self.fc1(x))     
        x = F.relu(self.fc2(x))    
        x = self.out(x)
        return x

model = Net(inputs, middle_layer)
print(model)

loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

epoch_report = 25
history = []
start = time.time()
for i in range(epochs):        
    for X, y in iter(dataloader):
        # wrap the data in variables
        x_batch = Variable(torch.Tensor(X))
        y_batch = Variable(torch.Tensor(y))                   
        # forward pass
        y_pred = model(x_batch)        
        # compute and print loss
        loss = loss_fn(y_pred, y_batch)  
        history.append(loss.data/batch_size)         
        # reset gradients
        optimizer.zero_grad()        
        # backwards pass
        loss.backward()        
        # step the optimizer - update the weights
        optimizer.step()  
    if (i+1) % epoch_report == 0:
        clear_output()
        y_ = model(tensor_x)
        plt.plot(Xi, yi, 'o', color='black')
        plt.plot(Xi,y_.detach().numpy(), 'r')
        plt.show()
        print(f"[{i}] Loss = {loss.data}")
        time.sleep(1)
plt.plot(history)
end = time.time() - start

X_a = torch.rand(100,1).clone() * 10 - 5
y_a = model(X_a)
y_a = y_a.detach().numpy()
results.append([hp_test,end, X_a, y_a])

fig = plt.figure()
ax1 = fig.add_subplot(111)
for test,t,x,y in results:
    ax1.scatter(x, y, s=10, marker="s", label=f"{test} in {t:0.1f}s")

plt.legend(loc='upper left')
plt.plot(Xi, yi, 'o', color='black')
plt.show()

2.2 超参数调整对神经网络影响

middle_layer 的值由 5 修改为 25,运行代码后,观察两次测试的预测输出,可以看到 middle_layer25 时的模型性能更好。同时可以看到,模型的训练模型的时间略有不同,这是由于更大的模型需要更长的训练时间。

隐藏层数修改

可以修改超参数 batch_sizedata_step,观察不同超参数对模型性能的影响。但需要注意的是,这些值是相互关联的,如果通过将 data_step 减小到 0.1 来增加数据量,则同样需要增加 batch_size
下图展示了在增加数据量时改变和不改变批大小的结果,可以看到,训练 500epoch 所需的训练时间差异明显。

batch_size

继续修改其它超参数,将 learning_rate3.5e-06 修改为 3.5e-01,在调整超参数时,总体目标是创建一个能够快速训练并产生最佳结果的最小(参数量)模型。

学习率

3. 超参数调整规则

即使本节中仅仅只有五个超参数,超参数的调整仍然可能遇到困难,因此一个简单的方法是按照以下步骤进行操作:

  • 设置网络大小 - 在本节中,通过修改 middle_layer 的值实现,通常,首先调整网络大小或层数。但需要注意的是,增加线性层数通常并不如增加层内网络节点数量有效
    超参数训练规则#1:网络大小——增加网络层数能够从数据中提取更多特征,扩展或减小模型宽度(节点数)能够调整模型的拟合程度
  • 控制数据变化性 – 通常认为深度学习模型训练需要大量的数据,虽然深度学习模型可能会从更多的数据中获益,但这更多依赖于数据集中的变化性,在本节中,我们使用 data_step 值来控制数据的变化性,但通常情况下我们无法控制数据的变化性。因此,如果数据集包含较大的变化性,则很可能相应的需要增加模型的层数和宽度。与 MNIST 数据集中的手写数字图片相比,Fashion-MNIST 数据集中服饰图片的变化性要小得多
    超参数训练规则#2:数据变化性——更具变化性的数据集需要更大的模型,以提取更多特征并学习更复杂的解决方案
  • 选择批大小 - 调整模型的批大小可以显著提高训练效率,然而,批大小并不能解决训练性能问题,增加批大小可能对降低最终模型性能,批大小需要基于输入数据的变化性进行调优,输入数据变化性较大时,较小的批大小通常更有益(范围通常在 16-64 之间),而变化性较小的数据可能需要较大的批大小(范围通常在 64-256 之间,甚至更高)
    超参数训练规则#3:批大小——如果输入数据变化性较大,则减小批大小,对于变化较小且更统一的数据集,增加批大小
  • 调整学习率 - 学习率控制模型学习的速度,学习率与模型的复杂性相关,由输入数据的变化性驱动,数据变化性较高时,需要较小的学习率,而更均匀的数据可以采用较高的学习率,调整模型大小可能也可能需要调整学习率,因为模型复杂性发生了变化
    超参数训练规则#4:学习率——调整学习率以适应输入数据的变化性,如果需要增加模型的大小,通常也需要减小学习率
  • 调整训练迭代次数 - 处理较小规模的数据集时,模型通常会快速收敛到某个基准解,因此,可以简单地减少模型的 epochs (训练迭代次数),但是,如果模型较复杂且训练时间较长,则确定总的训练迭代次数可能更为困难。但多数深度学习框架提供了提前停止机制,它通过监视指定损失值,并在损失值不再变化时自动停止训练,因此,通常可以选择可能需要的最大训练迭代次数,另一种策略是定期保存模型的权重,然后在需要时,可以重新加载保存的模型权重并继续训练
    超参数训练规则#5:训练迭代次数——使用可能需要的最大训练迭代次数,使用提前停止等技术来减少训练迭代次数

使用以上五条策略能够更好的调整超参数,但这些技术只是一般规则,可能会有网络配置、数据集和其他因素改变这些一般规则。接下来,我们将进一步讨论构建稳定模型时可能需要决定的各种模型选项。
除了超参数外,模型改进的最大动力源于模型所选用的各种选项。DL 模型提供了多种选项,具体取决于实际问题和网络架构,但通常模型的细微改变就足以从根本上改变模型的拟合方式。
模型选项包括激活函数、优化器函数、以及网络层的类型和数量的选用。网络层的深度通常由模型需要提取和学习的特征数量所决定,网络层的类型(全连接、卷积或循环网络等)通常由需要学习的特征类型决定。例如,使用卷积层来学习特征的聚类,使用循环神经网络来确定特征的出现顺序。
因此,大多数 DL 模型的网络大小和层类型都受数据变化性和需要学习的特征类型的驱动。对于图像分类问题,卷积层用于提取视觉特征,例如眼睛或嘴巴,循环神经网络层用于处理语言或时间序列数据。
大多数情况下,模型选项需要关注的包括激活、优化器和损失函数。激活函数通常由问题类型和数据形式决定,避免在选项调整的最后阶段修改激活函数。通常,优化器和损失函数的选择决定了模型训练的好坏。下图显示了使用四种不同优化器来训练上一节模型得到的结果,其中超参数 middle_layer 值为 25,可以看出,与 AdamRMSprop 相比,随机梯度下降 (Stochastic Gradient Descent, SGD) 和 Adagrad 表现较差。

不同优化器

损失函数同样会对模型训练产生重大影响。在回归问题中,我们可以使用两个不同的损失函数:均方误差 (mean-squared error, MSE) 和平均绝对误差 (mean absolute error, MAE),下图显示了使用的两个不同损失函数的模型性能对比结果。可以看到,MAE 损失函数的效果更好一些。

不同损失函数

超参数训练规则#6:模型修改作为一般规则,更改模型架构或关键模型选项,都需要重新调整所有超参数

事实上,我们可能需要花费数天甚至数月的时间来调整模型的超参数,直到得到更好的损失和优化函数。超参数调整和模型选项选择并非易事,选用不合适的选项甚至得到更差的模型。在构建有效的 DL 模型时,通常会定义模型并选择最适合实际问题的选项。然后,通过调整各种超参数和选项优化模型。

小结

超参数优化的目标是通过调整模型的超参数,如学习率、正则化系数、网络架构、批大小等,来最大化模型的性能和泛化能力。选择合适的方法取决于问题的特性、计算资源和优化目标的复杂性。本节中,我们介绍了一些常见模型选项和调整DL模型超参数的技巧和策略。

系列链接

遗传算法与深度学习实战(1)——进化深度学习
遗传算法与深度学习实战(2)——生命模拟及其应用
遗传算法与深度学习实战(3)——生命模拟与进化论
遗传算法与深度学习实战(4)——遗传算法(Genetic Algorithm)详解与实现
遗传算法与深度学习实战(5)——遗传算法中常用遗传算子
遗传算法与深度学习实战(6)——遗传算法框架DEAP
遗传算法与深度学习实战(7)——DEAP框架初体验
遗传算法与深度学习实战(8)——使用遗传算法解决N皇后问题
遗传算法与深度学习实战(9)——使用遗传算法解决旅行商问题
遗传算法与深度学习实战(10)——使用遗传算法重建图像
遗传算法与深度学习实战(11)——遗传编程详解与实现
遗传算法与深度学习实战(12)——粒子群优化详解与实现
遗传算法与深度学习实战(13)——协同进化详解与实现
遗传算法与深度学习实战(14)——进化策略详解与实现
遗传算法与深度学习实战(15)——差分进化详解与实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2196579.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计算机找不到msvcr110.dll解决方法,详细解读三种靠谱方法

1. msvcr110.dll 简介 1.1 定义 msvcr110.dll 是 Microsoft Visual C 2012 Redistributable Package 的一部分,它是一个动态链接库(Dynamic Link Library)文件,对于运行使用 Visual C 2012 编译的应用程序至关重要。这个库文件包…

刷题 图论

面试经典 150 题 - 图 200. 岛屿数量 dfs 标记 visited class Solution { public:// dfs 染色const int direction[4][2] {{-1, 0}, {0, -1}, {1, 0}, {0, 1}};void dfs(vector<vector<char>>& grid, vector<vector<bool>>& visited, int x…

哪些人群适合参加六西格玛绿带培训?

六西格玛作为一种全球公认的质量管理方法论&#xff0c;凭借其强大的数据分析和流程改进能力&#xff0c;成为众多企业转型升级的重要工具。而六西格玛绿带培训&#xff0c;作为连接黄带和黑带之间的桥梁&#xff0c;更是吸引了来自不同行业和职位的众多人士。那么&#xff0c;…

理解C语言之深入理解指针(五)

目录 1. sizeof和strlen的对⽐ 1.1 sizeo 1.2 strlen 1.3 sizeof和strlen的对⽐ 2. 数组和指针笔试题解析 2.1 ⼀维数组 2.2 字符数组 2.3 ⼆维数组 3. 指针运算笔试题解析 3.1 题⽬1&#xff1a; 3.2 题⽬2 3.3 题⽬3 3.4 题⽬4 3.5 题⽬5 3.6 题⽬6 3.7 题⽬…

鸿蒙开发之ArkUI 界面篇 二十一 人气卡片综合案例

要实现如下图效果&#xff1a; 仔细分析效果&#xff0c;整体分为三个区域&#xff0c;分别是1、2、3&#xff0c;如图所示 我们整体分析&#xff0c;区域1是观察到的是图片&#xff0c;自然是Image组件&#xff0c;区域2有个背景&#xff0c;左边是Image&#xff0c;水平方向…

《Spring Microservices in Action, 2nd Edition》读后总结

总体来说有种时过境迁的感觉&#xff0c;有些章节的内容已经跟不上现在&#xff0c;特别对于云原生大行其道的当下&#xff0c; 越来越多东西下沉到基础设施层&#xff0c;然后应用层尽量轻量化成了一种新趋势&#xff1b;当然任何事物都具有多面性&#xff0c;云原生那套也要投…

21世纪现代国学四大泰斗颜廷利教授:一位多面兼具深度的思想家

颜廷利&#xff0c;出生于1971年10月15日的这位杰出人物&#xff0c;来自中国山东省济南市的一个平凡家庭。他在北京大学接受了高等教育&#xff0c;专攻哲学和教育学&#xff0c;深入探索了东西方哲学理论。他的研究领域涵盖了哲学、文化、经济等多个领域&#xff0c;并在易经…

【element-tiptap】报错Duplicate use of selection JSON ID cell at Selection.jsonID

我是下载了element-tiptap 给出的示例项目&#xff0c;在本地安装依赖、运行报错了&#xff0c; 报错截图&#xff1a; 在项目目录下找 node_modules/tiptap-extensions/node-modules&#xff0c;把最后的 node-modules 目录名字修改一下&#xff0c;例如修改为 node-modules–…

亨廷顿舞蹈症患者必知的营养补充指南

在生活的舞台上&#xff0c;每个人都是自己故事的主角&#xff0c;即使面对如亨廷顿舞蹈症&#xff08;HD&#xff09;这样的挑战&#xff0c;我们依然可以通过科学的饮食管理&#xff0c;为健康之路增添更多希望与色彩。今天&#xff0c;就让我们一起探索亨廷顿舞蹈症患者应该…

【汇编语言】寄存器(CPU工作原理)(四)—— “段地址x16 + 偏移地址 = 物理地址”的本质含义以及段的概念和小结

文章目录 前言1. "段地址x16 偏移地址 物理地址"的本质含义2. 段的概念3. 内存单元地址小结结语 前言 &#x1f4cc; 汇编语言是很多相关课程&#xff08;如数据结构、操作系统、微机原理&#xff09;的重要基础。但仅仅从课程的角度出发就太片面了&#xff0c;其实…

单片机教案 1.1 ATmega2560单片机概述

第一章 迈进单片机的大门 Arduino是一款便捷灵活、方便上手的开源电子原型平台&#xff0c;为迈进单片机的大门提供了良好的入门途径。以下是对Arduino的详细介绍&#xff1a; 一、Arduino简介 Arduino是一个能够用来感应和控制现实物理世界的一套工具&#xff0c;它由一个基…

C++ 基于SDL库的 Visual Studio 2022 环境配置

系统&#xff1a;w10、编辑器&#xff1a;Visual Studio 2022、 下载地址 必要库&#xff1a; SDL https://github.com/libsdl-org/SDL 字体 https://github.com/libsdl-org/SDL_ttf 图片 https://github.com/libsdl-org/SDL_image 音频 https://github.com/libsdl-org/SDL_m…

连续点击三次用户

有用户点击日志记录表 t2_click_log&#xff0c;包含user_id(用户ID),click_time(点击时间)&#xff0c;请查询出连续点击三次的用户数&#xff0c; 连续点击三次&#xff1a;指点击记录中同一用户连续点击&#xff0c;中间无其他用户点击&#xff1b; CREATE TABLE t2_click…

两个div中间有缝隙

两个div中间有缝隙效果图&#xff1a; 这种是display:inline-block造成的 在父元素中加入font-size:0px;&#xff0c;再在相应的子div中加入font-size:12px;就可以了 调整后效果图&#xff1a;

Pandas和Seaborn数据可视化

Pandas数据可视化 学习目标 本章内容不需要理解和记忆&#xff0c;重在【查表】&#xff01; 知道数据可视化的重要性和必要性知道如何使用Matplotlib的常用图表API能够找到Seaborn的绘图API 1 Pandas数据可视化 一图胜千言&#xff0c;人是一个视觉敏感的动物&#xff0c;大…

数据库-分库分表

什么是分库分表 分库分表是一种数据库优化策略。 目的&#xff1a;为了解决由于单一的库表数据量过大而导致数据库性能降低的问题 分库&#xff1a;将原来独立的数据库拆分成若干数据库组成 分表&#xff1a;将原来的大表(存储近千万数据的表)拆分成若干个小表 什么时候考虑分…

Web 性能优化|了解 HTTP 协议后才能理解的预加载

作者&#xff1a;谦行 一、前言 在性能优化过程中&#xff0c;开发者通常会集中精力在以下几个方面&#xff1a;服务器响应时间&#xff08;RT&#xff09;优化、服务端渲染&#xff08;SSR&#xff09;与客户端渲染优化、以及静态资源体积的减少。然而&#xff0c;对于许多用…

C(十五)函数综合(一)--- 开公司吗?

在这篇文章中&#xff0c;杰哥将带大家 “开公司”。 主干内容部分&#xff08;你将收获&#xff09;&#xff1a;&#x1f449; 为什么要有函数&#xff1f;函数有哪些&#xff1f;怎么自定义函数以及获得函数的使用权&#xff1f;怎么对函数进行传参&#xff1f;函数中变量的…

[嵌入式Linux]—STM32MP1启动流程

STM32MP1启动流程 1.启动模式 STM32MP1等SOC支持从多种设备中启动&#xff0c;如EMMC、SD、NAND、NOR、USB、UART等。其中USB、UART是作为烧录进行启动的。 STM32MP1内部ROM中存储有一段出厂代码来进行判断从哪种设备中启动&#xff0c;上电后这段代码会被执行&#xff0c;这…

使用java函数逆序一个单链表

代码功能 定义了一个ListNode类&#xff0c;用于表示单链表的节点&#xff0c;每个节点包含一个整数值和一个指向下一个节点的引用。 在ReverseLinkedList类的main方法中&#xff0c;创建了一个包含从1到10的整数的单链表。 定义了一个printList方法&#xff0c;用于打印链表的…