Pytorch深度学习-----神经网络模型的保存与加载(VGG16模型)

news2024/9/24 9:25:37

系列文章目录

PyTorch深度学习——Anaconda和PyTorch安装
Pytorch深度学习-----数据模块Dataset类
Pytorch深度学习------TensorBoard的使用
Pytorch深度学习------Torchvision中Transforms的使用(ToTensor,Normalize,Resize ,Compose,RandomCrop)
Pytorch深度学习------torchvision中dataset数据集的使用(CIFAR10)
Pytorch深度学习-----DataLoader的用法
Pytorch深度学习-----神经网络的基本骨架-nn.Module的使用
Pytorch深度学习-----神经网络的卷积操作
Pytorch深度学习-----神经网络之卷积层用法详解
Pytorch深度学习-----神经网络之池化层用法详解及其最大池化的使用
Pytorch深度学习-----神经网络之非线性激活的使用(ReLu、Sigmoid)
Pytorch深度学习-----神经网络之线性层用法
Pytorch深度学习-----神经网络之Sequential的详细使用及实战详解
Pytorch深度学习-----损失函数(L1Loss、MSELoss、CrossEntropyLoss)
Pytorch深度学习-----优化器详解(SGD、Adam、RMSprop)
Pytorch深度学习-----现有网络模型的使用及修改(VGG16模型)


文章目录

  • 系列文章目录
  • 一、网络模型的保存
    • 1.方法一
    • 2.方法二
  • 二、网络模型的加载
    • 1.方法一
    • 2.方法二
  • 三、总结


一、网络模型的保存

1.方法一

保存整个模型,包括其相关的所有参数

torch.save(obj, f, pickle_protocol=DEFAULT_PROTOCOL)

参数说明:

obj: 要保存的对象,可以是模型、张量、字典等。
f: 要保存到的文件路径或文件对象。
pickle_protocol: 序列化协议的版本,默认为DEFAULT_PROTOCOL。

代码如下:

import torch
import torchvision.models as models
from torch import nn

vgg16_true = models.vgg16(weights=True)
vgg16_false = models.vgg16(weights=False)

torch.save(vgg16_true, "vgg16_model_true.pth")

其中.pth是后缀标志。

在这里插入图片描述

2.方法二

只保存模型参数,在原有vgg16对象中使用.state_dict()方法即可。

代码如下:

import torch
import torchvision.models as models
from torch import nn

vgg16_true = models.vgg16(weights=True)
vgg16_false = models.vgg16(weights=False)

torch.save(vgg16_true.state_dict(), "vgg16_model_true_2.pth")

在这里插入图片描述

二、网络模型的加载

1.方法一

对应于上述中保存模型的方法1进行加载。

相关函数如下:

torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)

参数说明:

f: 要加载的文件路径或文件对象。
map_location: 可选参数,用于指定在哪个设备上加载模型。如果不提供该参数,默认会加载到当前设备。
pickle_module: 可选参数,用于指定用于反序列化的模块。默认为pickle。
pickle_load_args: 其他可选的用于反序列化的参数。

代码如下:

import torch
import torchvision.models as models
from torch import nn

model1 = torch.load("vgg16_model_true.pth")  # 因为vgg16_model_true.pth是使用方法一保存的,故输出后是整个模型网络结构
print(model1)
model2 = torch.load("vgg16_model_true_2.pth")  # 因为vgg16_model_true_2.pth是使用方法二保存的,只保留模型参数,故输出后是整个字典类型
print(model2)

vgg16_model_true.pth加载结果

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

vgg16_model_true_2.pth加载结果

OrderedDict([('features.0.weight', tensor([[[[-5.5373e-01,  1.4270e-01,  5.2896e-01],
          [-5.8312e-01,  3.5655e-01,  7.6566e-01],
          [-6.9022e-01, -4.8019e-02,  4.8409e-01]],

         [[ 1.7548e-01,  9.8630e-03, -8.1413e-02],
          [ 4.4089e-02, -7.0323e-02, -2.6035e-01],
          [ 1.3239e-01, -1.7279e-01, -1.3226e-01]],

         [[ 3.1303e-01, -1.6591e-01, -4.2752e-01],
          [ 4.7519e-01, -8.2677e-02, -4.8700e-01],
          [ 6.3203e-01,  1.9308e-02, -2.7753e-01]]],


        [[[ 2.3254e-01,  1.2666e-01,  1.8605e-01],
          [-4.2805e-01, -2.4349e-01,  2.4628e-01],
          [-2.5066e-01,  1.4177e-01, -5.4864e-03]],

         [[-1.4076e-01, -2.1903e-01,  1.5041e-01],
          [-8.4127e-01, -3.5176e-01,  5.6398e-01],
          [-2.4194e-01,  5.1928e-01,  5.3915e-01]],

         [[-3.1432e-01, -3.7048e-01, -1.3094e-01],
          [-4.7144e-01, -1.5503e-01,  3.4589e-01],
          [ 5.4384e-02,  5.8683e-01,  4.9580e-01]]],


        [[[ 1.7715e-01,  5.2149e-01,  9.8740e-03],
          [-2.7185e-01, -7.1709e-01,  3.1292e-01],
          [-7.5753e-02, -2.2079e-01,  3.3455e-01]],

         [[ 3.0924e-01,  6.7071e-01,  2.0546e-02],
          [-4.6607e-01, -1.0697e+00,  3.3501e-01],
          [-8.0284e-02, -3.0522e-01,  5.4460e-01]],

         [[ 3.1572e-01,  4.2335e-01, -3.4976e-01],
          [ 8.6354e-02, -4.6457e-01,  1.1803e-02],
          [ 1.0483e-01, -1.4584e-01, -1.5765e-02]]],


        ...,


2.方法二

import torch
import torchvision.models as models
from torch import nn

vgg16_true = models.vgg16(weights=True)

vgg16_true.load_state_dict(torch.load("vgg16_model_true_2.pth"))  # 针对第二种加载参数的情况,使其显示完整的网络结构
print(vgg16_true)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

注意: 加载模型时,要确保当前代码中使用的模型类与之前保存的模型类相同。

三、总结

torch.load()是PyTorch中用于加载保存的对象的函数,可以加载之前使用torch.save()保存的模型、张量、字典等。可以指定要加载的文件路径或文件对象,并可选地指定加载到的设备、反序列化模块等参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/854701.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

前端先行模拟接口(mock+expres+json)

目录 mock模拟数据:data/static.js 路由:index.js 服务器:server.js yarn /node 启动服务器:yarn start 客户端:修改代理路径(修改设置后都要重启才生效) 示例 后端框架express构建服务器 前端发起请求 静态数…

Power BI中实现购物篮分析详解

一、购物篮分析简介 相信,很多人都听过沃尔玛购物篮分析的故事---“啤酒和尿布湿“,即分析购买尿布湿的顾客最喜欢购买的商品是什么?(啤酒)。在零售终端经营中,通过购物篮分析,分析不同商品之间…

Leetcode-每日一题【剑指 Offer 16. 数值的整数次方】

题目 实现 pow(x, n) ,即计算 x 的 n 次幂函数(即,xn)。不得使用库函数,同时不需要考虑大数问题。 示例 1: 输入:x 2.00000, n 10输出:1024.00000 示例 2: 输入&#…

数据挖掘全流程解析

数据挖掘全流程解析 数据指标选择 在这一阶段,使用直方图和柱状图的方式对数据进行分析,观察什么数据属性对于因变量会产生更加明显的结果。 如何绘制直方图和条形统计图 数据清洗 观察数据是否存在数据缺失或者离群点的情况。 数据异常的两种情况…

每日后端面试5题 第三天

1. 线程有哪几种状态以及各种状态之间的转换?(必会) 看图: 图片来自 线程状态转换图及其5种状态切换_小曹的blog的博客-CSDN博客 图片来自 总算把线程六种状态的转换说清楚了! - 知乎 线程一共有4种状态,分别是: 1.…

js手写贪吃蛇游戏

前端手写贪吃蛇游戏 贪吃蛇游戏 场景 使用了js 和 html /css 就可以完成 一个贪吃蛇小游戏 技术分析 主要用到的几个技术点: clientWidth :元素的宽度,包含内边距clientHeight :元素的高度,包含内边距setInterval&am…

【论文笔记】Cross Modal Transformer: Towards Fast and Robust 3D Object Detection

原文链接:https://arxiv.org/abs/2301.01283 1. 引言 受到DETR启发,本文提出鲁棒的端到端多模态3D目标检测方法CMT(跨模态Transformer)。首先使用坐标编码模块(CEM),通过将3D点集隐式地编码为多…

面试笔记:Android 架构岗,一次4小时4面的体验

作者:橘子树 此次面试一共4面4小时,中间只有几分钟间隔。对持续的面试状态考验还是蛮大的。 关于面试的心态,保持悲观的乐观主义心态比较好。面前做面试准备时保持悲观,尽可能的做足准备。面后积极做复盘,乐观的接受最…

[分享]STM32G070 串口 乱码 解决方法

硬件 NUCLEO-G070RB 工具 cubemx 解决方法 7bit 改为 8bit printf 配置方法 添加头文件 #include <stdio.h> 添加重定向代码 #ifdef __GNUC__#define PUTCHAR_PROTOTYPE int __io_putchar(int ch)#else#define PUTCHAR_PROTOTYPE int fputc(int ch, FILE *f)#endi…

安装程序报错问题解决 -2147287037 <<30005>> 2203

本文如下报错适用&#xff1a; 一、The installer has encountered an unexpected error installing this package. Thismay indicate a problem with this package. The error code is 2203 二、错误 2203.数据库&#xff1a; C:\WINDOWS\Installer\inprogressinstallinfo.i…

别找了,这7个AI绘画图软件够你用了!

AI 绘图工具最妙的是也让人人都能成为朋友圈里的“画家”&#xff0c;如果你也想要拥有一个趁手的 AI 绘画工具&#xff0c;那么就跟随本文一起来看看吧&#xff01;本文精选了7全球顶尖的AI绘图工具给大家&#xff0c;包括&#xff1a;即时灵感、Jasper Art、Images.ai、Night…

休闲卤味强势崛起:卤味零食成为新一代热门美食

随着人们生活水平的提高和消费观念的转变&#xff0c;休闲卤味逐渐成为了人们日常生活中的热门美食。据最新数据显示&#xff0c;2022年&#xff0c;我国卤味市场销售额达到了约2000亿元&#xff0c;预计到2025年将突破3000亿元大关。其中&#xff0c;休闲卤味以每年10%的速度持…

趋势洞察:中国企业高质量出海白皮书!

目前&#xff0c;我国仍处于战略发展机遇期的大背景&#xff0c; 面对全球经济放缓、不确定性增强的常态&#xff0c;国内高端市场的竞争也日趋激烈&#xff0c;对于寻求高质量发展的中国企业&#xff0c; 出海将成为重要的增长点。 今天运营坛为大家整理了一份《中国企业高质量…

弹簧阻尼系统前馈PID位置控制(PLC完整闭环仿真SCL+ST代码)

弹簧阻尼系统的前馈PID控制请参看下面文章链接: 前馈控制之如何计算前馈量(质量弹簧阻尼系统)_前馈控制量_RXXW_Dor的博客-CSDN博客带前馈控制的博途PID程序请参看下面的文章链接:首先我们看下什么是弹簧阻尼系统。1、质量弹簧阻尼模型。_前馈控制量https://rxxw-control.bl…

使用Spring五大注解来更加简单的存储Bean对象

在使用Spring框架的时候我们如果使用这种方式来存储bean对象的话未免有点太麻烦了 <bean id"xxx" class"xxx"> </bean> 为了简化存储Bean对象的操作&#xff0c;我们可以使用五大类注解来进行存储Bean对象 我们首先要在配置文件配置扫描路径…

IoTDB在springboot2中的(二) 查询

上一章我们处理的基本的构建接入&#xff0c;以及插入的处理&#xff0c;那么接下来我们进行查询的操作处理。 我们继续在IoTDBSessionConfig工具类中加入查询的方法处理 /*** description: 根据SQL查询最新一条数据* author:zgy* param sql sql查询语句&#xff0c;count查询…

JVM 类加载和垃圾回收

JVM 1. 类加载1.1 类加载过程1.2 双亲委派模型 2. 垃圾回收机制2.1 死亡对象的判断算法2.2 垃圾回收算法 1. 类加载 1.1 类加载过程 对应一个类来说, 它的生命周期是这样的: 其中前 5 步是固定的顺序并且也是类加载的过程&#xff0c;其中中间的 3 步我们都属于连接&#xf…

【Java-16】动态代理的使用方法及原理实现

代理模式&#xff1a;静态代理 目标 了解静态代理模式实现 路径 静态代理概述静态代理案例 静态代理概述 静态代理&#xff1a; 是由程序员创建或工具生成代理类的源码&#xff0c;再编译成为字节码 &#xff08;字节码文件在没有运行java之前就存在了&#xff09; 在编译…

Linux——常用命令(2)

作者简介&#xff1a;一名云计算网络运维人员、每天分享网络与运维的技术与干货。 座右铭&#xff1a;低头赶路&#xff0c;敬事如仪 个人主页&#xff1a;网络豆的主页​​​​​ 前期回顾 【新星计划Linux】——常用命令&#xff08;1&#xff09; 目录 一.其它常用命…

vue或uniapp使用pdf.js预览

一、先下载稳定版的pdf.js&#xff0c;可以去官网下载 官网下载地址 或 pdf.js包下载(已配置好&#xff0c;无需修改) 二、下载好的pdf.js文件放在public下静态文件里&#xff0c; uniapp是放在 static下静态文件里 三、使用方式 1. vue项目 注意路径 :src"static/pd…