【深度学习】5-5 与学习相关的技巧 - 超参数的验证

news2024/9/25 9:38:14

超参数指的是,比如各层的神经元数量、batch大小、参数更新时的学习率或权值衰减等。如果这些超参数没有设置合适的值,模型的性能就会很差。
那么如何能够高效地寻找超参数的值的方法

验证数据
之前我们使用的数据集分成了训练数据和测试数据,训练数据用于学习测试数据用于评估泛化能力。
下面要对超参数设置各种各样的值以进行验证。这里要注意的是不能使用测试数据评估超参数的性能。这一点非常重要,但也容易被忽视。为什么不能使用测试数据评估超参数的性能,因为如果使用测试数需调整超参数,超参数的值会对测试数据发生过拟合

因此,调整超参数时,必须使用超参数专用的确认数据。用于调整超参也的数据,一般称为验证数据。一般使用这个验证数据评估超参数的好坏。

根据不同的数据集,有的会事先分成训练数据、验证数据、测试数据三部分,有的只分成训练数据和测试数据两部分,有的则不进行分割。如果是MNIST数据集,获得验证数据最简单的方法就是从训练数据中事先分割20%作为验证数据

(x_train, t_train),(x_test, t_test) = load_mnist()

# 打乱训练数据
x_train, t_train = shuffle_dataset(x_train, t_train)

# 分割验证数据
validation_rate = 0.28
validation_num = int(x_train.shape[0]*validation_rate)

x_val = x_train[:validation_num]
t_val = t_train[:validation_num]
x_train = x_train[validation_num:]
t_train = t_train[validation_num:]

这里,分割训练数据前,先打乱了输人数据和监督标签。这是因为数据集的数据可能存在偏向(洗牌)

超参数的最优化
进行超参数的最优化时,逐渐缩小超参数的“好值”的存在范围非常重要。所谓逐渐缩小范围,是指一开始先大致设定一个范围,从这个范围中随机出一个超参数(采样),用这个采样到的值进行识别精度的评估;然后,多重复该操作,观察识别精度的结果,根据这个结果缩小超参数的“好值”的范围通过重复这一操作,就可以逐渐确定超参数的合适范围。
(“好值”)

有报告叫显示,在进行神经网络的超参数的最优化时,与网格搜索等有规律的搜索相比,随机采样的搜索方式效果更好。这是因为在多个超参数中,各个超参数对最终的识别精度的影响程度不同。

超参数的范围只要“大致地指定”就可以了。所谓“大致地指定”,是指像0.001到1000这样,以“10的阶乘”的尺度指定范围

在超参数的最优化中,要注意的是深度学习需要很长时间。因此,在超参数的搜索中,需要尽早放弃那些不符合逻辑的超参数。于是,在超参数的最优化中,减少学习的epoch,缩短一次评估所需的时间是一个不错的办法。

下面简单归纳下:

  1. 设定超参数的范围
  2. 从设定的超参数范围中随机采样
  3. 使用步骤1中采样到的超参数的值进行学习,通过验证数据评估识别精度(但是要将epoch设置得很小)
  4. 重复步骤1和步骤2,根据它们的识别精度的结果,缩小超参数的范围

反复进行上述操作,不断缩小超参数的范围,在缩小到一定程度时,从施围中选出一个超参数的值。这就是进行超参数的最优化的一种方法

在参数的最优化中,如果需要更精炼的方法,可以使用贝叶斯最优化。贝叶斯最优化运用以贝叶斯定理为心的数学理论,能够更加严密、高效地进行最优化。

超参数最优化的实现
现在,我们使用MNIST数据集进行超参数的最优化。这里我们将学习率和控制权值衰减强度的系数(下文称为“权值衰减系数”)这两个超参数的搜索问题作为对象。

在该实验中,权值衰减系数的初始范围为10的负8次方到10的负4次方学习率的初始范围为10的负6次方到10的负2次方。此时,超参数的随机采样的代码如下所示:

weight_decay = 10 ** np.random.uniform(-8,-4)
lr = 10 ** np.random.uniform(-6,-2)

超参数最优化的源代码如下:

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.multi_layer_net import MultiLayerNet
from common.util import shuffle_dataset
from common.trainer import Trainer

(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)

# 为了实现高速化,减少训练数据
x_train = x_train[:500]
t_train = t_train[:500]

# 分割验证数据
validation_rate = 0.20
validation_num = int(x_train.shape[0] * validation_rate)
x_train, t_train = shuffle_dataset(x_train, t_train)
x_val = x_train[:validation_num]
t_val = t_train[:validation_num]
x_train = x_train[validation_num:]
t_train = t_train[validation_num:]


def __train(lr, weight_decay, epocs=50):
    network = MultiLayerNet(input_size=784, hidden_size_list=[100, 100, 100, 100, 100, 100],
                            output_size=10, weight_decay_lambda=weight_decay)
    trainer = Trainer(network, x_train, t_train, x_val, t_val,
                      epochs=epocs, mini_batch_size=100,
                      optimizer='sgd', optimizer_param={'lr': lr}, verbose=False)
    trainer.train()

    return trainer.test_acc_list, trainer.train_acc_list


# 超参数的随机搜索======================================
optimization_trial = 100
results_val = {}
results_train = {}
for _ in range(optimization_trial):
    # 指定搜索的超参数的范围===============
    weight_decay = 10 ** np.random.uniform(-8, -4)
    lr = 10 ** np.random.uniform(-6, -2)
    # ================================================

    val_acc_list, train_acc_list = __train(lr, weight_decay)
    print("val acc:" + str(val_acc_list[-1]) + " | lr:" + str(lr) + ", weight decay:" + str(weight_decay))
    key = "lr:" + str(lr) + ", weight decay:" + str(weight_decay)
    results_val[key] = val_acc_list
    results_train[key] = train_acc_list

# 绘制图形========================================================
print("=========== Hyper-Parameter Optimization Result ===========")
graph_draw_num = 20
col_num = 5
row_num = int(np.ceil(graph_draw_num / col_num))
i = 0

for key, val_acc_list in sorted(results_val.items(), key=lambda x:x[1][-1], reverse=True):
    print("Best-" + str(i+1) + "(val acc:" + str(val_acc_list[-1]) + ") | " + key)

    plt.subplot(row_num, col_num, i+1)
    plt.title("Best-" + str(i+1))
    plt.ylim(0.0, 1.0)
    if i % 5: plt.yticks([])
    plt.xticks([])
    x = np.arange(len(val_acc_list))
    plt.plot(x, val_acc_list)
    plt.plot(x, results_train[key], "--")
    i += 1

    if i >= graph_draw_num:
        break

plt.show()

运行结果如下:

在这里插入图片描述

按识别精度从高到低的顺序排列了验证数据的学习的变化从图中可知,直到“Best-5”左右,学习进行得都很顺利。

“Best-5”的超参数的值如下:
Best-5 (val acc:0.73) | lr:0.0052, weight decay:8.97e-06

从这个结果可以看出,学习率在0.001到0.01、权值衰减系数在10的负8次方到10的负4次方之间时,学习可以顺利进行。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/672187.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

WorkPlus AI助理正式上线!为企业打造定制化的AI私有助理

毋庸置疑,ChatGPT的应用充满无限的想象空间。但对于企业来说,使用时面临的最核心的问题就是“存在回答准确性不足”的弊端。那企业都想要通过GPT构建内容生态,在数字化时代保持行业领先地位。 企业都想要结合行业属性、业务需求等自身特点打…

【Flutter】Flutter 数据存储 Hive 的简要使用说明

文章目录 一、前言二、Hive 包的版本号三、Hive 简介1. Hive 是什么?2. Hive 的特点 四、Hive 的基本使用1. Hive 的安装2. Hive 的初始化3. 创建和打开 Hive 数据库4. 数据的存储和读取5. 数据的删除 五、总结 一、前言 🎉想要精通 Flutter&#xff0c…

是时候扔掉cmder, 换上Windows Terminal

作为一个Windows的长期用户,一直没有给款好用的终端,知道遇到了 cmder,它拯救一个习惯用Windows敲shell命令的人。 不用跟我安利macOS真香!公司上班一直用macOS,一方面确实更加习惯windows下面学习, 另一方面是上课需要…

Phantomjs实现后端将URL转换为图片

PhantomJS简介 PhantomJS is a command-line tool. – 其实就是一个命令行工具 PhantomJS的下载地址: Windows:phantomjs-2.1.1-windows.zip Linux:phantomjs-2.1.1-linux-x86_64.tar.bz2;phantomjs-2.1.1-linux-i686.tar.bz2 MacOS:phantomjs-2.1.1-macosx.zip…

西门子Mendix 入门 2

今天还是一直下载失败,就算成功了,速度也只有几K,于是使用翻墙软件,最终下载成功 下载成功后重新点击edit in studio pro 出现如下页面 首先先关闭安全性 进行添加任务和管理任务 点击上方绿色箭头后点击View App 出现如下页面…

ESP32-WROOM-32 UDP单播透传AT指令例程

ESP32-WROOM-32 AT指令配置TCP通讯 ESP32-WROOM-32前言固件烧录测试AT指令UDP单播通讯\透传ESP32配置SoftAPESP32与手机间的UDP通讯与透传普通传输模式演示UDP透传演示 ESP32-WROOM-32 前言 上次演示了ESP32与手机的三种TCP连接与数据传输方法,现在接着上一篇“ESP…

第二章 数据结构(一)——链表,栈和队列与kmp

文章目录 链表栈和队列表达式运算 单调栈单调队列kmp链表练习题826. 单链表827. 双链表 栈和队列练习题828. 模拟栈3302. 表达式求值829. 模拟队列830. 单调栈154. 滑动窗口 kmp练习题831. KMP字符串 kmp虐我一下午 链表 若用链式结构实现链表,效率低,因…

软件开发流程

目录 软件软件开发流程的演变 瀑布模型敏捷模型 XPSCRUMDevOps 1.软件 与计算机系统操作有关的计算机程序、可能有的文件、文档及数据。 软件可以分为两种主要类型: 独立软件:独立软件是一种完整的应用程序,可以直接在计算机或移动设备上…

Android系统安全 — 6.2 Ethernet安卓架构

1. Android Ethernet架构介绍 整个Ethernet系统架构如下图所示: 以太网服务(EthernetService)的启动与注册流程;应用层调用使能ethernet功能的方法流程来分析,从应用层如何将指令一步一步传到底层kernel;…

SAAS-HRM系统概述与搭建环境

SAAS-HRM系统概述与搭建环境 学习目标: 理解SaaS的基本概念 了解SAAS-HRM的基本需求和开发方式掌握Power Designer的用例图 完成SAAS-HRM父模块及公共模块的环境搭建完成企业微服务中企业CRUD功能 初识SaaS 云服务的三种模式 IaaS(基础设施即服务…

使用Windows To Go工具制作你的U盘系统【含下载Windows10系统镜像】亲测已成功23.06.21

WinToGo是一款辅助工具:专为能够让你将系统装进U盘,移动硬盘里,让你在任意电脑都能运行U盘里装的系统! 一、下载,安装“Windows To Go”工具 1、下载Windows To Go工具 口袋系统WinToGo: 安装Win 10到U盘 2、双击Wi…

从0到1精通自动化测试,pytest自动化测试框架,assert断言(七)

目录 一、前言 二、assert 三、异常信息 四、异常断言 五、常用断言 一、前言 断言是写自动化测试基本最重要的一步,一个用例没有断言,就失去了自动化测试的意义了。什么是断言呢? 简单来讲就是实际结果和期望结果去对比,符…

三分钟学习一个python小知识2-----------我的对python的类(Class)和对象(Object)的理解

文章目录 一、类(Class)和对象(Object)是什么?二、Python类和对象的实现1.定义类2.创建对象3.调用类的属性和方法 三、利用python实现了一个动物的类(Animal)和其两个子类(Cat和Dog&…

年轻人存款难,要攒够多少存款才可以体面的养老,结论亮了

这个情况确实值得我们思考。年轻人的经济压力比较大,所以他们普遍存款比较少。而10万元确实是一个比较大的数目,对于一些年轻人来说可能确实很难达到。 然而,我认为这并不是一个“坎”。我们应该鼓励年轻人理财,增加存款,以便应对未来可能出现的各种经济问题。同时,我们…

定义一个一维数组存放10个整数,要求从键盘输入10个数,对其进行求和、求平均、求最大值/最小值及其位置的下标

目录 题目 分析思路 法一:在主函数直接编程 法二:用 调用函数 实现 代码 法一:在主函数直接编程 法二:用 调用函数 实现 题目 定义一个一维数组存放10个整数,要求从键盘输入10个数,对其进行求和、求…

新华三H3C无线控制器AC对接网络准入实现定制化Portal短信认证

随着企业办公信息化的不断发展,企业内网安全也面临着诸多挑战。在包含了无线 WiFi、有线网络的混合网络环境中,员工或访客、外包人员、合作伙伴等用户在接入网络时,如果无需进行身份验证及访问权限的管理,则很可能给不法分子可乘之…

一起Talk Android吧(第五百四十八回:如何创建垂直版SeekBar)

文章目录 概念介绍创建方法示例程序 各位看官们大家好,上一回中咱们说的例子是"蓝牙广播中的厂商数据",本章回中介绍的例子是" 如何创建垂直版SeekBar"。闲话休提,言归正转,让我们一起Talk Android吧! 概念介…

基于深度学习的高精度绵羊检测识别系统(PyTorch+Pyside6+YOLOv5模型)

摘要:基于深度学习的高精度绵羊检测识别系统可用于日常生活中或野外来检测与定位绵羊目标,利用深度学习算法可实现图片、视频、摄像头等方式的绵羊目标检测识别,另外支持结果可视化与图片或视频检测结果的导出。本系统采用YOLOv5目标检测模型…

Java基础知识之异常处理

目录 1.Java 异常处理 2.Exception 类的层次 3.Java 内置异常类 4.异常方法 5.捕获异常 6.多重捕获块 7.throws/throw 关键字 7.1 throw 关键字 7.2 throws 关键字 8.finally关键字 8.1 实例--ExcepTest.java 文件代码: 9.try-with-resources 9.1 try-…

外设驱动库开发笔记54:外设库驱动设计改进的思考

不知不觉中我们已经发布了五十多篇外设驱动的文章。前段时间有一位网友提出了一些非常中肯的建议,这也让我们开始考虑怎么优化驱动程序设计的问题。在这一篇中我们将来讨论这一问题。 1、问题分析 首先我们来分析一下网友提出的几点问题。第一点是说在驱动设计时&a…