PyTorch之nn.Module、nn.Sequential、nn.ModuleList使用详解

news2025/1/17 2:56:11

文章目录

  • 1. nn.Module
    • 1.1 基本使用
    • 1.2 常用函数
      • 1.2.1 核心函数
      • 1.2.2 查看函数
      • 1.2.3 设置函数
      • 1.2.4 注册函数
      • 1.2.5 转换函数
      • 1.2.6 加载函数
  • 2. nn.Sequential()
    • 2.1 基本定义
    • 2.2 Sequential类不同的实现
    • 2.3 nn.Sequential()的本质作用
  • 3. nn.ModuleList
  • 参考资料

本篇文章主要介绍 torch.nn.Moduletorch.nn.Sequential()torch.nn.ModuleList 的使用方法与区别。

1. nn.Module

1.1 基本使用

在PyTorch中,nn.Module 类扮演着核心角色,它是构建任何自定义神经网络层、复杂模块或完整神经网络架构的基础构建块。通过继承 nn.Module 并在其子类中定义模型结构和前向传播逻辑(forward() 方法),开发者能够方便地搭建并训练深度学习模型。

在自定义一个新的模型类时,通常需要:

  • 继承 nn.Module
  • 重新实现 __init__ 构造函数
  • 重新实现 forward 方法

实现代码如下:

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    # nn.Module的子类函数必须在构造函数中执行父类的构造函数
    def __init__(self):
        super(Model, self).__init__()   # 等价与nn.Module.__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)
	def forward(self, x):
		x = F.relu(self.conv1(x))
		return F.relu(self.conv2(x))
    
   
model=Model()
print(model)

输出如下:

在这里插入图片描述

注意:

  • 一般把网络中具有可学习参数的层(如全连接层、卷积层)放在构造函数 __init__()
  • forward() 方法必须重写,它是实现模型的功能,实现各个层之间连接关系的核心

nn.Module类中的关键属性和方法包括:

  1. 初始化 (init):在类的初始化方法中定义并实例化所有需要的层、参数和其他组件。
    在实现自己的MyModel类时继承了nn.Module,在构造函数中要调用Module的构造函数 super(MyModel,self).init()

  2. 前向传播 (forward):实现前向传播函数来描述输入数据如何通过网络产生输出结果。
    因为parameters是自动求导,所以调用forward()后,不用自己写和调用backward()函数。而且一般不是显式的调用forward(layer.farword),而是layer(input),会自执行forward()。

  3. 管理参数和模块

  • 使用 .parameters() 访问模型的所有可学习参数。
  • 使用 add_module() 添加子模块,并给它们命名以便于访问。
  • 使用 register_buffer() 为模型注册非可学习的缓冲区变量。
  1. 训练与评估模式切换
  • 使用 model.train() 将模型设置为训练模式,这会影响某些层的行为,如批量归一化层和丢弃层。
  • 使用 model.eval() 将模型设置为评估模式,此时会禁用这些依赖于训练阶段的行为。
  1. 保存和加载模型状态
  • 调用 model.state_dict() 获取模型权重和优化器状态的字典形式。
  • 使用 torch.save() 和 torch.load() 来保存和恢复整个模型或者仅其状态字典。
  • 通过 model.load_state_dict(state_dict) 加载先前保存的状态字典到模型中。

此外,nn.Module 还提供了诸如移动模型至不同设备(CPU或GPU)、零化梯度等实用功能,这些功能在整个模型训练过程中起到重要作用。

1.2 常用函数

torch.nn.Module 这个类的内部有多达 48 个函数,下面就一些比较常用的函数进行讲解。

1.2.1 核心函数

  • __init__ 函数 和 forward() 函数
    __init__中主要是初始化一些内部需要用到的state;forward在这里没有具体实现,是需要在各个子类中实现的,如果子类中没有实现就会报错raise NotImplementedError。

  • apply(fn) 函数
    将Module及其所有的SubModule传进给定的fn函数操作一遍。我们可以用这个函数来对Module的网络模型参数用指定的方法初始化。下边这个例子就是将网络模型net中的子模型Linear的参数全部赋值为 1 。

def init_weight(m):
    if type(m) == nn.Linear:
        m.weight.data.fill_(1.0)
        m.bias.data.fill_(0)


net = nn.Sequential(nn.Linear(2, 2))
net.apply(init_weight)

输出如下:
在这里插入图片描述

  • state_dict() 函数
    返回一个包含module的所有state的dictionary,而这个字典的Keys对应的就是parameter和buffer的名字names。该函数的源码部分有一个循环可以递归遍历Module中所有的SubModule。
net = torch.nn.Linear(2, 2)
print(net.state_dict())

输出如下:

在这里插入图片描述

print(net.state_dict().keys())

在这里插入图片描述

  • add_module()函数
    将子模块加入当前模块中,被添加的模块可以用name来获取

1.2.2 查看函数

使用 nn.Module 中的查看类函数可以对网络中的参数进行有效管理,常用的查看类参数如下:

parameters()  #返回一个包含模型所有参数的迭代器
buffers() 
children()  # 返回当前模型子模块的迭代器
modules()  # 返回一个包含当前模型所有模块的迭代器

与之对应的四个函数:

named_parameters()
named_buffers()
namde_children()
named_modules() 
  • parameters() 函数
    可以使用for param in model.parameters()来遍历网络模型中的参数,因为该函数返回的是一个迭代器iterator。我们在使用优化算法的时候就是将model.parameters()传给优化器Optimizer。
net = nn.Sequential(nn.Linear(2, 2))
params = list(net.parameters())
print(params)

输出如下:
在这里插入图片描述

  • buffers 函数、 children 函数 和 modules 函数
    与parameters()函数类似。

  • named_parameters() 函数

net = nn.Sequential(nn.Linear(2, 2))
print(type(net.named_parameters()))
for name, params in net.named_parameters():
    print(name, params)

输出如下:

在这里插入图片描述

  • named_buffers 函数、 named_children 函数 和 named_modules 函数
    与named_parameters()函数类似。

1.2.3 设置函数

设置类包含包括设置模型的训练/测试状态、梯度设置、设备设置等。

  • train() 函数 和 eval() 函数
    • train(): 将Module及其SubModule设置为training mode
    • eval(): 将Module及其SubModule设置为evaluation mode

这两个函数只对特定的Module有影响,例如Class Dropout、Class BatchNorm。

  • requires_grad() 函数 和 zero_grad()函数

    • 设置self.parameters()是否需要record梯度,默认情况下是True。
    • 函数zero_grad 用于设置self.parameters()的gradients为零。
  • cuda() 函数 和 cpu()函数

    • cuda(): Moves all model parameters and buffers to the GPU.
    • cpu(): Moves all model parameters and buffers to the CPU.

两者返回的都是Module本身且都调用了_apply函数。

  • to() 函数
    函数to的作用是原地 ( in-place ) 修改Module,它可以当成三种函数来使用:
    • to(device=None, dtype=None, non_blocking=False):设备
    • to(dtype, non_blocking=False):类型
    • to(tensor, non_blocking=False): 张量

基于nn.Modeule构建Linear层:

linear = nn.Linear(2, 2)
print(linear.weight)
# Parameter containing:
# tensor([[ 0.4331,  0.6347],
#         [ 0.5735, -0.0210]], requires_grad=True)

修改参数类型:

linear.to(torch.double)
print(linear.weight)
# Parameter containing:
# tensor([[ 0.4331,  0.6347],
#         [ 0.5735, -0.0210]], dtype=torch.float64, requires_grad=True)

修改设备类型:

gpu1 = torch.device("cuda:1")
linear.to(gpu1, dtype=torch.half, non_blocking=True)
# Linear(in_features=2, out_features=2, bias=True)

print(linear.weight)
# Parameter containing:
# tensor([[ 0.4331, 0.6347],
#         [ 0.5735, -0.0210]], dtype=torch.float16, device='cuda:1')
cpu = torch.device("cpu")
linear.to(cpu)
# Linear(in_features=2, out_features=2, bias=True)

print(linear.weight)
# Parameter containing:
# tensor([[ 0.4331, 0.6347],
#         [0.5735, -0.0210]], dtype=torch.float16)

1.2.4 注册函数

register_parameter   # 向self._parameters注册新元素
register_buffer      # 向self._buffers注册新元素

register_backward_hook   # 向self._backward_hook注册新元素
register_forward_pre_hook   # 向self._forward_pre_hook注册新元素
register_forward_hook   # 向self._forward_hook注册新元素

1.2.5 转换函数

to()  # 转换为张量,设置类型、设备等
type()  # 将parameters和buffers的数据类型转换为目标类型dst_type
double()  # 将parameters和buffers的数据类型转换为double
float() # 将parameters和buffers的数据类型转换为float
half()  # 将parameters和buffers的数据类型转换为half

1.2.6 加载函数

可以很方便的进行 save 和 load,以防止突然发生的断点和系统崩溃现象

load_state_dict(state_dict, strict=True)
# 将state_dict中的参数和缓冲区复制到此模块及其后代中。如果strict为真,则state_dict的键必须与该模块的state_dict()函数返回的键完全匹配。

"""
state_dict (dict) – 保存parameters和persistent buffers的字典。
将state_dict中的parameters和buffers复制到此module和它的后代中。

state_dict中的key必须和 model.state_dict()返回的key一致。
"""

2. nn.Sequential()

nn.Sequential()是一个序列容器,用于搭建神经网络的模块按照被传入构造器的顺序添加到nn.Sequential()容器中。除此之外,一个包含神经网络模块的OrderedDict也可以被传入nn.Sequential()容器中。利用nn.Sequential()搭建好模型架构,模型前向传播时调用forward()方法,模型接收的输入首先被传入nn.Sequential()包含的第一个网络模块中。然后,第一个网络模块的输出传入第二个网络模块作为输入,按照顺序依次计算并传播,直到nn.Sequential()里的最后一个模块输出结果。

2.1 基本定义

先简单看一下它的定义:

class Sequential(Module): # 继承Module
    def __init__(self, *args):  # 重写了构造函数
    def _get_item_by_idx(self, iterator, idx):
    def __getitem__(self, idx):
    def __setitem__(self, idx, module):
    def __delitem__(self, idx):
    def __len__(self):
    def __dir__(self):
    def forward(self, input):  # 重写关键方法forward

2.2 Sequential类不同的实现

方法一:最简单的序列模型

import torch.nn as nn

model = nn.Sequential(
          nn.Conv2d(1, 20, 5),
          nn.ReLU(),
          nn.Conv2d(20, 64, 5),
          nn.ReLU()
        )
# 采用第一种方式,默认命名方式为  [0,1,2,3,4,...]
print(model, '\n')
print(model[2]) # 通过索引获取第几个层

输出如下:
在这里插入图片描述

在每一个包装块里面,各个层是没有名称的,层的索引默认按照0、1、2、3、4来排名。

方法二:有序字典(给每一个层添加名称)

import torch.nn as nn
from collections import OrderedDict

model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))
print(model, '\n')
print(model[2]) # 通过索引获取第几个层

输出如下:
在这里插入图片描述

很多人认为python中的字典是无序的,因为它是按照hash来存储的,但是python中有个模块collections(英文,收集、集合),里面自带了一个子类OrderedDict,实现了对字典对象中元素的排序。
从上面的结果中可以看出,这个时候每一个层都有了自己的名称,但是此时需要注意,并不能够通过名称直接获取层,依然只能通过索引index,即model[2] 是正确的,model[“conv2”] 是错误的,这其实是由它的定义实现的,看上面的Sequenrial定义可知,只支持index访问。

方法三:add_module()

import torch.nn as nn
from collections import OrderedDict

model = nn.Sequential()

model.add_module("conv1", nn.Conv2d(1, 20, 5))
model.add_module('relu1', nn.ReLU())
model.add_module('conv2', nn.Conv2d(20, 64, 5))
model.add_module('relu2', nn.ReLU())

print(model, '\n')
print(model[2])  # 通过索引获取第几个层

输出如下:
在这里插入图片描述
这里,add_module()这个方法是定义在它的父类Module里面的,Sequential继承了该方法。

2.3 nn.Sequential()的本质作用

与一层一层的单独调用模块组成序列相比,nn.Sequential() 可以允许将整个容器视为单个模块(即相当于把多个模块封装成一个模块),forward()方法接收输入之后,nn.Sequential()按照内部模块的顺序自动依次计算并输出结果。

这就意味着我们可以利用nn.Sequential() 自定义自己的网络层,示例如下:

import torch.nn as nn


class Model(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(Model, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(in_channel, in_channel / 4, kernel_size=1),
                                    nn.BatchNorm2d(in_channel / 4),
                                    nn.ReLU())
        self.layer2 = nn.Sequential(nn.Conv2d(in_channel / 4, in_channel / 4),
                                    nn.BatchNorm2d(in_channel / 4),
                                    nn.ReLU())
        self.layer3 = nn.Sequential(nn.Conv2d(in_channel / 4, out_channel, kernel_size=1),
                                    nn.BatchNorm2d(out_channel),
                                    nn.ReLU())
        
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        
        return x

上边的代码,我们通过nn.Sequential()将卷积层,BN层和激活函数层封装在一个层中,输入x经过卷积、BN和ReLU后直接输出激活函数作用之后的结果。

3. nn.ModuleList

nn.ModuleList就像一个普通的Python的List,我们可以使用下标来访问它。好处是传入的ModuleList的所有Module都会注册到PyTorch里,这样Optimizer就能找到其中的参数,从而用梯度下降进行更新。但是nn.ModuleList并不是Module(的子类),因此它没有forward等方法,通常会被放到某个Module里。

  • nn.ModuleList()

ModuleList 具有和List 相似的用法,实际上可以把它视作是 Module 和 list 的结合。

# 输入参数  modules (list, optional) – 将要被添加到MuduleList中的 modules 列表

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers=nn.ModuleList([
            nn.Linear(1,10), nn.ReLU(),
            nn.Linear(10,1)])
    def forward(self,x):
        out = x
        for layer in self.layers:
            out = layer(out)
        return out
    
model = Model()
print(model)

输出如下:
在这里插入图片描述

  • append(module)
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers=nn.ModuleList([
            nn.Linear(1,10), nn.ReLU(),
            nn.Linear(10,1)])
        self.layers.append(nn.Linear(1, 5))
    def forward(self,x):
        out = x
        for layer in self.layers:
            out = layer(out)
        return out

  • extend(modules)
    extend(),必须也为一个list
self.layers.extend([nn.Linear(size1, size2) for i in range(1, num_layers)])

nn.Sequential()和nn.ModuleList的区别:

  • nn.Sequential()定义的网络中各层会按照定义的顺序进行级联,因此需要保证各层的输入和输出之间要衔接。
  • nn.Sequential()实现了farward()方法,因此可以直接通过类似于x=self.combine(x)实现 forward()。
  • nn.ModuleList则没有顺序性要求,并且也没有实现forward()方法。

参考资料

  • 【PyTorch】torch.nn.Module 源码分析
  • pytorch nn.Module()模块
  • https://github.com/ShusenTang/Dive-into-DL-PyTorch/blob/master/docs/chapter04_DL_computation/4.1_model-construction.md

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1881542.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

护网蓝队面试

一、sql注入分类 **原理:**没有对用户输入项进行验证和处理直接拼接到查询语句中 查询语句中插⼊恶意SQL代码传递后台sql服务器分析执行 **从注入参数类型分:**数字型注入、字符型注入 **从注入效果分:**报错注入、布尔注入、延时注入、联…

【Python时序预测系列】基于LSTM实现多输入多输出单步预测(案例+源码)

这是我的第312篇原创文章。 一、引言 单站点多变量输入多变量输出单步预测问题----基于LSTM实现。 多输入就是输入多个特征变量 多输出就是同时预测出多个标签的结果 单步就是利用过去N天预测未来1天的结果 二、实现过程 2.1 读取数据集 dfpd.read_csv("data.csv&qu…

Java进阶学习|Day3.Java集合类(容器),Stream的使用,哈希初接触

java集合类(容器) Java中的集合类主要由Collection和Map这两个接口派生而出,其中Collection接口又派生出三个子接口,分别是Set、List、Queue。所有的Java集合类,都是Set、List、Queue、Map这四个接口的实现类&#xf…

7月408规划,保底100冲120+!

两句话看懂408! 408是计算机考研全国统考考试的代码,只是一个代码。 408由教育部统一出题,包含四门课程,分别是数据结构,计算机组成原理,计算机网络,操作系统.。 参考教材是: 数…

WPF的IValueConverter用于校验和格式化TextBox的数字输入

在数据绑定(Data Binding)的上下文中,我们经常使用继承 IValueConverter 接口的类,用于在源值和目标值之间进行转换。该接口定义了两个方法:Convert 和 ConvertBack,这两个方法分别用于从源值到目标值的转换…

【折腾手机】一加6T刷机postmarketOS经历和体验

写在前面 到目前为止,我已经花了非常多的时间去学习和了解x86架构和RISC-V架构,对它们的指令集编程、指令格式的设计、编译套件的使用都亲自去体会和实践过,学到了很多的东西。但是对于离我们最近的arm架构却了解甚少。为什么说离我们最近呢…

探索数据赋能的未来趋势:嵌入式BI技术的挑战与突破

数据分析能力越来越成为消费者和企业的必备品应用程序,复杂程度各不相同,从简单地一个网页或门户上托管一个可视化或仪表板,到在一个云服务上实现数据探索、建模、报告和可视化创建的应用程序。BI的实现方式越来越多,无论规模大小…

自动雪深传感器的类型

TH-XL2随着科技的飞速发展,气象监测技术也在不断进步。在降雪天气频发的冬季,雪深数据对于保障道路交通、农业生产和电力供应等具有至关重要的作用。自动雪深传感器作为气象监测的重要工具,其类型多样、功能各异,为气象数据的准确…

国产分布式数据库灾备高可用实现

最近在进行核心业务系统的切换演练测试,就在想一个最佳的分布式数据库高可用部署方案是如何保证数据不丢、系统可用的,做到故障时候可切换、可回切,并且业务数据的一致性。本文简要介绍了OceanBase数据库和GoldenDB数据库在灾备高可用的部署方…

leetCode-hot100-动态规划专题

动态规划 动态规划定义动态规划的核心思想动态规划的基本特征动态规划的基本思路例题322.零钱兑换53.最大子数组和72.编辑距离139.单词拆分62.不同路径63.不同路径Ⅱ64.最小路径和70.爬楼梯121.买卖股票的最佳时机152.乘积最大子数组 动态规划定义 动态规划(Dynami…

嫦娥六号成功带回月球背面土壤,嫦娥七号整装待发,2030年前实现载人登月!

本文首发于公众号“AntDream”,欢迎微信搜索“AntDream”或扫描文章底部二维码关注,和我一起每天进步一点点 嫦娥六号圆满成功 嫦娥六号任务是中国探月工程的一次重大成功,探测器于5月3日在中国文昌航天发射场发射升空并进入地月转移轨道。经…

【SQL】已解决:SQL分组去重并合并相同数据

文章目录 一、分析问题背景二、可能出错的原因三、错误代码示例四、正确代码示例五、注意事项 已解决:SQL分组去重并合并相同数据 在数据库操作中,数据的分组、去重以及合并是常见需求。然而,初学者在编写SQL语句时,可能会遇到一…

2024华为OD机试真题- 电脑病毒感染-(C++/Python)-C卷D卷-200分

2024华为OD机试题库-(C卷+D卷)-(JAVA、Python、C++) 题目描述 一个局域网内有很多台电脑,分别标注为 0 ~ N-1 的数字。相连接的电脑距离不一样,所以感染时间不一样,感染时间用 t 表示。 其中网络内一台电脑被病毒感染,求其感染网络内所有的电脑最少需要多长时间。如果…

整合、速通 版本控制器-->Git 的实际应用

目录 版本控制器 -- Git1、Git 和 SVN 的区别2、Git 的卸载和安装2-1:Git 卸载1、先查下原本的Git版本2、删除环境变量3、控制面板卸载 Git 2-2:Git 下载安装1、官网下载2、详细安装步骤3、安装成功展示 3、Git 基础知识3-1:基本的 Linux 命令…

逆向开发环境准备

JDK安装 AndroidStudio安装 默认sdk路径 C:\Users\Administrator\AppData\Local\Android\Sdk 将platform-tools所在的目录添加到path C:\Users\Administrator\AppData\Local\Android\Sdk\platform-tools 主要目的是使用该目录下的adb等命令 将tools所在的目录添加到path C:\Us…

LabVIEW风机跑合监控系统

开发了一种基于LabVIEW的风机跑合监控系统,提高风机测试的效率和安全性。系统通过自动控制风机的启停、实时监控电流和功率数据,并具有过流保护功能,有效减少了人工操作和安全隐患,提升了工业设备测试的自动化和智能化水平。 项目…

解决注册表删除Google报错问题

删除注册表中的Google时报错: 解决方式: 1、右键com.microsoft.browsercore,选择【权限】,在弹出的窗口中点击【高级】 2、可以看到现在的所有者是:TrustedInstaller,点击【更改】 3、点击选择用户和组中的…

东方航空逆向

声明(lianxi a15018601872) 本文章中所有内容仅供学习交流使用,不用于其他任何目的,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关! …

问题解决|endnote文献手工导入

一、背景介绍 手工导入一篇文献是指手动编辑文献的相关信息Preference。为什么要手动这么麻烦?因为有的文献比较老只有纸质版本,有的文献信息不全,有的则是没有编码无法识别等等,需要手工录入;一般需要手工录入的情况比…

使用gradio搭建私有云ChatGLM3网页客户端

【图书推荐】《ChatGLM3大模型本地化部署、应用开发与微调》-CSDN博客 通过简单的代码领略一下ChatGLM3大模型_chatglm3 history怎么写-CSDN博客 对于一般使用网页端完成部署的用户来说,最少需要准备一个自定义的网页端界面。在网页端界面上,可以设置文…