使用 PyTorch 的计算机视觉简介 (4/6)

news2024/11/19 1:19:59

一、说明

        在本单元中,我们将了解卷 积神经网络(CNN),它是专门为 计算机视觉设计的。 多层卷积层允许我们从图像中提取某些图像模式,池化层,以及在 CIFAR-10上的表现

二、多层 CNN

        在上一个单元中,我们学习了可以从图像中提取模式的卷积滤波器。对于我们的 MNIST 分类器,我们使用了 5 个 5 × 9 个过滤器,产生了 24 × 24 × <> 张量。

        我们可以使用相同的卷积思想来提取图像中更高级别的模式。例如,数字(如 8 和 9)的圆角边缘可以由许多较小的笔画组成。为了识别这些模式,我们可以在第一层的结果之上构建另一层卷积过滤器。

!wget https://raw.githubusercontent.com/MicrosoftDocs/pytorchfundamentals/main/computer-vision-pytorch/pytorchcv.py
import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
from torchinfo import summary
import numpy as np

from pytorchcv import load_mnist, train, plot_results, plot_convolution, display_dataset
load_mnist(batch_size=128)

三、池化层

        第一个卷积层寻找原始模式,例如水平线或垂直线。它们之上的下一级卷积层会寻找更高级别的模式,例如原始形状。更多的卷积层可以将这些形状组合到图片的某些部分,直到我们试图分类的最终对象。这将创建提取模式的层次结构。

        这样做时,我们还需要应用一个技巧:减小图像的空间大小。一旦我们检测到滑动窗口中存在水平斯托克,它发生在哪个确切像素就不那么重要了。因此,我们可以“缩小”图像的大小,这是使用池化层之一完成的:

  • 平均池化采用滑动窗口(例如,2 × 2 像素)并计算窗口内值的平均值。
  • “最大池化”将窗口替换为最大值。最大池化背后的想法是检测滑动窗口中是否存在某种模式。

        在典型的CNN中,将由几个卷积层组成,它们之间有池化层以减小图像的尺寸。我们还会增加过滤器的数量,因为随着模式变得更加先进,我们需要寻找更多可能的有趣组合。由于空间维度减小和要素/过滤器维度增加,此体系结构也称为金字塔体系结构

        在下一个示例中,我们将使用两层 CNN:

class MultiLayerCNN(nn.Module):
    def __init__(self):
        super(MultiLayerCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(10, 20, 5)
        self.fc = nn.Linear(320,10)

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 320)
        x = nn.functional.log_softmax(self.fc(x),dim=1)
        return x

net = MultiLayerCNN()
summary(net,input_size=(1,1,28,28))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Conv2d: 1-1                            [1, 10, 24, 24]           260
├─MaxPool2d: 1-2                         [1, 10, 12, 12]           --
├─Conv2d: 1-3                            [1, 20, 8, 8]             5,020
├─MaxPool2d: 1-4                         [1, 20, 4, 4]             --
├─Linear: 1-5                            [1, 10]                   3,210
==========================================================================================
Total params: 8,490
Trainable params: 8,490
Non-trainable params: 0
Total mult-adds (M): 0.47
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.06
Params size (MB): 0.03
Estimated Total Size (MB): 0.09
=======================================================================

        请注意有关定义的一些事项:

  • 我们没有使用 Flatten 层,而是使用 view 函数在正向函数内展平张量,这类似于 numpy 中的重塑函数。由于扁平化层没有可训练的权重,因此我们不需要在类中创建单独的层实例 — 我们只需使用 torch.nn.functional 命名空间中的函数即可。
  • 我们在模型中只使用池层的一个实例,也是因为它不包含任何可训练的参数,因此可以有效地重用一个实例。
  • 可训练参数的数量(~8.5K)比以前的情况(感知器中为80K,单层CNN中为50K)要少得多。
    发生这种情况是因为卷积层通常具有很少的参数,与输入图像大小无关。此外,由于池化,在应用最终密集层之前,图像的维度显着降低。少量参数对我们的模型有积极影响,因为它有助于防止过度拟合,即使在较小的数据集大小上也是如此。
hist = train(net,train_loader,test_loader,epochs=5)
Epoch  0, Train acc=0.949, Val acc=0.978, Train loss=0.001, Val loss=0.001

您可能应该观察到的是,我们能够实现更高的精度,而且速度更快 - 只需 1 或 2 个 epoch。这意味着复杂的网络架构需要更少的数据来弄清楚发生了什么,并从我们的图像中提取通用模式。

四、探索 CIFAR-10 数据集

让我们下载不同对象的真实图像数据集,称为
CIFAR-10。它包含60k 32×32彩色图像,分为10类。

transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=14, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=14, shuffle=False)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
display_dataset(trainset,classes=classes)

        CIFAR-10的一个著名架构称为LeNet,由Yann LeCun提出。它遵循与我们上面概述的相同原则。但是,由于所有图像都是彩色的,因此输入张量大小为 3 × 32 × 32,并且 5 × 5 卷积滤波器也应用于整个颜色维度——这意味着卷积核矩阵的大小为

        3 × 5 × 5。

        我们还对这个模型做了一个简化——我们不使用 log_softmax 作为输出激活函数,只返回最后一个全连接层的输出。在这种情况下,我们可以只使用交叉熵损失函数来优化模型。

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.conv3 = nn.Conv2d(16,120,5)
        self.flat = nn.Flatten()
        self.fc1 = nn.Linear(120,64)
        self.fc2 = nn.Linear(64,10)

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = nn.functional.relu(self.conv3(x))
        x = self.flat(x)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = LeNet()

summary(net,input_size=(1,3,32,32))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Conv2d: 1-1                            [1, 6, 28, 28]            456
├─MaxPool2d: 1-2                         [1, 6, 14, 14]            --
├─Conv2d: 1-3                            [1, 16, 10, 10]           2,416
├─MaxPool2d: 1-4                         [1, 16, 5, 5]             --
├─Conv2d: 1-5                            [1, 120, 1, 1]            48,120
├─Flatten: 1-6                           [1, 120]                  --
├─Linear: 1-7                            [1, 64]                   7,744
├─Linear: 1-8                            [1, 10]                   650
==========================================================================================
Total params: 59,386
Trainable params: 59,386
Non-trainable params: 0
Total mult-adds (M): 0.65
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.05
Params size (MB): 0.24
Estimated Total Size (MB): 0.30
==========================================================================================

        正确训练此网络将花费大量时间,最好在启用 GPU 的计算上完成。

        为了获得更好的训练结果,我们可能需要尝试一些训练参数,例如学习率。因此,我们在这里显式定义了一个 S对量梯度下降 (SGD) 优化器并传递训练参数。您可以调整这些参数并观察它们如何影响训练。

opt = torch.optim.SGD(net.parameters(),lr=0.001,momentum=0.9)
hist = train(net, trainloader, testloader, epochs=3, optimizer=opt, loss_fn=nn.CrossEntropyLoss())
Epoch  0, Train acc=0.261, Val acc=0.388, Train loss=0.143, Val loss=0.121
Epoch  1, Train acc=0.437, Val acc=0.491, Train loss=0.110, Val loss=0.101
Epoch  2, Train acc=0.508, Val acc=0.522, Train loss=0.097, Val loss=0.094

        我们通过 3 个时期的训练能够达到的准确性似乎不是很好。但是,请记住,盲猜只能给我们10%的准确率,而且我们的问题实际上比MNIST数字分类要困难得多。在如此短的训练时间内获得超过 50% 的准确率似乎是一个很好的成就。

五、小结

        在本单元中,我们学习了计算机视觉神经网络背后的主要概念——卷积网络。支持图像分类、对象检测甚至图像生成网络的真实架构都基于 CNN,只是具有更多层和一些额外的训练技巧。

        
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1035848.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

拦截|篡改|伪造.NET类库中不限于public的类和方法

大家好&#xff0c;我是沙漠尽头的狼。 本文首发于Dotnet9&#xff0c;介绍使用Lib.Harmony库拦截第三方.NET库方法&#xff0c;达到不修改其源码并能实现修改方法逻辑、预期行为的效果&#xff0c;并且不限于只拦截public访问修饰的类及方法&#xff0c;行文目录&#xff1a;…

Mysql004:用户管理

前言&#xff1a;本章节讲解的是mysql中的用户管理&#xff0c;包括&#xff08;管理数据用户&#xff09;、&#xff08;控制数据库的访问权限&#xff09;。 目录 1. 查询用户 2. 创建用户 3. 修改用户密码 4. 删除用户 5. 权限控制 1. 查询用户 在mysql数据库中&#xff0…

P-GaN栅极HEMT开关瞬态分析中的动态栅极电容模型

标题&#xff1a;Dynamic Gate Capacitance Model for Switching Transient Analysis in P-GaN Gate HEMTs 摘要 在这项工作中&#xff0c;提出了一种用于P-GaN栅极HEMT的高效开关瞬态分析模型&#xff0c;该模型考虑了开关瞬态过程中的动态栅极电容CG(VDS, VGS)特性。同时&a…

【STM32教程】第五章 STM32的定时器

案例代码及相关资料下载链接&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1hsIibEmsB91xFclJd-YTYA?pwdjauj 提取码&#xff1a;jauj 1 定时器综述 1.1 定时器简介 TIM&#xff08;Timer&#xff09;定时器&#xff0c;最基本功能就是定时触发中断&#xff1…

python随手小练3

题目&#xff1a; 写出一个判断闰年的python代码&#xff1a; 闰年的条件&#xff1a; 如果N能够被4整除&#xff0c;并且不能被100整除&#xff0c;则是闰年 或者&#xff1a;N能被400整除&#xff0c;也是闰年 即&#xff1a;4年一润并且百年不润&#xff0c;每400年再润一…

前端第二课,HTML,alt,title,width/heigh,border,<a>超链接,target,tr,td,th

目录 一、title: &#x1f49b; ​二、alt&#x1f499; 三、width/heigh&#x1f49c; 四、border ❤️ 五、超链接&#x1f49a; 六、target &#x1f497; 七、tr&#x1f495; 八、td&#x1f498; 九、th&#x1f49e; 十、rowspan 一、title: &#x1f49b; 快…

漏刻有时数据可视化Echarts组件开发(29):工业图形动画

var nodes = [{nodeName: 新能源,value: [-10, 100],symbolSize: 50,symbol:image://data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAIAAAACACAYAAADDPmHLAAAABGdBTUEAALGPC/xhBQAAACBjSFJNAAB6JgAAgIQAAPoAAACA6AAAdTAAAOpgAAA6mAAAF3CculE8AAAABmJLR0QA/wD/AP+gvaeTAAAAB3R…

【操作系统笔记四】高速缓存

CPU 高速缓存 存储器的分层结构&#xff1a; 问题&#xff1a;为什么这种存储器层次结构行之有效呢&#xff1f; 衡量 CPU 性能的两个指标&#xff1a; 响应时间&#xff08;或执行时间&#xff09;&#xff1a;执行一条指令平均时间 吞吐量&#xff0c;就是 1 秒内 CPU 可以…

ideal 同一项目启动多实列

点击service(如果没有:点击菜单栏&#xff1a;Views -> Tool Windows -> Services&#xff1b;中文对应&#xff1a;视图 -> 工具窗口 -> 服务&#xff1b;快捷键是Alt F8&#xff0c;但是本地快捷键可能冲突&#xff0c;并未成功。) 刚创建好的窗口是空白的&…

回归预测 | MATLAB实现基于RF-Adaboost随机森林结合AdaBoost多输入单输出回归预测

回归预测 | MATLAB实现基于RF-Adaboost随机森林结合AdaBoost多输入单输出回归预测 目录 回归预测 | MATLAB实现基于RF-Adaboost随机森林结合AdaBoost多输入单输出回归预测预测效果基本介绍模型描述程序设计参考资料 预测效果 基本介绍 1.MATLAB实现基于RF-Adaboost随机森林结合…

Android 遍历界面所有的View

关于作者&#xff1a;CSDN内容合伙人、技术专家&#xff0c; 从零开始做日活千万级APP。 专注于分享各领域原创系列文章 &#xff0c;擅长java后端、移动开发、商业变现、人工智能等&#xff0c;希望大家多多支持。 目录 一、导读二、概览三、实践四、 推荐阅读 一、导读 我们…

基于QT实现发送http的get和post请求(post还可以实现上传文件),同时实现接收返回json数据,并对其进行解析

使用到中重要的类&#xff0c;做个简单的介绍 QNetworkAccessManager&#xff1a;这个类是QT帮我们封装好的工具类&#xff0c;主要可以用来发送Http请求 QNetworkReply&#xff1a;这个类主要用来监听发送的请求&#xff0c;并得到请求的响应结果 QHttpMultiPart&#xff1a;这…

RSD处理气象卫星数据——常用投影

李国春 气象卫星扫描刈幅宽覆盖范围广&#xff0c;在地球的不同位置可能需要不同的投影以便更好地表示这些观测数据。这与高分辨率的局地数据有很大不同&#xff0c;高分数据更倾向于用使用处理局地小范围的投影方式。本文选择性介绍几种RSD常用的适合低、中、高纬和极地地区的…

python+nodejs+php+springboot+vue 校园安全车辆人员出入安全管理系统

本校园安全管理系统共包含15个表:分别是表现评分信息表&#xff0c;车辆登记信息表&#xff0c;配置文件信息表&#xff0c;家校互动信息表&#xff0c;监控系统信息表&#xff0c;教师信息表&#xff0c;留言板信息表&#xff0c;校园资讯信息表&#xff0c;人员登记信息表&am…

2023-9-23 合并果子

题目链接&#xff1a;合并果子 #include <iostream> #include <algorithm> #include <queue>using namespace std;int main() {int n;cin >> n;priority_queue<int, vector<int>, greater<int>> heap;for(int i 0; i < n; i){in…

Spring Cloud版本选择

SpringCloud版本号由来 SpringCloud的版本号是根据英国伦敦地铁站的名字进行命名的&#xff0c;由地铁站名称字母A-Z依次类推表示发布迭代版本。 SpringCloud和SpringBoot版本对应关系 注意事项&#xff1a; 其实SpringBoot与SpringCloud需要版本对应&#xff0c;否则可能会造…

墓园导航系统:实现数字化陵园祭扫新模式

墓园导航系统&#xff1a;实现数字化陵园祭扫新模式 随着人口老龄化趋势的加剧&#xff0c;人们对墓地的需求逐渐增加。同时&#xff0c;由于很多墓园面积较大&#xff0c;环境复杂&#xff0c;很多家属在寻找亲人墓地时感到不便和困难。此外&#xff0c;传统墓园的管理和服务水…

论文研读-数据共享-大数据流分析中的共享执行技术

Shared Execution Techniques for Business Data Analytics over Big Data Streams 大数据流分析中的共享执行技术 1、摘要 2020年的一篇共享工作的论文&#xff1a;商业数据分析需要处理大量数据流&#xff0c;并创建物化视图以便给用户实时提供分析结果。物化每个查询&#x…

FPGA——UART串口通信

文章目录 前言一、UART通信协议1.1 通信格式2.2 MSB或LSB2.3 奇偶校验位2.4 UART传输速率 二、UART通信回环2.1 系统架构设计2.2 fsm_key2.3 baud2.4 sel_seg2.5 fifo2.6 uart_rx2.7 uart_tx2.8 top_uart2.9 发送模块时序分析2.10 接收模块的时序分析2.11 FIFO控制模块时序分析…

java框架-Spring-事务

配置 配置事务管理器方法&#xff1a; Beanpublic PlatformTransactionManager platformTransactionManager(){return new DataSourceTransactionManager();}原理