PyTorch|保存及加载模型、nn.Sequential、ModuleList和ModuleDict

news2024/11/25 2:29:45

系列文章目录

PyTorch|Dataset与DataLoader使用、构建自定义数据集
PyTorch|搭建分类网络实例、nn.Module源码学习
pytorch|autograd使用、训练模型

文章目录

  • 系列文章目录
  • 一、保存及加载模型
    • (一)保存及加载模型的权重
    • (二)保存及加载优化器的权重
    • (三)保存及加载整个模型
    • (四)保存及加载更具一般性的checkpoint
    • (五)保存多个模型
  • 二、nn.Sequential源码分析
    • (一)init函数
    • (二)forward函数
  • 三、ModuleList和ModuleDict
    • (一)ModuleList
    • (二)ModuleDict


一、保存及加载模型

通过torch.save可以将该模型的参数、优化器状态、batch normalization、dropout、buffer变量等信息。

import torch
import torchvision.models as models

(一)保存及加载模型的权重

模型取自torchvision.models里的vgg16,权重为IMAGENET1K_V1。

model.state_dict()是模型的权重。state_dict状态字典:一般包含当前model的参数及buffer变量

model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')

推理时可以实现模型的加载:

  • 创建模型实例
  • 将实现保存的模型信息通过torch.load导入进来
  • 采用load_state_dict函数将模型信息载入模型实例
  • model.eval()使得模型进入推理模式
model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

(二)保存及加载优化器的权重

保存优化器权重:
在这里插入图片描述

加载优化器权重:
在这里插入图片描述

(三)保存及加载整个模型

保存整个模型:

torch.save(model, 'model.pth')

加载整个模型:

model = torch.load('model.pth')

(四)保存及加载更具一般性的checkpoint

保存并加载用于推理或恢复训练的一般性checkpoint有助于从上次中断的地方重新开始。在保存一般检查点时,不仅仅是保存模型的state_dict,还包括保存优化器的state_dict、停止使用的时间,最近记录的训练损失,外部的torch.nn.Embedding层等等。

# Additional information
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4

torch.save({
            'epoch': EPOCH,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

加载:

model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

(五)保存多个模型

保存多个模型时可以将其直接合并到一个大字典中保存。

# Specify a path to save to
PATH = "model.pt"

torch.save({
            'modelA_state_dict': netA.state_dict(),
            'modelB_state_dict': netB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            }, PATH)

二、nn.Sequential源码分析

nn.Sequential是有序的,当实例化nn.Sequential时,传入的模块顺序就是神经网络前向传播的顺序

在使用nn.Sequential时,可以按顺序传入模块,也可以输入一个字典。
在这里插入图片描述

(一)init函数

如果输入的是一个字典,init函数会采用遍历字典的方式,如果是一个一个的模块,init函数也会针对性的采取其他遍历方法。
在这里插入图片描述

(二)forward函数

对于一个模型的输入,nn.Sequential会依次的过其中的子模块。
在这里插入图片描述

nn.Sequential相比于ModuleList和ModuleDict来说,优势在于具有forward的功能。

三、ModuleList和ModuleDict

(一)ModuleList

pytorch允许我们把很多子模块放到一个列表中。ModuleList就是用于存放多个子模块的一个列表,在使用时可以对其进行遍历。ModuleList不单纯是一个列表,它本身就是一个module。
在这里插入图片描述

(二)ModuleDict

ModuleDict是用于存放多个子模块的一个字典,在使用时可以根据索引获得对应的子模块。ModuleDict不单纯是一个字典,它本身也是一个module。
在这里插入图片描述

除此之外,还有ParameterList、ParameterDict等,这些与ModuleList和ModuleDict的作用及使用方式类似。

参考:
8、深入剖析PyTorch的state_dict、parameters、modules源码
9、深入剖析PyTorch的nn.Sequential及ModuleList源码

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1607381.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

前端 - 基础 表单标签 - label 标签

# label 标签 其实不属于 表单标签名单经常和 表单标签 搭配使用。 # <label> 标签 为 input 元素 定义 标注&#xff08; 标签 &#xff09; 使用场景 # 其实说白&#xff0c;<label> 标签就是为了方便用户体验的,举例说明 就是说&#xff0c;如上示&am…

软件需求开发和管理过程性指导文件

1. 目的 2. 适用范围 3. 参考文件 4. 术语和缩写 5. 需求获取的方式 5.1. 与用户交谈向用户提问题 5.1.1. 访谈重点注意事项 5.1.2. 访谈指南 5.2. 参观用户的工作流程 5.3. 向用户群体发调查问卷 5.4. 已有软件系统调研 5.5. 资料收集 5.6. 原型系统调研 5.6.1. …

【深度学习】yolov5目标检测学习与调试

2024.4.15 -2024.4.16 完结 0.准备&&补充知识点 yolo检测算法可以实现目标检测、分割和分类任务。 项目仓库地址&#xff1a;https://github.com/ultralytics/yolov5 跟练视频&#xff1a;目标检测 YOLOv5 开源代码项目调试与讲解实战 lux下载视频神器&#xff1a;h…

【氮化镓】栅极漏电对阈值电压和亚阈值摆幅影响建模

本文是一篇关于p-GaN门AlGaN/GaN高电子迁移率晶体管&#xff08;HEMTs&#xff09;的研究文章&#xff0c;发表于《应用物理杂志》&#xff08;J. Appl. Phys.&#xff09;2024年4月8日的期刊上。文章的标题为“Analysis and modeling of the influence of gate leakage curren…

从智能家居到智能城市:物联网中的隐私和安全风险

随着科技的不断进步&#xff0c;智能设备和物联网&#xff08;IoT&#xff09;技术已经逐渐渗透到我们的生活中。从智能家居设备到智能城市的实现&#xff0c;这些设备和技术可以让我们的生活变得更加便捷和高效。但是&#xff0c;这些设备也带来了不可忽视的隐私和安全风险。 …

Windows(Win10、Win11)本地部署开源大模型保姆级教程

目录 前言1.安装ollama2.安装大模型3.安装HyperV4.安装Docker5.安装聊天界面6.总结 点我去AIGIS公众号查看本文 本期教程用到的所有安装包已上传到百度网盘 链接&#xff1a;https://pan.baidu.com/s/1j281UcOF6gnOaumQP5XprA 提取码&#xff1a;wzw7 前言 最近开源大模型可谓闹…

内外网文件摆渡系统,如何贯通网络两侧被隔断的工作流?

随着业务范围不断扩大&#xff0c;产生的数据体量越来越多&#xff0c;企业会采取网络隔离&#xff0c;对核心数据进行保护。网络隔离主要目的是保护企业内部的敏感数据和系统不受外部网络攻击的风险&#xff0c;可以通过物理或逻辑方式实现&#xff0c;例如使用防火墙、网闸、…

如何让指定 Windows 程序崩溃

一、为何要把人家搞崩溃呢 看到这个标题&#xff0c;大家可能觉得奇怪&#xff0c;为什么要让指定程序崩溃呢&#xff0c;难道是想作恶吗&#xff1f;&#x1f613; 哈哈&#xff0c;绝对不是&#xff0c;真实原因是这样的。如果大家用过 Windows 电脑&#xff0c;可能见过类…

正版四月惠,MarginNote _ BookxNote _ 白描优惠啦!会场软件 5 折起

我们的老朋友数码荔枝&#xff0c;最近开启了「正版四月惠」活动&#xff01;会场精选了一批高效办公软件和系统增强工具&#xff0c;快来看看有没有你期待的那一款吧&#xff5e; 会场商品低至 5 折&#xff0c;快把它们带回家&#xff1a; MarginNote 3&#xff1a;7 折价 4…

Linux 系统下的进程间通信 IPC 入门 「下」

以下内容为本人的学习笔记&#xff0c;如需要转载&#xff0c;请声明原文链接 微信公众号「ENG八戒」https://mp.weixin.qq.com/s/IvPHnEsC6ZdIHaFL8Deazg 共享内存 我们在进程间传输比较大的数据块时&#xff0c;通常选用共享内存的方式。共享内存大小也是有限制的&#xff0…

python-django企业设备配件检修系统flask+vue

本课题使用Python语言进行开发。代码层面的操作主要在PyCharm中进行&#xff0c;将系统所使用到的表以及数据存储到MySQL数据库中&#xff0c;方便对数据进行操作本课题基于WEB的开发平台&#xff0c;设计的基本思路是&#xff1a; 前端&#xff1a;vue.jselementui 框架&#…

OpenCV杂记(2):图像拼接(hconcat, vconcat)

OpenCV杂记&#xff08;1&#xff09;&#xff1a;绘制OSD&#xff08;cv::getTextSize, cv::putText&#xff09;https://blog.csdn.net/tecsai/article/details/137872058 1. 简述 做图像处理或计算机视觉技术的同学都知道&#xff0c;我们在工作中会经常遇到需要将两幅图像拼…

李沐51_序列数据——自学笔记

1.时序模型中&#xff0c;当前数据跟之前观察到的数据相关 2.自回归模型使用自身过去数据来预测未来 3马尔可夫模型假设当前只跟最近少数数据相关&#xff0c;从而简化模型 4.潜变量模型使用潜变量来概括历史信息 生成一些数据&#xff1a;使用正弦函数和一些可加性噪声来生…

Qt/QML编程之路:carplay认证(52)

现在有些中控采用高通的芯片如8155、8295等,实现多屏互动等,但是也有一些车型走低成本方案,比如能够实现HiCar、CarLife或者苹果Apple的Carplay等能进行手机投屏就好了。 能实现CarPlay功能通过Carplay认证,也就成了一些必须的过程,国产车规级中控芯片里,开阳有一款ARK1…

Android开发——ViewPager

适配器 package com.example.myapplication; import android.view.View; import android.view.ViewGroup; import androidx.annotation.AnimatorRes; import androidx.annotation.NonNull; import androidx.viewpager.widget.PagerAdapter; import java.util.ArrayList; publi…

单链表逆置(头插法,递归,数据结构栈的应用)

链表逆置就是把最后一个数据提到最前面&#xff0c;倒数第二个放到第二个……依次类推&#xff0c;直到第一个到最后一个。 由于链表没有下标&#xff0c;所以不能借助下标来实行数据的逆置&#xff0c;要靠空间的转移来完成链表的逆置&#xff0c;这里采用没有头节点的链表来实…

SSM项目前后端分离详细说明

1.后端 1.1打包 说明&#xff1a;使用idea打开项目&#xff0c;然后进行打包。 1.2tomcat 说明&#xff1a;把后端打成war包后放入tomcat启动。 1.3启动tomcat 说明&#xff1a; 找到tomcat中bin目录中的startup.bat文件&#xff0c;进行启动。如果启动失败&#xff0c;可以…

【英文演讲】人工智能,Artificial Intelligence: A Glimpse into the Future World

文章目录 1、Power Point(演示文稿)2、Speech manuscript(演讲稿)【假】序言:在这个充满机遇与挑战的时代,人工智能正以惊人的速度改变着我们的生活与工作方式。它不仅是一种技术,更是一种全新的思维方式,引领着我们走向未来世界的新篇章。本次演讲将深入探讨人工智能对…

wechat机器人个性化维护部署修改

大家好&#xff0c;我是雄雄&#xff0c;欢迎关注微信公众号&#xff1a;雄雄的小课堂。 服务端部署配置 在新服务器上安装mysql8.0 ,redis ,nginx,emqx修改数据库的远程访问权限&#xff0c;导入数据库文件application.yml中修改redis的信息application-druild.yml中修改数据…

一文详解视觉Transformer模型压缩和加速策略(量化/低秩近似/蒸馏/剪枝)

视觉Transformer&#xff08;ViT&#xff09;在计算机视觉领域标志性地实现了一次革命&#xff0c;超越了各种任务的最先进模型。然而&#xff0c;它们的实际应用受到高计算和内存需求的限制。本研究通过评估四种主要的模型压缩技术&#xff1a;量化、低秩近似、知识蒸馏和剪枝…