VGG网络分析与demo实例

news2024/9/27 21:20:19

参考自 

  • up主的b站链接:霹雳吧啦Wz的个人空间-霹雳吧啦Wz个人主页-哔哩哔哩视频
  • 这位大佬的博客 Fun'_机器学习,pytorch图像分类,工具箱-CSDN博客

VGG 在2014年由牛津大学著名研究组 VGGVisual Geometry Group)提出,斩获该年 ImageNet 竞赛中 Localization Task(定位任务)第一名和 Classification Task(分类任务)第二名。

VGG 的创新点:

通过堆叠多个小卷积核来替代大尺度卷积核,可以减少训练参数,同时能保证相同的感受野。
论文中提到,可以通过堆叠两个3×3的卷积核替代5x5的卷积核,堆叠三个3×3的卷积核替代7x7的卷积核

1. CNN感受野

在卷积神经网络中,决定某一层输出结果中一个元素所对应的输入层的区域大小,被称作感受野(receptive field)。
通俗的解释是,输出feature map上的一个单元 对应 输入层上的区域大小。

以下图为例,输出层 layer3 中一个单元 对应 输入层 layer2 上区域大小为2×2(池化操作),对应输入层 layer1 上大小为5×5
(可以这么理解,layer2中 2×2区域中的每一块对应一个3×3的卷积核,又因为 stride=2,所以layer1的感受野为5×5)
 

现在,我们来验证下VGG论文中的两点结

VGG网络有多个版本,一般常用的是VGG-16模型,其网络结构如下如所示:

pytorch搭建VGG网络

import torch.nn as nn
import torch

class VGG(nn.Module):
    def __init__(self, features, num_classes=1000, init_weights=False):
        super(VGG, self).__init__()
        self.features = features			# 卷积层提取特征
        self.classifier = nn.Sequential(	# 全连接层进行分类
            nn.Dropout(p=0.5),
            nn.Linear(512*7*7, 2048),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(True),
            nn.Linear(2048, num_classes)
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.features(x)
        # N x 512 x 7 x 7
        x = torch.flatten(x, start_dim=1)
        # N x 512*7*7
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                # nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

这里有一点需要注意的是:

VGG网络有 VGG-13、VGG-16等多种网络结构

# vgg网络模型配置列表,数字表示卷积核个数,'M'表示最大池化层
cfgs = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],											# 模型A
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],									# 模型B
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],					# 模型D
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 	# 模型E
}

# 卷积层提取特征
def make_features(cfg: list): # 传入的是具体某个模型的参数列表
    layers = []
    in_channels = 3		# 输入的原始图像(rgb三通道)
    for v in cfg:
        # 最大池化层
        if v == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        # 卷积层
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            layers += [conv2d, nn.ReLU(True)]
            in_channels = v
    return nn.Sequential(*layers)  # 单星号(*)将参数以元组(tuple)的形式导入


def vgg(model_name="vgg16", **kwargs):  # 双星号(**)将参数以字典的形式导入
    try:
        cfg = cfgs[model_name]
    except:
        print("Warning: model number {} not in cfgs dict!".format(model_name))
        exit(-1)
    model = VGG(make_features(cfg), **kwargs)
    return model

train.py

model_name = "vgg16"
net = vgg(model_name=model_name, num_classes=5, init_weights=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1334245.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【交叉编译环境】安装arm-linux交叉编译环境到虚拟机教程(简洁版本)

就是看到了好些教程有些繁琐,我就写了一个 我这个解压安装的交叉编译环境是Linaro GCC的一个版本,可以用于在x86_64的主机上编译arm-linux-gnueabihf的目标代码 步骤来了 在你的Ubuntu系统中创建一个目录,例如/usr/local/arm,然后…

day48算法训练|动态规划part09

198.打家劫舍 1. dp数组(dp table)以及下标的含义 dp[i]:考虑下标i(包括i)以内的房屋,最多可以偷窃的金额为dp[i]。 2.递推公式 决定dp[i]的因素就是第i房间偷还是不偷。 如果偷第i房间,那么…

通信原理 | 通信中有哪些量的单位是dB?

在通信领域中,分贝(dB)被用来表示各种不同的量,包括 信号强度功率电压以下是通信中常用的几种用分贝表示的量: 信号强度(Signal Strength) 通信设备发送或接收信号时,信号的强度可以用分贝来表示,通常以分贝毫瓦(dBm)为单位。 0 dBm表示1毫瓦的信号强度,负值表示…

Wafer晶圆封装工艺介绍

芯片封装的目的(The purpose of chip packaging): 芯片上的IC管芯被切割以进行管芯间连接,通过引线键合连接外部引脚,然后进行成型,以保护电子封装器件免受环境污染(水分、温度、污染物等)&…

基于ssm出租车管理系统的设计与实现论文

摘 要 现代经济快节奏发展以及不断完善升级的信息化技术,让传统数据信息的管理升级为软件存储,归纳,集中处理数据信息的管理方式。本出租车管理系统就是在这样的大环境下诞生,其可以帮助管理者在短时间内处理完毕庞大的数据信息&…

【视觉实践】使用Mediapipe进行图像分割实践

目录 1 Mediapipe 2 Solutions 3 安装依赖库 4 实践 1 Mediapipe Mediapipe是google的一个开源项目,可以提供开源的、跨平台的常用机器学习(machine learning,ML)方案。MediaPipe是一个用于构建机器学习管道的框架,用于处理视频、音频等时间序列数据。与资源消耗型的机…

Linux开发工具——gcc篇

gcc的使用 文章目录 gcc的使用 历史遗留问题(普通用户sudo) gcc编译过程 预处理(进行宏替换) 编译(生成汇编) 汇编(生成机器可识别代码) 链接(生成可执行文件或库文件&a…

使用 OpenTelemetry 和 Loki 实现高效的应用日志采集和分析

在之前的文章陆续介绍了 如何在 Kubernetes 中使用 Otel 的自动插桩 以及 Otel 与 服务网格协同实现分布式跟踪,这两篇的文章都将目标聚焦在分布式跟踪中,而作为可观测性三大支柱之一的日志也是我们经常使用的系统观测手段,今天这篇文章就来体…

MySQL学生向笔记以及使用过程问题记录(内含8.0.34安装教程

MySQL 只会写代码 基本码农 要学好数据库,操作系统,数据结构与算法 不错的程序员 离散数学、数字电路、体系结构、编译原理。实战经验, 高级程序员 去IOE:去掉IBM的小型机、Oracle数据库、EMC存储设备,代之以自己在开源…

通达OA header身份认证绕过漏洞复现

通达OA是中国通达公司的一套协同办公自动化软件,通达OA2013,通达OA2016,通达OA2017 存在身份认证绕过漏洞,攻击者可以利用漏洞生成cookie,实现未授权访问。 1.漏洞级别 高危 2.漏洞搜索 fofa title"office An…

【小白专用】Apache下禁止显示网站目录结构的方法 更新23.12.25

给我一个网站地址,我点开后显示的是目录格式,把网站的目录结构全部显示出来了 这个显示结果不正确,不应该让用户看到我们的目录结构 配置文件的问题,apache配置文件里有一项可以禁止显示网站目录的配置项,禁止掉就好了 在apache…

【Mathematical Model】Ransac线性回归Python代码

Ransac算法,也称为随机抽样一致性算法,是一种迭代方法,用于从一组包含噪声或异常值的数据中估计数学模型。Ransac算法特别适用于线性回归问题,因为它能够处理包含异常值的数据集,并能够估计出最佳的线性模型。 1 简介 …

Java 的 8 种异步实现方式

一、前言 异步执行对于开发者来说并不陌生,在实际的开发过程中,很多场景多会使用到异步,相比同步执行,异步可以大大缩短请求链路耗时时间,比如:「发送短信、邮件、异步更新等」 ,这些都是典型的…

将ipynb文件转为py的简单方法(图文并茂)

打开可以使用jupyter命令的命令窗口(如果没有jupyter则需要先安装jupyter),cd 命令进入到 ipynb 文件所在的文件夹,执行 jupyter nbconvert --to script xxx.ipynb 即可完成 ipynb 文件到 py 文件的转化,执行 jupyter …

每秒生成110张图像!StreamDiffusion开源 实时图像生成更强了

StreamDiffusion是一个开源项目,最近在推特上引起了热烈讨论。这个项目基于LCM和SDXL Turbo技术,每秒能够生成110张图像,为想要开发实时图像生成产品的人提供了一个值得关注的资源。这个项目主要是为了实时图像生成服务而设计的,并…

基于java的汽车维修保养智能预约系统论文

摘 要 信息数据从传统到当代,是一直在变革当中,突如其来的互联网让传统的信息管理看到了革命性的曙光,因为传统信息管理从时效性,还是安全性,还是可操作性等各个方面来讲,遇到了互联网时代才发现能补上自古…

探索 HTTP 请求的世界:get 和 post 的奥秘(上)

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…

kubevela 安装(windows、minikube)

minikube 启动指定版本的 k8s minikube start --kubernetes-versionv1.23.4 根据 kubevela 的 官网 提示,执行 powershell -Command "iwr -useb https://kubevela.net/script/install.ps1 | iex" 出现如下问题,根据提示执行 Set-ExecutionPo…

【Linux系统基础】(2)在Linux上部署MySQL、RabbitMQ、ElasticSearch等各类软件

实战章节:在Linux上部署各类软件 前言 为什么学习各类软件在Linux上的部署 在前面,我们学习了许多的Linux命令和高级技巧,这些知识点比较零散,同学们跟随着课程的内容进行练习虽然可以基础掌握这些命令和技巧的使用,…

移动开发git版本控制经验之谈

移动开发git版本控制经验之谈 团队或应用规模是否会影响发布流程?这取决于具体情况。让我们来想象一下一个小型团队的创业公司。在这种情况下,通常是团队开发一个功能,然后直接发布。现在我们再来想象一个大型项目,比如一个银行应…