【从零开始学习深度学习】27.卷积神经网络之VGG11模型介绍及其Pytorch实现【含完整代码】

news2025/1/17 1:17:26

目录

    • 1. VGG块介绍
    • 2. 构造VGG网络模型
    • 3. 获取Fashion-MNIST数据并用VGG-11训练模型
    • 4.总结

AlexNet在LeNet的基础上增加了3个卷积层。但AlexNet对卷积窗口、输出通道数和构造顺序均做了大量的调整。虽然AlexNet模型表明深度卷积神经网络可以取得出色的结果,但并没有提供相应规则以指导后来的研究者如何设计新的网络。我们将在后续介绍几种不同的深度网络设计思路。

本文将介绍VGG网络模型,VGG主要思路是通过重复使用简单的基础块来构建深度模型。

1. VGG块介绍

VGG块的组成规律是:连续使用数个相同的填充为1、窗口形状为 3 × 3 3\times 3 3×3的卷积层后接上一个步幅为2、窗口形状为 2 × 2 2\times 2 2×2的最大池化层。卷积层保持输入的高和宽不变,而池化层则对其减半

3x3卷积的优点:

多个3×3的卷积层比一个大尺寸的filter有更少的参数,假设卷基层的输入和输出的特征图大小相同为C,那么三个3×3的卷积层参数个数3×(3×3×C×C)=27CC;一个7×7的卷积层参数为49CC;所以可以把三个3×3的filter看成是一个7×7filter的分解(中间层有非线性的分解)。

下面我们定义一个vgg_block函数来实现这个基础的VGG块,它可以指定卷积层的数量和输入输出通道数。

对于给定的感受野(与输出有关的输入图片的局部大小),采用堆积的小卷积核优于采用大的卷积核,因为可以增加网络深度来保证学习更复杂的模式,而且代价还比较小(参数更少)。例如,在VGG中,使用了3个3x3卷积核来代替7x7卷积核,使用了2个3x3卷积核来代替5*5卷积核,这样做的主要目的是在保证具有相同感知野的条件下,提升了网络的深度,在一定程度上提升了神经网络的训练效果。

import time
import torch
from torch import nn, optim
import sys
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def vgg_block(num_convs, in_channels, out_channels):
    blk = []
    for i in range(num_convs):
        if i == 0:
            blk.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        else:
            blk.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
        blk.append(nn.ReLU())
    blk.append(nn.MaxPool2d(kernel_size=2, stride=2)) # 这里会使宽高减半
    return nn.Sequential(*blk)

2. 构造VGG网络模型

VGG采用的是一种Pre-training的方式,先训练浅层的的简单网络 VGG11,再复用 VGG11 的权重来初始化 VGG13,如此反复训练并初始化 VGG19,能够使训练时收敛的速度更快。整个网络都使用卷积核尺寸为 3×3 和最大池化尺寸 2×2。比较常用的VGG-16的16指的是conv+fc的总层数是16,是不包括max pool的层数!

下图中最左侧的A列表示最原始的VGG11,因为这个网络使用了8个卷积层和3个全连接层,所以被称为VGG-11。

img

VGG与AlexNet和LeNet一样,VGG网络由卷积层模块后接全连接层模块构成。卷积层模块串联数个vgg_block,其超参数由变量conv_arch定义。该变量指定了每个VGG块里卷积层个数和输入输出通道数。全连接模块则跟AlexNet中的一样。

现在我们构造一个VGG网络。它有5个卷积块,前2块使用单卷积层,而后3块使用双卷积层。第一块的输入输出通道分别是1(因为下面要使用的Fashion-MNIST数据的通道数为1)和64,之后每次对输出通道数翻倍,直到变为512。

conv_arch = ((1, 1, 64), (1, 64, 128), (2, 128, 256), (2, 256, 512), (2, 512, 512))
# 经过5个vgg_block, 宽高会减半5次, 变成 224/32 = 7
fc_features = 512 * 7 * 7 # c * w * h
fc_hidden_units = 4096 # 任意

下面我们实现VGG-11。

def vgg(conv_arch, fc_features, fc_hidden_units=4096):
    net = nn.Sequential()
    # 卷积层部分
    for i, (num_convs, in_channels, out_channels) in enumerate(conv_arch):
        # 每经过一个vgg_block都会使宽高减半
        net.add_module("vgg_block_" + str(i+1), vgg_block(num_convs, in_channels, out_channels))
    # 全连接层部分
    net.add_module("fc", nn.Sequential(d2l.FlattenLayer(),
                                 nn.Linear(fc_features, fc_hidden_units),
                                 nn.ReLU(),
                                 nn.Dropout(0.5),
                                 nn.Linear(fc_hidden_units, fc_hidden_units),
                                 nn.ReLU(),
                                 nn.Dropout(0.5),
                                 nn.Linear(fc_hidden_units, 10)
                                ))
    return net

下面构造一个高和宽均为224的单通道数据样本来观察每一层的输出形状。

net = vgg(conv_arch, fc_features, fc_hidden_units)
X = torch.rand(1, 1, 224, 224)

# named_children获取一级子模块及其名字(named_modules会返回所有子模块,包括子模块的子模块)
for name, blk in net.named_children(): 
    X = blk(X)
    print(name, 'output shape: ', X.shape)

输出:

vgg_block_1 output shape:  torch.Size([1, 64, 112, 112])
vgg_block_2 output shape:  torch.Size([1, 128, 56, 56])
vgg_block_3 output shape:  torch.Size([1, 256, 28, 28])
vgg_block_4 output shape:  torch.Size([1, 512, 14, 14])
vgg_block_5 output shape:  torch.Size([1, 512, 7, 7])
fc output shape:  torch.Size([1, 10])

可以看到,每次我们将输入的高和宽减半,直到最终高和宽变成7后传入全连接层。与此同时,输出通道数每次翻倍,直到变成512。因为每个卷积层的窗口大小一样,所以每层的模型参数尺寸和计算复杂度与输入高、输入宽、输入通道数和输出通道数的乘积成正比。VGG这种高和宽减半以及通道翻倍的设计使得多数卷积层都有相同的模型参数尺寸和计算复杂度。

3. 获取Fashion-MNIST数据并用VGG-11训练模型

因为VGG-11计算比AlexNet更加复杂,出于测试的目的我们构造一个通道数更小的网络在Fashion-MNIST数据集上进行训练。

ratio = 8
small_conv_arch = [(1, 1, 64//ratio), (1, 64//ratio, 128//ratio), (2, 128//ratio, 256//ratio), 
                   (2, 256//ratio, 512//ratio), (2, 512//ratio, 512//ratio)]
net = vgg(small_conv_arch, fc_features // ratio, fc_hidden_units // ratio)
print(net)

输出:

Sequential(
  (vgg_block_1): Sequential(
    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (vgg_block_2): Sequential(
    (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (vgg_block_3): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (vgg_block_4): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (vgg_block_5): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): FlattenLayer()
    (1): Linear(in_features=3136, out_features=512, bias=True)
    (2): ReLU()
    (3): Dropout(p=0.5)
    (4): Linear(in_features=512, out_features=512, bias=True)
    (5): ReLU()
    (6): Dropout(p=0.5)
    (7): Linear(in_features=512, out_features=10, bias=True)
  )
)

模型训练过程与之前的AlexNet类似。

batch_size = 64
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)

lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

输出:

training on  cuda
epoch 1, loss 0.0101, train acc 0.755, test acc 0.859, time 255.9 sec
epoch 2, loss 0.0051, train acc 0.882, test acc 0.902, time 238.1 sec
epoch 3, loss 0.0043, train acc 0.900, test acc 0.908, time 225.5 sec
epoch 4, loss 0.0038, train acc 0.913, test acc 0.914, time 230.3 sec
epoch 5, loss 0.0035, train acc 0.919, test acc 0.918, time 153.9 sec

4.总结

  • VGG-11通过5个可以重复使用的卷积块来构造网络。根据每块里卷积层个数和输出通道数的不同可以定义出不同的VGG模型。

对文章存在的问题,或者其他关于Python相关的问题,都可以在评论区留言或者私信我哦

如果文章内容对你有帮助,感谢点赞+关注!

关注下方GZH:阿旭算法与机器学习,可获取更多干货内容~欢迎共同学习交流

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/93733.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++ Reference: Standard C++ Library reference: Containers: map: map: emplace

C官网参考链接&#xff1a;https://cplusplus.com/reference/map/map/emplace/ 公有成员函数 <map> std::map::emplace template <class... Args> pair<iterator,bool> emplace (Args&&... args);构造并插入元素 如果元素的键是唯一的&#xff0c;…

【沙拉查词】沙拉查词配置教程——如何实现截图OCR翻译、截图翻译?

一、问题背景 2022年12月16日&#xff0c;沙拉查词仍然没有截图翻译的功能。 这个功能&#xff0c;在百度翻译上虽然能够实现&#xff0c;但是要额外下一个软件和挂在后台&#xff0c;总是觉得麻烦。 二、解决方法 如果你是一个quicker软件使用者&#xff0c;那么通过添加「…

python中的字典详解

目录 一.思考 二.字典定义 注意 三.字典数据的获取 注意 字典的嵌套 四.字典常用操作 1.新增、更新元素 2.删除元素 3.清空字典 4.获取全部Key 5.利用Key遍历字典 五.字典总结 六.字典实例 一.思考 为什么需要字典? 生活中的字典我们可以根据【字】来找到对应的【含…

QT-Linux安装

1、在虚拟机Ubuntu的环境安装好之后&#xff0c;详细看&#xff1a; QT Linux环境搭建——VM虚拟机和Ubuntu的安装_sgmcy的博客-CSDN博客 下面就可以直接安装linux环境下的QT了 2、唯一要注意的一点是&#xff0c;之前安装虚拟机的时候&#xff0c;磁盘空间一定要大一点&…

第十四届蓝桥杯集训——JavaC组第十四篇——嵌套循环

第十四届蓝桥杯集训——JavaC组第十四篇——循环嵌套 目录 第十四届蓝桥杯集训——JavaC组第十四篇——循环嵌套 循环嵌套是逻辑程序中的方法 对应嵌套的循环复杂度 嵌套循环示例&#xff1a; 名词解析&#xff1a; 笛卡尔积 循序命名 循环嵌套是逻辑程序中的方法 循环嵌…

“无人区”行驶8年,李诞的脱口秀路在何方?

刚从《脱口秀大会5》的舞台上下来的李诞&#xff0c;给自己找了份“新工作”——脱口秀直播带货。 12月10日晚&#xff0c;李诞入淘。讲段子、玩梗手到擒来&#xff0c;李诞将自己风趣幽默的脱口秀风格沿用到了这场“来个彩诞”直播首秀中&#xff0c;给交个朋友贡献了超3200万…

5.单点登录(Vue2.x)

概况 百度百科 单点登录&#xff08;Single Sign On&#xff09;&#xff0c;简称为 SSO&#xff0c;是比较流行的企业业务整合的解决方案之一。SSO的定义是在多个应用系统中&#xff0c;用户只需要登录一次就可以访问所有相互信任的应用系统。 关键词 token、session、cooki…

Java基础之《netty(13)—任务队列taskQueue》

一、任务队列 1、用户程序自定义的普通任务 2、用户自定义定时任务 3、非当前Reactor线程调用Channel的各种方法 例如在推送系统的业务线程里面&#xff0c;根据用户的标识&#xff0c;找到对应的Channel引用&#xff0c;然后调用Write类方法向该用户推送消息&#xff0c;就…

基于java+springmvc+mybatis+vue+mysql的养老院管理系统

项目介绍 管理员后台页面&#xff1a; 功能&#xff1a;主页、个人中心、护工管理、家属管理、楼房资料管理、房间资料管理、床位管理、老人入住管理、老人档案管理、身体状态管理、用药情况管理、转房登记管理、外出登记管理、药品信息管理、药品入库管理、药品出库管理、物品…

【C语言】整型的存储方式(大小端,原码,反码,补码)

目录 一、基本类型 二、原码&#xff0c;反码&#xff0c;补码 2.1 原&#xff0c;反&#xff0c;补的计算方式 2.1.1 正数的原&#xff0c;反&#xff0c;补 2.1.2 负数的原&#xff0c;反&#xff0c;补 2.2 为什么要用补码存放 2.3 大小端是什么&#xff1f; 2.3.1 …

明道云联合契约锁共建人事场景电子签约解决方案

背景介绍 在每个组织的人事管理工作中&#xff0c;从招聘、入职、在职、调岗到离职&#xff0c;整个过程中存在大量的合同、证明、函件、通知等文件需要签字盖章。HR每天都要在“核对文件、敲章、通知员工签合同、催进度、给外地员工寄合同、关注合同到期时间等”繁琐的签署工…

使用vite和Element Plus,实现部署后不修改代码/打包,新增主题/皮肤包

Web前端界面切换主题/皮肤&#xff0c;是一个常见的需求。如果希望在打包部署后实现皮肤的修改甚至增加皮肤&#xff0c;不需要修改源码或者重新打包&#xff0c;类似于我们常见的皮肤包扩展&#xff0c;又该如何实现呢&#xff1f; 我使用类似上一期多语言包功能中介绍的方法来…

基于Xlinx的时序分析与约束(3)----基础概念(下)

1、4种基本的时序路径 下图是一张典型的FPGA与上游器件、下游器件通信的示意图&#xff1a; 其可以划分为4条基本的数据路径&#xff0c;这4条路径也是需要进行时序约束的最基本路径。 &#xff08;1&#xff09;寄存器到寄存器 路径2&#xff0c;FPGA内部的寄存器到另一个寄存…

[附源码]Node.js计算机毕业设计高校医疗健康服务系统的设计与实现Express

项目运行 环境配置&#xff1a; Node.js最新版 Vscode Mysql5.7 HBuilderXNavicat11Vue。 项目技术&#xff1a; Express框架 Node.js Vue 等等组成&#xff0c;B/S模式 Vscode管理前后端分离等等。 环境需要 1.运行环境&#xff1a;最好是Nodejs最新版&#xff0c;我…

【C++初阶】类和对象(下)再谈构造函数、static成员、C++11的成员初始化新玩法、友元类、内部类

文章目录再谈构造函数static成员C11的成员初始化新玩法友元类内部类再谈构造函数 1.构造函数体赋值 在创建对象时&#xff0c;编译器通过调用构造函数&#xff0c;给对象中各个成员变量一个合适的初始值。 虽然上述构造函数调用之后&#xff0c;对象中已经有了一个初始值&am…

客户管理系统如何提升体验

数字化时代&#xff0c;客户与企业交互的触点爆炸式增长&#xff0c;客户体验正从单一触点走向端到端旅程。众多的产品、海量的数据&#xff0c;导致客户对体验的要求越来越多......CRM客户管理系统是企业提升客户体验的有效工具&#xff0c;它不仅可以帮助您进一步了解客户&am…

App自动化之dom结构和元素定位方式(包含滑动列表定位)

先来看几个名词和解释&#xff1a; dom: Document Object Model 文档对象模型dom应用: 最早应用于html和js的交互。界面的结构化描述&#xff0c; 常见的格式为html、xml。核心元素为节点和属性xpath: xml路径语言&#xff0c;用于xml 中的节点定位&#xff0c;XPath 可在 xml…

javaSE - 三个常用的接口(Comparable,Comparator,Cloneable)

1、Comparable 英 [ˈkɒmpərəbl] 美 [ˈkɑːmpərəbl] 可比较的;可比的;可比性;可比;可比较   2、Comparator 美 [kəmˈpɜrətər] n. 比较器&#xff0c;比色器&#xff0c;比较电路&#xff0c;比长仪&#xff0c;场强计 3、 Cloneable 可复制的 一、Comparable …

MySQL性能优化浅析

1. 硬件 1.1 CPU IO密集型&#xff0c;提升CPU核心数 计算密集型&#xff0c;提升CPU频率 1.2 磁盘 机械硬盘在随机访问时&#xff0c;由于受磁针移动速度的限制&#xff0c;性能会大幅降低。使用固态硬盘可以大幅提升随机访问的能力。按需选择。 1.3 其他 带宽、内存频…

Superset 安装配置

文章目录Superset 安装配置一、Superset 概述1. Superset简介2. 功能概述3. 支持的数据库二、Superset 环境部署步骤三、创建虚拟机&#xff0c;安装CentOS1.下载CentOS2.创建虚拟机3.编辑虚拟机设置4.安装centos7.9mini版本5.启动centos&#xff0c;并进行登录四、CentOS配置1…