经典卷积神经网络-VGGNet

news2024/10/5 22:21:45

经典卷积神经网络-VGGNet

一、背景介绍

VGG是Oxford的Visual Geometry Group的组提出的。该网络是在ILSVRC 2014上的相关工作,主要工作是证明了增加网络的深度能够在一定程度上影响网络最终的性能。VGG有两种结构,分别是VGG16和VGG19,两者并没有本质上的区别,只是网络深度不一样。

在这里插入图片描述

二、VGG-16网络结构

在这里插入图片描述

其中VGG系列具体的网络结构如下表所示:

在这里插入图片描述

如图所示,这是论文中所有VGG网络的详细信息,D列对应的为VGG-16网络。16指的是在这个网络中包含16个卷积层和全连接层(不算池化层和Softmax)。

  • VGG-16的卷积层没有那么多的超参数,在整个网络模型中,所有卷积核的大小都是 3 × 3的,并且padding为same,stride为1。所有池化层的池化核大小都是 2 × 2 的,并且步长为2。在几次卷积之后紧跟着池化,整个网络结构很规整。

  • 总共包含约1.38亿个参数,但其结构并不复杂,结构很规整,都是几个卷积层后面跟着可以压缩图像大小的池化层,同时,卷积层的卷积核数量的变化也存在一定的规律,都是池化之后图像高度宽度减半,但在下一个卷积层中通道数翻倍,这正是这种简单网络结构的一个规则。

  • VGG16相比AlexNet的一个改进是采用连续的几个3x3的卷积核代替AlexNet中的较大卷积核(11x11,7x7,5x5)。对于给定的感受野(与输出有关的输入图片的局部大小),采用堆积的小卷积核是优于采用大的卷积核,因为多层非线性层可以增加网络深度来保证学习更复杂的模式,而且代价还比较小(参数更少)。在VGG中,使用了3个3x3卷积核来代替7x7卷积核,使用了2个3x3卷积核来代替5×5卷积核,这样做的主要目的是在保证具有相同感受野的条件下,提升了网络的深度,在一定程度上提升了神经网络的效果。

  • 它的主要缺点就是需要训练的特征数量非常大。有些文章介绍了VGG-19,但通过研究发现VGG-19和VGG-16的性能表现几乎不分高下,所以很多人还是使用VGG-16,这也说明了单纯的增加网络深度,其性能不会有太大的提升。

  • 论文中还介绍了权重初始化方法,即预训练低层模型参数为深层模型参数初始化赋值。原文:网络权重初始化是非常重要的,坏的初始化会使得深度网络的梯度的不稳定导致无法学习。为了解决这个问题,我们首先在网络A中使用随机初始化进行训练。然后到训练更深的结构时,我们将第一层卷积层和最后三层全连接层的参数用网络A中的参数初始化(中间层的参数随机初始化)。

  • 论文中揭示了,随着网络深度的增加,图像的高度和宽度都以一定规律不断缩小,每次池化之后刚好缩小一半,而通道数量在不断增加,而且刚好也是在每组卷积操作后增加一倍。也就是说,图像缩小和通道增加的比例是有规律的,从这个角度看,这篇论文很吸引人。

三、VGG-16的Pytorch实现

我们可以根据:https://dgschwend.github.io/netscope/#/preset/vgg-16,来搭建VGG-16。

在这里插入图片描述

后面要将VGG-16Net应用到CIFAR10数据集上,所以对网络做了一些修改,具体代码如下:

from torch import nn


class Vgg16_Net(nn.Module):
    def __init__(self):
        super(Vgg16_Net, self).__init__()

        self.layer1 = nn.Sequential(
            # input_size = (3, 32, 32)
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            # input_size = (64, 32, 32)
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            # input_size = (64, 32, 32)
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.layer2 = nn.Sequential(
            # input_size = (64, 16, 16)
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            # input_size = (128, 16, 16)
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            # input_size = (128, 16, 16)
            nn.MaxPool2d(2, 2)
        )

        self.layer3 = nn.Sequential(
            # input_size = (128, 8, 8)
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            # input_size = (256, 8, 8)
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            # input_size = (256, 8, 8)
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            # input_size = (256, 8, 8)
            nn.MaxPool2d(2, 2)
        )

        self.layer4 = nn.Sequential(
            # input_size = (256, 4, 4)
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            # input_size = (512, 4, 4)
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            # input_size = (512, 4, 4)
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            # input_size = (512, 4, 4)
            nn.MaxPool2d(2, 2)
        )

        self.layer5 = nn.Sequential(
            # input_size = (512, 2, 2)
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            # input_size = (512, 2, 2)
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            # input_size = (512, 2, 2)
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            # input_size = (512, 2, 2)
            nn.MaxPool2d(2, 2)
            # output_size = (512, 1, 1)
        )

        self.conv = nn.Sequential(
            self.layer1,
            self.layer2,
            self.layer3,
            self.layer4,
            self.layer5
        )

        self.fc = nn.Sequential(
            # input_size = 512
            nn.Linear(512, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),

            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),

            nn.Linear(256, 10)
        )

    def forward(self, x):
        x = self.conv(x)
        # -1表示自动计算行数
        # -1也可以改成x.size(0) 表示batch_size的大小
        x = x.view(-1, 512 * 1 * 1)
        x = self.fc(x)
        return x

四、案例:CIFAR-10分类问题

import time
import torch
import torchvision
from model import *
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from matplotlib import pyplot as plt

# 加载数据集 拿到dataloader
def load_dataset(batch_size):
    train_data = torchvision.datasets.CIFAR10("../dataset/CIFAR10", train=True, download=True, transform=transforms.ToTensor())
    test_data = torchvision.datasets.CIFAR10("../dataset/CIFAR10", train=False, download=True, transform=transforms.ToTensor())
    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)
    return train_dataloader, test_dataloader


# 模型训练
def train(model, train_dataloader, criterion, optimizer, epochs, device, num_print, lr_scheduler=None, test_dataloader=None):
    # 记录train和test的acc方便绘制学习曲线
    record_train = list()
    record_test = list()

    # 开始训练
    model.train()
    for epoch in range(epochs):
        print("========== epoch: [{}/{}] ==========".format(epoch + 1, epochs))
        # total记录样本数 correct记录正确预测样本数
        total, correct, train_loss = 0, 0, 0
        start = time.time()

        # 结合enumerate函数和迭代器的unpacking 可以在获取数据的同时获取该批次数据对应的索引
        for i, (image, target) in enumerate(train_dataloader):
            image, target = image.to(device), target.to(device)
            output = model(image)
            loss = criterion(output, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            total += target.size(0)
            correct += (output.argmax(dim=1) == target).sum().item()
            train_acc = 100.0 * correct / total

            if (i + 1) % num_print == 0:
                print("step: [{}/{}], train_loss: {:.3f} | train_acc: {:6.3f}% | lr: {:.6f}".format(i + 1,
                                 len(train_dataloader), train_loss / (i + 1), train_acc, get_cur_lr(optimizer)))

        # 更新当前优化器的学习率
        if lr_scheduler is not None:
            lr_scheduler.step()

        print("--- cost time: {:.4f}s ---".format(time.time() - start))

        if test_dataloader is not None:
            record_test.append(test(model, test_dataloader, criterion, device))
        record_train.append(train_acc)

        # 保存当前模型
        torch.save(model.state_dict(), "train_model/VGG-16Net_{}.pth".format(epoch + 1))

    return record_train, record_test


# 模型测试
def test(model, test_dataloader, criterion, device):
    # total记录样本数 correct记录正确预测样本数
    total, correct = 0, 0

    # 开始测试
    model.eval()
    with torch.no_grad():
        print("*************** test ***************")
        for X, y in test_dataloader:
            X, y = X.to(device), y.to(device)

            output = model(X)
            loss = criterion(output, y)

            total += y.size(0)
            correct += (output.argmax(dim=1) == y).sum().item()

    test_acc = 100.0 * correct / total

    print("test_loss: {:.3f} | test_acc: {:6.3f}%".format(loss.item(), test_acc))
    print("************************************\n")

    # 记得重新调用model.train()
    model.train()

    return test_acc

# 获取当前的学习率 这里直接返回了第一个参数分组的学习率
def get_cur_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

# 绘制学习曲线
def learning_curve(record_train, record_test=None):
    # 设置 Matplotlib 图形样式
    # ggplot2 是一个用于数据可视化的流行 R 语言包,以其优雅和灵活的语法而闻名
    plt.style.use("ggplot")

    plt.plot(range(1, len(record_train) + 1), record_train, label="train acc")
    if record_test is not None:
        plt.plot(range(1, len(record_test) + 1), record_test, label="test acc")

    plt.legend(loc=4)
    plt.title("learning curve")
    plt.xticks(range(0, len(record_train) + 1, 5))
    plt.yticks(range(0, 101, 5))
    plt.xlabel("epoch")
    plt.ylabel("accuracy")

    plt.show()

# 定义超参数
BATCH_SIZE = 128
NUM_EPOCHS = 20
NUM_CLASSES = 10
LEARNING_RATE = 0.02
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0005
NUM_PRINT = 100
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


def main():
    model = Vgg16_Net()
    model = model.to(DEVICE)

    # 加载数据
    train_dataloader, test_dataloader = load_dataset(BATCH_SIZE)

    # 定义损失函数
    criterion = nn.CrossEntropyLoss()

    # 定义优化器
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=LEARNING_RATE,
        momentum=MOMENTUM,
        weight_decay=WEIGHT_DECAY,
        nesterov=True
    )
    # 定义学习率调度器
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    # 进行训练 返回训练集正确率和测试集正确率
    record_train, record_test = train(model, train_dataloader, criterion, optimizer, NUM_EPOCHS, DEVICE, NUM_PRINT, lr_scheduler, test_dataloader)

    # 绘制学习曲线
    learning_curve(record_train, record_test)

if __name__ == '__main__':
    main()

查看训练结果可以发现,测试集正确率基本保持在87.3%左右,训练集正确率接近100%:

在这里插入图片描述

学习曲线如下:

在这里插入图片描述

参考链接:

  • https://cloud.tencent.com/developer/article/1638597

  • https://blog.csdn.net/m0_50127633/article/details/117047057?spm=1001.2014.3001.5502

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1350693.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JRT代码结构调整和示例

之前一直没建表专门使用ORM的api,做模板设计器需要建表,就一边开发设计器一般测试和调整ORM的api,只有做业务才能知道哪些api使用别扭,写了设计器之后改进了ORM的api以方便业务操作数据库。新写法差不多是ORM操作数据库的稳定api了…

【jmeter】将上一个请求的结果作为下一个请求的参数

1、简介 ApacheJMeter是Apache组织开发的基于Java的压力测试工具。用于对软件做压力测试,它最初被设计用于Web应用测试但后来扩展到其他测试领域。它可以用于测试静态和动态资源例如静态文件、Java小服务程序、CGI脚本、Java对象、数据库,FTP服务器&…

Think-on-Graph—基于知识图谱的LLM推理

文章目录 背景动机LLM模型存在的问题LLM ⊕ \oplus ⊕KG范式的局限性 LLM ⊗ \otimes ⊗KG范式(Think on Graph,ToG)LLM ⊗ \otimes ⊗KG范式的过程ToG的三个阶段初始化实体提取关系及实体探索推理 例子及效果相关结论搜索深度和波束宽度对To…

Centos安装Kafka(KRaft模式)

1. KRaft引入 Kafka是一种高吞吐量的分布式发布订阅消息系统,它可以处理消费者在网站中的所有动作流数据。其核心组件包含Producer、Broker、Consumer,以及依赖的Zookeeper集群。其中Zookeeper集群是Kafka用来负责集群元数据的管理、控制器的选举等。 由…

大模型通向AGI,腾讯云携手业界专家探索创新应用新风向

引言 一年过去,ChatGPT 引发的 AGI 热潮丝毫未减。只是相对于最初推出时掀起的全民大模型热,如今关于该如何落地的讨论更多了起来。 随着算力、数据库、大数据等底层技术的发展,大模型的建设与在各个领域的应用正在加速推进,那么…

SaleSmartly获得了Meta Business Partners认证徽章

近日,SaleSmartly通过了社交网络服务巨头Meta在消息领域的Business Partners认证,这项权威且重要的认证进一步证实了SaleSmartly在消息管理领域的卓越实力和卓越成果。 Meta是一家美国互联网公司,旗下拥有Facebook、Instagram、WhatsApp等社交…

YOLOv8改进 | 注意力篇 | ACmix自注意力与卷积混合模型(提高FPS+检测效率)

一、本文介绍 本文给大家带来的改进机制是ACmix自注意力机制的改进版本,它的核心思想是,传统卷积操作和自注意力模块的大部分计算都可以通过1x1的卷积来实现。ACmix首先使用1x1卷积对输入特征图进行投影,生成一组中间特征,然后根…

项目引入Jar包的几种方式

目录 背景 方式一 前提 创建一个jar包 使用 方式二 背景 通常情况下,使用SpringBoot框架开发项目的过程中,需要引入一系列依赖,首选的就是在项目的 pom.xml 文件里面通过Maven坐标进行引入(可以通过Maven的坐标引入jar包的前…

设计模式之工厂设计模式【创造者模式】

学习的最大理由是想摆脱平庸,早一天就多一份人生的精彩;迟一天就多一天平庸的困扰。各位小伙伴,如果您: 想系统/深入学习某技术知识点… 一个人摸索学习很难坚持,想组团高效学习… 想写博客但无从下手,急需…

今天用vite新建的vue3的项目 启动遇到报错

UnhandledPromiseRejectionWarning: SyntaxError: Unexpected token ??at Loader.moduleStrategy (internal/modules/esm/translators.js:145:18) (Use node --trace-warnings ... to show where the warning was created) (node:30304) UnhandledPromiseRejectionWarning: U…

数据库索引、三范式、事务

索引 索引(Index)是帮助 MySQL 高效获取数据的数据结构。常见的查询算法,顺序查找,二分查找,二叉排序树查找,哈希散列法,分块查找,平衡多路搜索树 B 树(B-tree)。 常见索引原则有 选择唯一性索引:唯一性索引的值是唯…

听GPT 讲Rust源代码--library/panic_unwind

File: rust/library/panic_unwind/src/seh.rs 在Rust源代码中,rust/library/panic_unwind/src/seh.rs这个文件的作用是实现Windows操作系统上的SEH(Structured Exception Handling)异常处理机制。 SEH是Windows上的一种异常处理机制&#xff…

c++ / day04

1. 整理思维导图 2. 全局变量,int monster 10000;定义英雄类hero,受保护的属性string name,int hp,int attck;公有的无参构造,有参构造,虚成员函数 void Atk(){blood-0;},法师类继承自英雄类&a…

七功能遥控编解码芯片

一、基本概述 TT6/TR6 是一对为遥控玩具车设计的 CMOS LSI 芯片。TT6 为发射编码芯片,TR6 为接收解码芯片。TT6/TR6 提供七个功能按键控制前进、后退、左转、右转、加速、独立功能 F1,独立功能 F2 的动作。除此以外,还有这五种常规小车功能(…

valgrind跨平台调试及其问题分析

背景 同事在项目中遇到了内存泄漏问题,长时间没有解决,领导临时让我支援一下。心想,应该不难,毕竟我之间做过valgrind的使用总结。并输出内存泄漏问题分析思路(案例篇)和快速定位内存泄漏的套路两篇文章&a…

关于Github部分下载的方法

一、问题 在Github中,我需要下载部分文件,而github只有下载最原始文件夹和单独文件的功能。 比如我想下载头四个文件,难以操作。 二、方法 推荐使用谷歌浏览器,进入扩展程序界面: 在应用商店获取GitZip for github…

理解SQL中not in 与null值的真实含义

A not in B的原理是拿A表值与B表值做是否不等的比较, 也就是a ! b. 在sql中, null是缺失未知值而不是空值。 当你判断任意值a ! null时, 官方说, “You cannot use arithmetic comparison operators such as , <, or <> to test for NULL”, 任何与null值的对比都将返…

Java基础综合练习(飞机票,打印素数,验证码,复制数组,评委打分,数字加密,数字解密,抽奖,双色球)

练习一&#xff1a;飞机票 需求: ​ 机票价格按照淡季旺季、头等舱和经济舱收费、输入机票原价、月份和头等舱或经济舱。 ​ 按照如下规则计算机票价格&#xff1a;旺季&#xff08;5-10月&#xff09;头等舱9折&#xff0c;经济舱8.5折&#xff0c;淡季&#xff08;11月到来…

菜鸟之MATLAB学习——QPSK OQPSK信号生成及频谱分析

本人MATLAB学习小白&#xff0c;仅做笔记记录和分享~~ % qpsk && oqpsk clc; close all;Ts1; fc10;N_sample16; N_sum100; dt1/fc/N_sample; t0:dt:N_sum*Ts-dt; Tdt*length(t);d1sign(randn(1,N_sum)); d2sign(randn(1,N_sum));gtones(1,fc*N_sample); …

反转链表、链表的中间结点、合并两个有序链表【LeetCode刷题日志】

一、反转链表 给你单链表的头节点 head &#xff0c;请你反转链表&#xff0c;并返回反转后的链表。 力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 思路一&#xff1a;翻转单链表指针方向 这里解释一下三个指针的作用&#xff1a; n1&#xff1…