14、保存与加载PyTorch训练的模型和超参数

news2024/11/29 2:13:38

文章目录

  • 1. state_dict
  • 2. 模型保存
  • 3. check_point
  • 4. 详细保存
  • 5. Docker
  • 6. 机器学习常用库

1. state_dict

nn.Module 类是所有神经网络构建的基类,即自己构建一个深度神经网络也是需要继承自nn.Module类才行,并且nn.Module中的state_dict包含神经网络中的权重weight ,偏置bias,过程量buffer,举例说明:

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :NN_Embedding.py
# @Time      :2024/11/26 22:50
# @Author    :Jason Zhang
import torch
from torch import nn


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear1 = nn.Linear(3, 4)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(4, 5)
        self.batch_norm = nn.BatchNorm2d(4)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        y = self.linear2(x)
        return y


if __name__ == "__main__":
    my_test = MyModel()
    my_keys = my_test.state_dict().keys()
    print(f"my_keys={my_keys}")
  • 结果:
    从结果中看出,跟说明的一样,不仅存的是weight,bias ,还有buffer
y_keys=odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias', 'batch_norm.weight', 'batch_norm.bias', 'batch_norm.running_mean', 'batch_norm.running_var', 'batch_norm.num_batches_tracked'])

2. 模型保存

保存和加载

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :torch_save.py
# @Time      :2024/11/27 21:33
# @Author    :Jason Zhang
import torch
import torchvision.models as models

if __name__ == "__main__":
    run_code = 0
    model = models.vgg16(weights='IMAGENET1K_V1')
    torch.save(model.state_dict(), 'model_weights.pth')
    model.load_state_dict(torch.load('model_weights.pth', weights_only=True))
    model.eval()
    torch.save(model, 'model.pth')

3. check_point

# Define model
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F


class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])
Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]

4. 详细保存

在训练过程中,我们希望详细保存,以至于我们可以在中断训练中恢复训练。
保存模型

5. Docker

关于Docker方式搭建深度神经网络环境和配置
在这里插入图片描述

6. 机器学习常用库

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2249448.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【计算机网络】多路转接之poll

poll也是一种linux中的多路转接方案(poll也是只负责IO过程中的"等") 解决:1.select的fd有上限的问题;2.每次调用都要重新设置关心的fd 一、poll的使用 int poll(struct pollfd *fds, nfds_t nfds, int timeout); ① struct pollfd *fds&…

矩阵重新排列——sort函数

s o r t sort sort函数表示排序,对向量和矩阵都成立 向量 s o r t ( a ) sort(a) sort(a)将向量 a a a中元素从小到大排序 s o r t ( a , ′ d e s c e n d ′ ) sort(a,descend) sort(a,′descend′)将向量 a a a中元素从大到小排序 [ s o r t a , i d ] s o r…

深入解密 K 均值聚类:从理论基础到 Python 实践

1. 引言 在机器学习领域,聚类是一种无监督学习的技术,用于将数据集分组成若干个类别,使得同组数据之间具有更高的相似性。这种技术在各个领域都有广泛的应用,比如客户细分、图像压缩和市场分析等。聚类的目标是使得同类样本之间的…

Leetcode322.零钱兑换(HOT100)

链接 代码&#xff1a; class Solution { public:int coinChange(vector<int>& coins, int amount) {vector<int> dp(amount1,amount1);//要兑换amount元硬币&#xff0c;我们就算是全选择1元的硬币&#xff0c;也不过是amount个&#xff0c;所以初始化amoun…

【61-70期】Java面试题深度解析:从集合框架到线程安全的最佳实践

&#x1f680; 作者 &#xff1a;“码上有前” &#x1f680; 文章简介 &#xff1a;Java &#x1f680; 欢迎小伙伴们 点赞&#x1f44d;、收藏⭐、留言&#x1f4ac; 文章题目&#xff1a;Java面试题深度解析&#xff1a;从集合框架到线程安全的最佳实践 摘要&#xff1a; 本…

关闭AWS账号后,服务是否仍会继续运行?

在使用亚马逊网络服务&#xff08;AWS&#xff09;时&#xff0c;用户有时可能会考虑关闭自己的AWS账户。这可能是因为项目结束、费用过高&#xff0c;或是转向使用其他云服务平台。然而&#xff0c;许多人对关闭账户后的服务状态感到困惑&#xff0c;我们九河云和大家一起探讨…

JVM_垃圾收集器详解

1、 前言 JVM就是Java虚拟机&#xff0c;说白了就是为了屏蔽底层操作系统的不一致而设计出来的一个虚拟机&#xff0c;让用户更加专注上层&#xff0c;而不用在乎下层的一个产品。这就是JVM的跨平台&#xff0c;一次编译&#xff0c;到处运行。 而JVM中的核心功能其实就是自动…

若依框架部署在网站一个子目录下(/admin)问题(

部署在子目录下首先修改vue.config.js文件&#xff1a; 问题一&#xff1a;登陆之后跳转到了404页面问题&#xff0c;解决办法如下&#xff1a; src/router/index.js 把404页面直接变成了首页&#xff08;大佬有啥优雅的解决办法求告知&#xff09; 问题二&#xff1a;退出登录…

BERT解析

BERT项目 我在BERT添加注释和部分推理代码 main.py vocab WordVocab.load_vocab(args.vocab_path)#加载vocab那么这个加载的二进制是什么呢&#xff1f; 1. 加载数据集 继承关系&#xff1a;TorchVocab --> Vocab --> WordVocab TorchVocab 该类主要是定义了一个词…

连接共享打印机0X0000011B错误多种解决方法

打印机故障一直是一个热门话题&#xff0c;特别是共享打印机0x0000011b错误特别头疼&#xff0c;有很多网友经常遇到共享打印机0x0000011b错误。0x0000011b有更新补丁导致的、有访问共享打印机服务异常、有访问共享打印机驱动异常等问题导致的&#xff0c;针对共享打印机0x0000…

spring +fastjson 的 rce

前言 众所周知&#xff0c;spring 下是不可以上传 jsp 的木马来 rce 的&#xff0c;一般都是控制加载 class 或者 jar 包来 rce 的&#xff0c;我们的 fastjson 的高版本正好可以完成这些&#xff0c;这里来简单分析一手 环境搭建 <dependency><groupId>org.spr…

jeecgbootvue2重新整理数组数据或者添加合并数组并遍历背景图片或者背景颜色

想要实现处理后端返回数据并处理&#xff0c;添加已有静态数据并遍历快捷菜单背景图 遍历数组并使用代码 需要注意&#xff1a; 1、静态数组的图片url需要的格式为 require(../../assets/b.png) 2、设置遍历背景图的代码必须是: :style"{ background-image: url( item…

15分钟做完一个小程序,腾讯这个工具有点东西

我记得很久之前&#xff0c;我们都在讲什么低代码/无代码平台&#xff0c;这个概念很久了&#xff0c;但是&#xff0c;一直没有很好的落地&#xff0c;整体的效果也不算好。 自从去年 ChatGPT 这类大模型大火以来&#xff0c;各大科技公司也都推出了很多 AI 代码助手&#xff…

jenkins 2.346.1最后一个支持java8的版本搭建

1.jenkins下载 下载地址&#xff1a;Index of /war-stable/2.346.1 2.部署 创建目标文件夹&#xff0c;移动到指定位置 创建一个启动脚本&#xff0c;deploy.sh #!/bin/bash set -eDATE$(date %Y%m%d%H%M) # 基础路径 BASE_PATH/opt/projects/jenkins # 服务名称。同时约定部…

3D建筑模型的 LOD 规范

LOD&#xff08;细节层次&#xff09; 是3D城市建模中用于表示建筑模型精细程度的标准化描述不同的LOD适用于不同的应用场景 LOD是3D建模中重要的分级标准&#xff0c;不同层级适合不同精度和用途的需求。 从LOD0到LOD4&#xff0c;细节逐渐丰富&#xff0c;复杂性和精度也逐…

解锁 Vue 项目中 TSX 配置与应用简单攻略

在 Vue 项目中配置 TSX 写法 在 Vue 项目中使用 TSX 可以为我们带来更灵活、高效的开发体验&#xff0c;特别是在处理复杂组件逻辑和动态渲染时。以下是详细的配置步骤&#xff1a; 一、安装相关依赖 首先&#xff0c;我们需要在命令行中输入以下命令来安装 vitejs/plugin-v…

【UE5 C++课程系列笔记】04——创建可操控的Pawn

根据官方文档创建一个可以控制前后左右移动、旋转视角、缩放视角的Pawn 。 步骤 一、创建Pawn 1. 新建一个C类&#xff0c;继承Pawn类&#xff0c;这里命名为“PawnWithCamera” 2. 在头文件中申明弹簧臂、摄像机和静态网格体组件 3. 在源文件中引入组件所需库 在构造函数…

多目标优化算法——多目标粒子群优化算法(MOPSO)

Handling Multiple Objectives With Particle Swarm Optimization&#xff08;多目标粒子群优化算法&#xff09; 一、摘要&#xff1a; 本文提出了一种将帕累托优势引入粒子群优化算法的方法&#xff0c;使该算法能够处理具有多个目标函数的问题。与目前其他将粒子群算法扩展…

HTML5好看的音乐播放器多种风格(附源码)

文章目录 1.设计来源1.1 音乐播放器风格1效果1.2 音乐播放器风格2效果1.3 音乐播放器风格3效果1.4 音乐播放器风格4效果1.5 音乐播放器风格5效果 2.效果和源码2.1 动态效果2.2 源代码 源码下载万套模板&#xff0c;程序开发&#xff0c;在线开发&#xff0c;在线沟通 作者&…

通用网络安全设备之【防火墙】

概念&#xff1a; 防火墙&#xff08;Firewall&#xff09;&#xff0c;也称防护墙&#xff0c;它是一种位于内部网络与外部网络之间的网络安全防护系统&#xff0c;是一种隔离技术&#xff0c;允许或是限制传输的数据通过。 基于 TCP/IP 协议&#xff0c;主要分为主机型防火…