【深度学习入门篇 ⑦】PyTorch池化层

news2024/9/23 3:25:03

【🍊易编橙:一个帮助编程小伙伴少走弯路的终身成长社群🍊】

大家好,我是小森( ﹡ˆoˆ﹡ ) ! 易编橙·终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。


池化层 (Pooling) 降低维度,缩减模型大小,提高计算速度. 即: 主要对卷积层学习到的特征图进行下采样(SubSampling)处理 。

  • 通过下采样,我们可以提取出特征图中最重要的特征,同时忽略掉一些不重要的细节。
  • 上采样是指增加数据(图像)的尺寸;通常用于图像的分割、超分辨率重建或生成模型中,以便将特征图恢复到原始图像的尺寸或更大的尺寸。 

池化层

池化包含最大池化和平均池化,有一维池化,二维池化,三维池化,在这里以二维池化为例

最大池化

最大池化就是求一个区域中的最大值,来代替该区域。

torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

假设输入的尺寸是(𝑁,𝐶,𝐻,𝑊),输出尺寸是(𝑁,𝐶,𝐻𝑜𝑢𝑡,𝑊𝑜𝑢𝑡),kernel_size是(𝑘𝐻,𝑘𝑊),可以写成下面形式 :

其中,输入参数 kernel_sizestridepaddingdilation可以是

  • 一个 int :代表长宽使用同样的参数
  • 两个int组成的元组:第一个int用在H维度,第二个int用在W维度
import torch
import torch.nn as nn
#长宽一致的池化,核尺寸为3x3,池化步长为2
ml = nnMaxPool2d(3, stride=2)
#长宽不一致的池化
m2 = nn.MaxPool2d((3,2), stride=(2,1))
input = torch.randn(4,3,24,24)
output1 = m1( input)
output2 = m2( input)
print( "input.shape = " ,input.shape)
print( "output1.shape = " , output1.shape)
print( "output2.shape = " , output2.shape)

 输出:

input.shape = torch.size([4,3,24,24])
output1.shape = torch. size([4,3,11,11])
output2.shape = torch.size([4,3,11,23])
平均池化

平均池化就是用一个区域中的平均数来代替本区域

torch.nn.AvgPool2d(kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None)

import torch
import torch.nn as nn
#长宽一致的池化,核尺寸为3x3,池化步长为2
ml = nn. AvgPool2d( 3, stride=2)
#长宽不一致的池化
m2 = nn. AvgPool2d(( 3,2), stride=(2,1) )
input = torch.randn(4,3,24,24)
output1 = m1( input)
output2 = m2( input)
print("input.shape = ",input. shape)
print("output1.shape = " , output1.shape)
print( "output2.shape = ", output2.shape)
  • randn是生成形状为[batch_size, channels, height, width] 

输出:

input.shape = torch.size([4,3,24,24])
output1.shape = torch.size([4,3,11,11])
output2.shape = torch.size([4,3,11,23])

BN层

BN,即Batch Normalization,是对每一个batch的数据进行归一化操作,可以使得网络训练更稳定,加速网络的收敛。

import torch
import torch.nn as nn
#批量归一化层(具有可学习参数)
m_learnable = nn. BatchNorm2d(100)
#批量归一化层(不具有可学习参数)
m_non_learnable = nn.BatchNorm2d(100,affine=False)
#随机生成输入数据
input = torch.randn(20,100,35,45)
#应用具有可学习参数的批量归一化层
output_learnable = m_learnable(input)
#应用不具有可学习参数的批量归一化层
output_non_learnable = m_non_learnable(input)
print( "input.shape = ", input.shape)
print( "output_learnable.shape = ", output_learnable.shape)
print( "output_non_learnable.shape = ", output_non_learnable.shape)

 输出:

input.shape = torch.size([20,100,35,45])
output_learnable.shape = torch.size( [20,100,35,45])
output_non_learnable.shape = torch.size([20,100,35,45])

常见的层就是上面提到的这些,如果这些层结构被反复调用,我们可以将其封装成一个个不同的模块。

案例:复现LeNet

LeNet结构,使用PyTorch进行复现,卷积核大小5x5,最大池化层,核大小2x2

import torch
import torch.nn as nn
from torchsummary import summary
class LeNet( nn . Module):
    def _init_( self,num_classes=10):
        super(Leet, self)._init__()
        self.conv1 = nn.conv2d( in_channels=3,out_channels=6,kernel_size=5)
        self.pool1 = nn. MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)
        self.pool2 = nn. MaxPool2d(kernel_size=2)
        self.conv3 = nn.conv2d(in_channels=16,out_channels=120, kernel_size=5)
        self.fc1 = nn.Linear(in_features=120,out_features=84)
        self.fc2 = nn.Linear(in_features=84,out_features=10)
    
    def forward(self, x):
        #通过卷积层、ReLU和池化层
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = x.view( -1,120)
        x = self.fc1(x)
        x = self.fc2(x)
        return x
#创建网络实例
num_classes = 10
net = LeNet( num_classes)

#创建一个输入
batch_size = 4
input_tensor = torch.randn(batch_size,3,32,32)
# 假设输入是32x32的RGB图像
#将输入Tensor传递给网络
output = net(input_tensor)
# #显示输出Tensor的形状
print(output.shape)
summary(net,(3,32,32))

Sequential: 顺序容器

Sequential属于顺序容器。模块将按照在构造函数中传递的顺序从上到下进行运算。

使用OrderedDict,可以进一步对传进来的层进行重命名。

#使用sequential来创建小模块,当有输入进来,会从上到下依次经过所有模块
model = nn. Sequential(
nn.conv2d(1,20,5),nn.ReLu() ,
nn.conv2d(20,64,5),nn.ReLU()
)
#使用orderedDict,可以对传进来的模块进行命名,实现效果同上
from collections import orderedDict
model = nn. sequential ( orderedDict([
        ( 'conv1 ', nn.Conv2d( 1,20,5)),
        ( 'relu1 ', nn.ReLU( ) ),
        ( 'conv2 ', nn.conv2d(20,64,5)),
        ( 'relu2 ', nn.ReLU())
]))

除此之外,还可以用 ModuleList和 ModuleDict 来存放子模块,但是用的不多,掌握了上面的内容就足够了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1930048.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Windows 系统利用 SSH 和 WSL2 子系统当服务器

由于最近组内需要将一台 Windows 系统的电脑 W A W_A WA​ 转成能通过 SSH 访问,并且能用 Linux 命令当服务器运行。忙活了一天,终于是把全部东西弄通了。 安装 SSH 首先就是 W A W_A WA​ 先要安装 OpenSSH 服务,直接按照下面的教程安装…

LVS+Keepalive高可用

1、keepalive 调度器的高可用 vip地址主备之间的切换,主在工作时,vip地址只在主上,vip漂移到备服务器。 在主备的优先级不变的情况下,主恢复工作,vip会飘回到住服务器 1、配优先级 2、配置vip和真实服务器 3、主…

【Python实战因果推断】38_双重差分9

目录 Doubly Robust Diff-in-Diff Propensity Score Model Delta Outcome Model All Together Now Doubly Robust Diff-in-Diff 另一种纳入干预前协变变量和时间不变协变变量以考虑条件平行趋势的方法是制作双稳健差分法(DRDID)。要做到这一点&#…

鸿蒙语言基础类库:【@system.brightness (屏幕亮度)】

屏幕亮度 说明: 从API Version 7 开始,该接口不再维护,推荐使用新接口[ohos.brightness]。本模块首批接口从API version 3开始支持。后续版本的新增接口,采用上角标单独标记接口的起始版本。 导入模块 import brightness from sy…

【接口自动化_13课_接口自动化总结】

一、自我介绍 二、项目介绍 自己的职责、项目流程 1)功能测试,怎么设计用例的--测试策略 2)功能测试为什么还有代码实现,能用工具实现,为什么还用代码实现。 基本情况 项目名称:项目类型:项目测试人员…

ubantu22.04安装OceanBase 数据库

1、管理员启动cmd,运行 sudo bash -c "$(curl -s https://obbusiness-private.oss-cn-shanghai.aliyuncs.com/download-center/opensource/service/installer.sh)" 2、提示如下代表安装完成 3、修改数据库配置文件的密码 sudo vim /etc/oceanbase.cnf 然后保存退…

如何申请自费访问学者?自费访问学者材料要哪些?

四、申请自费访问学者需要准备哪些材料 访问学者的申请材料应包括个人简历、推荐信、成绩单、研究计划或课题、语言能力证明等,资费访问学者需提交财务证明。 申请表格填写基本个人信息、访学时间、访学目的等。且应在个人简历详细列出教育背景、工作经历、学术成…

一张图生成绘画全过程,这下人人都成“原画师”了

玩过SD的应该都知道ControlNet吧,最近ControlNet的作者Lvmin Zhang 又搞了一个开源项目PaintsUndo,在Github刚上线就收获了2.7k Star。 只需要上传一张静态图像,PaintsUndo就可以根据提供的图像自动生成对应的绘画全过程视频。 展示从一张白…

【python报错已解决】Stack Overflow

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 引言 在开发的旅程中,我们难免会遇到各种各样的报错信息,这些报错就像旅途中的绊脚石,阻挡…

GA-Kmeans-Transformer-GRU时序聚类+状态识别组合模型,创新发文无忧!

GA-Kmeans-Transformer-GRU时序聚类状态识别组合模型,创新发文无忧! 目录 GA-Kmeans-Transformer-GRU时序聚类状态识别组合模型,创新发文无忧!效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.GA-Kmeans-Transformer-GRU时…

0601大学物理电磁篇 静电场中的导体和电介质

静电场中的导体和电介质01 6-1静电场中的导体 6-1静电场中的导体

【Hot100】LeetCode—155. 最小栈

目录 题目1- 思路2- 实现⭐155. 最小栈——题解思路 3- ACM 实现 题目 原题连接&#xff1a;155. 最小栈 1- 思路 思路 最小栈 ——> 借助两个栈来实现 2- 实现 ⭐155. 最小栈——题解思路 class MinStack {Stack<Integer> data;Stack<Integer> min;public …

如何在Linux上部署Ruby on Rails应用程序

在Linux上部署Ruby on Rails应用程序是一个相对复杂的过程&#xff0c;需要按照一系列步骤进行。下面是一个基本的部署过程&#xff0c;涵盖了从安装所需软件到部署应用程序的所有步骤。 安装必要的软件 在部署Ruby on Rails应用程序之前&#xff0c;需要确保Linux系统上安装了…

etime:拓展time

拓展C库的time模块&#xff0c;时间格式转换、代码块计时器。

公司想无偿裁员,同事赖着不走

关注卢松松&#xff0c;会经常给你分享一些我的经验和观点。 这招好像也不错! 事情是这样的&#xff1a;某公司准备把成本高的员工都裁掉&#xff0c;主要包含研发部和程序员&#xff0c;总共18个人&#xff0c;准备裁掉10人&#xff0c;因为他们工资开的太高了&#xff0c;…

CH390H+STM32F1+LWIP

文章目录 1、CH390芯片介绍2、电路部分3、LWIP调试3.1修改点13.2 修改点2 4、结果展示参考 1、CH390芯片介绍 官网地址&#xff1a; 南京沁恒微电子股份有限公司 特点&#xff1a; 2、电路部分 CH390及接口&#xff1a; STM32F1引脚&#xff1a; 不含LWIP的demo及LWIP…

数据安全评估师CCRC-DSA:数据安全如何“护航”创新发展

在2024年的中关村论坛年会关于数据安全治理与发展的讨论中&#xff0c;新加坡资讯通信媒体发展局局长柳俊泓指出&#xff0c;根据预测&#xff0c;今年全球将产生约147泽字节的数据量&#xff0c;这意味着每个地球上的人都拥有相当于150部iPhone的数据量。 柳俊泓强调&#xf…

【Linux网络】poll{初识poll / poll接口 / poll vs select / poll开发多客户端echo服务器}

文章目录 1.初识pollpoll与select的主要联系与区别poll的原理poll的优点poll的缺点poll vs select 2.poll开发多客户端echo服务器封装套接字接口Makefile主函数日志服务聊天服务器 1.初识poll poll是Linux系统中的一个系统调用&#xff0c;它用于监控多个文件描述符&#xff08…

RocketMQ实现分布式事务

RocketMQ的分布式事务消息功能&#xff0c;在普通消息基础上&#xff0c;支持二阶段的提交。将二阶段提交和本地事务绑定&#xff0c;实现全局提交结果的一致性。 1、生产者将消息发送至RocketMQ服务端。 2、RocketMQ服务端将消息持久化成功之后&#xff0c;向生产者返回Ack确…

2. LangChain4j 之AI Services

一: AI Services介绍 LangChain4j提供了很多基础的组件&#xff0c;每次使用都需要你编写大量的样板代码&#xff0c;LangChain4j就提供了一个高级的组件AI Services&#xff0c;它可以简化与大模型(LLM)和其他组件交互的复杂度。让开发者更专注于业务逻辑&#xff0c;而不是底…