第二章 VGG网络详解

news2025/1/9 15:52:17

系列文章目录

第一章 AlexNet网络详解

第二章 VGG网络详解

第三章 GoogLeNet网络详解 

第四章 ResNet网络详解 

第五章 ResNeXt网络详解 

第六章 MobileNetv1网络详解 

第七章 MobileNetv2网络详解 

第八章 MobileNetv3网络详解 

第九章 ShuffleNetv1网络详解 

第十章 ShuffleNetv2网络详解 

第十一章 EfficientNetv1网络详解 

第十二章 EfficientNetv2网络详解 

第十三章 Transformer注意力机制

第十四章 Vision Transformer网络详解 

第十五章 Swin-Transformer网络详解 

第十六章 ConvNeXt网络详解 

第十七章 RepVGG网络详解 

第十八章 MobileViT网络详解 


文章目录

  • VGG
  • 0. 前言
  • 1. 摘要
  • 2. AlexNet网络架构
    • 1.AlexNet_Model.py(pytorch实现)
    • 2.
  • 总结


0、前言


1、摘要

       在本研究中,我们探讨了卷积网络深度对于大规模图像识别准确性的影响。我们的主要贡献是对于使用非常小的(3 ×3)卷积过滤器的逐渐加深的网络进行彻底评估。研究结果表明,通过将深度推到16到19个权重层,可以显著提高传统网络配置的表现。这些研究成果是我们2014年ImageNet挑战赛提交的基础,我们的团队在定位和分类赛道中分别获得了第一和第二名。我们还展示了我们的表现良好的方法可以泛化到其他数据集,其中它们取得了最先进的结果。我们公开了我们表现最佳的ConvNet模型,以便促进关于深度视觉表征在计算机视觉中的使用的进一步研究。

2、VGG网络结构

"Very Deep Convolutional Networks for Large-Scale Image Recognition" (VGG) 是2014年由Karen Simonyan和Andrew Zisserman发表的一篇论文。该论文提出了一种卷积神经网络,它具有非常深的架构,被称为 VGG16 和 VGG19。 VGG网络是一种端到端的模型,可以接受任何大小的图像,但通常情况下,输入的图像大小为 224*224,类别数为1000。VGG网络中的卷积操作均使用3*3 的卷积核,并且在每个卷积层之间使用池化层进行下采样。这种架构的主要优点是能够实现非常深的模型,并且可以容易地训练,因为所有的卷积层和池化层的结构都非常简单和相似。此外,在使用 3*3的卷积核时,可以使用更多的层来增加网络的深度,而不会使参数数量增加过快。 VGG网络通过在两个 GPU 上进行训练,在 ILSVRC2014 图像识别竞赛中取得了第一名,同时还刷新了当时的记录。模型在该数据集上的 top-5 错误率仅为 7.3%。 总的来说,VGG网络通过简单而有效的模型架构取得了显著的性能提升,同时为后续深度学习模型设计提供了有力的参考与借鉴。

使用7x7的卷积核所需要的参数数量:7*7*C*C=49C^{^{2}}

使用3个3x3的卷积核代替一个7x7卷积核所需要的参数数量:3*3*3*C*C=27C^{^{2}}

经过卷积后的矩阵尺寸大小计算公式为: N = (W - F+2P)/S+1

1.VGG_Model.py(pytorch实现)

import torch.nn as nn
import torch

class AlexNet(nn.Module):
    def __init__(self,num_classes=1000,init_weights=False):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(48, 128, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(128, 192, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128*6*6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes)
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

2.train.py

import os
import sys
import json
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm
from model import AlexNet

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
    data_root = 'D:/100_DataSets/'
    image_path = os.path.join(data_root, "03_flower_data")
    assert os.path.exists(image_path), "{} path does not exits.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform = data_transform['train'])
    train_num = len(train_dataset)
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
    batch_size = 6
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
    print('Using {} dataloder workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=nw)
    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform['val'])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                 batch_size=4,
                                                 shuffle=False,
                                                 num_workers=nw)
    print("using {} image for train, {} images for validation.".format(train_num, val_num))
    net = AlexNet(num_classes=5, init_weights=True)
    net.to(device)
    loss_fuction = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0002)
    epochs = 10
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_fuction(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:,.3f}".format(epoch+1, epochs, loss)
        net.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
        val_accurate = acc / val_num
        print('[epoch % d] train_loss: %.3f val_accuracy: %.3f' %
              (epoch+1, running_loss / train_steps, val_accurate))
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(),save_path)
    print("Finished Training")

if __name__ == '__main__':
    main()






3.predict.py

import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import AlexNet

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]
    )
    img_path = "D:/20_Models/01_AlexNet_pytorch/image_predict/tulip.jpg"
    assert os.path.exists(img_path), "file: '{}' does not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    img = data_transform(img)
    img = torch.unsqueeze(img, dim=0)
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' does not exist.".format(json_path)
    with open(json_path,"r") as f:
        class_indict = json.load(f)
    model = AlexNet(num_classes=5).to(device)
    weights_path = "./AlexNet.pth"
    assert os.path.exists(weights_path), "file: '{}' does not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path))
    model.eval()
    with torch.no_grad():
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()
    print_res = "class: {} prob: {:.3f}".format(class_indict[str(predict_cla)],
                                                predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10} prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))
    plt.show()

if __name__ == '__main__':
    main()

4.predict.py

import os
from shutil import copy, rmtree
import random

def mk_file(file_path: str):
    if os.path.exists(file_path):
        rmtree(file_path)
    os.makedirs(file_path)

def main():
    random.seed(0)
    split_rate = 0.1
    #cwd = os.getcwd()
    #data_root = os.path.join(cwd, "flower_data")
    data_root = 'D:/100_DataSets/03_flower_data'
    origin_flower_path = os.path.join(data_root, "flower_photos")
    assert os.path.exists(origin_flower_path), "path '{}' does not exist".format(origin_flower_path)
    flower_class = [cla for cla in os.listdir(origin_flower_path) if os.path.isdir(os.path.join(origin_flower_path, cla))]
    train_root = os.path.join(data_root,"train")
    mk_file(train_root)
    for cla in flower_class:
        mk_file(os.path.join(train_root, cla))
    val_root = os.path.join(data_root, "val")
    mk_file(val_root)
    for cla in flower_class:
        mk_file(os.path.join(val_root,cla))
    for cla in flower_class:
        cla_path = os.path.join(origin_flower_path,cla)
        images = os.listdir(cla_path)
        num = len(images)
        eval_index = random.sample(images, k=int(num*split_rate))
        for index, image in enumerate(images):
            if image in eval_index:
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(val_root, cla)
                copy(image_path, new_path)
            else:
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(train_root, cla)
                copy(image_path, new_path)
            print("\r[{}] processing [{} / {}]".format(cla, index+1, num), end="")
        print()
    print("processing done!")
    
if __name__ == "__main__":
    main()

总结

提示:这里对文章进行总结:

每天一个网络,网络的学习往往具有连贯性,新的网络往往是基于旧的网络进行不断改进。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/672604.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Debian12中Grub2识别Windows

背景介绍:windows10 debian11,2023年6月,Debian 12正式版发布了。抵不住Debian12新特性的诱惑,我将Debian11升级至Debian12。升级成功,但Debian12的Grub2无法识别Window10。于是执行如下命令: debian:~# update-grub G…

图片加载错误的捕获及处理

引言 前端开发中,图片是我们在网页中加载最多的静态资源类型之一,但是图片加载过程中也有可能出现加载失败的情况,这是十分影响用户体验的。那么如何正确的判断图片是否成功加载,以及图片加载失败的时候,如何处理&…

了解 Dockerfile 和搭建 Docker 私有仓库:让容器化部署变得更简单

目录 1、Dockerfile 1.1什么是Dockerfile 1.2常用命令 1.3使用脚本创建镜像 2、Docker私有仓库 2.1私有仓库介绍: 2.2私有仓库搭建与配置 2.3上传镜像到私有仓库: 1、Dockerfile 1.1什么是Dockerfile Dockerfile是由一些列命令和参数构成的脚本…

【服务器数据恢复】IBM存储分配的卷无法访问的数据恢复案例

服务器故障: 一台IBM DS存储出现故障,存储分配给aix小机的卷无法访问。从底层查看分配给aix小机的3个卷的lvm信息丢失。 服务器数据恢复过程: 1、将存储中所有磁盘编号后取出,以只读方式做全盘镜像,后续的数据分析和数…

【C++篇】字符串:标准库string类

友情链接:C/C系列系统学习目录 知识总结顺序参考C Primer Plus(第六版)和谭浩强老师的C程序设计(第五版)等,内容以书中为标准,同时参考其它各类书籍以及优质文章,以至减少知识点上的…

Eyeshot 2023 Added NuGet packages.

Added Microsoft Visual Studio 2022 Extensions menu item.Microsoft .NET 6 Windows Toolbox items.Added NuGet packages.Planar curve projection on Sketch plane.Improved fillet surfaces quality and speed.Added ICurve.ConverToLinearPath() family of methods.   …

RabbitMQ学习笔记4(小滴课堂)RabbitMQ工作队列模型实战

Java项目整合RabbitMQ 创建一个maven项目。 然后我们在maven里加上jdk和rabbitmq的依赖设置: 我们写一段生产者的代码: 然后我们去运行它: 可以看到这里有一个队列。 现在我们是可以查看到队列的。 我们去写消费者代码: 这里的之…

LeetCode 2090. K Radius Subarray Averages【前缀和,滑动窗口,数组】中等

本文属于「征服LeetCode」系列文章之一,这一系列正式开始于2021/08/12。由于LeetCode上部分题目有锁,本系列将至少持续到刷完所有无锁题之日为止;由于LeetCode还在不断地创建新题,本系列的终止日期可能是永远。在这一系列刷题文章…

Redis是什么,如何学习,如何整合SpringBoot?

目录 一、Redis是什么? 二、如何学习Redis 三、如何整合SpringBoot 一、Redis是什么? Redis 是一个高性能的开源 NoSQL 数据库,支持多种数据结构,包括字符串、哈希、列表、集合和有序集合等。它采用内存存储,可以快…

Python3 函数与数据结构 | 菜鸟教程(十一)

目录 一、Python3 函数 (一)定义一个函数 1、你可以定义一个由自己想要功能的函数,以下是简单的规则: 2、语法 3、实例 ①让我们使用函数来输出"Hello World!": ②更复杂点的应用&#xff…

axios 发送请求请求头信息不包含Cookie信息

问题 axios 发送请求请求头信息不包含Cookie信息 详细问题 使用VueSpringBoot进行项目开发,axios进行网络请求,发送请求,请求头信息不包含Cookie信息 具体如下 实际效果 预期效果 解决方案 作用域 Vue项目全局配置 打开Vue项目的入口…

SpringSecurity是什么,如何学习SpringSecurity?

目录 一、SpringSecurity是什么 二、如何学习SpringSecurity 三、SpringSecurity整合springboot 一、SpringSecurity是什么 Spring Security是一个功能强大的安全框架,它为企业级应用程序提供了完整的身份验证和授权管理。它是一个开源项目,由Pivota…

Linux:第五章课后习题及答案

第五章 Linux常用命令 Q1:常用的文本内容显示命令有哪些?区别是什么? 文本内容显示的命令有cat,more,less, head,tailcat:显示文本文件, 也可以把几个文件 内容附加到另…

第一章 AlexNet网络详解

系列文章目录 第一章 AlexNet网络详解 第二章 VGG网络详解 第三章 GoogLeNet网络详解 第四章 ResNet网络详解 第五章 ResNeXt网络详解 第六章 MobileNetv1网络详解 第七章 MobileNetv2网络详解 第八章 MobileNetv3网络详解 第九章 ShuffleNetv1网络详解 第十章…

python绘画多边形(turtle)

目录 前言 正三角形 正四边形 正多边形 总结: 前言 事情的起因是,我今天心血来潮想让openai生成路飞的图像看效果怎么样,他是这样回我的。 我这一想,这不稳了吗,这么轻松。结果…… import turtle# 定义画笔颜色…

【Visual Studio】报错 C1083,使用 C++ 语言,配合 Qt 开发串口通信界面

知识不是单独的,一定是成体系的。更多我的个人总结和相关经验可查阅这个专栏:Visual Studio。 文章目录 问题解决方案Ref. 问题 使用 C 语言,配合 Qt 开发串口通信界面时,报错代码 C1083。 复制一下错误信息,方便别人…

【C语言进阶】编译链接

文章目录 📖程序的两种环境 🔖翻译环境🔖执行环境 📖详解翻译环境🔖从人的角度去看编译链接🔖预编译🔖编译🔖汇编🔖链接🔖符号表的作用 📖执行环境…

外贸订单管理平台有哪些?

外贸订单的管理是外贸出口业务中非常重要的一项管理工作,订单能否实现准时交付则需要涉及到各种流程顺畅的支持。那么外贸订单管理平台有哪些?有孚盟软件。 首先,外贸订单管理平台主要是解决外贸公司的订单查询与管理,面对大量不…

2.4G无线收发芯片 XL2400,SOP8封装,外挂MCU使用

XL2400 芯片是工作在 2.400~2.483GHz 世界通用 ISM 频段的单片无线收发芯片。该芯片集成射频收发机、频率收生器、晶体振荡器、调制解调器等功能模块,并且支持一对多组网和带 ACK 的通信模式。发射输出率、工作频道以及通信数据率均可配置。芯片已将多颗外围贴片阻容…

mysql错误-1055 - Expression #1 of SELECT list is not in GROUP BY clause 解决方案

目录 业务场景发现问题表结构表数据sql查询 分析问题验证 解决问题方案一方案二方案三 注意事项 业务场景 当遇到数据库重复数据,就要将数据进行分组,取其中一条来展示,此时就要用到group by语句。 但当mysql的版本高于5.7时,在执…