深度学习笔记:dropout和调优超参数方法

news2024/10/7 6:47:44

1 Dropout

Dropout是一个常用于深度学习的减轻过拟合的方法。该方法在每一轮训练中随机删除部分隐藏层神经元。被删除的神经元不会进行正向或反向信号传递。在测试阶段所有神经元都会传递信号,但对各个神经元的输出要乘以训练时删除比例。

在这里插入图片描述
Dropout实现程序:

class Dropout:
    """
    http://arxiv.org/abs/1207.0580
    """
    def __init__(self, dropout_ratio=0.5):
        self.dropout_ratio = dropout_ratio
        self.mask = None

    def forward(self, x, train_flg=True):
        if train_flg:
            self.mask = np.random.rand(*x.shape) > self.dropout_ratio
            return x * self.mask
        else:
            return x * (1.0 - self.dropout_ratio)

    def backward(self, dout):
        return dout * self.mask

在该程序中,我们初始化dropout的比例为0.5。在正向传播中,如果train_flg为True(神经网络在训练状态),会生成一个和输入x形状相同的boolean矩阵mask,mask为False的位置(即被删除的神经元)在正向传播和反向传播的结果都为0。当train_flg为False时(神经网络在预测状态),会将正向传播的结果乘以(1.0 - self.dropout_ratio)输出

实验:利用dropout抑制过拟合

在该测试程序中,我们使用7层神经网络,每层神经元个数100,权重更新方法SGD,学习率0.01,进行300轮训练。每一轮训练样本量仅为300以增大过拟合概率。

此时我们在网络中使用dropout,并测试训练准确度和测试准确度的差距

# coding: utf-8
import os
import sys
sys.path.append("D:\AI learning source code")  # 为了导入父目录的文件而进行的设定
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.multi_layer_net_extend import MultiLayerNetExtend
from common.trainer import Trainer

(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)

# 为了再现过拟合,减少学习数据
x_train = x_train[:300]
t_train = t_train[:300]

# 设定是否使用Dropuout,以及比例 ========================
use_dropout = True  # 不使用Dropout的情况下为False
dropout_ratio = 0.2
# ====================================================

network = MultiLayerNetExtend(input_size=784, hidden_size_list=[100, 100, 100, 100, 100, 100],
                              output_size=10, use_dropout=use_dropout, dropout_ration=dropout_ratio)
trainer = Trainer(network, x_train, t_train, x_test, t_test,
                  epochs=301, mini_batch_size=100,
                  optimizer='sgd', optimizer_param={'lr': 0.01}, verbose=True)
trainer.train()

train_acc_list, test_acc_list = trainer.train_acc_list, trainer.test_acc_list

# 绘制图形==========
markers = {'train': 'o', 'test': 's'}
x = np.arange(len(train_acc_list))
plt.plot(x, train_acc_list, marker='o', label='train', markevery=10)
plt.plot(x, test_acc_list, marker='s', label='test', markevery=10)
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()

结果:
在这里插入图片描述
可以看到dropout可以一定程度上抑制过拟合(没有抑制过拟合的对照组实验数据在之前的文章http://t.csdn.cn/b7c9S)

dropout抑制过拟合的原理:

dropout可以看做是对集成学习的模拟。集成学习即为运行多个相同或类似结构的神经网络进行训练,然后预测时求各个网络预测结果的平均值。dropout删除神经元的操作类似于在每一轮训练一个略有不同的神经网络,最后将每一个网络的结果在预测阶段叠加,实现利用一个神经网络模拟集成学习。

超参数的验证

1 对数据分类:

一般来说,我们将神经网络数据集分为训练数据,测试数据,和验证数据。训练数据用于训练网络权重和偏置,测试数据用于验证网络泛化能力,防止过拟合,而验证数据用于评估超参数。这里注意不可使用测试数据验证超参数,否则可能导致超参数对测试数据过拟合,即调试的超参数只适用于测试数据。

对于没有进行分类的数据集,如这里测试用的mnist数据集,我们可以先在训练数据中分出20%作为验证数据

(x_train, t_train), (x_test, t_test) = load_mnist()

x_train, t_train = shuffle_dataset(x_train, t_train)

validation_rate = 0.20
validation_num = int(x_train.shape[0] * validation_rate)

x_val = x_train[:validation_num]
t_val = t_train[:validation_num]
x_train = x_train[validation_num:]
t_train = t_train[validation_num:]

这里我们在分割数据集前先随机打乱数据集以进一步保障数据随机性,然后将数据集前20%作为验证数据,后80%作为训练数据

2 超参数最优化步骤

1 确定一个大致的超参数取值范围
2 在设置的范围内随机取得超参数值
3 利用该超参数值进行学习,使用验证数据评估学习精度(要将学习的epoch数设置很小以节省时间)
4 多次重复2,3步骤,并根据结果缩小取值范围

有研究表明在范围内随机取样比规律性搜索效果更好,因为在多个超参数中,不同超参数对识别精度影响程度不同

一般来说,找到合适的超参数数量级即可,不需要精确到具体某个值

超参数优化实验

# coding: utf-8
import sys, os
sys.path.append("D:\AI learning source code")  # 为了导入父目录的文件而进行的设定
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.multi_layer_net import MultiLayerNet
from common.util import shuffle_dataset
from common.trainer import Trainer

(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)

# 为了实现高速化,减少训练数据
x_train = x_train[:500]
t_train = t_train[:500]

# 分割验证数据
validation_rate = 0.20
validation_num = int(x_train.shape[0] * validation_rate)
x_train, t_train = shuffle_dataset(x_train, t_train)
x_val = x_train[:validation_num]
t_val = t_train[:validation_num]
x_train = x_train[validation_num:]
t_train = t_train[validation_num:]


def __train(lr, weight_decay, epocs=50):
    network = MultiLayerNet(input_size=784, hidden_size_list=[100, 100, 100, 100, 100, 100],
                            output_size=10, weight_decay_lambda=weight_decay)
    trainer = Trainer(network, x_train, t_train, x_val, t_val,
                      epochs=epocs, mini_batch_size=100,
                      optimizer='sgd', optimizer_param={'lr': lr}, verbose=False)
    trainer.train()

    return trainer.test_acc_list, trainer.train_acc_list


# 超参数的随机搜索======================================
optimization_trial = 100
results_val = {}
results_train = {}
for _ in range(optimization_trial):
    # 指定搜索的超参数的范围===============
    weight_decay = 10 ** np.random.uniform(-8, -4)
    lr = 10 ** np.random.uniform(-6, -2)
    # ================================================

    val_acc_list, train_acc_list = __train(lr, weight_decay)
    print("val acc:" + str(val_acc_list[-1]) + " | lr:" + str(lr) + ", weight decay:" + str(weight_decay))
    key = "lr:" + str(lr) + ", weight decay:" + str(weight_decay)
    results_val[key] = val_acc_list
    results_train[key] = train_acc_list

# 绘制图形========================================================
print("=========== Hyper-Parameter Optimization Result ===========")
graph_draw_num = 20
col_num = 5
row_num = int(np.ceil(graph_draw_num / col_num))
i = 0

for key, val_acc_list in sorted(results_val.items(), key=lambda x:x[1][-1], reverse=True):
    print("Best-" + str(i+1) + "(val acc:" + str(val_acc_list[-1]) + ") | " + key)

    plt.subplot(row_num, col_num, i+1)
    plt.title("Best-" + str(i+1))
    plt.ylim(0.0, 1.0)
    if i % 5: plt.yticks([])
    plt.xticks([])
    x = np.arange(len(val_acc_list))
    plt.plot(x, val_acc_list)
    plt.plot(x, results_train[key], "--")
    i += 1

    if i >= graph_draw_num:
        break

plt.show()

这一次我们利用mnist数据集测试寻找学习率和权值衰减率这两个超参数的最优值。我们对学习率初始取值10e-6值10e-3 (python程序里表示为 10 ** np.random.uniform(-6, -2)).权重衰减值范围10e-8到10e-4.

注:np.random.uniform范围左闭右开

我们使用7层神经网络,每隐藏层神经元个数100.为了加快训练速度,我们每次训练样本量500,每轮训练100次。在训练完成后,我们选出其中测试准确度最高的前20组,并绘制其训练和测试准确度变化图象。

第一轮前20组结果如下:
在这里插入图片描述

=========== Hyper-Parameter Optimization Result ===========
Best-1(val acc:0.79) | lr:0.007167252581661965, weight decay:4.46411213720861e-08
Best-2(val acc:0.76) | lr:0.009590540451785262, weight decay:3.4336503187454634e-06
Best-3(val acc:0.75) | lr:0.007574924636516419, weight decay:1.8694988419536705e-08
Best-4(val acc:0.75) | lr:0.008691535159159042, weight decay:5.358524004570447e-05
Best-5(val acc:0.74) | lr:0.008869169716698419, weight decay:5.906144381013852e-07
Best-6(val acc:0.74) | lr:0.007231956339661699, weight decay:5.1648136815500515e-08
Best-7(val acc:0.73) | lr:0.006823929935149825, weight decay:6.587792640131085e-06
Best-8(val acc:0.72) | lr:0.005506621708230906, weight decay:6.907800794800708e-06
Best-9(val acc:0.71) | lr:0.009160313797949956, weight decay:5.047627735555578e-08
Best-10(val acc:0.7) | lr:0.006506168259111944, weight decay:4.886725890956679e-05
Best-11(val acc:0.69) | lr:0.009539273908181656, weight decay:1.1043227008270985e-05
Best-12(val acc:0.69) | lr:0.007798786280457097, weight decay:2.566925672778291e-06
448464, weight decay:9.47276313549097e-06
Best-18(val acc:0.56) | lr:0.004121973859588465, weight decay:1.278735236820568e-07
Best-19(val acc:0.54) | lr:0.005094840375624678, weight decay:1.3168432394782485e-07
Best-20(val acc:0.52) | lr:0.0021804464282101947, weight decay:1.4357750811527073e-07

根据第一轮测试结果,我们将学习率范围缩小到10e-4至10e-3,将权值衰减率范围缩小到10e-8至10e-6,开始第二轮实验

第二轮实验结果
在这里插入图片描述

=========== Hyper-Parameter Optimization Result ===========
Best-1(val acc:0.8) | lr:0.007860334709347631, weight decay:3.5744103167252943e-07
Best-2(val acc:0.79) | lr:0.009101764902840044, weight decay:7.304973331423018e-08
Best-3(val acc:0.79) | lr:0.0077682308379276544, weight decay:8.08776787254994e-08
Best-4(val acc:0.78) | lr:0.007695432026250778, weight decay:2.4699443793745103e-07
Best-5(val acc:0.77) | lr:0.005844523564130418, weight decay:1.1571928419657492e-08
Best-6(val acc:0.76) | lr:0.008100878064710397, weight decay:9.921361407618477e-07
Best-7(val acc:0.74) | lr:0.007225166920829242, weight decay:1.1863229062662078e-08
Best-8(val acc:0.74) | lr:0.005706941972517327, weight decay:4.595973192460211e-06
Best-9(val acc:0.73) | lr:0.006445581641493332, weight decay:2.9452804347731977e-07
Best-10(val acc:0.73) | lr:0.00693450459896848, weight decay:2.0186451326501946e-06
Best-11(val acc:0.65) | lr:0.004925747032013345, weight decay:5.919824221197248e-07
Best-12(val acc:0.64) | lr:0.007030579937109058, weight decay:3.6013880654217633e-06
Best-13(val acc:0.64) | lr:0.005998218080578915, weight decay:2.6194581561983115e-06
Best-14(val acc:0.63) | lr:0.004333280515717759, weight decay:3.293391789330398e-06
Best-15(val acc:0.57) | lr:0.003608522576771299, weight decay:8.316798762001302e-08
Best-16(val acc:0.56) | lr:0.0036538430517432698, weight decay:8.428957748502751e-07
Best-17(val acc:0.56) | lr:0.002554863407218098, weight decay:2.398791245915158e-06
Best-18(val acc:0.54) | lr:0.003487090693530097, weight decay:3.1726798158025344e-08
Best-19(val acc:0.54) | lr:0.00494510123268657, weight decay:1.1596593557212992e-07
Best-20(val acc:0.53) | lr:0.004066867353963753, weight decay:3.0451709733114034e-08

可以看到排名靠前的模型学习率基本上都固定在10e-3这个数量级,而对于weight decay则主要集中在10e-7和10e-8中。下一轮实验中我们将学习率固定为0.007,而weight decay定在10e-7和10e-8间

第三轮实验结果:
在这里插入图片描述

=========== Hyper-Parameter Optimization Result ===========
Best-1(val acc:0.9) | lr:0.007, weight decay:2.0423654363915898e-07
Best-2(val acc:0.88) | lr:0.007, weight decay:2.0464597011239193e-08
Best-3(val acc:0.86) | lr:0.007, weight decay:2.1418827188785024e-07
Best-4(val acc:0.85) | lr:0.007, weight decay:2.568854551694536e-08
Best-5(val acc:0.85) | lr:0.007, weight decay:2.5102331031174626e-08
Best-6(val acc:0.84) | lr:0.007, weight decay:1.641132728663186e-08
Best-7(val acc:0.84) | lr:0.007, weight decay:2.2668599269481673e-08
Best-8(val acc:0.84) | lr:0.007, weight decay:1.2593373600583017e-07
Best-9(val acc:0.83) | lr:0.007, weight decay:5.965681876085778e-07
Best-10(val acc:0.83) | lr:0.007, weight decay:3.184447130202175e-08
Best-11(val acc:0.83) | lr:0.007, weight decay:6.429143459339116e-07
Best-12(val acc:0.83) | lr:0.007, weight decay:9.829436450746176e-07
Best-13(val acc:0.83) | lr:0.007, weight decay:1.475800382782156e-08
Best-14(val acc:0.83) | lr:0.007, weight decay:4.4960017860216687e-08
Best-15(val acc:0.82) | lr:0.007, weight decay:2.6357274929242283e-08
Best-16(val acc:0.82) | lr:0.007, weight decay:5.25981898629044e-08
Best-17(val acc:0.82) | lr:0.007, weight decay:3.165929380917094e-08
Best-18(val acc:0.81) | lr:0.007, weight decay:4.0322974353983454e-08
Best-19(val acc:0.81) | lr:0.007, weight decay:5.27878837562846e-07
Best-20(val acc:0.81) | lr:0.007, weight decay:8.988618606625146e-07

结论:学习率设置在10e-3这个数量级,weight decay在10e-7到10e-8直接都可以,影响不大

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/390788.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

毕业设计 基于STM32单片机无线ZIGBEE智能大棚土壤湿度光照检测

基于STM32单片机无线ZIGBEE智能大棚土壤湿度光照检测1、项目简介1.1 系统构成1.2 系统功能2、部分电路设计2.1 STM32F103C8T6核心系统电路设计2.2 光敏采集电路设计2.3 温度采集电路设计3、部分代码展示3.1 读取DS18B20温度值3.2 定时器初始化1、项目简介 选题指导&#xff0c…

Learning Typescript and React in ts

目录 Basic typescript What is typescript? Configuring the TypeScript Compiler Fundamental build in types TypeScript Simple Types TypeScript Special Types Type: unknown Type: never Type: undefined & null Arrays Tuple Enums functions Ob…

Java集合专题

文章目录框架体系CollectionListArrayListLinkedListVectorSetHashSetLinkedHashSetTreeSetMapHashMapHashtableLinkedHashMapTreeMapPropertiesCollections框架体系 1、集合主要分了两组(单列集合,双列集合) 2、Collection接口有两个重要的子…

2.SpringSecurity认证

2.1登录校验流程 2.2认证原理 *源码流程: *自定义认证流程: *校验流程: *认证和校验连接: 2.3思路分析 *登录:

SQLI-Labs通关(2)5-7关

跟之前一样首先传参,然后查看注入点以及闭合 用and 11 and 12都没问题,接下来测试单引号 利用 and 12的时候会报错 利用order by来判断列数 得出一共三列 接下来就是联合查询 但是这个并不会回显 那么就利用盲注或者报错注入 在这里我们利用报错来测…

Vue3的生命周期函数

文章目录🌟 写在前面🌟 生命周期钩子函数🌟 组合式API生命周期🌟 写在最后🌟 写在前面 专栏介绍: 凉哥作为 Vue 的忠实 粉丝输出过大量的 Vue 文章,应粉丝要求开始更新 Vue3 的相关技术文章&am…

OPPO 数据恢复:如何从 OPPO 手机恢复已删除的文件?

Oppo 手机以其精美的外观和拍摄的精美照片和视频而闻名。如果您不小心丢失了 OPPO 手机中珍贵的照片、视频等重要文件,并且为如何找回而苦恼,那么您来对地方了。我们其实有很多OPPO数据恢复方案,现在最重要的是尽快尝试这些方法,防…

Git 相关内容

目录 Git 相关流程和常用命令 Git workflow Git hooks Git 相关流程和常用命令 Git远程操作详解 - 阮一峰的网络日志 Git 使用规范流程 - 阮一峰的网络日志 常用 Git 命令清单 - 阮一峰的网络日志 Git workflow 啥玩意: 就是一个工作流程。可以比喻成一个河流…

用逻辑回归制作评分卡

目录 一.评分卡 二.导库,获取数据 三.探索数据与数据预处理 1.去除重复值 2.填补缺失值 3.描述性统计处理异常值 4.为什么不统一量纲,也不标准化数据分布 5.样本不均衡问题 6.分训练集和测试集 三.分箱 1.分多少个箱子才合适 2.分箱要达成什么…

Antlr4: 为parser rule添加label

1. parser rule中的label 1.1 简介 Antrl4语法文件Calculator.g4,stat和expr两个parser rule含有多个rule element,我们这两个parse rule的每个rule element添加了Alternative labels(简称label) 按照Antlr4的语法规则&#xff…

2022年显卡性能跑分排名表

2022年显卡性能跑分排名表(数据来源于快科技)这个版本的电脑显卡跑分榜第一的是NVIDIA GeForce RTX 3090 Ti显卡。由于显卡跑分受不同的测试环境、不同的显卡驱动版本以及不同散热设计而有所不同,所以显卡跑分会一直变化。 前二十名的台式电…

Linux(进程概念详解)

进程是如今编程领域非常重要的一个概念,进程是比较抽象的,不容易直接理解。因为进程与操作系统息息相关,因此在介绍进程之前,笔者打算先简易讲一下操作系统的工作流程,理解操作系统是如何管理软件和硬件的,…

垃圾收集器和内存分配(第五章)

《实战Java虚拟机:JVM故障诊断与性能优化 (第2版)》 Java 平台,标准版热点虚拟机垃圾回收调优指南 垃圾收集器虽然看起来数量比较多,但其实总体逻辑都是因为硬件环境的升级而演化出来的产品,不同垃圾收集器的产生总体可以划分为几…

智能优化算法应用:基于蚁狮优化算法的工程优化案例-附代码

智能优化算法应用:基于蚁狮算法的工程优化案例 文章目录智能优化算法应用:基于蚁狮算法的工程优化案例1.蚁狮算法2.压力容器设计问题3.三杆桁架设计问题4.拉压弹簧设计问题5.Matlab代码6.python代码摘要:本文介绍利用蚁狮搜索算法&#xff0c…

191、【动态规划】AcWing —— 900. 整数划分:完全背包解法+加减1解法(C++版本)

题目描述 参考文章:900. 整数划分 解题思路 因为本题中规定了数字从大到小,其实也就是不论是1 2 1 4,还是2 1 1 4,都会被看作是2 1 1 4这一种情况,因此本题是在遍历中不考虑结果顺序。 背包问题中只需考虑…

AcWing:并查集

并查集理论基础并查集的作用是什么:将两个集合合并。询问两个元素是否在一个集合当中。如果不使用并查集,要完成上述两个操作,我们需要:创建一个数组来表示某个元素在某个集合之中,如belong[x] a,即x元素在…

0201基础-组件-React

1 组件和模块 1.1 模块 对外提供特定功能的js程序,一般就是一个js文件 为什么拆分模块呢?随着业务逻辑增加,代码越来越多,越来越复杂。作用:复用js,简化js,提高js运行效率 1.2 模块化 当应用…

用gdb.attach()在gdb下断点但没停下的情况及解决办法

在python中,如果导入了pwntools,就可以使用里面的gdb.attach(io)的命令来下断点。 但是这一次鼠鼠遇到了一个情况就是下了断点,但是仍然无法在断点处开始运行,奇奇怪怪。 这是我的攻击脚本 我们运行一下。 可以看到其实已经运行起…

计算机网络模型、协议

ARP(IP->MAC)RARP(MAC->IP)TFTPHTTPDHCPNATARP(IP->MAC) 主机建立自己的ARP缓冲区存ARP列表 广播ARP请求,单播ARP响应 RARP(MAC->IP) 用于无盘工作站&am…

Java分布式全局ID(一)

随着互联网的不断发展,互联网企业的业务在飞速变化,推动着系统架构也在不断地发生变化。 如今微服务技术越来越成熟,很多企业都采用微服务架构来支撑内部及对外的业务,尤其是在高 并发大流量的电商业务场景下,微服务…