PyTorch学习笔记(十四)——网络模型的保存与读取

news2025/1/26 15:50:14

两种方式保存和加载模型

方式一

保存模型

不仅保存了网络模型的结构,也保存了网络模型的参数

import torch
import torchvision

vgg16 = torchvision.models.vgg16(weights=False)
torch.save(vgg16,"vgg16_method1.pth")

加载模型

打印出的是网络模型的结构 

import torch

model = torch.load("vgg16_method1.pth")
print(model)

方式二

保存模型

网络模型的参数保存为字典,不保存网络模型的结构(官方推荐的保存方式,用的空间小)

import torch
import torchvision

vgg16 = torchvision.models.vgg16(weights=False)
torch.save(vgg16.state_dict(),"vgg16_method2.pth")

 加载模型

打印出的是参数的字典形式

import torch

model = torch.load("vgg16_method2.pth")
print(model)

 如何恢复网络模型结构?

import torchvision.models
 
vgg16 = torchvision.models.vgg16(pretrained=False)  # 预训练设置为False
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))  # vgg16通过字典形式,加载状态即参数
print(vgg16)

 陷阱

问题描述

首先在 model_save.py 中写以下代码并运行

import torch
from torch import nn


# 陷阱
class Mynn(nn.Module):
    def __init__(self):
        super(Mynn, self).__init__()
        self.conv1 = nn.Conv2d(3,64,3)

    def forward(self,x):
        x = self.conv1(x)
        return x

mynn = Mynn()
torch.save(mynn,"mynn_method.pth")

再在 model_load.py 中写以下代码

import torch

# 陷阱
model = torch.load("mynn_method.pth")
print(model)

运行后报错

 解决办法1:

需要将 model_save.py 中的网络结构复制到 model_load.py 中(不需要写mynn = Mynn()),即

import torch
from torch import nn

class Mynn(nn.Module):
    def __init__(self):
        super(Mynn, self).__init__()
        self.conv1 = nn.Conv2d(3,64,3)

    def forward(self,x):
        x = self.conv1(x)
        return x

model = torch.load("mynn_method.pth")
print(model)

 解决办法2:

实际写项目过程中,直接定义在一个单独的文件中(如model_save.py),再在 model_load.py 中

from model_save import *

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/894521.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++新经典05--文件操作

文件简介 文件在程序设计中是一个比较重要的概念,这里所说的文件,是指保存在硬盘、U盘等存储介质上的数据,这些存储介质(简称磁盘)上的数据就是以一个个文件的形式体现,每一个文件有一个对应的名字&#x…

描述性统计:集中趋势和分散

一、说明 在本文中,我们将深入研究描述性统计领域,探索其不同方面,包括统计类型、总体与样本、参数与统计、数据类型以及集中趋势和离散的度量。 让我幽默地向您介绍统计数据。 “统计数据就像比基尼。他们揭示的东西是暗示性的,但…

API开放!将语聚AI智能助手接入到您的自有系统中

概述 语聚AI基于集简云强大的应用软件“连接器”能力,提供了工具延展、知识延展、模型延展和嵌入集成等一系列功能,为用户带来了更加强大和智能的AI新体验。 我们深知,每家企业对于AI应用都有自己独特的需求和应用场景,只有通过开…

接口测试之Postman 安装与使用

Postman 安装 官网下载地址 www.postman.com/downloads Postman 使用 发送get请求 新建请求 填写请求方式:GET 填写请求 URL: ceshiren.com/ httpbin.ceshiren.com/get 填写请求参数: para_key para_value 发送 POST 请求 请求方式&a…

2023年的IC求职究竟有多难?

去年应移知教育要求,写了一篇关于秋招的看法《聊一聊今年的芯片就业市场》,当时提出来的点很简单: ● 处在赛道内的人要正视竞争的难度,提升自身的企业价值分; ● 想要进入赛道的人要放平心态,降低和保留…

C++ string类的模拟实现

模拟实现string类不是为了造一个更好的轮子,而是更加理解string类,从而来掌握string类的使用 string类的接口设计繁多,故而不会全部涵盖到,但是核心的会模拟实现 库中string类是封装在std的命名空间中的,所以在模拟…

ImageKit10 VCL Crack

ImageKit10 VCL Crack ImageKit10 VCL是一个允许您快速轻松地将图像处理功能添加到应用程序中的组件。使用ImageKit10 VCL,您可以编写从TWAIN扫描仪和数码相机检索图像的应用程序;加载和保存图像文件,并将图像从一种格式转换为另一种格式;编辑图像、在图…

MySQL的Json类型字段IN查询分组和优化方法

前言 MySQL从5.7的版本开始支持Json后,我时常在设计表格时习惯性地添加一个Json类型字段,用做列的冗余。毕竟Json的非结构性,存储数据更灵活,比如接口请求记录用于存储请求参数,因为每个接口入参不一致,也…

python的交互式库Qgrid

目录 Qgrid介绍Qgrid使用Qgrid使用过程中遇到的问题解决方案 Qgrid介绍 在Jupyter notebook中直接读取DataFrame数据,只显示为静态表格的形式,没有类似于excel的筛选等交互式功能。Qgrid作为 Jupyter notebook 组件,可以为我们的 DataFrame …

三本书与三场发布会,和鲸社区重新定义编程类书籍从阅读到实践新体验

当 AI 开发者社区配备 AI 基础设施开发平台工具时,它还能做什么? 答案是:过去半年,和鲸社区凭借在气象、医学、社科等垂直领域的长期积累以及多方伙伴的支持,联合举办了三场新书发布会——从 Python 到 R 语言 、从气…

程序员与ChatGPT的交织:探索人工智能和软件开发的新篇章

目录 前言创作者程序员会被替代吗程序员如何更好的使用chatgpt 前言 在技术持续进步的当今世界,程序员与人工智能(AI)之间的关系越来越紧密。特别是对于一些创新性的技术如OpenAI旗下的ChatGPT,这种联系就更为明显。程序员与Chat…

2023/8/16 华为云OCR识别驾驶证、行驶证

目录 一、 注册华为云账号开通识别驾驶证、行驶证服务 二、编写配置文件 2.1、配置秘钥 2.2、 编写配置工具类 三、接口测试 3.1、测试接口 3.2、结果 四、实际工作中遇到的问题 4.1、前端传值问题 4.2、后端获取数据问题 4.3、使用openfeign调用接口报错 4.3、前端显示问题…

python bytes基本用法

目录 1 第一个字符变大写,其余字符变小写 capitalize() 2 生成指定长度内容,然后把指定的bytes放到中间 center() 3 计数 count() 4 解码 decode() 5 是否以指定的内容结尾 endswith() 6 将制表符调整到指定大小 expandtabs() 7 寻找指…

ref拿到组件的实例对象或者原生html标签

在组件中,或者html标签中写ref属性,就是在注册引用 可以通过ref拿到组件的实例对象 也可以通过ref拿到原生的html标签

Linux系统安装及使用HHDBCS

1 安装 1.1 下载HHDBCS 使用浏览器进入官方社区(恒辉产品社区),选择HHDBCS子社区,首页点击下载,进入下载页面; 选择官网下载/云盘下载皆可。 在弹出框中选择如图所示选项,点击下载&#xff…

带着设计思维画版图——第一次和第二次

版图设计目标: 面积小,性能好(少恶化),成本低 设计规则规定了同层与不同层之间的最小距离,因此限制了最小面积 模拟版图设计流程 第一步:设计原理图输入 常用快捷键如下: 介…

YOLO算法封装进入ros系统,识别结果供其他节点订阅

一,前期工作空间搭建 新建工作空间,第一级名称可以换,第二级src最好别换,这是ros系统的固定格式 mkdir -p workspace_yolo/src切换到工作空间 workspace_yolo,进行编译构建项目 cd workspace_yolo/catkin_make输出如下所示: 添加环境变量 cd devel/ 获取到devel文件路径…

模型预测笔记(一):数据清洗分析及可视化、模型搭建、模型训练和预测代码一体化和对应结果展示(可作为baseline)

模型预测 一、导入关键包二、如何载入、分析和保存文件三、修改缺失值3.1 众数3.2 平均值3.3 中位数3.4 0填充 四、修改异常值4.1 删除4.2 替换 五、数据绘图分析5.1 饼状图5.1.1 绘制某一特征的数值情况(二分类) 5.2 柱状图5.2.1 单特征与目标特征之间的…

花生十三 判断推理(三)分析类、推出类

分析类 题型 真假分析 定义:孰真孰假的真假话分析,命题真假无法确定,无法利用推出关系解题 解题思路 矛盾法(三种矛盾):A和非A,“A或B” 与“非A且非B” 技巧:一“找”矛盾&am…

在ARM服务器上一键安装Proxmox VE(以在Oracle Cloud VPS上为例)(甲骨文)

前言 如题,具体用到的说明文档如下 virt.spiritlhl.net 具体流程 首先是按照说明,先得看看自己的服务器符不符合安装 Proxmox VE的条件 https://virt.spiritlhl.net/guide/pve_precheck.html#%E5%90%84%E7%A7%8D%E8%A6%81%E6%B1%82 有提到硬件和软…