【深度学习|基础算法】3.VggNet(附训练|推理代码)

news2025/1/10 22:13:13

这里写目录标题

  • 1.摘要
  • 2.Vgg的网络架构
  • 3.代码
    • backbone
    • train
    • predict
  • 4.训练记录
  • 5.推理
    • onnxruntime推理
      • export_onnx
    • openvino推理
    • tensorrt推理

1.摘要

  vgg是由牛津大学视觉几何组(Visual Geometry Group)的这篇论文中提出的模型,并且在2014年的ILSVRC物体分类和目标检测任务中分别斩获第二和第一名的好成绩,被cv界以VGG的名称所熟知。VGG中提出以小卷积核多层卷积的核心思想,显著的提升了模型的性能(后面细讲),并且vgg的网络架构十分的清晰,卷积核与池化层的参数稳定,容易迁移到不同的数据集中来适应不同的任务。

2.Vgg的网络架构


   以下的vgg的网络架构图是以vgg-16为例来展示的(共16层权重层,包含卷积层和全连接层)
  • 输入尺寸为224×224×3。
  • 使用了3×3的小卷积核来代替7×7的大卷积核,通过计算我们可以发现,设置stride=1pooling_size=0的前提下,我们可以根据感受野的公式(rfsize=(out_size-1)×stride+ksize),假定我们现在得到一个特征图,我们以某一个像素为开始向下映射感受野,可以得到3个3 × 3的卷积核得到的感受野和1个7 × 7的卷积核的感受野是一样的,但是我们可以通过使用多个卷积核来将网络在维度上设计的更深,能够增强网络的非线性能力,并且能够构建除更具判别能力的决策函数(不同维度的信息不同,例如纹理,形状,抽象具象特征等)。
  • 根据参数量的计算方式,假设特征图的数量为c,那么3个3×3的卷积核得到的特征图的参数量为3 × (3 × 3 + 1)× c2 = 30c2,而1个7×7的卷积核得到的特征图的参数量为1 × (7 × 7 + 1) ×c2 = 50c2,从参数量看,多层的小卷积核比单层的大卷积核的参数要少了不少,因此缩短了训练时间。
  • 使用了5层最大池化,都使用的是same卷积的方式(卷积的三种模式:full、same、valid),每经过一次池化之后特征图尺寸缩小一半,最后一层的特征层尺寸即为224/2^5 = 7,因此进入线性层之前的最后一个特征层为7 × 7的尺寸,这是一个比价合理的尺寸大小。
  • 然后跟着的是两个隐藏层节点数为4096的全连接层,最后是一个class_nums维度的softmax分类层。(可以根据我们的分类需求来自定义最后的分类个数)
    在这里插入图片描述
    在这里插入图片描述

3.代码

backbone

import torch.nn as nn
import torch

cfgs = {
    'vgg11' : [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13' : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

class VGG(nn.Module):
    def __init__(self, feature, num_class=1000, pre_trained=False):
        super(VGG, self).__init__()
        self.feature = feature
        self.L1 = nn.Sequential(
            nn.Linear(7*7*512, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU()
        )
        self.L2 = nn.Sequential(
            nn.Linear(4096, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU()
        )
        self.FC = nn.Linear(4096, num_class)

    def forward(self, x):
        features = self.feature(x)
        out = features.view(x.size(0), -1)
        # x = torch.flatten(x, start_dim=1)
        out_1 = self.L1(out)
        out_2 = self.L2(out_1)
        out_3 = self.FC(out_2)
        return out_3
def Feature(cfg:list):
    layers = []
    in_channels = 3
    for c in cfg:
        if c == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            out_channels = c
            conv_block = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
            )
            layers += conv_block
            in_channels = c
    return nn.Sequential(*layers)


def vgg(model_name:str, pre_trained=False, **kwargs, ):
    assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)
    cfg = cfgs[model_name]

    features = Feature(cfg)

    model = VGG(features, **kwargs)
    if pre_trained:
        model.load_state_dict(torch.load("best.pth"))
    return model

train

import argparse
import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from tqdm import tqdm
from vgg import vgg

def parse_args():
    parser = argparse.ArgumentParser(description="vgg train by pytorch")
    parser.add_argument('--train_root', default='', type=str, help="the path of train-data folder")
    parser.add_argument('--val_root', default='', type=str, help="the path of val-data folder")
    parser.add_argument('--model_name', default="vgg16", type=str, help="the name of backbone")
    parser.add_argument('--class_num', default=2, type=int, help="the number of classes")
    parser.add_argument('--batch_size', default=10, type=int, help="the size of one batch")
    parser.add_argument('--epochs', default=100, type=int, help="the times of iteration")
    parser.add_argument('--workers', default=10, type=int, help="the nums of process")
    parser.add_argument('--learn_rate', default=0.00005, type=int, help="the speed of learning rate")
    parser.add_argument('--momentum', default=0.9, type=int, help="momentum")
    parser.add_argument('--weight_decay', default=None, type=int, help="weight decay")
    parser.add_argument('--gpu_id', default=0, type=int, help="the index of your gpu")
    parser.add_argument('--save_dir', default='best.pth', type=str, help="the best accuracy pth saved path")
    args = parser.parse_args()
    return args

def train():
    args = parse_args()

    train_dir = os.path.join(args.train_root, "dataset")
    val_dir = os.path.join(args.val_root, "dataset")
    assert os.path.exists(train_dir), "{} path does not exist!".format(train_dir)
    assert os.path.exists(val_dir), "{} path does not exist!".format(val_dir)

    device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else "cpu")
    # device = 'cpu'
    print('using {} to train.'.format(device))

    # nw = 1
    # if args.class_num == None:
    #     nw = min([os.cpu_count(), args.batchsize if args.batchsize > 1 else 0, 8])
        # args.workers = nw
    print('using {} workers to train'.format(args.workers))

    vgg_transform = {
        "train" : transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ],
                                 std  = [ 0.229, 0.224, 0.225 ]),
        ]),

        "val" : transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ],
                                 std  = [ 0.229, 0.224, 0.225 ]),
        ])
    }

    train_dataset = datasets.ImageFolder(root=train_dir, transform=vgg_transform["train"])
    train_loader  = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
    val_dataset   = datasets.ImageFolder(root=val_dir, transform=vgg_transform["val"])
    val_loader    = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    train_len     = len(train_dataset)
    val_len       = len(val_dataset)
    print('using {} images to train and {} images to val'.format(train_len, val_len))

    model = vgg(model_name=args.model_name, num_class=args.class_num)
    model.to(device)

    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.learn_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

    train_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(args.epochs):
        model.train()
        running_loss = 0.0
        cnt = 0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for time, data in enumerate(train_bar):
            images, labels = data
            outputs = model(images.to(device))
            loss = loss_function(outputs, labels.to(device))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            cnt += 1
            train_bar.desc = "train_epoch[{}]/total_epoch[{}] —— avg_loss:{}".format(epoch+1, args.epochs, running_loss/cnt)
        scheduler.step(running_loss)

        model.eval()
        val_acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(val_loader, file=sys.stdout)
            for step, data in enumerate(val_bar):
                images, labels = data
                outputs = model(images.to(device))
                pred_label = torch.max(outputs, dim=1)[1]
                val_acc += torch.eq(pred_label, labels.to(device)).sum().item()

        average_acc = val_acc / val_len
        print("[{}-epoch] train_loss: {}  val_accuracy: {}".format(epoch + 1, running_loss/train_steps, average_acc))

        if average_acc > train_acc:
            torch.save(model.state_dict(), args.save_dir)

    print('finish training')

if __name__ == "__main__":
    train()

predict

import torch
import vgg
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from train import parse_args
import cv2
from PIL import Image
from torchvision.transforms import ToPILImage
from torch.autograd import Variable



# 前处理
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

classes = [
    "cat",
    "dog",
]

args   = parse_args()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model  = vgg.vgg(model_name=args.model_name, pre_trained=False, num_class=args.class_num).to(device)
model.load_state_dict(torch.load(r'best.pth'))
model.eval()

image = cv2.imread('dataset/dog/dog.12386.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = Image.fromarray(image)
input = val_transform(image)
# image.show()

input = Variable(torch.unsqueeze(input, dim=0).float(), requires_grad=False).to(device)
with torch.no_grad():
    pred = model(input)
    class_idx = torch.argmax(pred[0])
    class_idx = class_idx.detach().cpu()
    print("This is a {}".format(classes[class_idx]))

4.训练记录

  • 需要设置较小的初始learning_rate来进行反向传播,否则会导致step太大,无法收敛,我一开始设置的是0.01,结果无法收敛,后来设置成0.0001,效果会好很多。
  • 随着我们训练的不断迭代,每个参数逐渐的接近收敛到梯度最小值处,这个位置往往是梯度变化较大的位置,小小的步长就能带来巨大的参数变化,导致模型精度的影响,因此在后期,我们需要进行学习率的调整,使模型能够在一个较小的学习率下去更新各项参数,因此我引入了
  • scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)来实现一个学习率的自动衰减,效果很好,可以尝试一下不同的方式,比如余弦退火法

5.推理

onnxruntime推理

首先,需要加载pth模型权重到实例化的模型中,然后调用自己写的导出onnx模型的代码:

export_onnx

import os.path

import vgg
import onnx
from train import parse_args
import torch

def export_onnx(pt_path:str, onnx_path:str):
    args  = parse_args()
    model = vgg.vgg(model_name=args.model_name, pre_trained=False, num_class=args.class_num)
    model.load_state_dict(torch.load(pt_path))

    dummy_inputs = {
        "input" : torch.randn(1, 3, 224, 224, dtype=torch.float),
    }

    output_names = {
        "classes"
    }

    # if not os.path.exists(os.path.dirname(onnx_path)):
    #     os.makedirs(os.path.dirname(onnx_path))

    with open(onnx_path, "wb") as f:
        print(f"Exporting onnx model to {onnx_path}...")
        torch.onnx.export(
            model,
            tuple(dummy_inputs.values()),
            f,
            export_params=True,
            verbose=False,
            opset_version=11,
            do_constant_folding=True,
            input_names=list(dummy_inputs.keys()),
            output_names=output_names,
            # dynamic_axes=dynamic_axes,
        )

if __name__ == "__main__":
    pt_path = "best.pth"
    onnx_path = "best.onnx"
    export_onnx(pt_path, onnx_path)

openvino推理

tensorrt推理

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1596505.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C#版Facefusion ,换脸器和增强器

C#版Facefusion ,换脸器和增强器 目录 说明 效果 项目 调用代码 说明 Facefusion是一款最新的开源AI视频/图片换脸项目。是原来ROOP的项目的延续。项目官方介绍只有一句话,下一代换脸器和增强器。 代码实现参考 https://github.com/facefusion/f…

Windows中通过cmd查看以保存的WiFi密码

1、要以管理员身份运行CMD 要以管理员身份运行CMD 2、查看以保存的所有WiFi 执行命令 netsh wlan show profiles 会显示所有已保存的wifi 3、这里查看已经保存的WiFi HLS_HMD 的密码 执行命令查看: netsh wlan show profile name"HLS_HMD&qu…

【Java开发指南 | 第四篇】Java常量、自动类型转换、修饰符

读者可订阅专栏:Java开发指南 |【CSDN秋说】 文章目录 Java常量自动类型转换Java修饰符 Java常量 常量在程序运行时是不能被修改的。 在 Java 中使用 final 关键字来修饰常量,声明方式和变量类似: final double PI 3.1415927;自动类型转换…

app证书在设置在哪

根据近日工业和信息化部发布的《工业和信息化部关于开展移动互联网应用程序备案工作的通知》,相信不少要进行IOS平台App备案的朋友遇到了一个问题,就是apple不提供云管理式证书的下载,也就无法获取公钥及证书SHA-1指纹。 已经上架的应用不想重…

ruoyi的一些gateway传token的一些问题

ruoyi的一些gateway传token的一些问题 1、gateway会处理授权信息 2、authorization: Bearer 253f2bd990754ca4aae22e1f755b17fe 是一个很简单的信息。 去授权信息,用这个token做验证。然后封装一个user在内部传递。 这个是不会暴露在最外面的请求中的。 还会放一…

RMAN数据迁移方案

数据迁移 Oracle环境检查 开启归档 1.首先关闭数据库 shutdown immediate; 2.打开mount状态 startup mount; 3.更改数据库为归档模式 alter database archivelog; 4.打开数据库 alter database open; 5.再次检查 archive log list; 查看构造的表和数据 由于数据会有中文&…

redis复习笔记08(小滴课堂)

案例实战需求之大数据下的用户画像标签去重 我们就简单的做到了去重了。 案例实战社交应用里面之关注、粉丝、共同好友案例 这就是我们set的一个应用。 案例实战之SortedSet用户积分实时榜单最佳实践 准备积分类对象: 我们加上构造方法和判断相等的equals和hascod…

C语言是不是要跟不上社会需求了?

C 语言是否已经难以跟上社会需求的步伐了呢?有这么一位网友曾提及,就在几天前,他遭受了老板的严厉批评,原因便是他仅仅精通 C 语言编程,已然无法满足老板的实际需求。事实上,C 语言在嵌入式领域仍旧拥有着极…

Depth maps转点云

前言 本文主要记录一下如何可视化相机位姿,如何用Blender得到的深度图反投影到3D空间,得到相应的点云。 Refernce https://github.com/colmap/colmap/issues/1106 https://github.com/IntelRealSense/librealsense/issues/12090 https://medium.com/yod…

图灵奖简介及2023年获奖者Avi Wigderson的贡献

No.内容链接1Openlayers 【入门教程】 - 【源代码示例300】 2Leaflet 【入门教程】 - 【源代码图文示例 150】 3Cesium 【入门教程】 - 【源代码图文示例200】 4MapboxGL【入门教程】 - 【源代码图文示例150】 5前端就业宝典 【面试题详细答案 1000】 文章目录 2023年的…

C语言基础入门案例(1)

目录 第一题:实现大衍数列的打印 第二题:生成所有由1、2、3、4组成的互不相同且无重复数字的三位数,并计算总数 第三题:整数加法计算器 第四题:实现一个范围累加和函数 第五题:编写一个函数计算整数的阶…

Webscoket简单demo介绍

前言 WebSocket 是从 HTML5 开始⽀持的⼀种⽹⻚端和服务端保持⻓连接的 消息推送机制. 理解消息推送: 传统的 web 程序, 都是属于 “⼀问⼀答” 的形式. 客⼾端给服务器发送了⼀个 HTTP 请求, 服务器给客 ⼾端返回⼀个 HTTP 响应.这种情况下, 服务器是属于被动的⼀⽅. 如果客⼾…

分类预测 | Matlab实现OOA-BP鱼鹰算法优化BP神经网络数据分类预测

分类预测 | Matlab实现OOA-BP鱼鹰算法优化BP神经网络数据分类预测 目录 分类预测 | Matlab实现OOA-BP鱼鹰算法优化BP神经网络数据分类预测分类效果基本介绍程序设计参考资料 分类效果 基本介绍 1.Matlab实现OOA-BP鱼鹰算法优化BP神经网络多特征分类预测(完整源码和数…

Win7开机进不了系统一直再自动修复,只能选禁用驱动签名才能进系统 其它模式都不行

环境: Win7专业版 问题描述: Win7开机进不了系统一直再修复,只能选禁用驱动签名才能进系统 其它模式都不行 解决方案: 1.开机按F8,选择禁用驱动签名进系统 2.查看系统日志文件定位错误原因 3.我这是DsArk64.sys导…

JS-32-jQuery01-jQuery的引入

一、初识jQuery jQuery是JavaScript世界中使用最广泛的一个库。鉴于它如此流行,又如此好用,所以每一个入门JavaScript的前端工程师都应该了解和学习它。 jQuery是一个优秀的JS函数库。 (对BOM和DOM的封装) jQuery这么流行&#x…

Flink设计运行原理 | 大数据技术

⭐简单说两句⭐ ✨ 正在努力的小新~ 💖 超级爱分享,分享各种有趣干货! 👩‍💻 提供:模拟面试 | 简历诊断 | 独家简历模板 🌈 感谢关注,关注了你就是我的超级粉丝啦! &…

计算机毕业设计springboot小区物业报修管理系统m8x57

该物业报修管理系统实施的目的在于帮助物业管理企业升级员工管理、住户管理、报修问题管理等内部管理平台,整合物业管理企业物力和人力,全面服务于维修人员管理的内部管理需求,并重视需求驱动、管理创新、与业主交流等外部需求,通过物业管理企业各项资源…

ArrayList部分底层源码分析

JDK版本为1.8.0_271,以插入和删除元素为例,部分源码如下: // 部分属性 transient Object[] elementData; // 底层数组 private int size; // 记录元素个数 private static final Object[] DEFAULTCAPACITY_EMPTY_ELEMENTDATA {}; // 空Obje…

异地组网怎么安装?

异地组网安装是指在不同地域的多个设备之间建立网络连接,以便实现数据传输和协同工作的过程。在如今的数字化时代,异地组网安装已经成为了许多企业和组织所必需的一项技术。 天联的使用场景 在异地组网安装中,天联是一种常用的工具。它具有以…

得物 Zookeeper SLA 也可以 99.99% | 得物技术

一、背景 ZooKeeper(ZK)是一个诞生于2007年的分布式应用程序协调服务。尽管出于一些特殊的历史原因,许多业务场景仍然不得不依赖它。比如,Kafka、任务调度等。特别是在 Flink 混合部署 ETCD 解耦 时,业务方曾要求绝对…