《动手学深度学习 Pytorch版》 7.1 深度卷积神经网络(AlexNet)

news2025/1/15 16:34:18

7.1.1 学习表征

深度卷积神经网络的突破出现在2012年。突破可归因于以下两个关键因素:

  • 缺少的成分:数据
    数据集紧缺的情况在 2010 年前后兴起的大数据浪潮中得到改善。ImageNet 挑战赛中,ImageNet数据集由斯坦福大学教授李飞飞小组的研究人员开发,利用谷歌图像搜索对分类图片进行预筛选,并利用亚马逊众包标注每张图片的类别。这种数据规模是前所未有的。
  • 缺少的成分:硬件
    2012年,Alex Krizhevsky和Ilya Sutskever使用两个显存为3GB的NVIDIA GTX580 GPU实现了快速卷积运算,推动了深度学习热潮。

7.1.2 AlexNet

2012年横空出世的 AlexNet 首次证明了学习到的特征可以超越手动设计的特征。

AlexNet 和 LeNet 的架构非常相似(此书对模型稍微精简了一下,取出来需要两个小GPU同时运算的设计特点):

全连接层(1000)

↑ \uparrow

全连接层(4096)

↑ \uparrow

全连接层(4096)

↑ \uparrow

3 × 3 3\times3 3×3最大汇聚层,步幅2

↑ \uparrow

3 × 3 3\times3 3×3卷积层(384),填充1

↑ \uparrow

3 × 3 3\times3 3×3卷积层(384),填充1

↑ \uparrow

3 × 3 3\times3 3×3卷积层(384),填充1

↑ \uparrow

3 × 3 3\times3 3×3最大汇聚层,步幅2

↑ \uparrow

5 × 5 5\times5 5×5卷积层(256),填充2

↑ \uparrow

3 × 3 3\times3 3×3最大汇聚层,步幅2

↑ \uparrow

11 × 11 11\times11 11×11卷积层(96),步幅4

↑ \uparrow

输入图像( 3 × 224 × 224 3\times224\times224 3×224×224

AlexNet 和 LeNet 的差异:

- AlexNet 比 LeNet 深的多
- AlexNet 使用 ReLU 而非 sigmoid 作为激活函数

以下为 AlexNet 的细节。

  1. 模型设计

    由于 ImageNet 中的图像大多较大,因此第一层采用了 11 × 11 11\times11 11×11 的超大卷积核。后续再一步一步缩减到 3 × 3 3\times3 3×3。而且 AlexNet 的卷积通道数是 LeNet 的十倍。

    最后两个巨大的全连接层分别各有4096个输出,近 1G 的模型参数。因早期 GPU 显存有限,原始的 AlexNet 采取了双数据流设计。

  2. 激活函数

    ReLU 激活函数是训练模型更加容易。它在正区间的梯度总为1,而 sigmoid 函数可能在正区间内得到几乎为 0 的梯度。

  3. 容量控制和预处理

    AlexNet 通过暂退法控制全连接层的复杂度。此外,为了扩充数据,AlexNet 在训练时增加了大量的图像增强数据(如翻转、裁切和变色),这也使得模型更健壮,并减少了过拟合。

import torch
from torch import nn
from d2l import torch as d2l
net = nn.Sequential(
    # 这里使用一个11*11的更大窗口来捕捉对象。
    # 同时,步幅为4,以减少输出的高度和宽度。
    # 另外,输出通道的数目远大于LeNet
    nn.Conv2d(1, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2),
    # 减小卷积窗口,使用填充为2来使得输入与输出的高和宽一致,且增大输出通道数
    nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2),
    # 使用三个连续的卷积层和较小的卷积窗口。
    # 除了最后的卷积层,输出通道的数量进一步增加。
    # 在前两个卷积层之后,汇聚层不用于减少输入的高度和宽度
    nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),
    nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),
    nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nn.Flatten(),
    # 这里,全连接层的输出数量是LeNet中的好几倍。使用dropout层来减轻过拟合
    nn.Linear(6400, 4096), nn.ReLU(),
    nn.Dropout(p=0.5),
    nn.Linear(4096, 4096), nn.ReLU(),
    nn.Dropout(p=0.5),
    # 最后是输出层。由于这里使用Fashion-MNIST,所以用类别数为10,而非论文中的1000
    nn.Linear(4096, 10))
X = torch.randn(1, 1, 224, 224)
for layer in net:
    X=layer(X)
    print(layer.__class__.__name__,'output shape:\t',X.shape)
Conv2d output shape:	 torch.Size([1, 96, 54, 54])
ReLU output shape:	 torch.Size([1, 96, 54, 54])
MaxPool2d output shape:	 torch.Size([1, 96, 26, 26])
Conv2d output shape:	 torch.Size([1, 256, 26, 26])
ReLU output shape:	 torch.Size([1, 256, 26, 26])
MaxPool2d output shape:	 torch.Size([1, 256, 12, 12])
Conv2d output shape:	 torch.Size([1, 384, 12, 12])
ReLU output shape:	 torch.Size([1, 384, 12, 12])
Conv2d output shape:	 torch.Size([1, 384, 12, 12])
ReLU output shape:	 torch.Size([1, 384, 12, 12])
Conv2d output shape:	 torch.Size([1, 256, 12, 12])
ReLU output shape:	 torch.Size([1, 256, 12, 12])
MaxPool2d output shape:	 torch.Size([1, 256, 5, 5])
Flatten output shape:	 torch.Size([1, 6400])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 10])

7.1.3 读取数据集

如果真用 ImageNet 训练,即使是现在的 GPU 也需要数小时或数天的时间。在此仅作演示,仍使用 Fashion-MNIST 数据集,故在此需要解决图像分辨率的问题。

batch_size = 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)

7.1.4 训练 AlexNet

lr, num_epochs = 0.01, 10
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())  # 大约需要二十分钟,慎跑
loss 0.330, train acc 0.879, test acc 0.878
592.4 examples/sec on cuda:0

在这里插入图片描述

练习

(1)尝试增加轮数。对比 LeNet 的结果有什么不同?为什么?

lr, num_epochs = 0.01, 15
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())  # 大约需要三十分钟,慎跑
loss 0.284, train acc 0.896, test acc 0.887
589.3 examples/sec on cuda:0

在这里插入图片描述

相较于 LeNet 的增加轮次反而导致精度下降,AlexNet 具有更好的抗过拟合能力,增加轮次精度就会上升。


(2) AlexNet 模型对 Fashion-MNIST 可能太复杂了。

a. 尝试简化模型以加快训练速度,同时确保准确性不会显著下降。

b. 设计一个更好的模型,可以直接在 $28\times28$ 像素的图像上工作。
net_Better = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=5, stride=2, padding=2), nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=1),
    nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(),
    nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(),
    nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nn.Flatten(),
    nn.Linear(64 * 5 * 5, 1024), nn.ReLU(),
    nn.Dropout(p=0.3), 
    nn.Linear(1024, 512), nn.ReLU(),
    nn.Dropout(p=0.3),
    nn.Linear(512, 10)
)

X = torch.randn(1, 1, 28, 28)
for layer in net_Better:
    X=layer(X)
    print(layer.__class__.__name__,'output shape:\t',X.shape)
Conv2d output shape:	 torch.Size([1, 64, 14, 14])
ReLU output shape:	 torch.Size([1, 64, 14, 14])
MaxPool2d output shape:	 torch.Size([1, 64, 12, 12])
Conv2d output shape:	 torch.Size([1, 128, 12, 12])
ReLU output shape:	 torch.Size([1, 128, 12, 12])
Conv2d output shape:	 torch.Size([1, 128, 12, 12])
ReLU output shape:	 torch.Size([1, 128, 12, 12])
Conv2d output shape:	 torch.Size([1, 64, 12, 12])
ReLU output shape:	 torch.Size([1, 64, 12, 12])
MaxPool2d output shape:	 torch.Size([1, 64, 5, 5])
Flatten output shape:	 torch.Size([1, 1600])
Linear output shape:	 torch.Size([1, 1024])
ReLU output shape:	 torch.Size([1, 1024])
Dropout output shape:	 torch.Size([1, 1024])
Linear output shape:	 torch.Size([1, 512])
ReLU output shape:	 torch.Size([1, 512])
Dropout output shape:	 torch.Size([1, 512])
Linear output shape:	 torch.Size([1, 10])
batch_size = 128
train_iter28, test_iter28 = d2l.load_data_fashion_mnist(batch_size=batch_size)
lr, num_epochs = 0.01, 10
d2l.train_ch6(net_Better, train_iter28, test_iter28, num_epochs, lr, d2l.try_gpu())  # 快多了
loss 0.429, train acc 0.841, test acc 0.843
6650.9 examples/sec on cuda:0

在这里插入图片描述


(3)修改批量大小,并观察模型精度和GPU显存变化。

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)

lr, num_epochs = 0.01, 10
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())  # 大约需要二十分钟,慎跑
loss 0.407, train acc 0.850, test acc 0.855
587.8 examples/sec on cuda:0

在这里插入图片描述

4G 显存基本拉满,精度略微下降,过拟合貌似严重了。


(4)分析 AlexNet 的计算性能。

a. 在 AlexNet 中主要是哪一部分占用显存?

b. 在AlexNet中主要是哪部分需要更多的计算?

c. 计算结果时显存带宽如何?

a. 第一个全连接层占用显存最多

b. 倒数第二个卷积层需要更多的计算


(5)将dropout和ReLU应用于LeNet-5,效果有提升吗?再试试预处理会怎么样?

net_try = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.ReLU(),
    nn.Dropout(p=0.2), 
    nn.Linear(120, 84), nn.ReLU(),
    nn.Dropout(p=0.2), 
    nn.Linear(84, 10))

lr, num_epochs = 0.6, 10
d2l.train_ch6(net_try, train_iter28, test_iter28, num_epochs, lr, d2l.try_gpu())  # 浅调一下还挺好
loss 0.306, train acc 0.887, test acc 0.883
26121.2 examples/sec on cuda:0

在这里插入图片描述

浅浅调一下,效果挺好,精度有所提升。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1036715.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring后处理器-BeanPostProcessor

Spring后处理器-BeanPostProcessor Bean被实例化后,到最终缓存到名为singletonObjects单例池之前,中间会经过bean的初始化过程((该后处理器的执行时机)),例如:属性的填充、初始化方…

第 364 场 LeetCode 周赛题解

A 最大二进制奇数 降序排序字符串&#xff0c;然后将最后一个 1 与最后一位交换 class Solution { public:string maximumOddBinaryNumber(string s) {sort(s.begin(), s.end(), greater<>());for (int i s.size() - 1;; i--)if (s[i] 1) {swap(s[i], s.back());break;…

【Oracle】Oracle系列之八--SQL查询

文章目录 往期回顾前言1. 基本查询&#xff08;1&#xff09;All&#xff08;2&#xff09;in/exists 子查询&#xff08;3&#xff09;union/except/intersect&#xff08;4&#xff09;group by&#xff08;5&#xff09;having&#xff08;6&#xff09;聚集函数&#xff1a…

SLAM从入门到精通(用c++实现机器人运动控制)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 之前的一篇文章&#xff0c;我们知道了可以通过wpr_simulation包仿真出机器人和现场环境。如果需要控制机器人&#xff0c;这个时候就需要rqt_robo…

AcWing 5153. 删除(AcWing杯 - 周赛)(结论+枚举)

思路&#xff1a; ACcode: #include<bits/stdc.h> using namespace std; #define int long long string s; void solve() {cin>>s;s"00"s;int lens.size();for(int i0; i<len; i) {for(int ji1; j<len; j) {for(int kj1; k<len; k) {int xs[i]*…

leetcode:2446. 判断两个事件是否存在冲突(python3解法)

难度&#xff1a;简单 给你两个字符串数组 event1 和 event2 &#xff0c;表示发生在同一天的两个闭区间时间段事件&#xff0c;其中&#xff1a; event1 [startTime1, endTime1] 且event2 [startTime2, endTime2] 事件的时间为有效的 24 小时制且按 HH:MM 格式给出。 当两个…

Windows 基于Visual Studio 开发Qt 6 连接MySQL 8

前提条件&#xff1a; 1、Visual Studio 2022 社区版(免费版) 2、Qt-6.5.1版本 3、MySQL 8 Qt 6 配置MySQL 8 动态/静态连接库和MySQL 8 驱动。 libmysql.dll 和libmysql.lib是QT所需的动态和静态链接库&#xff1b;qsqlmysql.dll 和qsqlmysql.dll.debug是Qt所需的mysql驱…

机器人过程自动化(RPA)入门 1. 什么是机器人过程自动化?

如今&#xff0c;我们生活中几乎没有任何方面不受自动化的影响。一些例子包括洗衣机、微波炉、汽车和飞机的自动驾驶模式&#xff0c;雀巢在日本的商店里使用机器人销售咖啡豆&#xff0c;沃尔玛在美国测试无人机送货&#xff0c;我们的银行支票使用光学字符识别&#xff08;OC…

【Linux】调试代码的工具 - gdb

1、安装gdb sudo yum -y install gdb【安装gdb】 2、gdb的使用 在 Linux 下&#xff0c;我们编写代码默认以 release 方式发布如果想让代码以 debug 方式发布&#xff0c;必须给 gcc 添加 -g 选项 (gdb) q / quit【退出gdb】(gdb) l / list&#xff08;list可缩写为 l&#xf…

C++的文件操作

文件操作 程序运行时产生的数据都属于临时数据&#xff0c;通过文件可将数据持久化 C中对文件操作需要包含头文件<fstream> 文件类型分为两种&#xff1a; 文本文件 - 文件以文本的ASCII码形式存储在计算机中二进制文件 - 文件以文本的二进制形式存储在计算机中&…

如何取消自动播放音乐:取消手机汽车连上后汽车自动播放音乐?

背景 手机和汽车通过蓝牙连接上之后&#xff0c;汽车音响会自动播放手机上的音乐&#xff0c;似乎是自动唤醒APP的&#xff0c;因为这些音乐APP在手机上是已经被杀了后台的了。 而且汽车的屏幕的播放列表里头会显示播放的音乐的名称&#xff0c;也有可能是视频的名称&#xf…

安卓备份分区----手动查询安卓系统分区信息 导出系统分区的一些基本操作

在玩机搞机过程中。有时候需要手动查看有些分区信息&#xff0c;或者备份分区的操作。那么今天以小米8为例解析下其中的操作步骤 机型&#xff1a;小米8 adb版本&#xff1a;https://developer.android.com/studio/releases/platform-tools 机型芯片&#xff1a;高通骁龙845…

基于微信小程序的校园商铺系统,附源码、数据库

文章目录 第一章 简介第二章 技术栈第三章&#xff1a;总体设计第四章系统详细设计4.1 前台功能模块4.2后台功能模块4.2.1管理员功能模块 五 源码咨询 第一章 简介 今天&#xff0c;为大家带来的事基于微信小程序的校园商铺系统。本系统的主要意义在于&#xff0c;全力以赴为用…

Redis双写一致性、持久化机制、分布式锁

一.双写一致性: 含义:当数据库中的数据被修改了以后&#xff0c;我们也需要同时修改缓存&#xff0c;使缓存和数据库的数据保持一致 &#xff08;1&#xff09;读操作:当请求发来的时候&#xff0c;先去看redis里面是否有对应的数据&#xff0c;如果有直接返回&#xff0c;如果…

轻量级的日志采集组件 Filebeat 讲解与实战操作

文章目录 一、概述二、Kafka 安装三、Filebeat 安装1&#xff09;下载 Filebeat2&#xff09;Filebeat 配置参数讲解3&#xff09;filebeat.prospectors 推送kafka完整配置1、filebeat.prospectors2、processors3、output.kafka 4&#xff09;filebeat.inputs 与 filebeat.pros…

【STL】vector常见用法及模拟实现(附源码)

目录 前言1. vector介绍及使用1.1vector的介绍1.2 vector的使用1.2.1 构造函数 1.2.2 vector对象遍历1.2.3 reserve和resize1.2.4 insert和erase 2. vector模拟实现2.1 vector迭代器失效问题2.2 模拟实现reserve函数浅拷贝问题2.3模拟实现源码2.3.1 vector.h2.3.2 test.cpp 前言…

org.postgresql.util.PSQLException: Bad value for type long

项目用 springbootmybatis mybatisplus&#xff0c; 数据库是&#xff1a;postgresql 。 执行查询时候返回错误。 org.springframework.dao.DataIntegrityViolationException: Error attempting to get column city_id from result set. Cause: org.postgresql.util.PSQLExce…

如何让ChatGPT为留学生所用?

“我们这一届学Data Analyics和Data Science的没一个找到工作的。”朋友饭桌上的闲话让研究生才算踏入DA圈子的我瑟瑟发抖。 还没开始正式求职的我&#xff0c;似乎已经被宣告失业了。而这一切都要“归功”于以ChatGPT为代表的大语言模型&#xff08;LLMs&#xff09;。 问世不…

接口测试练习步骤

在接触接口测试过程中补了很多课&#xff0c; 终于有点领悟接口测试的根本&#xff1b; 偶是个实用派&#xff5e;&#xff0c;那么现实中没有用的东西&#xff0c;基本上我都不会有很大的概念&#xff1b; 下面给的是接口测试的统一大步骤&#xff0c;其实就是让我们对接口…

第9章 【MySQL】InnoDB的表空间

表空间 是一个抽象的概念&#xff0c;对于系统表空间来说&#xff0c;对应着文件系统中一个或多个实际文件&#xff1b;对于每个独立表空间来说&#xff0c;对应着文件系统中一个名为 表名.ibd 的实际文件。大家可以把表空间想象成被切分为许许多多个 页 的池子&#xff0c;当我…