Exponential Moving Average (EMA) in Stable Diffusion

news2024/9/17 8:43:23

1.Moving Average in Stable Diffusion (SMA&EMA)

1.Moving average
2.移动平均值
3.How We Trained Stable Diffusion for Less than $50k (Part 3)

Moving Average
在统计学中,移动平均是通过创建整个数据集中不同选择的一系列平均值来分析数据点的计算。


给定一数字序列和固定子集大小,移动平均值的第一个元素是通过对数字序列的初始固定子集求平均值而获得的。然后通过“前移”的方式修改子集;也就是说,排除系列的第一个数字并包括子集中的下一个值。

移动平均的理解,来自移动平均值

1.1 Simple Moving Average(SMA,an unweighted MA)


1.2 Exponential Moving Average (EMA,a weighted MA)

In the context of Stable Diffusion, the Exponential Moving Average (EMA) is a technique used during the training of machine learning models, particularly neural networks, to stabilize and improve the model’s performance.

The Exponential Moving Average is a method of averaging that gives more weight to recent data points, making it more responsive to recent changes compared to a simple moving average, which treats all data points equally.

1.2.1 EMA in Stable Diffusion

In the context of Stable Diffusion, EMA is applied to the model parameters during training to create a smoothed version of the model. This is particularly useful in machine learning because the training process can be noisy, with the model parameters oscillating as they converge towards an optimal solution. By maintaining an EMA of the model parameters, the training process can benefit from the following:

  1. Smoothing: EMA smooths out the parameter updates, reducing the impact of noise and making the training process more stable.
  2. Better Generalization: The EMA version of the model often generalizes better on unseen data compared to the model with the raw parameters. This is because EMA tends to favor parameter values that are more consistent over time.
  3. Preventing Overfitting: By averaging the parameters over time, EMA can help mitigate overfitting, especially in cases where the model might otherwise converge too quickly to a suboptimal solution.

笔者个人理解
代价函数(loss function)是关于参数(weight&bias)的函数,也就是说一个loss值对应一组参数值,loss值表现为震荡,也就是说模型参数也在变化。在训练SD时的MSE Loss在梯度下降过程中是上下震荡的,对应的模型参数也在震荡,可以用EMA取得这些模型参数震荡值的中间值,这个模型参数的中间值也就能更好的代表所有时刻模型参数的平均水平,让模型获得了更好的泛化能力

Stable Diffusion 2 uses Exponential Moving Averaging (EMA), which maintains an exponential moving average of the weights. At every time step, the EMA model is updated by taking 0.9999 times the current EMA model plus 0.0001 times the new weights after the latest forward and backward pass. By default, the EMA algorithm is applied after every gradient update for the entire training period. However, this can be slow due to the memory operations required to read and write all the weights at every step.
每个时间步都对所有参数进行EMA代价较大,因为要在每个时刻读写模型的全部参数
EMA t = 0.0001 ⋅ x t + 0.9999 ⋅ EMA t − 1 \text{EMA}_t=0.0001\cdot x_t+0.9999\cdot \text{EMA}_{t-1} EMAt=0.0001xt+0.9999EMAt1
为了使得计算EMA代价减小,我们仅仅采取在最后时间段进行EMA计算
To avoid this costly procedure, we start with a key observation: since the old weights are decayed by a factor of 0.9999 at every batch, the early iterations of training only contribute minimally to the final average. This means we only need to take the exponential moving average of the final few steps. Concretely, we train for 1,400,000 batches and only apply EMA for the final 50,000 steps, which is about 3.5% of the training period. The weights from the first 1,350,000 iterations decay away by (0.9999)^50000, so their aggregate contribution would have a weight of less than 1% in the final model. Using this technique, we can avoid adding overhead for 96.5% of training and still achieve a nearly equivalent EMA model.

1.2.2 Implementation in Stable Diffusion

During the training of a diffusion model, the EMA of the model’s weights is updated alongside the regular updates. Here’s a typical process:

  1. Initialize EMA Weights: At the start of training, initialize the EMA weights to be the same as the model’s initial weights.
  2. Update During Training: After each batch update, update the EMA weights using the formula mentioned above. This requires storing a separate set of weights for the EMA.
  3. Use for Inference: At the end of the training, use the EMA weights for inference instead of the raw model weights. This is because the EMA weights represent a more stable and potentially better-performing version of the model.

1.2.3 Practical Considerations

  1. Choosing α \alpha α:The smoothing factor α \alpha α is a hyperparameter that needs to be chosen carefully. A common practice is to set α \alpha α based on the number of iterations or epochs, such as α = 2 N + 1 \alpha=\frac{2}{N+1} α=N+12 where N N N is the number of iterations
  2. Performance Overhead: Maintaining EMA weights requires additional memory and computational overhead, but the benefits in terms of model stability and performance often outweigh these costs.

module.py

class EMA:
# Initializes the EMA object with a smoothing factor (beta) and a step counter (step).
    def __init__(self, beta):
        super().__init__()
        self.beta = beta  # Smoothing factor for the exponential moving average
        self.step = 0  # Step counter to keep track of the number of updates
# Updates the moving average of the parameters of the EMA model (ma_model) based on the current model (current_model)
    def update_model_average(self, ma_model, current_model):
        # Update the moving average (EMA) of model parameters
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            # Update the moving average of the parameters
            ma_params.data = self.update_average(old_weight, up_weight)
# Computes the exponentially weighted average of the old and new parameters.
    def update_average(self, old, new):
        # Compute the updated average
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new
# Either resets the EMA model parameters to match the current model parameters 
# if the step count is less than step_start_ema, 
# or updates the EMA model parameters based on the current model parameters. 
# It increments the step counter after each call.
    def step_ema(self, ema_model, model, step_start_ema=2000):
        # Update EMA model parameters or reset them based on the step count
        if self.step < step_start_ema:
            self.reset_parameters(ema_model, model)
        else:
            self.update_model_average(ema_model, model)
        self.step += 1  # Increment the step counter
# Copies the current model's parameters to the EMA model to initialize the EMA model parameters
    def reset_parameters(self, ema_model, model):
        # Initialize EMA model parameters to be the same as the current model's parameters
        ema_model.load_state_dict(model.state_dict())

train.py

def train(args):
    device = args.device  # Get the device to run the training on
    model = UNET().to(device)   # Initialize the model and move it to the device
    model.train()
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)  # set up the optimizer with AdamW
    mse = nn.MSELoss()  # Mean Squared Error loss function
    logger = SummaryWriter(os.path.join("runs", args.run_name))
    len_train = len(train_loader)
# EMA:Exponential Moving Average
    ema = EMA(0.995)  # Exponential Moving Average with decay rate 0.995
# At the start of training, initialize the EMA weights to be the same as the model’s initial weights.
    ema_model = copy.deepcopy(model).eval().requires_grad_(False)  # Create a copy of the model for EMA, set to eval mode and no gradients
    print('Start into the loop !')
    for epoch in range(args.epochs):
        logging.info(f"Starting epoch {epoch}:")  # log the start of the epoch
        progress_bar = tqdm(train_loader)  # progress bar for the dataloader
        optimizer.zero_grad()  # Explicitly zero the gradient buffers
        accumulation_steps = 4
        # Load all data into a batch
        for batch_idx, (images, captions) in enumerate(progress_bar):
            images = images.to(device)  # move images to the device
            # The dataloaer will add a batch size dimension to the tensor, but I've already added batch size to the VAE
            # and CLIP input, so we're going to remove a batch size and just keep the batch size of the dataloader
            images = torch.squeeze(images, dim=1)
            captions = captions.to(device)  # move caption to the device
            text_embeddings = torch.squeeze(captions, dim=1) # squeeze batch_size
            timesteps = ddpm_sampler.sample_timesteps(images.shape[0]).to(device)  # Sample random timesteps
            noisy_latent_images, noises = ddpm_sampler.add_noise(images, timesteps)  # Add noise to the images
            time_embeddings = timesteps_to_time_emb(timesteps)
            # x_t (batch_size, channel, Height/8, Width/8) (bs,4,256/8,256/8)
            # caption (batch_size, seq_len, dim) (bs, 77, 768)
            # t (batch_size, channel) (batch_size, 1280)
            # (bs,320,H/8,W/8)
            with torch.no_grad():
                last_decoder_noise = model(noisy_latent_images, text_embeddings, time_embeddings)
            # (bs,4,H/8,W/8)
            final_output = diffusion.final.to(device)
            predicted_noise = final_output(last_decoder_noise).to(device)
            loss = mse(noises, predicted_noise)  # Compute the loss
            loss.backward()  # Backpropagate the loss
            if (batch_idx + 1) % accumulation_steps == 0:  # Wait for several backward passes
                optimizer.step()  # Now we can do an optimizer step
                optimizer.zero_grad()  # Reset gradients to zero
# EMA:Exponential Moving Average
    		ema.step_ema(ema_model, model)
            progress_bar.set_postfix(MSE=loss.item())  # Update the progress bar with the loss
            # log the loss to TensorBoard
            logger.add_scalar("MSE", loss.item(), global_step=epoch * len_train + batch_idx)
        # Save the model checkpoint
        os.makedirs(os.path.join("models", args.run_name), exist_ok=True)
        torch.save(model.state_dict(), os.path.join("models", args.run_name, f"stable_diffusion.ckpt"))
        torch.save(optimizer.state_dict(),
                   os.path.join("models", args.run_name, f"optim.pt"))  # Save the optimizer state

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1948780.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

树莓集团——缔造国际领先的数字媒体产业孵化基地

在数字化浪潮席卷全球的背景下&#xff0c;树莓集团以其卓越的创新能力&#xff0c;正在全力打造一个全球领先的数字媒体产业基地。成都国际数字影像产业园不仅是技术和创意的汇聚地&#xff0c;更是数字内容创新与应用的典范&#xff0c;推动数字媒体产业迈向新的高度。 树莓集…

vue.js入门

目录 一. 框架概述 二. vue常用命令 2.1 插值表达式 2.2 v-text 2.3 v-html 2.4 v-on 2.5 v-model 2.6 v-show 2.7 v-if 2.8 v-else 2.9 v-bind 2.10 v-for 三. vue生命周期函数 一. 框架概述 我们之前写的javaScript代码都是原生的,而框架是在基础语言之上,对其进…

pytorch-梯度下降

梯度下降 y x 2 ∗ s i n ( x ) y ′ 2 x s i n x x 2 c o s x x 1 x − Δ y ′ ( x ) 其中 Δ 表示学习率&#xff0c; y ′ ( x ) 代表 y 在 x 点处关于 x 的梯度。 y x^2 * sin(x) \\ y 2xsinxx^2cosx \\ x_1 x - \Delta y(x) \\ 其中 \Delta 表示学习率&#xff0c…

qt表格模型视图

Qt 提供了一套强大的模型/视图框架&#xff0c;允许你以一种非常灵活和高效的方式显示和处理数据。在 Qt 中&#xff0c;表格视图&#xff08;TableView&#xff09;和模型&#xff08;TableModel&#xff09;是这种框架的一部分&#xff0c;常用于显示和编辑表格数据。 以下是…

便携式气象仪:科技赋能,让气象观测更智能

随着科技的快速发展&#xff0c;越来越多的领域受益于技术的进步。其中&#xff0c;气象观测领域也不例外。传统的气象观测设备虽然精确可靠&#xff0c;但往往体积庞大、携带不便&#xff0c;且需要专业人员进行操作和维护。而便携式气象仪的出现&#xff0c;则打破了这一局限…

BGP路由反射器

原理概述 缺省情况下&#xff0c;路由器从它的一个 IBGP对等体那里接收到的路由条目不会被该路由器再传递给其他IBGP对等体&#xff0c;这个原则称为BGP水平分割原则&#xff0c;该原则的根本作用是防止 AS内部的BGP路由环路。因此&#xff0c;在AS内部&#xff0c;一般需要每台…

深度解读大语言模型中的Transformer架构

一、Transformer的诞生背景 传统的循环神经网络&#xff08;RNN&#xff09;和长短期记忆网络&#xff08;LSTM&#xff09;在处理自然语言时存在诸多局限性。RNN 由于其递归的结构&#xff0c;在处理长序列时容易出现梯度消失和梯度爆炸的问题。这导致模型难以捕捉长距离的依…

《知识点扫盲 · Redis 序列化器》

&#x1f4e2; 大家好&#xff0c;我是 【战神刘玉栋】&#xff0c;有10多年的研发经验&#xff0c;致力于前后端技术栈的知识沉淀和传播。 &#x1f497; &#x1f33b; CSDN入驻不久&#xff0c;希望大家多多支持&#xff0c;后续会继续提升文章质量&#xff0c;绝不滥竽充数…

【2024最新版】Stable diffusion汉化版安装教程(附SD安装包),一键激活,永久免费!

目前广泛使用的Stable Diffusion Web UI简称(SDWebUI)是发布在开源平台Github上的一个Python项目,与通常的软件安装方法不同,这个项目并不是下载并安装即可使用的应用程序,而是需要准备执行环境,编译源码. 如果你是一个新手不会安装,现在可以直接使用一键启动包. 例如:国内的…

【Linux】信号量与生产消费模型

我们已经实现过锁条件变量的PC模型&#xff0c; 但是BlockingQueue并不能进行生产与消费的并发&#xff0c;原因在于我们使用的是STL提供的队列&#xff0c;进行了一个适配&#xff0c;底层的实现可能会修改到成员变量造成未知的错误。 而这次我们选择使用环形队列&#xff08…

python实现盲反卷积算法

python实现盲反卷积算法 盲反卷积算法算法原理算法实现Python实现详细解释优缺点应用领域盲反卷积算法 盲反卷积算法是一种图像复原技术,用于在没有先验知识或仅有有限信息的情况下,估计模糊图像的原始清晰图像和点扩散函数(PSF)。盲反卷积在摄影、医学成像、天文学等领域…

监控Windows文件夹下面的文件(C#和C++实现)

最近在做虚拟打印机时&#xff0c;需要实时监控打印文件的到达&#xff0c;并移动文件到另外的位置。一开始我使用了线程&#xff0c;在线程里去检测新文件的到达。实际上Windows提供了一个文件监控接口函数ReadDIrectoryChangesW。这个函数可以对所有文件操作进行监控。 ReadD…

当 Nginx 出现请求的乱序到达,如何处理?

&#x1f345;关注博主&#x1f397;️ 带你畅游技术世界&#xff0c;不错过每一次成长机会&#xff01; 文章目录 当 Nginx 出现请求的乱序到达&#xff0c;如何处理&#xff1f;一、理解请求乱序到达的现象二、请求乱序到达可能带来的影响三、解决方案&#xff08;一&#xf…

安卓嘀嗒清单v7.2.2.2高级版

软件介绍 TickTick是一款轻便高效的任务管理、日程管理&#xff08;GTD&#xff09;和时间管理应用&#xff0c;配备强大的记事和提醒功能。你可以在手机、平板、网页等多达11个平台上使用滴答清单记录大小事务、制定工作计划、整理购物清单、设置生日提醒&#xff0c;甚至安排…

CSS技巧专栏:一日一例 12 -纯CSS实现边框上下交错的按钮特效

CSS技巧专栏&#xff1a;一日一例 12 -纯CSS实现边框上下交错的按钮特效 大家好&#xff0c;今天我们来做一个上下边框交错闪动的按钮特效。 本例图片 案例分析 虽说这按钮给人的感觉就是上下两个边框交错变换了位置&#xff0c;但我们都知道border是没法移动的。那么这个按…

土耳其云手机提升TikTok电商效率

在数字化飞速发展的今天&#xff0c;TikTok不仅是一个社交平台&#xff0c;更是一个巨大的电商市场。随着TikTok电商功能在全球范围内的扩展&#xff0c;土耳其的商家和内容创作者正面临着前所未有的机遇。本文将详细介绍土耳其云手机怎样帮助商家抓住机遇&#xff0c;实现业务…

单片机学习历程

学习单片机的过程可以分为几个主要阶段&#xff0c;每个阶段都涉及不同的学习内容和技能提升。下面我将以一个典型的学习历程为例进行介绍&#xff1a; 初学阶段 1.入门理论学习&#xff1a; 开始接触单片机的基础知识&#xff0c;学习其工作原理、体系结构和常见的芯片类型…

怎样在 Nginx 中配置基于请求客户端 Wi-Fi 连接状态的访问控制?

&#x1f345;关注博主&#x1f397;️ 带你畅游技术世界&#xff0c;不错过每一次成长机会&#xff01; 文章目录 怎样在 Nginx 中配置基于请求客户端 Wi-Fi 连接状态的访问控制一、理解请求客户端 Wi-Fi 连接状态二、Nginx 中的访问控制基础知识三、获取客户端 Wi-Fi 连接状态…

Qt 使用视口和窗口作图

物理坐标系与逻辑坐标系 绘图设备的物理坐标系是基本的坐标系&#xff0c;通过 QPainter 的平移、旋转等坐标变换可以得到更容 易操作的逻辑坐标系。 物理坐标系也称为视口&#xff08;viewport&#xff09;坐标系&#xff0c;逻辑坐标系也称为窗口&#xff08; window&…

《操作系统》(学习笔记)(王道)

一、计算机系统概述 1.1 操作系统的基本概念 1.1.1 操作系统的概念 操作系统&#xff08;OperatinggSystem&#xff0c;OS&#xff09;是指控制和管理整个计算机系统的硬件与软件资源&#xff0c;合理地组织、调度计算机的工作与资源的分配&#xff0c;进而为用户和其他软件…