6-pytorch - 网络的保存和提取

news2024/11/23 21:05:44

前言

我们训练好的网络,怎么保存和提取呢?
总不可以一直不关闭电脑吧,训练到一半,想结束到明天再来训练,这就需要进行网络的保存和提取了。
本文以前面博客3-pytorch搭建一个简单的前馈全连接层网络(回归问题)的网络进行网络的保存和提取,建议先看完上面博客再来看本博客。

一、生成训练数据

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# 生成数据(fake data)
x = torch.linspace(-1,1,100).reshape(-1,1)
# 加上点噪声
y = x.pow(2) + 0.2*torch.rand(x.shape)

# 可视化一下数据
plt.scatter(x.data.numpy(),y.data.numpy())
plt.show()

输出:
在这里插入图片描述

二、网络保存

def save():
    net1 = torch.nn.Sequential(
        torch.nn.Linear(1,10),
        torch.nn.ReLU(), 
        torch.nn.Linear(10,1)
    )
    optimizer = torch.optim.SGD(net1.parameters(),lr=0.5)
    loss_func = torch.nn.MSELoss()
    
    for t in range(100):
        prediction = net1(x)
        loss = loss_func(prediction,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # 下面介绍两种不同的保存方法,方法二可能运行速度要快点
    # 保存整个网络的所有
    torch.save(net1, 'net.pkl')     
    # 保存好网络的参数
    torch.save(net1.state_dict(),'net_params.pkl')
    
    # plot result
    plt.figure(1,figsize=(10,3))
    plt.subplot(131)
    plt.title('Net1')
    plt.scatter(x.data.numpy(),y.data.numpy())
    plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)

【注】:保存整个网络还是保存网络参数,个人建议仅保存参数,这个速度更快。

三、网络提取

def restore_net():
    net2 = torch.load('net.pkl')
    prediction = net2(x)
    
    # plot result
    plt.subplot(132)
    plt.title('Net1')
    plt.scatter(x.data.numpy(),y.data.numpy())
    plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
    
def restore_params():
    # 如果只是保留参数的情况,提取时需要再次定义相同网络才行
    net3 = torch.nn.Sequential(
        torch.nn.Linear(1,10),
        torch.nn.ReLU(),
        torch.nn.Linear(10,1)
    )
    net3.load_state_dict(torch.load('net_params.pkl'))
    prediction = net3(x)
    
    # plot result
    plt.subplot(133)
    plt.title('Net1')
    plt.scatter(x.data.numpy(),y.data.numpy())
    plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)

四、对保存网络提取进行结果展示

save()
restore_net()
restore_params()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1600988.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Level protection and deep learning

1.模拟生成的数据 import randomdef generate_data(level, num_samples):if level not in [2, 3, 4]:return Nonedata_list []for _ in range(num_samples):# 构建指定等级的数据data str(level)for _ in range(321):data str(random.randint(0, 9))data_list.append(data)…

2.4G漂移小车电子方案 酷得智能科技

漂移高速遥控车是一种专门设计用于执行高速漂移动作的遥控车模型。以下是一些关于漂移高速遥控车的功能介绍: 1、高速性能:漂移车通常配备有强力的电机和电池,以便在保持高速的同时进行漂移动作。 2、漂移能力:漂移车的轮胎和悬挂…

操作系统—实现可变式分区分配算法

文章目录 实现可变式分区分配算法1.实验环境2.如何在xv6中实现分区分配算法?(1).xv6的内存管理机制(2).实现思路 3.最佳适应算法(1).基本思路(2).步骤(3).测试&Debug 总结参考资料 实现可变式分区分配算法 1.实验环境 因为这一次的实验仍然是在xv6中进行&#…

【AIGC】AIGC在虚拟数字人中的应用:塑造未来互动体验的革新力量

🚀 🚀 🚀随着科技的快速发展,AIGC已经成为引领未来的重要力量。其中,AIGC在虚拟数字人领域的应用更是引起了广泛关注。虚拟数字人作为一种先进的数字化表达形式,结合了3D建模、动画技术、人工智能等多种先进…

PaddleOCR训练自己模型(2)----参数配置及训练

一、介绍 paddleocr分为文字定位(Det)和文字识别(Rec)两个部分 二、定位模型训练 (1)Det预训练模型下载:https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_train.tar (2)下载完之后,…

女上司问我:误删除PG百万条数据,可以闪回吗?

作者:IT邦德 中国DBA联盟(ACDU)成员,10余年DBA工作经验 擅长主流数据Oracle、MySQL、PG、openGauss运维 备份恢复,安装迁移,性能优化、故障应急处理等可提供技术业务: 1.DB故障处理/疑难杂症远程支援 2.Mysql/PG/Oracl…

32. BI - 依据淘宝的用户行为,从 0 开始实现一个简单的移动推荐系统

本文为 「茶桁的 AI 秘籍 - BI 篇 第 32 篇」 Hi, 你好。我是茶桁。 今天咱们要来完成一个简单的推荐系统的建立。 之前的课程里给大家讲了两种模型,也希望大家对模型的概念以及使用场景会有些了解。不光是推荐系统,在生物、心理学、社交网络等等里面都…

LangChain-Chatchat 开源知识库来了

LangChain-Chatchat 开源知识库来了 LangChain-Chatchat 架构设计LangChain-ChatChat 具体实现过程 一键本地离线部署软件环境硬件环境支持三种部署方式 LangChain-Chatchat 是基于 ChatGLM 等大语言模型与 LangChain 等应用框架实现,开源、可离线部署的 RAG 检索增…

安全特低电压 SELV(Safety Extra Low Voltage,缩写SELV) 是不接地系统的安全特低电压

SELV LED驱动器 市场上有很多LED灯是非隔离的,甚至还有灯条要100多伏特电压才能点亮的,安全吗? 国外多数LED驱动器标注了SELV,为什么? 安全特低电压 SELV(Safety Extra Low Voltage,缩写SELV) 是不接地系…

通过adb 命令打印安装在第三方模拟器上的log

1,环境:Windows 11 ,第三方模拟器 网易的MuMu 步骤: 1,打开cmd,输入 adb connect 172.0.0.1:7555 2,在cmd,再次输入adb logcat 回车

MongoDB的CURD(增删改查操作)

读者大大们好呀!!!☀️☀️☀️ 🔥 欢迎来到我的博客 👀期待大大的关注哦❗️❗️❗️ 🚀欢迎收看我的主页文章➡️寻至善的主页 ✈️如果喜欢这篇文章的话 🙏大大们可以动动发财的小手👉&#…

初步学习node.js文件模块

环境已安装好; 写一个read1.js如下; var fs require("fs"); var data ;// 创建一个流 var stream1 fs.createReadStream(test1.jsp); stream1.setEncoding(UTF8);// 绑定data事件 stream1.on(data, function(mydata) {data mydata; });/…

竞赛 基于GRU的 电影评论情感分析 - python 深度学习 情感分类

文章目录 1 前言1.1 项目介绍 2 情感分类介绍3 数据集4 实现4.1 数据预处理4.2 构建网络4.3 训练模型4.4 模型评估4.5 模型预测 5 最后 1 前言 🔥 优质竞赛项目系列,今天要分享的是 基于GRU的 电影评论情感分析 该项目较为新颖,适合作为竞…

注册表让我重回80年代(狗头保命

现在是2024年4月16日23:09:07,今天之所以这么晚才睡,是因为遇到了一个很有意思的事情,以至于解决完之后,强挺困意,将其记录—— 缘由是想只用键盘操纵电脑,上面有写,那用winR就是家常便饭。只不…

贴片滚珠振动开关 / 振动传感器的用法

就是这种小东西: 上面的截图来自:https://item.szlcsc.com/3600130.html 以前写过一篇介绍这种东西内部的结构原理:贴片微型滚珠振动开关的结构原理。就是有个小滚珠会接通开关两边的电极,振动时滚珠会在内部蹦跳,开关…

基于Springboot的影城管理系统

基于SpringbootVue的影城管理系统的设计与实现 开发语言:Java数据库:MySQL技术:SpringbootMybatis工具:IDEA、Maven、Navicat 系统展示 用户登录 首页展示 电影信息 电影资讯 后台登录页 后台首页 用户管理 电影类型管理 放映…

RabbitMQ Stream插件使用详解

2.4版为RabbitMQ流插件引入了对RabbitMQStream插件Java客户端的初始支持。 RabbitStreamTemplateStreamListener容器 将spring rabbit流依赖项添加到项目中&#xff1a; <dependency><groupId>org.springframework.amqp</groupId><artifactId>sprin…

WebKit内核游览器

WebKit内核游览器 基础概念游览器引擎Chromium 浏览器架构Webkit 资源加载这里就不得不提到http超文本传输协议这个概念了&#xff1a; 游览器多线程HTML 解析总结 基础概念 百度百科介绍 WebKit 是一个开源的浏览器引擎&#xff0c;与之相对应的引擎有Gecko&#xff08;Mozil…

C# 字面量null对于引用类型变量和值类型变量

编译器让相同的字符串字面量共享堆中的同一内存位置以节约内存。 在C#中&#xff0c;字面量&#xff08;literal&#xff09;是指直接表示固定值的符号&#xff0c;比如数字、字符串或者布尔值。而关键字&#xff08;keyword&#xff09;则是由编程语言定义的具有特殊含义的标…

mysql 转pg 两者不同的地方

因项目数据库&#xff08;原来是MySQL&#xff09;要改成PostgreSQL。 项目里面的sql要做一些调整。 1&#xff0c;写法上的区别&#xff1a; 1&#xff0c;数据准备&#xff1a; 新建表格&#xff1a; CREATE TABLE property_config ( CODE VARCHAR(50) NULL…