分类神经网络1:VGGNet模型复现

news2024/12/23 23:49:33

目录

分类网络的常见形式

VGG网络架构

VGG网络部分实现代码


分类网络的常见形式

常见的分类网络通常由特征提取部分分类部分组成。

特征提取部分实质就是各种神经网络,如VGG、ResNet、DenseNet、MobileNet等。其负责捕获数据的有用信息,一般是通过堆叠多个卷积层和池化层来实现的,这些层有助于检测图像中的边缘、纹理和特征。

分类部分通常是一个全连接层,负责将特征提取部分输出的信息映射到最终的类别或标签。这些全连接层通常包括一个或多个隐藏层,以及一个输出层,其中输出层的节点数量等于任务中的类别数量。

VGG网络架构

论文原址:https://arxiv.org/pdf/1409.1556v6.pdf

VGG 网络是由牛津大学的Visual Geometry Group 开发的,其结构特点在于使用了多个 3x3 的小卷积核,并通过这些小卷积层的重复堆叠来构建网络,从而能够捕捉到更加复杂和抽象的特征表示。VGG 网络的模型结构如下:

VGG网络的核心架构可以分为以下几个部分:

  1. 输入层:VGG网络接受224x224像素的RGB图像作为输入。
  2. 卷积层:网络的前几层由多个卷积层组成,每个卷积层都使用3x3的卷积核来提取图像的特征。这些卷积层后面通常跟着一个2x2 最大池化,用于逐步减小特征图的空间尺寸,同时增加特征深度。
  3. 池化层:在卷积层之后,网络使用最大池化层来降低特征图的空间分辨率,这有助于减少计算量并提取更加抽象的特征。
  4. 全连接层:经过多个卷积和池化层之后,网络的特征图被展平并通过几个全连接层进行处理。全连接层的作用是将学习到的特征映射到最终的分类结果。
  5. 输出层:VGG网络的最后是一个softmax层,它将网络的输出转换为概率分布,以便进行多类别的分类任务。

VGG网络的一个显著特点是其深度,其相关配置信息如下:

VGG系列不同变体内容如下:

  • VGG A:这是一个基础的配置,没有特别独特的设计。
  • VGG A-LRN:在这个版本中,加入了局部响应归一化(LRN),这是一种在AlexNet中首次使用的技术。不过,LRN在当前的深度学习实践中已经较少被采用。
  • VGG B:相较于A版本,B版本增加了两个卷积层,以增强网络的学习能力。
  • VGG C:在B的基础上,C版本进一步增加了三个卷积层,但这些层使用的是1x1的卷积核。1x1卷积核可以看作是对输入特征图进行线性变换,有助于减少参数数量并增加非线性。
  • VGG D:D版本在C版本的基础上做了调整,将1x1的卷积核替换为3x3的卷积核,这个配置后来被称为VGG16,因为它总共有16层。
  • VGG E:在D版本的基础上,E版本进一步增加了三个3x3的卷积层,形成了VGG19,总共有19层。

从图中可以看出,随着网络深度的加深,模型变得更为复杂。通常来说,增加网络的深度可以增加模型的表示能力,使其能够学习到更复杂的特征和模式,从而在某些任务上取得更好的性能。然而,随着网络深度的增加,模型的参数数量也会增加,导致模型的复杂度增加,训练和推理的计算成本也会增加,同时可能会增加过拟合的风险。

VGG网络部分实现代码

废话不多说,直接上干货

import torch
import torch.nn as nn

__all__ = ["VGG", "vgg11_bn", "vgg13_bn", "vgg16_bn", "vgg19_bn"]

cfg = {
    'A': [64,     'M', 128,      'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],
    'C': [64, 64, 'M', 128, 128, 'M', 256, 256, 256,      'M', 512, 512, 512,      'M', 512, 512, 512,      'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

class ConvBNReLU(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1,  kernel_size=3, padding=1):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class VGG(nn.Module):
    def __init__(self, features, num_classes=1000, init_weights=True):
        super(VGG, self).__init__()
        self.features = features
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        for layer in self.features:
            x = layer(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

def make_layers(cfg):
    layers = nn.ModuleList()
    in_channels = 3
    for i in cfg:
        if i == 'M':
            layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
        else:
            layers.append(ConvBNReLU(in_channels=in_channels, out_channels=i))
            in_channels = i
    return layers

def vgg11_bn(num_classes):
    model = VGG(make_layers(cfg['A']), num_classes=num_classes)
    return model

def vgg13_bn(num_classes):
    model = VGG(make_layers(cfg['B']), num_classes=num_classes)
    return model

def vgg16_bn(num_classes):
    model = VGG(make_layers(cfg['C']), num_classes=num_classes)
    return model

def vgg19_bn(num_classes):
    model = VGG(make_layers(cfg['D']), num_classes=num_classes)
    return model

if __name__=='__main__':
    import torchsummary
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    input = torch.ones(2, 3, 224, 224).to(device)
    net = vgg16_bn(num_classes=4)
    net = net.to(device)
    out = net(input)
    print(out)
    print(out.shape)
    torchsummary.summary(net, input_size=(3, 224, 224))
    # Total params: 134,285,380

这只是一个网络架构部分实现代码,其中 cfg 列表是 VGG 卷积和池化后的通道数,大家可以结合 VGG 的配置信息图一起对比理解。希望对大家有所帮助呀!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1613402.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第24天:安全开发-PHP应用文件管理模块显示上传黑白名单类型过滤访问控制

第二十四天 一、PHP文件管理-显示&上传功能实现 如果被抓包抓到数据包,并修改Content-Type内容 则也可以绕过筛查 正常进行上传和下载 二、文件上传-$_FILES&过滤机制实现 无过滤机制 黑名单过滤机制 使用 explode 函数通过点号分割文件名,…

Spring Boot | Spring Boot 默认 “缓存管理“ 、Spring Boot “缓存注解“ 介绍

目录: 一、Spring Boot 默认 "缓存" 管理 :1.1 基础环境搭建① 准备数据② 创建项目③ 编写 "数据库表" 对应的 "实体类"④ 编写 "操作数据库" 的 Repository接口文件⑤ 编写 "业务操作列" Service文件⑥ 编写 "applic…

【 AIGC 研究最新方向(下)】面向平面、视觉、时尚设计的高可用 AIGC 研究方向总结

目前面向平面、视觉、时尚等设计领域的高可用 AIGC 方向有以下 4 种: 透明图层生成可控生成图像定制化SVG 生成 本篇(下篇)介绍 3、4,上篇在:https://blog.csdn.net/weixin_44212848/article/details/138035279?spm…

【FFmpeg】视频与图片互相转换 ( 视频与 JPG 静态图片互相转换 | 视频与 GIF 动态图片互相转换 )

文章目录 一、视频与 JPG 静态图片互相转换1、视频转静态图片2、视频转多张静态图片3、多张静态图片转视频 二、视频与 GIF 动态图片互相转换1、视频转成 GIF 动态图片2、 GIF 动态图片转成视频 一、视频与 JPG 静态图片互相转换 1、视频转静态图片 执行 ffmpeg -i input.mp4 …

初始化Git仓库时应该运行哪个命令?

文章目录 初始化Git仓库时,你应该运行git init这个命令。这个命令的作用是在你当前所在的目录里创建一个新的Git仓库。这样,你就可以在这个目录里开始使用Git来管理你的文件了。 下面我给你举个详细的例子来说明一下: 首先,你需要…

# 从浅入深 学习 SpringCloud 微服务架构(三)注册中心 Eureka(3)

从浅入深 学习 SpringCloud 微服务架构(三)注册中心 Eureka(3) 段子手168 1、eureka:高可用的引入 Eureka Server 可以通过运行多个实例并相互注册的方式实现高可用部署, Eureka Server 实例会彼此增量地…

Spark和Hadoop的安装

实验内容和要求 1.安装Hadoop和Spark 进入Linux系统,完成Hadoop伪分布式模式的安装。完成Hadoop的安装以后,再安装Spark(Local模式)。 2.HDFS常用操作 使用hadoop用户名登录进入Linux系统,启动…

ChatGPT研究论文提示词集合2-【形成假设、设计研究方法】

点击下方▼▼▼▼链接直达AIPaperPass ! AIPaperPass - AI论文写作指导平台 目录 1.形成假设 2.设计研究方法 3.书籍介绍 AIPaperPass智能论文写作平台 近期小编按照学术论文的流程,精心准备一套学术研究各个流程的提示词集合。总共14个步骤&#…

Llama3新一代 Llama模型

最近,Meta 发布了 Llama3 模型,从发布的数据来看,性能已经超越了 Gemini 1.5 和 Claud 3。 Llama 官网说,他们未来是要支持多语言和多模态的,希望那天赶紧到来。 未来 Llama3还将推出一个 400B大模型,目前…

Linux--链表 第二十五天

1. 链表 t1.next -> data t1.next->next->data .(点号)的优先级比->的大 所以 t1.next->data 就可以了 不用(t1.next)->data 2. 链表的静态增加和动态遍历 打印链表算法, void printLink(struct Test *head) { struct Te…

安装和部署maven

准备工作 maven下载地址:https://maven.apache.org/download.cgi 使用wget将maven包下载到linux环境上,/toos/ 目录下(也可用迅雷) wget https://dlcdn.apache.org/maven/maven-3/3.9.6/binaries/apache-maven-3.9.6-bin.tar.g…

PaddleOCRV4训练自己的模型(4)------模型推理及导出

一、Det模型推理: (1)上一篇文章只讲了推理的实现方法,没有展示结果,这里顺带展示一下结果。 因为训练定位模型的时候是整图训练,所以推理的时候也是整图推理。 (2)在推理的时候可以…

LinkedList和链表

1.ArrayList的缺陷 ArraryList由于底层是一段连续的空间,所以在ArrayList任意位置插入或者删除元素时,就 需要将后续元素往前或者往后搬移,时间复杂度为O(n),效率比较低,因此ArrayList不适合做任意位置插入和删除比较…

断言(Assertion)在IT技术中的确切含义— 基于四类典型场景的分析

当“断言”(Assertion)一词成为IT术语时,语义的混沌性和二义性也随之而生。那么,何为断言?断言何为?实际上,只需分析四种典型场景,确切答案和准确描述就将自然显现。 在SAML&#xf…

【讲解下Spring Boot单元测试】

🌈个人主页: 程序员不想敲代码啊 🏆CSDN优质创作者,CSDN实力新星,CSDN博客专家 👍点赞⭐评论⭐收藏 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共…

数据可视化(七):Pandas香港酒店数据高级分析,涉及相关系数,协方差,数据离散化,透视表等精美可视化展示

Tips:"分享是快乐的源泉💧,在我的博客里,不仅有知识的海洋🌊,还有满满的正能量加持💪,快来和我一起分享这份快乐吧😊! 喜欢我的博客的话,记得…

websocket 请求头报错 Provisional headers are shown 的解决方法

今日简单总结 websocket 使用过程中遇到的问题&#xff0c;主要从以下三个方面来分享&#xff1a; 1、前端部分 websocket 代码 2、使用 koa.js 实现后端 websocket 服务搭建 3、和后端 java Netty 库对接时遇到连接失败问题 一、前端部分 websocket 代码 <template>…

B2024 输出浮点数 洛谷题单

首选需要进行了解的就是%a.bf所代表的含义就行了&#xff0c;直接莽了&#xff0c;没啥解释的笑脸&#x1f644; 在 Python 中&#xff0c;%a.bf 中的参数 a 和 b 是用来格式化浮点数的输出的&#xff0c;具体含义如下&#xff1a; a 表示总输出宽度&#xff0c;包括小数点、…

Kubernetes Kubelet 的 Cgroups 资源限制机制分析

前言 容器技术的两大技术基石&#xff0c;想必大家都有所了解&#xff0c;即 namespace 和 cgroups。但你知道 cgroups 是如何在 kubernetes 中发挥作用的吗&#xff1f;kubelet 都设置了哪些 cgroups 参数来实现对容器的资源限制的呢&#xff1f;本文就来扒一扒 Kubernetes k…

Docker - WEB应用实例

原文地址&#xff0c;使用效果更佳&#xff01; Docker - WEB应用实例 | CoderMast编程桅杆Docker - WEB应用实例 在之前的章节中&#xff0c;仅对普通容器进行了演示&#xff0c;但在实际中常常使用到 Docker 容器中的 WEB 应用程序。 运行一个WEB应用 拉取镜像 创建一个容器…