【mmengine】优化器封装(OptimWrapper)(入门)优化器封装 vs 优化器

news2024/10/2 18:35:36
  • MMEngine 实现了优化器封装,为用户提供了统一的优化器访问接口。优化器封装支持不同的训练策略,包括混合精度训练、梯度累加和梯度截断。用户可以根据需求选择合适的训练策略。优化器封装还定义了一套标准的参数更新流程,用户可以基于这一套流程,实现同一套代码,不同训练策略的切换。
  • 分别基于 Pytorch 内置的优化器和 MMEngine 的优化器封装(OptimWrapper)进行单精度训练混合精度训练梯度累加,对比二者实现上的区别。

一、 基于 Pytorch 的 SGD 优化器实现单精度训练

import torch
from torch.optim import SGD
import torch.nn as nn
import torch.nn.functional as F

inputs = [torch.zeros(10, 1, 1)] * 10
targets = [torch.ones(10, 1, 1)] * 10
model = nn.Linear(1, 1)
optimizer = SGD(model.parameters(), lr=0.01)
optimizer.zero_grad()

for input, target in zip(inputs, targets):
    output = model(input)
    loss = F.l1_loss(output, target)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

二、 使用 MMEngine 的优化器封装实现单精度训练

from mmengine.optim import OptimWrapper

optim_wrapper = OptimWrapper(optimizer=optimizer)

for input, target in zip(inputs, targets):
    output = model(input)
    loss = F.l1_loss(output, target)
    optim_wrapper.update_params(loss)

优化器封装的 update_params 实现了标准的梯度计算、参数更新和梯度清零流程,可以直接用来更新模型参数。
在这里插入图片描述

三、 基于 Pytorch 的 SGD 优化器实现混合精度训练

在这里插入图片描述

  • 混合精度训练:单精度 float和半精度 float16 混合,其优势为:
    • 内存占用更少
    • 计算更快
from torch.cuda.amp import autocast

model = model.cuda()
inputs = [torch.zeros(10, 1, 1, 1)] * 10
targets = [torch.ones(10, 1, 1, 1)] * 10

for input, target in zip(inputs, targets):
    with autocast():
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

四、 基于 MMEngine 的 优化器封装实现混合精度训练

from mmengine.optim import AmpOptimWrapper

optim_wrapper = AmpOptimWrapper(optimizer=optimizer)

for input, target in zip(inputs, targets):
    with optim_wrapper.optim_context(model):
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    optim_wrapper.update_params(loss)

在这里插入图片描述

  • 混合精度训练需要使用 AmpOptimWrapper,他的 optim_context 接口类似 autocast,会开启混合精度训练的上下文。除此之外他还能加速分布式训练时的梯度累加,这个我们会在下一个示例中介绍

五、 基于 Pytorch 的 SGD 优化器实现混合精度训练和梯度累加

for idx, (input, target) in enumerate(zip(inputs, targets)):
    with autocast():
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    loss.backward()
    if idx % 2 == 0:
        optimizer.step()
        optimizer.zero_grad()

六、基于 MMEngine 的优化器封装实现混合精度训练和梯度累加

optim_wrapper = AmpOptimWrapper(optimizer=optimizer, accumulative_counts=2)

for input, target in zip(inputs, targets):
    with optim_wrapper.optim_context(model):
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    optim_wrapper.update_params(loss)

在这里插入图片描述
只需要配置 accumulative_counts 参数,并调用 update_params 接口就能实现梯度累加的功能。除此之外,分布式训练情况下,如果我们配置梯度累加的同时开启了 optim_wrapper 上下文,可以避免梯度累加阶段不必要的梯度同步。

七、 获取学习率/动量

优化器封装提供了 get_lrget_momentum 接口用于获取优化器的一个参数组的学习率:

import torch.nn as nn
from torch.optim import SGD

from mmengine.optim import OptimWrapper

model = nn.Linear(1, 1)
# 优化器
optimizer = SGD(model.parameters(), lr=0.01)
# 封装器
optim_wrapper = OptimWrapper(optimizer)

print("get info from optimizer ------")
print(optimizer.param_groups[0]['lr'])  # 0.01
print(optimizer.param_groups[0]['momentum'])  # 0
print("get info from wrapper ------")
print(optim_wrapper.get_lr())  # {'lr': [0.01]}
print(optim_wrapper.get_momentum())  # {'momentum': [0]}

在这里插入图片描述

八、 导出/加载状态字典

优化器封装和优化器一样,提供了 state_dictload_state_dict 接口,用于导出/加载优化器状态,对于 AmpOptimWrapper,优化器封装还会额外导出混合精度训练相关的参数:

import torch.nn as nn
from torch.optim import SGD
from mmengine.optim import OptimWrapper, AmpOptimWrapper

model = nn.Linear(1, 1)
# 优化器
optimizer = SGD(model.parameters(), lr=0.01)

# ---- 导出 ---- #
print("print state_dict")
# 单精度封装器
optim_wrapper = OptimWrapper(optimizer=optimizer)
# 混合精度封装器
amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer)

# 导出状态字典
optim_state_dict = optim_wrapper.state_dict()
amp_optim_state_dict = amp_optim_wrapper.state_dict()
print(optim_state_dict)
print(amp_optim_state_dict)



# ---- 加载 ---- #
print("load state_dict")
# 单精度封装器
optim_wrapper_new = OptimWrapper(optimizer=optimizer)
# 混合精度封装器
amp_optim_wrapper_new = AmpOptimWrapper(optimizer=optimizer)

# 加载状态字典
amp_optim_wrapper_new.load_state_dict(amp_optim_state_dict)
optim_wrapper_new.load_state_dict(optim_state_dict)

在这里插入图片描述

九、 使用多个优化器

OptimWrapperDict 的核心功能是支持批量导出/加载所有优化器封装的状态字典;支持获取多个优化器封装的学习率、动量。如果没有 OptimWrapperDict,MMEngine 就需要在很多位置对优化器封装的类型做 if else 判断,以获取所有优化器封装的状态。

from torch.optim import SGD
import torch.nn as nn

from mmengine.optim import OptimWrapper, OptimWrapperDict

# model1
gen = nn.Linear(1, 1)
# model2
disc = nn.Linear(1, 1)

# optimizer1
optimizer_gen = SGD(gen.parameters(), lr=0.01)
# optimizer2
optimizer_disc = SGD(disc.parameters(), lr=0.01)

# wrapper1
optim_wapper_gen = OptimWrapper(optimizer=optimizer_gen)
# wrapper2
optim_wapper_disc = OptimWrapper(optimizer=optimizer_disc)

# wrapper_dict = wrapper1 + wrapper2
optim_dict = OptimWrapperDict(gen=optim_wapper_gen, disc=optim_wapper_disc)

print("wrapper_dict = wrapper1 + wrapper2")
print(optim_dict.get_lr())  # {'gen.lr': [0.01], 'disc.lr': [0.01]}
print(optim_dict.get_momentum())  # {'gen.momentum': [0], 'disc.momentum': [0]}

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2184850.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SWAP、AquaCrop、FVCOM、Delft3D、SWAT、R+VIC、HSPF、HEC-HMS......

全流程SWAP农业模型数据制备、敏感性分析及气候变化影响实践技术应用 SWAP模型是由荷兰瓦赫宁根大学开发的先进农作物模型,它综合考虑了土壤-水分-大气以及植被间的相互作用;是一种描述作物生长过程的一种机理性作物生长模型。它不但运用Richard方程&…

2024最新软件测试八股文(含答案+文档)

🍅 点击文末小卡片,免费获取软件测试全套资料,资料在手,涨薪更快 一、软件测试基础面试题 1、阐述软件生命周期都有哪些阶段? 常见的软件生命周期模型有哪些? 软件生命周期是指一个计算机软件从功能确定设计,到…

系统安全 - Linux 安全模型及实践

文章目录 导图Linux 安全模型用户层权限管理的细节多用户环境中的权限管理文件权限与目录权限 最小权限原则的应用Linux 系统中的认证、授权和审计机制认证机制授权机制审计机制 主机入侵检测系统(HIDS)_ Host-based Intrusion Detection SystemHIDS 的概…

Android问题笔记五十:构建错误-AAPT2 aapt2-7.0.2-7396180-windows Daemon

Unity3D特效百例案例项目实战源码Android-Unity实战问题汇总游戏脚本-辅助自动化Android控件全解手册再战Android系列Scratch编程案例软考全系列Unity3D学习专栏蓝桥系列ChatGPT和AIGC 👉关于作者 专注于Android/Unity和各种游戏开发技巧,以及各种资源分…

jmeter中token测试

案例: 网站:http://shop.duoceshi.com 讲解:用三个接口来讲解 第一个接口code:GET http://manage.duoceshi.com/auth/code 第二个登录接口:http://manage.duoceshi.com/auth/login 第三个接口:http://…

iOS中的链表 - 双向链表

iOS中的链表 - 单向链表_ios 链表怎么实现-CSDN博客​​​​​​​ 引言 在数据结构中,链表是一种常见的且灵活的线性存储方式。与数组不同,链表的元素在内存中不必连续存储,这使得它们在动态内存分配时更加高效。其中,双向链表…

Pikachu-Cross-Site Scripting-DOM型xss_x

查看代码&#xff0c;输入的内容&#xff0c;通过get请求方式&#xff0c;用text 参数带过去&#xff1b; 获取text内容&#xff0c;赋值给xss 然后拼接到 dom 里&#xff1b;构造payload的关键语句&#xff1a; <a href"xss">就让往事都随风,都随风吧</a&…

【SQL】DDL语句

文章目录 1.SQL通用语法2.SQL的分类3.DDL3.1数据库操作3.2 表操作3.2.1 表操作--数据类型3.2.2 表操作--修改3.2.3 表操作--删除 SQL 全称 Structured Query Language&#xff0c;结构化查询语言。操作关系型数据库的编程语言&#xff0c;定义了一套操作关系型数据库统一标准 。…

【Python语言初识(六)】

一、网络编程入门 1.1、TCP/IP模型 实现网络通信的基础是网络通信协议&#xff0c;这些协议通常是由互联网工程任务组 &#xff08;IETF&#xff09;制定的。所谓“协议”就是通信计算机双方必须共同遵从的一组约定&#xff0c;例如怎样建立连接、怎样互相识别等&#xff0c;…

解决MySQL报Incorrect datetime value错误

目录 一、前言二、问题分析三、解决方法 一、前言 欢迎大家来到权权的博客~欢迎大家对我的博客进行指导&#xff0c;有什么不对的地方&#xff0c;我会及时改进哦~ 博客主页链接点这里–>&#xff1a;权权的博客主页链接 二、问题分析 这个错误通常出现在尝试将一个不…

基于C++和Python的进程线程CPU使用率监控工具

文章目录 0. 概述1. 数据可视化示例2. 设计思路2.1 系统架构2.2 设计优势 3. 流程图3.1 C录制程序3.2 Python解析脚本 4. 数据结构说明4.1 CpuUsageData 结构体 5. C录制代码解析5.1 主要模块5.2 关键函数5.2.1 CpuUsageMonitor::Run()5.2.2 CpuUsageMonitor::ComputeCpuUsage(…

Python库matplotlib之五

Python库matplotlib之五 小部件(widget)RangeSlider构造器APIs应用实列 TextBox构造器APIs应用实列 小部件(widget) 小部件(widget)可与任何GUI后端一起工作。所有这些小部件都要求预定义一个Axes实例&#xff0c;并将其作为第一个参数传递。 Matplotlib不会试图布局这些小部件…

【数学分析笔记】第4章第1节 微分和导数(1)

4. 微分 4.1 微分和导数 考虑一个函数 y f ( x ) yf(x) yf(x)&#xff0c;当 x x x做一些微小的变动&#xff0c;函数值也会有微小的变动&#xff0c;比如&#xff1a; x → x △ x x\to x\bigtriangleup x x→x△x&#xff0c;则 f ( x ) → f ( x △ x ) f(x)\to f(x\bi…

足球青训俱乐部管理:Spring Boot技术驱动

摘 要 随着社会经济的快速发展&#xff0c;人们对足球俱乐部的需求日益增加&#xff0c;加快了足球健身俱乐部的发展&#xff0c;足球俱乐部管理工作日益繁忙&#xff0c;传统的管理方式已经无法满足足球俱乐部管理需求&#xff0c;因此&#xff0c;为了提高足球俱乐部管理效率…

数据结构——“AVL树”的四种数据旋转的方法

因为上次普通的二叉搜索树在极端情况下极容易造成我们的链式结构&#xff08;这会导致我们查询的时间复杂度变为O(n)&#xff09;&#xff0c;然而AVL树就很好的解决了这一问题&#xff08;归功于四种旋转的方法&#xff09;&#xff0c;它让我们的树的查询的时间复杂度变得接近…

TIM(Timer)定时器的原理

一、介绍 硬件定时器的工作原理基于时钟信号源提供稳定的时钟信号作为计时器的基准。计数器从预设值开始计数&#xff0c;每当时钟信号到达时计数器递增。当计数器达到预设值时&#xff0c;定时器会触发一个中断信号通知中断控制器处理相应的中断服务程序。在中断服务程序中&a…

无人化焦炉四大车系统 武汉正向科技 工业机车无人远程控制系统

焦炉四大车无人化系统介绍 采用格雷母线光编码尺双冗余定位技术&#xff0c;炉门视觉定位自学习技术&#xff0c;wifi5G无线通讯技术&#xff0c;激光雷达安全识别技术&#xff0c;焦化智慧调度&#xff0c;手机APP监控功能。 焦炉四大车无人化系统功能 该系统能自动生成生产…

遥感图像垃圾处理场分割,北京地区高分2图像,3500张图像,共2GB,分割为背景,空地,垃圾,垃圾处理设施四类

遥感图像垃圾处理场分割&#xff0c;北京地区高分2图像&#xff0c;3500张图像&#xff0c;共2GB&#xff0c;分割为背景&#xff0c;空地&#xff0c;垃圾&#xff0c;垃圾处理设施四类 遥感图像垃圾处理场分割数据集 规模 图像数量&#xff1a;3500张数据量&#xff1a;2G…

黑科技!Llama 3.2多模态AI震撼发布

黑科技&#xff01;Llama 3.2多模态AI震撼发布 Meta发布Llama 3.2模型&#x1f680;&#xff0c;引领AI新潮流&#xff01;它能处理文字、图片、视频&#x1f4f8;&#xff0c;满足不同需求&#xff0c;性能媲美大牌选手✨。一键启动包已准备好&#xff0c;让你轻松体验AI的魔…

模版and初识vector

一、引言 在C语言中&#xff0c;不论是数组&#xff0c;还是结构体定义的数组&#xff0c;功能都比较欠缺&#xff0c;不是单纯的添加几个变量就能够解决的。缺少增删查改的功能&#xff0c;为了解决这个问题&#xff0c;C决定填上C语言这个坑&#xff0c;但是填过坑的人都知道…