【现代深度学习技术】深度学习计算 | 参数管理

news2025/1/31 3:06:43

在这里插入图片描述

【作者主页】Francek Chen
【专栏介绍】 ⌈ ⌈ PyTorch深度学习 ⌋ ⌋ 深度学习 (DL, Deep Learning) 特指基于深层神经网络模型和方法的机器学习。它是在统计机器学习、人工神经网络等算法模型基础上,结合当代大数据和大算力的发展而发展出来的。深度学习最重要的技术特征是具有自动提取特征的能力。神经网络算法、算力和数据是开展深度学习的三要素。深度学习在计算机视觉、自然语言处理、多模态数据分析、科学探索等领域都取得了很多成果。本专栏介绍基于PyTorch的深度学习算法实现。
【GitCode】专栏资源保存在我的GitCode仓库:https://gitcode.com/Morse_Chen/PyTorch_deep_learning。

文章目录

    • 一、参数访问
      • (一)目标参数
      • (二)一次性访问所有参数
      • (三)从嵌套块收集参数
    • 二、参数初始化
      • (一)内置初始化
      • (二)自定义初始化
    • 三、参数绑定
    • 小结


  在选择了架构并设置了超参数后,我们就进入了训练阶段。此时,我们的目标是找到使损失函数最小化的模型参数值。经过训练后,我们将需要使用这些参数来做出未来的预测。此外,有时我们希望提取参数,以便在其他环境中复用它们,将模型保存下来,以便它可以在其他软件中执行,或者为了获得科学的理解而进行检查。

  之前的介绍中,我们只依靠深度学习框架来完成训练的工作,而忽略了操作参数的具体细节。本节,我们将介绍以下内容:

  • 访问参数,用于调试、诊断和可视化;
  • 参数初始化;
  • 在不同模型组件间共享参数。

  我们首先看一下具有单隐藏层的多层感知机。

import torch
from torch import nn

net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
X = torch.rand(size=(2, 4))
net(X)

在这里插入图片描述

一、参数访问

  我们从已有模型中访问参数。当通过Sequential类定义模型时,我们可以通过索引来访问模型的任意层。这就像模型是一个列表一样,每层的参数都在其属性中。如下所示,我们可以检查第二个全连接层的参数。

print(net[2].state_dict())

在这里插入图片描述

  输出的结果告诉我们一些重要的事情:首先,这个全连接层包含两个参数,分别是该层的权重和偏置。两者都存储为单精度浮点数(float32)。注意,参数名称允许唯一标识每个参数,即使在包含数百个层的网络中也是如此。

(一)目标参数

  注意,每个参数都表示为参数类的一个实例。要对参数执行任何操作,首先我们需要访问底层的数值。有几种方法可以做到这一点。有些比较简单,而另一些则比较通用。下面的代码从第二个全连接层(即第三个神经网络层)提取偏置,提取后返回的是一个参数类实例,并进一步访问该参数的值。

print(type(net[2].bias))
print(net[2].bias)
print(net[2].bias.data)

在这里插入图片描述

  参数是复合的对象,包含值、梯度和额外信息。这就是我们需要显式参数值的原因。除了值之外,我们还可以访问每个参数的梯度。在上面这个网络中,由于我们还没有调用反向传播,所以参数的梯度处于初始状态。

net[2].weight.grad == None

在这里插入图片描述

(二)一次性访问所有参数

  当我们需要对所有参数执行操作时,逐个访问它们可能会很麻烦。当我们处理更复杂的块(例如,嵌套块)时,情况可能会变得特别复杂,因为我们需要递归整个树来提取每个子块的参数。下面,我们将通过演示来比较访问第一个全连接层的参数和访问所有层。

print(*[(name, param.shape) for name, param in net[0].named_parameters()])
print(*[(name, param.shape) for name, param in net.named_parameters()])

在这里插入图片描述

  这为我们提供了另一种访问网络参数的方式,如下所示。

net.state_dict()['2.bias'].data

在这里插入图片描述

(三)从嵌套块收集参数

  让我们看看,如果我们将多个块相互嵌套,参数命名约定是如何工作的。我们首先定义一个生成块的函数(可以说是“块工厂”),然后将这些块组合到更大的块中。

def block1():
    return nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 4), nn.ReLU())

def block2():
    net = nn.Sequential()
    for i in range(4):
        # 在这里嵌套
        net.add_module(f'block {i}', block1())
    return net

rgnet = nn.Sequential(block2(), nn.Linear(4, 1))
rgnet(X)

在这里插入图片描述

  设计了网络后,我们看看它是如何工作的。

print(rgnet)

在这里插入图片描述

  因为层是分层嵌套的,所以我们也可以像通过嵌套列表索引一样访问它们。下面,我们访问第一个主要的块中、第二个子块的第一层的偏置项。

rgnet[0][1][0].bias.data

在这里插入图片描述

二、参数初始化

  知道了如何访问参数后,现在我们看看如何正确地初始化参数。我们在【深度学习基础】多层感知机 | 数值稳定性和模型初始化 中讨论了良好初始化的必要性。深度学习框架提供默认随机初始化,也允许我们创建自定义初始化方法,满足我们通过其他规则实现初始化权重。

  默认情况下,PyTorch会根据一个范围均匀地初始化权重和偏置矩阵,这个范围是根据输入和输出维度计算出的。PyTorch的nn.init模块提供了多种预置初始化方法。

(一)内置初始化

  让我们首先调用内置的初始化器。下面的代码将所有权重参数初始化为标准差为0.01的高斯随机变量,且将偏置参数设置为0。

def init_normal(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, mean=0, std=0.01)
        nn.init.zeros_(m.bias)
net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]

在这里插入图片描述

  我们还可以将所有参数初始化为给定的常数,比如初始化为1。

def init_constant(m):
    if type(m) == nn.Linear:
        nn.init.constant_(m.weight, 1)
        nn.init.zeros_(m.bias)
net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]

在这里插入图片描述

  我们还可以对某些块应用不同的初始化方法。例如,下面我们使用Xavier初始化方法初始化第一个神经网络层,然后将第三个神经网络层初始化为常量值42。

def init_xavier(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)
def init_42(m):
    if type(m) == nn.Linear:
        nn.init.constant_(m.weight, 42)

net[0].apply(init_xavier)
net[2].apply(init_42)
print(net[0].weight.data[0])
print(net[2].weight.data)

在这里插入图片描述

(二)自定义初始化

  有时,深度学习框架没有提供我们需要的初始化方法。在下面的例子中,我们使用以下的分布为任意权重参数 w w w定义初始化方法:

w ∼ { U ( 5 , 10 )  可能性  1 4 0  可能性  1 2 U ( − 10 , − 5 )  可能性  1 4 (1) \begin{aligned} w \sim \begin{cases} U(5, 10) & \text{ 可能性 } \frac{1}{4} \\ 0 & \text{ 可能性 } \frac{1}{2} \\ U(-10, -5) & \text{ 可能性 } \frac{1}{4} \end{cases} \end{aligned} \tag{1} w U(5,10)0U(10,5) 可能性 41 可能性 21 可能性 41(1)

  同样,我们实现了一个my_init函数来应用到net

def my_init(m):
    if type(m) == nn.Linear:
        print("Init", *[(name, param.shape) for name, param in m.named_parameters()][0])
        nn.init.uniform_(m.weight, -10, 10)
        m.weight.data *= m.weight.data.abs() >= 5

net.apply(my_init)
net[0].weight[:2]

在这里插入图片描述

  注意,我们始终可以直接设置参数。

net[0].weight.data[:] += 1
net[0].weight.data[0, 0] = 42
net[0].weight.data[0]

在这里插入图片描述

三、参数绑定

  有时我们希望在多个层间共享参数:我们可以定义一个稠密层,然后使用它的参数来设置另一个层的参数。

# 我们需要给共享层一个名称,以便可以引用它的参数
shared = nn.Linear(8, 8)
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), shared, nn.ReLU(),
                    shared, nn.ReLU(), nn.Linear(8, 1))
net(X)
# 检查参数是否相同
print(net[2].weight.data[0] == net[4].weight.data[0])
net[2].weight.data[0, 0] = 100
# 确保它们实际上是同一个对象,而不只是有相同的值
print(net[2].weight.data[0] == net[4].weight.data[0])

在这里插入图片描述

  这个例子表明第三个和第五个神经网络层的参数是绑定的。它们不仅值相等,而且由相同的张量表示。因此,如果我们改变其中一个参数,另一个参数也会改变。这里有一个问题:当参数绑定时,梯度会发生什么情况?答案是由于模型参数包含梯度,因此在反向传播期间第二个隐藏层(即第三个神经网络层)和第三个隐藏层(即第五个神经网络层)的梯度会加在一起。

小结

  • 我们有几种方法可以访问、初始化和绑定模型参数。
  • 我们可以使用自定义初始化方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2284974.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

c++ 定点 new

&#xff08;1&#xff09; 代码距离&#xff1a; #include <new> // 需要包含这个头文件 #include <iostream>int main() {char buffer[sizeof(int)]; // 分配一个足够大的字符数组作为内存池int* p new(&buffer) int(42); // 使用 placement new…

宫本茂的游戏设计思想:有趣与风格化

作为独立游戏开发者之一&#xff0c;看到任天堂宫本茂20年前的言论后&#xff0c;深感认同。 游戏研发思想&#xff0c;与企业战略是互为表里的&#xff0c;游戏是企业战略的具体战术体现&#xff0c;虚空理念的有形载体。 任天堂长盛不衰的关键就是靠简单有趣的游戏&#xf…

【AI论文】扩散对抗后训练用于一步视频生成总结

摘要&#xff1a;扩散模型被广泛应用于图像和视频生成&#xff0c;但其迭代生成过程缓慢且资源消耗大。尽管现有的蒸馏方法已显示出在图像领域实现一步生成的潜力&#xff0c;但它们仍存在显著的质量退化问题。在本研究中&#xff0c;我们提出了一种在扩散预训练后针对真实数据…

在线可编辑Excel

1. Handsontable 特点&#xff1a; 提供了类似 Excel 的表格编辑体验&#xff0c;包括单元格样式、公式计算、数据验证等功能。 支持多种插件&#xff0c;如筛选、排序、合并单元格等。 轻量级且易于集成到现有项目中。 具备强大的自定义能力&#xff0c;可以调整外观和行为…

【javaweb项目idea版】蛋糕商城(可复用成其他商城项目)

该项目虽然是蛋糕商城项目&#xff0c;但是可以复用成其他商城项目或者购物车项目 想要源码的uu可点赞后私聊 技术栈 主要为&#xff1a;javawebservletmvcc3p0idea运行 功能模块 主要分为用户模块和后台管理员模块 具有商城购物的完整功能 基础模块 登录注册个人信息编辑…

langchain基础(三)

Chain&#xff1a; 关于三个invoke&#xff1a; 提示模板、聊天模型和输出解析器都实现了langchain的runnable接口&#xff0c; 都具有invoke方法&#xff08;因为invoke方法是Runnable的通用调用方法&#xff09; 所以可以一次性调用多次invoke直接得到最终结果&#xff1a;…

Kafka 深入服务端 — 时间轮

Kafka中存在大量的延迟操作&#xff0c;比如延时生产、延时拉取和延时删除等。Kafka基于时间轮概念自定义实现了一个用于延时功能的定时器&#xff0c;来完成这些延迟操作。 1 时间轮 Kafka没有使用基于JDK自带的Timer或DelayQueue来实现延迟功能&#xff0c;因为它们的插入和…

一文掌握ADB的安装及使用

文章目录 一、什么是ADB&#xff1f;二、 安装ADB2.1 下载ADB2.2 配置环境变量 三、连接Android设备四、 常用ADB命令五、ADB高级功能5.1 屏幕截图和录制5.2 模拟按键输入5.3 文件管理5.4 系统设置管理5.5 系统操作指令5.6 日志操作指令5.7 APK操作指令5.8 设备重启和恢复 六、…

Linux系统下速通stm32的clion开发环境配置

陆陆续续搞这个已经很久了。 因为自己新电脑是linux系统无法使用keil&#xff0c;一开始想使用vscode里的eide但感觉不太好用&#xff1b;后面想直接使用cudeide但又不想妥协&#xff0c;想趁着这个机会把linux上的其他单片机开发配置也搞明白&#xff1b;而且非常想搞懂cmake…

OpenCSG月度更新2025.1

1月的OpenCSG取得了一些亮眼的成绩 在2025年1月&#xff0c;OpenCSG在产品和社区方面继续取得了显著进展。产品方面&#xff0c;推出了AutoHub浏览器自动化助手&#xff0c;帮助用户提升浏览体验&#xff1b;CSGHub企业版功能全面升级&#xff0c;现已开放试用申请&#xff0c…

AWTK 骨骼动画控件发布

Spine 是一款广泛使用的 2D 骨骼动画工具&#xff0c;专为游戏开发和动态图形设计设计。它通过基于骨骼的动画系统&#xff0c;帮助开发者创建流畅、高效的角色动画。本项目是基于 Spine 实现的 AWTK 骨骼动画控件。 代码&#xff1a;https://gitee.com/zlgopen/awtk-widget-s…

go到底是什么意思:对go的猜测或断言

go这个单词&#xff0c;简单地讲&#xff0c;表示“走或去”的意思&#xff1a; go v.去&#xff1b;走 认真想想&#xff0c;go是一个非常神秘的单词&#xff0c;g-和o-这两个字母&#xff0c;为什么就会表达“去&#xff1b;走”的意思呢&#xff1f;它的字面义或本质&…

学习数据结构(2)空间复杂度+顺序表

1.空间复杂度 &#xff08;1&#xff09;概念 空间复杂度也是一个数学表达式&#xff0c;表示一个算法在运行过程中根据算法的需要额外临时开辟的空间。 空间复杂度不是指程序占用了多少bytes的空间&#xff0c;因为常规情况每个对象大小差异不会很大&#xff0c;所以空间复杂…

DeepSeek--通向通用人工智能的深度探索者

一、词源与全称 “DeepSeek"由"Deep”&#xff08;深度&#xff09;与"Seek"&#xff08;探索&#xff09;组合而成&#xff0c;中文译名为"深度求索"。其全称为"深度求索人工智能基础技术研究有限公司"&#xff0c;英文对应"De…

Unity游戏(Assault空对地打击)开发(1) 创建项目和选择插件

目录 前言 创建项目 插件导入 地形插件 前言 这是游戏开发第一篇&#xff0c;进行开发准备。 创作不易&#xff0c;欢迎支持。 我的编辑器布局是【Tall】&#xff0c;建议调整为该布局&#xff0c;如下。 创建项目 首先创建一个项目&#xff0c;过程略&#xff0c;名字请勿…

(三)Session和Cookie讲解

目录 一、前备知识点 &#xff08;1&#xff09;静态网页 &#xff08;2&#xff09;动态网页 &#xff08;3&#xff09;无状态HTTP 二、Session和Cookie 三、Session 四、Cookie &#xff08;1&#xff09;维持过程 &#xff08;2&#xff09;结构 正式开始说 Sessi…

1.Template Method 模式

模式定义 定义一个操作中的算法的骨架&#xff08;稳定&#xff09;&#xff0c;而将一些步骤延迟&#xff08;变化)到子类中。Template Method 使得子类可以不改变&#xff08;复用&#xff09;一个算法的结构即可重定义&#xff08;override 重写&#xff09;该算法的某些特…

【PyTorch】5.张量索引操作

目录 1. 简单行、列索引 2. 列表索引 3. 范围索引 4. 布尔索引 5. 多维索引 个人主页&#xff1a;Icomi 在深度学习蓬勃发展的当下&#xff0c;PyTorch 是不可或缺的工具。它作为强大的深度学习框架&#xff0c;为构建和训练神经网络提供了高效且灵活的平台。神经网络作为…

[EAI-023] FAST: Efficient Action Tokenization for Vision-Language-Action Models

Paper Card 论文标题&#xff1a;FAST: Efficient Action Tokenization for Vision-Language-Action Models 论文作者&#xff1a;Karl Pertsch, Kyle Stachowicz, Brian Ichter, Danny Driess, Suraj Nair, Quan Vuong, Oier Mees, Chelsea Finn, Sergey Levine 论文链接&…

2025年AI手机集中上市,三星Galaxy S25系列上市

2025年被认为是AI手机集中爆发的一年&#xff0c;各大厂商都会推出搭载人工智能的智能手机。三星Galaxy S25系列全球上市了。 三星Galaxy S25系列包含S25、S25和S25 Ultra三款机型&#xff0c;起售价为800美元&#xff08;约合人民币5800元&#xff09;。全系搭载骁龙8 Elite芯…