Pytorch学习:卷积神经网络—nn.Conv2d、nn.MaxPool2d、nn.ReLU、nn.Linear和nn.Dropout

news2024/9/20 8:59:50

文章目录

    • 1. torch.nn.Conv2d
    • 2. torch.nn.MaxPool2d
    • 3. torch.nn.ReLU
    • 4. torch.nn.Linear
    • 5. torch.nn.Dropout

卷积神经网络详解:csdn链接
其中包括对卷积操作中卷积核的计算、填充、步幅以及最大值池化的操作。

1. torch.nn.Conv2d

对由多个输入平面组成的输入信号应用2D卷积。

官方文档:torch.nn.Conv2d
CLASS torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)
主要参数:

  • in_channels(int):输入图像中的通道数
  • out_channels(int):卷积产生的通道数
  • kernel_size(int或tuple):卷积内核的大小

默认参数:

  • stride(int或tuple,可选):卷积的步幅。默认值:1
  • padding(int,tuple或str,可选):添加到输入的所有四边的填充。默认值:0
  • padding_mode(str,可选):‘zeros’ 、 ‘reflect’ 、 ‘replicate’ 或 ‘circular’ 。默认值: ‘zeros’
  • dilation(int或tuple,可选):内核元素之间的间距。默认值:1
  • groups(int,可选):从输入通道到输出通道的阻塞连接数。默认值:1
  • bias(bool,可选):如果 True ,则向输出添加可学习的偏置。默认值: True

举例说明

import torch
from torch import nn

# 内核方正,步调一致
m1 = nn.Conv2d(16, 33, 3, stride=2)
# 非方形内核和不等距步长,并有填充
m2 = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
# 非方形内核和不等步长,以及填充和扩张
m3 = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))

卷积核通过选取内核大小,其中参数值是随机的

import torch
from torch import nn

k = nn.Conv2d(1, 1, 3, stride=1)
print(list(k.parameters()))

在这里插入图片描述

使用CIFAR10数据进行卷积操作,并进行可视化操作

import torch
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 使用CIFAR10的训练数据
train_data = torchvision.datasets.CIFAR10("../dataset", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
train_loader = DataLoader(train_data, batch_size=64)

writer = SummaryWriter("logs")


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.model1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1)

    def forward(self, x):
        x = self.model1(x)
        return x


step = 0
# 创建模型
model = Model()
for data in train_loader:
    imgs, targets = data
    output = model(imgs)
    writer.add_images("input", imgs, step)
    # output此时为6通道,为了可视化,现将其转换为3通道
    # imgs 的形状大小为 [64, 3, 32, 32]
    # 经过卷积操作后,output的高度和宽度为:32 - 3 + 1 = 30
    output = torch.reshape(output, (-1, 3, 30, 30))
    writer.add_images("output", output, step)
    step = step + 1

writer.close()

打开命令行,输入以下代码,并打开TensorBoard的链接: http://localhost:6006/

tensorboard --logdir=logs

在这里插入图片描述

2. torch.nn.MaxPool2d

最大值池化层,对由多个输入平面组成的输入信号应用2D最大池化。

官方文档:torch.nn.MaxPool2d
CLASS torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)
参数 kernel_size 、 stride 、 padding 、 dilation 可以是:

  • 单个 int-在这种情况下,高度和宽度尺寸使用相同的值
  • 两个int的 tuple -在这种情况下,第一个int用于高度维度,第二个int用于宽度维度

主要参数:

  • kernel_size(Union[int,Tuple[int,int]])-窗口的最大值

默认参数:

  • stride(Union[int,Tuple[int,int]])-窗口的步幅。默认值为 kernel_size
  • padding(Union[int,Tuple[int,int]])-要在两边添加的隐式负无穷大填充
  • dilation(Union[int,Tuple[int,int]])-控制窗口中元素步幅的参数
  • return_indices(bool)-如果 True ,将返回最大索引沿着输出。 torch.nn.MaxUnpool2d 以后有用
  • ceil_mode(bool)-当为True时,将使用ceil而不是floor来计算输出形状

最大汇聚层,也叫做最大池化层,代码实现

import torch
import torchvision.datasets
from torch import nn
from torch.nn import MaxPool2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 使用测试数据集
dataset = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)
dataloader = DataLoader(dataset, batch_size=64)

writer = SummaryWriter("logs")


class Model(nn.Module):
    def __init__(self):
        super(Module, self).__init__()
        self.maxpool1 = MaxPool2d(kernel_size=3, ceil_mode=False)

    def forward(self, input):
        output = self.maxpool1(input)
        return output

# 创建模型
model = Model()
step = 0
for data in dataloader:
    imgs, targets = data
    output = model(imgs)
    writer.add_images("input", imgs, step)
    writer.add_images("output_maxpool", output, step)
    step = step + 1

writer.close()

在这里插入图片描述

3. torch.nn.ReLU

非线性激活函数,逐元素应用整流线性单位函数:
输入与输出形状相同
在这里插入图片描述

官方文档:torch.nn.ReLU
CLASS torch.nn.ReLU(inplace=False)
主要参数:

  • inplace(bool):可以选择就地执行操作。默认值: False
    在这里插入图片描述

代码实现

import torch
import torchvision.datasets
from torch import nn
from torch.nn import MaxPool2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 使用测试数据集
dataset = torchvision.datasets.CIFAR10("../dataset", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)
dataloader = DataLoader(dataset, batch_size=64)

writer = SummaryWriter("logs")


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.relu1 = nn.ReLU()

    def forward(self, input):
        output = self.relu1(input)
        return output


# 创建模型
model = Model()
step = 0
for data in dataloader:
    imgs, targets = data
    output = model(imgs)
    writer.add_images("input", imgs, step)
    writer.add_images("output", output, step)
    step = step + 1

writer.close()

在这里插入图片描述
nn.ReLU()不是很明显,这里用nn.Sigmoid()更为清除
在这里插入图片描述

4. torch.nn.Linear

线性层,对传入数据应用线性变换:在这里插入图片描述

官方文档:torch.nn.Linear
CLASS torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)
主要参数:

  • in_features(int):每个输入样本的大小
  • out_features(int):每个输出样本的大小

默认参数:

  • bias(bool):如果设置为 False ,则层将不会学习加性偏置。默认值: True
import torch
import torchvision.datasets
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)
dataloader = DataLoader(dataset, batch_size=64)


class Module(nn.Module):
    def __init__(self):
        super(Module, self).__init__()
        self.linear1 = Linear(196608, 10)

    def forward(self, input):
        output = self.linear1(input)
        return output


module = Module()


for data in dataloader:
    imgs, targets = data
    print(imgs.shape)
    # output = torch.reshape(imgs, (1, 1, 1, -1))
    output = torch.flatten(imgs)
    print(output.shape)
    output = module(output)
    print(output.shape)

在这里插入图片描述

5. torch.nn.Dropout

在训练期间,使用来自伯努利分布的样本以概率 p 随机地将输入张量的一些元素归零。每个信道将在每次前向呼叫时独立地归零。

官方文档:torch.nn.Dropout
CLASS torch.nn.Dropout(p=0.5, inplace=False)
主要参数:

  • p(float)-元素被置零的概率。默认值:0.5
  • inplace(bool)-如果设置为 True ,将就地执行此操作。默认值: False

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/977767.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ChatGPT AIGC 完成超炫酷的大屏可视化

大屏可视化一直各大企业进行数据决策的重要可视化方式,接下来我们先来看一下ChatGPT,AIGC人工智能帮我们实现的综合案例大屏可视化效果: 像这样的大屏可视化使用HTML,JS,Echarts就可以来完成,给ChatGPT,AIGC发送指令的同时可以将数据一起发送给ChatGPT。 第一段指令加数…

Direct3D绘制旋转立方体例程

初始化文件见Direct3D的初始化_direct3dcreate9_寂寂寂寂寂蝶丶的博客-CSDN博客 D3DPractice.cpp #include <windows.h> #include "d3dUtility.h" #include <d3dx9math.h>IDirect3DDevice9* Device NULL; IDirect3DVertexBuffer9* VB NULL; IDirect3…

【C语言】入门——结构体

目录 结构体 为什么有结构体&#xff1f; 1.结构体的声明 1.2结构体变量的访问和初始化 2.结构体成员的访问 结构体 struct 结构体类型 {//相关属性; }结构体变量; 结构体和数组不同&#xff0c;同一类型的数据的集合是数组&#xff1b; 结构体是多种类型的数据的集合&…

NSV60600MZ4T1G 双极型晶体管(BJT)学习总结

双极型晶体管的起源: 双极型晶体管是在1947年发明的&#xff0c;第一个晶体管是将两条具有尖锐端点的金属线与锗衬底(germanium substrate)形成点接触(point contact)&#xff0c;以今天的水准来看&#xff0c;此第一个晶体管虽非常简陋但它却改变了整个电子工业及人类的生活方…

CANdelaStudio CDD编写方法

本文是基于CANdelaStudio12.0讲解 一.把DTC从Excel导入cdd的方法 问题一&#xff1a;当导入DTC的xxx.cdi文件报如下红色错误 可能原因&#xff1a;在设置具有下拉框的属性的内容时&#xff0c;输入的内容不在下拉框列表中 解决办法:在.cddt文件中更新“”Error Code Table“”…

通达信趋向指标DMI公式详解

DMI指标(Directional Movement Index)也称趋向指标或动向指标&#xff0c;是用于衡量市场的趋势方向以及趋势强度的一种技术指标&#xff0c;由著名的技术派大师威尔斯威尔德(Welles Wilder)于1978年发表在《技术交易系统新概念》这本书中。威尔斯威尔德(Welles Wilder)这位大佬…

企微SCRM营销平台MarketGo-ChatGPT助力私域运营

一、前言 ChatGPT是由OpenAI&#xff08;开放人工智能&#xff09;研发的自然语言处理模型&#xff0c;其全称为"Conversational Generative Pre-trained Transformer"&#xff0c;即对话式预训练转换器。它是GPT系列模型的最新版本&#xff0c;GPT全称为"Gene…

springboot项目中application.properties无法变成小树叶问题解决

1.检查我们的resources目录的状态&#xff0c;看看是不是处在普通文件夹的状态&#xff0c;如果是的话&#xff0c;我们需要重新mark一下 右键点击文件夹&#xff0c;选择mark directory as → resources root 此时我们发现配置文件变成了小树叶 2.如果执行了上述方法还是不行…

[uniapp]踩坑日记 unexpected character > 1或‘=’>1 报错

在红色报错文档里下滑&#xff0c;找到Show more 根据提示看是缺少标签&#xff0c;如果不是缺少标签&#xff0c;看看view标签内容是否含有<、>、>、<号,把以上符合都进行以<号为例做{{“<”}}处理

超详细-Vivado配置Sublime+Sublime实现VHDL语法实时检查

目录 一、前言 二、准备工作 三、Vivado配置Sublime 3.1 Vivado配置Sublime 3.2 环境变量添加 3.3 环境变量验证 3.4 Vivado设置 3.5 配置验证 3.6 解决Vivado配置失败问题 四、Sublime配置 4.1 Sublime安装Package Control 4.2 Sublime安装VHDL插件 4.3 语法检查…

排序相关问题

本篇博客在B站做了内部分享,标题为「排序相关问题」 MySQL的ORDER BY有两种排序实现方式&#xff1a; 利用有序索引获取有序数据 (不得不进行)文件排序 在explain中分析时&#xff0c;利用有序索引获取有序数据显示Using index&#xff0c;文件排序显示Using filesort。 1. 能够…

macm1环境下jdk版本切换

macm1环境下jdk版本切换 本文目录 macm1环境下jdk版本切换下载jdk安装动态切换jdk终端生效全局生效 参考 下载jdk oracle官方源下载地址 https://www.oracle.com/java/technologies/downloads/#jdk17-mac Azul下载地址 https://www.azul.com/downloads/?packagejdk#download…

Autosar-Runnables(可运行实体)

文章目录 Runnable entities (简称Runnables)一、Runnables的定义二、Runnables的作用三、DaVinci配置总结Runnable entities (简称Runnables) 包含实际实现的函数(具体的逻辑算法或者操作) Runables由RTE周期性、或事件触发调用(如,当接收到数据、被操作调用) 一、Runna…

时序预测 | MATLAB实现ELM极限学习机时间序列预测未来

时序预测 | MATLAB实现ELM极限学习机时间序列预测未来 目录 时序预测 | MATLAB实现ELM极限学习机时间序列预测未来预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.MATLAB实现ELM极限学习机时间序列预测未来&#xff1b; 2.运行环境Matlab2018及以上&#xff0c;data为数…

MYSQL MHA实现故障转移和自动切换

目录 1、MHA理论&#xff1a; 1.1、MHA概述 1.2、MHA的组成&#xff1a; 1.3、特点&#xff1a; 1.4、传统的MySQL主从架构存在一些常见的问题&#xff1a; 1.5、MHA工作原理总结如下 1.6、 故障切换备选主库的算法&#xff1a; 2、 故障转移实验 2.1、搭建 MySQL MHA…

Linux知识点 -- 网络编程套接字

Linux知识点 – 网络编程套接字 文章目录 Linux知识点 -- 网络编程套接字一、预备知识1.认识端口号2.套接字3.TCP协议与UDP协议4.网络字节序 二、socket编程接口1.socket常见API2.sockaddr结构 三、UDP套接字编程1.直接打印客户端信息2.执行客户端发来的指令3.多用户聊天4.在wi…

ALBEF、VLMO、BLIP、BLIP2、InstructBLIP要点总结(WIP)

ALBEF&#xff08;ALign BEfore Fuse&#xff09; 为什么有5个loss&#xff1f; 两个ITC两个MIM1个ITM。ITM是基于ground truth的&#xff0c;必须知道一个pair是不是ground truth&#xff0c;同时ITM loss是用了hard negative&#xff0c;这个是和Momentum Distillation&…

优化爬虫效率:利用HTTP代理进行并发请求

网络爬虫作为一种自动化数据采集工具&#xff0c;广泛应用于数据挖掘、信息监测等领域。然而&#xff0c;随着互联网的发展和网站的增多&#xff0c;单个爬虫往往无法满足大规模数据采集的需求。为了提高爬虫的效率和性能&#xff0c;我们需要寻找优化方法。本文将介绍一种利用…

(位运算) 剑指 Offer 56 - I. 数组中数字出现的次数 ——【Leetcode每日一题】

❓剑指 Offer 56 - I. 数组中数字出现的次数 难度&#xff1a;中等 一个整型数组 nums 里除两个数字之外&#xff0c;其他数字都出现了两次。请写程序找出这两个只出现一次的数字。要求时间复杂度是 O ( n ) O(n) O(n)&#xff0c;空间复杂度是 O ( 1 ) O(1) O(1)。 示例 …

开源对象存储系统minio部署配置与SpringBoot客户端整合访问

文章目录 1、MinIO安装部署1.1 下载 2、管理工具2.1、图形管理工具2.2、命令管理工具2.3、Java SDK管理工具 3、MinIO Server配置参数3.1、启动参数&#xff1a;3.2、环境变量3.3、Root验证参数 4、MinIO Client可用命令 官方介绍&#xff1a; MinIO 提供高性能、与S3 兼容的对…