Python做个猫狗识别系统,给人美心善的邻居

news2024/10/5 19:20:27

嗨害大家好鸭!我是爱摸鱼的芝士❤

请添加图片描述

宠物真的看着好治愈

谁不想有一只属于自己的乖乖宠物捏~

这篇文章中我放弃了以往的model.fit()训练方法,
改用model.train_on_batch方法。

两种方法的比较:

  • model.fit():用起来十分简单,对新手非常友好
  • model.train_on_batch():封装程度更低,可以玩更多花样。

此外我也引入了进度条的显示方式,更加方便我们及时查看模型训练过程中的情况,可以及时打印各项指标。

🚀 我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1
  • 显卡(GPU):NVIDIA GeForce RTX 3080

请添加图片描述

一、前期工作

1. 设置GPU

如果使用的是CPU可以注释掉这部分的代码。

import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")
 
if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  
    tf.config.set_visible_devices([gpus[0]],"GPU")
 
print(gpus)
PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

2. 导入数据

import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  
import os,PIL
import numpy as np
np.random.seed(1)
import tensorflow as tf
tf.random.set_seed
import warnings
warnings.filterwarnings('ignore')
 
import pathlib
data_dir = "./data/train"
data_dir = pathlib.Path(data_dir)

3. 查看数据

image_count = len(list(data_dir.glob('*/*')))
 
print("图片总数为:",image_count)
图片总数为:3400

请添加图片描述

二、数据预处理

1. 加载数据

使用image_dataset_from_directory
方法将磁盘中的数据加载到tf.data.Dataset中

batch_size = 8
img_height = 224
img_width = 224

TensorFlow版本是2.2.0的同学可能会遇到
module ‘tensorflow.keras.preprocessing’ has no attribute 'image_dataset_from_directory’的报错,
升级一下TensorFlow就OK了

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

Found 3400 files belonging to 2 classes.
Using 2720 files for training.

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 3400 files belonging to 2 classes.
Using 680 files for validation.

我们可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。

class_names = train_ds.class_names
print(class_names)
['cat', 'dog']

2. 再次检查数据

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break
(8, 224, 224, 3)
(8,)

Image_batch是形状的张量(8, 224, 224, 3)。这是一批形状224x224x3的8张图片(最后一维指的是彩色通道RGB)。

Label_batch是形状(8,)的张量,这些标签对应8张图片

3. 配置数据集

  • shuffle() :打乱数据,关于此函数的详细介绍可以参考:https://zhuanlan.zhihu.com/p/42417456
  • prefetch() :预取数据,加速运行,其详细介绍可以参考我前两篇文章,里面都有讲解。
  • cache() :将数据集缓存到内存当中,加速运行
AUTOTUNE = tf.data.AUTOTUNE
 
def preprocess_image(image,label):
    return (image/255.0,label)
train_ds = train_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds   = val_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
 
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

如果报 AttributeError: module ‘tensorflow._api.v2.data’ has no attribute ‘AUTOTUNE’ 错误,就将 AUTOTUNE = tf.data.AUTOTUNE 更换为 AUTOTUNE = tf.data.experimental.AUTOTUNE,这个错误是由于版本问题引起的。

4. 可视化数据

plt.figure(figsize=(15, 10)) 
 
for images, labels in train_ds.take(1):
    for i in range(8):
        
        ax = plt.subplot(5, 8, i + 1) 
        plt.imshow(images[i])
        plt.title(class_names[labels[i]])
        
        plt.axis("off")

请添加图片描述

请添加图片描述

三、构建VG-16网络

VGG优缺点分析:

  • VGG优点

VGG的结构非常简洁,整个网络都使用了同样大小的卷积核尺寸(3x3)和最大池化尺寸(2x2)。

  • VGG缺点

1)训练时间过长,调参难度大。2)需要的存储容量大,不利于部署。例如存储VGG-16权重值文件的大小为500多MB,不利于安装到嵌入式系统中。

结构说明:

  • 13个卷积层(Convolutional Layer),分别用blockX_convX表示
  • 3个全连接层(Fully connected Layer),分别用fcX与predictions表示
  • 5个池化层(Pool layer),分别用blockX_pool表示

VGG-16包含了16个隐藏层(13个卷积层和3个全连接层),故称为VGG-16

请添加图片描述
请添加图片描述

from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout
 
def VGG16(nb_classes, input_shape):
    input_tensor = Input(shape=input_shape)
    # 1st block
    x = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv1')(input_tensor)
    x = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv2')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block1_pool')(x)
    # 2nd block
    x = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv1')(x)
    x = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv2')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block2_pool')(x)
    # 3rd block
    x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv1')(x)
    x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv2')(x)
    x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv3')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block3_pool')(x)
    # 4th block
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv1')(x)
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv2')(x)
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv3')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block4_pool')(x)
    # 5th block
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv1')(x)
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv2')(x)
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv3')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block5_pool')(x)
    # full connection
    x = Flatten()(x)
    x = Dense(4096, activation='relu',  name='fc1')(x)
    x = Dense(4096, activation='relu', name='fc2')(x)
    output_tensor = Dense(nb_classes, activation='softmax', name='predictions')(x)
 
    model = Model(input_tensor, output_tensor)
    return model
 
model=VGG16(1000, (img_width, img_height, 3))
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 112, 112, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 56, 56, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 56, 56, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 28, 28, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 28, 28, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 14, 14, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 7, 7, 512)         0         
_________________________________________________________________
flatten (Flatten)            (None, 25088)             0         
_________________________________________________________________
fc1 (Dense)                  (None, 4096)              102764544 
_________________________________________________________________
fc2 (Dense)                  (None, 4096)              16781312  
_________________________________________________________________
predictions (Dense)          (None, 1000)              4097000   
=================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
_________________________________________________________________

请添加图片描述

四、编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 评价函数(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
model.compile(optimizer="adam",
              loss     ='sparse_categorical_crossentropy',
              metrics  =['accuracy'])

请添加图片描述

五、训练模型

from tqdm import tqdm
import tensorflow.keras.backend as K
 
epochs = 10
lr     = 1e-4
 
# 记录训练数据,方便后面的分析
history_train_loss     = []
history_train_accuracy = []
history_val_loss       = []
history_val_accuracy   = []
 
for epoch in range(epochs):
    train_total = len(train_ds)
    val_total   = len(val_ds)
    
    """
    total:预期的迭代数目
    ncols:控制进度条宽度
    mininterval:进度更新最小间隔,以秒为单位(默认值:0.1)
    """
    with tqdm(total=train_total, desc=f'Epoch {epoch + 1}/{epochs}',mininterval=1,ncols=100) as pbar:
        
        lr = lr*0.92
        K.set_value(model.optimizer.lr, lr)
        
        for image,label in train_ds:      
            """
            训练模型,简单理解train_on_batch就是:它是比model.fit()更高级的一个用法
            
            想详细了解 train_on_batch 的同学,
            可以看看我的这篇文章:https://mtyjkh.blog.csdn.net/article/details/119506151
            """
            history = model.train_on_batch(image,label)
            
            train_loss     = history[0]
            train_accuracy = history[1]
            
            pbar.set_postfix({"loss": "%.4f"%train_loss,
                              "accuracy":"%.4f"%train_accuracy,
                              "lr": K.get_value(model.optimizer.lr)})
            pbar.update(1)
        history_train_loss.append(train_loss)
        history_train_accuracy.append(train_accuracy)
            
    print('开始验证!')
    
    with tqdm(total=val_total, desc=f'Epoch {epoch + 1}/{epochs}',mininterval=0.3,ncols=100) as pbar:
 
        for image,label in val_ds:      
            
            history = model.test_on_batch(image,label)
            
            val_loss     = history[0]
            val_accuracy = history[1]
            
            pbar.set_postfix({"loss": "%.4f"%val_loss,
                              "accuracy":"%.4f"%val_accuracy})
            pbar.update(1)
        history_val_loss.append(val_loss)
        history_val_accuracy.append(val_accuracy)
            
    print('结束验证!')
    print("验证loss为:%.4f"%val_loss)
    print("验证准确率为:%.4f"%val_accuracy)
Epoch 1/10: 100%|████████| 340/340 [00:23<00:00, 14.36it/s, loss=1.1077, accuracy=0.6250, lr=9.2e-5]
开始验证!
Epoch 1/10: 100%|█████████████████████| 85/85 [00:02<00:00, 36.55it/s, loss=0.9331, accuracy=0.6250]
结束验证!
验证loss为:0.9331
验证准确率为:0.6250
Epoch 2/10: 100%|███████| 340/340 [00:19<00:00, 17.49it/s, loss=0.4633, accuracy=0.6250, lr=8.46e-5]

......

Epoch 9/10: 100%|███████| 340/340 [00:19<00:00, 17.36it/s, loss=0.0112, accuracy=1.0000, lr=4.72e-5]
开始验证!
Epoch 9/10: 100%|█████████████████████| 85/85 [00:01<00:00, 43.46it/s, loss=0.0302, accuracy=1.0000]
结束验证!
验证loss为:0.0302
验证准确率为:1.0000
Epoch 10/10: 100%|██████| 340/340 [00:19<00:00, 17.22it/s, loss=0.0000, accuracy=1.0000, lr=4.34e-5]
开始验证!
Epoch 10/10: 100%|████████████████████| 85/85 [00:02<00:00, 42.15it/s, loss=0.0231, accuracy=1.0000]
结束验证!
验证loss为:0.0231
验证准确率为:1.0000
# 这是我们之前的训练方法。
# history = model.fit(
#     train_ds,
#     validation_data=val_ds,
#     epochs=epochs
# )

请添加图片描述

六、模型评估

epochs_range = range(epochs)
 
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
 
plt.plot(epochs_range, history_train_accuracy, label='Training Accuracy')
plt.plot(epochs_range, history_val_accuracy, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
 
plt.subplot(1, 2, 2)
plt.plot(epochs_range, history_train_loss, label='Training Loss')
plt.plot(epochs_range, history_val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

请添加图片描述

请添加图片描述

七、保存and加载模型

这是最简单的模型保存与加载方法哈

# 保存模型
model.save('model/21_model.h5')
# 加载模型
new_model = tf.keras.models.load_model('model/21_model.h5')

八、预测

plt.figure(figsize=(18, 3))  
plt.suptitle("预测结果展示")
 
for images, labels in val_ds.take(1):
    for i in range(8):
        ax = plt.subplot(1,8, i + 1)  

        plt.imshow(images[i].numpy())

        img_array = tf.expand_dims(images[i], 0) 
       
        predictions = new_model.predict(img_array)
        plt.title(class_names[np.argmax(predictions)])
 
        plt.axis("off")

请添加图片描述

今天的文章就是这样啦~

祝大家早日有属于自己的爱宠~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/411226.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【剧前爆米花--爪哇岛寻宝】java文件操作和io流

作者&#xff1a;困了电视剧 专栏&#xff1a;《JavaEE初阶》 文章分布&#xff1a;这是一篇关于文件操作的文件&#xff0c;介绍了文件读写以及相关对象的内容&#xff0c;希望对你有所帮助&#xff01; 目录 文件操作 文件路径 绝对路径 相对路径 File类 File类的构造方…

REDIS Hash 槽 与 一致性hash

开头还是介绍一下群&#xff0c;如果感兴趣polardb ,mongodb ,mysql ,postgresql ,redis 等有问题&#xff0c;有需求都可以加群群内有各大数据库行业大咖&#xff0c;CTO&#xff0c;可以解决你的问题。加群请联系 liuaustin3 &#xff0c;在新加的朋友会分到2群&#xff08;共…

Spring boot+Vue3博客平台:文章发布与编辑功能的技术实现

本文将详细介绍如何实现一个博客平台中的文章发布与编辑功能&#xff0c;包括前端的Vue组件设计和后端的Spring Boot接口实现。在阅读本文后&#xff0c;您将了解如何设计和实现高效、易用的文章发布与编辑功能。 一、发布文章 设计思路 在设计文章发布功能时&#xff0c;我们…

vscode中调试rust程序

文章目录一、vscode运行和调式rust程序二、常见问题1.rust: Request textDocument/formatting failed.2.cargo命令3.使用rust-gdb调试rust程序4.cargo build太慢一、vscode运行和调式rust程序 环境&#xff1a;在WSL&#xff08;ubuntu20.04&#xff09;中使用vscode &#xf…

数据技术嘉年华星光璀璨,云和恩墨全栈数据技术能力闪耀会场

导语 2023年4月7-8日&#xff0c;由中国DBA联盟&#xff08;ACDU&#xff09;和墨天轮社区联合主办的第十二届『数据技术嘉年华』&#xff08;DTC 2023&#xff09;在北京成功举办。云和恩墨作为大会的协办方和深度参与者&#xff0c;以6场演讲2大展台全面呈现公司的全栈数据技…

几何算法——4.交线(intersection curve)的表达与参数化、微分性质

几何算法——4.曲面求交的交线&#xff08;intersection curve&#xff09;的表达与参数化、微分性质1 关于曲面求交的交线表达2 交线的微分性质3 交线的参数化4 修正弦长参数化的微分性质1 关于曲面求交的交线表达 两个曲面求交&#xff0c;比较经典的方法是用跟踪法&#xf…

wsl使用vscode搭建自己的MySQL

装wsl装MySQL装wsl 我已经装好了,就不说了 装MySQL 安装 MySQL 服务器&#xff1a;终端命令行输入sudo apt install mysql-server 安装完成后&#xff0c;MySQL 服务器会自动启动并在 Ubuntu 启动时启动。您可以使用以下命令检查 MySQL 服务器是否正在运行&#xff1a;sudo ser…

【三十天精通Vue 3】第六天 Vue 3 计算属性和监听器详解

✅创作者&#xff1a;陈书予 &#x1f389;个人主页&#xff1a;陈书予的个人主页 &#x1f341;陈书予的个人社区&#xff0c;欢迎你的加入: 陈书予的社区 &#x1f31f;专栏地址: 三十天精通 Vue 3 文章目录引言一、Vue 3 计算属性概述1.1 计算属性的简介1.2 计算属性的分类…

第二十章 案例TodoList之动态数据

我们之前已经实现了静态的组件拆分&#xff0c;既然是静态说明数据就是死的&#xff0c;显然这不是我们需要的结果&#xff0c;之前我们学习了React组件&#xff0c;知道组件里面的状态数据驱动了页面的显示&#xff0c;每个组件都有属于自己的状态数据。接下来我们改造组件使得…

SAR ADC系列25:作业和上机实践

作业&#xff1a; 异步SAR逻辑中VALID信号如何产生&#xff1f;答&#xff1a;OUTP和OUTN与非。单纯通过减小“比较器环路”的延时(t1t22*t32*t4)的方式来提升ADC的转换速率可行吗&#xff1f;为什么&#xff1f;答&#xff1a;不可行&#xff0c;还要考虑CDAC建立的速度&…

【ARMv8 编程】A64 数据处理指令——位域字节操作指令

有些指令将字节、半字或字扩展到寄存器大小&#xff0c;可以是 X 或 W。这些指令存在于有符号&#xff08;SXTB、SXTH、SXTW&#xff09;和无符号&#xff08;UXTB、UXTH&#xff09;变体中&#xff0c;并且是适当的位域操作指令。 这些指令的有符号和无符号变体都将字节、半字…

【失业即将到来?】AI时代会带来失业潮吗?

文章目录前言一、全面拥抱AIGC二、AI正在取代这类行业总结前言 兄弟姐妹们啊&#xff0c;AI时代&#xff0c;说抛弃就抛弃&#xff0c;真的要失业了。 一、全面拥抱AIGC 蓝色光标全面暂停外包&#xff1f; 一份文件截图显示&#xff0c;中国知名4A广告公司&#xff0c;蓝色光…

汇编第二次上机实验(续第一次,字符串比较及双重循环)【嵌入式系统】

汇编第二次上机实验&#xff08;续第一次&#xff0c;字符串比较及双重循环&#xff09;【嵌入式系统】前言推荐说明汇编第二次上机实验&#xff08;续第一次&#xff0c;字符串比较及双重循环&#xff09;内容1 sort说明流程图代码编写结果分析2 string流程图代码编写结果分析…

Nginx的安装、反向代理、负载均衡及部署项目

Nginx 一、Nginx简介 Nginx称为:负载均衡器或 静态资源服务器:html,css,js,img ​ Nginx(发音为“engine X”)是俄罗斯人编写的十分轻量级的HTTP服务器,是一个高性能的HTTP和反向代理服务器&#xff0c;同时也是一个IMAP/POP3/SMTP 代理服务器。Nginx是由俄罗斯人 Igor Syso…

DOM(上)

DOM&#xff08;文档对象模型&#xff09;&#xff1a;处理可扩展标记语言(HTML或XML&#xff09;的标准编程接口&#xff0c;可以改变网页的内容、结构和样式。DOM树&#xff1a; …

大数据项目实战之数据仓库:电商数据仓库系统——第2章 数据仓库建模概述

第2章 数据仓库建模概述 2.1 数据仓库建模的意义 如果把数据看作图书馆里的书&#xff0c;我们希望看到它们在书架上分门别类地放置&#xff1b;如果把数据看作城市的建筑&#xff0c;我们希望城市规划布局合理&#xff1b;如果把数据看作电脑文件和文件夹&#xff0c;我们希…

CMake——从入门到百公里加速6.7s

目录 一、前言 二、HelloWorld 三、CMAKE 界面 3.1 gui正则表达式 3.2 GUI构建 四 关键字 4.1 add_library 4.2 add_subdirectory 4.3 add_executable 4.4 aux_source_directory 4.5 SET设置变量 4.6 INSTALL安装 4.7 ADD_LIBRARY 4.8 SET_TARGET_PROPERTIES 4.9…

[JavaEE]----Spring03

文章目录Spring_day031&#xff0c;AOP简介1.1 什么是AOP?1.2 AOP作用1.3 AOP核心概念2&#xff0c;AOP入门案例2.1 需求分析2.2 思路分析2.3 环境准备2.4 AOP实现步骤步骤1:添加依赖步骤2:定义接口与实现类步骤3:定义通知类和通知步骤4:定义切入点步骤5:制作切面步骤6:将通知…

测试-子查询及数据更新

测试-子查询及数据更新 目录测试-子查询及数据更新1、修改borrow表增加一列&#xff1b;修改日期数据&#xff08;两条语句完成&#xff09;题目代码题解2、 SQL更新&#xff1a;删除-删除“吴宾”的所有成绩记录题目代码3、SQL查询&#xff1a;查询没有被订购的商品题目代码4、…

CMake GUI工具使用 MinGW 64构建工程

系列文章目录 文章目录系列文章目录前言一、open Project是灰色&#xff1f;前言 CMake GUI 打开 CMake GUI。 在 “Where is the source code” 字段中&#xff0c;选择 Krita 源代码目录&#xff1a;E:/krita-dev/krita。 在 “Where to build the binaries” 字段中&#x…