贝叶斯神经网络用于学习曲线的概率预测【ICLR 2017】

news2024/11/20 10:36:36

论文下载地址:Excellent-Paper-For-Daily-Reading/hyper-parameters at main

类别:超参数

时间:2023/10/30

摘要

面对不同的神经网络结构、超参数和训练协议,通常需要检查生成学习曲线,以快速终止超参数设置不佳的运行,从而大大加快手动超参数优化。通过跨超参数设置的学习曲线的概率模型,可以在自动超参数优化中利用相同的信息。论文研究了贝叶斯神经网络的使用,并通过一个专门的学习曲线层来提高它们的性能。

论文完成的成果

  • 研究贝叶斯神经网络如何很好地适应各种架构和超参数设置的学习曲线,以及它们的不确定性估计有多可靠。
  • 开发了一个带有学习曲线层的专用神经网络架构,以改进学习曲线预测。
  • 比较了生成贝叶斯神经网络的不同方法:概率反向传播和两种不同的基于随机梯度的马尔可夫链蒙特卡罗(MCMC)方法。
  • 评估了全新学习曲线和外推部分观察曲线的预测质量,在学习曲线尚未收敛的阶段。
  • 扩展了多臂强盗策略(multi-armed bandit strategy),使用我们的模型进行采样,而不是均匀随机采样,从而使其能够比传统的贝叶斯优化更快地接近最优配置。

实验

学习曲线预测的实验

在实验部分,采用了不同的神经网络架构和学习曲线预测方法,并在不同数据集上进行了评估。实验结果表明,新模型的性能表现出良好的均方误差和平均对数似然,特别是使用随机梯度汉密尔顿MCMC方法(SGHMC)时表现更佳。此外,文章还比较了其他用于学习曲线预测的方法,包括随机森林、高斯过程、概率反向传播和简单的“最后一个观察到的值”方法。

左图显示了不同方法在CNN基准上的平均预测。所有模型都观察到真实学习曲线(黑色)的前12个epoch的验证误差。右图,绘制了40个epoch值的后验分布。

结论

论文研究了一种基于贝叶斯神经网络的学习曲线建模方法,为解决超参数优化和性能改进问题提供了新的思路和工具。贝叶斯神经网络的引入以及新型学习曲线层的设计为未来的研究和实践提供了有趣的方向。

这篇论文为深度学习领域的研究者提供了一个全新的视角,强调了贝叶斯神经网络在学习曲线预测和超参数优化中的重要性。通过结合不同领域的知识,我们有望进一步提高机器学习算法的性能。

学习率范围测试

学习率范围测试,又被称为LR Finder,是机器学习领域的一个重要实践工具。在深度学习模型训练中,学习率的选择通常是一个挑战,因为一个合适的学习率可以加速收敛并提高性能,但不合适的学习率可能导致训练不稳定或收敛缓慢。传统上,学习率的设定是基于经验和试错的,这篇论文介绍了一种更科学、更系统的方法,即学习率范围测试。

学习率范围测试的主要思想是在训练过程中逐渐增加学习率,然后观察模型的损失如何随学习率的增加而变化。通过分析损失与学习率之间的关系,可以找到一个合适的学习率范围,其中学习率既不会过高导致模型发散,也不会过低导致训练速度过慢。这种方法有助于为模型选择一个更有科学依据的初始学习率。

下面我根据模型、dataloader、损失函数和学习率进行调整:


import torch.nn as nn
import torch

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from utils.util import get_network
from utils.datasets import get_train_loader
from utils.lr_scheduler import _LRScheduler
from pyzjr.dlearn.learnrate import get_optimizer

class FindLR(_LRScheduler):
    """
    exponentially increasing learning rate

    Args:
        optimizer: optimzier(e.g. SGD)
        num_iter: totoal_iters
        max_lr: maximum  learning rate
    """
    def __init__(self, optimizer, max_lr=10, num_iter=100, last_epoch=-1):

        self.total_iters = num_iter
        self.max_lr = max_lr
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        return [base_lr * (self.max_lr / base_lr) ** (self.last_epoch / (self.total_iters + 1e-32)) for base_lr in self.base_lrs]

# class lr_finder():
#     def __init__(self,net, training_loader, loss_function,optimizer_type="sgd",num_iter=100,batch_size=4):
#         self.net = net
#         self.training_loader = training_loader
#         self.loss_function = loss_function
#         self.optimizer_type = optimizer_type
#         self.num_iter = num_iter
#         self.batch_size = batch_size
# 
#     def update(self, init_lr=1e-7, max_lr=10):
#         n = 0
#         learning_rate = []
#         losses = []
#         optimizer = get_optimizer(self.net, self.optimizer_type, init_lr)
#         lr_scheduler = FindLR(optimizer, max_lr=max_lr, num_iter=self.num_iter)
#         epoches = int(args.num_iter / len(self.training_loader)) + 1
# 
#         for epoch in range(epoches):
#             net.train()
#             for batch_index, (images, labels) in enumerate(self.training_loader):
#                 if n > self.num_iter:
#                     break
#                 if torch.cuda.is_available():
#                     images = images.cuda()
#                     labels = labels.cuda()
# 
#                 optimizer.zero_grad()
#                 predicts = net(images)
#                 loss = loss_function(predicts, labels)
#                 if torch.isnan(loss).any():
#                     n += 1e8
#                     break
#                 loss.backward()
#                 optimizer.step()
#                 lr_scheduler.step()
# 
#                 print('Iterations: {iter_num} [{trained_samples}/{total_samples}]\tLoss: {:0.4f}\tLR: {:0.8f}'.format(
#                     loss.item(),
#                     optimizer.param_groups[0]['lr'],
#                     iter_num=n,
#                     trained_samples=batch_index * self.batch_size + len(images),
#                     total_samples=len(self.training_loader),
#                 ))
# 
#                 learning_rate.append(optimizer.param_groups[0]['lr'])
#                 losses.append(loss.item())
#                 n += 1
# 
#         self.learning_rate = learning_rate[10:-5]
#         self.losses = losses[10:-5]
# 
#     def plotshow(self, show=True):
#         import matplotlib
#         matplotlib.use("TkAgg")
#         fig, ax = plt.subplots(1, 1)
#         ax.plot(self.learning_rate, self.losses)
#         ax.set_xlabel('learning rate')
#         ax.set_ylabel('losses')
#         ax.set_xscale('log')
#         ax.xaxis.set_major_formatter(plt.FormatStrFormatter('%.0e'))
#         if show:
#             plt.show()
# 
#     def save(self, path='result.jpg'):
#         self.plotshow(show=False)
#         plt.savefig(path)




if __name__ == '__main__':
    class parser_args():
        def __init__(self):
            self.net = "vgg16"
            self.batch_size = 64
            self.base_lr = 1e-7
            self.max_lr = 10
            self.num_iter = 100
            self.Cuda = True
            self.num_class = 4
            
    from pyzjr.dlearn.learnrate import lr_finder
    
    args = parser_args()

    txt_path = r"D:\PythonProject\Torchproject\classification\dataset\train.txt"
    train_loader = get_train_loader(txt_path, batch_size=4, train=True)

    net = get_network(args)

    loss_function = nn.CrossEntropyLoss()
    lrfinder = lr_finder(net, train_loader, loss_function)
    lrfinder.update()
    lrfinder.plotshow()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1154858.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《数据安全与流通:技术、架构与实践》新书发布

随着数据成为关键生产资料和要素,国内外数据安全相关的法律法规在快速完善,数据安全技术也在快速发展。5月25-26日,由星环科技、上海数据交易所、上海大数据联盟、财联社联合主办的向星力未来数据技术峰会 (FDTC)上&am…

Data Stream 复习(考试向)

Data Stream Review OverallUniform SamplingBloom FilterMisra-Gries AlgorithmCountMin Sketch AlgorithmCount Sketch Algorithm Overall Uniform Sampling Bloom Filter 一个箱子没有球的概率可以表示为 (1 - 1/n)^m 的原因是基于以下逻辑: 对于第一个球&#x…

Vue+OpenLayers6入门到实战进阶案例汇总目录,兼容OpenLayers7和OpenLayers8

本篇作为《VueOpenLayers入门教程》和《VueOpenLayers实战进阶案例》所有文章的二合一汇总目录,方便查找。 本专栏源码是由OpenLayers结合Vue框架编写。 本专栏从Vue搭建脚手架到如何引入OpenLayers依赖的每一步详细新手教程,再到通过各种入门案例和综合…

[双指针] (三) LeetCode LCR 179. 查找总价格为目标值的两个商品 和 15. 三数之和

[双指针] (三) LeetCode LCR 179. 查找总价格为目标值的两个商品 和 15. 三数之和 文章目录 [双指针] (三) LeetCode LCR 179. 查找总价格为目标值的两个商品 和 15. 三数之和查找总价格为目标值的两个商品题目分析解题思路代码实现总结 三数之和题目分析解题思路代码实现总结 …

Web APIs——其他事件

1、页面加载事件 加载外部资源(如图片、外联CSS和JavaScript等)加载完毕时触发的事件 为什么要学? 有些时候需要等页面资源全部处理完了做一些事情老代码喜欢把script写head中,这时候直接找dom元素找不到 事件名:load …

内存DMA及设备内存控制详解

序言 对于PCIe 设备(PCIe Endpoint)来说,其和CPU CORE、DRAM 的交互,主要涉及两种类型的内存访问: 设备内存访问:PCIe 设备的 Device Memory(设备内存)的访问,例如CPU …

java.net.URISyntaxException: Illegal character in query at index

现象 现象调用httpGet请求时,报错,如下: 原因 因为调用的url里有特殊字符 如单引号,双引号,等号,& | 等 解决方案 使用url带参构造方法,会对url里面的特殊字符进行转义处理 URL url n…

Python-常用的量化交易代码片段

算法交易正在彻底改变金融世界。通过基于预定义标准的自动化交易,交易者可以以闪电般的速度和比以往更精确的方式执行订单。如果您热衷于深入了解算法交易的世界,本指南提供了帮助您入门的基本代码片段。从获取股票数据到回溯测试策略,我们都能满足您的需求! 1. 使用 YFina…

k8s从私有仓库拉取镜像

从私有仓库拉取镜像 | Kubernetes 准备开始 你必须拥有一个 Kubernetes 的集群,同时你必须配置 kubectl 命令行工具与你的集群通信。 建议在至少有两个不作为控制平面主机的节点的集群上运行本教程。可以通过 Minikube 构建一个你自己的集群,或者你可以…

网管的利器之NMap

在进行网络管理过程中,网管会借助很多的工具比如付费的一些产品,比如漏洞扫描、安全隐患发现、网络设备管理、上网行为管理等。 更多的情况下,网管员使用一些DOS命令或者免费的工具进行,比如前面介绍过的PingInfoView.exe、WinMTR…

机器学习(六)构建机器学习模型

1.9构建机器学习模型 我们使用机器学习预测模型的工作流程讲解机器学习系统整套处理过程。 整个过程包括了数据预处理、模型学习、模型验证及模型预测。其中数据预处理包含了对数据的基本处理,包括特征抽取及缩放、特征选择、特征降维和特征抽样;我们将…

lambda表达式 - c++11

文章目录: lambda表达式概念lambda表达式语法函数对象与lambda表达式 lambda表达式概念 lambda 表达式是 c11 中引入的一种匿名函数,它可以在需要函数对象的地方使用,可以用作函数参数或返回值。lambda 表达式可以看作是一种局部定义的函数对…

mysql之用户管理、权限管理、密码管理

用户管理 创建用户create user 杨20.0.0.13 identified by 123; 用户重命名rename user 杨20.0.0.13 to yang20.0.0.13; 删除用户drop user 杨20.0.0.13; 权限管理 查看用户权限show grants for 杨20.0.0.13; 赋予用户权限grant all privileges on *.* to 杨localhost id…

文章导读助你高效成长

文章目录 Java基础篇MySQL数据库篇Redis缓存篇 📕我是廖志伟,一名Java开发工程师、Java领域优质创作者、CSDN博客专家、51CTO专家博主、阿里云专家博主、清华大学出版社签约作者、产品软文创造者、技术文章评审老师、问卷调查设计师、个人社区创始人、开…

超低直流电阻测试仪

KDZD5510半导体体积电阻率测试仪是一款针对超低直流电阻测试专门设计开发的一款高精度测试仪,界面清爽、操作便捷;量程范围为:0.01uΩ~10MΩ;显示位数为五位半;自动双向电流测试, 同时脉冲式的测试方式避免…

医院室内地图导航技术分析与作用

随着科技的不断发展,医疗行业的服务水平也在逐步提高。为了方便患者和医务人员,医院室内地图导航技术应运而生。这种技术运用了多种元素,包括模型地图、室内3D电子地图、路线指引、对接医院系统、位置分享和寻车导航等,为医院提供…

Three.js 开发引擎的特点

Three.js 是一个流行的开源 3D 游戏和图形引擎,用于在 Web 浏览器中创建高质量的三维图形和互动内容。以下是 Three.js 的主要特点和适用场合,希望对大家有所帮助。北京木奇移动技术有限公司,专业的软件外包开发公司,欢迎交流合作…

Python3,区区5行代码,制作期待的图表,这技能值得拥有(二)。

1、引言 小屌丝:鱼哥,这次按脚还不错? 小鱼:你说呢~ 小屌丝:那seabornde还记得? 小鱼:昂, 有印象 小屌丝:那咱开始整? 小鱼:这个… 行吧 小屌丝&…

ctfshow-web入门37-52

include($c);表达式包含并运行指定文件。 使用data伪协议 ?cdata://text/plain;base64,PD9waHAgc3lzdGVtKCdjYXQgZmxhZy5waHAnKTs/Pg PD9waHAgc3lzdGVtKCdjYXQgZmxhZy5waHAnKTs/Pg 是<?php system(cat flag.php);?> base64加密 源代码查看得到flag 38 多禁用了ph…

订水商城实战教程-06店铺信息

目录 1 创建数据源2 生成管理后台3 创建腾讯地图API4 配置小程序5 地址组件配置地图API6 显示店铺名称总结 上一篇我们介绍了权限控制&#xff0c;本篇我们就开始首页开发了。首页先需要显示店铺的名称&#xff0c;我们需要将店铺的信息存入数据源中。 1 创建数据源 打开控制台…