pytorch初学笔记(九):神经网络基本结构之卷积层

news2024/10/6 8:36:03

目录

一、torch.nn.CONV2D

 1.1 参数介绍

 1.2 stride 和 padding 的可视化

1.3 输入、输出通道数

1.3.1 多通道输入

1.3.2 多通道输出

二、卷积操作练习

2.1 数据集准备

2.2 自定义神经网络

2.3 卷积操作控制台输出结果

2.4 tensorboard可视化

三、完整代码 


一、torch.nn.CONV2D

官方文档: 

torch.nn — PyTorch 1.13 documentation

 1.1 参数介绍

 常用的参数主要是前五个。

(31条消息) Pytorch中dilation(Conv2d)参数详解_MaZhe丶的博客-CSDN博客_conv2d参数解释

  • in_channels (int) – Number of channels in the input image,输入图片的通道数

  • out_channels (int) – Number of channels produced by the convolution,输出图片的通道数, 代表卷积核的个数,使用n个卷积核输出的特征矩阵深度即channel就是n

  • kernel_size (int or tuple) – Size of the convolving kernel,卷积核的大小

                e.g. if kernel size = 3, 则卷积核的大小是3*3                                        

  • stride (int or tupleoptional) – Stride of the convolution. Default: 1,步径大小

  • padding (inttuple or stroptional) – Padding added to all four sides of the input. Default: 0

  • padding_mode (stroptional) – 'zeros''reflect''replicate' or 'circular'. Default: 'zeros'

  • dilation (int or tupleoptional) – Spacing between kernel elements. Default: 1

  • groups (intoptional) – Number of blocked connections from input channels to output channels. Default: 1

  • bias (booloptional) – If True, adds a learnable bias to the output. Default: True

如何进行输出图片长和宽的计算?

 

 1.2 stride 和 padding 的可视化

conv_arithmetic/README.md at master · vdumoulin/conv_arithmetic · GitHub

1.3 输入、输出通道数

(31条消息) 【卷积神经网络】多输入通道和多输出通道(channels)_ZSYL的博客-CSDN博客_多输出通道

(31条消息) (pytorch-深度学习系列)CNN的多输入通道和多输出通道_我是一颗棒棒糖的博客-CSDN博客

彩色图像在高和宽2个维度外还有RGB(红、绿、蓝)3个颜色通道。假设彩色图像的高和宽分别是h hh和w ww(像素),那么它可以表示为一个3 × h × w 的多维数组。我们将大小为3的这一维称为通道(channel)维

1.3.1 多通道输入

1.3.2 多通道输出

多通道输出的通道数 = 卷积核的个数

二、卷积操作练习

2.1 数据集准备

使用cifar10数据集的验证集进行操作。 

import torch
import torchvision.datasets
from torch.nn import Conv2d
from torch.utils.data import DataLoader
#数据集准备
dataset = torchvision.datasets.CIFAR10(root=".\CIFAR10",train=False,transform=torchvision.transforms.ToTensor(),download=True)
#使用dataloader加载数据集,批次数为64
dataloader = DataLoader(dataset,batch_size=64)

2.2 自定义神经网络

  • 自定义名为Maweiyi的神经网络。
  • 自定义的神经网络要继承nn.Module框架。
  • 重写init和forward方法。
  • 该神经网络调用conv2d进行一层卷积,输入通道为3层(彩色图像为3通道),卷积核大小为3*3,输出通道为6。
  • 设置步长为1,padding为0,不进行填充。
import torch
import torchvision.datasets
from torch.nn import Conv2d
from torch.utils.data import DataLoader

class Maweiyi(torch.nn.Module):
    def __init__(self):
        super(Maweiyi, self).__init__()
        self.conv1 = Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)

    def forward(self,x):
        x = self.conv1(x)
        return x

maweiyi = Maweiyi()
print(maweiyi)

输出:

  

2.3 卷积操作控制台输出结果

 使用for循环输出结果。

# 输出卷积前的图片大小和卷积后的图片大小
for data in dataloader:
    imgs, labels = data
    print(imgs.shape)
    # 卷积操作
    outputs = maweiyi(imgs)
    print(outputs.shape)

输出:

2.4 tensorboard可视化

 注意:使用tensorboard输出时需要重新定义图片大小

  • 对于输入的图片集imgs来说,tensor.size([64,3,32,32]),即一批次为64张,一张图片为三个通道,大小为32*32
  • 对于经过卷积后输出的图片集output来说,tensor.size([64,6,30,30]),通道数变成了6,tensorboard不知道怎么显示通道数为6的图片,所以如果直接输出会报错

解决方案:

  • 使用reshape方法对outputs进行重定义,把通道数改成3,如果不知道批次数大小,可以使用-1代替,程序会自动匹配批次大小。

outputs = torch.reshape(outputs,(-1,3,30,30))

# 生成日志
writer = SummaryWriter("logs")
step = 0

# 输出卷积前的图片大小和卷积后的图片大小
for data in dataloader:
    imgs, labels = data
    # 显示输入的图片
    writer.add_images("inputs",imgs,step)

    #卷积操作
    outputs = maweiyi(imgs)

    #重定义输出图片的大小
    outputs = torch.reshape(outputs,(-1,3,30,30))
    # 显示输出的图片
    writer.add_images("outputs",outputs,step)
    step+=1

writer.close()

 输出:

input:一个批次64张图片

output:一个批次128张 (64*2=128)

三、完整代码 

import torch
import torchvision
from torch.nn import Conv2d
from torch.utils.data import DataLoader
# 数据集下载
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10(root=".\CIFAR10", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)
# 数据加载器
dataloader = DataLoader(dataset, batch_size=64)


class Maweiyi(torch.nn.Module):
    def __init__(self):
        super(Maweiyi, self).__init__()
        # 卷积层
        self.conv1 = Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)

    def forward(self, x):
        x = self.conv1(x)
        return x


maweiyi = Maweiyi()
print(maweiyi)

writer = SummaryWriter("logs")
step = 0

for data in dataloader:
    imgs, labels = data
    # 卷积操作
    output = maweiyi(imgs)
    print(imgs.shape)
    print(output.shape)
    writer.add_images("input", imgs, step)

    output = torch.reshape(output, (-1, 3, 30, 30))
    writer.add_images("output", output, step)
    step = step + 1

writer.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/23302.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

NestJS 使用体验 | 不如 Spring Boot

本博客站点已全量迁移至 DevDengChao 的博客 https://blog.dengchao.fun , 后续的新内容将优先在自建博客站进行发布, 欢迎大家访问. 文章目录前言正文开发体验运行体验总结相关内容推广前言 公司里近期在尝试部署一些业务到阿里云的函数计算上, 受之前迁移已有的 Spring Boot…

TestStand-调试VI

文章目录调试VI调试VI 在LabVIEW PASS/FAIL TEST步骤中放置一个断点。 ExecuteRun MainSequence。执行在LabVIEW PASS/FAIL处暂停测试步骤。 3.完成以下步骤来调试LabVIEW PASS/FAIL TEST VI步骤。 a.在TestStand的调试工具栏上单击step into(步进)…

System V IPC+消息队列

多进程与多线程 使用有名管道实现双向通信时,由于读管道是阻塞读的,为了不让“读操作”阻塞“写操作”,使用了父子进程来多线操作, 1)父进程这条线:读管道1 2)子进程这条线:写管道2…

【二叉树的顺序结构:堆 堆排序 TopK]

努力提升自己,永远比仰望别人更有意义 目录 1 二叉树的顺序结构 2 堆的概念及结构 3 堆的实现 3.1 堆向下调整算法 3.2 堆向上调整算法 3.3堆的插入 3.4 堆的删除 3.5 堆的代码实现 4 堆的应用 4.1 堆排序 4.2 TOP-K问题 总结: 1 二叉树的顺序结…

分享几招教会你怎么给图片加边框

大家平时分享图片的时候,会不会喜欢给照片加点装饰呢?比如加些边框、文字或者水印之类的。我就喜欢给图片加上一些边框,感觉加了边框的照片像裱在相框中的感觉似的,非常有趣。那么你知道如何给图片加边框吗?不知道的话…

【Nginx】01-什么是Nginx?Nginx技术的功能及其特性介绍

目录1. 介绍1.1 常见服务器的对比1)IIS2)Tomcat3)Apache4)Lighttpd1.2 Nginx的优点(1) 速度更快、并发更高(2) 配置简单、扩展性强(3) 高可靠性(4) 热部署(5) 成本低、BSD许可证2. Nginx常用功能2.1 基本HTTP服务2.2 高级HTTP服务…

华为数通2022年11月 HCIP-Datacom-H12-821 第二章

142.以下关于状态检测防火墙的描述,正确是哪一项? A.状态检测防火墙需要对每个进入防火墙的数据包进行规则匹配 B.因为UDP协议为面向无连接的协议,因此状态检测型防火墙无法对UDP报文进行状态表的匹配 C.状态检测型防火墙只需要对该连接的第一…

性能测试-CPU性能分析,IO密集导致系统负载高

目录 IO密集导致系统负载高 使用top命令-观察服务器资源状态 使用vmstat命令-观察服务器资源状态 使用pidstat命令-观察服务器资源状态 使用iostat命令-观察服务器资源状态 IO密集导致系统负载高 stress-ng -i 10 --hdd 1 --timeout 100-i :有多少个工作者进行&#…

函数的极限:如何通过 δ 和 ϵ 来定义一个连续的函数

连续的定义 维基百科给出的定义: 连续函数(英语:Continuous function)是指函数在数学上的属性为连续。直观上来说,连续的函数就是当输入值的变化足够小的时候,输出的变化也会随之足够小的函数。 所以不要直…

51单总线控制SV-5W语音播报模块

单总线控制SV-5W语音播报模块SV-5W语音播报模块SV-5W语音播报模块简介工作模式说明模块配置接线驱动部分代码效果展示SV-5W语音播报模块 SV-5W语音播报模块简介 DY-SV5W是一款智能语音模块,集成IO分段触发,UART串口控制,ONE_line单总线串口控…

macOS monterey 12.6.1安装homebrew + nginx + php + mysql

效果图 主要步骤 安装homebrew使用brew安装nginxphpmysql详细步骤 参考“Homebrew国内如何自动安装(国内地址)(Mac & Linux)”安装brew, 命令: /bin/zsh -c "$(curl -fsSL https://gitee.com/cu…

[附源码]java毕业设计网上学车预约系统

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

教你十分钟在Linux系统上快速装机并安装Ansible

PS:本教程建立在VMware软件上的使用上,Linux版本为centos7或者centos8都可以。 一、看发行版本 cat /etc/redhat-release 二、修改主机名 hostnamectl set-hostname centos8 三、自动获取IP地址 nmcli connection modify ens160 autoconnect yes 四、设置…

软件设计(一):统一建模语言基础知识

1.UML简介 1.1 UML简介 UML语言是一种可视化的标准建模语言,它是一种分析和设计语言,通过UML可以构造软件系统的蓝图。 1.2 UML的结构 1.2.1 视图(view) 1.2.2 图(daigram) 用例图 类图 对象图 包图…

C/C++ 语言怎么保留n位小数并且四舍五入

1、普通的printf输出打印 printf()函数的用例 float date=123.456; printf("date=%.2f\n", date);//保留2位 printf("date=%.1f\n", date);//保留1位 输出 2、获取四舍五入后的数据 1、使用round函数 C ++ round()函数 (C++ round() function) round(…

ELK技术栈简介

ELK技术栈简介ELK是什么ELK组件ElasticsearchES基本概念ES适用场景LogstashInput插件Filter插件Output插件CodecsKibanaBeatsELK是什么 ELK 即 Elasticsearch Logstash Kibana,是指Elastic公司开发的三种免费开源软件。其中,Elasticsearch是一个基于A…

基于PHP+MYSQL在线小说阅读网的设计与实现

随着互联网信息的发展,人们在闲暇的时候更多的原因选择小说来进行阅读,一方面扩展自己的阅读圈,另一方面消磨闲暇时光,但是当下的很多小说网站,要么是要收取高昂的阅读法,要么就是整个网站多充斥着大量的广告,为了给广大网友一个健康,免费的阅读空间我们开发了本系统 本在线小说…

【JS】数据结构之树结构

文章目录树结构二叉树二叉搜索树平衡树(AVL树)红黑树回顾其他数据结构(每种数据结构都有自己特定的应用场景): 数组:通过下标查询很快,插入和删除数据的时候,效率会很低,…

新品上线 | 企企通推出达人管理系统,助力达人营销提效增速

01、直播市场发展迅速 企企通达人管理系统应运而生 近年来,直播凭借其即时性、互动性、多样化的优势,迅速在互联网占据一席之地,“直播”模式不断扩展,直播电商应运而生。 在技术发展与市场需求双重驱动下,中国直播市…

day04 springmvc

day04 springmvc 第一章 SpringMVC运行原理 第一节 启动过程 1. Servlet 生命周期回顾 生命周期环节调用的方法时机次数创建对象无参构造器默认:第一次请求 修改:Web应用启动时一次初始化init(ServletConfig servletConfig)创建对象后一次处理请求se…