详细分析Pytorch中的register_buffer基本知识(附Demo)

news2024/9/20 15:54:20

目录

  • 1. 基本知识
  • 2. Demo
  • 3. 与自动注册的差异
    • 3.1 torch.nn.Parameter
    • 3.2 自动注册子模块
    • 3.3 总结

1. 基本知识

register_buffer 是 PyTorch 中 torch.nn.Module 提供的一个方法,允许用户将某些张量注册为模块的一部分,但不会被视为可训练参数。这些张量会随模型保存和加载,但在反向传播过程中不会更新

register_buffer 的作用

  • 将张量注册为模型的缓冲区(buffer),意味着这些张量会与模型一起保存和加载
  • 与参数不同,缓冲区不会参与梯度计算,因此不会在训练时更新
  • 常用于存储像均值、方差、掩码或其他状态信息

与模型参数的区别

  • 模型参数(register_parameterself.param = nn.Parameter(tensor)):这些张量会被认为是可学习的权重,在训练过程中会被优化器更新
  • 缓冲区(register_buffer):这些张量不会被优化器更新,适合用于保存模型的常量或中间状态

使用的场景有如下:

  1. 存储一些与训练无关但随模型保存的常量
  2. 存储一些在 eval() 模式下需要使用的统计数据(如 BatchNorm 层中的均值和方差)
  3. 缓存计算中间的状态或掩码,不希望它们在训练中被更新

2. Demo

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # 注册一个缓冲区, 不参与训练
        self.register_buffer('constant_tensor', torch.tensor([1.0, 2.0, 3.0]))
    
    def forward(self, x):
        # 使用缓冲区中的张量
        return x + self.constant_tensor

model = MyModel()
print(model.constant_tensor)  # 输出缓冲区内容 tensor([1., 2., 3.])

在训练模式和评估模式中的差异

model.train()  # 训练模式
print(model(torch.tensor([1.0, 1.0, 1.0])))  # tensor([2.0, 3.0, 4.0])

model.eval()  # 评估模式
print(model(torch.tensor([1.0, 1.0, 1.0])))  # tensor([2.0, 3.0, 4.0])

缓冲区内容不会因为模型模式的切换(train() 或 eval())而改变,因为缓冲区不是可训练参数,它仅存储数据

示例 1: 存储中间计算的掩码

class MaskedModel(nn.Module):
    def __init__(self):
        super(MaskedModel, self).__init__()
        self.register_buffer('mask', torch.tensor([1.0, 0.0, 1.0]))
    
    def forward(self, x):
        # 应用掩码到输入上
        return x * self.mask

model = MaskedModel()
input_tensor = torch.tensor([4.0, 3.0, 2.0])
output = model(input_tensor)
print(output)  # tensor([4.0, 0.0, 2.0])

示例 2: 保存 BatchNorm 层的均值和方差
在批归一化层(BatchNorm)中,均值和方差是通过 register_buffer 来存储的
这些值在训练时会动态更新,但在推理时会固定使用

class MyBatchNorm(nn.Module):
    def __init__(self, num_features):
        super(MyBatchNorm, self).__init__()
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
    
    def forward(self, x):
        # 这里简单模拟了批归一化的逻辑
        return (x - self.running_mean) / torch.sqrt(self.running_var + 1e-5)

bn_layer = MyBatchNorm(3)
input_tensor = torch.tensor([2.0, 3.0, 4.0])
output = bn_layer(input_tensor)
print(output)

示例 3: 用于固定的均值和方差
不希望在训练时更新某些统计数据(比如均值和方差),而希望使用固定的值

class FixedNorm(nn.Module):
    def __init__(self):
        super(FixedNorm, self).__init__()
        # 注册固定的均值和方差
        self.register_buffer('mean', torch.tensor([0.5]))
        self.register_buffer('std', torch.tensor([0.25]))
    
    def forward(self, x):
        return (x - self.mean) / self.std

model = FixedNorm()
input_tensor = torch.tensor([1.0, 0.5, 0.0])
output = model(input_tensor)
print(output)  # tensor([ 2.0, 0.0, -2.0])

3. 与自动注册的差异

补充与上述不同的知识点,上述为手动注册,下面为自动注册

3.1 torch.nn.Parameter

这些 Parameter 会参与反向传播和梯度更新。
通过 model.parameters() 可以获得所有自动注册的参数

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # torch.nn.Parameter 会自动注册为模型参数
        self.weight = nn.Parameter(torch.randn(3, 3))
    
    def forward(self, x):
        return x @ self.weight

model = MyModel()
# weight 自动注册为参数
print(list(model.parameters()))  # 输出: [Parameter containing: tensor(...)]

3.2 自动注册子模块

在模型的 __init__ 函数中定义了 torch.nn.Module 子模块(例如卷积层、线性层等),PyTorch 会自动将这些子模块注册到模型中,并且它们的参数也会一并注册

  • 通过 model.children() 或 model.named_children() 获取所有子模块
  • 子模块的所有参数也会自动注册,可以通过 model.parameters() 或 model.named_parameters() 获取
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # nn.Module 子模块会自动注册
        self.conv = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)
    
    def forward(self, x):
        return self.conv(x)

model = MyModel()
# conv 层自动注册为模型的子模块
print(list(model.children()))  # 输出: [Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))]

3.3 总结

  • register_buffer:用于注册不需要梯度的张量,比如存储模型状态的变量,不参与反向传播和优化过程
  • register_parameter:用于注册模型的可训练参数,会参与梯度计算和优化过程

两者结合的Demo

import torch
from torch import nn


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # 自动注册的子模块
        self.fc = nn.Linear(5, 5)

        # 手动注册一个可训练的参数
        self.register_parameter('my_weight', nn.Parameter(torch.randn(5, 5)))

        # 手动注册一个不可训练的缓冲区
        self.register_buffer('my_buffer', torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]))

    def forward(self, x):
        x = self.fc(x)
        return x @ self.my_weight + self.my_buffer


model = MyModel()
# 输出所有注册的参数和缓冲区
print("Parameters:", list(model.named_parameters()))
print("Buffers:", list(model.named_buffers()))

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2149336.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2区“发稿大户”!SCISSCI双检,3天上线出版,在这里,不用担心创新性不足~

【SciencePub学术】眼瞅评职晋升最后期限就在眼前,小编今天就给大家带来了一本“百发百中”的救命神刊~ 01 期刊详情 【期刊简介】IF:2.0-3.0 JCR2区中科院4区 【出版社】MDPI出版社 【自引率】8.30% 【类别】医学 【INDEX】SCIE&SSCI在检 02…

es由一个集群迁移到另外一个集群es的数据迁移

迁移es的数据 改下index的索引 就可以了。 查询 用curl -u就可以查询了

[数据集][目标检测]不同颜色的安全帽检测数据集VOC+YOLO格式7574张5类别

重要说明:数据集里面有2/3是增强数据集,请仔细查看图片预览,确认符合要求在下载,分辨率均为640x640 数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件…

微店商品列表API接口实战指南

微店商品列表数据接口是一种允许开发者在其应用程序中调用微店店铺所有商品数据的 API 接口。通过这个接口,开发者可以获取到微店店铺内所有商品的信息,包括但不限于商品的 ID、标题、价格、库存、销量、详情描述、图片等。以下是对微店商品列表数据接口…

如何确保Java程序分发后不被篡改?使用JNI对Java程序进行安全校验

前言 众所周知,Java/Kotlin编译后会编译成smali,使用Jadx这类的反编译工具或者Hook工具就能很轻松的把我们的软件安全校验给破解了。 为了防止这种情况发生,我们一般会将核心代码使用C编写,然后使用JNI技术,使用Java…

TCP报文格式

RFC9293协议规范,规定的TCP格式如图1, 对比RFC793规定的格式,控制位从6bit变成了8bit 图1,图片来源:datatracker.ietf.org 图2为,可对照的中文版TCP格式,中文版参照的是RFC793 图2 重点…

htop 命令:系统状态监控

一、命令简介 ​htop ​是一个互动式的进程查看器,它是 top ​命令的增强版本,提供了更丰富的功能和更好的用户界面。htop ​显示了系统的实时进程和资源使用情况(比如 CPU 和 memory 占用情况),允许用户进行交互式操…

基于Ubuntu的ECS实例实现OSS反向代理

阿里云OSS的存储空间(Bucket)访问地址会随机变换,您可以通过在ECS实例上配置OSS的反向代理,实现通过固定IP地址访问OSS的存储空间。 背景信息 阿里云OSS通过Restful API方式对外提供服务。最终用户通过OSS默认域名或者绑定的自定…

掌握Spring Boot数据库集成:用JPA和Hibernate构建高效数据交互与版本控制

在现代应用开发中,数据库操作是核心环节之一。Spring Boot提供了简化数据库集成的强大工具,而JPA(Java Persistence API)和Hibernate是两种非常流行的ORM(对象关系映射)框架,可以帮助我们将对象…

CLI示例(V2R8至V2R19C00版本):直连二层组网直接转发【AP+上层网络,增加AP下行口有线接入】

CLI示例(V2R8至V2R19C00版本):直连二层组网直接转发【AP+上层网络,增加AP下行口有线接入】 适用于:V200R008至V200R019C00版本的AC,以及有空闲以太网口的AP。 说明:本示例基于“直连二层组网直接转发【AP+AC+出口网关】”场景来介绍如何增加AP下行口有线接入。 业务需求…

Vue使用代理方式解决跨域问题

1、解决跨域问题 如果 Vue 前端应用请求后端 API 服务器,出现跨域问题(CORS),如下图: 解决方法:在 Vue 项目中,打开 vue.config.js 配置文件,在配置文件中使用代理解决跨域问题。 …

怎么找到抖音爆款内容,进行扩散传播?

企业如果想做好抖音平台的品牌营销,需要时刻监测抖音爆款内容并进行加热放大,据此快速创新和改进内容,才能短期提高品牌相关内容的曝光量,快速拉升品牌声量。怎么去找到抖音的爆款内容或者是值得品牌关注的优质内容,主…

印尼有几百种语言,初学者要怎么开始学习?《印尼语翻译通》app或许可以帮助你!印尼语零基础入门学习。

快速翻译,准确高效 采用最新技术,提供精准翻译。翻译结果符合中国人习惯。 体验印尼文化 学习地道印尼语,贴近当地文化。 旅游和工作的好帮手 提供旅游和商务用语,沟通无障碍。 学习印尼语的良师 文本和语音翻译,…

Spark-RDD持久化

一、Spark的三种持久化机制 1、cache 它是persist的一种简化方式,作用是将RDD缓存到内存中,以便后续快速访问,提高计算效率。cache操作是懒执行的,即执行action算子时才会触发。 2、persist 它提供了不同的存储级别&#xff0…

解锁数字转型新纪元:Vatee万腾平台,您的智能加速与策略智库

在数字经济时代的大潮中,企业的数字化转型已不再是选择题,而是必答题。面对这一挑战,Vatee万腾平台以其卓越的技术实力和前瞻性的战略视野,成为了众多企业加速数字化转型、实现智能化升级的得力助手和智囊团。 加速转型&#xff…

人工智能时代:程序员如何在变革中保持核心竞争力?

随着人工智能生成内容(AIGC)领域的快速发展,大语言模型如ChatGPT、Midjourney、Claude等层出不穷,AI辅助编程工具迅速普及,程序员的工作方式正在经历翻天覆地的变革。面对这一趋势,有人担心AI可能取代部分编…

嵌入式处理器详解

文章目录 一、CPU、MPU、MCU、SoC、Application Processors的概念1.CPU (Central Processing Unit)2.MPU (Micro Processor Unit)3.MCU (Micro Controller Unit)4.SoC(System on Chip)5.Application Processors 二、哈弗架构与冯诺伊曼架构三、XIP概念四、嵌入式系统硬件组成五…

【架构设计】多级缓存:应用案例与问题解决策略

【架构设计】多级缓存:应用案例与问题解决策略 多级缓存系统的工作原理及其在提升应用性能方面的关键作用。通过对比本地缓存与分布式缓存的特点 | 原创作者/编辑:凯哥Java | 分类:架构设计系列教程 多级缓存…

模拟电路分析基础知识总结笔记(电子电路分析与设计前置知识)

必备条件 电子电路的直流分析电子电路的正弦稳态分析RC电路的瞬态分析戴维南定理和诺顿定理拉普拉斯变换(看不懂,根本看不懂) 电子电路的直流分析 欧姆定律 ​ 在恒定温度下,电压与电流成正比,电压与电阻成正比&am…

Java-数据结构-优先级队列(堆)-(二) (゚▽゚*)

文本目录: ❄️一、PriorityQueue的常用接口: ➷ 1、PriorityQueue的特性: ➷ 2、使用PriorityQueue的注意: ➷ 3、PriorityQueue的构造: ☞ 1、无参数的构造方法: ☞ 2、有参数的构造方法: …