卷积神经网络(VGG-16)海贼王人物识别

news2024/9/25 9:28:54

文章目录

  • 前期工作
    • 1. 设置GPU(如果使用的是CPU可以忽略这步)
      • 我的环境:
    • 2. 导入数据
    • 3. 查看数据
  • 二、数据预处理
    • 1. 加载数据
    • 2. 可视化数据
    • 3. 再次检查数据
    • 4. 配置数据集
    • 5. 归一化
  • 三、构建VGG-16网络
    • 1. 官方模型(已打包好)
    • 2. 自建模型
    • 3. 网络结构图
  • 四、编译
  • 五、训练模型
  • 六、模型评估

前期工作

1. 设置GPU(如果使用的是CPU可以忽略这步)

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1
import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")

2. 导入数据

import matplotlib.pyplot as plt
import os,PIL

# 设置随机种子尽可能使结果可以重现
import numpy as np
np.random.seed(1)

# 设置随机种子尽可能使结果可以重现
import tensorflow as tf
tf.random.set_seed(1)

from tensorflow import keras
from tensorflow.keras import layers,models

import pathlib
data_dir = "weather_photos/"
data_dir = pathlib.Path(data_dir)

3. 查看数据

数据集中一共有路飞、索隆、娜美、乌索普、乔巴、山治、罗宾等7个人物角色

文件夹含义数量
lufei路飞117 张
suolong索隆90 张
namei娜美84 张
wusuopu乌索普77张
qiaoba乔巴102 张
shanzhi山治47 张
luobin罗宾105张
image_count = len(list(data_dir.glob('*/*.jpg')))

print("图片总数为:",image_count)

二、数据预处理

1. 加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

batch_size = 32
img_height = 224
img_width = 224
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 621 files belonging to 7 classes.
Using 497 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 621 files belonging to 7 classes.
Using 124 files for validation.

我们可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。

class_names = train_ds.class_names
print(class_names)
['lufei', 'luobin', 'namei', 'qiaoba', 'shanzhi', 'suolong', 'wusuopu']

2. 可视化数据

plt.figure(figsize=(10, 5))  # 图形的宽为10高为5

for images, labels in train_ds.take(1):
    for i in range(8):
        
        ax = plt.subplot(2, 4, i + 1)  

        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        
        plt.axis("off")

在这里插入图片描述

plt.imshow(images[1].numpy().astype("uint8"))

在这里插入图片描述

3. 再次检查数据

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break
(32, 224, 224, 3)
(32,)
  • Image_batch是形状的张量(32,180,180,3)。这是一批形状180x180x3的32张图片(最后一维指的是彩色通道RGB)。
  • Label_batch是形状(32,)的张量,这些标签对应32张图片

4. 配置数据集

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

5. 归一化

normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)
normalization_train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))
image_batch, labels_batch = next(iter(val_ds))
first_image = image_batch[0]
# 查看归一化后的数据
print(np.min(first_image), np.max(first_image))
0.0 0.9928046

三、构建VGG-16网络

VGG优缺点分析:

  • VGG优点

VGG的结构非常简洁,整个网络都使用了同样大小的卷积核尺寸(3x3)和最大池化尺寸(2x2)

  • VGG缺点

1)训练时间过长,调参难度大。2)需要的存储容量大,不利于部署。例如存储VGG-16权重值文件的大小为500多MB,不利于安装到嵌入式系统中。

1. 官方模型(已打包好)

官网模型调用这块我放到后面几篇文章中,下面主要讲一下VGG-16

# model = keras.applications.VGG16()
# model.summary()

2. 自建模型

from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout

def VGG16(nb_classes, input_shape):
    input_tensor = Input(shape=input_shape)
    # 1st block
    x = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv1')(input_tensor)
    x = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv2')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block1_pool')(x)
    # 2nd block
    x = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv1')(x)
    x = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv2')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block2_pool')(x)
    # 3rd block
    x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv1')(x)
    x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv2')(x)
    x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv3')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block3_pool')(x)
    # 4th block
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv1')(x)
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv2')(x)
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv3')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block4_pool')(x)
    # 5th block
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv1')(x)
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv2')(x)
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv3')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block5_pool')(x)
    # full connection
    x = Flatten()(x)
    x = Dense(4096, activation='relu',  name='fc1')(x)
    x = Dense(4096, activation='relu', name='fc2')(x)
    output_tensor = Dense(nb_classes, activation='softmax', name='predictions')(x)

    model = Model(input_tensor, output_tensor)
    return model

model=VGG16(1000, (img_width, img_height, 3))
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 112, 112, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 56, 56, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 56, 56, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 28, 28, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 28, 28, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 14, 14, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 7, 7, 512)         0         
_________________________________________________________________
flatten (Flatten)            (None, 25088)             0         
_________________________________________________________________
fc1 (Dense)                  (None, 4096)              102764544 
_________________________________________________________________
fc2 (Dense)                  (None, 4096)              16781312  
_________________________________________________________________
predictions (Dense)          (None, 1000)              4097000   
=================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
_________________________________________________________________

3. 网络结构图

结构说明:

  • 13个卷积层(Convolutional Layer),分别用blockX_convX表示
  • 3个全连接层(Fully connected Layer),分别用fcXpredictions表示
  • 5个池化层(Pool layer),分别用blockX_pool表示

VGG-16包含了16个隐藏层(13个卷积层和3个全连接层),故称为VGG-16

在这里插入图片描述

四、编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=1e-4)

model.compile(optimizer=opt,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

五、训练模型

epochs = 20

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)
Epoch 1/20
16/16 [==============================] - 14s 461ms/step - loss: 4.5842 - accuracy: 0.1349 - val_loss: 6.8389 - val_accuracy: 0.1129
Epoch 2/20
16/16 [==============================] - 2s 146ms/step - loss: 2.1046 - accuracy: 0.1398 - val_loss: 6.7905 - val_accuracy: 0.2016
Epoch 3/20
16/16 [==============================] - 2s 144ms/step - loss: 1.7885 - accuracy: 0.3531 - val_loss: 6.7892 - val_accuracy: 0.2903
Epoch 4/20
16/16 [==============================] - 2s 145ms/step - loss: 1.2015 - accuracy: 0.6135 - val_loss: 6.7582 - val_accuracy: 0.2742
Epoch 5/20
16/16 [==============================] - 2s 148ms/step - loss: 1.1831 - accuracy: 0.6108 - val_loss: 6.7520 - val_accuracy: 0.4113
Epoch 6/20
16/16 [==============================] - 2s 143ms/step - loss: 0.5140 - accuracy: 0.8326 - val_loss: 6.7102 - val_accuracy: 0.5806
Epoch 7/20
16/16 [==============================] - 2s 150ms/step - loss: 0.2451 - accuracy: 0.9165 - val_loss: 6.6918 - val_accuracy: 0.7823
Epoch 8/20
16/16 [==============================] - 2s 147ms/step - loss: 0.2156 - accuracy: 0.9328 - val_loss: 6.7188 - val_accuracy: 0.4113
Epoch 9/20
16/16 [==============================] - 2s 143ms/step - loss: 0.1940 - accuracy: 0.9513 - val_loss: 6.6639 - val_accuracy: 0.5968
Epoch 10/20
16/16 [==============================] - 2s 143ms/step - loss: 0.0767 - accuracy: 0.9812 - val_loss: 6.6101 - val_accuracy: 0.7419
Epoch 11/20
16/16 [==============================] - 2s 146ms/step - loss: 0.0245 - accuracy: 0.9894 - val_loss: 6.5526 - val_accuracy: 0.8226
Epoch 12/20
16/16 [==============================] - 2s 149ms/step - loss: 0.0387 - accuracy: 0.9861 - val_loss: 6.5636 - val_accuracy: 0.6210
Epoch 13/20
16/16 [==============================] - 2s 152ms/step - loss: 0.2146 - accuracy: 0.9289 - val_loss: 6.7039 - val_accuracy: 0.4839
Epoch 14/20
16/16 [==============================] - 2s 152ms/step - loss: 0.2566 - accuracy: 0.9087 - val_loss: 6.6852 - val_accuracy: 0.6532
Epoch 15/20
16/16 [==============================] - 2s 149ms/step - loss: 0.0579 - accuracy: 0.9840 - val_loss: 6.5971 - val_accuracy: 0.6935
Epoch 16/20
16/16 [==============================] - 2s 152ms/step - loss: 0.0414 - accuracy: 0.9866 - val_loss: 6.6049 - val_accuracy: 0.7581
Epoch 17/20
16/16 [==============================] - 2s 146ms/step - loss: 0.0907 - accuracy: 0.9689 - val_loss: 6.6476 - val_accuracy: 0.6452
Epoch 18/20
16/16 [==============================] - 2s 147ms/step - loss: 0.0929 - accuracy: 0.9685 - val_loss: 6.6590 - val_accuracy: 0.7903
Epoch 19/20
16/16 [==============================] - 2s 146ms/step - loss: 0.0364 - accuracy: 0.9935 - val_loss: 6.5915 - val_accuracy: 0.6290
Epoch 20/20
16/16 [==============================] - 2s 151ms/step - loss: 0.1081 - accuracy: 0.9662 - val_loss: 6.6541 - val_accuracy: 0.6613

六、模型评估

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1227833.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【win32_000】视频截图

PPT 编译器不会自己添加unicode定义 v 函数 WinMain int __clrcall WinMain([in] HINSTANCE hInstance ,//应用程序的当前实例的句柄。[in, optional] HINSTANCE hPrevInstance ,//应用程序上一个实例的句柄。 此参数始终为 NULL。[in] …

文心一言 VS 讯飞星火 VS chatgpt (138)-- 算法导论11.4 2题

二、用go语言,试写出HASH-DELETE 的伪代码;修改 HASH-INSERT,使之能处理特殊值DELETED。 文心一言: 首先,让我们定义一个基本的哈希表数据结构。这个结构将包括一个存储键值对的哈希表和一个存储已删除键值对的队列。我们可以用…

mac控制台命令小技巧

shigen日更文章的博客写手,擅长Java、python、vue、shell等编程语言和各种应用程序、脚本的开发。记录成长,分享认知,留住感动。 hello伙伴们,作为忠实的mac骨灰级别的粉丝,它真的给我带来了很多效率上的提升。那作为接…

mysql练习1

-- 1.查询出部门编号为BM01的所有员工 SELECT* FROMemp e WHEREe.deptno BM01; -- 2.所有销售人员的姓名、编号和部门编号。 SELECTe.empname,e.empno,e.deptno FROMemp e WHEREe.empstation "销售人员";-- 3.找出奖金高于工资的员工。 SELECT* FROMemp2 WHE…

日志维护库:loguru

在复杂的项目中,了解程序的运行状态变得至关重要。在这个过程中,日志记录(logging)成为我们追踪、调试和了解代码执行的不可或缺的工具。在python语言中常用logging日志库,但是logging日志库使用相对繁琐,在…

linux系统环境下mysql安装和基本命令学习

此篇文章为蓝桥云课--MySQL的学习记录 块引用部分为自己的实验部分,其余部分是课程自带的知识,链接如下: MySQL 基础课程_MySQL - 蓝桥云课 本课程为 SQL 基本语法及 MySQL 基本操作的实验,理论内容较少,动手实践多&am…

PMCW体制雷达系列文章(4) – PMCW雷达之抗干扰

说明 本文作为PMCW体制雷达系列文章之一,主要聊聊FMCW&PMCW两种体制雷达的干扰问题。事实上不管是通信领域还是雷达领域,对于一切以电磁波作为媒介的信息传递活动,干扰是无处不在的。近年来,随着雷达装车率的提高,…

Shell条件测试练习

1、取出/etc/passwd文件的第6行; [rootshell ~]# head -6 /etc/passwd | tail -1 sync:x:5:0:sync:/sbin:/bin/sync [rootshell ~]# sed -n 6p /etc/passwd sync:x:5:0:sync:/sbin:/bin/sync [rootshell ~]# awk NR6 /etc/passwd sync:x:5:0:sync:/sbin:/bin/sync 2…

FPGA设计时序约束八、others类约束之Set_Case_Analysis

目录 一、序言 二、Set Case Analysis 2.1 基本概念 2.2 设置界面 2.3 命令语法 2.4 命令示例 三、工程示例 四、参考资料 一、序言 在Vivado的时序约束窗口中,存在一类特殊的约束,划分在others目录下,可用于设置忽略或修改默认的时序…

一文带你详细了解JVM运行时内存

一文带你详细了解JVM运行时内存 1. 程序计数器2. 虚拟机栈3. 本地方法栈4. 堆4.1 堆的总括4.1.1 概念4.1.2 特点4.1.3 设置堆内存大小4.1.4 堆的分类 4.2 新生代和老年代4.2.1 对象存储4.2.2 配置新生代和老年代的堆中占比 4.3 对象分配过程4.4 堆GC 5.元空间6.方法区6.1 方法区…

redis问题归纳

1.redis为什么这么快? (1)基于内存操作:redis的所有数据都存在内存中,因此所有的运算都是内存级别的,所以性能比较高 (2)数据结构简单:redis的数据结构是专门设计的&…

Unity2021及以上 启动或者禁用自动刷新

Unity 2021以以上启动自动刷新 Edit---> Preferences--> Asset Pipline --> Auto Refresh 禁用的结果 如果不启动自动刷新在Project面板选择Refresh是不会刷新已经修改后的脚本的。

MIB 6.S081 System calls(1)using gdb

难度:easy In many cases, print statements will be sufficient to debug your kernel, but sometimes being able to single step through some assembly code or inspecting the variables on the stack is helpful. To learn more about how to run GDB and the common iss…

存储过程与触发器

一、存储过程 1.1 概念 把需要重复执行的内容放在存储过程中,实现代码的复用。 create procedure 创建存储过程的关键字 my_proc1:存储过程的名字。 执行下例代码就是创建了一个存储过程 执行存储过程,就是把上图的插入语句重复执行,现…

蓝桥杯每日一题2023.11.19

题目描述 “蓝桥杯”练习系统 (lanqiao.cn) 题目分析 首先想到的方法为dfs去寻找每一个数&#xff0c;但发现会有超时 #include<bits/stdc.h> using namespace std; const int N 2e5 10; int n, cnt, a[N]; void dfs(int dep, int sum, int start) {if(dep 4){if(s…

linux文件IO

文件IO截断 截断对文件的偏移量没有影响。

Cross-View Transformers for Real-Time Map-View Semantic Segmentation 论文阅读

论文链接 Cross-View Transformers for Real-Time Map-View Semantic Segmentation 0. Abstract 提出了 Cross-View Transformers &#xff0c;一种基于注意力的高效模型&#xff0c;用于来自多个摄像机的地图视图语义分割使用相机感知的跨视图注意机制隐式学习从单个相机视…

git基本用法和操作

文章目录 创建版本库方式&#xff1a;Git常用操作命令&#xff1a;远程仓库相关命令分支(branch)操作相关命令版本(tag)操作相关命令子模块(submodule)相关操作命令忽略一些文件、文件夹不提交其他常用命令 创建版本库方式&#xff1a; 创建文件夹 在目录下 右键 Git Bush H…

Kafka性能测试初探

相信大家对Kafka不会陌生&#xff0c;但首先还是要简单介绍一下。 Kafka是一种高性能的分布式消息系统&#xff0c;由LinkedIn公司开发&#xff0c;用于处理海量的实时数据流。它采用了发布/订阅模式&#xff0c;可以将数据流分发到多个消费者端&#xff0c;同时提供了高可靠性…

Shell判断:流程控制—if(三)

一、调试脚本 1、调试脚本的其他方法&#xff1a; [rootlocalhost ~] # sh -n useradd.sh 仅调试脚本中的语法错误。 [rootlocalhost ~]# sh -vx useradd.sh 以调试的方式执行&#xff0c;查询整个执行过程。 2、示例&#xff1a; [rootlocalhost ~]# sh -n useradd.sh #调…