PyTorch学习笔记-神经网络模型搭建小实战

news2025/3/1 1:44:28

1. torch.nn.Sequential

torch.nn.Sequential 是一个Sequential 容器,能够在容器中嵌套各种实现神经网络中具体功能相关的类,来完成对神经网络模型的搭建。模块的加入一般有两种方式,一种是直接嵌套,另一种是以 OrderedDict 有序字典的方式进行传入,这两种方式的唯一区别是:

  • 使用 OrderedDict 搭建的模型的每个模块都有我们自定义的名字。
  • 直接嵌套默认使用从零开始的数字序列作为每个模块的名字。

(1)直接嵌套方法的代码如下:

import torch.nn as nn

model = nn.Sequential(
    nn.Conv2d(1, 20, 5),
    nn.ReLU(),
    nn.Conv2d(20, 64, 5),
    nn.ReLU()
)

print(model)
# Sequential(
#   (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
#   (1): ReLU()
#   (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
#   (3): ReLU()
# )

(2)使用 OrderedDict 的代码如下:

import torch.nn as nn
from collections import OrderedDict

model = nn.Sequential(OrderedDict([
    ('Conv1', nn.Conv2d(1, 20, 5)),
    ('ReLU1', nn.ReLU()),
    ('Conv2', nn.Conv2d(20, 64, 5)),
    ('ReLU2', nn.ReLU())
]))
# print(model)
# Sequential(
#   (Conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
#   (ReLU1): ReLU()
#   (Conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
#   (ReLU2): ReLU()
# )

2. 小实战

现在我们来搭建下图所示的很简单的神经网络模型:

在这里插入图片描述

对模型进行分析可以看到第一层为卷积层 conv1,其卷积核 kernel_size=5,处理后的图像通道数从3变成了32,说明 out_channels=32,且图像大小没有变化,通过计算可知 stride=1, padding=2;第二层为最大池化层,其池化核 kernel_size=2,处理后的图像大小变为原来的一半。同理可以分析出之后每一层的参数。

代码如下:

from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import torch

class CIFAR10_Network(nn.Module):
    def __init__(self):
        super(CIFAR10_Network, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),  # [32, 32, 32]
            nn.MaxPool2d(kernel_size=2),  # [32, 16, 16]
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),  # [32, 16, 16]
            nn.MaxPool2d(kernel_size=2),  # [32, 8, 8]
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),  # [64, 8, 8]
            nn.MaxPool2d(kernel_size=2),  # [64, 4, 4]
            nn.Flatten(),  # [1024]
            nn.Linear(in_features=1024, out_features=64),  # [64]
            nn.Linear(in_features=64, out_features=10) # [10]
        )

    def forward(self, input):
        output = self.model(input)
        return output

network = CIFAR10_Network()

input = torch.randn(64, 3, 32, 32)  # 返回一个包含了从标准正态分布中抽取的一组随机数的张量
print(input.shape)  # torch.Size([64, 3, 32, 32])
output = network(input)
print(output.shape)  # torch.Size([64, 10])

writer = SummaryWriter('logs')
writer.add_graph(network, input)  # 生成计算图
writer.close()

使用 add_graph 函数可以在 TensorBoard 中生成神经网络的计算图,通过计算图可以很清晰地看到每一层计算时数据流入流出的结果,打开 TensorBoard 看一下结果:

在这里插入图片描述

双击相应的标签可以进一步深入查看更详细的信息,例如我们将我们构建的这个神经网络的类展开可以看到其中的模型:

在这里插入图片描述

将这个模型继续展开后就能看到我们在 Sequential 中定义的网络的各个层:

在这里插入图片描述

最后展开某一层即可看到在这层的数据流动情况:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/44464.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LabVIEW创建类 1

LabVIEW创建类 1 通过创建LabVIEW类,可在LabVIEW中创建用户定义的数据类型。LabVIEW类定义了对象相关的数据和可对数据执行的操作(即方法)。通过封装和继承可创建模块化的代码,使代码更易修改而不影响应用程序中的其它代码。 在…

Terraform 华为云最佳实践

目录划分如下:首先是环境,分为网络和service。global是全局的配置,也就是backend的配置,这次使用s3的存储作为backend的存储。最后就是模块做了一些封装。 在global里面的backend里面的main.tf去创建s3的存储。华为云支持s3存储&a…

[附源码]Python计算机毕业设计Django病房管理系统

项目运行 环境配置: Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术: django python Vue 等等组成,B/S模式 pychram管理等等。 环境需要 1.运行环境:最好是python3.7.7,…

RK3588平台开发系列讲解(USB篇)USB 外设 CONFIG

平台内核版本安卓版本RK3588Linux 5.10Android 12文章目录 一、 Mass Storage Class CONFIG二、USB Serial Converter CONFIG三、USB HID CONFIG四、USB Net CONFIG五、USB Camera CONFIG六、USB Audio CONFIG七、 USB HUB CONFIG沉淀、分享、成长,让自己和他人都能有所收获!…

PG::Seppuku

nmap -Pn -p- -T4 --min-rate1000 192.168.81.90 nmap -Pn -p 21,22,80,139,445,7080,7601,8088 -sCV 192.168.81.90 查看7601端口的页面 对路径进行爆破 在/secret路径下得到了用户名和一个密码字典 尝试ssh爆破 得到密码 eeyoree ssh登录 这里使用sudo -l&#xff0…

FineReport表格软件- 计算操作符说明

1. 概述 FineReport 中使用函数需要用到很多的操作符。 操作符不仅包含很多运算符,还包括一些报表特有的操作符。 FineReport 11.0 优化了公式 2. 运算符类型 运算符用于指定要对公式中的元素执行的计算类型。有默认计算顺序,但可以使用括号更改此顺序…

企业表格软件-FineReport 数组函数概述

1. ADD2ARRAY ADD2ARRAY(array, insertArray, start):在数组 array 的第 start 个位置插入 insertArray 中的所有元素,再返回该数组。 示例: ADD2ARRAY([3, 4, 1, 5, 7], [23, 43, 22], 3)返回[3, 4, 23, 43, 22, 1, 5, 7]。 ADD2ARRAY([…

将 AWS IAM Identity Center (SSO) SAML 与 Amazon OpenSearch Dashboard集成

Amazon OpenSearch Amazon OpenSearch Service 是一项 AWS 托管服务,可以让您运行和扩展 OpenSearch 集群,而不必担心管理、监控和维护您的基础设施,或者不必在操作 OpenSearch 集群方面积累深入的专业知识。 基于 SAML 的 OpenSearch Dash…

Json用法总结

1、忽略json JsonIgnoreProperties(value{“addressId”}) JSONField(serializefalse) JsonIgnore 2、 JsonFiled JsonProperty XStreamAlias Builder.Default 网上可以查询下相关资料 3、 JSON.parseObject(response, ***Response.class) JSONObject.parseObject(response, **…

LockSupport的使用

参考链接: LockSupport使用场景及原理详解 AQS的引入 LockSupport的使用 LockSupport是一个工具类,提供了基本的线程阻塞和唤醒功能,它是创建锁和其他同步组件的基础工具,内部是使用sun.misc.Unsafe类实现的。LockSupport和使用…

android分区概述

Android 设备包括几个分区,它们在启动过程中提供不同的功能。 1、 标准隔断 注意:支持无缝更新的设备每个分区需要一个插槽用于boot 、 system 、 vendor和radio 。 boot分区。此分区包含内核映像,并使用mkbootimg创建。您可以使用虚拟分区…

idea搭建ssm项目全过程详解:

1&#xff0c;创建maven项目&#xff1a; 然后&#xff0c;点击next 其次 2&#xff0c;在pom.xml导入相关依赖&#xff1a;&#xff08;如果idea没有集成maven需要先集成maven&#xff09; <dependencies><dependency><groupId>org.springframework</gr…

【LeetCode】接雨水 II [H](堆)

407. 接雨水 II - 力扣&#xff08;LeetCode&#xff09; 一、题目 给你一个 m x n 的矩阵&#xff0c;其中的值均为非负整数&#xff0c;代表二维高度图每个单元的高度&#xff0c;请计算图中形状最多能接多少体积的雨水。 示例 1&#xff1a; 输入: heightMap [[1,4,3,1,3…

Wireshark TS | 三谈 TCP 握手异常问题

前言 继续以一个实际案例来说下 TCP 握手问题&#xff0c;该数据包仍然来自于 Wireshark sharkfest 2017&#xff0c;一些简短但有趣的 TCP 跟踪文件中的又一个&#xff0c;或者说是最后一个了。可以说这些都是和 TCP 握手相关的连接问题&#xff0c;有兴趣的朋友可以私信&…

Mybatis-Plus开发提速器mybatis-plus-generator-ui

前言 在基于Mybatis的开发模式中&#xff0c;很多开发者还会选择Mybatis-Plus来辅助功能开发&#xff0c;以此提高开发的效率。虽然Mybatis也有代码生成的工具&#xff0c;但Mybatis-Plus由于在Mybatis基础上做了一些调整&#xff0c;因此&#xff0c;常规的生成工具生成的代码…

【一文秒懂——SLF4j日志】

目录 1. SLF4j日志 2. 日志输出 1. SLF4j日志 在添加了spring-boot-starter的项目中&#xff0c;已经包含了SLF4j日志的相关依赖项。 在添加了lombok的项目中&#xff0c;可以在类上添加Slf4j注解&#xff0c;则lombok框架会在编译期在类中声明名为log的变量&#xff0c;通…

农民歌唱家大衣哥喜迎贵客,这三位明星一般人还真请不动

都知道农民歌唱家大衣哥家里热闹&#xff0c;不过大部分都是蹭流量拍视频的&#xff0c;真正的好朋友绝对没有几个。虽然说没有几个好朋友&#xff0c;但是也不代表一个没有&#xff0c;看看在大衣哥家里吃饭的三位&#xff0c;每一个都不是一般人物。 如今的大衣哥&#xff0c…

发现智能合约中的 bug 的 7 个方法

寻找智能合约bug可能是一项高回报的工作&#xff0c;而且它也保护了生态系统免受黑客攻击。我最近有幸采访了一位开发人员&#xff0c;他发现了一个价值 70 亿美元的错误&#xff0c;并因报告该错误而获得了 220 万美元的报酬。 在这篇文章中&#xff0c;我将详细介绍该开发人…

路由和流量控制

路由策略 控制路由,从而影响IP包的转发路径。 路由策略的主要功能有两个,1)过滤路由信息,2)修改路由属性值。 路由匹配工具 acl 只有基本acl(Basic ACL,编号为 2000-2999)可以匹配路由。ACL匹配路由时只能匹配路由的网络号,但无法匹配掩码长度。 [RouterA] acl n…

基于SpringBoot的会员制医疗预约服务管理信息系统

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SpringBoot 前端&#xff1a;Vue、HTML 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#…