pytorch的buffer学习整理

news2024/9/26 5:21:19

pytorch模型中的buffer

这段时间忙于做项目,但是在项目中一直在模型构建中遇到buffer数据,所以花点时间整理下模型中的parameter和buffer数据的区别💕

1.torch.nn.Module.named_buffers(prefix=‘‘, recurse=True)

贴上pytorch官网对其的说明:
在这里插入图片描述
官网翻译:

named_buffers(prefix='', recurse=True)
方法: named_buffers(prefix='', recurse=True)

    Returns an iterator over module buffers, yielding both the name of the buffer as well 
    as the buffer itself.
    返回一个迭代器,该迭代器能够遍历模块的缓冲buffer,并且迭代返回的结果是缓冲的名字和缓冲本身.
    Parameters  参数
            prefix (str) – prefix to prepend to all buffer names.
            prefix (字符串) – 添加到所有缓冲名字之前的前缀.
            recurse (bool)if True, then yields buffers of this module and all submodules. 
            Otherwise, yields only buffers that are direct members of this module.
            recurse (布尔类型) – 如果该参数是True,那么表示递归地迭代返回,即迭代返回该模块的缓冲以及
            该模块的所有子模块的缓冲. 默认为True
    Yields  迭代返回
        (string, torch.Tensor) – Tuple containing the name and buffer
        (字符串,torch.Tensor类型) - 包含缓冲名字和缓冲自身的元组
        
    Example:  例子:

    >>> for name, buf in self.named_buffers():
    >>>    if name in ['running_var']:
    >>>        print(buf.size())

总结,缓冲buffer必须要登记注册才会有效,如果仅仅将张量赋值给Module模块的属性,不会被自动转为缓冲buffer.因而也无法被state_dict()、buffers()、named_buffers()访问到.此外state_dict()可以遍历缓冲buffer和参数Parameter.
可以概括为,缓冲buffer和参数Parameter的区别是前者不需要训练优化,而后者需要训练优化.在创建方法上也有区别,前者必须要将一个张量使用方法register_buffer()来登记注册,后者比较灵活,可以直接赋值给模块的属性,也可以使用方法register_parameter()来登记注册.
下面使用代码测试一下buffer数据:

import torch 
import torch.nn as nn
torch.manual_seed(seed=20200910)
class Model(torch.nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.conv1=torch.nn.Sequential(  # 输入torch.Size([64, 1, 28, 28])
                torch.nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),
                torch.nn.ReLU(),  # 输出torch.Size([64, 64, 28, 28])
        )
        self.attribute_buffer_in = torch.randn(3,5)                       # 仅仅赋值给模型属性,是无法访问到该buffer数据
        register_buffer_in_temp = torch.randn(4,6)               
        self.register_buffer('register_buffer_in', register_buffer_in_temp)   # 注册buffer数据,才能生效,能获取到数据

    def forward(self,x): 
        pass

print('cuda(GPU)是否可用:',torch.cuda.is_available())
print('torch的版本:',torch.__version__)
model = Model() #.cuda()



print('初始化之后模型修改之前'.center(100,"-"))
print('调用named_buffers()'.center(100,"-"))   
for name, buf in model.named_buffers():
    print(name,'-->',buf.shape)

print('调用named_parameters()'.center(100,"-"))
for name, param in model.named_parameters():     # 访问模型的parameter参数数据的名字和其本身
    print(name,'-->',param.shape)

print('调用buffers()'.center(100,"-"))           # 访问模型中的buffer数据本身
for buf in model.buffers():
    print(buf.shape)

print('调用parameters()'.center(100,"-"))        # 访问模型中的parameter数据本身
for param in model.parameters():
    print(param.shape)

print('调用state_dict()'.center(100,"-"))        # 同时获取模型的parameter参数数据、buffer参数数据
for k, v in model.state_dict().items():
    print(k, '-->', v.shape)



model.attribute_buffer_out = torch.randn(10,10)      # 赋值给模型属性
register_buffer_out_temp = torch.randn(15,15)
model.register_buffer('register_buffer_out', register_buffer_out_temp)  # 通过注册的方式,使得模型的buffer成员属性生效
print('模型初始化以及修改之后'.center(100,"-"))
print('调用named_buffers()'.center(100,"-"))         # 修改模型buffer属性之后,访问buffer数据名字和其本身
for name, buf in model.named_buffers():
    print(name,'-->',buf.shape)

print('调用named_parameters()'.center(100,"-"))      # 修改模型buffer属性之后,访问模型parameter数据名字和其本身
for name, param in model.named_parameters():
    print(name,'-->',param.shape)

print('调用buffers()'.center(100,"-"))
for buf in model.buffers():
    print(buf.shape)

print('调用parameters()'.center(100,"-"))
for param in model.parameters():
    print(param.shape)

print('调用state_dict()'.center(100,"-"))
for k, v in model.state_dict().items():
    print(k, '-->', v.shape)  

输出结果为:

Windows PowerShell
版权所有 (C) Microsoft Corporation。保留所有权利。

尝试新的跨平台 PowerShell https://aka.ms/pscore6

加载个人及系统配置文件用了 840 毫秒。
(base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>  & 'D:\Anaconda3\envs\ssd4pytorch1_2_0\python.exe' 'c:\Users\chenxuqi\.vscode\extensions\ms-python.python-2020.12.424452561\pythonFiles\lib\python\debugpy\launcher' '63490' '--' 'c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\test2.py'
cuda(GPU)是否可用: True
torch的版本: 1.2.0+cu92
--------------------------------------------初始化之后模型修改之前---------------------------------------------
-----------------------------------------调用named_buffers()------------------------------------------
register_buffer_in --> torch.Size([4, 6])                     # 
----------------------------------------调用named_parameters()----------------------------------------
conv1.0.weight --> torch.Size([64, 1, 3, 3])
conv1.0.bias --> torch.Size([64])
--------------------------------------------调用buffers()---------------------------------------------
torch.Size([4, 6])
-------------------------------------------调用parameters()-------------------------------------------
torch.Size([64, 1, 3, 3])
torch.Size([64])
-------------------------------------------调用state_dict()-------------------------------------------
register_buffer_in --> torch.Size([4, 6])
conv1.0.weight --> torch.Size([64, 1, 3, 3])
conv1.0.bias --> torch.Size([64])
--------------------------------------------模型初始化以及修改之后---------------------------------------------
-----------------------------------------调用named_buffers()------------------------------------------
register_buffer_in --> torch.Size([4, 6])
register_buffer_out --> torch.Size([15, 15])
----------------------------------------调用named_parameters()----------------------------------------
conv1.0.weight --> torch.Size([64, 1, 3, 3])
conv1.0.bias --> torch.Size([64])
--------------------------------------------调用buffers()---------------------------------------------
torch.Size([4, 6])
torch.Size([15, 15])
-------------------------------------------调用parameters()-------------------------------------------
torch.Size([64, 1, 3, 3])
torch.Size([64])
-------------------------------------------调用state_dict()-------------------------------------------
register_buffer_in --> torch.Size([4, 6])
register_buffer_out --> torch.Size([15, 15])
conv1.0.weight --> torch.Size([64, 1, 3, 3])
conv1.0.bias --> torch.Size([64])
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> 

模型中的buffer和parameter区别

在这里插入图片描述
在这里插入图片描述
下面使用代码进行说明:
pytorch保存模型参数的一种方式为:

# save
torch.save(model.state_dict(), PATH)

# load
model = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

可以看到模型保存的是 model.state_dict() 的返回对象。 model.state_dict() 的返回对象是一个 OrderDict ,它以键值对的形式包含模型中需要保存下来的参数,例如:

class MyModule(nn.Module):
    def __init__(self, input_size, output_size):
        super(MyModule, self).__init__()
        self.lin = nn.Linear(input_size, output_size)
    def forward(self, x):
        return self.lin(x)

module = MyModule(4, 2)
print(module.state_dict())

输出结果:
在这里插入图片描述
分析一个parameter和buffer的例子:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        buffer = torch.randn(2, 3)  # tensor
        self.register_buffer('my_buffer', buffer)
        self.param = nn.Parameter(torch.randn(3, 3))  # 模型的成员变量

    def forward(self, x):
        # 可以通过 self.param 和 self.my_buffer 访问
        pass
model = MyModel()
for param in model.parameters():
    print(param)
print("----------------")
for buffer in model.buffers():
    print(buffer)
print("----------------")
print(model.state_dict())

输出结果:
在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/31660.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

sqli-labs/Less-52

这一关输入几次rand()页面发生了改变 说明这一关的注入类型属于数字型注入 接着尝试一下报错注入 输入如下 sortupdatexml(1,if(11,concat(0x7e,database(),0x7e),1),1)-- 发现没有回显 显然不能使用报错注入 只能使用盲注了 这一关我们通过rand()函数的形式来实现盲注 首先…

HTML+CSS+JavaScript仿京东购物商城网站 web前端制作服装购物商城 html电商购物网站

常见网页设计作业题材有 个人、 美食、 公司、 学校、 旅游、 电商、 宠物、 电器、 茶叶、 家居、 酒店、 舞蹈、 动漫、 服装、 体育、 化妆品、 物流、 环保、 书籍、 婚纱、 游戏、 节日、 戒烟、 电影、 摄影、 文化、 家乡、 鲜花、 礼品、 汽车、 其他等网页设计题目, A…

SPC5777CDK3MMO4 IC MCU 32BIT,SPC5777CDK3MME3

MPC5777C Power Architecture 微控制器是一款高性能多核MCU,优化用于要求先进性能、计时系统、安全性和功能性安全能力的工业和汽车控制应用。MPC5777C设有两个独立的Power Architecture z7内核(运行速度高达300MHz)以及一个z7内核&#xff0…

搜索技术——盲目与启发

如果有兴趣了解更多相关内容,欢迎来我的个人网站看看:瞳孔空间 搜索是人工智能中的一个基本问题,并与推理密切相关。搜索策略的优劣将直接影响到智能系统的性能与推理效率。 一:搜索的基本概念 搜索:根据问题的实际…

Linux openvino 环境搭建遇见的问题

1.编译openvino源码,报错(PythonLibsNew) 通过报错路径结合cmakeLists.txt发现,有个文件夹内容为空导致的,因此需要单独下载对应的文件(这个文件夹藏的比较深,之前没有注意到,只关注openvino/thridparty下&…

(附源码)计算机毕业设计JavaJava毕设项目财务管理系统的设计与实现

项目运行 环境配置: Jdk1.8 Tomcat8.5 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: Springboot mybatis Maven Vue 等等组成,B/…

膜前氟离子超标的解决方法,除氟离子技术

原水氟化物浓度在150mg/l左右,处理水量大概在30m/h,要求出水氟化物浓度要小于10mg/l,同时呢对出水稳定性方面要求也非常严格。 “预过滤系统离子交换除氟反渗透系统超纯水系统”的工艺,利用季胺1型官能团选择性吸附氟化物&#x…

Nginx负载均衡配置、限流配置、Https配置详解

一. 负载均衡 1. 用法 通过proxy_pass 可以把请求代理至后端服务,但是为了实现更高的负载及性能, 我们的后端服务通常是多个, 这个是时候可以通过upstream 模块实现负载均衡。 使用的模块为:【ngx_http_upstream_module】&#…

股票买卖明细接口是怎样实现查询交易数据的?

股票买卖明细接口作为软件应用而言,很多资源和数据不一定就是由其自身提供的,所以说某些功能还是需要调用第三方提供的服务,这其中就涉及到API接口的调用。也就是说,股票买卖明细接口是与数据端直接挂钩的,通过一些量化…

大数据毕设选题 - 招聘岗位数据分析可视化(python 爬虫)

文章目录1 前言1 课题背景2 实现效果3 项目实现3.1 概述3.2 数据采集3.3 数据清洗与预处理4 数据分析与可视化Flask框架介绍5 最后1 前言 🔥 Hi,大家好,这里是丹成学长的毕设系列文章! 🔥 对毕设有任何疑问都可以问学…

Head First设计模式(阅读笔记)-03.装饰者模式

星巴兹咖啡 咖啡存在许多的种类,同时也有不同的调料。此时用户可以单点咖啡,也可以点咖啡调料,请计算费用(这里咖啡和调料都属于Drink的一类) 简单实现 方案1 每出现一种组合就实现一个类,但是每次增加一个咖啡种类或者一个新的调…

Centos7通过SSH使用密钥实现免密登录

Centos7通过SSH使用密钥实现免密登录 日常开发中,难免会有登录服务器的操作,而通过ssh方式登录无疑是比较方便的一种方式。 如果登录较频繁,使用密钥实现免密登录无疑更是方便中的方便。因此本文就简单说一说如何实现免密登录。一、安装配置ssh服务 默认情况下Centos7是安装…

推荐一款制作精良、功能强大、毫秒级精度的定时任务执行软件

目录 一、定时执行专家 - 功能详细 二、定时执行专家 - 最新版下载 三、定时执行专家 - 更新日志 四、关键字/Keyword 一、定时执行专家 - 功能详细 1、支持多种触发方式(定时方式):倒计时执行、持续执行、键盘鼠标空闲指定时长时执行、…

了解的Java泛型

作者:~小明学编程 文章专栏:JavaSE基础 格言:目之所及皆为回忆,心之所想皆为过往 目录 前言 什么是泛型 为什么要引入泛型 使用泛型 裸类型 泛型类的定义 类型擦除 通配符 什么是通配符 通配符的上下界 通配符的使用 …

Cookie和Session的工作流程以及Servlet中与之相关的API

目录 一、认识Cookie和Session 1、Cookie 2、Session 二、Cookie和Session的工作流程 三、Servlet中与Cookie和Session相关的API 1、HttpServletRequest类中的相关方法 2、HttpServletResponse类中的相关方法 3、HttpSession类中的相关方法 4、Cookie类中的相关方法 …

常用的框架技术-10 Spring Security Spring的企业应用系统提供声明式的安全访问控制解决方案的安全框架

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录1.Spring Security简介1.1 Spring Security概述1.2 Spring Security历史发展1.3 产品的对比1.3.1 Spring Security1.3.2 Shiro1.4 Spring Security 核心类1.4.1 Auth…

既然有了ES,为何还用ClickHouse——从原理万字总结ClickHouse为何这么快

通过了解 CH 的几大特性了解千亿级企业 ClickHouse 实时处理引擎架构设计、核心技术设计、运行机理全流程。 文章目录1 初始 ClickHouse1.1 什么是 ClickHouse1.2 ClickHouse 的优缺点1.3 谁在用 ClickHouse3 数据引擎3.1 库引擎3.2 表引擎3.3 MergeTree 引擎4 工作原理4.1 数据…

浙大MBA经验分享:在工作生活的缝隙中奋勇上岸

非常高兴可以为大家分享我的浙大MBA备考经验!首先针对我的背景简要介绍一下,我本科毕业于省内的普通大学浙江理工大学,学的是设计专业,就业于一家外企公司。在2022年的联考中获得了综合133,英语75,总分是20…

一个简单的音乐网站设计与实现(HTML+CSS)

⛵ 源码获取 文末联系 ✈ Web前端开发技术 描述 网页设计题材,DIVCSS 布局制作,HTMLCSS网页设计期末课程大作业 | 音乐网页设计 | 仿网易云音乐 | 各大音乐官网网页 | 明星音乐演唱会主题 | 爵士乐音乐 | 民族音乐 | 等网站的设计与制作 | HTML期末大学生网页设计作…

常见集群算法解析

Gossip协议 Gossip协议简介 定义 Gossip protocol,又叫 Epidemic Protocol (流行病协议),也叫“流言算法” 、 “疫情传播算法”等。其名称已经形象的说明了算法的原理和工作方式 应用场景 分布式网络,无集中管理节…