深度学习_VGG_3

news2025/1/4 20:33:03

目标

  • 知道VGG网络结构的特点
  • 能够利用VGG完成图像分类

2014年,牛津大学计算机视觉组(Visual Geometry Group)和Google DeepMind公司的研究员一起研发出了新的深度卷积神经网络:VGGNet,并取得了ILSVRC2014比赛分类项目的第二名,主要贡献是使用很小的卷积核(3×3)构建卷积神经网络结构,能够取得较好的识别精度,常用来提取图像特征的VGG-16和VGG-19。

1.VGG的网络架构

VGG可以看成是加深版的AlexNet,整个网络由卷积层和全连接层叠加而成,和AlexNet不同的是,VGG中使用的都是小尺寸的卷积核(3×3),其网络架构如下图所示:

VGGNet使用的全部都是3x3的小卷积核和2x2的池化核,通过不断加深网络来提升性能。VGG可以通过重复使用简单的基础块来构建深度模型。

在tf.keras中实现VGG模型,首先来实现VGG块,它的组成规律是:连续使用多个相同的填充为1、卷积核大小为3×33×3的卷积层后接上一个步幅为2、窗口形状为2×22×2的最大池化层。卷积层保持输入的高和宽不变,而池化层则对其减半。我们使用vgg_block函数来实现这个基础的VGG块,它可以指定卷积层的数量num_convs和每层的卷积核个数num_filters:

# 定义VGG网络中的卷积块:卷积层的个数,卷积层中卷积核的个数
def vgg_block(num_convs, num_filters):
    # 构建序列模型
    blk = tf.keras.models.Sequential()
    # 遍历所有的卷积层
    for _ in range(num_convs):
        # 每个卷积层:num_filter个卷积核,卷积核大小为3*3,padding是same,激活函数是relu
        blk.add(tf.keras.layers.Conv2D(num_filters,kernel_size=3,
                                    padding='same',activation='relu'))
    # 卷积块最后是一个最大池化,窗口大小为2*2,步长为2
    blk.add(tf.keras.layers.MaxPool2D(pool_size=2, strides=2))
    return blk

VGG16网络有5个卷积块,前2块使用两个卷积层,而后3块使用三个卷积层。第一块的输出通道是64,之后每次对输出通道数翻倍,直到变为512。

# 定义5个卷积块,指明每个卷积块中的卷积层个数及相应的卷积核个数
conv_arch = ((2, 64), (2, 128), (3, 256), (3, 512), (3, 512))

因为这个网络使用了13个卷积层和3个全连接层,所以经常被称为VGG-16,通过制定conv_arch得到模型架构后构建VGG16:

# 定义VGG网络
def vgg(conv_arch):
    # 构建序列模型
    net = tf.keras.models.Sequential()
    # 根据conv_arch生成卷积部分
    for (num_convs, num_filters) in conv_arch:
        net.add(vgg_block(num_convs, num_filters))
    # 卷积块序列后添加全连接层
    net.add(tf.keras.models.Sequential([
        # 将特征图展成一维向量
        tf.keras.layers.Flatten(),
        # 全连接层:4096个神经元,激活函数是relu
        tf.keras.layers.Dense(4096, activation='relu'),
        # 随机失活
        tf.keras.layers.Dropout(0.5),
        # 全连接层:4096个神经元,激活函数是relu
        tf.keras.layers.Dense(4096, activation='relu'),
        # 随机失活
        tf.keras.layers.Dropout(0.5),
        # 全连接层:10个神经元,激活函数是softmax
        tf.keras.layers.Dense(10, activation='softmax')]))
    return net
# 网络实例化
net = vgg(conv_arch)

我们构造一个高和宽均为224的单通道数据样本来看一下模型的架构:

# 构造输入X,并将其送入到net网络中
X = tf.random.uniform((1,224,224,1))
y = net(X)
# 通过net.summay()查看网络的形状
net.summay()

网络架构如下:

Model: "sequential_15" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= sequential_16 (Sequential) (1, 112, 112, 64) 37568 _________________________________________________________________ sequential_17 (Sequential) (1, 56, 56, 128) 221440 _________________________________________________________________ sequential_18 (Sequential) (1, 28, 28, 256) 1475328 _________________________________________________________________ sequential_19 (Sequential) (1, 14, 14, 512) 5899776 _________________________________________________________________ sequential_20 (Sequential) (1, 7, 7, 512) 7079424 _________________________________________________________________ sequential_21 (Sequential) (1, 10) 119586826 ================================================================= Total params: 134,300,362 Trainable params: 134,300,362 Non-trainable params: 0 __________________________________________________________________

2.手写数字势识别

因为ImageNet数据集较大训练时间较长,我们仍用前面的MNIST数据集来演示VGGNet。读取数据的时将图像高和宽扩大到VggNet使用的图像高和宽224。这个通过tf.image.resize_with_pad来实现。

2.1 数据读取

首先获取数据,并进行维度调整:

import numpy as np
# 获取手写数字数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 训练集数据维度的调整:N H W C
train_images = np.reshape(train_images,(train_images.shape[0],train_images.shape[1],train_images.shape[2],1))
# 测试集数据维度的调整:N H W C
test_images = np.reshape(test_images,(test_images.shape[0],test_images.shape[1],test_images.shape[2],1))

由于使用全部数据训练时间较长,我们定义两个方法获取部分数据,并将图像调整为224*224大小,进行模型训练:

# 定义两个方法随机抽取部分样本演示
# 获取训练集数据
def get_train(size):
    # 随机生成要抽样的样本的索引
    index = np.random.randint(0, np.shape(train_images)[0], size)
    # 将这些数据resize成22*227大小
    resized_images = tf.image.resize_with_pad(train_images[index],224,224,)
    # 返回抽取的
    return resized_images.numpy(), train_labels[index]
# 获取测试集数据 
def get_test(size):
    # 随机生成要抽样的样本的索引
    index = np.random.randint(0, np.shape(test_images)[0], size)
    # 将这些数据resize成224*224大小
    resized_images = tf.image.resize_with_pad(test_images[index],224,224,)
    # 返回抽样的测试样本
    return resized_images.numpy(), test_labels[index]

调用上述两个方法,获取参与模型训练和测试的数据集:

# 获取训练样本和测试样本
train_images,train_labels = get_train(256)
test_images,test_labels = get_test(128)

为了让大家更好的理解,我们将数据展示出来:

# 数据展示:将数据集的前九个数据集进行展示
for i in range(9):
    plt.subplot(3,3,i+1)
    # 以灰度图显示,不进行插值
    plt.imshow(train_images[i].astype(np.int8).squeeze(), cmap='gray', interpolation='none')
    # 设置图片的标题:对应的类别
    plt.title("数字{}".format(train_labels[i]))

结果为:

我们就使用上述创建的模型进行训练和评估。

2.2 模型编译

# 指定优化器,损失函数和评价指标
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.0)

net.compile(optimizer=optimizer,
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

2.3 模型训练

# 模型训练:指定训练数据,batchsize,epoch,验证集
net.fit(train_images,train_labels,batch_size=128,epochs=3,verbose=1,validation_split=0.1)

训练输出为:

Epoch 1/3 2/2 [==============================] - 34s 17s/step - loss: 2.6026 - accuracy: 0.0957 - val_loss: 2.2982 - val_accuracy: 0.0385 Epoch 2/3 2/2 [==============================] - 27s 14s/step - loss: 2.2604 - accuracy: 0.1087 - val_loss: 2.4905 - val_accuracy: 0.1923 Epoch 3/3 2/2 [==============================] - 29s 14s/step - loss: 2.3650 - accuracy: 0.1000 - val_loss: 2.2994 - val_accuracy: 0.1538

2.4 模型评估

# 指定测试数据
net.evaluate(test_images,test_labels,verbose=1)

输出为:

4/4 [==============================] - 5s 1s/step - loss: 2.2955 - accuracy: 0.1016 [2.2955007553100586, 0.1015625]

如果我们使用整个数据集训练网络,并进行评估的结果

[0.31822608125209806, 0.8855]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1509486.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OKHttpRetrofit

完成一个get请求 1.导入依赖 implementation("com.squareup.okhttp3:okhttp:3.14.")2.开启viewBinding android.buildFeatures.viewBinding true 3.加网络权限 和 http明文请求允许配置文件 <?xml version"1.0" encoding"utf-8"?> &l…

【精选】30+Redis面试题整理(2024)附答案

目录 前言Redis基础项目有用到redis吗&#xff1f;你们项目为什么用redis?redis为什么这么快&#xff1f;了解Redis的线程模型吗&#xff1f;Redis优缺点?redis如何实现持久化&#xff1f;RDB持久化过程&#xff1f;AOF持久化过程&#xff1f;AOF持久化会出现阻塞吗&#xff…

[Angular 基础] - 表单:响应式表单

[Angular 基础] - 表单&#xff1a;响应式表单 之前的笔记&#xff1a; [Angular 基础] - routing 路由(下) [Angular 基础] - Observable [Angular 基础] - 表单&#xff1a;模板驱动表单 开始 其实这里的表单和之前 Template-Driven Forms 没差很多&#xff0c;不过 Tem…

vue-创建vue项目记录

安装node.js 先安装node.js的运行环境node.js的下载地址 安装后就可以使用npm命令 1、清除npm缓存&#xff1a;npm cache clean --force 2、禁用SSL&#xff1a;npm config set strict-ssl false 3、手动设置npm镜像源&#xff1a;npm config set registry https://registry.…

Python AI 之Stable-Diffusion-WebUI

Stable-Diffusion-WebUI简介 通过Gradio库&#xff0c;实现Stable Diffusion web 管理接口 Windows 11 安装Stable-Diffusion-WebUI 个人认为Stable-Diffusion-WebUI 官网提供的代码安装手册/自动安装不适合新手安装&#xff0c;我这边将一步步讲述我是如何搭建Python Conda…

centos 系统 yum 无法安装(换国内镜像地下)

centos 系统 yum 因为无法连接到国外的官网而无法安装&#xff0c;问题如下图&#xff1a; 更换阿里镜像&#xff0c;配置文件路径&#xff1a;/etc/yum.repos.d/CentOS-Base.repo&#xff08;如果目录有多余的文件可以移动到子目录&#xff0c;以免造成影响&#xff09; bas…

php CI框架异常报错通过钉钉自定义机器人发送

php CI框架异常报错通过钉钉自定义机器人发送 文章目录 php CI框架异常报错通过钉钉自定义机器人发送前言一、封装一个异常监测二、封装好钉钉信息发送总结 前言 我们在项目开发中&#xff0c;经常会遇到自己测试接口没问题&#xff0c;上线之后就会测出各种问题&#xff0c;主…

弹性盒子布局 Flexbox Layout

可以嵌套下去 1.display 属性 默认行排列 <style>.flex-item{ height: 20px;width: 10px;background-color: #f1f1f1;margin: 10px;}</style> </head> <body> <div class"flex-container"><div class"flex-item">1&l…

ajax异步访问及跨域处理

文章目录 1 认识同步和异步1.1 什么是同步交互1.2 什么是异步交互 2 AJAX介绍3 案例开发之验证用户名4 JSON格式4.1 响应普通文本数据4.2 JSON的介绍和应用4.3 JSON 与 JS 对象的关系4.4 JSON 和 JS 对象互转4.5 GSON工具类的使用 5 AJAX结合jQuery实现5.1 jQuery.ajax()的简单…

问题解决:NPM 安装 TypeScript出现“sill IdealTree buildDeps”

一、原因&#xff1a; 使用了其他镜像&#xff08;例如我使用了淘宝镜像 npm config set registry https://registry.npm.taobao.org/ &#xff09; 二、解决方法&#xff1a; 1.切换为原镜像 npm config set registry https://registry.npmjs.org 安装typescript npm i …

vscode设置setting.json

{ // vscode默认启用了根据文件类型自动设置tabsize的选项 "editor.detectIndentation": false, // 重新设定tabsize "editor.tabSize": 2, // #每次保存的时候自动格式化 // "editor.formatOnSave": true, // #每次保存的时候将代码按eslint格式…

java学习(集合)

一.集合(主要是单列集合和双列集合) 1.集合的框架体系&#xff08;两大类&#xff09; 2.collection接口是实现类的特点&#xff1a; 1)collection实现子类可以存放多个元素&#xff0c;每个元素可以是Object 2)有效Collection的实现类&#xff0c;可以存放重复的元素&#…

交叉编译x264 zlib ffmpeg以及OpenCV等 以及解决交叉编译OpenCV时ffmpeg始终为NO的问题

文章目录 环境编译流程nasm编译x264编译zlib编译libJPEG编译libPNG编译libtiff编译 FFmpeg编译OpenCV编译问题1解决方案 问题2解决方案 总结 环境 系统&#xff1a;Ubutu 18.04交叉编译链&#xff1a;gcc-arm-10.2-2020.11-x86_64-aarch64-none-linux-gnu 我的路径/opt/toolch…

解释“RNN encode-decode”

“RNN encode-decode” 涉及使用循环神经网络&#xff08;Recurrent Neural Network&#xff0c;RNN&#xff09;来执行编码和解码操作。这种结构常用于处理序列数据&#xff0c;例如自然语言处理、语音识别和时间序列预测等任务。 以下是 “RNN encode-decode” 的一般概念&a…

week07day01(窗口函数)

一. 窗口函数的定义和一些规范&#xff1a; 对数据进行分区&#xff0c;数据的样式是不改变的&#xff0c;但是会多添加一列。窗口函数只能写在"结果集"中。 二. 排名函数 1. rank() over() 例题&#xff1a;对每个人的消费金额进行排名&#xff1a; rank() …

基于Android的教学课程系统设计与开发

摘 要 移动应用已经成为人们生活必不可缺的一部分&#xff0c;大学生身为移动应用的最大用户群体&#xff0c;在生活学习娱乐各个方面都与移动应用有着紧密联系&#xff0c;然而针对大学生校园学习的移动应用却寥寥无几&#xff0c;因为不同的学校&#xff0c;甚至不同的院系&…

unity显示当前时间

1建立文本组件和一个空对象 2创建一个脚本并复制下面代码 using System.Collections; using System.Collections.Generic; using TMPro; using UnityEngine;public class showtime: MonoBehaviour {public TextMeshProUGUI time;private void Update(){string currentTime Sy…

抖音开放平台第三方开发,实现代小程序备案申请

大家好&#xff0c;我是小悟 抖音小程序备案整体流程总共分为五个环节&#xff1a;备案信息填写、平台初审、工信部短信核验、通管局审核和备案成功。 服务商可以代小程序发起备案申请。在申请小程序备案之前&#xff0c;需要确保小程序基本信息已填写完成、小程序至少存在一个…

Clearview X for mac v3.5.0 电子书阅读器 兼容 M1/M2/M3

应用介绍 Clearview X 是 macOS 上的一款简洁易用且美观大方的电子书阅读器。直观好用的图书管理功能&#xff0c;支持 PDF, Epub, MOBI, CHM, FB2, CBR, CBZ 等流行的电子书格式&#xff0c;可以方便地添加注解&#xff0c;插入书签&#xff0c;及迅速的搜索查找。支持在不同…

基于Springboot的高校竞赛管理系统(有报告)。Javaee项目,springboot项目。

演示视频&#xff1a; 基于Springboot的高校竞赛管理系统&#xff08;有报告&#xff09;。Javaee项目&#xff0c;springboot项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系结构…