PyTorch|保存与加载自己的模型

news2024/7/4 5:31:35

训练好一个模型之后,我们往往要对其进行保存,除非下次用时想再次训练一遍。

下面以一个简单的回归任务来详细讲解模型的保存和加载。

图片

来看这样一组数据:

x=torch.linspace(-1,1,50)x=x.view(50,1)y=x.pow(2)+0.3*torch.rand(50).view(50,1)

画图:

plt.scatter(x.numpy(),y.numpy())

图片

很显然,x与y基本呈二次函数关系,那么接下来我们就来拟合整个函数

import torchimport matplotlib.pyplot as pltimport torch.nn as nnimport torch.optim as optim

x=torch.linspace(-1,1,50)x=x.view(50,1)y=x.pow(2)+0.3*torch.rand(50).view(50,1)net1=nn.Sequential(nn.Linear(1,10),                  nn.ReLU(),                  nn.Linear(10,1))criterion=nn.MSELoss()optimizer=optim.SGD(net1.parameters(),lr=0.2)#训练模型for i in range(1000):    pred=net1(x)    loss=criterion(pred,y)    optimizer.zero_grad()    loss.backward()    optimizer.step()
#测试模型net1.eval()with torch.no_grad():    y1=net1(x)    plt.plot(x.numpy(),y1.numpy(),'r-')    plt.scatter(x.numpy(),y.numpy())

图片

结果似乎不错!

这里我们得到了一个网络net1,它可以被当作一个二次函数,用于描述之前的x,y数据的关系

得到这个网络后,我们想保存它,主要有两种方式

1,保存整个网络,包括训练后的各个层的参数

​​​​​​​

#保存整个网络,包括训练后的各个层的参数torch.save(net1,'net1weight.pkl')

2,只保存训练好的网络的参数,速度更快

​​​​​​​

#只保存训练好的网络的参数,速度更快torch.save(net1.state_dict(),'net1_params.pkl')

假设我们按第一种方式保存,那么下次想要使用次网络时需要这样做:

network=torch.load('net1weight.pkl')
#测试模型network.eval()with torch.no_grad():    y1=network(x)    plt.plot(x.numpy(),y1.numpy(),'b-')    plt.scatter(x.numpy(),y.numpy())

图片

假设我们按第二种方式保存,那么下次想要使用次网络时需要这样做:

network=nn.Sequential(nn.Linear(1,10),                  nn.ReLU(),                  nn.Linear(10,1))network.load_state_dict(torch.load('net1_params.pkl'))​​​​​​​
#测试模型network.eval()with torch.no_grad():    y1=network(x)    plt.plot(x.numpy(),y1.numpy(),'g-')    plt.scatter(x.numpy(),y.numpy())

图片

可以看出,第二次首先需要构造出一个一模一样的模型,接着再导入参数即可。当然,这只是个简单的回归模型,其它模型保存与加载同样如此。

总结一下:

模型保存与导入有两种方式:

方式一:​​​​​​​

#模型保存torch.save(net1,'net1weight.pkl')#模型导入network=torch.load('net1weight.pkl')

方式二:​​​​​​​

#模型保存torch.save(net1.state_dict(),'net1_params.pkl')#模型导入network.load_state_dict(torch.load('net1_params.pkl'))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1366587.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringBoot + Mybatis 实现多数据源原来如此简单

1、为什么需要整合多数据源 在开发的过程中,我们可能会遇到一个工程使用多个数据源的情况,总体而言分为以下几个原因 a、数据隔离:将不同的数据存储在不同的数据库中,如多租户场景 b、性能优化:将数据分散到多个数据库…

【项目实战】Cadence工具的使用2

代码覆盖率的收集 双击total,打开imc工具。total 下的文件是代码覆盖率文件 找到DUT模块!从图中可以看到代码的覆盖率已经是94.43% 添加exclude文件,注意和Synopsys的后缀不同。 导入.vRefine文件 代码覆盖率为100%。 原因是我们添加了exclu…

大学生如何当一个程序员——第三篇:热门专业学习之路5

第三篇:热门专业学习之路5 1.WEB前端快速入门2.JavaScript基础与深入解析3.jQuery应用与项目开发4.PHP、数据库编程与设计5. Http服务于Ajax编程6. 做一个阶段项目7. H5新特性与移动端开发8.高级框架9.微信小程序 各位小伙伴想要博客相关资料的话关注公众号&#xf…

LabVIEW在设备状态监测与故障诊断中的应用

在现代工业自动化领域,LabVIEW的系统设计平台在设备状态监测与故障诊断中扮演着举足轻重的角色。通过提供一个可视化和数据流编程语言,LabVIEW大大提升了设备安全监测的效率,减少了系统维护成本,同时增强了设备的可靠性和可维护性…

QT上位机开发(日志调试)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 程序开发中有很多的调试方法,比如说IDE调试,也就是设置断点、查看变量等等;比如说日志调试;比如说c…

OCP NVME SSD规范解读-5.命令超时限制-2

Sanitize清除的数据很彻底,对FTL映射表、User Data(包括已经写入NAND和仍在cache里的)、Meta Data、安全密匙、CMB中SQ/CQ相关信息、可能含有用户数据的log等等会全部清除。不过,sanitize操作不会改变RPMB、boot分区、不包含用户数据的cache等内容。 RP…

2.4 DEVICE GLOBAL MEMORY AND DATA TRANSFER

在当前的CUDA系统中,设备通常是带有自己的动态随机存取存储器(DRAM)的硬件卡。例如,NVIDIA GTX1080具有高达8 GB的DRAM,称为全局内存。我们将互换使用全局内存和设备内存这两个术语。为了在设备上执行内核,…

CSS 圆形分割按钮动画 带背景、图片

<template><view class="main"><view class="up"> <!-- 主要部分上 --><button class="card1"><image class="imgA" src="../../static/A.png"></image></button><butt…

数据库系统-甘晴void学习笔记

数据库系统笔记 计科210X 甘晴void 202108010XXX 教材&#xff1a;《数据库系统概论》第6版 &#xff08;图片来源于网络&#xff0c;侵删&#xff09; 文章目录 数据库系统<br>笔记第一篇 基础篇1 绪论1.1数据库系统概述1.2数据模型1.3数据库系统的结构(三级模式结构…

【优选算法】专题三:二分查找 --- 34. 在排序数组中查找元素的第一个和最后一个位置

从今天开始,xxxflower 带着小伙伴们一起学习算法 ~ 今天我们要写的题目是: 34. 在排序数组中查找元素的第一个和最后一个位置 以下是题目的详细解析: class Solution {public int[] searchRange(int[] nums, int target) {// 判断数组为空的情况下返回-1,-1int[] ret new in…

2023年12月 C/C++(三级)真题解析#中国电子学会#全国青少年软件编程等级考试

C/C++编程(1~8级)全部真题・点这里 第1题:因子问题 任给两个正整数N、M,求一个最小的正整数a,使得a和(M-a)都是N的因子。 时间限制:10000 内存限制:65536 输入 包括两个整数N、M。N不超过1,000,000。 输出 输出一个整数a,表示结果。如果某个案例中满足条件的正整数不存…

reiserfs文件系统的磁盘布局

reiserfs文件系统的磁盘布局比较简单&#xff0c;它把整块分区分成相同大小的block块&#xff0c;一个block块的大小默认是4K&#xff0c;而最大块数未2^32次方&#xff0c;即一个分区最大大小为16TB。 reiserfs文件系统分区的前64KB总是为分区标签&#xff08;partition labe…

《每天一分钟学习C语言·十一》static静态,内联函数,宏函数,排序

1、 static静态修饰符 static全局变量和普通全局变量区别 static全局变量只能在本文件中使用&#xff0c;普通全局变量可以在一个项目下的所有文件中使用&#xff0c;需要加extern。static全局变量只能初始化一次 static局部变量和普通局部变量区别&#xff1a; static局部变量…

【性能测试】JMeter分布式测试及其详细步骤

性能测试概要 性能测试是软件测试中的一种&#xff0c;它可以衡量系统的稳定性、扩展性、可靠性、速度和资源使用。它可以发现性能瓶颈&#xff0c;确保能满足业务需求。很多系统都需要做性能测试&#xff0c;如Web应用、数据库和操作系统等。 性能测试种类非常多&#xff0c…

盛元广通实验室业务流审批管理系统2.0

系统通过对取样、分析、数据处理、检验报告等分析全过程中多种影响因素的有效管理&#xff0c;强化检验质量&#xff0c;获得准确可靠的分析成果。业务流审批管理系统主要包括了检测管理、业务受理、样品管理、资源质量管理、分包管理、报告生成、统计分析等&#xff0c;系统能…

Spring学习 基于注解的AOP控制事务

8.1.拷贝上一章代码 8.2.applicationContext.xml <!-- 开启spring对注解事务的支持 --> <tx:annotation-driven transaction-manager"transactionManager"/> 8.3.service Service Transactional(readOnlytrue,propagation Propagation.SUPPORTS) publi…

商城小程序(7.加入购物车)

目录 一、配置vuex二、创建购物车的store模块三、在商品详情页中使用store模块四、实现购加入购物车功能五、动态统计购物车中商品的总数量六、持久化存储购物车的商品七、优化商品详情页的total侦听器八、动态为tabBar页面设置数据徽标九、将设置tabBar徽标的代码抽离为mixins…

在Spring Cloud Config Github配置中心

关于Spring Cloud系列我们其实讲解了很多&#xff0c;但是这里我们介绍一下Spring Cloud Config&#xff0c;它是一个解决分布式系统的配置管理方案&#xff0c;他包含了Client 和 Server 两个部分&#xff0c;server提供配置文件的存储&#xff0c;以接口的方式将配置文件内容…

【数据库】聊聊常见的索引优化-下

分页查询优化 主键排序 在实际的使用中&#xff0c;通过limit 10000,10 查询第10000记录到10010记录&#xff0c;mysql执行的时候是按照将前10010记录全部统计出来&#xff0c;然后剔除前10000条记录&#xff0c;选择后10条记录。这样来看的话&#xff0c;效率不高。 如果数据…