EMA介绍
EMA,指数移动平均,常用于更新模型参数、梯度等。
EMA的优点是能提升模型的鲁棒性(融合了之前的模型权重信息)
代码示例
下面以yolov7/utils/torch_utils.py代码为例:
class ModelEMA:
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
Keep a moving average of everything in the model state_dict (parameters and buffers).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
"""
def __init__(self, model, decay=0.9999, updates=0):
# Create EMA
self.ema = deepcopy(model.module if is_parallel(model) else model).eval()
self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
for p in self.ema.parameters():
p.requires_grad_(False)
def update(self, model):
# Update EMA parameters
with torch.no_grad():
self.updates += 1
d = self.decay(self.updates)
msd = model.module.state_dict() if is_parallel(model) else model.state_dict()
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point:
v *= d
v += (1. - d) * msd[k].detach()
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes
copy_attr(self.ema, model, include, exclude)
ModelEMA类的__init__ 函数介绍
__init__ 函数的输入参数介绍
- model:需要使用EMA策略更新参数的模型
- decay:加权权重,默认为0.9999
- updates:模型参数更新/迭代次数
__init__ 函数的初始化介绍
首先深拷贝一份模型
"""
创建EMA模型
model.eval()的作用:
1. 保证BN层使用的是训练数据的均值(running_mean)和方差(running_val), 否则一旦test的batch_size过小, 很容易就会被BN层影响结果
2. 保证Dropout不随机舍弃神经元
3. 模型不会计算梯度,从而减少内存消耗和计算时间
is_parallel()的作用:
如果模型是并行训练(DP/DDP)的, 则深拷贝model.module,否则就深拷贝model
"""
self.ema = deepcopy(model.module if is_parallel(model) else model).eval()
接着,初始化updates次数,若是从头开始训练,则该参数为0
self.updates = updates
最后,定义加权权重decay的计算公式(这里呈指数型变化),
self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
ModelEMA类的update()函数介绍
如果调用该函数,则更新updates以及decay,
self.updates += 1
## d随着updates的增加而逐渐增大, 意味着随着模型迭代次数的增加, EMA模型的权重会越来越偏向于之前的权重
d = self.decay(self.updates)
取出当前模型的参数,为更新EMA模型的参数做准备,
msd = model.module.state_dict() if is_parallel(model) else model.state_dict()
对EMA模型参数以及当前模型参数进行加权求和,作为EMA模型的新参数,
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point:
v *= d
v += (1. - d) * msd[k].detach()
【参考文章】
【代码解读】在pytorch中使用EMA - 知乎
【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现 - 知乎
以史为鉴!EMA在机器学习中的应用 - 知乎