pytorch中nn.Sequential详解

news2024/9/28 11:16:58

1 nn.Sequential概述

1.1 nn.Sequential介绍

nn.Sequential是一个序列容器,用于搭建神经网络的模块被按照被传入构造器的顺序添加到容器中。除此之外,一个包含神经网络模块的OrderedDict也可以被传入nn.Sequential()容器中。利用nn.Sequential()搭建好模型架构,模型前向传播时调用forward()方法,模型接收的输入首先被传入nn.Sequential()包含的第一个网络模块中。然后,第一个网络模块的输出传入第二个网络模块作为输入,按照顺序依次计算并传播,直到nn.Sequential()里的最后一个模块输出结果。

因此,Sequential可以看成是有多个函数运算对象,串联成的神经网络,其返回的是Module类型的神经网络对象。

1.2 nn.Sequential的本质作用

与一层一层的单独调用模块组成序列相比,nn.Sequential() 可以允许将整个容器视为单个模块(即相当于把多个模块封装成一个模块),forward()方法接收输入之后,nn.Sequential()按照内部模块的顺序自动依次计算并输出结果。这就意味着我们可以利用nn.Sequential() 自定义自己的网络层。

示例代码:

from torch import nn


class net(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(net, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(in_channel, in_channel / 4, kernel_size=1),
                                    nn.BatchNorm2d(in_channel / 4),
                                    nn.ReLU())
        self.layer2 = nn.Sequential(nn.Conv2d(in_channel / 4, in_channel / 4),
                                    nn.BatchNorm2d(in_channel / 4),
                                    nn.ReLU())
        self.layer3 = nn.Sequential(nn.Conv2d(in_channel / 4, out_channel, kernel_size=1),
                                    nn.BatchNorm2d(out_channel),
                                    nn.ReLU())
        
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        
        return x

上边的代码,我们通过nn.Sequential()将卷积层,BN层和激活函数层封装在一个层中,输入x经过卷积、BN和ReLU后直接输出激活函数作用之后的结果。

1.3 nn.Sequential源码

def __init__(self, *args):
        super(Sequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)

nn.Sequential()首先判断接收的参数是否为OrderedDict类型,如果是的话,分别取出OrderedDict内每个元素的key(自定义的网络模块名)和value(网络模块),然后将其通过add_module方法添加到nn.Sequrntial()中。

    # NB: We can't really type check this function as the type of input
    # may change dynamically (as is tested in
    # TestScript.test_sequential_intermediary_types).  Cannot annotate
    # with Any as TorchScript expects a more precise type
    def forward(self, input):
        for module in self:
            input = module(input)
        return input

 调用forward()方法进行前向传播时,for循环按照顺序遍历nn.Sequential()中存储的网络模块,并以此计算输出结果,并返回最终的计算结果。

 

1.3 nn.Sequential与其它容器的区别

2 使用nn.Sequential定义网络

2.1 顺序添加网络模块到容器中

import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(28 * 28, 32),
    nn.ReLU(),
    nn.Linear(32, 10),
    nn.Softmax(dim=1)
)
print("model:", model)
print("model.parameters:", model.parameters)

x_input = torch.randn(2, 28, 28, 1)
print("x_input:", x_input)
print("x_input.shape:", x_input.shape)

y_pred = model.forward(x_input.view(x_input.size()[0], -1))
print("y_pred:", y_pred)

运行代码显示:

model: Sequential(
  (0): Linear(in_features=784, out_features=32, bias=True)
  (1): ReLU()
  (2): Linear(in_features=32, out_features=10, bias=True)
  (3): Softmax(dim=1)
)
model.parameters: <bound method Module.parameters of Sequential(
  (0): Linear(in_features=784, out_features=32, bias=True)
  (1): ReLU()
  (2): Linear(in_features=32, out_features=10, bias=True)
  (3): Softmax(dim=1)
)>
x_input.shape: torch.Size([2, 28, 28, 1])
y_pred: tensor([[0.1127, 0.0652, 0.1399, 0.0973, 0.1085, 0.0859, 0.1193, 0.1048, 0.0865,
         0.0800],
        [0.0986, 0.0955, 0.0927, 0.0765, 0.0782, 0.1004, 0.1171, 0.1605, 0.0883,
         0.0922]], grad_fn=<SoftmaxBackward0>)

2.2 包含神经网络模块的OrderedDict传入容器中

import torch
import torch.nn as nn
from collections import OrderedDict

model = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),
                                     ('relu1', nn.ReLU()),
                                     ('out', nn.Linear(32, 10)),
                                     ('softmax', nn.Softmax(dim=1))]))
print("model:", model)
print("model.parameters:", model.parameters)

x_input = torch.randn(2, 28, 28, 1)
print("x_input.shape:", x_input.shape)

y_pred = model.forward(x_input.view(x_input.size()[0], -1))
print("y_pred:", y_pred)

运行代码显示:

model: Sequential(
  (h1): Linear(in_features=784, out_features=32, bias=True)
  (relu1): ReLU()
  (out): Linear(in_features=32, out_features=10, bias=True)
  (softmax): Softmax(dim=1)
)
model.parameters: <bound method Module.parameters of Sequential(
  (h1): Linear(in_features=784, out_features=32, bias=True)
  (relu1): ReLU()
  (out): Linear(in_features=32, out_features=10, bias=True)
  (softmax): Softmax(dim=1)
)>
x_input.shape: torch.Size([2, 28, 28, 1])
y_pred: tensor([[0.0836, 0.1185, 0.1422, 0.0801, 0.0817, 0.0870, 0.0948, 0.1099, 0.1131,
         0.0892],
        [0.0772, 0.0933, 0.1312, 0.1135, 0.1214, 0.0736, 0.1461, 0.0711, 0.0908,
         0.0818]], grad_fn=<SoftmaxBackward0>)

3 nn.Sequential网络操作

3.1 索引查看子模块

import torch.nn as nn
from collections import OrderedDict

model = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),
                                     ('relu1', nn.ReLU()),
                                     ('out', nn.Linear(32, 10)),
                                     ('softmax', nn.Softmax(dim=1))]))
print("index0:", model[0])
print("index1:", model[1])
print("index2:", model[2])

运行代码显示:

index0: Linear(in_features=784, out_features=32, bias=True)
index1: ReLU()
index2: Linear(in_features=32, out_features=10, bias=True)

3.2 修改子模块

import torch.nn as nn
from collections import OrderedDict

model = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),
                                     ('relu1', nn.ReLU()),
                                     ('out', nn.Linear(32, 10)),
                                     ('softmax', nn.Softmax(dim=1))]))
model[1] = nn.Sigmoid()
print(model)

运行代码显示:

Sequential(
  (h1): Linear(in_features=784, out_features=32, bias=True)
  (relu1): Sigmoid()
  (out): Linear(in_features=32, out_features=10, bias=True)
  (softmax): Softmax(dim=1)
)

3.3 添加子模块

import torch.nn as nn
from collections import OrderedDict

model = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),
                                     ('relu1', nn.ReLU()),
                                     ('out', nn.Linear(32, 10)),
                                     ('softmax', nn.Softmax(dim=1))]))
model.append(nn.Linear(10, 2))
print(model)

运行代码显示:

Sequential(
  (h1): Linear(in_features=784, out_features=32, bias=True)
  (relu1): ReLU()
  (out): Linear(in_features=32, out_features=10, bias=True)
  (softmax): Softmax(dim=1)
  (4): Linear(in_features=10, out_features=2, bias=True)
)

3.4 删除子模块

import torch.nn as nn
from collections import OrderedDict

model = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),
                                     ('relu1', nn.ReLU()),
                                     ('out', nn.Linear(32, 10)),
                                     ('softmax', nn.Softmax(dim=1))]))
del model[2]
print(model)

运行代码显示:

Sequential(
  (h1): Linear(in_features=784, out_features=32, bias=True)
  (relu1): ReLU()
  (softmax): Softmax(dim=1)
)

3.5 嵌套子模块

import torch.nn as nn

seq_1 = nn.Sequential(nn.Linear(15, 10), nn.ReLU(), nn.Linear(10, 5))
seq_2 = nn.Sequential(nn.Linear(25, 15), nn.Sigmoid(), nn.Linear(15, 10))
seq_3 = nn.Sequential(seq_1, seq_2)
print(seq_3)

运行代码显示:

Sequential(
  (0): Sequential(
    (0): Linear(in_features=15, out_features=10, bias=True)
    (1): ReLU()
    (2): Linear(in_features=10, out_features=5, bias=True)
  )
  (1): Sequential(
    (0): Linear(in_features=25, out_features=15, bias=True)
    (1): Sigmoid()
    (2): Linear(in_features=15, out_features=10, bias=True)
  )
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1323486.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

音视频技术开发周刊 | 324

每周一期&#xff0c;纵览音视频技术领域的干货。 新闻投稿&#xff1a;contributelivevideostack.com。 467亿参数MoE追平GPT-3.5&#xff01;爆火开源Mixtral模型细节首公开&#xff0c;中杯逼近GPT-4 今天&#xff0c;Mistral AI公布了Mixtral 8x7B的技术细节&#xff0c;不…

Java精品项目源码新基于协同过滤算法的旅游推荐系统(编号V69)

Java精品项目源码新基于协同过滤算法的旅游推荐系统(编号V69) 大家好&#xff0c;小辰今天给大家介绍一个基于协同过滤算法的旅游推荐系统

java参数校验

引入依赖 <!--参数效验--><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-validation</artifactId></dependency><!--Length参数效验--><dependency><groupId>org.hib…

Golang(壹)

爱情不需要华丽的言语&#xff0c;只需要默默的行动。 简介 应用领域&#xff1a; 下载vscode 使用vscode Go下载 - Go语言中文网 - Golang中文社区 下载sdk 解压到文件中&#xff0c;打开sdk解压文件 穿插dos操作系统知识点&#xff1a; 测试go语言环境 看到vscode 的目录结…

[Win10系统] Win10 任务栏软件图标显示为空白 | 解决方案

文章目录 [Win10系统] Win10 任务栏软件图标显示为空白 | 解决方案前言产生错误的原因解决方案方法一&#xff1a;手动操作方法二&#xff1a;自动操作 总结 [Win10系统] Win10 任务栏软件图标显示为空白 | 解决方案 前言 有时候&#xff0c;我们在使用 Windows 10 系统时&…

深度学习环境配置------windows系统(GPU)------Pytorch

深度学习环境配置------windows系统&#xff08;GPU&#xff09;------Pytorch 准备工作明确操作系统明确显卡系列 CUDA和Cudnn下载与安装1.下载2.安装 环境配置过程1.安装Anacoda2.配置环境1&#xff09;创建一个新的虚拟环境2&#xff09;pytorch相关库的安装 2.安装VScode1&…

图片去除背景,无水印下载的六大免费平台!

随着人工智能技术的不断进步&#xff0c;越来越多的应用场景开始利用人工智能技术来提升用户体验。其中&#xff0c;AI去除图片背景是一项非常实用的功能。AIGCer尝试了多个平台&#xff0c;排除了很多有水印&#xff0c;需要付费&#xff0c;去除效果差等平台&#xff0c;为大…

[Verilog] 设计方法和设计流程

主页&#xff1a; 元存储博客 文章目录 1. 设计方法2. 设计流程 3 Vivado软件设计流程总结 1. 设计方法 Verilog 的设计多采用自上而下的设计方法&#xff08;top-down&#xff09;。设计流程是指从一个项目开始从项目需求分析&#xff0c;架构设计&#xff0c;功能验证&#…

Re解析(正则表达式解析)

正则表达式基础 元字符 B站教学视频&#xff1a; 正则表达式元字符基本使用 量词 贪婪匹配和惰性匹配 惰性匹配如下两张图&#xff0c;而 .* 就表示贪婪匹配&#xff0c;即尽可能多的匹配到符合的字符串&#xff0c;如果使用贪婪匹配&#xff0c;那么结果就是图中的情况三 p…

vue-springboot+java导师选择分配双选管理系统 0spy6

.2.3功能需求 本导师选择管理系统是为了提高用户查阅信息的效率和管理人员管理信息的工作效率&#xff0c;可以快速存储大量数据&#xff0c;还有信息检索功能&#xff0c;这大大的满足了学生、导师和管理员这三者的需求。操作简单易懂&#xff0c;合理分析各个模块的功能&…

凤凰架构之事务处理

目录 本地事务全局事务共享事务分布式事务可靠消息队列TCC事务SAGA事务 本地事务 本地事务是最基础的一种事务解决方案&#xff0c;只适用于单个服务使用单个数据源的场景。从应用角度看&#xff0c;它是直接依赖于数据源本身提供的事务能力来工作的&#xff0c;在程序代码层面…

石器时代H5小游戏架设教程

本文讲解石器时代 H5 之恐龙宝贝架设教程&#xff0c;想研究 H5 游戏如何实现&#xff0c;那请跟着此次教程学习在拥有小游戏源码的情况下该如何搭建起来 开始架设 1. 架设条件 石器时代架设需要准备&#xff1a; 一台linux 服务器&#xff0c;建议 CentOs 7.6 版本&#xf…

2023 英特尔On技术创新大会直播 |AI科技创新的引路者

英特尔大会 前言英特尔人工智能英特尔创新技术基于英特尔架构的科学计算总结 前言 英特尔技术创新大会是一个令人激动和启发的盛会。在这次大会上&#xff0c;我有幸观看了许多令人瞩目的科技创新和前沿技术的展示。这些展示不仅展示了英特尔作为科技巨头的实力&#xff0c;更…

告诉你playwright 不使用with sync_playwright() as编写脚本的新方法

大家都知道playwright代码的标准写法是&#xff1a; with sync_playwright() as p:browser p.chromium.launch(channel"chrome", headlessFalse)page browser.new_page()page.goto("http://www.baidu.com")print(page.title())browser.close() with sy…

Vue - 组件注册及其原理

1 Vue组件注册 Vue中注册组件的方式有两种&#xff1a;全局注册和局部注册。 2 局部注册 import HelloWorld from xxx/xxx export default {components: {HelloWorld} }3 全局注册 3.1 全局组件挂载 示例一&#xff1a; /** src/main.js */ // 表格动态列组件 import Dyn…

mysql 22day 对表格的增删改查、对数据的增删改查、对内容进行操作

目录 mysql 配置文件授权 远程链接 &#xff08;grant&#xff09;数据库操作创建库&#xff08;create&#xff09;切换数据库&#xff08;use&#xff09;查看当前所在库 表操作创建一张员工表查看表结构修改表名称增加字段修改字段名&#xff08;ALTER &#xff09;修改字段…

Floyd求最短路(Floyd算法)

参考&#xff1a;约会怎么走到目的地最近呢&#xff1f;一文讲清所有最短路算法问题-CSDN博客 有4个城市8条路&#xff0c;公路上的数字表示这条公路的长短&#xff0c;并且路是单向的&#xff0c;现在要求我们求出任意两个城市之间的最短路程&#xff0c;也就是求任意两个点之…

MIT18.06线性代数 笔记1

文章目录 方程组的几何解释矩阵消元乘法和逆矩阵A的LU分解转置-置换-向量空间R列空间和零空间求解Ax0主变量 特解求解Axb可解性和解的结构线性相关性、基、维数四个基本子空间矩阵空间、秩1矩阵和小世界图图和网络复习一 方程组的几何解释 线性组合&#xff1a; 找到合适的x和…

GitHub 如何修改 Fork from

如果你的仓库上面是 Fork from 的话&#xff0c;我们有什么办法能够取消掉这个 Fork from&#xff1f; 解决办法 GitHub 上面没有让你取消掉 Fork 的办法。 如果进入设置&#xff0c;在可见设置中也没有办法修改仓库的可见设置选项。 唯一的解决办法就是对你需要修改的仓库先…

透视数据:数据可视化工具的多重场景应用

数据可视化工具已经成为了许多领域中的重要利器&#xff0c;它们在各种场景下发挥着重要作用。下面我就以可视化从业者的角度简单谈谈数据可视化工具在不同场景下的应用&#xff1a; 企业数据分析与决策支持 在企业层面&#xff0c;数据可视化工具被广泛应用于数据分析和决策…