Pytorch中的Exponential Moving Average(EMA)

news2024/11/13 9:36:21

EMA介绍

EMA,指数移动平均,常用于更新模型参数、梯度等。

EMA的优点是能提升模型的鲁棒性(融合了之前的模型权重信息)

代码示例

下面以yolov7/utils/torch_utils.py代码为例:

class ModelEMA:
    """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
    Keep a moving average of everything in the model state_dict (parameters and buffers).
    This is intended to allow functionality like
    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
    A smoothed version of the weights is necessary for some training schemes to perform well.
    This class is sensitive where it is initialized in the sequence of model init,
    GPU assignment and distributed training wrappers.
    """

    def __init__(self, model, decay=0.9999, updates=0):
        # Create EMA
        self.ema = deepcopy(model.module if is_parallel(model) else model).eval()
        self.updates = updates  # number of EMA updates
        self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def update(self, model):
        # Update EMA parameters
        with torch.no_grad():
            self.updates += 1
            d = self.decay(self.updates)
            msd = model.module.state_dict() if is_parallel(model) else model.state_dict()  
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= d
                    v += (1. - d) * msd[k].detach()

    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
        # Update EMA attributes
        copy_attr(self.ema, model, include, exclude)

ModelEMA类的__init__ 函数介绍

__init__ 函数的输入参数介绍

  • model:需要使用EMA策略更新参数的模型
  • decay:加权权重,默认为0.9999
  • updates:模型参数更新/迭代次数

__init__ 函数的初始化介绍

首先深拷贝一份模型

"""
创建EMA模型

model.eval()的作用:
1. 保证BN层使用的是训练数据的均值(running_mean)和方差(running_val), 否则一旦test的batch_size过小, 很容易就会被BN层影响结果
2. 保证Dropout不随机舍弃神经元
3. 模型不会计算梯度,从而减少内存消耗和计算时间

is_parallel()的作用:
如果模型是并行训练(DP/DDP)的, 则深拷贝model.module,否则就深拷贝model

"""
self.ema = deepcopy(model.module if is_parallel(model) else model).eval()

接着,初始化updates次数,若是从头开始训练,则该参数为0

self.updates = updates

最后,定义加权权重decay的计算公式(这里呈指数型变化),

self.decay = lambda x: decay * (1 - math.exp(-x / 2000))

ModelEMA类的update()函数介绍

如果调用该函数,则更新updates以及decay,

self.updates += 1
## d随着updates的增加而逐渐增大, 意味着随着模型迭代次数的增加, EMA模型的权重会越来越偏向于之前的权重
d = self.decay(self.updates)

取出当前模型的参数,为更新EMA模型的参数做准备,

msd = model.module.state_dict() if is_parallel(model) else model.state_dict()

对EMA模型参数以及当前模型参数进行加权求和,作为EMA模型的新参数,

for k, v in self.ema.state_dict().items():
    if v.dtype.is_floating_point:
        v *= d
        v += (1. - d) * msd[k].detach()

参考文章

【代码解读】在pytorch中使用EMA - 知乎

【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现 - 知乎

以史为鉴!EMA在机器学习中的应用 - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/705093.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ELK报错no handler found for uri and method [PUT] 原因

执行后提示no handler found for uri and method post,最新版8.2的问题? 原因: index.mapping.single_type: true在索引上 设置将启用按索引的单一类型行为,该行为将在6.0后强制执行。 原 {type} 要改为 _doc,格式如…

优炫软件自主研发再结硕果,共享存储SRAC集群数据库重磅发布

新一轮科技革命重塑全球经济结构,关键核心技术是产业发展的基石,数据库、芯片、操作系统是我国数字技术领域三大卡脖子难题。数据库向下发挥硬件算力,向上使能应用系统,是各行各业业务系统运行的基础,是软件行业皇冠上…

光口的作用及应用场景

在光通信中,交换机是一个非常重要的设备,它的作用是将来自不同设备的数据包进行收发和交换。之前发布的文章我们有了解到交换机的光口是如何配置的,本期文章我们将详细讨论交换机的光口的作用及应用场景。 一、光口的主要作用 交换机的光口…

适合团队人数少,预算低的四种办公室类型

如果团队人数少且预算低,以下是一些在深圳比较适合租赁的办公室类型: 1. 联合办公空间:联合办公空间是一种灵活的办公模式,通常提供共享的办公环境和设施,如会议室、休息区、打印机等。这种模式的办公室租金通常较低&…

Linux下Master-Master Replication Manager for MySQL 双主故障切换

简述: Master-Master Replication Manager for MySQL(MMRM)是一种用于MySQL数据库的主-主复制管理工具。它允许在多个MySQL主机之间建立双向的主-主复制关系,实现数据的同步和高可用性。 工作原理是通过在每个MySQL主机上配置双…

javascript 剪贴板数据

本篇文章将介绍在 JavaScript 中检测粘贴事件上的剪贴板数据。 JavaScript 剪贴板数据 当用户通过浏览器 UI 启动粘贴操作时,将引发粘贴事件。 当光标位于可编辑上下文中时,默认操作是将剪贴板的内容粘贴到光标位置的文档中。 此事件的处理程序可以通过…

性能优化 :删除项目中没有引用关系的文件 useless-files-webpack-plugin

一般此类包不需要安装到项目中,减少node_modules体积(以项目实际情况决定-S/-D) npm i useless-files-webpack-plugin -S然后再vue.config.js中配置 const UselessFile require(useless-files-webpack-plugin) chainWebpack: config > …

KEIL支持中文编辑

默认只支持英文,中文一粘贴就变成问号。怎么办? 菜单:Edit, Configuration 点击Encoding的下拉按钮,选择GB2312.

关于Http Header必须了解的知识

header是发送请求时携带的头部信息。通常一些鉴权的信息:authentication/cookie、响应的数据格式:content-type等等的设置。当然响应的数据也会返回一些头部信息。 Requests Header,请求头信息  Responses Header&…

自动化测试中如何比较颜色

引言 在做Nightwatch自动化测试中,出现需要比较颜色的时候如何来做? 基本的思路是首先需要取到这个element的颜色值,然后跟预期的颜色进行对比。 比如我要取下面这个会话窗口的颜色,选中这个图标,按F12,查…

Java实现office办公文档在线预览(word、excel、ppt、txt等)

文章目录 一、官网下载openOffice 安装包,运行安装(不同系统的安装请自行百度,这里不做过多描述) 二、pom中引入依赖 三、office文件转为pdf流的工具类 四、service层代码 五、controller层代码 office办公文档,如doc…

C语言编程—预处理器

预处理器不是编译器的组成部分,但是它是编译过程中一个单独的步骤。简言之,C 预处理器只不过是一个文本替换工具而已,它们会指示编译器在实际编译之前完成所需的预处理。我们将把 C 预处理器(C Preprocessor)简写为 CP…

文心一言 VS 讯飞星火 VS chatgpt (52)-- 算法导论6.2 4题

四、当i>A.heap-size/2 时,调用 MAX-HEAPIFY(A,i)会有什么结果? 文心一言: 当 i > A.heap-size/2 时,调用 MAX-HEAPIFY(A, i) 会将 A[i] 与其子树中的最大元素进行交换,并将 A[i] 视为新的根节点。这个操作会…

移动WEB开发之flex布局

css基础(一)css基础(一)_上半场结束,中场已休息,下半场ing的博客-CSDN博客Emmet语法Emmet语法_上半场结束,中场已休息,下半场ing的博客-CSDN博客css基础(二)c…

Spring Boot中的STOMP Broker:原理及使用

Spring Boot中的STOMP Broker:原理及使用 简介 STOMP(Simple Text Oriented Messaging Protocol)是一种基于文本的协议,用于在Web应用程序之间传递消息。STOMP提供了一种简单的方式来实现WebSocket的双向通信。在Spring Boot中&…

centos7.X安装docker---个人学习经验

工具:VMware Workstation Pro 16.1 系统:CentOS-7-x86_64-DVD-2009 docker:docker-ce-24.0.2-1 说明:这是个人在学习安装docker的时候一些经验,如有不对的还请指教,有些步骤因个人专业能力和时间问题并未…

Elasticsearch-01篇(单机版简单安装)

Elasticsearch-01篇(单机版简单安装) 1. 前言1.1 关于 Elastic Stack 2. Elasticsearch 的安装(Linux)2.1 准备工作2.1.1 下载2.1.2 解压(启动不能用root,所以最好此处换个用户) 2.2 修改相应的…

2023年上海市浦东新区网络安全管理员决赛理论题样题

目录 一、判断题 二、单选题 三、多选题 一、判断题 1.等保1.0至等保2.0从信息系统拓展为网络和信息系统。 正确 (1)保护对象改变 等保1.0保护的对象是信息系统,等保2.0增加为网络和信息系统,增加了云计算、大数据、工业控制系统、物联网、移动物联技术、网络基础…

vite环境变量

vite环境变量 import.meta.env对象中存储环vite的境变量 环境变量以VITE_ 为前缀 在不同环境下,自动读取不同的文件 一般命名 .env .env.development .env.test .env.production

四格表fisher检验

一、案例介绍 某医生用新旧两种药物治疗某病患者27人&#xff0c;治疗结果见下表&#xff0c;现在想知道两种两种药物的治疗效果有无差别&#xff1f; 二、问题分析 本案例的分析目的是探究两种治疗效果有无差异&#xff0c;总样本量为27<40&#xff0c;所以考虑使用四格表…