借助CIFAR10模型结构理解卷积神经网络及Sequential的使用

news2025/1/12 15:51:02

 CIFAR10模型搭建

CIFAR10模型结构

0. input 3@32x323通道32x32的图片 --> 特征图(Feature maps) : 32@32x32
即经过323@5x5的卷积层,输出尺寸没有变化(有x个特征图即有x个卷积核。卷积核的通道数与输入的通道数相等,即3@5x5)。
两种方法推导出padding = 2stride = 1的值:

公式法:

𝐻𝑜𝑢𝑡=32,𝐻𝑖𝑛=32,dilation = 1(默认值,此时没有空洞),kernel_size = 5

理论法:为保持输出尺寸不变,padding都是卷积核大小的一半,则有padding=kernel_size/2;奇数卷积核把中心格子对准图片第一个格子,卷积核在格子外有两层那么padding=2

1.input 32@32x32 --> output : 32@16x16
即经过2x2的最大池化层,stride = 2(池化层的步长为池化核的尺寸),padding = 0,特征图尺寸减小一半。
2.input 32@16x16 --> output : 32@16x16
即即经过323@5x5的卷积层,输出尺寸没有变化。padding = 2stride = 1
3.input : 32@16x16 --> output : 32@8x8
即经过2x2的最大池化层,stride = 2padding = 0,通道数不变,特征图尺寸减小一半。
4.input : 32@8x8 --> output : 64@8x8
即即经过643@5x5的卷积层,输出尺寸没有变化。padding = 2stride = 1
5.input : 64@8x8 --> output : 64@4x4
即经过2x2的最大池化层,stride = 2,padding = 0,通道数不变,特征图尺寸减小一半。
6.input:64@4x4-->output :1×1024
即经过展平层 Flatten 作用,将64@4x4的特征图依次排开。

7.input:1×1024-->output :​​​​​​​1×64
即经过线性层Linear1的作用。
8.input:1×64-->output:1×10
即经过线性层Linear2的作用。

代码验证:
按照网络结构一层一层搭建网络结构。
示例1:

# 导入需要用到的库
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear


# 搭建CIFAR10模型网络
class Tudui(nn.Module):

    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = Conv2d(3, 32, 5, padding=2) # 第一个卷积层
        self.maxpool1 = MaxPool2d(2) # 第一个最大池化层

        self.conv2 = Conv2d(32, 32, 5, padding=2) # 第二个卷积层
        self.maxpool2 = MaxPool2d(2) # 第二个最大池化层

        self.conv3 = Conv2d(32, 64, 5, padding=2) # 第三个卷积层
        self.maxpool3 = MaxPool2d(2) # 第三个最大池化层

        self.flatten = Flatten() # 展平层

        # 两个线性层
        self.linear1 = Linear(1024, 64) # 第一个线性层
        self.linear2 = Linear(64, 10) # 第二个线性层

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.conv3(x)
        x = self.maxpool3(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.linear2(x)
        return x


tudui = Tudui() # 实例化
print(tudui) # 观察网络信息
input = torch.ones((64, 3, 32, 32)) # 为网络创建假想输入,目的是检查网络是否正确
output = tudui(input) # 输出
print(output.shape) # torch.Size([64, 10]),结果与图片结果一致

 运行结果:

# 两个print出的内容分别为:
Tudui(
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=1024, out_features=64, bias=True)
  (linear2): Linear(in_features=64, out_features=10, bias=True)
)
torch.Size([64, 10])

Sequential的使用

        当模型中只是简单的前馈网络时,即上一层的输出直接作为下一层的输入,这时可以采用torch.nn.Sequential()模块来快速搭建模型,而不必手动在forward()函数中一层一层地前向传播。因此,如果想快速搭建模型而不考虑中间过程的话,推荐使用torch.nn.Sequential()模块。

接下来用torch.nn.Sequential()改写示例 1,示例 2 如下。
示例2:

# 导入需要用到的库
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential


# 搭建CIFAR10模型网络
class Tudui(nn.Module):

    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
             Conv2d(3, 32, 5, padding=2),  # 第一个卷积层
             MaxPool2d(2),  # 第一个最大池化层

             Conv2d(32, 32, 5, padding=2), # 第二个卷积层
             MaxPool2d(2), # 第二个最大池化层

             Conv2d(32, 64, 5, padding=2),  # 第三个卷积层
             MaxPool2d(2),  # 第三个最大池化层

             Flatten(),  # 展平层

             # 两个线性层
             Linear(1024, 64),  # 第一个线性层
             Linear(64, 10)  # 第二个线性层
        )


    def forward(self, x):
        x = self.model1(x)
        return x


tudui = Tudui() # 实例化
print(tudui) # 观察网络信息
input = torch.ones((64, 3, 32, 32)) # 为网络创建假想输入,目的是检查网络是否正确
output = tudui(input) # 输出
print(output.shape) # torch.Size([64, 10]),结果与图片结果一致

运行结果:

# 两个print出来的结果分别为:
Tudui(
  (model1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=1024, out_features=64, bias=True)
    (8): Linear(in_features=64, out_features=10, bias=True)
  )
)
torch.Size([64, 10])

        我们发现,使用Sequential之后得到的结果(示例2)与按照前向传播一层一层搭建得到的结果(示例1)一致,使用Sequential之后可以使得forward函数中的内容得以简化。

使用tensorboard实现网络结构可视化

# 导入需要用到的库
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.tensorboard import SummaryWriter

# 搭建CIFAR10模型网络



class Tudui(nn.Module):

    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = Conv2d(3, 32, 5, padding=2) # 第一个卷积层
        self.maxpool1 = MaxPool2d(2) # 第一个最大池化层

        self.conv2 = Conv2d(32, 32, 5, padding=2) # 第二个卷积层
        self.maxpool2 = MaxPool2d(2) # 第二个最大池化层

        self.conv3 = Conv2d(32, 64, 5, padding=2) # 第三个卷积层
        self.maxpool3 = MaxPool2d(2) # 第三个最大池化层

        self.flatten = Flatten() # 展平层

        # 两个线性层
        self.linear1 = Linear(1024, 64) # 第一个线性层
        self.linear2 = Linear(64, 10) # 第二个线性层

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.conv3(x)
        x = self.maxpool3(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.linear2(x)
        return x


tudui = Tudui() # 实例化
print(tudui) # 观察网络信息
input = torch.ones((64, 3, 32, 32)) # 为网络创建假想输入,目的是检查网络是否正确
output = tudui(input) # 输出
print(output.shape) # torch.Size([64, 10]),结果与图片结果一致

# 使用tensorboard实现网络可视化
writer = SummaryWriter("./log_sequential")
writer.add_graph(tudui, input)
writer.close()

运行上述代码,则会在项目文件夹CIFAR10model中出现对应的日志文件夹log_sequential。

随后打开Terminal,如下图所示。

 输入tensorboard --logdir=log_sequential,如下图所示。

按下Enter键,得到一个网址,如下图所示。

 打开这个网址,得到可视化界面。

我们点开搭建好的网络Tudui,可以得到更具体的网络每一层,如下图所示。

我们将其放大,如下图所示。 

网络中的每一层

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/981325.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

01-从JDK源码级别剖析JVM类加载机制

1. 类加载运行全过程 当我们用java命令运行某个类的main函数启动程序时,首先需要通过类加载器把主类加载到JVM。 public class Math {public static final int initData 666;public static User user new User();public int compute() { //一个方法对应一块栈帧…

整理mongodb文档:事务(一)

个人博客 整理mongodb文档:事务(一) 原文链接,个人博客 求关注,本文主要讲下怎么在mongose下使用事务,建议电脑端看 文章概叙 本文的开发环境为Nodejs,在‘单机模式’讲解最基本的事务概念。并没有涉及分片以及集群&#xff0…

《向量数据库指南》——AI原生向量数据库Milvus Cloud 2.3新功能

New Feature Upsert 功能 支持用户通过 upsert 接口更新或插入数据。已知限制,自增 id 不支持 upsert;upsert 是内部实现是 delete + insert所以性能上会有一定损耗,如果明确知道是写入数据的场景请继续使用 insert。 Range Search 功能 支持用户通过输入参数指定 search 的…

TortoiseGit设置作者信息和用户名、密码存储

前言 Git 客户端每次与服务器交互,都需要输入密码,但是我们可以配置保存密码,只需要输入一次,就不再需要输入密码。 操作说明 在任意文件夹下,空白处,鼠标右键点击 在弹出菜单中按照下图点击 依次点击下…

LLVM 与代码混淆技术

项目源码 什么是 LLVM LLVM 计划启动于2000年,开始由美国 UIUC 大学的 Chris Lattner 博士主持开展,后来 Apple 也加入其中。最初的目的是开发一套提供中间代码和编译基础设施的虚拟系统。 LLVM 命名最早源自于底层虚拟机(Low Level Virtu…

LEARN GIT

概念 基础概念 本地电脑 代码区:工作区间,放代码的地方 暂存区:git所管理的暂存区域 本地仓库:git所管理的本机的硬盘区域 远程电脑 远程仓库:github、gitee 代码提交管理的过程 代码区------->暂存区-------&…

关于 RK3568的linux系统killed用户应用进程(用户现象为崩溃) 的解决方法

若该文为原创文章,转载请注明原文出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/132710642 红胖子网络科技博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV、OpenGL、ffmpeg、OSG、单片机、软硬…

模拟Proactor模式实现 I/O 处理单元

编写main.cpp 1.socket通信 服务器应用程序可以通过读取和写入 Socket 对象 来监听来自客户端的请求并向客户端返回响应 #define MAX_FD 65536 // 最大的文件描述符个数 #define MAX_EVENT_NUMBER 10000 // 监听的最大的事件数量 // 添加信号捕捉 void addsig(int sig, …

【MySQL】索引 详解

索引 详解 一. 概念二. 作用三. 使用场景四. 操作五. 索引背后的数据结构B-树B树聚簇索引与非聚簇索引 一. 概念 索引是一种特殊的文件,包含着对数据表里所有记录的引用指针。可以对表中的一列或多列创建索引,并指定索引的类型,各类索引有各…

机器学习的特征工程

字典特征提取 def dict_demo():"""字典特征提取:return:"""data [{city: 北京, temperature: 100}, {city: 上海, temperature: 60}, {city: 深圳, temperature: 30}]# data [{city:[北京,上海,深圳]},{temperature:["100","6…

《机器人学一(Robotics(1))》_台大林沛群 第 5 周【机械手臂 轨迹规划】 Quiz 5

我又行了!🤣 求解的 位置 可能会有 变动,根据求得的A填写相应值即可。注意看题目。 coursera链接 文章目录 第1题 Cartesian space求解 题1-3 的 Python 代码 第2题第3题第4题 Joint space求解 题4-6 的 Python 代码 第5题第6题其它可参考代…

leetcode 88:合并两个有序数组 。 双指针解法

题目 算法 双指针 code var merge function(nums1, m, nums2, n) {// 其实就是一个nums1数组从后向前的降序重排,从最后开始,比较nums1有效位置和nums2当前位置数的大小,依次填入,nums2最后若有剩余,则再多一步从后…

9、补充视频

改进后的dijkstra算法 利用小根堆 将小根堆特定位置更改,再改成小根堆 nodeHeap.addOrUpdateOrIgnore(edge.to, edge.weight + distance);//改进后的dijkstra算法 //从head出发,所有head能到达的节点,生成到达每个节点的最小路径记录并返回 public static HashMap<No…

Bytebase 和 GitLab 签署 Technology Partner 技术合作伙伴协议

Bytebase 和 GitLab 签署技术合作伙伴协议&#xff0c;携手为开发者提供流畅的数据库协作开发和管理体验。 GitLab 是世界领先的开源 AI 驱动 DevSecOps 平台&#xff0c;旨在帮助开发者团队更好协作、更高效交付软件。Bytebase 是一款为 DevOps 团队准备的数据库 CI/CD 工具&a…

一文讲解Linux内核内存管理架构

内存管理子系统可能是linux内核中最为复杂的一个子系统&#xff0c;其支持的功能需求众多&#xff0c;如页面映射、页面分配、页面回收、页面交换、冷热页面、紧急页面、页面碎片管理、页面缓存、页面统计等&#xff0c;而且对性能也有很高的要求。本文从内存管理硬件架构、地址…

上海控安携汽车网络安全新研产品出席AUTOSEMO“恒以致远,共创共赢”主题研讨会

8月31日&#xff0c;AUTOSEMO“恒以致远&#xff0c;共创共赢”主题研讨会在天津成功召开。本次大会由中国汽车工业协会软件分会中国汽车基础软件生态标委会&#xff08;简称&#xff1a;AUTOSEMO&#xff09;与天津市西青区人民政府联合主办。现场汇聚了100余位来自产学研政企…

单片机-LED介绍

简介 LED 即发光二极管。它具有单向导电性&#xff0c;通过 5mA 左右电流即可发光 电流 越大&#xff0c;其亮度越强&#xff0c;但若电流过大&#xff0c;会烧毁二极管&#xff0c;一般我们控制在 3 mA-20mA 之间&#xff0c;通常我们会在 LED 管脚上串联一个电阻&#xff0c…

unity 控制Dropdown的Arrow箭头变化

Dropdown打开下拉菜单会以“Template”为模板创建一个Dropdown List&#xff0c;在“Template”上添加一个脚本在Start()中执行下拉框打开时的操作&#xff0c;在OnDestroy()中执行下拉框收起时的操作即可。 效果代码如下用于控制Arrow旋转可以根据自己的想法进行修改&#xff…

雷达有源干扰识别仿真

各类干扰信号 基于数字射频存储(DRFM)技术的雷达干扰系统有三种工作方式&#xff1a;转发方式、应答方式和噪声方式&#xff0c;即&#xff0c;对应有三种干扰类型。 噪声干扰 DRFM干扰系统在噪声工作方式下不但可以产生传统噪声干扰&#xff0c;还可以通过将数字噪声调制到干…

网络空间内生安全数学基础(1)——背景

目录 &#xff08;一&#xff09;内生安全基本定义及实现什么是内生安全理论内生安全理论实现方法动态性异构性冗余性 &#xff08;二&#xff09;安全防御和可靠性问题起源内生安全防御、可靠性保证与香农可靠通信 &#xff08;三&#xff09;总结 &#xff08;一&#xff09;…