《动手学深度学习》笔记2.4——神经网络从基础→进阶 (文件读写-保存参数和模型)

news2024/9/27 5:37:45

目录

0. 前言

正文:读写文件

1. 加载和保存张量

2. 加载和保存模型参数

3. 小结


0. 前言

  • 课程全部代码(pytorch版)已上传到附件
  • 本章为原书第5章,共分为6节,本篇是5节:文件读写(保存参数和模型)
    • 第1节:《动手学深度学习》笔记2.1——神经网络从基础→进阶 (层和块 - 自定义块)-CSDN博客
  • 本节的代码位置:chapter_deep-learning-computation/read-write.ipynb
  • 本节的视频链接:读写文件_哔哩哔哩_bilibili

正文:读写文件

到目前为止,我们讨论了如何处理数据, 以及如何构建、训练和测试深度学习模型。 然而,有时我们希望保存训练的模型, 以备将来在各种环境中使用(比如在部署中进行预测)。 此外,当运行一个耗时较长的训练过程时, 最佳的做法是定期保存中间结果, 以确保在服务器电源被不小心断掉时,我们不会损失几天的计算结果。 因此,现在是时候学习如何加载和存储权重向量和整个模型了。

1. 加载和保存张量

对于单个张量,我们可以直接调用loadsave函数分别读写它们。 这两个函数都要求我们提供一个名称,save要求将要保存的变量作为输入。

In [1]:

import torch
from torch import nn
from torch.nn import functional as F
​
x = torch.arange(4)
torch.save(x, 'x-file')  # 在当前目录下,新建一个名为'x-file'的文件,把数据(权重和模型)存下来

我们现在可以将存储在文件中的数据读回内存。

In [2]:

x2 = torch.load('x-file') # 将存储在当前目录下'x-file'文件中的数据(权重和模型)读(load)回内存

x2
Out[2]:
tensor([0, 1, 2, 3])

我们可以[存储一个张量列表,然后把它们读回内存。]

In [3]:

y = torch.zeros(4)

torch.save([x, y],'x-files')  # 可以存(save)一个列表(list)
x2, y2 = torch.load('x-files')
(x2, y2)
Out[3]:
(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))

我们甚至可以(写入或读取从字符串映射到张量的字典)。 当我们要读取或写入模型中的所有权重时,这很方便。

In [4]:

mydict = {'x': x, 'y': y}

torch.save(mydict, 'mydict')  # 可以存(save)一个字典(dict)
mydict2 = torch.load('mydict')
mydict2
Out[4]:
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}

2. 加载和保存模型参数

保存单个权重向量(或其他张量)确实有用, 但是如果我们想保存整个模型,并在以后加载它们, 单独保存每个向量则会变得很麻烦。 毕竟,我们可能有数百个参数散布在各处。 因此,深度学习框架提供了内置函数来保存和加载整个网络。 需要注意的一个重要细节是,这将保存模型的参数而不是保存整个模型。 例如,如果我们有一个3层多层感知机,我们需要单独指定架构。 因为模型本身可以包含任意代码,所以模型本身难以序列化。 因此,为了恢复模型,我们需要用代码生成架构, 然后从磁盘加载参数。 让我们从熟悉的多层感知机开始尝试一下。

In [5]:

class MLP(nn.Module):

    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)
​
    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))
​
net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)
 

接下来,我们[将模型的参数存储在一个叫做“mlp.params”的文件中。]

In [6]:

torch.save(net.state_dict(), 'mlp.params') # 用state_dict()得到所有参数(parameters)的、字符串到参数的映射

为了恢复模型,我们[实例化了原始多层感知机模型的一个备份。] 这里我们不需要随机初始化模型参数,而是(直接读取文件中存储的参数。)

In [7]:

clone = MLP() # 想要在别的地方用这些参数,不仅要带走参数'mlp.params',还要带走MLP的模型定义MLP()

# clone = MLP()里的参数已经被随机初始化了
clone.load_state_dict(torch.load('mlp.params'))  # 调用load_state_dict()复写(over write)掉上面初始化的参数
clone.eval()  # eval()将模型设为评估模式,返回self(就是模型本身),这里用来返回模型,看看参数写入是否成功
Out[7]:
MLP(
  (hidden): Linear(in_features=20, out_features=256, bias=True)
  (output): Linear(in_features=256, out_features=10, bias=True)
)

由于两个实例具有相同的模型参数,在输入相同的X时, 两个实例的计算结果应该相同。 让我们来验证一下。

In [8]:

Y_clone = clone(X)

Y_clone == Y  # 和clone之前的模型net = MLP()参数比较一下,是完全相等的,说明参数写入成功
Out[8]:
tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])

3. 小结

  • saveload函数可用于张量对象的文件读写。
  • 我们可以通过参数字典保存和加载网络的全部参数。
  • 保存架构必须在代码中完成,而不是在参数中完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2169039.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

kafka分区和副本的关系?

概念来一波 比如一个topic的消息存放在两个分区中,分区1和分区2.每个分区都有自己的一个副本。即比如分区1有副本1/副本2/副本3,分区2也有分区2的副本1/副本2/副本3。一个节点上的一个topic的可以由多个分区存放,但是每个分区的leader副本会尽…

丹摩智算平台部署 Llama 3.1:实践与体验

文章目录 前言部署前的准备创建实例 部署与配置 Llama 3.1使用心得总结 前言 在最近的开发工作中,我有机会体验了丹摩智算平台,部署并使用了 Llama 3.1 模型。在人工智能和大模型领域,Meta 推出的 Llama 3.1 已经成为了目前最受瞩目的开源模…

manim中文字和目标的对齐方法的使用

为什么要文字对齐 ? 对齐原则在现实生活中无处不在,比如:书籍、货架、地铁座位等等;对齐的目的其实就是在规整文案信息,对齐有利于信息传达以及视觉规范,当我们做文字编排工作时,要根据构图形…

【计算机网络 - 基础问题】每日 3 题(二十六)

✍个人博客:Pandaconda-CSDN博客 📣专栏地址:http://t.csdnimg.cn/fYaBd 📚专栏简介:在这个专栏中,我将会分享 C 面试中常见的面试题给大家~ ❤️如果有收获的话,欢迎点赞👍收藏&…

基于springboot+vue 旅游网站的设计与实现

基于springbootvue 旅游网站的设计与实现 摘 要 互联网发展至今,无论是其理论还是技术都已经成熟,而且它广泛参与在社会中的方方面面。它让信息都可以通过网络传播,搭配信息管理工具可以很好地为人们提供服务。针对信息管理混乱&#xff0c…

【中级通信工程师】终端与业务(五):市场与通信市场

【零基础3天通关中级通信工程师】 终端与业务(五):市场与通信市场 本文是中级通信工程师考试《终端与业务》科目第五章《市场与通信市场》的复习资料和真题汇总。终端与业务是通信考试中最基础的科目之一,复习重点包括通信市场的概念、通信市场结构、以…

【IOS】申请开发者账号(公司)

官网:Apple Developer (简体中文) 申请开发者账号前提 如果是第一次申请建议注册一个新的apple id作为组织的开发者账号。(确保apple id的个人信息是真实的,不能是网名或者是其他名。后续的申请步骤需要能和apple id的个人信息对上。&#…

嵌入式开发 —— IO口高阻态模式

目 录 高阻态输入模式一、区别于浮空输入二、高阻态输入的优点 高阻态输入模式 MCU管脚的高阻态模式是电路的一种输出或输入状态。在这种状态下,电路的输入端或输出端对地或对电源的电阻非常大,在实际应用中与引脚悬空类似。 电气特性 1、高电阻值&…

C#入门教程

目录 1.if分支语句 2.面向对象 3.static简单说明 1.if分支语句 我们的这个C#里面的if语句以及这个if-else语句和C语言里面没有区别,就是打这个输出上面的方式不一样,c#里面使用的是这个console.writeline这个指令,其他的这个判断逻辑都是一…

技术美术百人计划 | 《5.1.3 PBR-基于物理的灯光》笔记

1. 辐射度学 定义:辐射度学是一门以整个电磁波段的电磁辐射能测量为研究的科学。 而计算机图形学中涉及的辐射度学,则集中于整个电磁波普中光学谱段中的可见光谱段的辐射能的计算。 1.1. 立体角 概念:单位球体上的一块区域对应的球面部分的…

计算机毕业设计 中医院问诊系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…

uniapp自定义底部tabBar

使用场景&#xff1a;在一个非tabbar页面&#xff0c;想要有底部导航效果&#xff0c;故自定义效果&#xff0c;系统原底部导航栏仍在正常使用 效果&#xff1a; 布局&#xff1a; <template><view class"tab-bar" :style"{height: height px}"…

《征服数据结构》哈夫曼树(Huffman Tree)

摘要&#xff1a; 1&#xff0c;哈夫曼树的介绍 2&#xff0c;哈夫曼树的构造 3&#xff0c;哈夫曼树带权路径长度计算 4&#xff0c;哈夫曼树的编码 5&#xff0c;哈夫曼树的解码 1&#xff0c;哈夫曼树的介绍 哈夫曼树(Huffman Tree)也叫霍夫曼树&#xff0c;或者赫夫曼树&am…

游戏怎么录制?王者荣耀游戏录制指南:iOS与电脑端全面教程

在王者荣耀的战场上&#xff0c;每一个五杀、每一次极限逃生都可能成为你游戏生涯中的高光时刻。但这些瞬间往往转瞬即逝&#xff0c;如何将它们永久保存&#xff0c;成为你游戏历程中不可磨灭的印记呢&#xff1f;本文将为你揭晓答案。无论你是手持iPhone的iOS用户&#xff0c…

正则中捕获组和非捕获组区别

捕获组和非捕获组 一. 捕获组&#xff08;Capturing Groups&#xff09;二. 非捕获组&#xff08;Non-Capturing Groups&#xff09;三. 区别四. 选择使用 这是我在这个网站整理的笔记,有错误的地方请指出&#xff0c;关注我&#xff0c;接下来还会持续更新。 作者&#xff1a;…

GESP等级考试C++二级-数学函数

C的cmath库中有丰富的数学函数&#xff0c;通过这些函数可以进行相应的数学计算。 1 cmath库的导入 通过import指令导入cmath库&#xff0c;代码如图1所示。 图1 导入cmath库的代码 2 abs()函数 abs()函数用来获取指定数的绝对值&#xff0c;代码如图2所示。 图2 abs()函数…

【递归】7. leetcode 404 左叶子之和

1 题目描述 题目链接&#xff1a;左叶子之和 2 解答思路 递归分为三步&#xff0c;接下来就按照这三步来思考问题 第一步&#xff1a;挖掘出相同的子问题 &#xff08;关系到具体函数头的设计&#xff09; 第二步&#xff1a;只关心具体子问题做了什么 &#xff08;关系…

macOS安装Redis教程, 通过brew命令, 时间是2024年9月26日, redis版本是0.7.2

搜索: brew search redis安装Redis: brew install redis关于启动命令的提示: To start redis now and restart at login:brew services start redis Or, if you dont want/need a background service you can just run:/opt/homebrew/opt/redis/bin/redis-server /opt/home…

【图像处理】多幅不同焦距的同一个物体的平面图象,合成一幅具有立体效果的单幅图像原理(二)

实现多幅不同焦距图像合成一幅具有立体效果的图像可以使用以下算法和开源库&#xff1a; 实现算法 图像对齐 使用特征点匹配&#xff08;如 SIFT、SURF 或 ORB&#xff09;来对齐图像。利用 RANSAC 算法剔除离群点&#xff0c;估计变换矩阵。 深度图生成 基于图像的焦距和视角…

Teams集成-会议侧边栏应用开发-会议转写

Teams应用开发&#xff0c;主要是权限比较麻烦&#xff0c;大量阅读和实践&#xff0c;摸索了几周&#xff0c;才搞明白。现将经验总结如下&#xff1a; 一、目标&#xff1a;开发一个Teams会议的侧边栏应用&#xff0c;实现会议的实时转写。 二、前提&#xff1a; 1&#x…