模型的权值平均的原理和Pytorch的实现

news2025/1/10 1:36:59

一、前言

模型权值平均是一种用于改善深度神经网络泛化性能的技术。通过对训练过程中不同时间步的模型权值进行平均,可以得到更宽的极值点(optima)并提高模型的泛化能力。 在PyTorch中,官方提供了实现模型权值平均的方法。

这里我们首先介绍指数移动平均(EMA)方法,它使用一个衰减系数来平衡当前权值和先前平均权值。其次,介绍了随机加权平均(SWA)方法,它通过将当前权值与先前平均权值进行加权平均来更新权值。最后,介绍了Tanh自适应指数移动EMA算法(T_ADEMA),它使用Tanh函数来调整衰减系数,以更好地适应训练过程中的不同阶段。

为了方便使用这些权值平均方法,我将官方的代码写成了一个基类AveragingBaseModel,以此引出EMAModel、SWAModel和T_ADEMAModel等方法。这些类可以用于包装原始模型,并在训练过程中更新平均权值。 为了验证这些权值平均方法的效果,我还在ResNet18模型上进行了简单的实验。实验结果表明,使用权值平均方法可以提高模型的准确率,尤其是在训练后期。

但请注意,博客中所提供的代码示例仅用于演示权值平均的原理和PyTorch的实现方式,并不能保证在所有情况下都能取得理想的效果。实际应用中,还需要根据具体任务和数据集来选择适合的权值平均方法和参数设置。

二、算法介绍

基类实现

这里我们的基类完全是参照于torch源码部分,仅仅进行了一点细微的修改。

它首先通过de_parallel函数将原始模型转换为单个GPU模型。de_parallel函数用于处理并行模型,将其转换为单个GPU模型。然后,它将转换后的模型复制到适当的设备(CPU或GPU)上(这一步很重要,问题大多数就是因为计算不匹配),并注册一个名为n_averaged的缓冲区,用于跟踪已平均的次数。

在forward方法中,它简单地将调用传递给转换后的模型。update方法首先获取当前模型和新模型的参数,并将它们转换为可迭代对象,用于更新平均权值。它接受一个新的模型作为参数,并将其与当前模型(已平均的权值)进行比较。

from copy import deepcopy
from pyzjr.core.general import is_parallel
import itertools
from torch.nn import Module

def de_parallel(model):
    """
    将并行模型(DataParallel 或 DistributedDataParallel)转换为单 GPU 模型。
    """
    return model.module if is_parallel(model) else model

class AveragingBaseModel(Module):
    def __init__(self, model, cuda=False, avg_fn=None, use_buffers=False):
        super(AveragingBaseModel, self).__init__()
        device = 'cuda' if cuda and torch.cuda.is_available() else 'cpu'
        self.module = deepcopy(de_parallel(model))
        self.module = self.module.to(device)
        self.register_buffer('n_averaged',
                             torch.tensor(0, dtype=torch.long, device=device))
        self.avg_fn = avg_fn
        self.use_buffers = use_buffers

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)

    def update(self, model):
        self_param = itertools.chain(self.module.parameters(), self.module.buffers() if self.use_buffers else [])
        model_param = itertools.chain(model.parameters(), model.buffers() if self.use_buffers else [])

        self_param_detached = [p.detach() for p in self_param]
        model_param_detached = [p.detach().to(p_averaged.device) for p, p_averaged in zip(model_param, self_param_detached)]

        if self.n_averaged == 0:
            for p_averaged, p_model in zip(self_param_detached, model_param_detached):
                p_averaged.copy_(p_model)

        if self.n_averaged > 0:
            for p_averaged, p_model in zip(self_param_detached, model_param_detached):
                n_averaged = self.n_averaged.to(p_averaged.device)
                p_averaged.copy_(self.avg_fn(p_averaged, p_model, n_averaged))

        if not self.use_buffers:
            for b_swa, b_model in zip(self.module.buffers(), model.buffers()):
                b_swa.copy_(b_model.to(b_swa.device).detach())

        self.n_averaged += 1

若当前模型尚未进行过平均(即n_averaged为0),则直接将新模型的参数复制到当前模型中。若当前模型已经进行过平均,则通过avg_fn函数计算当前模型和新模型的加权平均,并将结果复制到当前模型中。如果use_buffers为True,则会将缓冲区从新模型复制到当前模型。最后,n_averaged增加1,表示已进行一次平均。

指数移动平均(EMA)

EMA被用于根据当前参数和之前的平均参数来更新平均参数。其计算公式如下所示:

EMA_{param} = decay * EMA_{param} + (1 - decay) * current_{param}

这里的EMA param是当前的平均参数,current param是当前的参数,decay是一个介于0和1之间的衰减因子,它用于控制当前参数对平均参数的贡献程度。decay越接近1,平均参数对当前参数的影响就越小,反之亦是。

def get_ema_avg_fn(decay=0.999):
    @torch.no_grad()
    def ema_update(ema_param, current_param, num_averaged):
        return decay * ema_param + (1 - decay) * current_param
    return ema_update

class EMAModel(AveragingBaseModel):
    def __init__(self, model, cuda = False, decay=0.9, use_buffers=False):
        super().__init__(model=model, cuda=cuda, avg_fn=get_ema_avg_fn(decay), use_buffers=use_buffers)

随机加权平均(SWA)

SWA通过对神经网络的权重进行平均来改善模型的泛化能力。其计算公式如下所示:

SWA_{param} = avg_{param} + (current_{param} - avg_{param}) / (num_{avg} + 1)

SWA param是新的平均参数,averaged param是之前的平均参数,current param是当前的参数,num avg是已经平均的参数数量。

def get_swa_avg_fn():
    @torch.no_grad()
    def swa_update(averaged_param, current_param, num_averaged):
        return averaged_param + (current_param - averaged_param) / (num_averaged + 1)
    return swa_update

class SWAModel(AveragingBaseModel):
    def __init__(self, model, cuda = False,use_buffers=False):
        super().__init__(model=model, cuda=cuda, avg_fn=get_swa_avg_fn(), use_buffers=use_buffers)

Tanh自适应指数移动EMA算法(T_ADEMA)

这一个是在查询资料的时候,找到的一篇论文描述的,是否有效,还得经过实验才对。

全文阅读--XML全文阅读--中国知网 (cnki.net)

论文表示是为了在神经网络训练过程中根据不同的训练阶段更有效地过滤噪声,所提出的公式:

decay = alpha * tanh(num_{avg})

T_ADEMA_{param} = decay * avg_{param} + (1 - decay) * current_{param}

T_ADEMA param是新的平均参数,avg param是之前的平均参数,current param是当前的参数,num avg是已经平均的参数数量。alpha是一个控制衰减速率的超参数。通过将参数数量作为输入传递给切线函数的参数,动态地计算衰减因子。切线函数(tanh)的输出范围为[-1, 1],随着参数数量的增加,衰减因子会逐渐趋近于1。由于切线函数的特性,当参数数量较小时,衰减因子接近于0;当参数数量较大时,衰减因子接近于1。

def get_t_adema(alpha=0.9):
    num_averaged = [0]  # 使用列表包装可变对象,以在闭包中引用
    @torch.no_grad()
    def t_adema_update(averaged_param, current_param, num_averageds):
        num_averaged[0] += 1
        decay = alpha * torch.tanh(torch.tensor(num_averaged[0], dtype=torch.float32))
        tadea_update = decay * averaged_param + (1 - decay) * current_param
        return tadea_update
    return t_adema_update

class T_ADEMAModel(AveragingBaseModel):
    def __init__(self, model, cuda=False, alpha=0.9, use_buffers=False):
        super().__init__(model=model, cuda=cuda, avg_fn=get_t_adema(alpha), use_buffers=use_buffers)

三、构建一个简单的实验测试

这一部分我正在做实验,下面是调用了一个简单的resnet18网络,看看逻辑上面是否有错。

if __name__=="__main__":
    # 创建 ResNet18 模型
    import torch
    import torchvision.models as models
    from torch.utils.data import DataLoader
    from tqdm import tqdm
    from torch.optim.swa_utils import AveragedModel

    class RandomDataset(torch.utils.data.Dataset):
        def __init__(self, size=224):
            self.data = torch.randn(size, 3, 224, 224)
            self.labels = torch.randint(0, 2, (size,))

        def __getitem__(self, index):
            return self.data[index], self.labels[index]

        def __len__(self):
            return len(self.data)


    model = models.resnet18(pretrained=False)
    # model = model.to('cuda')
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = torch.nn.CrossEntropyLoss()

    # 创建数据加载器
    train_dataset = RandomDataset()
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

    # 定义权重平均模型
    swa_model = SWAModel(model, cuda=True)
    ema_model = EMAModel(model, cuda=True)
    t_adema_model = T_ADEMAModel(model, cuda=True)

    for epoch in range(5):
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{5}"):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # 更新权重平均模型
            ema_model.update(model)
            swa_model.update(model)
            t_adema_model.update(model)

    # 测试模型
    test_dataset = RandomDataset(size=100)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


    def evaluate(model):
        model.eval()  # 切换到评估模式
        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to('cuda'), labels.to('cuda')
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = correct / total
        print(f"模型准确率:{accuracy * 100:.2f}%")

    # 原模型测试
    print("Model Evaluation:")
    evaluate(model.to('cuda'))   #
    # 测试权重平均模型
    print("SWAModel Evaluation:")
    evaluate(swa_model.to('cuda'))

    print("EMAModel Evaluation:")
    evaluate(ema_model.to('cuda'))

    print("T-ADEMAModel Evaluation:")
    evaluate(t_adema_model.to('cuda'))

运行效果:

Model Evaluation:
模型准确率:46.00%
SWAModel Evaluation:
模型准确率:54.00%
EMAModel Evaluation:
模型准确率:58.00%
T - ADEMAModel Evaluation:
模型准确率:58.00%

仅仅是测试是否能够跑通,过程中也有比原模型要低的时候,而且权值平均主要是用于训练中后期,所以有没有效果应该需要自己去做实验。

当前你可以下载pip install pyzjr==1.2.9,调用from pyzjr.nn import EMAModel运行。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1373705.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

李沐-《动手学深度学习》--02-目标检测

一 、目标检测算法 1. R-CNN a . 算法步骤 使用启发式搜索算法来选择锚框(选出多个锚框大小可能不一,需要使用Rol pooling)使用预训练好的模型(去掉分类层)对每个锚框进行特征抽取(如VGG,AlexNet…)训练…

MYSQL篇--事务机制高频面试题

事务 1 什么是数据库事务? 事务是一个不可分割的数据库操作序列,也是数据库并发控制的基本单位,其执行的结果必须使数据库从一种一致性状态变到另一种一致性状态。事务是逻辑上的一组操作,要么都执行,要么都不执行。…

【sqlite3】sqlite3在linux下使用sqlitebrowser工具实现数据可视化

sqlite3在linux下使用sqlitebrowser工具实现数据可视化 1. ### install sqlitebrowser 1. ### install sqlitebrowser 安装指令 sudo apt-get install sqlitebrowser通过工具打开数据库 sqlitebrowser stereo.db打开效果

HTTPS详解及openssl简单使用

OpenSSL 中文手册 | OpenSSL 中文网 本文介绍https传输协议中涉及的概念,流程,算法,如何实现等相关内容。 HTTP传输过程 HTTP 之所以被 HTTPS 取代,最大的原因就是不安全,至于为什么不安全,看了下面这张图…

Linux第25步_在虚拟机中备份“ST官方的TF-A源码”

TF-A是ARM公司提供的,ST公司通过修改它,做了一个自己的TF-A代码。因为在后期开发中,若硬件被改变了,我们需要通过修改"ST官方的TF-A源码"就可以自己的TF-A代码了。为了防止源文件被误改了,我们需要将"S…

亲测,Chatgpt4.0充值(虚拟卡充值)

一、准备工作: 1、一个ChatGPT3.5账号 2、一张支持ChatGPT4.0的虚拟卡 二、流程【网页版充值】 充值前请先确认以下三点: 1,ChatGPT账户正常登陆。 2,充值过程中始终保持美区环境,且开启全局模式。 3&#xff0…

简洁计算器Python代码

简洁的Python计算器,直接上代码(用时10分钟): Python Gui图形化开发探索GUI开发的无限可能,使用强大的PyQt5、默认的Tkinter和跨平台的Kivy等工具,让Python成为你构建应用程序的得力助手。从本机用户界面到…

在 WinForms 应用程序中实现 FTP 文件操作及模式介绍

在 WinForms 应用程序中实现 FTP 文件操作及模式介绍 简介 在许多应用程序中,能够从远程服务器获取文件是一个非常有用的功能。本文将详细介绍如何在 Windows Forms (WinForms) 应用程序中使用 FTP 协议进行文件操作,包括连接到 FTP 服务器、列出目录、…

邂逅Node.JS的那一夜

邂逅Node.JS的那一夜🌃 本篇文章,学习记录于:尚硅谷🎢 本篇文章,并不完全适合小白,需要有一定的HTML、CSS、JS、HTTP、Web等知识及基础学习: 🆗,紧接上文,…

通过反射修改MultipartFile类文件名

1、背景 项目上有这样一个需求&#xff0c;前端传文件过来&#xff0c;后端接收后按照特定格式对文件进行重命名。(修改文件名需求其实也可以在前端处理的) //接口类似于下面这个样子 PosMapping("/uploadFile") public R uploadFile(List<MultipartFile> fil…

Spring Boot注解大全:从入门到精通,轻松掌握Spring Boot核心注解!

目录 1、前言 2、介绍 2.1 Spring Boot简介 2.2 为什么要学习Spring Boot注解 3、Spring Boot基本注解 3.1 SpringBootApplication 3.2 EnableAutoConfiguration 3.3 ComponentScan 4、控制器注解 4.1 RestController 4.2 RequestMapping 4.3 PathVariable 4.4 Re…

主播风格的多样性

主播风格是主播在直播过程中表现出来的一种个性特点&#xff0c;它可以影响观众的感知和互动体验。以下是常见的几种主播风格: 1.时尚型:这种风格的主播通常穿着时尚、前卫&#xff0c;以潮流、新颖的形象出现在观众面前&#xff0c;善于捕捉时尚元素&#xff0c;并能够将其融…

JAVA销售数据决策管理系统源码

JAVA销售数据决策管理系统源码 基于BS&#xff08;Extjs Strus2springhibernate Mysql&#xff09;的销售数据的决策支持 主要的功能有 系统功能具体内容包括基础资料、进货管理、出货管理、库存管理、决策分析、系统管理。

基于书生·浦语大模型应用开发范式介绍

文章目录 大模型应用开发范式LangChain简介构建向量数据库搭建知识库助手RAG方案优化建议 大模型应用开发范式 通用大模型的优势&#xff1a; 强大的语言理解、指令跟随、语言生成的能力可以理解用户自然语言的指令具有强大的知识储备和一定的逻辑推理能力。 通用大模型局限…

springboot私人健身与教练预约管理系统源码和论文

随着信息技术和网络技术的飞速发展&#xff0c;人类已进入全新信息化时代&#xff0c;传统管理技术已无法高效&#xff0c;便捷地管理信息。为了迎合时代需求&#xff0c;优化管理效率&#xff0c;各种各样的管理系统应运而生&#xff0c;各行各业相继进入信息管理时代&#xf…

中国智造闪耀CES | 木牛科技在美国CES展亮相多领域毫米波雷达尖端方案

素有全球科技潮流“风向标”之称的2024国际消费类电子产品展&#xff08;CES&#xff09;&#xff0c;于1月9-12日在美国拉斯维加斯会议中心举办。CES是全球最大的消费电子和消费技术展览会之一&#xff0c;汇集了世界各地优秀的消费电子和科技公司&#xff0c;带着最好的产品来…

vue3中ref和reactive联系与区别以及如何选择

vue3中ref和reactive区别与联系 区别 1、ref既可定义基本数据类型&#xff0c;也可以定义引用数据类型&#xff0c;reactive只能定义应用数据类型 2、ref在js中取响应值需要使用 .value&#xff0c;而reactive则直接取用既可 3、ref定义的对象通过.value重新分配新对象时依旧…

Windows下上帝模式的实现

在windows系统上有个特殊模式&#xff0c;那就是上帝模式&#xff0c;几乎包含了windows中所有的快捷方式&#xff0c;有很多小伙伴还不知道&#xff0c;让我们一起来实现这一操作吧&#xff01; 一、首先新建一个文件夹 二、接着将文件夹重命名&#xff0c;命名为以下代码&am…

【OpenCV学习笔记07】- 【彩蛋】实现轨迹条控制画笔颜色和笔刷半径,并可以正常绘画

彩蛋 实现轨迹条控制画笔颜色和笔刷半径&#xff0c;并可以正常绘画。 直接上彩蛋代码 示例代码&#xff1a; # 彩蛋&#xff0c;创建一个可以调节颜色和笔刷半径的轨迹栏&#xff0c;并且可以通过鼠标进行绘画 import numpy as np import cv2 as cv# 定义全局变量 # 如果 …

Linux 文件(夹)权限查看

命令 : ls -al ls -al 是一个用于列出指定目录下所有文件和子目录的命令,包括隐藏文件和详细信息。其中,-a 选项表示显示所有文件,包括以 . 开头的隐藏文件,-l 选项表示以列表的形式显示文件的详细信息。 本例中:drwxrwxr-x 为权限细节。 权限细节(Permission detail…