深度学习基础知识 register_buffer 与 register_parameter用法分析

news2024/11/15 19:46:36

深度学习基础知识 register_buffer 与 register_parameter用法分析

  • 1、问题引入
  • 2、register_parameter()
    • 2.1 作用
    • 2.2 用法
  • 3、register_buffer()
    • 3.1 作用
    • 3.2 用法

1、问题引入

思考问题:定义的weight与bias是否会被保存到网络的参数中,可否在优化器的作用下进行学习

验证方案:定义网络模型,设置weigut与bias,遍历网络结构参数net.named_parameters(),如果定义的weight与bias在里面,则说明是可学习参数;否则,是不可学习参数

import torch
import torch.nn as nn

# 思考两个问题,定义的weight与bias是否会被保存到网络的参数中,可否在优化器的作用下进行学习

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule,self).__init__()
        self.conv1=nn.Conv2d(in_channels= 3,
                            out_channels= 6,
                            kernel_size=3,
                            stride = 1,
                            padding=1,
                            bias=False)
        
        self.conv2=nn.Conv2d(in_channels= 6,
                            out_channels= 9,
                            kernel_size=3,
                            stride = 1,
                            padding=1,
                            bias=False)
        

        self.waight=torch.ones(10,10)
        self.bias=torch.zeros(10)

    def forward(self,x):
        x=self.conv1(x)
        x=self.conv2(x)
        x = x * self.weight + self.bias
        return x
    
net=MyModule()

for name,param in net.named_parameters():  # 如果weight与bias在里面,说明其是可学习参数;否则,是不可学习参数
    print(name,param.shape)

print("\n","-"*40,"\n")

for key,val in net.state_dict().items():  # 说明weight与bias是不会被state_dict转化为字典中的元素的
    print(key,val.shape)

打印分析结果:
在这里插入图片描述
可以看到,weight与bias不在其中,所以此种定义方式不会是的weight与bias成为可训练参数

2、register_parameter()

register_parameter()是 torch.nn.Module 类中的一个方法

2.1 作用

1、可将 self.weight 和 self.bias 定义为可学习的参数,保存到网络对象的参数中,被优化器作用进行学习
2、self.weight 和 self.bias 可被保存到 state_dict 中,进而可以 保存到网络文件 / 网络参数文件中

2.2 用法

register_parameter(name,param)

  • name:参数名称
  • param:参数张量, 须是 torch.nn.Parameter() 对象 或 None ,

否则报错如下
在这里插入图片描述

import torch
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)

        self.register_parameter('weight', torch.nn.Parameter(torch.ones(10, 10)))
        self.register_parameter('bias', torch.nn.Parameter(torch.zeros(10)))


    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x * self.weight + self.bias
        return x


net = MyModule()

for name, param in net.named_parameters():
    print(name, param.shape)

print('\n', '*'*40, '\n')

for key, val in net.state_dict().items():
    print(key, val.shape)

结果显示:
在这里插入图片描述

3、register_buffer()

register_buffer()是 torch.nn.Module() 类中的一个方法

3.1 作用

  • 将 self.weight 和 self.bias 定义为不可学习的参数,不会被保存到网络对象的参数中,不会被优化器作用进行学习

  • self.weight 和 self.bias 可被保存到 state_dict 中,进而可以 保存到网络文件 / 网络参数文件中

它用于在网络实例中 注册缓冲区,存储在缓冲区中的数据,类似于参数(但不是参数)

  • 参数:可以被优化器更新 (requires_grad=False / True)
  • buffer 中的数据 : 不会被优化器更新

3.2 用法

register_buffer(name,tensor)

  • name:参数名称
  • tensor:张量

代码:

import torch
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)

        self.register_buffer('weight', torch.ones(10, 10))   # 注意:定义的方式
        self.register_buffer('bias', torch.zeros(10))


    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x * self.weight + self.bias
        return x


net = MyModule()

for name, param in net.named_parameters():
    print(name, param.shape)

print('\n', '*'*40, '\n')

for key, val in net.state_dict().items():
    print(key, val.shape)

效果如下所示:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1074010.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

解决PlatformIO下载速度慢以及容易出错(解决vscode下载缓慢问题)

Content 问题描述:依赖下载缓慢问题解决:为vscode配置代理端口 问题描述:依赖下载缓慢 Arduino对于ESP32的开发提供了众多的库,但是Arduino IDE编译速度过于缓慢的问题属实让人难受。 为此我们使用vscode中的platformIO插件&…

斑馬打印機打印中文

创建项目 首先說一下,本文章是借鉴了其他大佬的文章,然后我记录一下的文章。 首先创建好一个.net framework的winform项目。 这里面主要用到两个库文件: Fnthex32.dll、LabelPrint.dll。 Fnthex32这个有8位参数和9位参数的,我这…

数据结构--》解锁数据结构中树与二叉树的奥秘(一)

数据结构中的树与二叉树,是在建立非线性数据结构方面极为重要的两个概念。它们不仅能够模拟出生活中各种实际问题的复杂关系,还常被用于实现搜索、排序、查找等算法,甚至成为一些大型软件和系统中的基础设施。 无论你是初学者还是进阶者&…

三角函数和角公式

该文章对三角函数和角公式做了多种证明:https://mp.weixin.qq.com/s?__bizMzI4ODYwNTM3Ng&mid2247484178&idx1&sn1f6e04c50ae30b63198201db3d9a4f05&chksmec3a96bddb4d1fab4baf8188ca6ba60d8d4364be4f08dc53e13b2e4cdfcec4fdb43928108001&scen…

10_9C++

X-mind #include <iostream> using namespace std; class Per { private:string name;int age;float *height;float *weight; public:Per()//无参构造函数{cout << "无参构造函数" << endl;}Per(string name,int age,float *height,float *weight)…

MySQL——单表与多表查询练习

MySQL 一、练习一二、练习二 一、练习一 这里首先将素材创建完毕&#xff0c;首先创建一个数据库并使用&#xff0c;这里我创建的数据库名为worker&#xff1a; 紧接着我们创建数据库表并创建表结构&#xff1a; 查看表结构&#xff1a; 接着我们导入数据&#xff1a; 这…

【LeetCode】剑指 Offer Ⅱ 第6章:栈(6道题) -- Java Version

题库链接&#xff1a;https://leetcode.cn/problem-list/e8X3pBZi/ 类型题目解决方案栈的应用剑指 Offer II 036. 后缀表达式模拟 栈 ⭐剑指 Offer II 037. 小行星碰撞分类讨论 栈 ⭐单调栈剑指 Offer II 038. 每日温度单调栈 ⭐剑指 Offer II 039. 直方图最大矩形面积单调栈…

Excel·VBA使用ADO合并工作簿

之前文章《ExcelVBA合并工作簿&#xff08;7&#xff0c;合并子文件夹同名工作簿中同名工作表&#xff0c;纵向汇总数据&#xff09;》处理合并工作簿问题&#xff0c;代码运行速度比较慢 而《ExcelVBA使用ADO读取工作簿工作表数据》读取数据非常快&#xff0c;那么是否可以使用…

vue3+elementui实现表格样式可配置

后端接口传回的数据格式如下图 需要依靠后端传回的数据控制表格样式 实现代码 <!-- 可视化配置-表格 --> <template><div class"tabulation_main" ref"myDiv"><!-- 尝试过在mounted中使用this.$refs.myDiv.offsetHeight,获取父元素…

Redis安装及key、string操作

安装 在官网下载的数据包上传到Linux家目录 Install Redis from Source | Redis wget https://download.redis.io/redis-stable.tar.gz tar -xzvf redis-stable.tar.gz cd redis-stable make 编译后出现以下提示后输入make install 出现以下提示则安装成功 输入redis-sever启…

扩展屏幕,副屏幕的使用与设置

1、设置扩展屏模式 按快捷键 win p 选择 扩展 模式 2、设置屏幕的方向 打开电脑的设置页面 选择 系统 点击 屏幕&#xff0c;然后 拖动两个屏幕的位置即可

Computer Architecture Subtitle:Engineering And Technology

原文链接&#xff1a;https://www.cs.umd.edu/~meesh/411/CA-online/index.html

FBZP 维护支持程序 创建国家付款方式

今天在扩充供应商时报了一个错误&#xff1a; 收付方式 I 没有为国家 HK 定义。 原因是香港没有I的支付方式&#xff0c;需要为HK增加一下。方法如下&#xff1a; SPRO 路径&#xff1a;财务会计&#xff08;新&#xff09;-->>应收帐目和应付帐目-->>业务交易--&…

ELementUI之CURD及表单验证

一.CURD 1.后端CURD实现 RequestMapping("/addBook")ResponseBodypublic JsonResponseBody<?> addBook(Book book){try {bookService.insert(book);return new JsonResponseBody<>("新增书本成功",true,0,null);} catch (Exception e) {e.p…

基于Winform的UDP通信

1、文件结构 2、UdpReceiver.cs using System; using System.Collections.Generic; using System.Linq; using System.Net; using System.Net.Sockets; using System.Text; using System.Threading.Tasks;namespace UDPTest.Udp {public class UdpStateEventArgs : EventArgs…

【C/C++】结构体内存分配问题

规则1&#xff1a;以多少个字节为单位开辟内存 就是说&#xff0c;该结构体最终所占字节大小&#xff0c;是这个单位的整数倍 给结构体变量分配内存的时候&#xff0c;会去结构体变量中找基本类型的成员 哪个基本类型的成员占字节数多&#xff0c;就以它大大小为单位开辟内存 …

数据产品读书笔记——数据产品经理和其他角色的关系

&#x1f34a;上一节我们初步对数据产品经理的角色有了初步的了解&#xff0c;今天我们继续学习数据产品经理与其他角色之间的关系。上一期的内容如下&#x1f447;: 链接: 数据产品读书笔记——认识数据产品经理 &#x1f340;当我们处在一个组织中&#xff0c;就一定会有与…

leetcode:2427. 公因子的数目(python3解法)

难度&#xff1a;简单 给你两个正整数 a 和 b &#xff0c;返回 a 和 b 的 公 因子的数目。 如果 x 可以同时整除 a 和 b &#xff0c;则认为 x 是 a 和 b 的一个 公因子 。 示例 1&#xff1a; 输入&#xff1a;a 12, b 6 输出&#xff1a;4 解释&#xff1a;12 和 6 的公因…

大模型rlhf 相关博客

想学习第一篇博客: https://huggingface.co/blog/zh/rlhf RLHF 技术分解 RLHF 是一项涉及多个模型和不同训练阶段的复杂概念&#xff0c;这里我们按三个步骤分解&#xff1a; 预训练一个语言模型 (LM) &#xff1b;聚合问答数据并训练一个奖励模型 (Reward Model&#xff0c;RM…

数据结构和算法(10):B-树

B-树&#xff1a;大数据 现代电子计算机发展速度空前&#xff0c;就存储能力而言&#xff0c;情况似乎也是如此&#xff1a;如今容量以TB计的硬盘也不过数百元&#xff0c;内存的常规容量也已达到GB量级。 然而从实际应用的需求来看&#xff0c;问题规模的膨胀却远远快于存储能…