经典卷积神经网络 - VGG

news2024/11/18 13:37:08

使用块的网络 - VGG。

使用多个 3 × 3 3\times 3 3×3的要比使用少个 5 × 5 5\times 5 5×5的效果要好。
在这里插入图片描述
在这里插入图片描述
VGG全称是Visual Geometry Group,因为是由Oxford的Visual Geometry Group提出的。AlexNet问世之后,很多学者通过改进AlexNet的网络结构来提高自己的准确率,主要有两个方向:小卷积核和多尺度。而VGG的作者们则选择了另外一个方向,即加深网络深度。

网络架构

卷积网络的输入是224 * 224RGB图像,整个网络的组成是非常格式化的,基本上都用的是3 * 3的卷积核以及 2 * 2max pooling,少部分网络加入了1 * 1的卷积核。因为想要体现出“上下左右中”的概念,3*3的卷积核已经是最小的尺寸了。

VGG16相比之前网络的改进是3个33卷积核来代替7x7卷积核,2个33卷积核来代替5*5卷积核,这样做的主要目的是在保证具有相同感知野的条件下,减少参数,提升了网络的深度。

多个VGG块后接全连接层。

不同次数的重复块得到不同的架构,如VGG-16,VGG-19等。

VGG:更大更深的AlexNet。

总结:

  • VGG使用可重复使用的卷积块来构建深度卷积神经网络
  • 不同的卷积块个数和超参数可以得到不同复杂度的变种

代码实现

使用数据集CIFAR

model.py

import torch
from torch import nn

class Vgg16(nn.Module):

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.model = nn.Sequential(
            nn.Conv2d(3,64,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(64,64,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(64,128,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(128,128,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(128,256,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(256,256,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(256,256,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(256,512,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(512,512,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(512,512,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Flatten(),
            nn.Linear(7*7*512,4096),
            nn.Dropout(0.5),
            nn.Linear(4096,4096),
            nn.Dropout(0.5),
            nn.Linear(4096,10)
        )

    def forward(self,x):
        return self.model(x)

# 验证模型正确性
if __name__ == '__main__':
    net = Vgg16()
    x = torch.ones((64,3,244,244))
    output = net(x)
    print(output)

train.py

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision.transforms import transforms
from model import Vgg16

# 扫描数据次数
epochs = 3
# 分组大小
batch = 64
# 学习率
learning_rate = 0.01
# 训练次数
train_step = 0
# 测试次数
test_step = 0


# 定义图像转换
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor()
])
# 读取数据
train_dataset = datasets.CIFAR10(root="./dataset",train=True,transform=transform,download=True)
test_dataset = datasets.CIFAR10(root="./dataset",train=False,transform=transform,download=True)
# 加载数据
train_dataloader = DataLoader(train_dataset,batch_size=batch,shuffle=True,num_workers=0)
test_dataloader = DataLoader(test_dataset,batch_size=batch,shuffle=True,num_workers=0)
# 数据大小
train_size = len(train_dataset)
test_size = len(test_dataset)
print("训练集大小:{}".format(train_size))
print("验证集大小:{}".format(test_size))

# GPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)
# 创建网络
net = Vgg16()
net = net.to(device)
# 定义损失函数
loss = nn.CrossEntropyLoss()
loss = loss.to(device)
# 定义优化器
optimizer = torch.optim.SGD(net.parameters(),lr=learning_rate)

writer = SummaryWriter("logs")
# 训练
for epoch in range(epochs):
    print("-------------------第 {} 轮训练开始-------------------".format(epoch))
    net.train()
    for data in train_dataloader:
        train_step = train_step + 1
        images,targets = data
        images = images.to(device)
        targets = targets.to(device)
        outputs = net(images)
        loss_out = loss(outputs,targets)
        optimizer.zero_grad()
        loss_out.backward()
        optimizer.step()

        if train_step%100==0:
            writer.add_scalar("Train Loss",scalar_value=loss_out.item(),global_step=train_step)
            print("训练次数:{},Loss:{}".format(train_step,loss_out.item()))

    # 测试
    net.eval()
    total_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            test_step = test_step + 1
            images, targets = data
            images = images.to(device)
            targets = targets.to(device)
            outputs = net(images)
            loss_out = loss(outputs, targets)
            total_loss = total_loss + loss_out
            accuracy = (targets == torch.argmax(outputs,dim=1)).sum()
            total_accuracy = total_accuracy + accuracy
        # 计算精确率
        print(total_accuracy)
        accuracy_rate = total_accuracy / test_size

        print("第 {} 轮,验证集总损失为:{}".format(epoch+1,total_loss))
        print("第 {} 轮,精确率为:{}".format(epoch+1,accuracy_rate))
        writer.add_scalar("Test Total Loss",scalar_value=total_loss,global_step=epoch+1)
        writer.add_scalar("Accuracy Rate",scalar_value=accuracy_rate,global_step=epoch+1)
    torch.save(net,"./model/net_{}.pth".format(epoch+1))
    print("模型net_{}.pth已保存".format(epoch+1))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1124904.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

day03_pandas_demo

文章目录 pandas介绍为什么使用pandasDataFrameDataFrame属性DataFrame的索引修改行列的索引值重设索引值以某列设置新索引 MultiIndexSerias索引操作直接索引按名字索引按数值索引 赋值操作排序对内容排序按索引排序 DataFrame的运算算术运算逻辑运算逻辑运算符号 < > |…

uni-app 小宠物 - 会说话的小鸟

在 template 中 <view class"container"><view class"external-shape"><view class"face-box"><view class"eye-box eye-left"><view class"eyeball-box eyeball-left"><span class"…

搭建nexus私服部署项目

目录 1、前言 2、添加release和snapshot版本库 3、配置idea中的Maven设置 4、配置maven的settings.xml文件 5、项目中使用maven部署 1、前言 前文主要讲述了maven私服nexus的搭建&#xff1a;maven私服nexus搭建mybatisplus使用-CSDN博客 本文将继续讲述搭建nexus私服有…

img标签如何将<svg></svg>数据渲染出来

要将 ​​<svg></svg>​​​ 数据插入到 ​​<img>​​ 标签中&#xff0c;你可以使用以下两种方法&#xff1a; 方法一&#xff1a;使用 Data URL 你可以将 ​​<svg></svg>​​ 数据编码为 Data URL&#xff0c;并将其作为 ​​<img>​​…

“游蛇”黑产团伙专题分析报告

目录 ​编辑 01概览 02黑产团伙攻击手段 2.1 恶意程序传播 双击类恶意程序 跳图类恶意程序 损坏类恶意程序 2.2 恶意程序执行 可信站点 黑产团伙基础设施 03黑产团伙的几种变现方式 3.1 伪装身份后实施诈骗 3.2 恶意拉群后实施诈骗 04防护、排查与处置 01概览 “…

手写SVG图片

有时候QT中可能会需要一些简单的SVG图片,但是网上的质量参差不齐,想要满意的SVG图片,我们可以尝试直接手写的方法. 新建文本文档,将以下代码复制进去,修改后缀名为.svg,保存 <?xml version"1.0" encoding"utf-8"?> <svg xmlns"http://www…

QTday06(人脸识别项目前置知识)

qt版本5.4.0&#xff1a;旧版本的qt&#xff0c;为啥要用旧版本的我也不知道 实现结果&#xff1a; 调用系统摄像头&#xff0c;用红框框住画面中的人头 代码&#xff1a; pro&#xff1a; #------------------------------------------------- # # Project created by QtC…

通过热敏电阻计算温度(二)---ODrive实现分析

文章目录 通过热敏电阻计算温度&#xff08;二&#xff09;---ODrive实现分析测量原理图计算分析计算拟合的多项式系数根据多项式方程计算温度的函数温度计算调用函数 通过热敏电阻计算温度&#xff08;二&#xff09;—ODrive实现分析 ODrive计算热敏电阻的温度采用的时B值的…

计算机基础知识37

针对记录的SQL语句 记录: 表中的一行一行的数据称之为是一条记录 先有库---->表---->记录 C:\Users\26647>mysql -u root -p # 先登录 mysql> show databases&#xff1b; # 查看所有库 mysql> create database db1; # 创造库 mysql> use db1; # 引用…

Java逻辑运算符(、||和!),Java关系运算符

逻辑运算符把各个运算的关系表达式连接起来组成一个复杂的逻辑表达式&#xff0c;以判断程序中的表达式是否成立&#xff0c;判断的结果是 true 或 false。 逻辑运算符是对布尔型变量进行运算&#xff0c;其结果也是布尔型&#xff0c;具体如表 1 所示。 表 1 逻辑运算符的用…

科大讯飞星火认知大模型

哈喽&#xff0c;大家好&#xff01; 前段时间「科大讯飞版ChatGPT」上线&#xff0c;给大家推荐了一波&#xff0c;演示了其强大的功能&#xff0c;不少小伙伴都立马申请体验了一把&#xff0c;也有私信说非常强大&#xff0c;工作效率提高不少&#xff0c;支持国产大模型&am…

【Python · PyTorch】数据基础

数据基础 1. 数据操作1.1 入门1.2 运算符1.3 广播机制1.4 索引和切片1.5 节省内存1.6 转化为其他Python对象 2. 数据预处理2.1 读取数据集2.2 处理缺失值2.3 转换为张量格式 本文介绍了PyTorch数据基础&#xff0c;Python版本3.9.0&#xff0c;代码于Jupyter Lab中运行&#xf…

linux系统安装jdk

1.从官网下载jdk包,Java Archive Downloads - Java SE 8u211 and later 2.创建java目录并上传jdk包 mkdir -p /home/local/java 3.解压jdk包 cd /home/local/java tar -zxvf /home/local/java/jdk-8u381-linux-x64.tar.gz 4.配置环境变量 vim /etc/profile i export JAV…

关于数据可视化那些事

干巴巴的数据没人看&#xff0c;数据可视化才能直观展现数据要点&#xff0c;提升数据分析、数字化运营决策效率。那关于可视化的实现方式、技巧、工具等&#xff0c;你了解几分&#xff1f;接下来&#xff0c;我们就来聊聊数据可视化那些事。 1、什么是数据可视化&#xff1f…

酒精壁炉:独特的室内取暖方式

酒精壁炉是一种现代而引人注目的室内取暖方式&#xff0c;其独特之处在于使用酒精作为唯一的燃料源。这种现代壁炉设计旨在为家庭带来温暖和舒适&#xff0c;同时呈现出简约而时尚的外观。 1、无需烟囱的壁炉 传统壁炉通常需要烟囱或排气系统&#xff0c;以排除燃烧过程中产生…

Java赋值运算符(=)

赋值运算符是指为变量或常量指定数值的符号。赋值运算符的符号为“”&#xff0c;它是双目运算符&#xff0c;左边的操作数必须是变量&#xff0c;不能是常量或表达式。 其语法格式如下所示&#xff1a; 变量名称表达式内容 在 Java 语言中&#xff0c;“变量名称”和“表达式…

RISC-V架构——物理内存保护机制设置函数(pmp_set)解析

1、物理内存保护机制 参考博客&#xff1a;《RISC-V架构——物理内存属性和物理内存保护》&#xff1b; 2、pmp_set函数源码 int pmp_set(unsigned int n, unsigned long prot, unsigned long addr,unsigned long log2len) {int pmpcfg_csr, pmpcfg_shift, pmpaddr_csr;unsign…

【C++和数据结构】位图和布隆过滤器

目录 一、位图 1、位图的概念 2、位图的实现 ①、基本结构 ②、set ③、reset&#xff1a; ④、test ⑤、问题&#xff1a; ⑥、位图优缺点及应用&#xff1a; ⑦、完整代码及测试 二、布隆过滤器 1、布隆过滤器的提出 2、布隆过滤器的实现 ①、基本结构 ②…

初识测开/测试

前言 在进入软件测试的正式讲解之前&#xff0c;我们需要对这个行业有一个整体的了解。 当我们从软件开发转向软件测试的时候&#xff0c;多数公司是欢迎的&#xff0c;而且难度也小。 反之&#xff0c;当我们从软件测试转向软件开发的时候&#xff0c;难度将会变得很大。 关于…