小白学Pytorch系列--Torch.optim API Base class(1)

news2025/1/18 9:01:55

小白学Pytorch系列–Torch.optim API Base class(1)


torch.optim是一个实现各种优化算法的包。大多数常用的方法都已得到支持,而且接口足够通用,因此将来还可以轻松集成更复杂的方法。

如何使用优化器

使用手torch.optim您必须构造一个优化器对象,该对象将保存当前状态,并将根据计算出的梯度更新参数。

构造它

要构造一个优化器,你必须给它一个包含参数(所有应该是变量)的可迭代对象来优化。然后,您可以指定特定于优化器的选项,如学习率、权重衰减等。

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)

每个参数选项

优化器还支持指定每个参数的选项。要做到这一点,不是传递一个变量的迭代对象,而是传递一个dict的迭代对象。它们每个都将定义一个单独的参数组,并且应该包含一个params键,包含属于它的参数列表。其他键应该与优化器接受的关键字参数匹配,并将用作该组的优化选项。

注意: 您仍然可以将选项作为关键字参数传递。在没有覆盖它们的组中,它们将作为默认值使用。当您只想改变一个选项,同时在参数组之间保持所有其他选项一致时,这很有用。

例如,当想要指定每层的学习率时,这是非常有用的

optim.SGD([
                {'params': model.base.parameters()},
                {'params': model.classifier.parameters(), 'lr': 1e-3}
            ], lr=1e-2, momentum=0.9)

这意味着这个model.base参数将使用默认的学习率1e-2model.classifier’参数将使用1e-3的学习率,所有参数将使用0.9的动量。

进行优化步骤

所有优化器都实现了一个step()方法,用于更新参数。它有两种用法
optimizer.step()
这是大多数优化器支持的简化版本。该函数可以在梯度计算完成后调用,例如backward()

例如:

for input, target in dataset:
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

optimizer.step(closure)
一些优化算法(如共轭梯度和LBFGS)需要多次重新计算函数,所以你必须传入一个闭包,允许它们重新计算你的模型。闭包应该清除梯度,计算损失并返回。

for input, target in dataset:
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        return loss
    optimizer.step(closure)

Base class

这部分参考了:https://zhuanlan.zhihu.com/p/87209990
PyTorch 的优化器基本都继承于 “class Optimizer”,这是所有 optimizer 的 base class。
下面是Optimizer的结构

class Optimizer(object):
    def __init__(self, params, defaults):
        self.defaults = defaults
        self._hook_for_profile()
        if isinstance(params, torch.Tensor):
            raise TypeError("params argument given to the optimizer should be "
                            "an iterable of Tensors or dicts, but got " +
                            torch.typename(params))

        self.state = defaultdict(dict)
        self.param_groups = []

        param_groups = list(params)
        if len(param_groups) == 0:
            raise ValueError("optimizer got an empty parameter list")
        if not isinstance(param_groups[0], dict):
            param_groups = [{'params': param_groups}]
        for param_group in param_groups:
            self.add_param_group(param_group)
    def state_dict(self):
    	...

    def load_state_dict(self, state_dict):
        ...

    def cast(param, value):
    	...

    def zero_grad(self, set_to_none: bool = False):
    	...
    
    def step(self, closure):
    	...
    
    def add_param_group(self, param_group):    
    	...

init 函数初始化

paramsdefaults是两个重要的参数,defaults定义了全局优化默认值,params定义了模型参数和局部优化默认值。

add_param_group

defaultdict的作用在于当字典里的 key 被查找但不存在时,返回的不是keyError而是一个默认值,此处defaultdict(dict)`返回的默认值会是个空字典。最后一行调用的self.add_param_group(param_group),其中param_group是个字典,Key 就是params,Value 就是param_groups = list(params)。

def add_param_group(self, param_group):
        params = param_group['params']
        if isinstance(params, torch.Tensor):
            param_group['params'] = [params]
        elif isinstance(params, set):
            raise TypeError('optimizer')
        else:
            param_group['params'] = list(params)

        for param in param_group['params']:
            if not isinstance(param, torch.Tensor):
                raise TypeError("optimizer " + torch.typename(param))
            if not param.is_leaf:
                raise ValueError("can't optimize a non-leaf Tensor")

        for name, default in self.defaults.items():
            if default is required and name not in param_group:
                raise ValueError("parameter group didn't specify a value of required optimization parameter " +
                                 name)
            else:
                param_group.setdefault(name, default) # 给参数设置默认参数

        params = param_group['params']
        if len(params) != len(set(params)):
            warnings.warn("optimizer contains ", stacklevel=3)

        param_set = set()
        for group in self.param_groups:
            param_set.update(set(group['params']))

        if not param_set.isdisjoint(set(param_group['params'])): # 判断两个集合是否包含相同的元素
            raise ValueError("some parameters appear in more than one parameter group")

        self.param_groups.append(param_group)

zero_grad

就是将所有参数的梯度置为零p.grad.zero_()。detach_()的作用是Detaches the Tensor from the graph that created it, making it a leaf. self.param_groups是列表,其中的元素是字典。

def zero_grad(self):
    r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
    for group in self.param_groups:
        for p in group['params']:
            if p.grad is not None:
                p.grad.detach_()
                p.grad.zero_()

step

更新参数作用, 在父类 Optimizer 的 step 函数中只有一行代码raise NotImplementedError。网络模型参数和优化器的参数都保存在列表 self.param_groups 的元素中,该元素以字典形式存储和访问具体的网络模型参数和优化器的参数。所以,可以通过两层循环访问网络模型的每一个参数 p 。获取到梯度d_p = p.grad.data之后,根据优化器参数设置是否使用 momentum或者nesterov再对参数进行调整。最后一行 p.data.add_(-group['lr'], d_p)的作用是对参数进行更新。state用于保存本次更新是优化器第几轮迭代更新参数。

下面以SGD优化器为例

def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            d_p_list = []
            momentum_buffer_list = []
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            maximize = group['maximize']
            lr = group['lr']

            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)
                    d_p_list.append(p.grad)

                    state = self.state[p]
                    if 'momentum_buffer' not in state:
                        momentum_buffer_list.append(None)
                    else:
                        momentum_buffer_list.append(state['momentum_buffer'])

            F.sgd(params_with_grad,
                  d_p_list,
                  momentum_buffer_list,
                  weight_decay=weight_decay,
                  momentum=momentum,
                  lr=lr,
                  dampening=dampening,
                  nesterov=nesterov,
                  maximize=maximize,)

            # update momentum_buffers in state
            for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
                state = self.state[p] ## 保存
                state['momentum_buffer'] = momentum_buffer
        return loss

F.sgd

def sgd(params: List[Tensor],
        d_p_list: List[Tensor],
        momentum_buffer_list: List[Optional[Tensor]],
        *,
        weight_decay: float,
        momentum: float,
        lr: float,
        dampening: float,
        nesterov: bool,
        maximize: bool):
    for i, param in enumerate(params):
        d_p = d_p_list[i]
        if weight_decay != 0:
            d_p = d_p.add(param, alpha=weight_decay)
        if momentum != 0:
            buf = momentum_buffer_list[i]
            if buf is None:
                buf = torch.clone(d_p).detach()
                momentum_buffer_list[i] = buf
            else:
                buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
            if nesterov:
                d_p = d_p.add(buf, alpha=momentum)
            else:
                d_p = buf
        alpha = lr if maximize else -lr
        param.add_(d_p, alpha=alpha)

SGD上引入了一个Momentum(又叫Heavy Ball)的改进。

load_state_dict

加载优化器状态。

def load_state_dict(self, state_dict):

     # deepcopy, to be consistent with module API
     state_dict = deepcopy(state_dict)
     # Validate the state_dict
     groups = self.param_groups
     saved_groups = state_dict['param_groups']

     if len(groups) != len(saved_groups):
         raise ValueError("loaded state dict has a different number of "
                          "parameter groups")
     param_lens = (len(g['params']) for g in groups)
     saved_lens = (len(g['params']) for g in saved_groups)
     if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
         raise ValueError("loaded state dict contains a parameter group "
                          "that doesn't match the size of optimizer's group")

     # Update the state
     id_map = {old_id: p for old_id, p in
               zip(chain.from_iterable((g['params'] for g in saved_groups)),
                   chain.from_iterable((g['params'] for g in groups)))}

     def cast(param, value):
         r"""Make a deep copy of value, casting all tensors to device of param."""
         if isinstance(value, torch.Tensor):
             # Floating-point types are a bit special here. They are the only ones
             # that are assumed to always match the type of params.
             if param.is_floating_point():
                 value = value.to(param.dtype)
             value = value.to(param.device)
             return value
         elif isinstance(value, dict):
             return {k: cast(param, v) for k, v in value.items()}
         elif isinstance(value, container_abcs.Iterable):
             return type(value)(cast(param, v) for v in value)
         else:
             return value

     # Copy state assigned to params (and cast tensors to appropriate types).
     # State that is not assigned to params is copied as is (needed for
     # backward compatibility).
     state = defaultdict(dict)
     for k, v in state_dict['state'].items():
         if k in id_map:
             param = id_map[k]
             state[param] = cast(param, v)
         else:
             state[k] = v

state_dict

以字典的形式返回优化器的状态。

def state_dict(self):    
     # Save order indices instead of Tensors
      param_mappings = {}
      start_index = 0

      def pack_group(group):
          nonlocal start_index
          packed = {k: v for k, v in group.items() if k != 'params'}
          param_mappings.update({id(p): i for i, p in enumerate(group['params'], start_index)
                                 if id(p) not in param_mappings})
          packed['params'] = [param_mappings[id(p)] for p in group['params']]
          start_index += len(packed['params'])
          return packed
      param_groups = [pack_group(g) for g in self.param_groups]
      # Remap state to use order indices as keys
      packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
                      for k, v in self.state.items()}
      return {
          'state': packed_state,
          'param_groups': param_groups,
      }

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/415150.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

下载和阅读Android源码

目录一、如何下载AOSP1.全量下载2.单个下载目录结构二、如何阅读AOSP1.要阅读哪些源码2.阅读源码的顺序和方式2.1 阅读顺序2.2 阅读方式3.用什么工具来阅读3.1 下载安装Source Insight3.2 导入AOSP源码3.3查看源码三、其他一、如何下载AOSP 源码下载是我们分析源码的开始&…

ctfshow愚人杯web复现

easy_signin 题目url base64解码是face.png&#xff0c;尝试flag.txt和flag.php&#xff0c;base64加密后传入都不对&#xff0c;用index.php加密后传入&#xff0c;看源码 将后面的base64解密得到flag 被遗忘的反序列化 源码 <?php# 当前目录中有一个txt文件哦 error_r…

Unity- 游戏结束以及重启游戏

文章目录游戏结束以及重启游戏建个游戏结束页面编写委托类 游戏主角 以及 ui管理类的脚本重启游戏游戏结束以及重启游戏 思路&#xff1a;利用Canvas创建好覆盖全屏的结束页面&#xff0c;默认关闭。游戏结束时&#xff0c;玩家控制的对象发起委托&#xff0c;ui管理收下委托&…

electron+vue3全家桶+vite项目搭建【六】集成vue-i18n 国际化

文章目录注意引入1.引入依赖2.集成vue i18n3.测试代码4.封装多语言切换组件5.测试多语言切换6.优化代码注意 已发现 9.2.2版本的vue-i18n 如果使用cnpm安装&#xff0c;打包会报错&#xff0c;使用npm或者pnpm安装依赖没有问题 引入 如果需要多语言支持&#xff0c;那么最好…

【100个 Unity实用技能】 | Lua中获取当前时间戳,时间戳和时间格式相互转换、时间戳转换为多久之前

Unity 小科普 老规矩&#xff0c;先介绍一下 Unity 的科普小知识&#xff1a; Unity是 实时3D互动内容创作和运营平台 。包括游戏开发、美术、建筑、汽车设计、影视在内的所有创作者&#xff0c;借助 Unity 将创意变成现实。Unity 平台提供一整套完善的软件解决方案&#xff…

【AI大比拼】文心一言 VS ChatGPT-4

摘要&#xff1a;本文将对比分析两款知名的 AI 对话引擎&#xff1a;文心一言和 OpenAI 的 ChatGPT&#xff0c;通过实际案例让大家对这两款对话引擎有更深入的了解&#xff0c;以便大家选择合适的 AI 对话引擎。 亲爱的 CSDN 朋友们&#xff0c;大家好&#xff01;近年来&…

Python自动录入ERP系统数据

大家好&#xff0c;我是毕加锁。 今天给大家带来的是用Python解决Excel问题的最佳姿势 文末送书&#xff01; 文末送书&#xff01; 文末送书&#xff01; 项目总体情况 软件&#xff1a;Pycharm 环境: Python 3.7.9(考虑到客户可能会有不同操作系统&#xff0c;为了兼容性…

【小程序】django学习笔记3

今天我们来做数据库和django的关联。 根据之前的代码应该看得出来我想做一个获取访客的ip地址并计算访问次数的app&#xff0c;所以必然会用到数据库。 这里选择用的是mysql(因为免费) 不一样的是这里我们打算用django提供的orm框架对数据库进行操作。 一. 环境准备 首先安…

SLAM面试笔记(3) - 视觉SLAM

目录 1 紧耦合、松耦合的区别 &#xff08;1&#xff09;紧耦合和松耦合的区别 &#xff08;2&#xff09;紧耦合和松耦合的分类 &#xff08;3&#xff09;为什么要使用紧耦合 2 SIFT和SUFT的区别 3 视差与深度的关系 4 闭环检测常用方法 5 描述PnP算法 6 梯度下降法…

SQL基础

目录 1.库操作 2.表操作 3.表操作--修改 4.表操作 --删表 5.添加数据 管理数据 查询表中数据&#xff08;重点&#xff09; 判空条件 1.模糊条件查询 2.聚合查询&#xff08;函数&#xff09; 3.排序查询 4.分页查询 5.分组查询&#xff08;配合聚合函数用于统计&a…

C++模拟实现读写锁

文章目录一、读者写者问题二、读写锁1.读写锁的概念2.读写锁的设计(1)成员变量(2)构造函数和析构函数(3)readLock函数(4)readUnlock函数(5)writeLock函数(6)writeUnlock函数3.RWLock类代码三、测试读写锁一、读者写者问题 在编写多线程的时候&#xff0c;有一种情况是非常常见…

为什么黑客不黑/攻击赌博网站?如何入门黑客?

攻击了&#xff0c;只是你不知道而已&#xff01; 同样&#xff0c;对方也不会通知你&#xff0c;告诉你他黑了赌博网站。 攻击赌博网站的不一定是正义的黑客&#xff0c;也可能是因赌博输钱而误入歧途的法外狂徒。之前看过一个警方破获的真实案件&#xff1a;28岁小伙因赌博…

Linux 操作系统原理作业 - 行人与机动车问题

大三上学期操作系统原理这门课中&#xff0c;老师给了一道作业《行人与机动车问题》&#xff1b; 即Linux多线程下处理行人与机动车谁优先的问题&#xff0c;需要用到多线程和互斥量&#xff1b; 行人 - 机动 车问题 假设有一个路口&#xff0c;有很多行人和机动车需要通过&a…

1673_MIT 6.828 Homework xv6 lazy page allocation要求翻译

全部学习汇总&#xff1a; GreyZhang/g_unix: some basic learning about unix operating system. (github.com) 在计划表中看到了这样一份作业&#xff0c;做一个简单的翻译整理。原来的页面&#xff1a;Homework: xv6 lazy page allocation (mit.edu) 家庭作业&#xff1a;x…

代码版本M、RC、GA、Release等标识的区别

引言 最近听说spring framework有了重大版本调整&#xff0c;出了6.0的GA版本了 那GA是啥意思呢&#xff1f; 看了下spring 官网和代码仓库&#xff0c;除了GA&#xff0c;还有M、RC、Release等 Spring FrameworkLevel up your Java code and explore what Spring can do f…

[Java Web]element | 一个由饿了么公司开发的前端框架,让你快速构建现代化、美观的 Web 应用程序。

⭐作者介绍&#xff1a;大二本科网络工程专业在读&#xff0c;持续学习Java&#xff0c;努力输出优质文章 ⭐作者主页&#xff1a;逐梦苍穹 ⭐所属专栏&#xff1a;Java Web ⭐如果觉得文章写的不错&#xff0c;欢迎点个关注一键三连&#x1f609;有写的不好的地方也欢迎指正&a…

【mybatis】mybatis的工作原理

目录一、工作流程二、说明2.1 构建SqlSessionFactory2.2 SqlSession的获取2.3 SqlSession执行语句三、源码结构3.1 接口层3.2 核心处理3.3 核心处理层四、代码示例4.1 通过inputStream构建SqlSessionFactory4.2 通过configuration构建SqlSessionFactory4.3 mybatis-config.xml示…

groovy环境搭建

什么是DSL? 领域特定语言DSL&#xff08;全称&#xff1a;domain specific language&#xff09; 常见的DSL语言有&#xff1a;UML、HTML、SQL、XML、Groovy 作用&#xff1a;解决某一特定领域的问题 什么是groovy? groovy是一种基于JVM的敏捷开发语言。 结合了Python、Ruby和…

Vite4+Vuejs3项目初步搭建,并部署多个vue项目到nginx

前提条件 1、熟悉命令行 2、已安装 16.0 或更高版本的 Node.js 参照vuejs官网的步骤&#xff0c;创建一个vue前端项目 当前vuejs的版本&#xff1a;3.2.47 npm init vuelatestVue.js - The Progressive JavaScript Framework√ Project name: ... vuejs3-project√ Add Type…

BitDock桌面美化工具 一直在后台偷偷上传东西,具体上传什么东西不知,一天耗费我几十个G的流量

通过流量防火墙监控发现bitdock一直在上传东西&#xff0c;目前截止发现已上传了40G的流量 ――――――――――――――――――――――― 程序名称&#xff1a;SystemAudioDetection.exe 程序说明&#xff1a; 路径&#xff1a;D:\BitDock\AudioEngine\SystemAudioDetecti…