PyTorch-Lightning:trining_step的自动优化

news2025/1/18 3:17:22

文章目录

  • PyTorch-Lightning:trining_step的自动优化
      • 总结:
    • class _ AutomaticOptimization()
      • def run
      • def _make_closure
      • def _training_step
        • class ClosureResult():
          • def from_training_step_output
      • class Closure

PyTorch-Lightning:trining_step的自动优化

使用PyTorch-Lightning时,在trining_step定义损失,在没有定义损失,没有任何返回的情况下没有报错,在定义一个包含loss的多个元素字典返回时,也可以正常训练,那么到底lightning是怎么完成训练过程的。

总结:

在自动优化中,training_step必须返回一个tensor或者dict或者None(跳过),对于简单的使用,在training_step可以return一个tensor会作为Loss回传,也可以return一个字典,其中必须包括key"loss",字典中的"loss"会提取出来作为Loss回传,具体过程主要包含在lightning\pytorch\loop\sautomatic.py中的_ AutomaticOptimization()类。

在这里插入图片描述

class _ AutomaticOptimization()

实现自动优化(前向,梯度清零,后向,optimizer step)

在training_epoch_loop中会调用这个类的run函数。

def run

首先通过 _make_closure得到一个closure,详见def _make_closure,最后返回一个字典,如果我们在training_step只return了一个loss tensor则字典只有一个’loss’键值对,如果return了一个字典,则包含其他键值对。

可以看到调用了_ optimizer_step,_ optimizer_step经过层层调用,最后会调用torch默认的optimizer.zero_grad,而我们通过 make_closure得到的closure作为参数传入,具体而言是调用了closure类的_ call __()方法。

def run(self, optimizer: Optimizer, batch_idx: int, kwargs: OrderedDict) -> _OUTPUTS_TYPE:
        closure = self._make_closure(kwargs, optimizer, batch_idx)

        if (
            # when the strategy handles accumulation, we want to always call the optimizer step
            not self.trainer.strategy.handles_gradient_accumulation and self.trainer.fit_loop._should_accumulate()
        ):
            # For gradient accumulation

            # -------------------
            # calculate loss (train step + train step end)
            # -------------------
            # automatic_optimization=True: perform ddp sync only when performing optimizer_step
            with _block_parallel_sync_behavior(self.trainer.strategy, block=True):
                closure()

        # ------------------------------
        # BACKWARD PASS
        # ------------------------------
        # gradient update with accumulated gradients
        else:
            self._optimizer_step(batch_idx, closure)

        result = closure.consume_result()
        if result.loss is None:
            return {}
        return result.asdict()

def _make_closure

创建一个closure对象,来捕捉给定的参数并且运行’training_step’和可选的其他如backword和zero_grad函数

比较重要的是step_fn,在这里调用了_training_step,得到的是一个存储我们在定义模型时重写的training step的输出所构成ClosureResult数据类。具体见def _training_step

def _make_closure(self, kwargs: OrderedDict, optimizer: Optimizer, batch_idx: int) -> Closure:

        step_fn = self._make_step_fn(kwargs)
        backward_fn = self._make_backward_fn(optimizer)
        zero_grad_fn = self._make_zero_grad_fn(batch_idx, optimizer)
        return Closure(step_fn=step_fn, backward_fn=backward_fn, zero_grad_fn=zero_grad_fn)

def _training_step

通过hook函数实现真正的训练step,返回一个存储training step输出的ClosureResult数据类。

将我们在定义模型时定义的lightning.pytorch.core.LightningModule.training_step的输出作为参数传入存储容器class ClosureResult的from_training_step_output方法,见class Closure

class ClosureResult():

一个数据类,包含closure_loss,loss,extra

    closure_loss: Optional[Tensor]
    loss: Optional[Tensor] = field(init=False, default=None)
    extra: Dict[str, Any] = field(default_factory=dict)
def from_training_step_output

一个类方法,如果我们在training_step定义的返回是一个字典,则我们会将key值为"loss"的value赋值给closure_loss,而其余的键值对赋值给extra字典,如果返回的既不是包含"loss"的字典也不是tensor,则会报错。当我们在training_step不设定返回,则自然为None,但是不会报错。

class Closure

闭包是将外部作用域中的变量绑定到对这些变量进行计算的函数变量,而不将它们明确地作为输入。这样做的好处是可以将闭包传递给对象,之后可以像函数一样调用它,但不需要传入任何参数。

在lightning,用于自动优化的Closure类将training_step和backward, zero_grad三个基本的闭包结合在一起。

这个Closure得到training循环中的结果之后传入torch.optim.Optimizer.step。

参数:

  • step_fn: 这里是一个存储了training step输出的ClosureResult数据类,见def _training_step
  • backward_fn: 梯度回传函数
  • zero_grad_fn: 梯度清零函数

按照顺序,会先检查得到loss,之后调用梯度清零函数,最后调用梯度回传函数

class Closure(AbstractClosure[ClosureResult]):

    warning_cache = WarningCache()

    def __init__(
        self,
        step_fn: Callable[[], ClosureResult],
        backward_fn: Optional[Callable[[Tensor], None]] = None,
        zero_grad_fn: Optional[Callable[[], None]] = None,
    ):
        super().__init__()
        self._step_fn = step_fn
        self._backward_fn = backward_fn
        self._zero_grad_fn = zero_grad_fn

    @override
    @torch.enable_grad()
    def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
        step_output = self._step_fn()

        if step_output.closure_loss is None:
            self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")

        if self._zero_grad_fn is not None:
            self._zero_grad_fn()

        if self._backward_fn is not None and step_output.closure_loss is not None:
            self._backward_fn(step_output.closure_loss)

        return step_output

    @override
    def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
        self._result = self.closure(*args, **kwargs)
        return self._result.loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1588156.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

算法 分割字符串为实体类

题目 String userData "10000:张三:男:1998-01-01#10001:张三:男:1998-01-01#10002:李四:女:1999-02-02#10003:王五:男:2000-03-03#10004:赵六:女:2001-04-04"; String[] usersArray userData.split("#"); // 使用Stream API将字符串数组转换为SysUser对…

ALV合并单元格

1、文章说明 在开发一些报表时,需要显示双层的标题,或者合并单元格的数据,归根结底就是要实现类似EXCEL合并单元格的需求。如图所示 网上的资料,很多根据国外某大神的方法实现:https://tricktresor.de/blog/zellen-ver…

JavaScript(四)-Web APIS

文章目录 日期对象实例化时间对象方法时间戳 节点操作DOM节点查找节点增加节点删除节点 M端事件JS插件Window对象BOM(浏览器对象模型)定时器-延时函数JS执行机制location对象navigator对象history对象 本地存储本地存储介绍本地存储分类localStoragesess…

Play Module Factory:Codigger系统上的高效Module开发工具

Play Module Factory,这款在Codigger系统上独树一帜的Play-Module开发工具,为广大的开发者们提供了一个全新的、高效的插件开发平台。它汇集了丰富的模板资源、基础库、API接口以及语言支持,这些功能强大的工具组合在一起,使得开发…

14亿美元!德国默克与AI生物科技公司合作;马斯克Neuralink首位脑机接口植入者用意念打游戏;黄仁勋在俄勒冈州立大学开讲

AI for Science 的新成果、新动态、新视角—— 日本第一 IT 公司富士通:生成式 AI 加速药物研发 马斯克:Neuralink 首位脑机接口植入者用「意念」打游戏 默克与 AI 生物科技公司 Caris 达成合作 AI 蛋白质设计服务提供商「天鹜科技」完成数千万元 Pre…

VE、希喂、PR猫咪主食冻干怎么样?测评品控最强、配方最好主食冻干!

我发现还是有不少铲屎官局限于“进口最高贵”,盲目的迷信进口产品。看到进口粮就盲买,甚至过分的贬低国产品牌,将国产粮贴上“不靠谱”“不合格”等标签。 最近,我针对主食冻干的国内、国际标准,相关规范文件&#xf…

C++资源重复释放问题

这不是自己释放了2次&#xff1b; 可能是类互相引用&#xff0c;有类似现象释放资源时引起&#xff1b;还不太了解&#xff1b; 类对象作为函数参数也会引起&#xff1b; 下面是一个简单示例&#xff1b; #include <iostream> #include <string.h> #include &l…

如何快速写一份简历

文章目录 如何快速写一份简历一些写简历的技巧 最近一段时间一直在忙简历相关的事情&#xff0c;起初是有一个其他行业的朋友问我&#xff0c;说这些简历我写了好久真难写&#xff0c;我说你可以借助AI&#xff0c;现在这种工具多了去了&#xff0c;为什么不借助呢&#xff1f;…

2024年 Mathorcup高校数学建模竞赛(B题)| 甲骨文识别 | 特征提取,图像分割解析,小鹿学长带队指引全代码文章与思路

我是鹿鹿学长&#xff0c;就读于上海交通大学&#xff0c;截至目前已经帮200人完成了建模与思路的构建的处理了&#xff5e; 本篇文章是鹿鹿学长经过深度思考&#xff0c;独辟蹊径&#xff0c;通过神经网络解决甲骨文识别问题。结合特征提取&#xff0c;图像分割等多元算法&…

Apabi Reader软件:打开ceb文件

Apabi Reader软件&#xff1a;打开ceb文件 软件下载软件安装 打开ceb文件参考 软件下载 下载官网-Apabi Reader软件 软件安装 打开ceb文件 ceb文件目录如下&#xff1a; 打开文件如下&#xff1a; 参考

从“黑箱”到“透明”:云里物里电子标签助力汽车总装数字化转型

“汽车总装”指“汽车产品&#xff08;包括整车及总成等&#xff09;的装配”&#xff0c;是把经检验合格的数以百计、或数以千计的各种零部件按照一定的技术要求组装成整车及发动机、变速器等总成的工艺过程&#xff0c;是汽车产品制造过程中最重要的工艺环节之一。 其中&…

社交网络与Web3:数字社交的下一阶段

随着信息技术的飞速发展&#xff0c;人们的社交方式也发生了巨大的变化。从最初的互联网聊天室到如今的社交网络平台&#xff0c;我们已经见证了数字社交的不断演变和发展。而随着区块链技术的兴起&#xff0c;Web3时代的到来将为数字社交带来全新的可能性和挑战。本文将探讨社…

.NET MAUI使用Visual Studio Android Emulator(安卓模拟器)运行

Android Emulator&#xff08;安卓模拟器&#xff09;运行&#xff1a; 安卓模拟器一直卡在不动&#xff1a; 在某些情况下&#xff0c;在“打开或关闭 Windows 功能”对话框中启用 Hyper-V 和 Windows 虚拟机监控程序平台后可能无法正确启用Hyper-V。 我就是开启Hyper-V才把安…

Python网络爬虫中JSON格式数据存储详解

目录 一、引言 二、JSON格式数据简介 三、Python中处理JSON数据 四、网络爬虫中获取JSON数据 五、存储JSON数据到文件 六、从文件中读取JSON数据 七、注意事项和常见问题 八、总结 一、引言 在网络爬虫的应用中&#xff0c;JSON格式数据以其轻量级、易读易写的…

Redis性能管理和集群的三种模式(二)

一、Redis集群模式 1.1 redis的定义 redis 集群 是一个提供高性能、高可用、数据分片、故障转移特性的分布式数据解决方案 1.2 redis的功能 数据分片&#xff1a;redis cluster 实现了数据自动分片&#xff0c;每个节点都会保存一份数据故障转移&#xff1a;若个某个节点发生故…

2024年第十四届 Mathorcup (B题)| 甲骨文智能识别 | 深度学习 计算机视觉 |数学建模完整代码+建模过程全解全析

当大家面临着复杂的数学建模问题时&#xff0c;你是否曾经感到茫然无措&#xff1f;作为2022年美国大学生数学建模比赛的O奖得主&#xff0c;我为大家提供了一套优秀的解题思路&#xff0c;让你轻松应对各种难题。 让我们来看看Mathorcup (B题&#xff09;&#xff01; CS团队…

基于springboot+vue的汽车租赁管理系统

背景介绍: 网络发展的越来越迅速&#xff0c;它深刻的影响着每一个人生活的各个方面。每一种新型事务的兴起都是为了使人们的生活更加方便。汽车租赁管理系统是一种低成本、更加高效的电子商务方式&#xff0c;它已慢慢的成为一种全新的管理模式。人们不再满足于在互联网上浏览…

自动化测试提速必备 - 并发编程

在自动化测试领域&#xff0c;多线程和多进程技术被广泛应用于提高测试的执行效率和性能。通过并发运行测试用例&#xff0c;可以显著缩短测试周期&#xff0c;特别是在面对大量测试用例或者需要在多个环境中进行测试时尤为重要。 在实际的自动化测试中&#xff0c;我们经常碰…

为什么选择成为程序员?

目录 兴趣和热爱高薪和就业机会持续学习和不断成长挑战和乐趣 兴趣和热爱 许多人选择成为程序员可能是热爱&#xff0c;对计算机&#xff0c;以及编程和科技产生了浓厚的兴趣&#xff0c;并且享受着解决每一个技术问题&#xff0c;构建应用程序和探索新技术所带来的乐趣。 谈到…

【多线程】 synchronized关键字 | 可重入锁 | 死锁 | volatile关键字 | 内存可见性问题 |waitnotify方法|

文章目录 synchronized和volatile关键字一、加锁互斥二、synchronized的使用1.修饰实例方法2.修饰类方法 三、可重入锁1.死锁关于死锁哲学家就餐问题如何避免死锁死锁成因的四个必要条件 四、volatile关键字1.保证内存可见性什么是内存可见性问题如何解决Java内存模型JMMvolati…