PyTorch学习笔记-Convolution Layers与Pooling Layers

news2025/1/19 20:25:51

1. Convolution Layers

由于图像是二维的,因此基本上最常用到的就是二维的卷积类:torch.nn.Conv2d,官方文档:torch.nn.Conv2d。

Conv2d 的主要参数有以下几个:

  • in_channels:输入图像的通道数,彩色图像一般都是三通道。
  • out_channels:通过卷积后产生的输出图像的通道数。
  • kernel_size:可以是一个数或一个元组,表示卷积核的大小,卷积核的参数是从数据的分布中采样得到的,这些数是多少无所谓,因为在神经网络训练的过程中就是对这些参数进行不断地调整。
  • stride:步长。
  • padding:填充。
  • padding_mode:填充模式,有 zerosreflectreplicatecircular,默认为 zeros
  • dilation:可以是一个数或一个元组,表示卷积核各个元素间的距离。
  • group:一般设置为1,基本用不到。
  • bias:偏置,一般设置为 True。

例如以下代码构建了一个只有一层卷积层的神经网络,该卷积层的输入和输出通道数都为三通道,卷积核大小为3*3,步长为1,无填充,然后用 CIFAR10 测试数据集进行测试:

from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn

test_set = datasets.CIFAR10('dataset/CIFAR10', train=False, transform=transforms.ToTensor())

data_loader = DataLoader(test_set, batch_size=64)

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=0)

    def forward(self, input):
        output = self.conv1(input)
        return output

network = Network()
print(network)  # Network((conv1): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1)))

writer = SummaryWriter('logs')

for step, data in enumerate(data_loader):
    imgs, targets = data
    output = network(imgs)
    writer.add_images('input', imgs, step)
    writer.add_images('output', output, step)

writer.close()

测试结果如下:

在这里插入图片描述

可以看到卷积运算能够提取输入图像的不同特征,第一层卷积层可能只能提取一些低级的特征如边缘、线条和角等层级,更多层的网络能从低级特征中迭代提取更复杂的特征。

2. Pooling Layers

Pooling Layers 相关函数介绍的官方文档:Pooling Layers。

其中的 MaxPool 表示最大池化,也称上采样;MaxUnpool 表示最小池化,也称下采样;AvgPool 表示平均池化。其中最常用的为 MaxPool2d,官方文档:torch.nn.MaxPool2d。

最大池化的步骤如下图所示:

在这里插入图片描述

MaxPool2d 的主要参数有以下几个:

  • kernel_size:用来取最大值的窗口(池化核)大小,和之前的卷积核类似。
  • stride:步长,注意默认值为 kernel_size
  • padding:填充,和 Conv2d 一样。
  • dilation:池化核中各个元素间的距离,和 Conv2d 一样。
  • return_indices:如果为 True,表示返回值中包含最大值位置的索引。注意这个最大值指的是在所有窗口中产生的最大值,如果窗口产生的最大值总共有5个,就会有5个返回值。
  • ceil_mode:如果为 True,表示在计算输出结果形状的时候,使用向上取整,否则默认向下取整。

输出结果形状的计算公式如下:

在这里插入图片描述

接下来我们用代码实现这个池化层:

from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import torch

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

    def forward(self, input):
        output = self.maxpool1(input)
        return output

input = torch.tensor([
    [1, 2, 1, 0],
    [0, 1, 2, 3],
    [3, 0, 1, 2],
    [2, 4, 0, 1]
], dtype=torch.float32)  # 注意池化层读入的数据需要为浮点型

input = torch.reshape(input, [1, 1, 4, 4])

network = Network()
print(network)  # Network((maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))

output = network(input)
print(output)
# tensor([[[[2., 3.],
#           [4., 2.]]]])

我们用图像来试试效果:

from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn

test_set = datasets.CIFAR10('dataset/CIFAR10', train=False, transform=transforms.ToTensor())

data_loader = DataLoader(test_set, batch_size=64)

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

    def forward(self, input):
        output = self.maxpool1(input)
        return output

network = Network()

writer = SummaryWriter('logs')

for step, data in enumerate(data_loader):
    imgs, targets = data
    output = network(imgs)
    writer.add_images('input', imgs, step)
    writer.add_images('output', output, step)

writer.close()

测试结果如下:

在这里插入图片描述

可以看到最大池化的目的是保留输入数据的特征,同时减小特征的数据量

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/44176.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

IDEA关于数据库报错SQL dialect is not configured或Unable to resolve table ‘表名‘

目录一、SQL dialect is not configured1.1 报错场景展示1.2 方式一,万能altenter1.3 方式二,在setting中设置二、Unable to resolve table 表名2.1 报错场景展示2.2 方式一,万能altenter2.3 方式二,在setting中设置一、SQL diale…

vscode开发STM32(三)---调试篇

vscode开发STM32(三)—调试篇 文章目录vscode开发STM32(三)---调试篇前提条件配置调试配置JLink使用JLinkGDB进行调试配置stlink使用openOCD进行调试完整的launch文件内容前提条件 安装Cortex-Debug插件 安装OpenOCD 安装JLink驱…

LeetCode HOT 100 —— 48.旋转图像

题目 给定一个 n n 的二维矩阵 matrix 表示一个图像。请你将图像顺时针旋转 90 度。 你必须在 原地 旋转图像,这意味着你需要直接修改输入的二维矩阵。请不要 使用另一个矩阵来旋转图像。 思路 方法一:使用辅助数组 可以得出规律,将图像旋…

集合框架----源码解读HashMap篇(一)

1.HashMap官方介绍 基于哈希表的Map接口实现。该实现提供了所有可选的映射操作,并允许空值和空键。(HashMap类大致相当于Hashtable,除了它是非同步的,并且允许为空值。)这个类不保证映射的顺序;特别是,它不能保证顺序随时间的推移…

Nodejs -- Express托管静态资源

文章目录托管静态资源1 expess.static()2 托管多个静态资源目录3 挂载路径前缀托管静态资源 1 expess.static() express提供了一个非常好用的函数,叫做express.static(),通过它,我们可以非常方便地创建一个静态资源服务器,例如&…

PG::FunboxEasyEnum

nmap -Pn -p- -T4 --min-rate1000 192.168.81.132 nmap -Pn -p 22,80 -sCV 192.168.81.132 80端口是Apache2 Ubuntu的默认页面 尝试路径爆破 /mini.php可以进行文件上传 直接上传reverse-php-shell 上传linpeas脚本进行枚举,得到oracle用户的密码hash oracle…

2022-11-28-大数据可视化“可视化国产/进口电影票房榜单”分析,特征维度大于50

可视化国产/进口电影票房榜单前言数据分析数据可视化过程分析总结前言 党的十八大以来,国产电影产业与事业快速发展,创作水平不断提高,题材类型丰富多元,受众口碑不断提升,在市场竞争中表现愈发突出,已成为…

《论文阅读》BA-NET: DENSE BUNDLE ADJUSTMENT NETWORKS

留个笔记自用 BA-NET: DENSE BUNDLE ADJUSTMENT NETWORKS 做什么 首先是最基础的,Structure-from-Motion(SFM),SFM可以简单翻译成运动估计,是一种基于dui8序列图片进行三维重建的算法。简单来说就是是从运动中不同角…

【Python】记录从3.9升级到3.11踩的坑

写在前面的话:如果想体验python3.11,不推荐生产环境升级,可以现在测试环境试试看 环境变化 原始环境 Python3.9,有挺多安装的第三方库,有自己写的类和方法,程序一切运行正常 升级环境 Python3.11&#…

如何获取Adreno GPU数据

什么是GPU GPU(Graphic Processing Unit)是图形处理器,相当于在计算机和移动终端上做图形图像运算工作的微处理器,显示芯片。通过向量计算和并行计算等方式加速了原有的计算工作,能够更好地处理几何转换和光照计算等&a…

如何与意法半导体STMicro建立EDI连接?

项目背景 意法半导体STMicro是全球最大的半导体公司之一,2010 年净收入 103.5 亿美元,2011 年第二季度净收入 25.7亿美元。 以业内最广泛的产品组合著称,凭借多元化的技术、尖端的设计能力、知识产权组合、合作伙伴战略和高效的制造能力&…

pdf怎么编辑?分享两款pdf编辑软件,编辑pdf也很简单!

pdf怎么编辑?其实也很简单,现在跟大家分享两款pdf编辑软件,可以让我们对pdf实现自由编辑修改,有了这两款pdf编辑软件,编辑pdf将不再困难。 pdf编辑软件一:万兴pdf编辑软件 万兴pdf是一款受众广泛&#xff0…

【设计】OOA、OOD、OOP

这三者都是 OO(Object-Oriented)领域的思想。 一般我们我们接到产品经理的需求后,开发阶段分这样几个步骤: 可行性预研阶段,此阶段评估需求是否合理,能否实现;OOA阶段,此阶段分析用…

【Lilishop商城】No2-5.确定软件架构搭建四(本篇包括消息中间件RocketMQ)

仅涉及后端,全部目录看顶部专栏,代码、文档、接口路径在: 【Lilishop商城】记录一下B2B2C商城系统学习笔记~_清晨敲代码的博客-CSDN博客 全篇只介绍重点架构逻辑,具体编写看源代码就行,读起来也不复杂~ 谨慎&#xff…

Python:如何在 CentOS 8 服务器上运行 Selenium 代码?

前言 因项目需求,需要在 CentOS 8 服务器上运行 Python-Selenium 代码,那么该如何操作呢? 运行环境 CentOS Stream 8Python 3.9.13selenium4.6.0Google Chrome 107.0.5304.121 操作步骤 安装 Google Chrome 下载 Linux 版本的 Chrome 将下…

怎么合并视频?快把这些方法收好

小伙伴们平时会在通过网课来提高自己的技能吗?我经常会在网上保存一系列的视频进行学习,可是当保存的网课视频数量多起来后,每次想要找对应的视频,都得花上不少的时间。其实我们可以通过将相同系列的视频合并起来的方法&#xff0…

java word,excel,ppt转pdf

准备工作 1.下载 jacob.jar 链接:https://pan.baidu.com/s/1TWIGyX9A3xQ6AG9Y3mVlVg 提取码:abcd 2.下载安装wpsWPS Office-支持多人在线编辑多种文档格式_WPS官方网站 3.添加 jar到项目和ddl文件放在jdk的jre/bin目录下,记得自己系统是…

13_cgi

知识点1【cgi实现计算器案例】 2、GET的同步方式&#xff1a; index.html <html><head><title>table</title><meta charset"UTF-8"><!--这是描述 js中的函数来之哪个js文件--><script type"text/javascript" sr…

Kafka基础与核心概念

本文&#xff0c;我们将试图回答什么是apache kafka。 kafka是一个分布式流平台或者分布式消息提交日志 分布式 Kafka 由一个或多个节点组成的工作集群&#xff0c;这些节点可以位于不同的数据中心&#xff0c;我们可以在 Kafka 集群的不同节点之间分布数据/负载&#xff0c;并…

【学习笔记47】开关变量和拖拽效果

一、开关案例 <button>点击获取验证码</button>&#xff08;一&#xff09;基本功能的实现 // 获取标签对象const oBtn document.querySelector(button);// 给按钮添加点击事件oBtn.addEventListener(click, function () {// 定义变量 用于获取验证码let count 5…