深度学习基础知识 nn.Sequential | nn.ModuleList | nn.ModuleDict

news2024/11/18 8:43:37

深度学习基础知识 nn.Sequential | nn.ModuleList | nn.ModuleDict

  • 1、nn.Sequential 、 nn.ModuleList 、 nn.ModuleDict 类都继承自 Module 类。
  • 2、nn.Sequential、nn.ModuleList 和 nn.ModuleDict语法
  • 3、Sequential 、ModuleDict、 ModuleList 的区别
  • 4、ModuleDict、 ModuleList 的区别
  • 5、nn.ModuleList 、 nn.ModuleDict 与 Python list、Dict 的区别

1、nn.Sequential 、 nn.ModuleList 、 nn.ModuleDict 类都继承自 Module 类。

2、nn.Sequential、nn.ModuleList 和 nn.ModuleDict语法

net = nn.Sequential(nn.Linear(32, 64), nn.ReLU()) →→只需要将定义的层按照顺序写入括号内就可以了

net = nn.ModuleList([nn.Linear(32, 6)4, nn.ReLU()]) →→在定义式需要加上中括号[],将定义的层写入到中括号内

net = nn.ModuleDict({‘linear’: nn.Linear(32, 64), ‘act’: nn.ReLU()}) →→需要大括号,将定义的层以键值对的形式写入

代码

import torch
import torch.nn as nn

net1 = nn.Sequential(nn.Linear(32, 64), nn.ReLU())
net2 = nn.ModuleList([nn.Linear(32, 64), nn.ReLU()])
net3 = nn.ModuleDict({'linear': nn.Linear(32, 64), 'act': nn.ReLU()})

print(net1)
print(net2)
print(net3)

在这里插入图片描述

3、Sequential 、ModuleDict、 ModuleList 的区别

1、 ModuleList 仅仅是一个储存各种模块的列表,这些模块之间没有联系也没有顺序(所以不用保证相邻层的输入输出维度匹配),而且没有实现 forward 功能需要自己实现

2、和 ModuleList 一样, ModuleDict 实例仅仅是存放了一些模块的字典,并没有定义 forward 函数需要自己定义

3、而 Sequential 内的模块需要按照顺序排列,要保证相邻层的输入输出大小相匹配,内部 forward 功能已经实现,所以,直接如下写模型,是可以直接调用的,不再需要写forward,sequential 内部已经有 forward

代码:

import torch
import torch.nn as nn

net1 = nn.Sequential(nn.Linear(32, 64), nn.ReLU())
net2 = nn.ModuleList([nn.Linear(32, 64), nn.ReLU()])
net3 = nn.ModuleDict({'linear': nn.Linear(32, 64), 'act': nn.ReLU()})


x = torch.randn(8, 3, 32)
print(net1(x).shape)    # 输出内容: torch.Size([8, 3, 64])
# print(net2(x).shape)  # 会报错,提示缺少forward
# print(net3(x).shape)   # 会报错,提示缺少forward

为 nn.ModuleList 写 forward 函数
代码:

import torch
import torch.nn as nn


class My_Model(nn.Module):
    def __init__(self):
        super(My_Model, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(32, 64),nn.ReLU()])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

net = My_Model()

x = torch.randn(8, 3, 32)
out = net(x)
print(out.shape)

输出结果:
在这里插入图片描述
为 nn.ModuleDict 写 forward 函数

import torch
import torch.nn as nn


class My_Model(nn.Module):
    def __init__(self):
        super(My_Model, self).__init__()
        self.layers = nn.ModuleDict({'linear': nn.Linear(32, 64), 'act': nn.ReLU()})

    def forward(self, x):
        for layer in self.layers.values():
            x = layer(x)
        return x

net = My_Model()
x = torch.randn(8, 3, 32)
out = net(x)
print(out.shape)

将 nn.ModuleList 转换成 nn.Sequential

import torch
import torch.nn as nn

module_list = nn.ModuleList([nn.Linear(32, 64), nn.ReLU()])
net = nn.Sequential(*module_list)
x = torch.randn(8, 3, 32)
print(net(x).shape)

输出如下:
在这里插入图片描述

将 nn.ModuleDict 转换成 nn.Sequential

import torch
import torch.nn as nn

module_dict = nn.ModuleDict({'linear': nn.Linear(32, 64), 'act': nn.ReLU()})
net = nn.Sequential(*module_dict.values())
x = torch.randn(8, 3, 32)
print(net(x).shape)

输出如下:
在这里插入图片描述

4、ModuleDict、 ModuleList 的区别

1、ModuleDict 可以给每个层定义名字,ModuleList 不会
2、ModuleList 可以通过索引读取,并且使用 append 添加元素

import torch.nn as nn

net = nn.ModuleList([nn.Linear(32, 64), nn.ReLU()])
net.append(nn.Linear(64, 10))
print(net)

3、ModuleDict 可以通过 key 读取,并且可以像 字典一样添加元素

import torch.nn as nn

net = nn.ModuleDict({'linear1': nn.Linear(32, 64), 'act': nn.ReLU()})
net['linear2'] = nn.Linear(64, 128)
print(net)

5、nn.ModuleList 、 nn.ModuleDict 与 Python list、Dict 的区别

import torch.nn as nn

net = nn.ModuleList([nn.Linear(32, 64), nn.ReLU()])

for name, param in net.named_parameters():
    print(name, param)


print("-----------------------------")
for name, param in net.named_parameters():
    print(name, param.size())

显示结果如下:
在这里插入图片描述

import torch.nn as nn

net = nn.ModuleDict({'linear': nn.Linear(32, 64), 'act': nn.ReLU()})

for name, param in net.named_parameters():
    print(name, param.size())
print("--------------------------")

for name, param in net.named_parameters():
    print(name, param.size())

显示结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1072881.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一文熟练使用python修改Excel中的数据

使用python修改Excel中的内容 1.初级修改 1.1 openpyxl库的功能: openpyxl模块是一个读写Excel 2010文档的Python库,如果要处理更早格式的Excel文档,需要用到额外的库,例如Xlwings。openpyxl是一个比较综合的工具,能…

HarmonyOS/OpenHarmony原生应用-ArkTS万能卡片组件Stack

堆叠容器,子组件按照顺序依次入栈,后一个子组件覆盖前一个子组件。该组件从API Version 7开始支持。可以包含子组件。 一、接口 Stack(value?: { alignContent?: Alignment }) 从API version 9开始,该接口支持在ArkTS卡片中使用。 二、…

缓存雪崩、缓存穿透和缓存击穿产生的原因及解决方案

目录 什么是缓存雪崩? 缓存雪崩的解决方案 什么是缓存穿透? 缓存穿透的解决方案 什么是缓存击穿? 缓存击穿的解决方案 缓存在提高系统性能和响应速度方面起着关键作用,但在实际应用中,我们常常面临一些与缓存相…

关于打造安卓测试机的方法以及常见问题的解决方式

摘要: 本文主要讲解如何打造安卓测试机,并刷机google原生系统、部署Magisk对测试机进行root的常用方式;并对一些常见问题进行思路解答。本文适合新手学习,大佬请绕过 本次实验使用的设备及环境如下: Nexus 5x 测试机…

【多线程进阶】线程安全的集合类

文章目录 前言1. 多线程环境使用 ArrayList2. 多线程环境使用队列3. 多线程环境使用哈希表3.1 HashTable3.2 ConcurrentHashMap 总结 前言 本文主要讲解 Java 线程安全的集合类, 在之前学习过的集合类中, 只有 Vector, Stack, HashTable, 是线程安全的, 因为在他们的关键方法中…

高端品牌如何利用软文抓住顾客的心?

如今高端品市场价值巨大,但之前由于“口罩”影响和冲击,高端品牌的线上销售份额占比较少,同时得益于互联网和新媒体技术的发展,高端品的利润来源大多数是线上推广进行销售,而软文就是高端品常用的推广方式,…

鉴源论坛 · 观模丨基于软件性质的自动化测试技术

作者 | 熊一衡 华东师范大学软件工程学院博士 苏亭 华东师范大学软件工程学院教授 版块 | 鉴源论坛 观模 社群 | 添加微信号“TICPShanghai”加入“上海控安51fusa安全社区” 在软件开发的生命周期中,测试是至关重要的一环。为了确保软件产品的质量,开…

MOM与MES管理系统有哪些本质上的区别

随着企业业务的不断发展,许多制造企业开始面临车间管理失控、生产不透明等问题。这时候,很多企业选择上线MES生产管理系统来提高生产管理水平。然而,随着企业业务的不断拓展,MES系统也逐渐暴露出其局限性。于是,MOM平台…

详解CAN通信的标识符掩码和标识符列表两种过滤机制

CAN 通信的应用非常广泛,本文不涉及CAN通信的基础配置,重点分析一下STM32和GD32的CAN通信两种ID过滤方式。 首先,不管是STM32还是GD32,实现CAN通信ID过滤的机制和原理一定是一样的,只是用到的寄存器有差别。 1. ID过…

TensorFlow入门(十四、数据读取机制(1))

TensorFlow的数据读取方式 TensorFlow的数据读取方式共有三种,分别是: ①预加载数据(Preloaded data) 预加载数据的方式,其实就是静态图(Graph)的模式。即将数据直接内嵌到Graph中,再把Graph传入Session中运行。 示例代码如下: import tensorflow.compat.v1 as tf tf.disabl…

超好用的IDEA插件推荐!

大家好,Apipost 最新推出IDEA插件V2版本!V2版本主要是Apipost 符合更多用户的需求而推出,支持在插件中获取 token、支持代码完成后在插件中进行 API调试 ,同时也保留了1.0版本部分功能如上传选择目录功能等。 V1版本还会继续保留…

韦东山D1S板子——xfel工具无法烧写bin文件到spi norFlash问题解决

1、早期问题排查 (1)参考博客:《韦东山D1S板子——烧录spi norFlash失败问题排查过程》; (2)早期排查到xfel工具烧写spi norFlash显示成功,但是实际没有烧写进bin文件,怀疑是norFlas…

“揭秘淘宝店铺所有商品接口:一键获取海量热销宝贝信息!“

淘宝店铺所有商品接口可以通过shop id或店铺主链接获取到整店商品,数据包括:商品ID,图片地址,店铺标题,优惠价,价格,销量,宝贝链接等整个店铺的商品。 要使用这个接口,需…

Linux 系统性能瓶颈分析(超详细)

Author:rab 目录 前言一、性能指标1.1 进程1.1.1 进程定义1.1.2 进程状态1.1.3 进程优先级1.1.4 进程与程序间的关系1.1.5 进程与进程间的关系1.1.6 进程与线程的关系 1.2 内存1.2.1 物理内存与虚拟内存1.2.2 页高速缓存与页写回机制1.2.3 Swap Space 1.3 文件系统1…

9个最常用的人体姿态估计模型

“姿态估计?”……“姿态”一词对于不同的人来说可能有不同的含义,但我们不是在讨论阿诺德经典、奥林匹亚或选美表演。 那么,姿态估计到底是什么? 那么,让我们深入探讨这个话题。 推荐:用 NSDT编辑器 快速搭…

【git】git命令行

首先要了解git整个流程的一个分类: workspace:工作区staging area:暂存区/缓存区local repository:版本库或本地仓库remote repository:远程仓库 创建仓库 git clone gitgithub.comxxxxxxxxxxxx//拷贝一份远程仓库 …

笔记34:转置卷积 Transposed Convolution 的由来

注:该文章为视频课的笔记补充 视频课:转置卷积(transposed convolution)_哔哩哔哩_bilibili 更详细的推导在:抽丝剥茧,带你理解转置卷积(反卷积)_逆卷积-CSDN博客 a a a 补充1…

3D模型轻量化工具HOOPS Communicator在3D打印行业中的应用分析

3D打印技术自问世以来,已经在制造业、医疗领域、航空航天和建筑等行业中产生了革命性的影响。随着3D打印技术的不断发展,对3D模型的需求也在不断增加。 随着3D模型复杂性的增加,模型文件的体积也不断膨胀,这对计算资源和数据传输提…

当zk某个节点坏掉如何修复

2.1 当zk某个节点坏掉如何修复 当发生zk数据文件丢失(误删或者磁盘损坏节点损坏都可能出现)时,cdh会出现如下告警

CoreData + CloudKit 在初始化 Schema 时报错 A Core Data error occurred 的解决

问题现象 如果希望为 CoreData 支持的 App 增加云数据备份和同步功能,那么 CloudKit 是绝佳的选择。CloudKit 会帮我们默默处理好一切,我们基本不用为升级而操心。 不过,有时在用本地 CoreData NSManagedObjectModel 初始化 iCloud 中的 Schema 时会发生如下错误: Error …