pyro ExponentialLR 如何设置优化器 optimizer的学习率 pytorch 深度神经网络 bnn,

news2024/9/22 5:25:19

 第一。pyro 不支持 “ReduceLROnPlateau” ,因为需要Loss作为输入数值,计算量大

第二 ,svi 支持 scheduler注意点,

属于  pyro.optim.PyroOptim的有三个
AdagradRMSProp ClippedAdam DCTAdam,但是还是会报错,类似下面的错误

optimizer = pyro.optim.SGD
# 指数下降学习率
pyro_scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': learn_rate}, 'gamma': 0.1})
Traceback (most recent call last):
  File "/home/aistudio/bnn_pyro_fso_middle_2_16__256.py", line 441, in <module>
    svi = SVI(model, mean_field_guide, optimizer, loss=Trace_ELBO())
  File "/home/aistudio/external-libraries/pyro/infer/svi.py", line 72, in __init__
    raise ValueError(
ValueError: Optimizer should be an instance of pyro.optim.PyroOptim class.

正确的方法是 

optimizer = torch.optim.SGD
# 指数下降学习率
pyro_scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': learn_rate}, 'gamma': 0.1})


# 设置ReduceLROnPlateau调度器
# 在这个例子中,`'min'`模式表示当验证集上的损失停止下降时,学习率会降低。`factor=0.1`表示学习率会乘以0.1,`
# patience=10`表示如果验证集损失在连续10个epochs内没有改善,则降低学习率。
# scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=20, verbose=True)

# svi = SVI(model, mean_field_guide, optimizer, loss=Trace_ELBO())
svi = SVI(model, mean_field_guide, pyro_scheduler, loss=Trace_ELBO())

ExponentialLR

ExponentialLR是指数型下降的学习率调节器,每一轮会将学习率乘以gamma,所以这里千万注意gamma不要设置的太小,不然几轮之后学习率就会降到0。

lr_scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
# 按指数衰减调整学习率,调整公式:lr = lr*gamma**epoch

这是一个用于动态生成参数的调整学习率的包装器,用于 `torch.optim.lr_scheduler` 对象。

    :param scheduler_constructor: 一个 `torch.optim.lr_scheduler` 的类
    :param optim_args: 一个包含优化器学习参数的字典,或者是一个返回此类字典的可调用对象。必须包含键 'optimizer',其值为 PyTorch 优化器的值
    :param clip_args: 一个包含 `clip_norm` 和/或 `clip_value` 参数的字典,或者是一个返回此类字典的可调用对象。

    例子::

        optimizer = torch.optim.SGD
        scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': 0.01}, 'gamma': 0.1})
        svi = SVI(model, guide, scheduler, loss=TraceGraph_ELBO())
        for i in range(epochs):
            for minibatch in DataLoader(dataset, batch_size):
                svi.step(minibatch)
            scheduler.step()
 

"""
A wrapper for :class:`~torch.optim.lr_scheduler` objects that adjusts learning rates
for dynamically generated parameters.

:param scheduler_constructor: a :class:`~torch.optim.lr_scheduler`
:param optim_args: a dictionary of learning arguments for the optimizer or a callable that returns
    such dictionaries. must contain the key 'optimizer' with pytorch optimizer value
:param clip_args: a dictionary of clip_norm and/or clip_value args or a callable that returns
    such dictionaries.

Example::

    optimizer = torch.optim.SGD
    scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': 0.01}, 'gamma': 0.1})
    svi = SVI(model, guide, scheduler, loss=TraceGraph_ELBO())
    for i in range(epochs):
        for minibatch in DataLoader(dataset, batch_size):
            svi.step(minibatch)
        scheduler.step()

pyro.optim.lr_scheduler的源代码

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, Iterable, List, Optional, Union, ValuesView

from torch import Tensor

from pyro.optim.optim import PyroOptim


class PyroLRScheduler(PyroOptim):
    """
    A wrapper for :class:`~torch.optim.lr_scheduler` objects that adjusts learning rates
    for dynamically generated parameters.

    :param scheduler_constructor: a :class:`~torch.optim.lr_scheduler`
    :param optim_args: a dictionary of learning arguments for the optimizer or a callable that returns
        such dictionaries. must contain the key 'optimizer' with pytorch optimizer value
    :param clip_args: a dictionary of clip_norm and/or clip_value args or a callable that returns
        such dictionaries.

    Example::

        optimizer = torch.optim.SGD
        scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': 0.01}, 'gamma': 0.1})
        svi = SVI(model, guide, scheduler, loss=TraceGraph_ELBO())
        for i in range(epochs):
            for minibatch in DataLoader(dataset, batch_size):
                svi.step(minibatch)
            scheduler.step()
    """

    def __init__(
        self,
        scheduler_constructor,
        optim_args: Union[Dict],
        clip_args: Optional[Union[Dict]] = None,
    ):
        # pytorch scheduler
        self.pt_scheduler_constructor = scheduler_constructor
        # torch optimizer
        pt_optim_constructor = optim_args.pop("optimizer")
        # kwargs for the torch optimizer
        optim_kwargs = optim_args.pop("optim_args")
        self.kwargs = optim_args
        super().__init__(pt_optim_constructor, optim_kwargs, clip_args)

    def __call__(self, params: Union[List, ValuesView], *args, **kwargs) -> None:
        super().__call__(params, *args, **kwargs)

    def _get_optim(
        self, params: Union[Tensor, Iterable[Tensor], Iterable[Dict[Any, Any]]]
    ):
        optim = super()._get_optim(params)
        return self.pt_scheduler_constructor(optim, **self.kwargs)

    def step(self, *args, **kwargs) -> None:
        """
        Takes the same arguments as the PyTorch scheduler
        (e.g. optional ``loss`` for ``ReduceLROnPlateau``)
        """
        for scheduler in self.optim_objs.values():
            scheduler.step(*args, **kwargs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2100371.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

号外!软考刷题小工具助力软考和 PMP 等级考试

一. 背景 四年前&#xff0c;我通过培训机构学习了 PMP&#xff0c;系统的学习了项目管理知识体系&#xff0c;说实话&#xff0c;学完感觉确实是有用的&#xff0c;尤其在项目管理方面&#xff0c;一些管理思维确实能够帮助到自己。 如果说 PMP 是国外的项目管理知识体系认证…

数据安全新纪元:Ftrans跨网跨域数据安全交换创新方案

随着业务的不断扩张和发展&#xff0c;大型组织企业&#xff0c;需要在不同的地理区域建设分支机构或办事处&#xff0c;用以覆盖更广泛的市场和客户群体&#xff0c;因此必然存在跨网跨域数据安全交换的场景需求。企业内部会同时存在下述一个或多个跨域文件交换场景&#xff1…

淘宝/天猫的拍立淘API:taobao.item_search_img返回值

淘宝/天猫的拍立淘&#xff08;Taobao Image Search&#xff09;功能允许用户通过上传图片来搜索相似的商品。然而&#xff0c;直接通过API使用这个功能&#xff08;如taobao.item_search_img&#xff09;在淘宝/天猫的官方API中并不直接提供。淘宝/天猫的开放平台&#xff08;…

k8s中pod基础及https密钥、horber仓库

一、pod基础&#xff1a; 1.pod是k8s里面最小单位&#xff0c;pod也是最小化运行容器的资源对象&#xff1b;容器是基于pod在k8s集群中工作&#xff1b;在k8s集群当中&#xff0c;一个pod就代表着一个运行进程&#xff0c;k8s的大部分组件都是围绕pod进行的&#xff0c;对pod进…

转换器和其他运放电路(恒流源+电压-电流/电流-电压转换器+峰值检测器)+故障检测(比较器故障+求和器故障)

2024-9-3&#xff0c;星期二&#xff0c;7:20&#xff0c;天气&#xff1a;阴转多云&#xff0c;心情&#xff1a;目前晴。又是上班的一天&#xff0c;没啥说的&#xff0c;加油上班&#xff0c;加油学习。 今天完成了第八章的学习&#xff0c;主要学习内容为&#xff1a;转换…

探索LLM大模型奥秘,新书《大模型入门:技术原理与实战应用》助你快速上手(附PDF下载)

随着大模型技术的不断完善和普及&#xff0c;我们将进入一个由数据驱动、智能辅助的全新工作模式和生活模式。个人和企业将能够利用大模型来降本增效&#xff0c;并创造全新的用户体验。 人工智能是人类探索未来的重要领域之一&#xff0c;以GPT为代表的大模型应用一经推出在短…

网站建设完成后, 做seo必须知道的专业知识之--robots协议

robots协议&#xff0c;也称为爬虫协议或机器人排除标准&#xff0c;是一种用于指导搜索引擎蜘蛛如何在网站上抓取和访问内容的协议。 网站建设完成后&#xff0c; 做seo必须知道的专业知识之--robots协议 通过这个协议&#xff0c;网站可以告诉搜索引擎哪些页面可以抓取&…

轴承知识大全,详细介绍(附3D图纸免费下载)

轴承一般由内圈、外圈、滚动体和保持架组成。对于密封轴承&#xff0c;再加上润滑剂和密封圈&#xff08;或防尘盖&#xff09;。这就是轴承的全部组成。 根据轴承使用的工作状况来选用不同类型的轴承&#xff0c;才能更好的发挥轴承的功能&#xff0c;并延长轴承的使用寿命。我…

设计模式 -- 迭代器模式(Iterator Pattern)

1 问题引出 编写程序展示一个学校院系结构&#xff1a;需求是这样&#xff0c;要在一个页面中展示出学校的院系组成&#xff0c;一个学校有多个学院&#xff0c; 一个学院有多个系 传统方式实现 将学院看做是学校的子类&#xff0c;系是学院的子类&#xff0c;这样实际上是站在…

CPU性能对比 Intel 海光 鲲鹏920 飞腾2500

横向对比一下x86和ARM芯片&#xff0c;以及不同方案权衡下的性能比较 CPU基本信息 海光 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 #lscpu Architecture: x86_64 CPU…

ssm基于微信小程序的食堂线上预约点餐系统论文源码调试讲解

2系统相关技术 2.1 Java语言简介 Java是由SUN公司推出&#xff0c;该公司于2010年被oracle公司收购。Java本是印度尼西亚的一个叫做爪洼岛的英文名称&#xff0c;也因此得来java是一杯正冒着热气咖啡的标识。Java语言在移动互联网的大背景下具备了显著的优势和广阔的前景&…

优化SQL查询之了解SQL执行顺序

文章目录 优化SQL查询之了解SQL执行顺序举例的SQL语句SQL执行顺序 - 文字解释SQL执行顺序 - 图示SQL执行顺序 - 动画演示distinct 子句会在哪个位置除了6个主要的关键字&#xff0c;还有哪些关键字总结 优化SQL查询之了解SQL执行顺序 SQL 查询的执行过程并非是简单的按照语句的…

话费接口API对接流程是什么?又有哪些优势?

话费接口 API 对接流程 前期准备 找一家专业做话费充值的公司&#xff0c;联系其商务了解对接的具体情况&#xff0c;包括合作模式、话费价格、消耗及打款金额是否可以开票、对接时是否有技术配合等 开户与对接 确定合作后在话费充值平台进行开户&#xff0c;获取账户参数及…

14、Django Admin的“Action(动作)”中添加额外操作

如图红框增加操作 将以下代码添加到HeroAdmin类中 actions ["mark_immortal"] def mark_immortal(self, request, queryset):queryset.update(is_immortalTrue) 修改后完整代码如下&#xff1a; admin.register(Hero) class HeroAdmin(admin.ModelAdmin):list_di…

c++返回一个pair类型

前言 Under the new standard we can list initialize the return value. 代码测试 #include<iostream> #include<string> #include<vector>std::pair<std::string, int> process(std::vector<std::string>& v) {if (!v.empty()){return …

无用但有趣的R包之教你怎么科学地用R语言摸鱼

如果你觉得R只是用来科研的工具&#xff0c;那就太辜负广大开发者的良苦用心了。今天给大家介绍几个useless但fun的的R包&#xff0c;为大家工作学习之余提供一点微不足道的小乐趣。 All work and no play makes jack a dull boy. 话不多说&#xff0c;游戏开始。 1.Fun包&a…

Nginx: 进程结构和信号量管理

进程结构 需要首先关注多进程和多线程的一个区别Nginx采用的是一种多进程&#xff0c;这样一种进程结构, 为何不采用多线程这样一种进程结构Nginx 设计之初就是为了高可能性和高可靠性而设计的Nginx 通常是运行在边缘节点上&#xff0c;通常是用来接受用户的第一个请求的时候&…

第三届人工智能与智能信息处理国际学术会议(AIIIP 2024)

目录 大会介绍 基本信息 合作单位 主讲嘉宾 会议组委 征文主题 ​编辑 参会方式 会议日程 中国-天津 | 2024年10月25-27日 | 会议官网&#xff1a;www.aiiip.net 大会介绍 第三届人工智能与智能信息处理国际学术会议&#xff08;AIIIP 2024&#xff09;将…

如何通过PLM系统提升企业研发效率与市场竞争力?

在当今快速变化的商业环境中&#xff0c;企业的研发能力和市场响应速度成为了决定其竞争力的关键因素。PLM系统作为一种集成了产品设计、开发、制造、销售、服务及最终报废等全生命周期信息的集成化管理平台&#xff0c;正逐渐成为企业提升研发效率、加速产品创新、优化资源配置…

Java后端面试题(微服务相关)(day12)

目录 分布式与微服务区别&#xff1f;什么是CAP原则&#xff1f;Spring Cloud Alibaba 组件有哪些&#xff1f;Nacos配置中心动态刷新原理目前主流的负载方案有哪些&#xff1f;Nginx作为服务端负载均衡器&#xff0c;常见的负载均衡策略有哪些&#xff1f;Spring Ribbon相关Sp…