pytorch搭建VGG网络

news2024/10/1 15:12:49

pytorch搭建VGG网络

  • CNN 感受野
  • VGG-16
  • 搭建VGG网络
    • model.py
    • train.py
    • predict.py

VGG 网络的创新点:通过堆叠多个小卷积核来替代大尺度卷积核,可以减少训练参数,同时能保证相同的感受野。
例如,可以通过堆叠两个 3×3 的卷积核替代 5x5 的卷积核,堆叠三个 3×3 的卷积核替代 7x7 的卷积核。

CNN 感受野

输出层 feature map 上的一个单元对应输入层上的区域大小,被称作感受野(receptive field)
如下图所示,输出层 layer3 中一个单元对应输入层 layer2 上区域大小为 2×2 (池化操作),对应输入层 layer1 上大小为5×5
在这里插入图片描述
感受野的计算公式为:

F(i)=(F(i+1)−1)×Stride +Ksize
其中:
  • F(i) 为第 i 层感受野
  • Stride 为第 i 层的步距
  • Ksize 为 卷积核 或 池化核 尺寸

以图中例子计算:可得F(3) = 1、F(2)=(1 − 1) × 2 + 2 = 2、F (1) = (2 − 1) × 2 + 3 = 5,可以理解为 layer2 中 2×2 区域中的每一块对应一个 3×3 的卷积核,又因为 stride=2,所以 layer1 的感受野为 5×5

思考:堆叠3×3卷积核后训练参数是否真的减少了?

CNN参数个数 = 卷积核尺寸 × 卷积核深度 × 卷积核组数 = 卷积核尺寸 × 输入特征矩阵深度 × 输出特征矩阵深度

现假设 输入特征矩阵深度 = 输出特征矩阵深度 = C

使用7×7卷积核所需参数个数:7 × 7 × C × C = 49C²
堆叠三个3×3的卷积核所需参数个数:3 × 3 × C × C + 3 × 3 × C × C + 3 × 3 × C × C = 27C²

VGG-16

网络结构如下图所示:
在这里插入图片描述

搭建VGG网络

model.py

跟 AlexNet 网络模型一样,VGG网络也是分为 卷积层提取特征 和 全连接层进行分类

import torch.nn as nn
import torch

class VGG(nn.Module):
    def __init__(self, features, num_classes=1000, init_weights=False):
        super(VGG, self).__init__()
        self.features = features			# 卷积层提取特征
        self.classifier = nn.Sequential(	# 全连接层进行分类
            nn.Dropout(p=0.5),
            nn.Linear(512*7*7, 2048),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(True),
            nn.Linear(2048, num_classes)
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.features(x)
        # N x 512 x 7 x 7
        x = torch.flatten(x, start_dim=1)
        # N x 512*7*7
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                # nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

VGG网络有 VGG-13、VGG-16等多种网络结构,其全连接层完全一样,卷积层只有卷积核个数稍有不同,可以将这几种结构统一在一个模型中。

# vgg网络模型配置列表,数字表示卷积核个数,'M'表示最大池化层
cfgs = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],											# 模型A
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],									# 模型B
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],					# 模型D
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 	# 模型E
}

# 卷积层提取特征
def make_features(cfg: list): # 传入的是具体某个模型的参数列表
    layers = []
    in_channels = 3		# 输入的原始图像(rgb三通道)
    for v in cfg:
        # 最大池化层
        if v == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        # 卷积层
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            layers += [conv2d, nn.ReLU(True)]
            in_channels = v
    return nn.Sequential(*layers)  # 单星号(*)将参数以元组(tuple)的形式导入


def vgg(model_name="vgg16", **kwargs):  # 双星号(**)将参数以字典的形式导入
    try:
        cfg = cfgs[model_name]
    except:
        print("Warning: model number {} not in cfgs dict!".format(model_name))
        exit(-1)
    model = VGG(make_features(cfg), **kwargs)
    return model

train.py

训练脚本跟另一篇 AlexNet 基本一致,需要注意的是实例化网络的过程:

model_name = "vgg16"
net = vgg(model_name=model_name, num_classes=5, init_weights=True)

predict.py

预测脚本跟另一篇 AlexNet 一致

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/686894.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringSecurity6.0+Redis+JWT基于token认证功能开发(可用于实际生产项目,保证API安全)

基于token认证功能开发 引子:最近做项目时遇到了一个特殊的需求,需要写共享接口把本系统的一些业务数据共享给各地市的自建系统,为了体现公司的专业性以及考虑到程序的扩展性(通过各地市的行政区划代码做限制)&#xf…

Java框架之spring AOP 和 IOC

写在前面 本文一起看下spring aop 和 IOC相关的内容。 1:spring bean核心原理 1.1:spring bean的生命周期 spring bean生命周期,参考下图: 我们来一步步的看下。 1 其中1构造函数就是执行类的构造函数完成对象的创建&#x…

第八十四天学习记录:Linux基础:初识Linux

流行的Linux发行版: 任何人都可以封装Linux,目前市面上有非常多的Linux发行版,常用的知名的如下: VMware WorkStations安装 安装完成后,要通过下图方式查看网络适配器是否正常配置: 配置成功&#xff1a…

软件需求分析文档怎么写?

什么是软件需求规范文档 (SRS)? 软件需求规范 (SRS) 文档列出了未来项目的需求、期望、设计和标准。其中包括规定项目目标的高级业务需求、最终用户要求和需求以及产品在技术方面的功能。简而言之,SRS 提供…

Vue-Element-Admin项目学习笔记(8)配置表单校验规则

前情回顾: vue-element-admin项目学习笔记(1)安装、配置、启动项目 vue-element-admin项目学习笔记(2)main.js 文件分析 vue-element-admin项目学习笔记(3)路由分析一:静态路由 vue-element-adm…

软考:中级软件设计师:校验码,汉明码纠错,信息位L和校验位r的关系

软考:中级软件设计师:校验码,汉明码纠错 提示:系列被面试官问的问题,我自己当时不会,所以下来自己复盘一下,认真学习和总结,以应对未来更多的可能性 关于互联网大厂的笔试面试,都是…

Linux通过crontab定时执行脚本任务

Linux通过crontab定时执行脚本任务 前言1. 创建写入脚本2. 设置执行权限3. 添加定时任务定时任务语法格式每分钟写入一条信息到指定文件 4. 查看日志文件5. 定时执行脚本的作用和用途 前言 在Linux中可以使用crontab来定时执行脚本。crontab是一个用于管理定时任务的工具&…

我这样回答多线程并发,面试官直接惊叹!

目录 前言: 1.单线程执行 2、多线程执行 3.守护线程 4.阻塞线程 前言: 多线程并发是一种处理任务的方式,它可以在同一时间内执行多个任务。多线程并发通常应用于需要同时处理多个任务或同时运行多个程序的情况下。 1.单线程执行 Pyth…

便携式水污染检测设备可以分析多少项污水指标

便携式水污染检测设备可以分析多少项污水指标(以下只是一部分) 水质检测仪可检测范围 1、饮用水检测:生活用水(自来水)、(瓶、桶装)矿泉水、天然矿泉水等; 2、工业用水检测&#xf…

人机融合智能的现状与展望

本篇文章是博主在人工智能等领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对人工智能等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在学习摘录和笔记专…

【开源库剖析】Shadow v2.3.0 源码解析

作者:Stan_Z 一、框架介绍 Shadow是19年腾讯开源的自研Android插件化框架,经过线上亿级用户量检验。 Shadow不仅开源分享了插件技术的关键代码,还完整的分享了上线部署所需要的所有设计。 优点: 1)复用独立安装app源…

Python可视化库之Matplotlib详解及使用方法

Matplotlib是Python中最常用的可视化工具之一,可以非常方便地创建海量类型的2D图表和一些基本的3D图表。本文主要推荐一个学习使用Matplotlib的步骤。 基本前提 如果你除了本文之外没有任何基础,建议用以下几个步骤学习如何使用matplotlib: 学习基本的matplotlib术语,尤其是…

第二十二章Java一维数组的定义、赋值和初始化

当数组中每个元素都只带有一个下标时,这种数组就是“一维数组”。一维数组(one-dimensional array)实质上是一组相同类型数据的线性集合,是数组中最简单的一种数组。 数组是引用数据类型,引用数据类型在使用之前一定要…

Reactor的概念

一、Reactor的概念 ​ Reactor模式是一种事件驱动模式,由一个或多个并发输入源(input),一个消息分发处理器(Initiation Dispatcher),以及每个消息对应的处理器(Request Handler)构成…

Linux安装nodejs

一、下载包 https://registry.npmmirror.com/binary.html?pathnode/ 比如:10.9.0 https://registry.npmmirror.com/binary.html?pathnode/v10.9.0/ 按需下载 https://registry.npmmirror.com/-/binary/node/v10.9.0/node-v10.9.0-linux-x64.tar.gz 二、上传到…

使用Nginx+Lua实现自定义WAF(Web application firewall)

转载https://github.com/unixhot/waf WAF 使用NginxLua实现自定义WAF(Web application firewall) 功能列表: 支持IP白名单和黑名单功能,直接将黑名单的IP访问拒绝。 支持URL白名单,将不需要过滤的URL进行定义。 支持…

解析vcruntime140.dll文件,缺失了要怎么去修复?

在计算机的世界中,vcruntime140.dll是一个重要的动态链接库文件。然而,有时候这个文件可能会引发一系列问题,影响应用程序的正常运行。如果你缺少了vcruntime140.dll,那么你的程序就会打不开,今天我们一起来聊聊vcrunt…

408数据结构第四章

串 定义存储结构模式匹配 小题形式考,比较简单,拿两个题来练手就会了 定义 字符串简称串 由零个或多个字符组成的有限序列 S是串名n称为串的长度,n0称为空串 串中多个连续的字符组成的子序列称为该串的子串 串的逻辑结构和线性表极为相似&am…

ByteHouse+Apache Airflow:高效简化数据管理流程

Apache Airflow 与 ByteHouse 相结合,为管理和执行数据流程提供了强大而高效的解决方案。本文突出了使用 Apache Airflow 与 ByteHouse 的主要优势和特点,展示如何简化数据工作流程并推动业务成功。 主要优势 可扩展可靠的数据流程:Apache Ai…

使用MASA Stack+.Net 从零开始搭建IoT平台 第五章 使用时序库存储上行数据

目录 前言分析实施步骤时序库的安装解决playload没有时间戳问题代码编写 总结 前言 我们可以将设备上行数据存储到关系型数据库中,我们需要两张带有时间戳的表(最新数据表 和 历史数据表),历史数据表存储所有设备上报的数据&…