PyTorch|搭建分类网络实例、nn.Module源码学习

news2025/1/22 16:13:02

系列文章目录

PyTorch|Dataset与DataLoader使用、构建自定义数据集

文章目录

  • 系列文章目录
  • 一、Transforms
  • 二、构建神经网络模型
  • 三、模型分层
    • (一)模型输入
    • (二)nn.Flatten
    • (三)nn.Linear
    • (四)非线性激活函数nn.ReLU
    • (五)nn.Sequential
    • (六)nn.Softmax
    • (七)模型参数
  • 四、nn.Module源码
    • (一)init函数
    • (二)register_buffer函数
    • (三)register_parameter函数
    • (四)add_module函数、register_module函数、get_submodule函数
    • (五)get_parameter函数、get_buffer函数
    • (六)_apply函数和apply函数
    • (七)cuda函数、xpu函数、cpu函数
    • (八)type函数、float函数、double函数、half函数、bfloat16函数
    • (九)to函数、to_empty函数
    • (十)__getattr__函数、parameters函数、buffers函数、modules函数
    • (十一)_save_to_state_dict函数、state_dict函数、_load_from_state_dict函数、load_state_dict函数
    • (十二)train函数、eval函数
    • (十三)requires_grad_函数、zero_grad函数


一、Transforms

数据并不总是以训练机器学习算法所需的最终处理形式出现。Transforms是对数据的特征和标签等进行变换,使其满足神经网络的输入要求。Transforms函数一般是在Dataset中定义好,然后通过get_item应用。

  • transform:修改特征
  • target_transform:修改标签

FashionMNIST数据集为PILlmage格式,标签为整数。神经网络的训练需要将特征归一化张量,标签是单热编码张量。为了进行这些转换,我们使用ToTensor和Lambda:

  • ToTensor将PIL图像或NumPy数组转换为FloatTensor,并在范围内缩放图像的像素强度值[0,1]。
  • Lambda变换应用于任何用户定义的Lambda函数。这里定义了一个函数来将整数转换为一个独热编码张量。它首先创建一个大小为10的零张量(我们数据集中的标签数量),并调用scatter_,它在标签y给出的索引上赋值1。
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))

二、构建神经网络模型

神经网络由对数据执行操作的层/模块组成。torch nn命名空间提供了构建自己的神经网络所需的所有构建块。PyTorch中的每个模块都是n. module的子类。

构建一个神经网络来对FashionMNIST数据集中的图像进行分类:

导入相关库:

import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

确定训练的设备:

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

定义分类模型网络:
所有的层、网络、模型都需要继承自nn.Module父类,并且通常需要定义两个方法:

  • init方法:创建子模块,初始化神经网络层
  • forward方法:对输入数据的前向运算
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__() #调用父类init方法
        
        self.flatten = nn.Flatten() 
        # 线性relu堆叠模块
        self.linear_relu_stack = nn.Sequential( #串联模块
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x) #维度展开,铺平
        logits = self.linear_relu_stack(x)
        return logits

调用分类网络:

model = NeuralNetwork().to(device)
print(model)

在这里插入图片描述

使用分类网络:

X = torch.rand(1, 28, 28, device=device)
logits = model(X)
pred_probab = nn.Softmax(dim=1)(logits) 
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

在这里插入图片描述

三、模型分层

(一)模型输入

采用3张大小为28x28的图像的样本minibatch输入到网络中:

input_image = torch.rand(3,28,28) #随机生成一个维度为(3,28,28)的张量
print(input_image.size())

输入张量的大小:
在这里插入图片描述

(二)nn.Flatten

从start_dim维度到end_dim维度进行铺平,默认从第一维到最后一维(最终只保留第0维和其他维共两个维度)
在这里插入图片描述

调用flatten之后维度就从3x28x28转换为了3x784:

flatten = nn.Flatten()
flat_image = flatten(input_image)
print(flat_image.size())

在这里插入图片描述

(三)nn.Linear

nn.Linear包含输入维度、输出维度、偏置、设备、类型等参数,Linear层还包括weight和bias属性。
在这里插入图片描述

将维度为[3,784]的数据输入到线性层中,返回输出维度为[3,20]。

layer1 = nn.Linear(in_features=28*28, out_features=20)
hidden1 = layer1(flat_image)
print(hidden1.size())

在这里插入图片描述

(四)非线性激活函数nn.ReLU

在这里插入图片描述

print(f"Before ReLU: {hidden1}\n\n")
hidden1 = nn.ReLU()(hidden1)
print(f"After ReLU: {hidden1}")

在这里插入图片描述

(五)nn.Sequential

nn.Sequential是一个关于模块的堆叠容器。

seq_modules = nn.Sequential(
    flatten,
    layer1,
    nn.ReLU(),
    nn.Linear(20, 10)
)
input_image = torch.rand(3,28,28)
logits = seq_modules(input_image)

(六)nn.Softmax

实例归一化:

softmax = nn.Softmax(dim=1)
pred_probab = softmax(logits)

(七)模型参数

print(f"Model structure: {model}\n\n")

for name, param in model.named_parameters():
    print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n")

由于第1层和第3层是relu函数,不含参数,所以不会打印出来。
在这里插入图片描述

四、nn.Module源码

(一)init函数

默认情况下training=True
在这里插入图片描述

(二)register_buffer函数

神经网络中的“buffer”: 通常指代在网络中的某些层或操作中存储或缓存的临时数据。这些缓冲区可以在网络的前向传播和反向传播过程中被使用,以帮助网络进行参数更新和计算梯度。例如:

  • Batch Normalization 中的均值和方差: 在批量归一化层中,通常会在训练过程中计算每个批次的输入数据的均值和方差,并将它们存储在缓冲区中。这些均值和方差用于标准化输入数据,以便提高网络的训练效果。
  • 滑动平均(Exponential Moving Average,EMA): 在一些优化算法(如 Momentum、Adam 等)中,会使用滑动平均来估计参数的移动均值,以稳定优化过程。这些移动均值通常存储在缓冲区中,并在每次迭代中更新。
  • 卷积层的权重和偏置: 在卷积神经网络中,卷积层的权重和偏置通常存储在缓冲区中。这些参数在网络的训练过程中被更新,并在前向传播和反向传播中被使用。
  • 循环神经网络(RNN)中的隐藏状态: 在循环神经网络中,隐藏状态通常被存储在缓冲区中,并在每个时间步被更新。这些隐藏状态在网络的每个时间步被传递和使用。

register_buffer函数的作用: 定义一组参数,该组参数在模型训练时不会更新(即调用 optimizer.step() 后该组参数不会变化,只可人为地改变它们的值),但是保存模型时,该组参数又作为模型参数不可或缺的一部分被保存。

在这里插入图片描述

(三)register_parameter函数

Parameter和Buffer的区分:

  • 模型中需要进行更新的参数注册为Parameter,不需要进行更新的参数注册为buffer
  • 模型保存的参数是 model.state_dict() 返回的 OrderDict
  • 模型进行设备移动时,模型中注册的参数(Parameter和buffer)会同时进行移动

register_parameter函数主要用于注册一个可训练更新的参数: 将一个不可训练的类型Tensor转换成可以训练的类型parameter,并将这个parameter绑定到这个module里面,相当于变成了模型的一部分,成为了模型中可以根据训练进行变化的参数。

在这里插入图片描述

使用实例:
在这里插入图片描述

(四)add_module函数、register_module函数、get_submodule函数

add_module函数:往当前module中再去增加一个子模块,这个子模块会加入到self._modules字典中
在这里插入图片描述

register_module函数:用于注册模块
在这里插入图片描述

get_submodule函数:用于获取当前模块中的子模块
在这里插入图片描述

(五)get_parameter函数、get_buffer函数

get_parameter函数:根据字符串获得对应的模型参数
在这里插入图片描述
get_buffer函数:根据字符串获得对应的buffer
在这里插入图片描述

(六)_apply函数和apply函数

_apply函数:

  • 对所有的子模块调用某个function
  • 对所有的参数调用某个function
  • 对buffer变量调用某个function

apply函数:

  • 在模型参数初始化时会用到apply函数,主要作用是递归的将某个function运用到子模块
    在这里插入图片描述

(七)cuda函数、xpu函数、cpu函数

cuda函数是将所有的模型参数及buffer变量移动到gpu上
在这里插入图片描述

xpu函数、cpu函数类似,是将所有的模型参数及buffer变量移动到xpu、cpu上
在这里插入图片描述

在这里插入图片描述

这三个函数本质上都是使用的_apply函数

(八)type函数、float函数、double函数、half函数、bfloat16函数

  • type函数:将所有的参数、buffer都转化一个数据类型
  • float函数、double函数、half函数、bfloat16函数都是实现对于浮点数的转换(都是转换函数,但是只针对浮点类型)

(九)to函数、to_empty函数

  • to_empty函数:将当前模型中的参数、buffer都移动到一个设备上,但是不会拷贝存储空间
  • to函数有许多种用法:
    在这里插入图片描述

(十)__getattr__函数、parameters函数、buffers函数、modules函数

__getattr__函数中的所有_parameters、_buffers、_modules没有对子模块进行遍历,只会去对当前模块进行查找
在这里插入图片描述
而parameters函数、buffers函数及modules函数都是递归的,会返回当前module及子module的参数或者buffers。

(十一)_save_to_state_dict函数、state_dict函数、_load_from_state_dict函数、load_state_dict函数

  • _save_to_state_dict函数:把当前module的所有参数及buffers放入一个字典中

  • state_dict函数:对当前module及子module的所有参数及buffers放入一个字典中

  • _load_from_state_dict函数:从一个state_dict中得到参数及buffers中然后载入到当前的模型中

  • load_state_dict函数:递归载入所有参数及buffers

(十二)train函数、eval函数

train函数:参数设置成true就说明已经将该模型设置为训练模式(包括子模块)
在这里插入图片描述

eval函数:将模型设置为评估模式,其实就是将train函数的参数设置为false

在这里插入图片描述

(十三)requires_grad_函数、zero_grad函数

自动微分是否需要在这些参数上记录操作,换句话说就是是否需要对这个模型记录导数值
在这里插入图片描述

zero_grad函数用于清理之前的累计梯度,在训练中一般不用对模型参数进行清理,优化器中会有用到。

参考:
PyTorch官方教程
7、深入剖析PyTorch nn.Module源码
8、深入剖析PyTorch的state_dict、parameters、modules源码

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1581664.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

spring boot admin搭建,监控springboot程序运行状况

新建一个spring boot web项目&#xff0c;添加以下依赖 <dependency><groupId>de.codecentric</groupId><artifactId>spring-boot-admin-starter-server</artifactId><version>2.3.0</version></dependency> <dependency&…

大语言模型如何工作?

此为观看视频How Large Language Model works的笔记。 GPT&#xff08;Generative Pre-trained Transformer&#xff09;是一个大语言模型&#xff08;LLM&#xff09;&#xff0c;可以生成类似人类的文本。本文阐述&#xff1a; 什么是LLMLLM如何工作LLM的应用场景 什么是…

【算法】双指针算法

个人主页 &#xff1a; zxctscl 如有转载请先通知 题目 1. 283. 移动零1.1 分析1.2 代码 2. 1089. 复写零2.1 分析2.2 代码 3. 202. 快乐数3.1 分析3.2 代码 4. 11. 盛最多水的容器4.1 分析4.2 代码 5. LCR 179. 查找总价格为目标值的两个商品5.1 分析5.2 代码 6. 15. 三数之和…

【前端】es-drager 图片同比缩放 缩放比 只修改宽 只修改高

【前端】es-drager 图片同比缩放 缩放比 ES Drager 拖拽组件 (vangleer.github.io) 核心代码 //初始宽 let width ref(108)//初始高 let height ref(72)//以下两个变量 用来区分是单独的修改宽 还是高 或者是同比 //缩放开始时的宽 let oldWidth 0 //缩放开始时的高 let o…

JRT判断数据是否存在优化

有一种业务情况类似下图&#xff0c;质控能做的项目是仪器关联的项目。这时候维护质控物时候开通项目时候要求加载仪器项目里面的项目&#xff08;没有开通的子业务数据的部分&#xff09;。对右边已经开通的部分要求加载仪器项目里面的项目&#xff08;有开通业务子数据的部分…

概率论基础——拉格朗日乘数法

概率论基础——拉格朗日乘数法 概率论是机器学习和优化领域的重要基础之一&#xff0c;而拉格朗日乘数法与KKT条件是解决优化问题中约束条件的重要工具。本文将简单介绍拉格朗日乘数法的基本概念、应用以及如何用Python实现算法。 1. 基本概念 拉格朗日乘数法是一种用来求解…

Redis缓存穿透和缓存雪崩

一、缓存穿透 1 什么是缓存穿透 缓存穿透说简单点就是大量请求的 key 根本不存在于缓存中&#xff0c;导致请求直接到了数据库上&#xff0c;根本没有经过缓存这一层。举个例子&#xff1a;某个黑客故意制造我们缓存中不存在的 key 发起大量请求&#xff0c;导致大量请求落到数…

scFed:联邦学习用于scRNA-seq分类

scRNA-seq的出现彻底改变了我们对生物组织中细胞异质性和复杂性的理解。然而&#xff0c;大型&#xff0c;稀疏的scRNA-seq数据集的隐私法规对细胞分类提出了挑战。联邦学习提供了一种解决方案&#xff0c;允许高效和私有的数据使用。scFed是一个统一的联邦学习框架&#xff0c…

用户态网络缓冲区的设计

一、网络缓冲区 在内核中也是有网络缓冲区的&#xff0c;比如使用 read 读取数据&#xff08;read 是一种系统调用&#xff0c;第一个参数为 fd&#xff09;&#xff0c;当陷入到内核态的时候&#xff0c;会通过 fd 指定 socket&#xff0c;socket 会找到对应的接收缓冲区。在…

安装VMware ESXi虚拟机系统

简介&#xff1a;ESXi是VMware公司开发的一款服务器虚拟化操作系统。它能够在一台物理服务器上运行多个虚拟机&#xff0c;每个虚拟机都可以独立运行操作系统和应用程序&#xff0c;而且对硬件配置要求低&#xff0c;系统运行稳定。 准备工具&#xff1a; 1.8G或者8G以上容…

7-155 好玩的游戏:消消乐

消消乐是一个非常流行的手机游戏。现在游戏创意设计师Jerry突发奇想设计一个如下图所示的一维消消乐游戏,Jerry想知道游戏生成的小球布局在玩家玩的过程中最高总分能得多少,现在Jerry向资深的程序员你求助,希望你能帮助他算出每个游戏初局的最高得分。 游戏规则是这样的:…

SWM341系列应用(RTC、FreeRTOS\RTTHREAD应用和Chip ID)

SWM341系列RTC应用 22.1、RTC的时钟基准 --liuzc 2023-8-17 现象:客户休眠发现RTC走的不准&#xff0c;睡眠2小时才走了5分钟。 分析与解决&#xff1a;经过排查RTC的时钟源是XTAL_32K&#xff0c;由于睡眠时时设置XTAL->CR0&#xff1b;&#xff0c;会把XTAL_32K给关…

【DM8】外部表

外部表是指不存在于数据库中的表。 通过向达梦数据库定义描述外部表的元数据&#xff0c;可以把一个操作系统文件当成一个只读的数据库表&#xff0c;对外部表将像普通定义的表一样访问。 外部表的数据存储在操作系统文件中&#xff0c;建立外部表的时候&#xff0c;不会产生…

B02、分析GC日志-6.3

1、相关GC日志参数 -verbose:gc 输出gc日志信息&#xff0c;默认输出到标准输出-XX:PrintGC 输出GC日志。类似&#xff1a;-verbose:gc-XX:PrintGCDetails 在发生垃圾回收时打印内存回收详细的日志&#xff0c; 并在进程退出时输出当前内存各区域分配情况-XX:PrintGCTimeStamp…

深度学习-多尺度训练的介绍与应用

一、引言 在当今快速发展的人工智能领域&#xff0c;多尺度训练已经成为了一种至关重要的技术&#xff0c;特别是在处理具有复杂结构和不同尺度特征的数据时。这种技术在许多应用中发挥着关键作用&#xff0c;例如图像识别、自然语言处理和视频分析等。 多尺度训练的定义 多尺…

设计模式(22):解释器模式

解释器 是一种不常用的设计模式用于描述如何构成一个简单的语言解释器&#xff0c;主要用于使用面向对象语言开发的解释器和解释器设计当我们需要开发一种新的语言时&#xff0c;可以考虑使用解释器模式尽量不要使用解释器模式&#xff0c;后期维护会有很大麻烦。在项目中&…

4月9日学习记录

[GXYCTF 2019]禁止套娃 涉及知识点&#xff1a;git泄露&#xff0c;无参数RCE 打开环境&#xff0c;源码什么的都没有&#xff0c;扫描后台看看 扫描发现存在git泄露 用githack下载查看得到一串源码 <?php include "flag.php"; echo "flag在哪里呢&#…

django之ajax

【一】前言 Ajax 异步提交局部刷新 发送请求的方式 浏览器地址栏直接输入url回车 GET请求a标签href属性 GET请求form表单 GET请求/POST请求ajax GET请求/POST请求 ​ AJAX 不是新的编程语言&#xff0c; 而是一种使用先有标准的新方法&#xff08;比如装饰器&#xff09; …

AtCoder ABC347 A-D题解

个人感觉这次D有点难。 比赛链接:ABC347 Problem A: 签到题。 #include <bits/stdc.h> using namespace std; int main(){int N,K;cin>>N>>K;for(int i1;i<N;i){int A;cin>>A;if(A%K0)cout<<A/K;}return 0; } Problem B: 主要考substr的…

unity按路径移动

using System; using System.Collections; using System.Collections.Generic; using UnityEngine;public class FollowPathMove : MonoBehaviour {public Transform[] wayPointArray;[SerializeField] private Transform PathA;//路径点的父物体[SerializeField]private Trans…