【深度学习实战(25)】搭建训练框架之ModelEMA

news2025/2/24 17:20:08

一、什么是ModelEMA:

在深度学习中,经常会使用EMA(指数移动平均)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。
指数移动平均(Exponential Moving Average)也叫权重移动平均(Weighted Moving Average),是一种给予近期数据更高权重的平均方法。

二、如何实现ModelEMA

创建EMA eval mode,去并行化

self.ema = deepcopy(de_parallel(model)).eval() 

EMA更新次数

self.updates = updates

根据更新次数,获取衰减系数

self.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)

去掉梯度,ema不需要梯度

for p in self.ema.parameters():
    p.requires_grad_(False)

EMA更新次数+1

self.updates += 1

根据更新次数,获取衰减系数

d = self.decay(self.updates)

根据衰减系数,当前模型(去并行化)来修改当前ema模型

msd = de_parallel(model).state_dict()  # model state_dict
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= d
                    v += (1 - d) * msd[k].detach()

三、ModelEMA完整实现

#----------------------#
#   判断是否并行训练模式
#----------------------#
def is_parallel(model):
    # Returns True if model is of type DP or DDP
    return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)

#----------------------#
#   去并行训练模式
#----------------------#
def de_parallel(model):
    # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
    return model.module if is_parallel(model) else model

#----------------------#
#   模型拷贝
#----------------------#
def copy_attr(a, b, include=(), exclude=()):
    # Copy attributes from b to a, options to only include [...] and to exclude [...]
    for k, v in b.__dict__.items():
        if (len(include) and k not in include) or k.startswith('_') or k in exclude:
            continue
        else:
            setattr(a, k, v)


class ModelEMA:
    """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
    Keeps a moving average of everything in the model state_dict (parameters and buffers)
    For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
    """

    def __init__(self, model, decay=0.9999, tau=2000, updates=0):
        #----------------------#
        #   创建EMA eval mode,去并行化
        #----------------------#
        self.ema = deepcopy(de_parallel(model)).eval()  # FP32 EMA
        #----------------------#
        #   EMA更新次数
        #----------------------#
        self.updates = updates
        #----------------------#
        #   根据更新次数,获取衰减系数
        #----------------------#
        self.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)
        #----------------------#
        #   去掉梯度,ema不需要梯度
        #----------------------#
        for p in self.ema.parameters():
            p.requires_grad_(False)

    #----------------------#
    #   根据ema更新次数获取衰减系数,再根据衰减系数和当前模型(去并行化)来修改当前ema模型
    #----------------------#
    def update(self, model):
        # Update EMA parameters
        with torch.no_grad():
            #----------------------#
            #   EMA更新次数+1
            #----------------------#
            self.updates += 1
            #----------------------#
            #   根据更新次数,获取衰减系数
            #----------------------#
            d = self.decay(self.updates)
            print('decay:',d)
            dict_decay.append(d)

            #----------------------#
            #   根据衰减系数,当前模型(去并行化)来修改当前ema模型
            #----------------------#
            msd = de_parallel(model).state_dict()  # model state_dict
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= d
                    v += (1 - d) * msd[k].detach()

    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
        # Update EMA attributes
        copy_attr(self.ema, model, include, exclude)

四、ModelEMA在训练框架中的使用

#----------------------#
#   搭建训练框架
#----------------------#
model = models.AlexNet()
model = model.train()
#----------------------#
#   创建EMA模型
#----------------------#
ema = ModelEMA(model)


num_train_data = 50
batch_size = 10
epoch_step = num_train_data // batch_size

Init_epoch = 50
Total_epoch = 60

#----------------------#
#   记录EMA更新次数
#----------------------#
ema.updates = Init_epoch * epoch_step

#----------------------#
#   训练
#----------------------#
for epoch in range(Init_epoch, Total_epoch):
    dict_epoch.append(epoch)
    for iter in range(epoch_step):
        #----------------------#
        #   根据ema更新次数获取衰减系数,再根据衰减系数和当前模型(去并行化)来修改当前ema模型
        #----------------------#
        ema.update(model)

#----------------------#
#   验证
#----------------------#
#----------------------#
#   获取EMA eval mode,去并行化
#----------------------#
model = ema.ema
for epoch in range(Init_epoch, Total_epoch):
    for iter in range(epoch_step):
        pass

#----------------------#
#   保存权重
#----------------------#
#----------------------#
#   获取EMA模型的权重
#----------------------#
save_state_dict = ema.ema.state_dict()
path = "yourpath"
torch.save(save_state_dict,path)
print('dene')

五、完整代码

import torch
import math
import torch.nn as nn
from copy import deepcopy
from torchvision import models
import matplotlib.pyplot as plt

dict_decay = []
dict_update_num = []

#----------------------#
#   判断是否并行训练模式
#----------------------#
def is_parallel(model):
    # Returns True if model is of type DP or DDP
    return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)

#----------------------#
#   去并行训练模式
#----------------------#
def de_parallel(model):
    # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
    return model.module if is_parallel(model) else model

#----------------------#
#   模型拷贝
#----------------------#
def copy_attr(a, b, include=(), exclude=()):
    # Copy attributes from b to a, options to only include [...] and to exclude [...]
    for k, v in b.__dict__.items():
        if (len(include) and k not in include) or k.startswith('_') or k in exclude:
            continue
        else:
            setattr(a, k, v)


class ModelEMA:
    """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
    Keeps a moving average of everything in the model state_dict (parameters and buffers)
    For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
    """

    def __init__(self, model, decay=0.9999, tau=2000, updates=0):
        #----------------------#
        #   创建EMA eval mode,去并行化
        #----------------------#
        self.ema = deepcopy(de_parallel(model)).eval()  # FP32 EMA
        #----------------------#
        #   EMA更新次数
        #----------------------#
        self.updates = updates
        #----------------------#
        #   根据更新次数,获取衰减系数
        #----------------------#
        self.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)
        #----------------------#
        #   去掉梯度,ema不需要梯度
        #----------------------#
        for p in self.ema.parameters():
            p.requires_grad_(False)

    #----------------------#
    #   根据ema更新次数获取衰减系数,再根据衰减系数和当前模型(去并行化)来修改当前ema模型
    #----------------------#
    def update(self, model):
        # Update EMA parameters
        with torch.no_grad():
            #----------------------#
            #   EMA更新次数+1
            #----------------------#
            self.updates += 1
            #----------------------#
            #   根据更新次数,获取衰减系数
            #----------------------#
            d = self.decay(self.updates)
            print('decay:',d)
            dict_decay.append(d)

            #----------------------#
            #   根据衰减系数,当前模型(去并行化)来修改当前ema模型
            #----------------------#
            msd = de_parallel(model).state_dict()  # model state_dict
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= d
                    v += (1 - d) * msd[k].detach()

    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
        # Update EMA attributes
        copy_attr(self.ema, model, include, exclude)


#----------------------#
#   搭建训练框架
#----------------------#
model = models.AlexNet()
model = model.train()
#----------------------#
#   创建EMA模型
#----------------------#
ema = ModelEMA(model)


num_train_data = 100
batch_size = 10
epoch_step = num_train_data // batch_size

Init_epoch = 50
Total_epoch = 300

#----------------------#
#   记录EMA更新次数
#----------------------#
ema.updates = Init_epoch * epoch_step

#----------------------#
#   训练
#----------------------#
num_update = 0
for epoch in range(Init_epoch, Total_epoch):
    for iter in range(epoch_step):
        #----------------------#
        #   根据ema更新次数获取衰减系数,再根据衰减系数和当前模型(去并行化)来修改当前ema模型
        #----------------------#
        ema.update(model)
        num_update += 1
        dict_update_num.append(num_update)


#----------------------#
#   验证
#----------------------#
#----------------------#
#   获取EMA eval mode,去并行化
#----------------------#
model = ema.ema
for epoch in range(Init_epoch, Total_epoch):
    for iter in range(epoch_step):
        pass

#----------------------#
#   保存权重
#----------------------#
#----------------------#
#   获取EMA模型的权重
#----------------------#
save_state_dict = ema.ema.state_dict()
path = "yourpath"
#torch.save(save_state_dict,path)
print('dene')



# -----------------------------------------------#
#   save EMA decay figure
# -----------------------------------------------#
plt.figure()
plt.title('EMA decay during training')
plt.plot(dict_update_num, dict_decay, label="EMA decay")
plt.legend()
plt.grid()
plt.draw()
plt.savefig('EMA decay')
plt.show()

EMA decay 曲线变化图
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1628507.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redis面试题二(数据存储)

目录 1.redis 的数据过期策略 1. 惰性删除(Lazy Expiration) 2. 定期删除(Periodic Expiration) 3. 定时删除(Timing-Based Expiration) 实际应用中的组合策略 2.redis 有哪些内存淘汰机制 volatile&…

uniapp 根据不同角色实现动态底部TabBar导航栏

文章目录 前言最终效果一、实现步骤1.配置page.json中的tabBar属性2.创建自定义tabBar文件3.配置Vuex4.在main.js中引入并挂载store:5.登录页内引入自定义tabbar,根据角色进行登录验证6.在每个导航页中使用自定义的tabbar 前言 在UniApp的开发过程中&am…

Swift - 函数

文章目录 Swift - 函数1. 函数的定义2. 隐式返回(Implicit Return)3. 返回元组:实现多返回值4. 函数的文档注释5. 参数标签(Argument Label)6. 默认参数值(Default Parameter Value)7. 可变参数(Variadic P…

【Java】全套云HIS源码包含EMR、LIS(多医院、卫生机构使用)

云HIS系统简介 SaaS模式Java版云HIS系统源码,在公立二甲医院应用三年,经过多年持续优化和打磨,系统运行稳定、功能齐全,界面布局合理、操作简便。 1、融合B/S版电子病历系统,支持电子病历四级,HIS与电子病…

Redis(七) zset有序集合类型

文章目录 前言命令ZADDZCARDZCOUNTZRANGEZREVRANGEZRANGEBYSCOREZPOPMAXZPOPMIN两个阻塞版本的POP命令BZPOPMAX BZPOPMINZRANKZREVRANKZSCOREZREMZREMRANGEBYRANKZREMRANGEBYSCOREZINCRBY集合间操作ZINTERSTOREZUNIONSTORE 命令小结 内部编码使用场景 前言 对于有序集合这个名…

航片水体空洞修补

水体空洞情况如下图所示: 水体空洞修补结果如下图所示: 操作视频教程: MCM智拼图软件V8.5-漏洞空洞修补-水体修补_哔哩哔哩_bilibili

【SDC时序约束】1.主时钟创建

一、时钟 DC工具在进行综合时,需要根据一个时钟进行时序分析。   因此我们需要通过SDC给DC提供一个时钟。   时钟创建是必须的,在创建时钟的同时对时钟进行约束,从而确定整个设计的性能和限制外部时钟。 二、时钟创建 时钟约束通过creat…

详解centos8 搭建使用Tor 创建匿名服务和匿名网站(.onion)

1 Tor运行原理: 请求方需要使用:洋葱浏览器(Tor Browser)或者Google浏览器来对暗,网网站进行访问 响应放需要使用:Tor协议的的Hidden_service 2 好戏来了 搭建步骤: 1.更新yum源 rpm -Uvh h…

React复习笔记

基础语法 创建项目 借助脚手架,新建一个React项目(可以使用vite或者cra,这里使用cra) npx create-react-app 项目名 create-react-app是React脚手架的名称 启动项目 npm start 或者 yarn start src是源文件index.js相当于Vue的main.js文件。整个…

C++—DAY4

在Complex类的基础上&#xff0c;完成^&#xff0c;<<&#xff0c;>>&#xff0c;~运算符的重载 #include <iostream>using namespace std; class Complex {int rel;int vir; public:Complex(){}Complex(int rel,int vir):rel(rel),vir(vir){}void show(){c…

力扣每日一题-总行驶距离-2024.4.25

力扣题目&#xff1a;总行驶距离 题目链接: 2739.总行驶距离 题目描述 代码思路 直接用数学模拟计算即可 代码纯享版 class Solution {public int distanceTraveled(int mainTank, int additionalTank) {int sum 0;while(additionalTank > 0){if(mainTank > 5){mai…

动静态库以及动态链接

文章目录 静态库制作静态库如何使用静态库 动态库动态库的制作动态库的使用动态链接 库是给别人用的&#xff0c;所以库中一定不存在main函数。库一般会有lib前缀和后缀&#xff0c;去掉前缀和后缀才是库名。 静态库 静态库&#xff08;.a&#xff09;&#xff1a;程序在编译…

OpenHarmony实战开发-使用SmartPerf-Host分析应用性能

简介 SmartPerf-Host是一款深入挖掘数据、细粒度展示数据的性能功耗调优工具&#xff0c;可采集CPU调度、频点、进程线程时间片、堆内存、帧率等数据&#xff0c;采集的数据通过泳道图清晰地呈现给开发者&#xff0c;同时通过GUI以可视化的方式进行分析。该工具当前为开发者提…

C#技巧之窗体去鼠标化

简介 在窗体程序中不用鼠标&#xff0c;直接使用键盘完成想要的操作。 实现的方法有两种&#xff0c;一种是使用键盘上的Tab键使控件获得焦点&#xff0c;然后用enter键触发该控件上的事件&#xff08;一般为click事件&#xff09;。另一种是&#xff0c;为控件添加快捷键&am…

优维全新力作:统一采控平台

在本月&#xff0c;优维新一代核心系统「EasyOps」7.0大版本重磅上线&#xff0c;为广大用户带来了“更核心、更智能、更开放、更客制”的产品能力。&#xff08;点击回看&#xff1a;重磅&#xff01;优维科技发布EasyOps7.0大版本&#xff09;在本次版本能力分享上&#xff0…

基于springboot实现中药实验管理系统设计项目【项目源码+论文说明】计算机毕业设计

基于springboot实现中药实验管理系统设计演示 摘要 随着信息技术在管理上越来越深入而广泛的应用&#xff0c;管理信息系统的实施在技术上已逐步成熟。本文介绍了中药实验管理系统的开发全过程。通过分析中药实验管理系统管理的不足&#xff0c;创建了一个计算机管理中药实验管…

基于WOA鲸鱼优化的购售电收益与风险评估算法matlab仿真

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 5.完整工程文件 1.课题概述 基于WOA鲸鱼优化的购售电收益与风险评估算法.WOA优化算法是一种基于鲸鱼捕食过程的仿生优化算法&#xff0c;其包括鲸鱼行走觅食、鲸鱼包围以及鲸鱼螺旋捕食三个步骤。在WOA优…

用 Python 创建 Voronoi 图

概述 最常见的空间问题之一是找到距离我们当前位置最近的兴趣点 (POI)。假设有人很快就会耗尽汽油&#xff0c;他/她需要在为时已晚之前找到最近的加油站&#xff0c;解决这个问题的最佳解决方案是什么&#xff1f;当然&#xff0c;驾驶员可以检查地图来找到最近的加油站&…

C++从入门到精通——string类

string类 前言一、为什么学习string类C语言中的字符串示例 二、标准库中的string类string类string类的常用接口说明string类对象的常见构造string类对象的容量操作string的接口测试及使用string类对象的访问及遍历操作下标和方括号遍历范围for遍历迭代器遍历相同的代码&#xf…

6.模板初阶

目录 1.泛型编程 2. 函数模板 2.1 函数模板概念 2.2函数模板格式 2.3 模板的实现 2.4函数模板的原理 2.5 函数模板的实例化 3.类模板 1.泛型编程 我们如何实现一个 交换函数呢&#xff1f; 使用函数重载虽然可以实现&#xff0c;但是有一下几个不好的地方&#xff1a; …