register_parameter和register_buffer 详解

news2025/1/18 16:47:34

在参考yolo系列代码或其他开源代码,经常看到register_buffer register_parameter的使用,接下来将详细对他们进行介绍。

1. 前沿

在搭建网络时,我们 自定义的参数,往往不会保存到模型权重文件中,或者成为模型可学习的参数。即我们通过 net.named_parameters() (模型可学习参数)或 net.state_dict().items()(保存模型权重值)方法都无法遍历输出。那如何解决呢,这就需要用到本文讲的register_parameterregister_buffer方法。

2. register_parameter

register_parameter() 是 torch.nn.Module 类中的一个方法。

2.1 主要作用

  • 用于定义可学习参数
  • 定义的参数可被保存到网络对象的参数中,可使用 net.parameters()net.named_parameters() 查看
  • 定义的参数可用 net.state_dict() 转换到字典中,进而 保存到网络文件 / 网络参数文件

2.2 函数说明

register_parameter(name,param)

参数:

  • name:参数名称

  • param:参数张量, 须是torch.nn.Parameter()对象 或 None ,否则报错如下
    TypeError: cannot assign 'torch.FloatTensor' object to parameter 'xx' (torch.nn.Parameter or None required)

2.3 举例说明

(1)自定义的参数未使用register_parameter

import torch
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)

        self.weight = torch.ones(10,10)
        self.bias = torch.zeros(10)


    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x * self.weight + self.bias
        return x


net = MyModule()

print('\n', '*'*30+"net.named_parameters"+'*'*30, '\n')
for name, param in net.named_parameters():
    print(name, param.shape)


print('\n', '*'*30+"net.state_dict"+'*'*30, '\n')
for key, val in net.state_dict().items():
    print(key, val.shape) 

输出:
在这里插入图片描述
在网络搭建的代码中,我们自定义了self.weightself.bias参数。我们思考下2个问题:1. 我们定义的self.weightself.bias参数是否会保存到网络的参数中,是否能在优化器的作用下进行学习。2. 这些参数是否能够保存到模型文件中,从而可以利用state_dict中遍历出来。通过上面的打印信息我们发现:

  • 使用net.named_parameters()迭代网络中可学习的参数,发现输出的参数只有conv1conv2的weight参数,并没有输出我们定义的self.weightself.bias
  • 接下来使用net.state_dict()方法迭代保存的参数,同样发现self.weightself.bias参数也没有被输出出来。

(2)通过register_parameter方法来定义参数

  • 接下来我们使用register_parameter来定义weight和bias参数,看看会有啥效果。代码修改如下:
self.register_parameter('weight',torch.nn.Parameter(torch.ones(10,10)))
self.register_parameter('bias',torch.nn.Parameter(torch.zeros(10)))

完整代码

import torch
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)

        self.register_parameter('weight',torch.nn.Parameter(torch.ones(10,10)))
        self.register_parameter('bias',torch.nn.Parameter(torch.zeros(10)))


    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x * self.weight + self.bias
        return x


net = MyModule()

print('\n', '*'*30+"net.named_parameters"+'*'*30, '\n')
for name, param in net.named_parameters():
    print(name, param.shape)


print('\n', '*'*30+"net.state_dict"+'*'*30, '\n')
for key, val in net.state_dict().items():
    print(key, val.shape) 

在这里插入图片描述

  • 可以看到,使用了register_parameter定义的参数weight和bias,可以通过net.named_parameters或者net.parameters迭代出来的,这说明weight和bias已经存到了网络的参数中,他们是可学习的参数
  • 同时,通过state_dict()也能将参数和值给迭代出来,就说明如果要保存模型权重或网络参数时,这两个参数时可以被保存起来的。

3 register_buffer()

register_buffer()是 torch.nn.Module() 类中的一个方法

3.1 作用

  • 用于定义不可学习的参数
  • 定义的参数不会被保存到网络对象的参数中,使用 net.parameters() 或 net.named_parameters() 查看不到
  • 定义的参数可用 net.state_dict() 转换到字典中,进而 保存到网络文件 / 网络参数文件中

register_buffer() 用于在网络实例中 注册缓冲区,存储在缓冲区中的数据,类似于参数(但不是参数),它与参数的区别为:

  • 参数:可以被优化器更新 (requires_grad=False / True)

  • buffer 中的数据 (不可学习): 不会被优化器更新

3.2、举例说明

将定义的weight和bias,通过register_buffer来定义。

self.register_buffer('weight',torch.ones(10,10))
self.register_buffer('bias',torch.zeros(10))

运行完整代码看看效果:

import torch
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)

        self.register_buffer('weight',torch.ones(10,10))
        self.register_buffer('bias',torch.zeros(10))


    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x * self.weight + self.bias
        return x


net = MyModule()z

print('\n', '*'*30+"net.named_parameters"+'*'*30, '\n')
for name, param in net.named_parameters():
    print(name, param.shape)


print('\n', '*'*30+"net.state_dict"+'*'*30, '\n')
for key, val in net.state_dict().items():
    print(key, val.shape) 

在这里插入图片描述
我们可以看到:

  • 通过register_buffer定义的参数weight和bias,它是没有被named_parameter给迭代出来的,也就是说weight和bias不是网络的可学习参数,无法通过优化器来迭代更新,我们把它叫做buffer,而不是参数
  • 然而我们使用net.state_dict去迭代的话,weight和bias事可以被迭代出来的,这就说明使用register_buffer定义的数据,可以保持到模型或者权重文件中。

注意:

  • 在使用register_parameter定义参数时,必须定义为可学习的参数,因此需要通过torch.nn.Parameter去定义为一个可学习的参数
  • 而我们使用register_buffer定义参数时,是不需要通过torch.nn.Parameter去定义为可学习的参数的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1171902.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【多线程】并发问题

public class BuyTicket implements Runnable{private int ticketNums10;Overridepublic void run() {for(int i1;i<ticketNums;i){if(ticketNums<0){break;}System.out.println(Thread.currentThread().getName() "抢到了第" i "张票");ticketNu…

数字处理-第10届蓝桥杯省赛Python真题精选

[导读]&#xff1a;超平老师的Scratch蓝桥杯真题解读系列在推出之后&#xff0c;受到了广大老师和家长的好评&#xff0c;非常感谢各位的认可和厚爱。作为回馈&#xff0c;超平老师计划推出《Python蓝桥杯真题解析100讲》&#xff0c;这是解读系列的第3讲。 数字处理&#xff…

高性能网络编程 - 关于单台服务器并发TCP连接数理论值的讨论

文章目录 概述操作系统的限制因素文件句柄限制1. 进程限制2. 全局限制 端口号范围限制 概述 单台服务器可以支持的并发TCP连接数取决于多个因素&#xff0c;包括硬件性能、操作系统限制、网络带宽和应用程序设计。以下是一些影响并发TCP连接数的因素&#xff1a; 服务器硬件性…

软件设计模式原则(二)开闭原则

继续讲解第二个重要的设计模式原则——开闭原则~ 一.定义 开闭原则&#xff0c;在面向对象编程领域中&#xff0c;规定“软件中的对象&#xff08;类&#xff0c;模块&#xff0c;函数等等&#xff09;应该对于扩展是开放的&#xff0c;但是对于修改是封闭的”&#xff0c;这意…

Golang源码分析之golang/sync之singleflight

1.1. 项目介绍 golang/sync库拓展了官方自带的sync库&#xff0c;提供了errgroup、semaphore、singleflight及syncmap四个包&#xff0c;本次分析singlefliht的源代码。 singlefliht用于解决单机协程并发调用下的重复调用问题&#xff0c;常与缓存一起使用&#xff0c;避免缓存…

Capto2024专为Mac电脑设计的屏幕录制和视频编辑软件

不得不说视频编辑功能&#xff1a;Capto提供了多种视频编辑功能&#xff0c;例如剪辑、旋转、裁剪、调整音频和视频的音量、加入水印、添加注释等&#xff0c;你能够使用Capto编辑你的视频&#xff0c;使之更加专业和生动。有目共睹的是录制完成后&#xff0c;你能够使用Capto提…

PowerShell实战:文件操作相关命令笔记

目录 1、New-Item 创建新项命令 2、Remove-Item 删除项命令 3、Rename-Item 项重命名 1、New-Item 创建新项命令 cmdlet New-Item 将创建新项并设置其值。 可创建的项类型取决于项的位置。 例如&#xff0c;在文件系统 New-Item 中创建文件和文件夹。 在注册表中&#xff0c; N…

叶片卷曲

叶片卷曲 上卷/内卷白粉病强烈阳光&温度太高虫害&#xff08;蓟马&#xff09; 下卷 叶片卷曲的原因有很多&#xff0c;很多情况无法从外表分辨&#xff0c;并且有可能多种原因混杂&#xff0c;扰乱判断 上卷/内卷 白粉病 当植株感染白粉病时&#xff0c;白粉病菌孢子附…

c语言进阶部分详解(《高质量C-C++编程》经典例题讲解及柔性数组)

上篇文章我介绍了介绍动态内存管理 的相关内容&#xff1a;c语言进阶部分详解&#xff08;详细解析动态内存管理&#xff09;-CSDN博客 各种源码大家可以去我的github主页进行查找&#xff1a;唔姆/比特学习过程2 (gitee.com) 今天便接“上回书所言”&#xff0c;来介绍《高质…

CANoe新建XML自动化Test Modules

文章目录 1.打开Test Modules2.新建Environment3.新建XML Test Modules4.新建.can文件5.打开XML Test Modules6.新建xml脚本并保存7.编译8.在.can文件写个测试用例9.修改报告格式为HTML10.运行查看报告后面介绍的文章会重复用到这部分,这里单独介绍下,后面不做重复介绍。 1.…

Envoy XDS协议学习

Envoy xds学习 资料地址 envoy官网资料连接 接口说明 xds分为增量接口和全量接口SotW&#xff1a;state of the world 即全量的数据Incremental&#xff1a; 增量的数据 具体接口 Listener: Listener Discovery Service (LDS) SotW: ListenerDiscoveryService.StreamList…

一文搞懂设计模式之工厂模式

大家好&#xff0c;我是晴天&#xff0c;本周将同大家一起学习设计模式系列的第二篇文章——工厂模式&#xff0c;我们将依次学习简单工厂模式&#xff0c;工厂方法模式和抽象工厂模式。拿好纸和笔&#xff0c;我们现在开始啦~ 前言 我们在进行软件开发的时候&#xff0c;虽然…

vector类模拟实现(c++)(学习笔记)

vector 构造函数析构函数[]push_backsize()capacity()reserve()push_back() 迭代器实现非const和const版本 pop_back()resize()insert()***重点erase()***重点再谈构造函数&#xff01;拷贝构造函数****&#xff08;重点&#xff09;运算符重载***&#xff08;重点&#xff09;…

详解RSA加密算法 | Java模拟实现RSA算法

目录 一.什么是RSA算法 二.RSA算法的算法原理 算法描述 三.RSA算法安全性 四.RSA算法的速度 五.用java实现RSA算法 一.什么是RSA算法 1976年&#xff0c;Diffie和Hellman在文章“密码学新方向&#xff08;New Direction in Cryptography&#xff09;”中首次提出了公开…

arduino - NUCLEO-H723ZG - test

文章目录 arduino - NUCLEO-H723ZG - test概述笔记物理串口软串口备注END arduino - NUCLEO-H723ZG - test 概述 准备向NUCLEO-H723ZG上移植西门子飞达控制的Arduino程序. 先确认一下知识点和效果. 笔记 物理串口 NUCLEO-H723ZG在STM32 Arduino 库中, 只提供了一个串口 Se…

快速了解推荐引擎检索技术

目录 一、推荐引擎和其检索技术 二、推荐引擎的整体架构和工作过程 &#xff08;一&#xff09;用户画像 &#xff08;二&#xff09;文章画像 &#xff08;三&#xff09;推荐算法召回 三、基于内容的召回 &#xff08;一&#xff09;召回算法 &#xff08;二&#xf…

uni-app---- 点击按钮拨打电话功能点击按钮调用高德地图进行导航的功能【安卓app端】

uniapp---- 点击按钮拨打电话功能&&点击按钮调用高德地图进行导航的功能【安卓app端】 先上效果图&#xff1a; 1. 在封装方法的文件夹下新建一个js文件&#xff0c;然后把这些功能进行封装 // 点击按钮拨打电话 export function getActionSheet(phone) {uni.showAct…

【雷达原理】雷达杂波抑制方法

目录 一、杂波及其特点 1.1 什么是杂波&#xff1f; 1.2 杂波的频谱特性 二、动目标显示(MTI)技术 2.1 对消原理 2.2 数字对消器设计 三、MATLAB仿真 3.1 对消效果验证 3.2 代码 一、杂波及其特点 1.1 什么是杂波&#xff1f; 杂波是相对目标回波而言的&#xff0c;…

【Python工具】Panoply介绍及安装步骤

Panoply介绍及安装步骤 1 Panoply介绍2 Panoply安装步骤&#xff08;Windows&#xff09;2.1 下载并安装JAVA环境2.2 下载Panoply报错&#xff1a;Error: A JNI error has occurred, please check your installation and try again. 参考 1 Panoply介绍 Panoply是一款由美国国…

【大数据】Apache NiFi 数据同步流程实践

Apache NiFi 数据同步流程实践 1.环境2.Apache NIFI 部署2.1 获取安装包2.2 部署 Apache NIFI 3.NIFI 在手&#xff0c;跟我走&#xff01;3.1 准备表结构和数据3.2 新建一个 Process Group3.3 新建一个 GenerateTableFetch 组件3.4 配置 GenerateTableFetch 组件3.5 配置 DBCP…