指数移动平均(EMA)

news2024/11/18 15:23:38

文章目录

  • 前言
  • EMA的定义
  • 在深度学习中的应用
    • PyTorch代码实现
    • yolov5中模型的EMA实现
  • 参考


前言

在深度学习中,经常会使用EMA(指数移动平均)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。
实际上,_EMA可以看作是Temporal Ensembling,在模型学习过程中融合更多的历史状态,从而达到更好的优化效果。

EMA的定义

指数移动平均(Exponential Moving Average)也叫权重移动平均(Weighted Moving Average),是一种给予近期数据更高权重的平均方法。
假设有n个权重数据image.png

  • 普通的平均数:image.png
  • EMA:image.png

其中, vt表示前 t条的平均值 ( v0=0 ),β是加权权重值 (一般设为0.9-0.999)。

Andrew Ng在Course 2 Improving Deep Neural Networks中讲到,EMA可以近似看成过去 1/(1−β) 个时刻 v 值的平均。
普通的过去n时刻的平均是这样的:image.png
类比EMA,可以发现当 image.png 时,两式形式上相等。需要注意的是,两个平均并不是严格相等的,这里只是为了帮助理解。
实际上,EMA计算时,过去 1/(1−β) 个时刻之前的数值平均会decay到 1/e 的加权比例,证明如下。
如果将这里的 vt展开,可以得到:
image.png
其中, image.png,代入可以得到 image.png

在深度学习中的应用

上面讲的是广义的ema定义和计算方法,特别的,在深度学习的优化过程中, image.png是t时刻的模型权重weights, vt是t时刻的影子权重(shadow weights)。在梯度下降的过程中,会一直维护着这个影子权重,但是这个影子权重并不会参与训练。
基本的假设是,模型权重在最后的n步内,会在实际的最优点处抖动,所以我们取最后n步的平均,能使得模型更加的鲁棒。

PyTorch代码实现

下面是一个简单的指数移动平均(EMA)的PyTorch实现:

import torch

class EMA():
    def __init__(self, alpha):
        self.alpha = alpha    # 初始化平滑因子alpha
        self.average = None   # 初始化平均值为空
        self.count = 0        # 初始化计数器为0

    def update(self, x):
        if self.average is None:  # 如果平均值为空,则将其初始化为与x相同大小的全零张量
            self.average = torch.zeros_like(x)
        self.average = self.alpha * x + (1 - self.alpha) * self.average  # 更新平均值
        self.count += 1   # 更新计数器

    def get(self):
        return self.average / (1 - self.alpha ** self.count)   # 根据计数器和平滑因子计算EMA值,并返回平均值除以衰减系数的结果

在这个类中,我们定义了三个方法,分别是__init__、update和get。

  • __init__方法用于初始化平滑因子alpha、平均值average和计数器count
  • update方法用于更新EMA值
  • get方法用于获取最终的EMA值。

使用这个类时,我们可以先实例化一个EMA对象,然后在每个时间步中调用update方法来更新EMA值,最后调用get方法来获取最终的EMA值。
例如:

ema = EMA(alpha=0.5)
for value in data:
    ema.update(torch.tensor(value))
smoothed_data = ema.get()

在这个例子中,我们使用alpha=0.5来初始化EMA对象,然后遍历数据集data中的每个数据点,调用update方法更新EMA值。最后我们调用get方法来获取平滑后的数据。

yolov5中模型的EMA实现

如下:

class ModelEMA:
    """ 
    Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
    Keeps a moving average of everything in the model state_dict (parameters and buffers)
    For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
    """
    def __init__(self, model, decay=0.9999, tau=2000, updates=0):
        # Create EMA
        self.ema = deepcopy(de_parallel(model)).eval()  # FP32 EMA
        self.updates = updates  # number of EMA updates
        self.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def update(self, model):
        # Update EMA parameters
        self.updates += 1
        d = self.decay(self.updates)
        msd = de_parallel(model).state_dict()  # model state_dict
        for k, v in self.ema.state_dict().items():
            if v.dtype.is_floating_point:  # true for FP16 and FP32
                v *= d
                v += (1 - d) * msd[k].detach()
    # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'

    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
        # Update EMA attributes
        copy_attr(self.ema, model, include, exclude)

参考

https://zhuanlan.zhihu.com/p/68748778


如果有用,请点个三连呗 点赞、关注、收藏
你的鼓励是我最大的动力

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1505989.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

全栈的自我修养 ———— css中常用的布局方法flex和grid

在项目里面有两种常用的主要布局:flex和grid布局(b站布局),今天分享给大家这两种的常用的简单方法! 一、flex布局1、原图2、中心对齐3、主轴末尾或者开始对其4、互相间隔 二、grid布局1、基本效果2、加间隔3、放大某一个元素 一、…

数据的加密方式及操作方法

目录 一 什么是加密 二 加密方法 对称加密(如AES加密) 非对称加密(如RSA加密) 散列(如MD5加密) 三 加密操作 1 MD5加密(散列) 2 AES加密(对称加密) …

HTMK5七天学会基础动画网页10(2)

制作立方体 学完前面的基础内容&#xff0c;制作立方体是个不错的练习方法&#xff0c;先看成品 再分析一下&#xff0c;六个面让每个面旋转平移就可以实现一个立方体&#xff0c;来看代码: <title> 制作立方体</title> <style> *{ margin: 0; padding: 0; …

如何搭建财务数据运营体系:基于财务五力模型的分析

在当今复杂多变的商业环境中,财务数据作为企业决策的重要参考依据,其运营体系的搭建显得尤为关键。一个健全、高效的财务数据运营体系不仅能够为企业提供准确的财务数据支持,还能帮助企业在激烈的市场竞争中保持领先地位。基于财务五力模型的分析,我们可以从收益力、安定力…

基于深度学习YOLOv8+Pyqt5的抽烟吸烟检测识别系统(源码+跑通说明文件)

wx供重浩&#xff1a;创享日记 对话框发送&#xff1a;39抽烟 获取完整源码源文件4000张已标注的数据集配置说明文件 可有偿59yuan一对一远程操作跑通 效果展示 基于深度学YOLOv8PyQt5的抽烟吸烟检测识别系统&#xff08;完整源码跑通说明文件&#xff09; 各文件说明 模型评价…

Mybatis-Plus——07,性能分析插件

性能分析插件 一、导入插件二、SpringBoot中配置环境为dev或test环境三、运行测试————————创作不易&#xff0c;笔记不易&#xff0c;如觉不错&#xff0c;请三连&#xff0c;谢谢~~ MybatisPlus也提供了性能分析插件&#xff0c;如果超过这个时间就停止运行&#xff0…

常见3大web漏洞

常见3大web漏洞 XSS攻击 描述&#xff1a; 跨站脚本&#xff08;cross site script&#xff09;-简称XSS&#xff0c;常出现在web应用中的计算机安全漏桶、web应用中的主流攻击方式。 攻击原理&#xff1a; 攻击者利用网站未对用户提交数据进行转义处理或者过滤不足的缺点。 …

前端文件上传

文件上传方式 前端文件上传有两种方式&#xff0c;第一种通过二进制blob传输&#xff08;formData传输&#xff09;&#xff0c;第二种是通过base64传输 文件相关的对象 file对象其实是blob的子类 blob对象的第一个参数必须是一个数组&#xff0c;你可以把一个file对象放进去…

Oracle SQL优化(读懂执行计划 一)

目录 SQL执行计划的作用示例演示执行计划概念介绍执行计划实例DISPLAY_CURSOR 类型DISPLAY_AWR 类型 指标详解 SQL执行计划的作用 示例演示 执行计划概念介绍 执行计划实例 DISPLAY_CURSOR 类型 DISPLAY_AWR 类型 指标详解

Vivado原语模板

1.原语的概念 原语是一种元件&#xff01; FPGA原语是芯片制造商已经定义好的基本电路元件&#xff0c;是一系列组成逻辑电路的基本单元&#xff0c;FPGA开发者编写逻辑代码时可以调用原语进行底层构建。 原语可分为预定义原语和用户自定义原语。预定义原语为如and/or等门级原语…

【电路笔记】-PNP晶体管

PNP晶体管 文章目录 PNP晶体管1、概述2、PNP晶体管电路示例3、PNP晶体管识别1、概述 PNP 晶体管与我们在上一篇教程中看到的 NPN 晶体管器件完全相反。 在这种类型的 PNP 晶体管结构中,两个互连的二极管相对于之前的 NPN 晶体管是相反的。 这会产生正-负-正类型的配置,箭头…

vxe-table配合Export2Excel导出object类型数据{type,count}。表格数据呈现是利用插槽,导出只要count该怎么做

先贴一张数据来&#xff1a; 一、然后是vxe-grid的columns配置&#xff1a; 然后就正常用封装好的Export2Excel就行。 碰到一次在控制台报错&#xff1a; 没复现出来&#xff0c;大概就说是count咋样咋样。 以后碰到的话再说&#xff0c;各位要用的话也注意看看 二、或者 用js…

不知道吧,腾讯云轻量应用服务器使用有一些限制!

腾讯云轻量应用服务器相对于云服务器CVM是有一些限制的&#xff0c;比如轻量服务器不支持更换内网IP地址&#xff0c;不支持自定义私有网络VPC&#xff0c;内网连通性方面也有限制&#xff0c;轻量不支持CPU内存、带宽或系统盘单独升级&#xff0c;只能整个套餐整体升级&#x…

AWS 入门实践-远程访问AWS EC2 Linux虚拟机

远程访问AWS EC2 Linux虚拟机是AWS云计算服务中的一个基本且重要的技能。本指南旨在为初学者提供一系列步骤&#xff0c;以便成功地设置并远程访问他们的EC2 Linux实例。包括如何上传下载文件、如何ssh远程登录EC2虚拟机。 一、创建一个AWS EC2 Linux 虚拟机 创建一个Amazon…

MySQL通过SQL语句进行递归查询

这里主要是针对于MySQL8.0以下版本&#xff0c;因为MySQL8.0版本出来了一个WITH RECURSIVE函数专门用来进行递归查询的 先看下表格数据&#xff0c;就是很普通的树结构数据&#xff0c;通过parentId关联上下级关系 下面我们先根据上级节点id递归获取所有的下级节点数据&#x…

nmcli绑定bond双网卡(active-backup模式)

当前网卡mac地址IP都不一样 创建名为“jbl”的新连接&#xff0c;并将其模式设置为“active-backup” nmcli connection add type bond ifname jbl mode active-backup添加物理网卡到bond(JBL),两个物理网卡添加到新创建的bond连接中 nmcli connection add type bond-slave…

python:布伊山德U检验(Buishand U test,BUT)突变点检测(以NDVI时间序列为例)

作者:CSDN @ _养乐多_ 本文将介绍布伊山德U检验(Buishand U test,BUT)突变点检测代码。以 NDVI 时间序列为例。输入数据可以是csv,一列NDVI值,一列时间。代码可以扩展到遥感时间序列突变检测(突变年份、突变幅度等)中。 结果如下图所示, 文章目录 一、准备数据二、…

Building Systems with the ChatGPT API

Building Systems with the ChatGPT API 本文是 https://www.deeplearning.ai/short-courses/building-systems-with-chatgpt/ 这门课程的学习笔记。 文章目录 Building Systems with the ChatGPT APIWhat you’ll learn in this course Language Models, the Chat Format and…

编译原理之词法分析-语法分析-中间代码生成

编译原理之词法分析-语法分析-中间代码生成 文章说明源码效果展示Gitee链接 文章说明 学习编译原理后&#xff0c;总是想制作自己的一款小语言编译器&#xff0c;虽然对技术不是很理解&#xff0c;学的不是很扎实&#xff0c;但还是想着尝试尝试&#xff1b;目前该效果只是初步…

关于安卓ZXing条码识别(一)引入源码

背景 从0-1引入安卓zxing&#xff0c;实现条码识别 环境 win10 as4 jdk8 引入 首先&#xff0c;官方网站&#xff0c;就是源码。链接 选择你要引入的分支&#xff0c;这里博主选择的是最近更新的分支&#xff0c;如下图&#xff1a; 上图中&#xff0c;1和2都需要引入&am…