self.register_buffer方法使用解析(pytorch)

news2024/10/7 15:24:05

self.register_buffer就是pytorch框架用来保存不更新参数的方法。

列子如下:

self.register_buffer("position_emb", torch.randn((5, 3)))

第一个参数position_emb传入一个字符串,表示这组参数的名字,第二个就是tensor形式的参数torch.randn((5, 3),并一次初始化后保存于模型,不会有梯度传播给它,能被模型的model.state_dict()记录下来,可以理解为模型的常数。当然,你想保留固定值,使用如下代码:

self.register_buffer("position_emb", torch.tensorrt([[2,5],[8,9]]))

进一步探讨训练对该参数是否有影响,答案是:没影响。具体可看下面实现的列子代码:

import torch
from torch.nn import Embedding

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.emb = Embedding(5, 3)
        self.register_buffer("position_emb", torch.randn((5, 3)))
    def forward(self,vec):
        input = torch.tensor([0, 1, 2, 3, 4])
        emb_vec1 = self.emb(input)
        emb_vec1=emb_vec1+self.position_emb
        output = torch.einsum('ik, kj -> ij', emb_vec1, vec)
        return output
def simple_train():
    model = Model()
    vec = torch.randn((3, 1))
    label = torch.Tensor(5, 1).fill_(3)
    loss_fun = torch.nn.MSELoss()
    opt = torch.optim.SGD(model.parameters(), lr=0.015)
    print('初始化后position_emb参数:\n',model.position_emb)
    for iter_num in range(100):
        output = model(vec)
        loss = loss_fun(output, label)
        opt.zero_grad()
        loss.backward(retain_graph=True)
        opt.step()
    print('训练后position_emb参数:\n', model.position_emb)

if __name__ == '__main__':
   simple_train()  # 训练与保存权重

实现结果如下:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1178209.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

微信Wxid转换微信号

微信号在申请的时候,系统随机分配了一个微信原始ID,该ID号以wxid_开头,后面是随机的字符串 分配的原始ID是目前是不可以直接用来加好友的,需要转换成微信号才能加好友, 经过逆向分析通过PC端找到了该接口并且可以成功用…

Langchain知识点(下)

背景: 这部分给主要介绍Langchain的agent部分,前面已经章节已经介绍了思维和思路作为一种数据资产是这一次LLM数据化的核心。也介绍了各种的chain,那么既然有了chain可以把专家思路和专家思维固化并且可被方便的共享和利用;那为什…

数据结构-链表的简单操作实现

目录 0.链表前序工作 1.构建出一个链表 2.展示链表中的所有存储数据 3.查找关键字key是否在链表中 4.求链表的长度 5.头插法 6.尾插法 7.插入任意位置(规定第一个元素位置为0下标) 8.删除第一次出现的值为key的关键字 9.删除所有值为key的关键字…

【算法】通信线路(二分,堆优化版dijkstra)

题目 在郊区有 N 座通信基站,P 条 双向 电缆,第 i 条电缆连接基站 Ai 和 Bi。 特别地,1 号基站是通信公司的总站,N 号基站位于一座农场中。 现在,农场主希望对通信线路进行升级,其中升级第 i 条电缆需要花费…

Tensor.scatter_add_函数解释:

Tensor.scatter_add_(dim, index, src) → Tensor out.scatter_add_(dim, index, src) 1.参数: dim (int) – 哪一dim进行操作 index (LongTensor) – 要在的out的哪一index进行操作 src (Tensor) – 待操作的源数字 2.官方的解释的操作如下: 3.例…

基于8086汽车智能小车控制系统

**单片机设计介绍,基于8086汽车智能小车控制系统 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于 8086 的汽车智能小车控制系统是一种将微处理器技术应用于汽车控制的系统。下面是其主要的设计介绍: 硬…

ubuntu22.04桌面版系统无法识别USB摄像头

虚拟机连接电脑摄像头连接失败(驱动程序错误) 本文为转载:版权归远作者所有,之所以转载是为了避免被原作者删除 巴黎铁塔下的女孩 你尽管努力,剩下的交给时间 虚拟机调用电脑的摄像头,正常情况下只需点击…

方案分享:F5机器人防御助企业应对复杂攻击

企业是Bot攻击者的目标,网络犯罪分子会不断调整他们的攻击,来攻破愈发成熟的Bot防护,这使企业安全团队时刻处于紧张状态。如果不能有效地管理Bot,应用性能、客户体验和业务都会被影响,但在尝试阻止这些攻击时&#xff…

技术分享 | web自动化测试-文件上传与弹框处理

实战演示 文件上传 input 标签使用自动化上传,先定位到上传按钮,然后 send_keys 把路径作为值给传进去. 如图所示,是企业微信文件上传的页面 定位到标签为 input,type 为 file 的元素信息,然后使用 send_keys 把文件…

Harbor企业级Registry基础镜像仓库的详细安装使用教程(保姆级)

Harbor Docker 官方提供的私有仓库 registry,用起来虽然简单 ,但在管理的功能上存在不足。 Harbor是vmware一个用于存储和分发Docker镜像的企业级Registry服务器,harbor使用的是官方的docker registry(v2命名是distribution)服务去完成。 ha…

原语:串并转换器

串并转换器OSERDESE2 可被Select IO IP核调用。 OSERDESE2允许DDR功能 参考: FPGA原语学习与整理第二弹,OSERDESE2串并转换器 - 知乎 (zhihu.com) 正点原子。 ISERDESE2原语和OSERDESE2原语是串并转换器,他的的功能都是实现串行数据和并行…

阿里云服务器怎么购买更省钱?优惠入口分享

阿里云服务器怎么购买更省钱?不要直接在云服务器页面购买,不划算,在阿里云特价活动上购买更优惠,阿腾云atengyun.com分享阿里云服务器省钱购买方法,节省90%,可以先在阿里云CLUB中心领券 aliyun.club 专用满…

JavaScript_Element对象_方法

1、Element.focus() Element.focus方法用于将当前页面的焦点,转移到指定元素上 2、Element.blur() Element.blur方法用于将焦点从当前元素移除 3、Element.remove() Element.remove方法用于将当前元素节点从它的父节点移除 4、Element.getBoundingClientRect() …

蓝桥杯练习

即约分数 题目 思路 遍历所有的x&#xff0c;y&#xff0c;判断x/y是不是即越约分数。 代码 #include <iostream> using namespace std; int gcd(int x,int y) {int r;while(y!0){rx%y;xy;yr;}return x; } int main() {// 请在此输入您的代码int sum4039;//1/y和x/1都…

C函数速查手册

链接下载&#xff1a;提取码:Tywdhttps://www.123pan.com/s/JRpSVv-PLnjv.html 双击打开即可

RxJava/RxAndroid的基本使用方法(一)

文章目录 一、什么是RxJava二、使用前的准备1、导入相关依赖2、字段含意3、Upstream/Downstream——上/下游4、BackPressure5、BackPressure策略6、“热” and “冷” Observables7、 基类8、事件调度器9、操作符是什么&#xff1f; 三、RxJava的简单用法1、Observable——Obse…

Docker安装教程

Docker安装教程 安装教程Centos7.6docker镜像源修改docker目录修改 Ubuntu20.04docker镜像源修改docker数据目录修改 安装教程 Centos7.6 &#x1f680;docker支持的Cetnos操作系统版本 CentOS 7 CentOS 8 (stream) CentOS 9 (stream) &#x1f680;支持的CPU ARM/X86_64 查看…

django+drf+vue 简单系统搭建 (1) - django创建项目

本系列文章为了记录自己第一个系统生成过程&#xff0c;主要使用django,drf,vue。本人非专业人士&#xff0c;此文只为记录学习&#xff0c;若有部分描述不够准确的地方&#xff0c;烦请指正。 建立这个系统的原因是因为&#xff0c;在生活中&#xff0c;很多觉得可以一两行代码…

分页存储管理、分段存储管理、段页式存储管理、两级页表

目录: 分页存储管理 基本地址存储机构 具有快表的地址存储机构 两级页表 分段存储管理 段页式管理方式 分页存储管理(重点) 首先回顾,逻辑地址和物理地址. 为什么要引入分页存储管理? 把物理地址下,离散的各个小片都利用起来,也就是在逻辑地址中看似是连续存储的,实际上对应…

ViT模型中的tokens和patches概念辨析

概念辨析 在ViT模型中&#xff0c;“tokens”&#xff08;令牌&#xff09;和"patches"&#xff08;图像块&#xff09;是两个相关但不同的概念。 令牌&#xff08;Tokens&#xff09;&#xff1a;在ViT中&#xff0c;令牌是指将输入图像分割成固定大小的图块&#…