pytorch初学笔记(十):神经网络基本结构之最大池化的使用

news2024/11/19 5:30:36

目录

一、最大池化:下采样

1.1 参数介绍 

1.2 公式

二、最大池化的作用和目的

三、代码实战

3.1 题目要求

3.2 池化的具体实现

3.2.1 步骤

3.2.2 报错及其原因 

3.2.3 ceil_mode不同运行的结果不同

3.2.4 完整代码

3.3 tensorboard可视化


一、最大池化:下采样

官方文档:torch.nn — PyTorch 1.13 documentation

1.1 参数介绍 

Parameters:

  • kernel_size 池化核 (Union[intTuple[intint]]) – the size of the window to take a max over

  • stride (Union[intTuple[intint]]) – the stride of the window. Default value is kernel_size

  • padding (Union[intTuple[intint]]) – implicit zero padding to be added on both sides

  • dilation (Union[intTuple[intint]]) – a parameter that controls the stride of elements in the window

  • return_indices (bool) – if True, will return the max indices along with the outputs. Useful for torch.nn.MaxUnpool2d later

  • ceil_mode (bool) – when True, will use ceil instead of floor to compute the output shape,是否对结果进行保留,默认为FALSE

注意: 

1. stride的默认大小为池化核的大小 

2. dilation:空洞卷积,如右图,进行卷积操作时会隔n个取一个。

 

  

3. ceil_mode:ceil为向上取整,floor为向下取整。

  • ceil_mode=True,结果进行保留;
  • ceil_mode=False,结果不进行保留

1.2 公式

输入的input要求为四维或者三维,需要输入通道数以及长和宽。

        因此当我们自定义输入一个input矩阵时,需要再使用torch.reshape方法将其转变成(N, C, H, W)的维度。

 

二、最大池化的作用和目的

作用:最大限度的保留图片特征,同时减少数据量。加速训练速度。

三、代码实战

3.1 题目要求

输入tensor矩阵为5*5,如下图所示,池化核为大小为3*3,经过池化后根据ceil_mode的不同应输出下图所示的两个矩阵,输出如下的采样结果。

 

3.2 池化的具体实现

3.2.1 步骤

  1. 输入tensor型变量input
  2. 按照池化函数所需的input尺寸reshape输入的大小: 
    input = torch.reshape(input, (-1,1, 5, 5))
  3. 自定义神经网络,完成池化操作 
  4. 实例化神经网络,输出结果

3.2.2 报错及其原因 

import torch
from torch.nn import MaxPool2d
#输入的矩阵 
input = torch.tensor([
    [1,2,0,3,1],
    [0,1,2,3,1],
    [1,2,1,0,0],
    [5,2,3,1,1],
    [2,1,0,1,1]
])

input = torch.reshape(input,(1,5,5))
print(input.shape)

class Maweiyi(torch.nn.Module):
    def __init__(self):
        super(Maweiyi, self).__init__()
        # 设置池化
        self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=False)

    def forward(self,input):
        output = self.maxpool1(input)
        return output

maweiyi = Maweiyi()
output = maweiyi(input)
print(output)

会出现如下报错:无法实现long型的数据。

解决:需要修改input的类型,设置tensor的dtype=float32 。

#输入的矩阵
input = torch.tensor([
    [1,2,0,3,1],
    [0,1,2,3,1],
    [1,2,1,0,0],
    [5,2,3,1,1],
    [2,1,0,1,1]
],dtype=torch.float32)

 修改过后即可成功运行。

3.2.3 ceil_mode不同运行的结果不同

1. ceil_mode = True,保留最大采样过程中的所有结果,运行出的tensor大小为4*4

self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)

 

2. ceil_mode = False,不保留最大采样过程中的所有结果,运行出的tensor大小为1*1

self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=False)

3.2.4 完整代码

import torch
from torch.nn import MaxPool2d
#输入的矩阵
input = torch.tensor([
    [1,2,0,3,1],
    [0,1,2,3,1],
    [1,2,1,0,0],
    [5,2,3,1,1],
    [2,1,0,1,1]
],dtype=torch.float32)

input = torch.reshape(input,(1,5,5))
print(input.shape)

class Maweiyi(torch.nn.Module):
    def __init__(self):
        super(Maweiyi, self).__init__()
        # 设置池化
        self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)

    def forward(self,input):
        output = self.maxpool1(input)
        return output

maweiyi = Maweiyi()
output = maweiyi(input)
print(output)

 

3.3 tensorboard可视化

import torch
import torchvision.datasets
from torch.nn import MaxPool2d
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10(root=".\CIFAR10",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset,batch_size=64)
from torch.utils.tensorboard import SummaryWriter

input = torch.tensor([
    [1, 2, 0, 3, 1],
    [0, 1, 2, 3, 1],
    [1, 2, 1, 0, 0],
    [5, 2, 3, 1, 1],
    [2, 1, 0, 1, 1]
], dtype=torch.float32)

input = torch.reshape(input, (-1,1, 5, 5))


class Maweiyi(torch.nn.Module):
    def __init__(self):
        super(Maweiyi, self).__init__()
        self.maxPool1 = MaxPool2d(kernel_size=3,ceil_mode=True)

    def forward(self, input):
        output = self.maxPool1(input)
        return output

writer = SummaryWriter("logs")
step = 0

maweiyi = Maweiyi()

for data in dataloader:
    imgs,labels=data
    writer.add_images("inputs",imgs,step)
    output = maweiyi(imgs)
    writer.add_images("output",output,step)
    step+=1

writer.close()

输出如下所示:

可以看到经过最大池化之后图片变为了马赛克形式,不太清晰,但是大体能保留原图像的特征。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/29567.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【强化学习论文合集】AAMAS-2022 强化学习论文 | 2022年合集(三)

强化学习(Reinforcement Learning, RL),又称再励学习、评价学习或增强学习,是机器学习的范式和方法论之一,用于描述和解决智能体(agent)在与环境的交互过程中通过学习策略以达成回报最大化或实现特定目标的问题。 本专栏整理了近几年国际顶级会议中,涉及强化学习(Rein…

相控阵天线(六):直线阵列特殊综合方法(变形泰勒综合法、贝利斯综合法、伍德沃德抽样法)

目录简介变形泰勒综合法贝利斯综合法伍德沃德-劳森抽样法配相抵消法简介 阵列天线的综合问题是其分析的逆问题,即是在预先给定辐射特性(如方向图形状、副瓣电平等)的情况下,综合出阵列激励幅度和相位。其中特殊综合主要包括:左右副瓣电平不相…

关于Idea合并不同分支代码你怎么看

一、环境说明 1. IDEA版本 2020.1 2. git版本 2.33.0 二、整体合并 1. 软件开发中,在一次版本迭代过程中,大家可能会在同一个开发分支dev进行开发,同时开发不同功能 ,开发完以后需要自行合并到测试分支test,交给测试…

Feign高级实战-源码分析

目录参考导读什么是FeignFeign 和 Openfeign 的区别OpenFeign的启动原理在启动类申明EnableFeignClientsregisterDefaultConfigurationregisterFeignClientsregisterFeignClientgetTarget()创建一个代理对象HttpClientFeignLoadBalancerConfigurationOpenFeign 的工作原理动态代…

多策略协同改进的阿基米德优化算法及其应用(Matlab代码实现)

🍒🍒🍒欢迎关注🌈🌈🌈 📝个人主页:我爱Matlab 👍点赞➕评论➕收藏 养成习惯(一键三连)🌻🌻🌻 🍌希…

论文阅读:On the User Behavior Leakage from Recommender System Exposure

论文地址 Motivation: 现阶段对于用户行为的保护仅仅从用户端来考虑,比如用户的行为数据等。然而推荐系统是一个闭环的过程,即用户交互了物品,推荐系统根据用户的交互信息去推荐物品,用户也会根据推荐系统推荐的物品做…

[Java] 浅析rpc的原理及所用到的基本底层技术

文章目录前言阅读前须知rpc是什么?别的进程 vs 别的机器rpc的目的或是我们为什么需要rpc?实现rpc所涉及到的底层技术1. 通信技术(网络IO、Network IO)套接字(Socket)bio、nio与Netty2. 网络协议&#xff08…

【仿真建模】第三课:AnyLogic入门基础课程 - 多层建筑行人疏散仿真讲解

文章目录一、Agent类的概念二、行人疏散仿真2.1 仿真模型示意图2.2 具体实现步骤一、Agent类的概念 二、行人疏散仿真 2.1 仿真模型示意图 2.2 具体实现步骤 首先,新建模型 新建一个MyFloor1对象,代表第一个楼层 创建矩形墙,并放到原点…

专业数采软件DXP OPC Server售后问题解决方案

DeviceXPlorer OPC Server是一套实现工业自动化设备数据读取或发送的软件。它提供与制造车间中的控制设备(如 PLC、机床和机器人)的连接,支持200多种设备通讯协议,便捷的配置,快速实现设备联网采集。 在与设备通讯方面…

HTML+CSS大作业 环境网页设计与实现(垃圾分类) web前端开发技术 web课程设计 网页规划与设计

🎀 精彩专栏推荐👇🏻👇🏻👇🏻 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 💂 作者主页: 【主页——🚀获取更多优质源码】 🎓 web前端期末大作业…

idea永久设置maven配置,新项目不用再设置

在这里设置就是永久的设置,新项目将使用该设置,maven的配置也在新项目和新模块创建的时候直接加载 英文的话,看位置大概也应该可以找到 点开后左上角搜索maven,找到如图maven的设置 主路径就是maven的安装包软件的路径 用户设置…

LeetCode 110平衡二叉树 257.二叉树的所有路径 404左叶子之和

文章目录110平衡二叉树c 代码实现python 代码实现257.二叉树的所有路径c代码实现python 代码实现404左叶子之和c 代码实现python 代码110平衡二叉树 给定一个二叉树,判断它是否是高度平衡的二叉树。 本题中,一棵高度平衡二叉树定义为: 一个…

http 知识整理

1. 启发式缓存 在不设置cache-control/expires的情况下,浏览器不会默认进入协商缓存。而是根据Date/LastModified去自动计算出合适的缓存时间。 计算方式为:(Date - LastModified) * n n:LM-Factor,处于[0,1]之间 2. 强制缓存 -…

Vue的模版代码与数据绑定方式

目录 模版代码 插值语法 指令语法 数据多层访问 vue模版语小结 数据绑定方式 模版代码 插值语法 插值语法就是使用{{xxx}}描述的 <div id"root">{{name}} </div> 指令语法 <div id"root"><a :href"school.url">…

lazada买家订单导出

下载安装与运行 https://www.yuque.com/webcrawl/handbook/mtad3q 用途与功能 所见即所得的导出自由选择导出项支持Excel、JSON两种方式导出自由排序Excel导出列顺序导出过程中有进度提示&#xff0c;用户可以随时提前中止 导出过程演示 选择lazada订单导出&#xff0c;开始…

linux内核整体架构

操作系统概念 操作系统属于软件范畴&#xff0c;负责管理系统的硬件资源。OS具备的功能&#xff1a;1.为应用程序提供执行环境。2.为多用户和应用程序管理计算机的硬件资源。3.虚拟化功能。4.支持并发。 宏内核与微内核架构 宏内核&#xff1a;所有的内核代码都编译成二进制…

基于JAVA的学生课程后台管理系统【数据库设计、源码、开题报告】

数据库脚本下载地址&#xff1a; https://download.csdn.net/download/itrjxxs_com/86427641 开学选好课是具备学术能力的首要表现。学生不能为了拿高分&#xff0c;只选简单课程&#xff0c;也没有必要为了显示出自己热衷自我挑战&#xff0c;奋不顾身地一头扎进高难度课程。在…

强化深度学习中利用时序差分法中的Sarsa算法解决风险投资问题实战(附源码 超详细必看)

需要源码请点赞关注收藏后评论区留下QQ~~~ 一、Sarsa算法简介 Sarsa算法每次更新都需要获取五元组&#xff08;S,A,R,S,A&#xff09;这也是该算法称为Sarsa的原因&#xff0c;每当从非终止状态进行一次转移后&#xff0c;就进行一次更新&#xff0c;但需要注意的是&#xff0…

【论文阅读】社交网络传播最大化问题-04

Efficient Influence Maximization in Social Networks相关工作改进的贪心算法对独立级联模型的改进对加权级联模型的改进改进度折扣算法影响力最大化&#xff1a;在社交网络中找到一小部分能够最大化传播影响力的节点(种子节点)。一是改进原有的贪心算法&#xff0c;进一步缩短…

KMP算法——通俗易懂讲好KMP算法:实例图解分析+详细代码注解

文章目录1.kmp算法基本介绍2.字符串的最长公共前后缀&部分匹配表2.1 什么是最长公共前后缀2.2 什么是部分匹配表Next2.3 字符串最长公共前后缀&部分匹配表的代码实现2.4 代码测试3.根据部分匹配表搜索字符串匹配位置3.1 匹配成功一个就退出匹配的代码3.1.1 KMP算法的大…