Pytorch基础:torch.load_state_dict()方法在加载时不会检查类型

news2025/1/10 16:15:52

相关阅读

Pytorch基础icon-default.png?t=N7T8https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482


        笔者在使用torch.nn.module的load_state_dict中出现了一个问题,一个被注册的张量在加载后居然没有变化,一开始以为是加载出现了问题,但发现其他参数加载成功,思索后发现是注册的张量的类型是整型而checkpoint中保存为浮点数类型,恰好注册时的默认值给的是0,而checkpoint中的浮点数又在0到1之间,因此出现了这个令人困惑的bug。

        下面首先复现这个bug。

import torch
import torch.nn as nn

# 定义一个简单的线性模型,参数类型为整数
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.register_buffer('test', torch.tensor(0)) # 注册一个整型张量

# 创建一个简单模型实例
model = SimpleModel()

# 创建一个浮点数作为参数
float_parameter = torch.tensor(0.6)

# 将注册名指向另一个浮点型张量
model.test = float_parameter

# 保存模型
torch.save(model.state_dict(), 'model.pth')

# 直接使用原模型加载
checkpoint = torch.load('model.pth')
model.load_state_dict(checkpoint)

# 打印加载后的参数
print(model.test)

# 直接使用新模型加载
model_1 = SimpleModel()
model_1.load_state_dict(checkpoint)

# 打印加载后的参数
print(model_1.test)
输出:
tensor(0.6000)
tensor(0)

        可以看到,当模型中注册的名字(test),指向了一个类型不符的张量后,并不会导致浮点型张量被截断为整型,这是因为此处是直接使用赋值号=,使名字指向了另一个张量。

        但使用load_state_dict()方法与使用赋值号是不同的,load_state_dict()方法的实现中,调用了_load_from_state_dict()方法,其中调用了copy_()方法,进行了原位(in-place)数据替换,这可能会进行截断,下面是原位替换的一个例子。

import torch

# 创建两个张量
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5.1, 6.1], [7.1, 8.1]])

# 查看张量对象的id
print(id(a))
print(id(b))

# 查看底层存储的内存地址
print(a.storage().data_ptr())
print(b.storage().data_ptr())

# 将张量 b 中的值复制到张量 a 中
a.copy_(b)

# 打印复制后的结果
print(a)

# 查看张量对象的id
print(id(a))
print(id(b))

# 查看底层存储的内存地址
print(a.storage().data_ptr())
print(b.storage().data_ptr())
输出:
2604425272672
2604426953808  
2604511348096  
2602930352832  
tensor([[5, 6],
        [7, 8]])
2604425272672
2604426953808
2604511348096
2602930352832

        在保存了模型的状态字典后,使用load_state_dict()方法加载后,也不会有任何截断问题,因为对于原模型而言,名字test指向的是一个浮点型张量,此时原位替换,类型吻合。但是对于一个新的模型,此时的test指向的是一个整型张量,此时原位替换,会发生截断。

        因此,在注册一个张量时,需要确保其在注册时和保存时的类型吻合,此处除了指形状,还有类型,否则可能会出现意想不到的bug。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1642156.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ATTCK的优缺点分别是什么

ATT&CK(Adversarial Tactics, Techniques, and Common Knowledge)框架是一个广泛使用的资源,它提供了对网络威胁的深入洞察,特别是关于攻击者可能采取的战术、技术和程序(TTPs)。以下是ATT&CK框架的优缺点: 优点: 全面的威胁情报:ATT&CK框架详细描述了各种…

Linux基础之gcc/g++

目录 一、gcc/g的介绍 二、一个程序的翻译过程 2.1 预处理阶段 2.2 编译阶段 2.3 汇编阶段 2.4 链接阶段 三、动静态库简介 四、动静态库的优缺点 一、gcc/g的介绍 首先,先简单的介绍一下gcc/g。 GCC(GNU Compiler Collection)是一个…

MySql安装到配置-超详细版

哈喽宝子们,好久不见,大一五一有没有出去玩呀~反正我是没有出去,就5月1号那天晚上跟室友去看了个电影,然后这几天基本都在宿舍“卷”,其实也不是啦,就是学习学习,因为一方面,暑期实习…

debian10 (armbian) 配置CUPS 服务

更新apt apt-update安装相关软件 apt-get install ghostscript apt-get install dc apt-get install foomatic-db-engine apt-get install cups3.修改配置文件 nano /etc/cups/cupsd.conf Listen localhost:631改为 Listen 0.0.0.0:631 以下四段配置加入Allow All # Only li…

大模型公开可用的模型检查点或 API

文章目录 公开可用的模型检查点或 APILLaMA 变体系列大语言模型的公共 API 公开可用的模型检查点或 API 众所周知,大模型预训练是一项对计算资源要求极高的任务。因此,经过预训练的公开模型检查点(Model Checkpoint)对于推动大语言…

2024牛客五一集训派对day2 Groundhog Looking Dowdy 个人解题思路

前言: 被实验室教练要求要打的这次五一牛客的训练赛,这些区域赛难度的题对于大一的我来说难度实在是太高了,我和我的队友只写了一些非常简单的签到题,其他题目都没怎么看(我们太弱了),但我可以分…

Spring Cloud学习笔记(Hystrix):execute,queue,observe,toObservable样例和特性

这是本人学习的总结,主要学习资料如下 - 马士兵教育 1、Overview2、execute()2.1、Overview2.2、示例 3、queue()3.1、Overview3.2、示例 4、observe()4.1、Overview4.2、示例 5、toObservable()5.1、observe()和toObservable()的区别 1、Overview 我们知道Hystrix…

一文看懂卷积神经网络CNN(1)—前馈神经网络

目录 参考资料 一、神经网络 1、人脑神经网络 2、人工神经网络 3、神经网络的发展历史 二、前馈神经网络 1、神经元 (1)Sigmoid型函数 ① Logistic函数 ②Tanh函数 ③两个函数形状对比 (2)ReLU函数 ① 带泄露的ReLU函…

vue 设置输入框只能输入数字且只能输入小数点后两位,并且不能输入减号

<el-input v-model.trim"sb.price" placeholder"现价" class"input_w3" oninput"valuevalue.replace(/[^0-9.]/g,).replace(/\.{2,}/g,.).replace(/^(\-)*(\d)\.(\d\d).*$/,$1$2.$3)"/> 嘎嘎简单、、、、、、、、、

微软如何打造数字零售力航母系列科普08 - Yobe 如何联手微软Azure,安全使用客户数据,预测客户购买行为?

Yobe 如何联手Azure&#xff0c;安全使用客户数据&#xff0c;预测客户购买行为&#xff1f; 在当今数据驱动的世界中&#xff0c;了解客户行为并有能力通过数据和分析预测客户意图是企业保持竞争力所应具备的首要优势。Yobi由Max Snow、Bill Wise和Tom Griffiths于2019年创立&…

【软考高项】三十一、成本管理4个过程

一、规划成本管理 1、定义、作用 定义&#xff1a;确定如何估算、预算、管理、监督和控制项目成本的过程作用&#xff1a;在整个项目期间为如何管理项目成本提供指南和方向 应该在项目规划阶段的早期就对成本管理工作进行规划&#xff0c;建立各成本管理过程的基本框架&…

题目:极速返航

问题描述&#xff1a; 解题思路&#xff1a; 看到题目要求最大值最小&#xff0c;最小值最大&#xff1a;一眼二分答案。二分的时间复杂度是O(log n)。 二分枚举可能的答案X。check()函数判断合法情况。 AC代码&#xff1a; #include<bits/stdc.h> using namespace std…

算法课程笔记——蓝桥云课第六次直播

&#xff08;只有一个数&#xff0c;或者因子只有一个&#xff09;先自己打表&#xff0c;找找规律函数就是2的n次方 异或前缀和 相等就抵消 先前缀和再二分

OpenCV(一) —— OpenCV 基础

1、OpenCV 简介 OpenCV&#xff08;Open Source Computer Vision Library&#xff09;是一个基于 BSD 许可开源发行的跨平台的计算机视觉库。可用于开发实时的图像处理、计算机视觉以及模式识别程序。由英特尔公司发起并参与开发&#xff0c;以 BSD 许可证授权发行&#xff0c…

【深度学习】位置编码

一、引言 Self-Attention并行的计算方式未考虑输入特征间的位置关系&#xff0c;这对NLP来说是不可接受的&#xff0c;毕竟一个句子中每个单词都有着明显的顺序关系。Transformer没有RNN、LSTM那样的顺序结构&#xff0c;所以Transformer在提出Self-Attention的同时提出了Posi…

ruoyi漏洞总结

若依识别 黑若依 :icon hash"-1231872293 绿若依 :icon hash"706913071” body" 请通过前端地址访 " body" 认证失败&#xff0c;无法访问系统资源 " 如果页面访问显示不正常&#xff0c;可添加默认访问路径尝试是否显示正常 /login?redi…

STM32 F103C8T6学习笔记17:类IIC通信(SMBus协议)—MLX90614红外非接触温度计

今日学习配置MLX90614红外非接触温度计 与 STM32 F103C8T6 单片机的通信 文章提供测试代码讲解、完整工程下载、测试效果图 本文需要用到的大概基础知识&#xff1a;1.3寸OLED配置通信显示、IIC通信、 定时器配置使用 这里就只贴出我的 OLED驱动方面的网址链接了&#xff1a…

QT+串口调试助手+基本版

一、创建串口调试助手UI界面 1、首先生成串口连接必要参数界面&#xff0c;删除关闭串口控件 2、给参数下拉框添加常见的选项&#xff0c;删除关闭串口控件 3、将串口调试助手参数界面布局整齐&#xff0c;删除关闭串口控件 4、更改控件名字&#xff0c;方便后续编程&#xff…

深度学习中权重初始化的重要性

深度学习模型中的权重初始化经常被人忽略&#xff0c;而事实上这是非常重要的一个步骤&#xff0c;模型的初始化权重的好坏关系到模型的训练成功与否&#xff0c;以及训练速度是否快速&#xff0c;效果是否更好等等&#xff0c;这次我们专门来看看深度学习中的权重初始化问题。…

stm32之hal库串口中断和ringbuffer的结合

前言 结合hal库封装的中断处理函数使用rt-thread内部的rt-ringbuffer数据结构源码改造hal库串口部分的源码&#xff0c;将内部静态方法变为弱引用的函数&#xff0c;方便重写标志位采用信号量或变量的两种方式&#xff0c;内部数据分配方式采用动态和静态两种方式 hal库部分串…