【学习笔记】【Pytorch】一、卷积层

news2025/1/14 1:20:54

【学习笔记】【Pytorch】一、卷积层

  • 学习地址
  • 主要内容
    • 一、卷积操作示例
    • 二、Tensor(张量)是什么?
    • 三、functional.conv2d函数的使用
      • 1.使用说明
      • 2.代码实现
    • 四、torch.Tensor与torch.tensor区别
    • 五、nn.Conv2d类的使用
      • 1.使用说明
      • 2.代码实现
    • 六、卷积公式

学习地址

PyTorch深度学习快速入门教程【小土堆】.

主要内容

一、卷积操作示例
二、Tensor(张量)是什么?
三、functional.conv2d函数的使用
作用:对几个输入平面组成的输入信号应用2D卷积。
四、torch.Tensortorch.tensor区别
作用:图片尺寸缩放。
五、nn.Conv2d类的使用
作用:二维卷积层, 输入的尺度是(N, C_in,H,W),输出尺度(N,C_out,H_out,W_out)。
六、卷积公式

一、卷积操作示例

在这里插入图片描述
:卷积核的大小是自己设置的,初始参数是自定义的,卷积核上每个位置相当于一个权重w,比如一个3*3的卷积核,就是9个w,训练网络的目的就是学习这9个权值。

二、Tensor(张量)是什么?

参考

  • 什么是张量?

总结

  • 我们通常需要处理的数据有零维的(单纯的一个数字)、一维的(数组)、二维的(矩阵)、三维的(空间矩阵)、还有很多维的。Pytorch为了把这些各种维统一起来,所以起名叫张量
  • 标量视为零阶张量,矢量视为一阶张量,矩阵视为二阶张量。多一个维度,我们就多加一个[]。例如三维张量,torch.tensor([[[9,1,8],[6,7,5],[3,4,2]],[[2,9,1],[8,6,7],[5,3,4]],[[1,5,9],[7,2,6],[4,8,3]]])。
  • PyTorch中的Tensor支持超过一百种操作,包括转置、索引、切片、数学运算(加法、减法、点乘…)、线性代数、随机数等等,总之,凡是你能想到的操作,在pytorch里都有对应的方法去完成。PyTorch学习笔记(二):Tensor操作。

三、functional.conv2d函数的使用

import torch.nn.functional as F

1.使用说明

【实例化】torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)

  • 作用:对几个输入平面组成的输入信号应用2D卷积。
  • input: 输入张量, (minibatch, in_channels, iH, iW),即tensor_data.shape 应该有四个参数。
    • minibatch -即batch_size,每个batch(批次)要加载多少个样本
    • in_channels - 输入通道数
    • iH、iW - 数据形状
  • weight – 卷积核(过滤器)张量 ,(out_channels, in_channels/groups 整除, kH, kW) ,即tensor_data.shape 应该有四个参数。
  • bias – 可选偏置张量 (out_channels)
  • stride – 卷积核的步长,左移和下移的步长,可以是单个数字或一个元组 (sh x sw)。默认为1。
  • padding – 输入上隐含零填充。可以是单个数字或元组。 默认值:0。
  • dilation – 核元素之间的间距。默认值:1。
  • groups – 将输入分成组,in_channels应该被组数除尽。默认值:1。

2.代码实现

import torch
import torch.nn.functional as F

# 二维张量
input = torch.tensor([[1, 2, 0, 3, 1],
                     [0, 1, 2, 3, 1],
                     [1, 2, 1, 0, 0],
                     [5, 2, 3, 1, 1],
                     [2, 1, 0, 1, 1]])

# 3*3的卷积核,二维张量
kernel =torch.tensor([[1, 2, 1],
                     [0, 1, 0],
                     [2, 1, 0]])

# 转化为四维张量,通道数是1,batch_size(数据个数)的大小为1,数据维度是5*5
input = torch.reshape(input, (1, 1, 5, 5))
# 转化为四维张量,通道数是1,batch_size(数据个数)的大小为1,数据维度是5*5
kernel = torch.reshape(kernel, (1, 1, 3, 3))

# print(input.shape) == print(input.size())
print(input.shape)  # torch.Size([1, 1, 5, 5])
print(kernel.size())  # torch.Size([1, 1, 3, 3])

# 上下步进为1的卷积操作
output = F.conv2d(input, kernel, stride=1)
print('\n', output)  # 四维张量

# 上下步进为2的卷积操作
output = F.conv2d(input, kernel, stride=2)
print('\n', output)  # 四维张量

# 上下步进为1、启用 1层0填充卷积
output = F.conv2d(input, kernel, stride=1, padding=1)
print('\n', output)  # 四维张量

控制台输出

torch.Size([1, 1, 5, 5])
torch.Size([1, 1, 3, 3])

 tensor([[[[10, 12, 12],
          [18, 16, 16],
          [13,  9,  3]]]])

 tensor([[[[10, 12],
          [13,  3]]]])

 tensor([[[[ 1,  3,  4, 10,  8],
          [ 5, 10, 12, 12,  6],
          [ 7, 18, 16, 16,  8],
          [11, 13,  9,  3,  4],
          [14, 13,  9,  7,  4]]]])

总结
将二维张量转化为四维张量才能满足torch.nn.functional.conv2d函数的输入要求,转化结果也就是添加了两个[ ]。

四、torch.Tensor与torch.tensor区别

参考
torch.Tensor()与torch.tensor()
总结

  • torch.Tensor()是python类,调用torch.Tensor([1,2, 3, 4, 5])来构造一个tensor的时候,会调用Tensor类的构造函数,生成一个单精度浮点类型的张量。它不能指定数据类型,除非转成一个已知数据类型的张量,使用type_as(tesnor)将张量转换为给定类型的张量。

  • torch.tensor()是python的函数,其中data可以是list,tuple,NumPy,ndarray等其他类型,torch.tensor(data)会从data中的数据部分做拷贝(而不是直接引用),根据原始数据类型生成相应的torch.LongTensor torch.FloatTensor和torch.DoubleTensor。通过设置dtype的函数参数值,生成对应类型的张量。

五、nn.Conv2d类的使用

from torch.nn import Conv2d

在这里插入图片描述

作用:二维卷积层, 输入的尺度是(N, C_in,H,W),输出尺度(N,C_out,H_out,W_out)。

1.使用说明

【实例化】Conv2d(
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = ‘zeros’, # TODO: refine this type
device=None,
dtype=None
)

  • 作用:创建一个二维卷积层,的实例,卷积核初始参数随机初始化。
  • in_channels(int) – 输入信号的通道
    out_channels(int) – 卷积产生的通道
    kerner_size(int or tuple) - 卷积核的尺寸
    stride(int or tuple, optional) - 卷积步长
    padding(int or tuple, optional) - 输入的每一条边补充0的层数
    dilation(int or tuple, optional) – 卷积核元素之间的间距
    groups(int, optional) – 从输入通道到输出通道的阻塞连接数
    bias(bool, optional) - 如果bias=True,添加偏置
  • 例子:
# 用conv1变量存储一个 Conv2d 实例
# in_channels表示3通道图片数据
 # out_channels表示输出通道数(这里3通道变6通道,一般是2个卷积核进行卷积得到的)
conv1 = Conv2d(in_channels=3, out_channels=6,
                            kernel_size=3, stride=1, padding=0)
  • 注:in_channels,彩色就是3,灰度就是1。

【_call_】conv1(x)
例子:

    def forward(self, x):
        x = self.conv1(x)
        return x

2.代码实现

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


class Model(nn.Module):
    def __init__(self):
        super().__init__()  # 父类参数初始化
        # 用conv1变量存储一个 Conv2d 实例
        # in_channels表示3通道图片数据
        # out_channels表示输出通道数(这里3通道变6通道,一般是2个卷积核进行卷积得到的)
        self.conv1 = Conv2d(in_channels=3, out_channels=6,
                            kernel_size=3, stride=1, padding=0)

    def forward(self, x):
        x = self.conv1(x)
        return x




dataset = torchvision.datasets.CIFAR10(root="./dataset", train=False,
                                       transform=torchvision.transforms.ToTensor(),
                                       download=True)  # 创建一个 CIFAR10 实例
dataloader = DataLoader(dataset=dataset, batch_size=64)  # 创建一个 DataLoader 实例

nn_model = Model()  # 创建一个 Model 实例
# print(model)
# Model((conv1): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1)))

writer = SummaryWriter("./dataloader_logs")  # 创建一个SummaryWriter实例
step = 0
for data in dataloader:
    imgs, targets = data
    output = nn_model(imgs)
    print(imgs.size())
    # torch.Size([64, 3, 32, 32]) batch_size=64,in_channels=3,图片尺寸
    print(output.shape)
    # torch.Size([64, 6, 30, 30]) batch_size=64,out_channels=6,图片尺寸

    writer.add_images("input", imgs, step)

    # 因为不能显示6通道的图片,所以使用reshape()转化为3通道图片,-1表示自动设置batch_size
    output = torch.reshape(output, (-1, 3, 30, 30))
    writer.add_images("output", output, step)

    step += 1


# tensorboard命令:tensorboard --logdir=dataloader_logs --port=6007

输出

Files already downloaded and verified
torch.Size([64, 3, 32, 32])
torch.Size([64, 6, 30, 30])
torch.Size([64, 3, 32, 32])
torch.Size([64, 6, 30, 30])
....
....

TensorBoard输出
在这里插入图片描述
output的batch_size变成了2*64=128。

六、卷积公式

参考:torch.nn.Conv2d
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/162706.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于servlet+mysql+jsp实现鞋子商城系统

基于servletmysqljsp实现鞋子商城系统一、系统介绍1、系统主要功能:2、环境配置二、功能展示1.主页(客户)2.用户登陆、个人中心(客户)3.商品分类(客户)3.我的购物车(客户)4.我的订单(客户)5.订单…

微信小程序页面导航、编程式导航、页面事件、生命周期和WXS脚本

文章目录页面导航1.导航到tarBar页面2.导航到非 tabBar 页面3.后退导航编程式导航1.导航到tabBar页面2.导航到非 tabBar 页面3.后退导航导航传参1. 声明式导航传参2. 编程式导航传参3. 在 onLoad 中接收导航参数页面事件下拉刷新上拉触底数据请求获取中添加loading效果,请求完毕…

一本修炼秘籍,带你打穿文件上传的21层妖塔(1)

目录 前言 引子 第一层:JS限制——你在玩一种很新的防御 第二层:Content-Type限制——我好像在哪见过你 第三层:黑名单绕过——让我康康! 前言 🍀作者简介:被吉师散养、喜欢前端、学过后端、练过CTF、…

Jetpack Compose中的副作用Api

Compose的生命周期 每个Composable函数最终会对应LayoutNode节点树中的一个LayoutNode节点,可简单的为其定义生命周期: onActive: 进入重组作用域, Composable对应的LayoutNode节点被挂接到节点树上onUpdate:触发重组&#xff0c…

Dolphin scheduler在Windows环境下的部署与开发

这里写自定义目录标题环境介绍WSL2工程下载修改POM文件java版本mysql驱动修改mysql密码IDEA配置JDK8模块导出运行配置环境介绍 MySql:8.0.31 JDK:17 需要安装windows的wsl2 WSL2 首先安装好WSL2,并且通过 sudo apt-get install openjdk-17…

类模板与模板类

#include <stdio.h>#include <iostream>using namespace std;//注意必须将类的声明和定义写在同一个.h文件中 未来把它包含进来//写上关键字template 和模板参数列表template<typename T, int KSize, int KVal>class MyArray{public:MyArray();//当在类内定义…

正点原子STM32(基于HAL库)2

目录STM32 基础知识入门寄存器基础知识STM32F103 系统架构Cortex M3 内核& 芯片STM32 系统架构存储器映射寄存器映射新建寄存器版本MDK 工程STM32 基础知识入门 寄存器基础知识 寄存器&#xff08;Register&#xff09;是单片机内部一种特殊的内存&#xff0c;它可以实现…

【自学Docker】Docker HelloWorld

Docker HelloWorld Docker服务 查看Docker服务状态 使用 systemctl status docker 命令查看 Docker 服务的状态。 haicoder(www.haicoder.net)# systemctl status docker我们使用 systemctl status docker 命令查看 Docker 服务的状态&#xff0c;显示结果如下图所示&#…

HotPDF Delphi PDF编译器形成PDF文档

HotPDF Delphi PDF编译器形成PDF文档 HotPDF Delphi PDF编译器支持通过内部和外部链接完全形成PDF文档。计算机还完全支持Unicode。此外&#xff0c;在您的产品和软件中使用此计算机的最新功能&#xff0c;您可以指定加密、打印和编辑PDF文档的能力。当您加密PDF文档时&#xf…

Markdown总结

为什么要使用Markdowm 什么是Markdown?为什么需要使用Markdown&#xff1f; Markdown 是一种轻量级标记语言&#xff0c;它允许人们使用易读易写的纯文本格式编写文档。 Markdown 语言在 2004 由约翰格鲁伯&#xff08;英语&#xff1a;John Gruber&#xff09;创建。 Markdo…

openEuler 社区 2022 年 12 月运作报告

社区活跃度在社区所有开发者和用户的共同参与下&#xff0c;openEuler的3年持续迸发活力&#xff01;从0到超过1.27万名开发者&#xff0c;从0到超过100万的社区用户&#xff0c;从0到超过750家企业伙伴加入社区……截至目前&#xff0c;在大家的持续贡献下&#xff0c;openEul…

GemBox.Bundle 47.0.1012 VS Spire.Office Platinum 8.1.1

GemBox.Bundle 是一个 .NET 组件包&#xff0c;使您能够简单高效地处理办公文件&#xff08;电子表格、文档、演示文稿和电子邮件&#xff09;。 使用我们的组件&#xff0c;您可以以易于使用的形式快速获得可靠的结果。只需要 .NET&#xff0c;因此您可以轻松部署您的应用程序…

收官!OceanBase第五届技术征文大赛获奖名单公布!

OceanBase 一直在思考&#xff0c;什么样的数据库对用户而言更易用&#xff1f; 更易用&#xff0c;除了功能完善、性能优秀、运行稳定的数据库系统&#xff0c;丰富多样的生态工具也必不可少。 作为一款完全自主研发的原生分布式数据库&#xff0c;OceanBase 的生态工具经历…

基于Java SSM springboot+VUE+redis实现的前后端分类版网上商城项目

基于Java SSM springbootVUEredis实现的前后端分类版网上商城项目 博主介绍&#xff1a;5年java开发经验&#xff0c;专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 超级帅帅吴 Java毕设项目精品实战案例《500套》 欢迎点赞 收藏 ⭐留言…

路由 OSPF LSA介绍、1~7类LSA详细介绍

1.0.0 路由 OSPF LSA介绍、1~7类LSA详细介绍 OSPF LSA 链路状态通告( Link status announcement)&#xff0c;作用于 向其它邻接OSPF路由器 传递拓扑信息与路由信息。 LSA如何去描述拓扑信息与路由信息的呢&#xff1f; 其实是基于不同类型LSA进行描述&#xff0c;而常见的LS…

EquiBind模型源码分析

EquiBind模型源码分析 使用提供的模型权重来预测你自己的蛋白质配体对的结合结构 第 1 步:你需要什么作为输入 mol2或.sdf或.pdbqt或.pdb格式的配体文件&#xff0c;其名称包含字符串配体(配体文件应包含所有氢)。 .pdb格式的受体文件&#xff0c;其名称包含字符串protein。我…

leetcode.1819 序列中不同最大公约数的数目 - gcd + 枚举

1819. 序列中不同最大公约数的数目 目录 1、java版 2、c版 思路&#xff1a; 有n个元素的数组&#xff0c;则其子序列有 个&#xff0c;而1 ≤ n ≤ &#xff0c;则不可能枚举每一个子序列计算它的gcd&#xff0c;那样会tle我们可以逆转思路&#xff0c;因为1 ≤ nums[i]…

PhysioNet2017数据集介绍

一、数据集下载 PhysioNet2017为短单导联心电图记录的房颤分类数据集&#xff0c;下载地址如下&#xff1a;https://www.physionet.org/content/challenge-2017/1.0.0/ 二、数据集介绍 PhysioNet2017数据集主要用于对记录是否显示正常窦性心律、心房颤动&#xff08;AF&…

背包问题= =

一、01背包有 N 件物品和一个容量是 V 的背包。每件物品只能使用一次。第 i件物品的体积是 vi&#xff0c;价值是 wi。求解将哪些物品装入背包&#xff0c;可使这些物品的总体积不超过背包容量&#xff0c;且总价值最大。输出最大价值。&#xff08;下图是例子&#xff0c;一下…

14、ThingsBoard-自定义华为云SMS规则节点

1、概述 一个物联网平台承载着很多设备的连接,当设备出现异常的时候,能够快速的通知到运维管理员是非常重要的,thingsboard提供了自定义配置邮箱,但是它对支持发送短信的不是很友好,都是国外的sms服务商,我反正是不用那个,在国内常见就是阿里、腾讯、华为、七牛常用的s…