pytorch05:卷积、池化、激活

news2025/4/24 7:33:23

目录

  • 一、卷积
    • 1.1 卷积的概念
    • 1.2 卷积可视化
    • 1.3 卷积的维度
    • 1.4 nn.Conv2d
      • 1.4.1 无padding 无stride卷积
      • 1.4.2 无padding stride=2卷积
      • 1.4.3 padding=2的卷积
      • 1.4.4 空洞卷积
      • 1.4.5 分组卷积
    • 1.5 卷积输出尺寸计算
    • 1.6 卷积的维度
    • 1.7 转置卷积
      • 1.7.1 为什么被称为转置卷积
      • 1.7.2 nn.ConvTranspose2d
      • 1.7.3 转置卷积的计算方法
      • 1.7.4 核心代码
  • 二、池化层(Pooling Layer)
    • 2.1 池化的概念
    • 2.2 nn.MaxPool2d
      • 2.2.1 代码实现
    • 2.3 nn.AvgPool2d
      • 2.3.1 代码实现
    • 2.4 最大池化与平均池化区别
    • 2.5 nn.MaxUnpool2d
      • 2.5.1 核心代码实现
  • 三、线性层(Linear Layer)
    • 3.1nn.Linear
  • 四、激活函数层(Activation Layer)
    • 4.1 概念
    • 4.2 nn.Sigmoid激活函数
    • 4.3 nn.tanh激活函数
    • 4.4 nn.ReLU激活函数
    • 4.5 ReLU变体形式

一、卷积

1.1 卷积的概念

卷积运算:卷积核在输入信号(图像)上滑动,相应位置上进行乘加
卷积核:又称为滤波器,过滤器,可认为是某种模式,某种特征。
卷积过程类似于用一个模版去图像上寻找与它相似的区域,与卷积核模式越相似,激活值越高,从而实现特征提取。
在这里插入图片描述

1.2 卷积可视化

AlexNet这篇论文对卷积核进行了可视化,发现卷积核学习到的是边缘,条纹,色彩这一些细节模式,但是只有前几层卷积提取的特征可视化较为明显,随着网络的加深,卷积次数的增加,特征可视化也逐渐模糊。
在这里插入图片描述

1.3 卷积的维度

卷积维度:一般情况下,卷积核在几个维度上滑动,就是几维卷积,下面三幅图分别是一维卷积、二维卷积、三维卷积。我们常见的图片特征提取使用的是二维卷积(conv2d),在医学图像领域用于癌细胞切片分析使用的是三维卷积(conv3d)。
一维卷积
在这里插入图片描述
在这里插入图片描述

1.4 nn.Conv2d

功能:对多个二维信号进行二维卷积,例如图片
主要参数:
• in_channels:输入通道数
• out_channels:输出通道数,等价于卷积核个数
• kernel_size:卷积核尺寸
• stride:步长,卷积核每次移动的长度
• padding :图片边缘填充个数
• dilation:空洞卷积大小,常用于图像分割任务,用来提升感受野
• groups:分组卷积设置
• bias:偏置
在这里插入图片描述

1.4.1 无padding 无stride卷积

每次在原图滑动1个单位
在这里插入图片描述

1.4.2 无padding stride=2卷积

每次在原图滑动两个单位
在这里插入图片描述

1.4.3 padding=2的卷积

在原图的边缘增加2个单位的填充。
在这里插入图片描述

1.4.4 空洞卷积

在这里插入图片描述

1.4.5 分组卷积

同一种张图片使用两个不同的GPU进行训练,最后将两张GPU提取的特征进行融合。在这里插入图片描述

1.5 卷积输出尺寸计算

在这里插入图片描述
完整尺寸计算公式:
在这里插入图片描述
一般我们输入的图像都会进行预处理,将长宽变为相同大小,所以H,W两个公式可以看为相等。

1.6 卷积的维度

卷积维度:一般情况下,卷积核在几个维度上滑动,就是几维卷积,我们的图像是二维图像,卷积核的维度也是二维。
我们的图像是RGB三个通道,所以会在三个二维图像上进行滑动提取特征,最后将红绿蓝三个通道特征提取之后进行相加,得到一个output特征图。
在这里插入图片描述

1.7 转置卷积

转置卷积又称为反卷积(Deconvolution)和部分跨越卷积(Fractionallystrided Convolution) ,用于对图像进行上采样(UpSample)

1.7.1 为什么被称为转置卷积

正常卷积,图片经过卷积之后,等到的特征图尺寸会比原图小
在这里插入图片描述

而转置卷积经过卷积核之后会将原图尺寸方法常用于上采样,提升图片的尺度
在这里插入图片描述

在这里插入图片描述

1.7.2 nn.ConvTranspose2d

功能:转置卷积实现上采样
在这里插入图片描述
主要参数:
• in_channels:输入通道数
• out_channels:输出通道数
• kernel_size:卷积核尺寸
• stride:步长
• padding :填充个数
• dilation:空洞卷积大小
• groups:分组卷积设置
• bias:偏置

1.7.3 转置卷积的计算方法

在这里插入图片描述
完整版本:
在这里插入图片描述

1.7.4 核心代码

flag = 1
if flag:
    conv_layer = nn.ConvTranspose2d(3, 1, 3, stride=2)  # input:(i, o, size)
    nn.init.xavier_normal_(conv_layer.weight.data)
    # calculation
    img_conv = conv_layer(img_tensor)

输出结果:
在这里插入图片描述
在这里插入图片描述

二、池化层(Pooling Layer)

2.1 池化的概念

池化运算:对信号进行 “收集”并 “总结”,类似水池收集水资源,因而得名池化层,“收集”:多变少;“总结”:最大值/平均值

池化有最大池化和平均池化
最大池化:取池化范围内最大的数,下图中池化范围2x2,取每个池化范围内数值最大的
平均池化:取池化范围内的平均值,下图中池化范围2x2,取每个池化范围内数值之和,再求平均
在这里插入图片描述

2.2 nn.MaxPool2d

功能:对二维信号(图像)进行最大值池化
在这里插入图片描述
主要参数:
• kernel_size:池化核尺寸
• stride:步长
• padding :填充个数
• dilation:池化核间隔大小
• ceil_mode:尺寸向上取整
• return_indices:记录池化像素索引

2.2.1 代码实现

import os
import torch
import random
import numpy as np
import torchvision
import torch.nn as nn
from torchvision import transforms
from matplotlib import pyplot as plt
from PIL import Image
from common_tools import transform_invert, set_seed

set_seed(1)  # 设置随机种子

# ================================= load img ==================================
path_img = os.path.join(os.path.dirname(os.path.abspath(__file__)), "lena.png")
img = Image.open(path_img).convert('RGB')  # 0~255

# convert to tensor
img_transform = transforms.Compose([transforms.ToTensor()])
img_tensor = img_transform(img)
img_tensor.unsqueeze_(dim=0)  # C*H*W to B*C*H*W

# ================ maxpool
flag = 1
# flag = 0
if flag:
    maxpool_layer = nn.MaxPool2d((2, 2), stride=(2, 2)) #这里为什么池化和步长都设置(2,2),是为了保证每次池化的区域不重叠
    img_pool = maxpool_layer(img_tensor)
# ================================= 展示图像 ==================================
print("池化前尺寸:{}\n池化后尺寸:{}".format(img_tensor.shape, img_pool.shape))
img_pool = transform_invert(img_pool[0, 0:3, ...], img_transform)
img_raw = transform_invert(img_tensor.squeeze(), img_transform)
plt.subplot(122).imshow(img_pool)
plt.subplot(121).imshow(img_raw)
plt.show()

输出结果,图片大小为原来的一半:
在这里插入图片描述
在这里插入图片描述

2.3 nn.AvgPool2d

功能:对二维信号(图像)进行平均值池化
在这里插入图片描述
主要参数:
• kernel_size:池化核尺寸
• stride:步长
• padding :填充个数
• ceil_mode:尺寸向上取整
• count_include_pad:填充值用于计算
• divisor_override :除法因子

2.3.1 代码实现

核心代码:

flag = 1
# flag = 0
if flag:
    avgpoollayer = nn.AvgPool2d((2, 2), stride=(2, 2))  # input:(i, o, size) weights:(o, i , h, w)
    img_pool = avgpoollayer(img_tensor)

输出结果:
在这里插入图片描述
在这里插入图片描述

2.4 最大池化与平均池化区别

下面第一幅图是最大池化,第二幅图是平均池化,因为最大池化取的是一个区域内的最大值,所以第一幅图比第二幅图某些区域更亮,特征更明显。
在这里插入图片描述

2.5 nn.MaxUnpool2d

功能:对二维信号(图像)进行最大值池化进行上采样,但是需要根据池化中的最大值位置索引进行上采样,例如[1,2,0,1]经过最大池化,取第二个位置,当前索引为2,所以[3,2,1,7]进行上采样,其中3是在上采样后索引为2的位置上,其他区域为0.
在这里插入图片描述
在这里插入图片描述
主要参数:
• kernel_size:池化核尺寸
• stride:步长
• padding :填充个数

2.5.1 核心代码实现

flag = 1
if flag:
    # pooling
    img_tensor = torch.randint(high=5, size=(1, 1, 4, 4), dtype=torch.float) # 生成特征图
    maxpool_layer = nn.MaxPool2d((2, 2), stride=(2, 2), return_indices=True) # 设置池化层
    img_pool, indices = maxpool_layer(img_tensor) #获取池化后的数据以及索引

    # unpooling
    img_reconstruct = torch.randn_like(img_pool, dtype=torch.float) #根据img_poolshape随机构建数据
    maxunpool_layer = nn.MaxUnpool2d((2, 2), stride=(2, 2)) #搭建最大池化上采样层
    img_unpool = maxunpool_layer(img_reconstruct, indices)

    print("raw_img:\n{}\nimg_pool:\n{}".format(img_tensor, img_pool))
    print("索引位置:{}".format(indices))
    print("img_reconstruct:\n{}\nimg_unpool:\n{}".format(img_reconstruct, img_unpool))

输出结果:
在这里插入图片描述

三、线性层(Linear Layer)

线性层又称全连接层,其每个神经元与上一层所有神经元相连,实现对前一层的线性组合,线性变换。
在这里插入图片描述
在这里插入图片描述
输入的input=[1,2,3],经过加权相乘得到的hidden=[6,1,18,24]

3.1nn.Linear

功能:对一维信号(向量)进行线性组合
在这里插入图片描述
主要参数:
• in_features:输入结点数
• out_features:输出结点数
• bias :是否需要偏置
计算公式:y = 𝒙*𝑾𝑻 + 𝒃𝒊𝒂𝒔

代码实现:

flag = 1
if flag:
    inputs = torch.tensor([[1., 2, 3]])
    linear_layer = nn.Linear(3, 4)
    linear_layer.weight.data = torch.tensor([[1., 1., 1.],
                                             [2., 2., 2.],
                                             [3., 3., 3.],
                                             [4., 4., 4.]])

    linear_layer.bias.data.fill_(0.5)  # 偏执项,x*w+b
    output = linear_layer(inputs)
    print(inputs, inputs.shape)
    print(linear_layer.weight.data, linear_layer.weight.data.shape)
    print(output, output.shape)

输出结果:
在这里插入图片描述

四、激活函数层(Activation Layer)

4.1 概念

激活函数对特征进行非线性变换,赋予多层神经网络具有深度的意义。
为什么要使用激活函数呢,因为输入的特征只是通过线性变换,不过是经过多层网络都还是线性变换,就如下面这幅图的计算公式一样。
在这里插入图片描述

4.2 nn.Sigmoid激活函数

函数图像:
在这里插入图片描述
计算公式:
在这里插入图片描述

4.3 nn.tanh激活函数

函数图像:
在这里插入图片描述
计算公式:
在这里插入图片描述

4.4 nn.ReLU激活函数

函数图像:
在这里插入图片描述
计算公式:
在这里插入图片描述

4.5 ReLU变体形式

nn.LeakyReLU:在负半轴添加一点斜率;
nn.PReLU:将负半轴的斜率变为可学习的;
nn.RReLU:负半轴的斜率上下均匀分布;
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1345666.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

React Hooks 面试题 | 08.精选React Hooks面试题

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…

git回滚操作,常用场景

文章目录 git回滚操作1.git reset --hard 【版本号】2.回滚后的版本v2又想回到之前的版本v32.1 git reflog 3.git checkout -- 文件名4.git reset HEAD 文件名 git回滚操作 假设我们现在有三个版本 现在回滚一个版本 1.git reset --hard 【版本号】 发现只剩下两个版本了 2.…

html文件Js写输入框和弹框调接口jQuery

业务场景&#xff1a;需要使用写一个html文件&#xff0c;实现输入数字&#xff0c;保存调接口。 1、使用 JS原生写法&#xff0c; fetchAPI调接口&#xff0c;使用 alert 方法弹框会阻塞线程&#xff0c;所以写了一个弹框。 <!DOCTYPE html> <html lang"en"…

SpringMVC源码解析——DispatcherServlet初始化

在Spring中&#xff0c;ContextLoaderListener只是辅助功能&#xff0c;用于创建WebApplicationContext类型的实例&#xff0c;而真正的逻辑实现其实是在DispatcherServlet中进行的&#xff0c;DispatcherServlet是实现Servlet接口的实现类。Servlet是一个JAVA编写的程序&#…

机器学习(二) -- 数据预处理(3)

系列文章目录 机器学习&#xff08;一&#xff09; -- 概述 机器学习&#xff08;二&#xff09; -- 数据预处理&#xff08;1-3&#xff09; 未完待续…… 目录 前言 tips&#xff1a;这里只是总结&#xff0c;不是教程哈。本章开始会用到numpy&#xff0c;pandas以及matpl…

DFS BFS

用DFS和BFS分别实现 //这边给出DFS的模版 void dfs(int x,int y) {//判断是否到达终点&#xff08;只有给出结束点的时候需要&#xff09; if (x ex && y ey) {if (min_steps > step) {min_steps step;}return;}//给出移动方向int move[4][2] {{0, 1}, {0, -1}…

如何使用python脚本生成redis格式的数据包

用python脚本生成redis格式的数据包 &#xff08;1&#xff09;使用下述网站下载开源的生成gopher协议规则的包的工具 https://github.com/firebroo/sec_tools/tree/master/redis-over-gopher &#xff08;2&#xff09;首先要修改redis.cmd中的内容 flushall config set di…

Linux 运维工具之1Panel

一、1Panel 简介 1Panel 是一个现代化、开源的 Linux 服务器运维管理面板。 特点&#xff1a; 快速建站&#xff1a;深度集成 Wordpress 和 Halo&#xff0c;域名绑定、SSL 证书配置等一键搞定&#xff1b;高效管理&#xff1a;通过 Web 端轻松管理 Linux 服务器&#xff0…

第一讲:BeanFactory和ApplicationContext

BeanFactory和ApplicationContext 什么是BeanFactory 它是ApplicationContext的父接口它才是Spring的核心容器&#xff0c;主要的ApplicationContext实现都组合了它的功能 BeanFactory能做什么? 表面上看BeanFactory的主要方法只有getBean()&#xff0c;实际上控制反转、基…

力扣:63. 不同路径 II(动态规划)

题目&#xff1a; 一个机器人位于一个 m x n 网格的左上角 &#xff08;起始点在下图中标记为 “Start” &#xff09;。 机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角&#xff08;在下图中标记为 “Finish”&#xff09;。 现在考虑网格中有障碍物。那…

【Matlab】基于遗传算法优化BP神经网络 (GA-BP)的数据时序预测

资源下载&#xff1a; https://download.csdn.net/download/vvoennvv/88682033 一&#xff0c;概述 基于遗传算法优化BP神经网络 (GA-BP) 的数据时序预测是一种常用的机器学习方法&#xff0c;用于预测时间序列数据的趋势和未来值。 在使用这种方法之前&#xff0c;需要将时间序…

visual studio + intel Fortran 错误解决

版本&#xff1a;VS2022 intel Fortran 2024.0.2 Package ID: w_oneAPI_2024.0.2.49896 共遇到三个问题。 1.rc.exe not found 2.kernel32.lib 无法打开 3.winres.h 无法打开 我安装时参考的教程&#xff1a;visual studio和intel oneAPI安装与编写fortran程序_visual st…

小巧的Windows Memory Cleaner内存清理工具-释放内存,提升电脑的性能-供大家学习研究参考

软件介绍 Windows Memory Cleaner是一款非常不错的内存清理工具大小仅200KB&#xff0c;这款免费的 RAM 清理器使用本机 Windows 功能来清理内存区域&#xff0c;帮助用户释放内存&#xff0c;提升电脑的性能&#xff0c;有时程序不会释放分配的内存&#xff0c;从而使计算机变…

【Vue2+3入门到实战】(15)VUE路由入门声明式导航的基本使用与详细代码示例

目录 一、声明式导航-导航链接1.需求2.解决方案3.通过router-link自带的两个样式进行高亮4.总结 二、声明式导航-两个类名1.router-link-active2.router-link-exact-active3.在地址栏中输入二级路由查看类名的添加4.总结 三、声明式导航-自定义类名&#xff08;了解&#xff09…

日志高亮 | notepad

高亮显示日志 日志文件无法清晰看到关键问题所在? 看到一堆日志头疼?高亮日志可以清晰展示出日志的 ERROR级等各种等级的问题, 一下浏览出日志关键所在 tailspin 项目地址&#xff1a; https://githubfast.com/bensadeh/tailspin 使用Rust包管理器cargo安装 安装 - Cargo 手…

LeetCode二叉树路径和专题:最大路径和与路径总和计数的策略

目录 437. 路径总和 III 深度优先遍历 前缀和优化 124. 二叉树中的最大路径和 437. 路径总和 III 给定一个二叉树的根节点 root &#xff0c;和一个整数 targetSum &#xff0c;求该二叉树里节点值之和等于 targetSum 的 路径 的数目。 路径 不需要从根节点开始&#xf…

【MYSQL】-函数

&#x1f496;作者&#xff1a;小树苗渴望变成参天大树&#x1f388; &#x1f389;作者宣言&#xff1a;认真写好每一篇博客&#x1f4a4; &#x1f38a;作者gitee:gitee✨ &#x1f49e;作者专栏&#xff1a;C语言,数据结构初阶,Linux,C 动态规划算法&#x1f384; 如 果 你 …

在线智能防雷监控检测系统应用方案

在线智能防雷监控检测系统是一种利用现代信息技术&#xff0c;对防雷设施的运行状态进行实时监测、管理和控制的系统&#xff0c;它可以有效提高防雷保护的安全性、可靠性和智能化程度&#xff0c;降低运维成本和风险&#xff0c;为用户提供全方位的防雷解决方案。 地凯科技在…

Vue常见面试问答

vue响应式数据 vue2 Vue2 的对象数据是通过 Object.defineProperty 对每个属性进行监听&#xff0c;当对属性进行读取的时候&#xff0c;就会触发 getter&#xff0c;对属性进行设置的时候&#xff0c;就会触发 setter。 /** * 这里的函数 defineReactive 用来对 Object.def…