【机器学习】卷积神经网络(CNN)的特征数计算

news2024/12/23 10:00:09

文章目录

  • 基本步骤
  • 示例
  • 图解过程

基本步骤

在卷积神经网络(CNN)中,计算最后的特征数通常涉及到以下步骤:

  1. 确定输入尺寸

    首先,你需要知道输入数据的尺寸。对于图像数据,这通常是 (batch_size, channels, height, width)

  2. 应用卷积层

    在卷积操作过程中,图像与卷积核进行滑动窗口式的乘加运算,这会导致图像尺寸的变化。特征数会根据卷积核的数量和大小以及步长等因素发生变化。

    • in_channels:输入数据的通道数。
    • out_channels:卷积层产生的输出特征图的数量,即卷积核的数量。
    • kernel_size:卷积核(filter)的大小(FxF)(kernel_size的选择对模型的性能有很大影响,因为它决定了模型能够捕捉到的特征的尺度和复杂性。增大kernel_size可以捕获更大范围的特征,但可能会增加计算复杂性和过拟合的风险;减小kernel_size则可以关注更细节、局部的特征,但可能忽略掉一些重要的全局信息。因此,选择合适的kernel_sizeCNN设计中的一个重要环节)。
    • stride:卷积核在输入数据上滑动的步长。
    • padding:在输入数据边缘添加的零填充的数量。


    卷积层的输出尺寸可以通过以下公式计算(floor()是向下取整函数):

    output_height = floor((input_height - kernel_size + 2 * padding) / stride) + 1
    output_width = floor((input_width - kernel_size + 2 * padding) / stride) + 1
    

    特征数(或通道数)在卷积层后变为 out_channels

  3. 应用池化层

    池化层通常不会改变特征数,但会改变特征图的高度和宽度。

    池化层的输出尺寸可以通过以下公式计算:

    output_height = floor((input_height - kernel_size) / stride) + 1
    output_width = floor((input_width - kernel_size) / stride) + 1
    
  4. 重复以上步骤

    继续应用卷积层和池化层,每次更新特征图的尺寸和特征数。

  5. 全局平均池化或全连接层

    在某些情况下,网络可能包含全局平均池化层或全连接层,这些层可以进一步改变特征数。为了将这些特征图转换为一维向量以输入到全连接层,你需要将特征图的元素"展平"(flatten)。展平的过程是将所有元素按顺序排列成一个单一的向量。

    计算展平后的输入维度(in_features)的公式为:

    in_features = channels * height * width
    
  6. 最终特征数

    网络的最后一层之前的特征图的通道数就是最后的特征数。

示例

以下是一个简单的例子来说明如何计算最后特征图的尺寸:给定 RGB 图像 (batch_size=32,channels=3,height=60,width=90)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv_block1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )

        self.conv_block2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )

        self.fc2 = nn.Sequential(
            nn.Linear(18816, 9408),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(9408, 4704),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4704, 5)
        )

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc2(x)
        return x

在上述代码中,给定一个 RGB 图像 (batch_size=32,channels=3,height=60,width=90),我们将图像输入到 self.conv_block1self.conv_block2 进行处理。

首先,我们计算经过 self.conv_block1 后的特征数:

  • 输入数据有 3 个通道(RGB 图像)。
  • 第一个卷积层将输出通道数增加到 32

由于 kernel_size=3, stride=1, padding=1,即卷积核的大小为 3×3,步长为 1,填充为 1,我们可以计算新的特征图尺寸:

output_height = (60 - 3 + 2 * 1) / 1 + 1 = 60
output_width = (90 - 3 + 2 * 1) / 1 + 1 = 90
  • 经过 ReLU 激活函数后,特征数保持为 32
  • 第二个卷积层仍然保持 32 个输出通道,同上特征图的高度和宽度不变。
  • 再经过 ReLU 激活函数后,特征数仍为 32
  • 最后,最大池化层不会改变通道数,但会减小特征图的高度和宽度。

由于 nn.MaxPool2d(kernel_size=3, stride=2),即最大池化层的池化窗口的大小为 3×3 步长为 2,我们可以计算新的特征图尺寸:

output_height = (60 - 3) / 2 + 1 = 29
output_width = (90 - 3) / 2 + 1 = 44

所以,经过self.conv_block1后,特征图的尺寸为(1, 32, 29, 44),特征数为 32

接下来,我们将这个 32 通道的特征图输入到self.conv_block2

  • 第一个卷积层将输出通道数从 32 增加到 64,同上特征图的高度和宽度不变。
  • 经过 ReLU 激活函数后,特征数保持为 64
  • 第二个卷积层仍然保持 64 个输出通道,同上特征图的高度和宽度不变。
  • 再经过 ReLU 激活函数后,特征数仍为 64
  • 最后,最大池化层不会改变通道数,但会进一步减小特征图的高度和宽度。

同样地,最大池化层的池化窗口的大小为 3×3 步长为 2,我们可以计算新的特征图尺寸:

output_height = (29 - 3) / 2 + 1 = 14
output_width = (29 - 3) / 2 + 1 = 21

因此,经过 self.conv_block1self.conv_block2 后,最终的特征图的尺寸为 (32, 64, 14, 21)

nn.LinearPyTorch 中的一个全连接层(Fully Connected Layer),它用于执行线性变换。全连接层的输入和输出维度通常是由网络架构和数据的特性决定的。

nn.Linear 的第一个参数,即输入维度(input_featuresin_features

为了将这些特征图转换为一维向量以输入到全连接层,你需要将特征图的元素“展平”(flatten)。展平的过程是将所有元素按顺序排列成一个单一的向量。我们可以计算展平后新的特征数,即输入维度 (in_features)

in_features = 64 * 14 * 21 = 18816

第一个全连接层输出维度为 9408,再经过 ReLU 激活函数。

nn.DropoutPyTorch 库中的一种正则化技术的实现,常用于防止过拟合。在深度学习模型训练过程中,dropout 通过随机忽略(“丢弃”)一部分神经元的输出来降低模型的复杂性。这里 dropout 比例为 0.5,那么在训练过程中,每一步有 50% 的神经元输出会被随机设置为0。

同上过程,再来一次最后输出维度为 5,显然这是个 5-分类问题

图解过程

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1322004.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ST股票预测模型(机器学习_人工智能)

知己知彼,百战不殆;不知彼而知己,一胜一负;不知彼,不知己,每战必贻。--《孙子兵法》谋攻篇 ST股票 ST股票是指因连续两年净利润为负而被暂停上市的股票,其风险较高,投资者需要谨慎…

域架构下的功能安全思考

来源:联合电子 随着整车电子电气架构的发展,功能域控架构向整车集中式区域控制演进。新的区域控制架构下,车身控制模块(BCM),整车控制单元(VCU),热管理系统(TMS)和动力底…

JDK各个版本特性讲解-JDK14特性

JDK各个版本特性讲解-JDK14特性 一、Java14概述二、语法层面的变化1. instanceof2. switch表达式3. 文本块的改进4. Records记录类型 二、关于GC1.G1的NUMA内存分配优化2. 弃用SerialCMS,ParNewSerial Old3.删除CMS4.ZGC on macOS and Windows 三、其他变化1.友好的空指针异常提…

利用python在abaqus中画Voronoi多面体简单示例

利用python在abaqus中画Voronoi多面体简单示例 利用scipy.spatial库得到Voronoi多面体顶点坐标abaqus中绘制多面体CAE操作得到相应rpy文件0、 将vertices.csv和ridge_vertices.csv导入abaqus1、 新建一个part2、创建点3、画线4、画面 完整代码 利用scipy.spatial库得到Voronoi多…

【03】GeoScene创建海图或者电子航道图数据

1 配置Nautical属性 1.1 管理长名称 长名称(LNAM)是一个必要的对象标识符,是生产机构(AGEN)、要素识别号码(FIDN)和要素识别子项(FIDS)组件的串联。这三个子组件用于数…

azkaban编译时报错的解决方案

大数据单机学习环境搭建(11)Azkaban单机部署,关于Azkaban和gradle下载,本文编译不限于单机solo模式。 一.大多数报错处理 1.1首先操作 1)安装 git yum install git -y 2)替换 azkaban 目录下的 build.gradle 文件的 2处 repositories 信息。改为 阿里…

回归预测 | MATLAB实现GA-LSSVM基于遗传算法优化最小二乘向量机的多输入单输出数据回归预测模型 (多指标,多图)

回归预测 | MATLAB实现GA-LSSVM基于遗传算法优化最小二乘向量机的多输入单输出数据回归预测模型 (多指标,多图) 目录 回归预测 | MATLAB实现GA-LSSVM基于遗传算法优化最小二乘向量机的多输入单输出数据回归预测模型 (多指标&#…

【HCIP学习记录】OSPF之DD报文

1.OSPF报文格式 24字节 字段长度含义Version1字节版本,OSPF的版本号。对于OSPFv2来说,其值为2。Type1字节类型,OSPF报文的类型,有下面几种类型: 1:Hello报文;● 2:DD报文&#xff1…

使用Kaptcha实现的验证码功能

目录 一.需求 二.验证码功能实现步骤 验证码 引入kaptcha依赖 完成application.yml配置文件 浏览器显示验证码 前端页面 登录页面 验证成功页面 后端 此验证码功能是以SpringBoot框架下基于kaptcha插件来实现的。 一.需求 1.页面生成验证码 2.输入验证码&#xff…

vue中echarts柱状图点击x轴数据复制

参考自:Vue 3 使用 vue-echarts 的柱状图 barItem 和 x, y 轴点击事件实现_echarts x轴点击事件-CSDN博客 例如柱状图如下: 步骤: 一、数据处理的时候需要在 xAxis 对象中添加:triggerEvent: true 这个键值对,以增加…

ES索引误删的名场面

慌了3秒,果断发个邮件; 01 最近,在版本发布时; ES线上未备份的索引,被当场「误删」了; 对于新手来说,妥妥的社死名场面; 对于老手来说,慌它3秒表示一下态度&#xff1…

Python3,100行代码,写一段新年祝福视频,为新年喝彩。

新年祝福 1、引言2、代码示例2.1 思路2.2 介绍2.2.1 画布2.2.2 用法 2.3 实例 3、总结 1、引言 小屌丝:鱼哥, 这2023年马上就结束了, 是不是要表示表示。 小鱼:我也在思考这个事情。 小屌丝:这还需要思考?…

kubernetesr安全篇之云原生安全概述

云原生 4C 安全模型 云原生 4C 安全模型,是指在四个层面上考虑云原生的安全: Cloud(云或基础设施层)Cluster(Kubernetes 集群层)Container(容器层)Code(代码层&#xf…

modelsim使用技巧

Modelsim关闭Add items to the Project后,该如何添加existing file: 在project页面下,右键选择add to project-add existing file 设置modelsim的仿真波形时间单位: 打开Modelsim后,在Wave-Wave Preferences后&#…

从零开始学习Web自动化:用Python和Selenium实现网站登录功能!

Web自动化测试实战项目:使用Selenium和Python完成网站登录功能的自动化测试 本文将介绍如何使用Selenium和Python编写自动化测试脚本,对网站登录功能进行测试。我们将通过模拟用户在网站上输入用户名和密码,并点击登录按钮,来检验…

JavaWeb编程语言—登录校验

一、前言&简介 前言:小编的上一篇文章“JavaWeb编程语言—登录功能实现”,介绍了如何通过Java代码实现通过接收前端传来的账号、密码信息来登录后端服务器,但是没有实现登录校验功能,这代表着用户不需要登录也能直接访问服务器…

设计模式 原型模式 与 Spring 原型模式源码解析(包含Bean的创建过程)

原型模式 原型模式(Prototype模式)是指:用原型实例指定创建对象的种类,并且通过拷贝这些原型,创建新的对象。 原型模式是一种创建型设计模式,允许一个对象再创建另外一个可定制的对象,无需知道如何创建的细节。 工作原…

技术分享-Jenkins

持续集成及Jenkins介绍 软件开发生命周期叫SDLC(Software Development Life Cycle),集合了计划、开发、测试、部署过程。 在平常的开发过程中, 需要频繁地(一天多次)将代码集成到主干,这个叫持…

电子烟单片机方案开发,32位单片机PY32F030电子烟解决方案

电子烟是一种低压的微电子雾化设备。可以通过加热液体产生雾状物质,供用户吸入使用的新型电子产品。它是由微控制器(MCU)、超声波雾化发生器、充电管理IC、锂离子电池、发热棒等器件构成,主要用于替代传统香烟和戒烟,与…

一文读懂什么是智能工厂?

引言 在当今快速变革的制造业中,智能工厂如一盏明灯,照亮着未来生产的道路。它们不仅代表着技术的进步,更是制造业向前迈进的里程碑。智能工厂利用先进的技术和创新方法,将传统工厂转化为高度自动化、数字化和智能化的生产中心。…