PyTorch实战2:彩色图片识别(CIFAR10)

news2024/10/6 22:25:25
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第P2周:彩色图片识别
  • 🍖 原作者:K同学啊|接辅导、项目定制

目录

    • 一、数据准备
    • 二、构建简单CNN网络
        • ⭐1. `torch.nn.Conv2d()`详解
        • ⭐2. torch.nn.Linear()详解
        • ⭐3. torch.nn.MaxPool2d()详解
        • ⭐4. torch.nn.BatchNorm2d()详解
        • ⭐5. 关于卷积层、池化层的计算:
        • 6.构建稀疏卷积的CNN
    • 三、总结

一、数据准备

torchvision.datasets详解 :http://t.csdn.cn/DCqMk

本次案例依然使用Pytorch自带的一个数据库torchvision.datasets,通过代码在线下载数据,这里使用的是torchvision.datasets中的CIFAR10数据集。

具体代码:

train_ds = torchvision.datasets.CIFAR10('data', 
                                      train=True, 
                                      transform=torchvision.transforms.ToTensor(), # 将数据类型转化为Tensor
                                      download=True)

test_ds  = torchvision.datasets.CIFAR10('data', 
                                      train=False, 
                                      transform=torchvision.transforms.ToTensor(), # 将数据类型转化为Tensor
                                      download=True)

在这里简单展示一下部分彩色图片:
在这里插入图片描述

后面具体实现操作可参考:http://t.csdn.cn/DCqMk,这里直接进入构建网络部分。

二、构建简单CNN网络

⭐1. torch.nn.Conv2d()详解

函数原型如下:

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')

参数说明

  • in_channels:输入特征图的通道数;
  • out_channels:输出特征图的通道数,即卷积核的个数;
  • kernel_size:卷积核的大小,可以是int、tuple型变量。如kernel_size=3表示使用3x3的卷积核进行卷积;
  • stride:卷积核的步长,可以是int、tuple型变量。如stride=2表示每隔1行/列卷积一次;
  • padding:填充的长度,可以是int、tuple型变量。如padding=1表示在输入特征图的四周各加1圈0,以减小特征图尺寸;
  • dilation:空洞卷积操作的空洞率(dilation rate),可以是int、tuple型变量;
  • groups:组卷积的分组数量,默认为1,表示普通卷积操作;
  • bias:是否添加偏置项,默认添加;
  • padding_mode:填充模式,可取’zeros’或’circular’。

例子

import torch
import torch.nn as nn

conv = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
input_data = torch.randn(1, 3, 32, 32)  # 输入特征图大小为[1, 3, 32, 32]
output = conv(input_data)  # 输出特征图大小为[1, 16, 32, 32]

以上代码定义了一个输入特征图通道数为3,输出特征图通道数为16,卷积核大小为3x3,步长为1,填充长度为1的卷积层。将随机生成的大小为[1, 3, 32, 32]的张量作为输入,经过一次卷积操作后得到输出特征图大小为[1, 16, 32, 32]的张量。

⭐2. torch.nn.Linear()详解

torch.nn.Linear()用于实现线性变换或全连接层。它将大小为in_features的输入张量映射到大小为out_features的输出张量,通过以下公式实现:

y = xA^T + b

函数原型

torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)

参数说明

  • in_features:每个输入样本的大小
  • out_features:每个输出样本的大小

其中,x是输入张量,A是权重矩阵,b是偏置向量,y是输出张量。

使用torch.nn.Linear()可以方便地定义神经网络模型中的全连接层,并自动管理权重和偏置等参数。例如:

import torch
import torch.nn as nn

# 定义一个输入维度为3,输出维度为4的全连接层
linear_layer = nn.Linear(3, 4)

# 随机生成一个大小为(2, 3)的输入张量
input_tensor = torch.randn(2, 3)

# 将输入张量传入全连接层进行前向计算
output_tensor = linear_layer(input_tensor)

# 查看输出张量的形状
print(output_tensor.shape)

上述代码定义了一个输入维度为3,输出维度为4的全连接层,并随机生成了一个大小为(2,3)的输入张量进行前向计算。输出张量的形状应该为(2, 4)

除了输入维度和输出维度之外,torch.nn.Linear()还可以设置其他参数,如是否包括偏置项、权重初始化方法等。这些参数可以通过传递关键字参数进行设置。

⭐3. torch.nn.MaxPool2d()详解

torch.nn.MaxPool2d() 用于进行 2D 最大池化操作的函数。它可以将输入的二维数据张量按照指定大小进行划分,并在每个子区域中取最大值,从而得到一个更小的输出张量。

函数原型

torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

以下是 torch.nn.MaxPool2d() 的常用参数:

  • kernel_size:池化窗口的大小,可以是一个整数(表示正方形)或一个元组(表示长方形)。如果设置为 ( k , k ) (k,k) (k,k) 或者 k k k,则表示使用 k × k k\times k k×k 的池化窗口。
  • stride:池化窗口的步幅,可以是一个整数(表示横向和纵向相同的步幅),也可以是一个元组(表示横向和纵向不同的步幅)。
  • padding:填充的大小,可以是一个整数(表示正方形)或一个元组(表示长方形),与卷积的 padding 参数类似。
  • dilation:卷积核的扩张率,即卷积核中各个元素之间的间隔距离。
  • return_indices:是否返回最大值的索引。
  • ceil_mode:当 stride 不被整除时,是否向上取整,可以避免出现边界像素没有参与池化的情况。

具体用法如下:

import torch.nn as nn

maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
input = torch.randn(1, 3, 64, 64)

output = maxpool(input)

上面的代码中,我们创建了一个 3x3 的池化窗口,步幅为 2,填充大小为 1 的最大池化层,并将其应用于一个输入张量。最后得到的输出张量形状会因为池化操作而变小。

⭐4. torch.nn.BatchNorm2d()详解

torch.nn.BatchNorm2d()用于批量标准化(Batch Normalization)操作。它可以对输入数据进行标准化,并将其缩放和移位以使其均值为0,方差为1。该层通常用于神经网络中,可使训练更稳定且加快收敛速度。

该层的输入是形状为 (batch_size, num_channels, height, width) 的4D张量,其中 batch_size 表示批次大小,num_channels 表示通道数,heightwidth 分别表示输入数据的高度和宽度。

参数说明:

  • num_features (int):输入特征的数量(即 num_channels)。
  • eps (float, optional):防止除以0的小数,默认为1e-5。
  • momentum (float, optional):用于计算统计信息的动量,应在0到1之间,默认为0.1。
  • affine (bool, optional):是否使用可学习的仿射变换,默认为True。
  • track_running_stats (bool, optional):是否计算并跟踪运行时统计数据,默认为True。

首先通过 nn.Conv2d 进行卷积操作,然后传递给 nn.BatchNorm2d 层进行标准化操作,接着再使用ReLU激活函数进行非线性变换。最后,将做过标准化和非线性变换的输出传递到全连接层,以生成最终的预测结果。

⭐5. 关于卷积层、池化层的计算:

下面的网络数据shape变化过程为:

3, 32, 32(输入数据)
-> 64, 30, 30(经过卷积层1)-> 64, 15, 15(经过池化层1)
-> 64, 13, 13(经过卷积层2)-> 64, 6, 6(经过池化层2)
-> 128, 4, 4(经过卷积层3) -> 128, 2, 2(经过池化层3)
-> 512 -> 256 -> num_classes(10)

(注:此处计算过程只是作为例子参考)

6.构建稀疏卷积的CNN

定义一个基于稀疏卷积神经网络的分类器,包括了三个主要的组成部分:卷积层、批量归一化层以及全连接层。

首先,在初始化函数(__init__)中,我们定义了卷积层conv1,使用的卷积核大小为3x3,有16个输出通道。然后,我们加入了一个批量归一化层bn1,其输入通道数为16。接下来,我们定义了一个稀疏卷积层sparse_conv,其输入通道数为16,输出通道数为32,且不使用padding。最后,我们添加了另一个批量归一化层bn2,其输入通道数为32,用于归一化稀疏卷积层的输出。最终,我们加入了一个全连接层fc1,将32个特征图转换为10个类别的概率值。

在前向传递函数(forward)中,我们首先对输入数据x应用第一层卷积操作,然后使用ReLU激活函数和批量归一化对其进行处理。接着,我们将输出结果再次通过稀疏卷积层、批量归一化层及ReLU激活函数处理。之后,我们使用平均池化层将特征图压缩成一个1x1的向量,以便我们可以将其送入一个全连接层进行分类。在最后一步中,我们应用softmax激活函数,并返回对数值作为输出结果,这个输出结果包含每个类别出现的概率。

class SparseConvNet(nn.Module):
    def __init__(self):
        super(SparseConvNet, self).__init__()
        # 第一层卷积
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)   # 批量归一化层
        # 稀疏卷积层
        self.sparse_conv = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(32)   # 批量归一化层
        # 全连接层
        self.fc1 = nn.Linear(32, 10)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))     # 第一层卷积 + 批量归一化 + 激活函数(ReLU)
        x = F.relu(self.bn2(self.sparse_conv(x)))    # 稀疏卷积层 + 批量归一化 + 激活函数(ReLU)
        x = F.avg_pool2d(x, kernel_size=x.size()[2:])    # 平均池化层
        x = x.view(-1, 32)   # 将特征图拉成向量
        x = self.fc1(x)  # 全连接层
        return F.log_softmax(x, dim=1)  # 输出层应用softmax激活函数,并返回对数值

网络结构展示:
在这里插入图片描述

训练模型与结果可视化在上一篇PyTorch实战1:实现mnist手写数字识别已有详细的赘述

在此就直接摆出本文案例的运行结果图:
在这里插入图片描述
在这里插入图片描述

三、总结

本文实战并没有使用深度学习训练营中的网络结构进行模型训练,而是自己设计了一个较为简单的、易于理解的网络结构,发现亲手设计从0到1的网络会遇到一些问题,比如每个层的参数该如何设置,卷积层、池化层如何计算,使用多少个卷积层、池化层、全连接层,尝试不用正规卷积而改用稀疏卷积如何去实现等等。

本次实战运用自己设计的网络结构开始训练模型,最终结果证明效果一般般,毕竟这是第一次且刚入门的小案例,后续熟练了再试着去调参优化,将模型的精度提高至80,甚至90。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/451739.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

图扑软件 | 数字孪生智慧水泥工厂

前言 近年来,随着我国经济的发展和人民生活水平的提高,我国对于水泥行业的关注程度也越来越高,为了保证水泥行业的健康稳定发展,许多地方都在大力推动水泥生产技术创新工作。当前水泥行业的发展正处于新旧动能更迭的关键阶段&…

JavaWeb开发 —— SpringBootWeb综合案例

通过综合案例,我们来了解和学习在开发Web程序时,前端程序和后端程序以及数据库三者之间是如何交互、如何协作的,而通过综合案例也来掌握通过接口文档开发服务端接口的能力。 目录 一、准备工作 1. 需求说明 2. 环境搭建 3. 开发规范 二…

Postgis导出shp和gdb数据库(Postgre入门九)

背景 有时候我们需要将postgis数据库中的空间数据表导出shp格式,而PG自带的PostGIS Shapefile Import/Export Manager 导出shp大部分时候是可以用的,但是它有个缺点是,当shp字段名称超过10个字节时,字段会被切掉,如字段“afdskskkfkfjdj”被切掉后是“afdskskkfk”,所以…

文案自动修改软件-文案自动改写的免费软件下载

文章生成器ai写作机器人 随着人工智能技术的飞速发展,越来越多的新型产品被推向市场。其中,文章生成器AI写作机器人是一个备受关注的新兴行业。它使用机器学习和自然语言处理等技术,为用户自动生成高质量的文章和内容,帮助用户在…

基于OpenCV与深度神经网络——实现证件识别扫描并1比1还原证件到A4纸上

前言 1.用拍照的证件照片正反面,实现用证件去复印到A4纸上的效果,还有证件的格式化识别。 图1:把拍照的证件1比1还原证件到A4纸上 图2:证件OCR格式化识别 2.使用Yolo做目标识别,Enet做边缘检测,Paddle OCR做文字识别&…

【数据结构与算法】常用数据结构(一)

😀大家好,我是白晨,一个不是很能熬夜,但是也想日更的人✈。如果喜欢这篇文章,点个赞👍,关注一下👀白晨吧!你的支持就是我最大的动力!💪&#x1f4…

燃气管道定位83KHZ地下电子标识器探测仪ED-8000操作说明1

1、功能简要说明 ED-8000地下电子标识器探测仪是华翔天诚推出的一款可支持模拟电子标识器(无 ID)探测和数字 ID 电子标识器 探测两种工作模式,在模拟电子标识器(无 ID)探测模式下,可探测 所有按标准频率生…

Unity-ML-Agents安装

目录 1.下载ML-Agents 1.1 前往官网 1.2 选择版本 1.3 下载文件 2.下载Anaconda 3.虚拟环境 3.1 构建虚拟环境 3.2 创建项目,导入package.json 3.2.1 创建项目,导入package.json 3.2.2 导入成功 3.2.3 将模板项目拖入unity项目中 3.3 开始训练 …

低代码感觉很能打——可视化搭建系统,把格局做大

有人说「可视化搭建系统」说到底只是重复造轮子产生的玩具; 有人说「可视化搭建系统」本质是组件枚举,毫无意义。 片面的认知必有其产生的道理,但我们不妨从更高的角度出发,并真切落地实践,也许你会发现:我…

Java面试题总结 | Java面试题总结5- 数据结构模块(持续更新)

数据结构 文章目录 数据结构顺序表和链表的区别HashMap 和 Hashtable 的区别Java中用过哪些集合,说说底层实现,使用过哪些安全的集合类Java中线程安全的基本数据结构有哪些ArrayList、Vector和LinkedList有什么共同点与区别?ArrayList和Linke…

怎样正确做web应用的压力测试?

web应用,通俗来讲就是一个网站,主要依托于浏览器实现其功能。 提到压力测试,我们想到的是服务端压力测试,其实这是片面的,完整的压力测试包含服务端压力测试和前端压力测试。 下文将从以下几部分内容展开&#xff1a…

源码简读 - AlphaFold2的2.3.2版本源码解析 (1)

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://blog.csdn.net/caroline_wendy/article/details/130323566 时间:2023.4.22 官网:https://github.com/deepmind/alphafold AlphaFold2是一种基于深度学习的方法…

torch中fft和ifft新旧版本使用

pytorch旧版本(1.7之前)中有一个函数torch.rfft(),但是新版本(1.8、1.9)中被移除了,添加了torch.fft.rfft(),但它并不是旧版的替代品。 torch.fft label_fft1 torch.rfft(label_img4, signal…

25岁走出外包后,感到迷茫了.....

我认识一个老哥,他前段时间从外包出来了,他在外包干了3年左右的点工,可能也是自身的原因,也没有想到提升自己的技术水平,后面觉得快废了,待着没意思就出来了,出来后他自己更迷茫了,本…

Linux安装Jenkins搭配Gitee自动化部署Springboot项目

目录 前言一、环境准备二、全局工具配置jdk、maven、git三、配置Gitee四、新建任务-部署Springboot项目 前言 Jenkins是一款流行的开源持续集成(CI)和持续交付(CD)工具。它可以帮助开发人员自动构建、测试和部署软件应用程序&…

广州蓝景分享—快速了解Typescript 5.0 中重要的新功能

作为一种在开发人员中越来越受欢迎的编程语言,TypeScript 不断发展,带来了大量的改进和新功能。在本文中,我们将深入研究 TypeScript 的最新迭代版本 5.0,并探索其最值得注意的更新。 1.装饰器 TypeScript 5.0 引入了改进的装饰…

二、SQLServer 的适配记录

SQLServer 适配记录 说明:由于 SQLSERVER 数据库本身和MYSQL数据库有一定的语法,创表结构,物理模式等差别,在适配过程中,可能会出现各种错误情况,可参考本次适配记录。 当前环境: 适配项目:JDK11,SpringBoot服务。 适配数据库:SELECT @@VERSION,得 Microsoft SQL …

ProtocolBuffer入门和使用

<<<<<<< HEAD 基础 入门 优势 protocol buffer主要用于结构化数据串行化的灵活、高效、自动的方法&#xff08;简单来说就是结构化数据的可串行化传输&#xff0c;类似JSON、XML等&#xff09;。 比XML解析更快&#xff1a;解析的层数更少&#xff0c;…

【技术发烧】MySqlServer,MySQL WorkBench安装详细教程

目录 一.下载安装MySQLSever 1.下载 2.安装 1.解压 2.编写配置文件 二.初始化数据库 1.以管理员身份打开命令提示符 2.初始化数据库 3.安装mysql服务并启动 4.连接MySQL 5. 修改密码 三.MySQL WorkBench下载 一.下载安装MySQLSever 1.下载 下载路径&#xff1a;https:/…

java导入导出excel数据图片合成工具

目录 java导出和导入excel数据java读取excel数据java数据导出成excel表格 java服务端图片合成的工具 java导出和导入excel数据 可以使用hutool的ExcelUtil工具。 在项目中加入以下依赖&#xff1a; <dependency><groupId>cn.hutool</groupId><artifactI…