【深度学习】四大图像分类网络之VGGNet

news2024/12/26 19:15:54

2014年,牛津大学计算机视觉组(Visual Geometry Group)和Google DeepMind公司一起研发了新的卷积神经网络,并命名为VGGNet。VGGNet是比AlexNet更深的深度卷积神经网络,该模型获得了2014年ILSVRC竞赛的第二名,第一名是GoogLeNet。

论文原文:Very Deep Convolutional Networks for Large-Scale Image Recognition

本文将首先介绍 VGGNet 的基本结构,然后讲述VGGNet的创新点,最后给出了基于pytorch的VGGNet代码实现。

一、 网络结构

VGG 的结构与 AlexNet 类似,区别是深度更深,但形式上更加简单,提出了模块化的概念。VGG由5层卷积层、3层全连接层、1层softmax输出层构成,层与层之间使用maxpool(最大化池)分开,所有隐藏层的激活单元都采用ReLU。作者在原论文中,根据卷积层不同的子层数量,设计了A、A-LRN、B、C、D、E这6种网络结构。

VGGNet包含很多级别的网络,深度从11层到19层不等。为了解决初始化(权重初始化)等问题,VGG采用的是一种Pre-training的方式,先训练浅层的的简单网络VGG11,再复用VGG11的权重初始化VGG13,如此反复训练并初始化VGG19,能够使训练时收敛的速度更快。比较常用的是VGGNet-16和VGGNet-19。

二、创新点

1. 使用多个小卷积核构成的卷积层代替较大的卷积层

两个3x3卷积核的堆叠相当于5x5卷积核的视野,三个3x3卷积核的堆叠相当于7x7卷积核的视野。一方面减少参数,另一方面拥有更多的非线性变换,增加了CNN对特征的学习能力。VGGNet还提出,用的卷积核越小,效果会越好(越小的话,可以有更多的卷积核数量,网络构建会越深)。

感受野是什么呢?在卷积神经网络中,决定某一层输出结果中一个元素所对应的输入层的区域大小,被称作感受野。通俗的来说就是,输出特征图上的一个单元对应输入层上的区域大小。


2. 引入1*1的卷积核

在不影响输入输出维度的情况下,引入更多非线性变换,降低计算量,​​​​​同时,还可以用它来整合各通道的信息,并输出指定通道数。并且在网络测试阶段,将训练阶段的3个全连接替换为3个卷积,测试重新用训练时的参数,使得测试得到的全卷积网络因为没有全连接的限制,因而可以接收任意宽或高的输入。

3. 预训练策略

训练时,先训练级别简单(层数较浅)的VGGNet的A级网络,然后使用A网络的权重来初始化后面的复杂模型,加快训练的收敛速度;

4. 采用Multi-Scale方法来做数据增强

在训练时,将同一张图片缩放到不同的尺寸,在随机剪裁到224*224的大小,能够增加数据量。在预测时,将同一张图片缩放到不同尺寸做预测,最后取平均值。增加训练的数据量,防止模型过拟合。

三、 代码

import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import math
 
__all__ = [
    'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
    'vgg19_bn', 'vgg19',
]
 
model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
}
 
class VGG(nn.Module):
    def __init__(self, features, num_classes=1000):
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )
        self._initialize_weights()
 
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
 
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()
 
def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)
 
 
cfg = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
 
def vgg11(pretrained=False, model_root=None, **kwargs):
    """VGG 11-layer model (configuration "A")"""
    model = VGG(make_layers(cfg['A']), **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg11'], model_root))
    return model
 
def vgg11_bn(**kwargs):
    """VGG 11-layer model (configuration "A") with batch normalization"""
    kwargs.pop('model_root', None)
    return VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
 
def vgg13(pretrained=False, model_root=None, **kwargs):
    """VGG 13-layer model (configuration "B")"""
    model = VGG(make_layers(cfg['B']), **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg13'], model_root))
    return model
 
def vgg13_bn(**kwargs):
    """VGG 13-layer model (configuration "B") with batch normalization"""
    kwargs.pop('model_root', None)
    return VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
 
def vgg16(pretrained=False, model_root=None, **kwargs):
    """VGG 16-layer model (configuration "D")"""
    model = VGG(make_layers(cfg['D']), **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg16'], model_root))
    return model
 
def vgg16_bn(**kwargs):
    """VGG 16-layer model (configuration "D") with batch normalization"""
    kwargs.pop('model_root', None)
    return VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
 
def vgg19(pretrained=False, model_root=None, **kwargs):
    """VGG 19-layer model (configuration "E")"""
    model = VGG(make_layers(cfg['E']), **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg19'], model_root))
    return model
 
def vgg19_bn(**kwargs):
    """VGG 19-layer model (configuration 'E') with batch normalization"""
    kwargs.pop('model_root', None)
    return VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)

参考资料:

CNN经典网络模型(三):VGGNet简介及代码实现(PyTorch超详细注释版)_vggnet是哪年发明的-CSDN博客icon-default.png?t=O83Ahttps://blog.csdn.net/qq_43307074/article/details/126027852

【论文必读#4:VGGNet】小卷积核成就一代名模,视觉领域重要里程碑_哔哩哔哩_bilibiliicon-default.png?t=O83Ahttps://www.bilibili.com/video/BV1cT411T7uL?spm_id_from=333.788.videopod.sections&vd_source=0dc0c2075537732f2b9a894b24578eed

CNN经典:VGGNet网络详解与PyTorch实现【含完整代码】 - 知乎icon-default.png?t=O83Ahttps://zhuanlan.zhihu.com/p/658741272

深度学习图像处理之VGG网络模型 (超级详细)_vgg模型-CSDN博客icon-default.png?t=O83Ahttps://blog.csdn.net/BIgHAo1/article/details/121105934

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2251914.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Pytest框架学习20--conftest.py

conftest.py作用 正常情况下,如果多个py文件之间需要共享数据,如一个变量,或者调用一个方法 需要先在一个新文件中编写函数等,然后在使用的文件中导入,然后使用 pytest中定义个conftest.py来实现数据,参…

【力扣】389.找不同

问题描述 思路解析 只有小写字母,这种设计参数小的,直接桶排序我最开始的想法是使用两个不同的数组,分别存入他们单个字符转换后的值,然后比较是否相同。也确实通过了 看了题解后,发现可以优化,首先因为t相…

HarmonyOS4+NEXT星河版入门与项目实战(23)------组件转场动画

文章目录 1、控件图解2、案例实现1、代码实现2、代码解释3、实现效果4、总结1、控件图解 这里我们用一张完整的图来汇整 组件转场动画的用法格式、属性和事件,如下所示: 2、案例实现 这里我们对上一节小鱼游戏进行改造,让小鱼在游戏开始的时候增加一个转场动画,让小鱼自…

Wireshark常用功能使用说明

此处用于记录下本人所使用 wireshark 所可能用到的小技巧。Wireshark是一款强大的数据包分析工具,此处仅介绍常用功能。 Wireshark常用功能使用说明 1.相关介绍1.1.工具栏功能介绍1.1.1.时间戳/分组列表概况等设置 1.2.Windows抓包 2.wireshark过滤器规则2.1.wiresh…

Vue3 开源UI 框架推荐 (大全)

一 、前言 💥这篇文章主要推荐了支持 Vue3 的开源 UI 框架,包括 web 端和移动端的多个框架,如 Element-Plus、Ant Design Vue 等 web 端框架,以及 Vant、NutUI 等移动端框架,并分别介绍了它们的特性和资源地址。&#…

探索Python词云库WordCloud的奥秘

文章目录 探索Python词云库WordCloud的奥秘1. 背景介绍:为何选择WordCloud?2. WordCloud库简介3. 安装WordCloud库4. 简单函数使用方法5. 应用场景示例6. 常见Bug及解决方案7. 总结 探索Python词云库WordCloud的奥秘 1. 背景介绍:为何选择Wo…

Kali Linux系统一键汉化中文版及基础使用详细教程

Kali Linux系统一键汉化中文版及基础使用详细教程 引言 Kali Linux是一款基于Debian的Linux发行版,专为渗透测试和网络安全而设计。由于其强大的功能和丰富的工具,Kali Linux在安全领域得到了广泛应用。然而,许多用户在使用Kali Linux时会遇…

网络安全(三):网路安全协议

网络安全协议设计的要求是实现协议过程中的认证性、机密性与不可否认性。网络安全协议涉及网络层、传输层与应用层。 1、网络层安全与IPSec协议、IPSec VPN 1.1、IPSec安全体系结构 IP协议本质上是不安全的额,伪造一个IP分组、篡改IP分组的内容、窥探传输中的IP分…

2. STM32_中断

中断 中断是什么: 打断CPU执行正常的程序,转而处理紧急程序,然后返回原暂停的程序继续运行,就叫中断。 中断的意义: 中断可以高效处理紧急程序,不会一直占用CPU资源。如实时控制、故障处理、处理不确定…

【聚类】主成分分析 和 t-SNE 降维

1 主成分分析PCA PCA 是一种线性降维技术,旨在通过选择具有最大方差的特征方向(称为主成分)来压缩数据,同时尽可能减少信息损失。 1.1 原理 1.2 优缺点 from sklearn.decomposition import PCA import matplotlib.pyplot as plt…

ARM 嵌入式处理器内核与架构深度剖析:解锁底层技术逻辑

目录 一、ARM架构概述 1.1. 优势与特点 1.2. 应用领域 二、ARM内核的主要系列及特点 2.1. ARM内核与架构的关系 2.2. Cortex-A系列 2.2.1. 应用场景 2.2.2. 特点 2.3. Cortex-R系列 2.3.1. 应用场景 2.3.2. 特点 2.4. Cortex-M系列 2.4.1. 应用场景 2.4.2. 特点 …

数据结构 (21)树、森林和二叉树的关系

一、树 定义:树是由一个集合以及在该集合上定义的一种关系构成的。集合中的元素称为树的结点,所定义的关系称为父子关系。当集合为空时,是一棵空树;当集合非空时,有且仅有一个特定的称为根的结点。树中的每个结点可以有…

探索温度计的数字化设计:一个可视化温度数据的Web图表案例

随着科技的发展,数据可视化在各个领域中的应用越来越广泛。在温度监控和展示方面,传统的温度计已逐渐被数字化温度计所取代。本文将介绍一个使用Echarts库创建的温度计Web图表,该图表通过动态数据可视化展示了温度值,并通过渐变色…

计算机网络——数据链路层Mac帧详解

目录 前言 一、以太网 二、Mac帧 三、MTU——最大传输单元 四、Mac帧的传输过程 1.ARP协议 2.RARP协议 前言 在之前,我们学习过网络层的IP协议,了解到IP协议解决了从哪里来,到哪里去的问题,也就是提供了将数据从A到B的能力…

LabVIEW将TXT文本转换为CSV格式(多行多列)

在LabVIEW中,将TXT格式的文本文件内容转换为Excel格式(即CSV文件)是一项常见的数据处理任务,适用于将以制表符、空格或其他分隔符分隔的数据格式化为可用于电子表格分析的形式。以下是将TXT文件转换为Excel(CSV&#x…

响应式编程一、Reactor核心

目录 一、前置知识1、Lambda表达式2、函数式接口 Function3、StreamAPI4、Reactive-Stream1)几个实际的问题2)Reactive-Stream是什么?3)核心接口4)处理器 Processor5)总结 二、Reactor核心1、Reactor1&…

Vue3之弹窗

文章目录 第一步、引入JS第二步、弹框 在前端开发语言Vue3&#xff0c;在管理端如何进行弹窗&#xff1f;下面根据API实现效果。 Element API文档&#xff1a; Element-plus文档 搭建环境可参考博客【 初探Vue3环境搭建与nvm使用】 第一步、引入JS <script lang"ts&…

w~大模型~合集24

我自己的原文哦~ https://blog.51cto.com/whaosoft/12707697 #Time Travelling Pixels (TTP) 一种名为“时空旅行”&#xff08;TTP&#xff09;的新方法&#xff0c;该方法将SAM基础模型的通用知识整合到变化检测任务中。该方法有效地解决了在通用知识转移中的领域偏移问题…

git的简单使用与gdb

版本控制器git 为了能够更方便管理这些不同版本的文件&#xff0c;有了版本控制器&#xff0c;可以了解一个文件的历史&#xff0c;以及它的发展过程的系统&#xff0c;通俗的说就是一个可以记录工程的每一次改动和版本迭代的一个管理系统&#xff0c;同时也方便多人协作。 三…

从0开始学PHP面向对象内容之常用设计模式(策略,观察者)

PHP设计模式——行为型模式 PHP 设计模式中的行为模式&#xff08;Behavioral Patterns&#xff09;主要关注对象之间的通信和交互。行为模式的目的是在不暴露对象之间的具体通信细节的情况下&#xff0c;定义对象的行为和职责。它们常用于解决对象如何协调工作的问题&#xff…