PyTorch-网络模型的保存和读取

news2024/9/30 20:21:57

1. 模型的保存

方法一:保存模型的结构+模型的参数

陷阱:需要让文件访问到你自己的模型定义方式,可以用import的方式引入先前的模型定义。

model_save.py

import torch
import torchvision

vgg16 = torchvision.models.vgg16(weights=None)
# 保存方式一
torch.save(vgg16, 'vgg16_method1.pth')

方法二:保存模型的参数(官方推荐,文件小一些)

model_save.py

import torch
import torchvision

vgg16 = torchvision.models.vgg16(weights=None)
# 保存方式二 保存网络模型的参数
torch.save(vgg16.state_dict(), 'vgg16_method2.pth')

2. 模型的加载

model_load.py(对应方法一的)

import torch

# 加载模型
model = torch.load('vgg16_method1.pth')
print(model)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

Process finished with exit code 0

model_load.py(对应方法二的)

model2 = torch.load('vgg16_method2.pth')
print(model2)

OrderedDict([('features.0.weight', tensor([[[[ 0.0588, -0.0743, -0.1424],
          [-0.0034,  0.0577,  0.0819],
          [-0.0233, -0.0427,  0.1821]],

         [[ 0.0583, -0.0244,  0.0121],
          [ 0.0243, -0.0532,  0.0252],
          [-0.0372,  0.0098,  0.0754]],

         [[ 0.0480,  0.0094,  0.0544],
          [-0.0291, -0.0081,  0.0834],
          [-0.0282,  0.0537, -0.0363]]],

......

若要恢复网络模型: 

import torch
import torchvision

# 加载模型
vgg16 = torchvision.models.vgg16(weights=None)
vgg16.load_state_dict(torch.load('vgg16_method2.pth'))
print(vgg16)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)

......

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/585422.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux— 网络编程套接字

目录 预备知识 认识端口号 理解源端口号和目的端口号 认识TCP协议 认识UDP协议 网络字节序 socket编程接口 socket 常见API sockaddr结构 sockaddr 结构​编辑 sockaddr_in 结构 in_addr结构 地址转换函数 简单的UDP网络程序 实现一个简单的英译汉的功能 简易的远程…

注解-反射-XML配置原理

java刚开始原本是直接在方法中创建对象执行程序等,部分代码重复率高,后来就发展成方法封装调用,再后来出现的像spring框架等,引入了XML配置,使得程序更加简洁,方便等,其中XML配置也是基于java反…

java 线程安全和多线程

文章目录 前言一、ThreadLocal是什么?二、synchronized 和 ReentrantLock 都是 Java 中提供的可重入锁,二者的主要区别有以下 5 个:三、线程安全的集合类有哪些?四、说一下你对CompletableFuture的理解四、项目中是如何创建线程池…

R语言生物群落(生态)数据统计分析与绘图(从数据整理到分析结果展示)

R 语言作的开源、自由、免费等特点使其广泛应用于生物群落数据统计分析。生物群落数据多样而复杂,涉及众多统计分析方法。以生物群落数据分析中的最常用的统计方法回归和混合效应模型、多元统计分析技术及结构方程等数量分析方法为主线,通过多个来自经典…

C++类和对象三

文章目录 类和对象三初始化列表用途与特性 explicit关键字再谈构造函数static成员static的特性 友元友元函数友元函数特性 友元类友元类特性 内部类概念特性 匿名对象拷贝对象时的一些编译器优化 类和对象三 初始化列表 初始化列表:以一个冒号开始,接着…

【源码解析】Spring Bean生命周期源码解析

Spring启动核心 AbstractApplicationContext#refresh,Spring刷新容器的核心方法。最关键的就是 AbstractApplicationContext#invokeBeanFactoryPostProcessors,扫描BeanAbstractApplicationContext#finishBeanFactoryInitialization,生成Be…

【MySql】InnoDB一棵B+树可以存放多少行数据?

文章目录 背景一、怎么得到InnoDB主键索引B树的高度?二、小结三、最后回顾一道面试题总结参考资料 背景 InnoDB一棵B树可以存放多少行数据?这个问题的简单回答是:约2千万。为什么是这么多呢?因为这是可以算出来的,要搞…

[C语言实现]带你手撕带头循环双链表

目录 什么是双链表? 带头结点的优势: 双链表的实现: 什么是循环双链表? 众所周知,顺序表的插入和删除有时候需要大量移动数据,并且每次开辟空间都可能会浪费大量内存和CPU资源,于是我们有了链表,我们之…

【实用篇】SpringCloud01

SpringCloud01 1.认识微服务 随着互联网行业的发展,对服务的要求也越来越高,服务架构也从单体架构逐渐演变为现在流行的微服务架构。这些架构之间有怎样的差别呢? 1.0.学习目标 了解微服务架构的优缺点 1.1.单体架构 单体架构&#xff…

基于复旦微 JFM7K325T 全国产FPGA的高速数据采集、图像处理方案

板卡概述 PCIE-XM711 是一款基于 PCIE 总线架构的高性能数据预处理 FMC载板,板卡采用复旦微的 JFM7K325T FPGA 作为实时处理器,实现 各个接口之间的互联。该板卡可以实现 100%国产化。 板卡具有 1 个 FMC(HPC)接口,1 路…

这10道测试用例面试题,面试官肯定会问到

前言 软件测试面试中,测试用例是非常容被问到的一个点,今天小编就给大家把最常见的测试用例方面的问题给大家整理出来,希望对大家的面试提供帮助。 1、 什么是测试用例‍ 答:测试用例的设计就是如何覆盖所有软件表现出来的状态…

硬盘有多少图片取决我的网速, Python获取硬盘少女图集

目录 前言开发环境:案例实现的步骤:代码展示尾语 💝 前言 嗨喽~大家好呀,这里是魔王呐 ❤ ~! 图片数据在网络中随处可见, 如果需要一些图片素材 一睹为快把! ! ! 开发环境: 解释器: python 3.8 编辑器: pycharm 2022.3 专业版 内置模块使用 os &g…

CH341的I2C接口编程说明

CH341的I2C接口特性: 1、支持I2C速度20K/100K/400K/750K; 2、默认不支持设备的ACK应答监测,即忽略ACK状态;强制支持需修改软件; 引脚序号功能说明24SCL23SDA Windows系统SPI通讯接口函数 HANDLE WINAPI CH341OpenD…

Piccolo傻瓜式环境配置

首先下载Piccolo的代码 Piccolo下载网址点击Code然后点击Download ZIP。可能有点慢,总共要下139M。 下载CMAKE CMAKE下载网址下载完后安装CMAKE 构建环境 将下来的Piccolo-main压缩包解压为Piccolo-main文件夹。打开CMAKE,如下图进行目录选择&…

Linux - 第19节 - 网络基础(传输层二)

1.TCP相关实验 1.1.理解listen的第二个参数 在编写TCP套接字的服务器代码时,在进行了套接字的创建和绑定之后,需要调用listen函数将创建的套接字设置为监听状态,此后服务器就可以调用accept函数获取建立好的连接了。其中listen函数的第一个参…

LeetCode - 1049 最后一块石头的重量 II (0-1背包)

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://blog.csdn.net/caroline_wendy/article/details/130935119 LeetCode:1049. 最后一块石头的重量 II 题目:有一堆石头,用整数数组 stones 表示。其中 stones[i] 表示第 i 块石头的重量。 每一回合,从中选…

人脸识别3:C/C++ InsightFace实现人脸识别Face Recognition(含源码)

人脸识别3:C/C InsightFace实现人脸识别Face Recognition(含源码) 目录 1. 前言 2. 项目安装 (1)项目结构 (2)配置开发环境(OpenCVOpenCLbase-utilsTNN) (3)部署TNN模型 (4&a…

【C++】手把手教你模拟实现string类

模拟实现string 前言类的成员变量构造函数析构函数size和length[ ] 重载迭代器赋值运算符重载和拷贝构造函数拷贝构造函数赋值运算符重载现代式写法 reserve 和 resizereserveresize 字符串追加push_backappend insertpos位置插字符pos位置插字符串 erase>> 和 <<&…