深度学习——神经网络参数的保存和提取

news2024/11/15 3:49:11

代码与详细注释:

Talk is cheap. Show you the code!

import torch
import matplotlib.pyplot as plt


# 造数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)


# 保存net1
def save():
    # 使用Sequential快速搭建神经网络
    net1 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
    # 优化器设置为SGD
    optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
    # 损失函数设置为MSELoss
    loss_func = torch.nn.MSELoss()

    # 训练100步
    for t in range(100):
        # 计算预测值
        prediction = net1(x)
        # 计算预测值和真实值之间的误差
        loss = loss_func(prediction, y)
        # 将梯度设置为0
        optimizer.zero_grad()
        # 误差反向传播
        loss.backward()
        # 优化器逐步优化
        optimizer.step()

    # 绘制第一张子图
    plt.figure(1, figsize=(10, 3))
    plt.subplot(131)
    plt.title('Net1')
    # 绘制原数据的散点图
    plt.scatter(x.data.numpy(), y.data.numpy())
    # 绘制回归曲线
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)

    # 方法一:保存整个网络
    torch.save(net1, 'net.pkl')
    # 方法二:只保存网络参数
    torch.save(net1.state_dict(), 'net_params.pkl')


# 提取
def restore_net():
    # 方法一:提取整个网络
    net2 = torch.load('net.pkl')
    # 使用net2进行预测
    prediction = net2(x)

    # 绘制第二张子图
    plt.subplot(132)
    # 设置子图标题
    plt.title('Net2')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)


def restore_params():
    # 使用Sequential快速搭建一个和net1结构一致的网络net3
    net3 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )

    # 方法二:加载net1的参数,提取,然后赋给net3的参数
    net3.load_state_dict(torch.load('net_params.pkl'))
    prediction = net3(x)

    # 绘制子图3
    plt.subplot(133)
    plt.title('Net3')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.show()


# net1进行预测,并且用两种方法进行保存模型
save()

# 提取整个网络到net2进行预测并绘图
restore_net()

# 提取net1的网络参数,然后赋给net3预测并绘图
restore_params()

运行结果:

在这里插入图片描述

因为网络结构和网络参数是一样的,所以训练出来的效果也是一致的!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/739691.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

unity 调用高德SDK

unity 2022.2.20f1c1 一、准备工作: 方式一:Unity打包arr 导入AndroidStudio ,AndroidStudio打包 方式二:Unity通过MainActivity.java调用SDK ,MainActivity.java 放入到Android Studio中编写代码 二、打包环境…

数字化时代,企业的数据指标体系

在社会节奏越来越快,处理的信息量越来越大的今天,传统的经营管理模式已经适应不了当下的环境。而由经验、情感组成的业务调整以及决策能力不再能正确指导企业走在正确的方向上,所以数据就成为了企业新的业务优化调整和支撑企业高层管理进行决…

关于saltstack的监控系统部署

环境 master 是centos7-linux 192.14.0.79 minios 是 windows11 192.14.0.207 下载saltstack主节点 sudo yum install salt-master下载saltstack 客户端 windows的minios配置Salt-Minion-3006.1-Py3-AMD64-Setup.exe 过程 master 端 vim /etc/salt/master.d/network.conf…

如何让一个盒子因为内容不同,而样式也不同呢

例如,每个盒子上面都有一个色块,静态,动态,岗位。如何让不同的内容就有不同的字体颜色和背景呢? 可以给每个盒子重复一样的步骤,但是显然最简单的方法是用一个循环。循环遍历数据,直接写一个盒…

《Pytorch深度学习和图神经网络(卷 1)》学习笔记——第八章

本书之后的内容与当前需求不符合不再学习 信息熵与概率的计算关系… 联合熵、条件熵、交叉熵、相对熵(KL散度)、JS散度、互信息 无监督学习 监督训练中,模型能根据预测结果与标签差值来计算损失,并向损失最小的方向进行收敛。无…

CRYPTO-36D-rsaEZ

0x00 前言 CTF 加解密合集:CTF 加解密合集 0x01 题目 给了一个秘钥,三个加密后的文件 0x02 Write Up 先获取n和e # 导入公钥 with open(r"C:\Users\wdd\Downloads\flag\fujian\public.key", "rb") as f:key RSA.import_key(f…

行业追踪,2023-07-10,汽车零部件如期调整,需要耐心等待第二波

自动复盘 2023-07-10 成交额超过 100 亿 排名靠前,macd柱由绿转红 成交量要大于均线 有必要给每个行业加一个上级的归类,这样更能体现主流方向 rps 有时候比较滞后,但不少是欲杨先抑, 应该持续跟踪,等 macd 反转时参与…

Vue简单使用及整合elementui

创建vue工程 在vue工程目录下npm install vue 下载离线vue https://v2.vuejs.org/v2/guide/installation.html 引入工程中 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" c…

C语言指针进阶

目录 0.指针初阶回顾&#xff1a; 1.字符指针 2.指针数组 3.数组指针 &#xff08;1&#xff09;数组指针的定义&#xff1a; &#xff08;2&#xff09;数组名和&数组名 &#xff08;3&#xff09;数组指针的使用 &#xff08;4&#xff09;数组指针的数组&#x…

使用均值漂移来量化带宽分类数据

均值漂移概念 均值漂移的基本概念&#xff1a;沿着密度上升方向寻找聚簇点&#xff0c;其计算过程如下&#xff1a; 1 均值漂移算法首先找到一个中心点center&#xff08;随机选择&#xff09;&#xff0c;然后根据半径划分一个范围 把这个范围内的点输入簇x的标记个数加1 2 在…

GPT和爬虫有什么区别?如何利用爬虫实现GPT功能

GPT&#xff08;Generative Pre-trained Transformer&#xff09;和爬虫是两个完全不同的概念和技术。GPT是一种基于Transformer模型的自然语言处理模型&#xff0c;用于生成文本&#xff0c;而爬虫是一种用于从互联网上收集数据的技术。 GPT是由OpenAI开发的一种深度学习模型&…

以太网之ARP协议(八)

一、概要 在网络通信中主要以IP为主机标识进行数据通信的&#xff0c;但实际的数据链路层传输以MAC地址为数据传输的节点地址。那设备之间又是如何通过IP地址确认对应主机的MAC地址的&#xff1f;这就是ARP协议的工作。 ARP是一种以目标IP地址为线索&#xff0c;用来定位下一个…

300PLC转以太网模块plc300以太网通信模块

摘要 工业通讯的发展已经迅速到了一个令人咋舌的地步&#xff0c;以太网通讯已经成为了工业通讯的主流。而今天&#xff0c;我们要介绍的是一款以太网通讯处理器——捷米特ETH-S7300-JM01&#xff0c;它不仅成熟、稳定&#xff0c;而且价格优惠&#xff0c;为工业以太网通讯领域…

Qt对地震数据(文件格式*.Segd)实现将时域数据转频域数据

文件格式以segd为例&#xff0c;其他地震文件格式同理。 时域数据 时域数据通俗点讲就是我在某个时间段记录的一个值&#xff0c;然后经过一段时间后&#xff0c;产生的一组数据就是时域数据。 频域数据 频域数据是指信号在频率域上的表示&#xff0c;即信号的频率特性。频…

PowerShell快速ssh

文件 ~/.ssh/config 内容 Host masterHostName 192.168.10.154User root访问 $ ssh master 效果 进阶 配置秘钥 待续。。。

Transform、GameObject、Rigidbody

文章目录 零、初衷和溯源一、Transform类二、GameObject类三、Rigidbody类 零、初衷和溯源 这三个类的API官方文档&#xff0c;有些杂乱——本可以把它们分门别类的整理好&#xff0c;结果却是凌乱的堆在一起&#xff0c;令人恼火。   之所以把它仨放一起总结&#xff0c;是因…

【数据挖掘】时间序列教程【十】

5.4 通用卡尔曼滤波 上一节中描述的状态空间模型作为观测方程的更一般的公式 和状态方程 这里是一个p1 向量

simulink stateFlow流程图

基础 修改分支优先级 使用matlab workspace变量 例题 输出数组输入数组的平方 for循环 使用脚本的数值 实现数组索引

2021 RoboCom 世界机器人开发者大赛-本科组(初赛)

编程题得分&#xff1a;100 总分&#xff1a;100 7-1 懂的都懂 (20分) 众所周知&#xff0c;在互联网上有很多话是不好直接说出来的&#xff0c;不过一些模糊的图片仍然能让网友看懂你在说什么。然而对这种言论依然一定要出重拳&#xff0c;所以请你实现一个简单的匹配算法。 …

图像分类论文阅读

该论文通过结合VGG-19和VIT模型,实现乳腺超声图像的分类Breast Ultrasound Images Dataset | Kaggle PyTorch VGG19复现代码 # VGG19.py import torch import torch.nn as nnclass Conv(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,…