TensorFlow入门图像分类-猫狗分类-MobileNet优化

news2024/9/30 1:26:08

        在上一篇文章中《Tensorflow入门图像分类-猫狗分类-安卓》,介绍了使用TensorFlow训练一个猫狗图像分类器的模型并在安卓应用上使用的全过程。

        在这一篇文章中,将采用 MobileNet 来重新训练一个猫狗图像分类器。

一、 MobileNet 介绍 

        MobileNet是一种轻量级的神经网络架构,主要用于移动和嵌入式设备上的计算机视觉应用。它由Google Brain团队开发,旨在通过减少模型参数数量和计算复杂性来实现高效的图像分类、目标检测和语义分割等任务。

        MobileNet采用了深度可分离卷积(depthwise separable convolution)来替代传统卷积操作,从而大幅降低了计算成本。深度可分离卷积将卷积操作分为两个步骤:首先对每个输入通道进行单独的空间卷积,然后再对通道之间的结果进行逐点卷积。这种方法可以显著地减少模型中的参数数量和计算量,并且可以在保持较高准确率的同时,将模型压缩到原始模型的很小部分。

        MobileNet还使用了全局平均池化层来代替全连接层,以进一步减少模型大小和计算复杂度。此外,它还引入了线性瓶颈结构和批规范化(batch normalization)技术来提高性能和稳定性。

        总体来说,MobileNet是一种非常有效的神经网络架构,可以在移动和嵌入式设备上实现高效的计算机视觉应用。

二、数据准备

        本实例依然是采用上一篇文章介绍的猫狗数据集。链接:Download Kaggle Cats and Dogs Dataset from Official Microsoft Download Center

        同样需要把数据集中的脏数据清除掉:

import os
from PIL import Image
 
# Set the directory to search for corrupted files
directory = 'path/to/directory'
 
# Loop through all files in the directory
for filename in os.listdir(directory):
 
    # Check if file is an image
    if filename.endswith('.jpg') or filename.endswith('.png'):
        
        # Attempt to open image with PIL
        try:
            img = Image.open(os.path.join(directory, filename))
            img.verify()
            img.close()
        except (IOError, SyntaxError) as e:
            print(f"Deleting {filename} due to error: {e}")
            os.remove(os.path.join(directory, filename))

        接着需要把数据集分为:训练集、测试集、验证集。

        参考代码:

# 将数据集分为训练集、测试集、验证集
import os
import shutil
import numpy as np
 
train_dir = os.path.join(os.path.dirname(dataset_dir), 'PetImages_train')
val_dir = os.path.join(os.path.dirname(dataset_dir), 'PetImages_validation')
test_dir = os.path.join(os.path.dirname(dataset_dir), 'PetImages_test')
 
if not os.path.exists(train_dir):
    train_ratio = 0.7  # 训练集比例
    val_ratio = 0.15   # 验证集比例
    test_ratio = 0.15  # 测试集比例
 
    classfiers = ['Cat', 'Dog']
    for cls in classfiers:
        # 获取数据集中所有文件名
        filenames = os.listdir(os.path.join(dataset_dir, cls))
        
        # 计算拆分后的数据集大小
        num_samples = len(filenames)
        num_train = int(num_samples * train_ratio)
        num_val = int(num_samples * val_ratio)
        num_test = num_samples - num_train - num_val
        
        # 将文件名打乱顺序
        shuffle_indices = np.random.permutation(num_samples)
        filenames = [filenames[i] for i in shuffle_indices]
        
        os.makedirs(os.path.join(train_dir, cls), exist_ok=True)
        os.makedirs(os.path.join(val_dir, cls), exist_ok=True)
        os.makedirs(os.path.join(test_dir, cls), exist_ok=True)
    
        # 拆分数据集并复制文件到相应目录
        for i in range(num_train):
            src_path = os.path.join(dataset_dir, cls, filenames[i])
            dst_path = os.path.join(train_dir, cls, filenames[i])
            shutil.copy(src_path, dst_path)
 
        for i in range(num_train, num_train+num_val):
            src_path = os.path.join(dataset_dir, cls, filenames[i])
            dst_path = os.path.join(val_dir, cls, filenames[i])
            shutil.copy(src_path, dst_path)
 
        for i in range(num_train+num_val, num_samples):
            src_path = os.path.join(dataset_dir, cls, filenames[i])
            dst_path = os.path.join(test_dir, cls, filenames[i])
            shutil.copy(src_path, dst_path)

三、模型训练

3.1 准备

import os
import tensorflow as tf
from tensorflow.keras import layers

3.2 数据加载和数据增强

# 加载数据集
train_dir = 'I:/数据集/kagglecatsanddogs_5340/PetImages_train'
val_dir = 'I:/数据集/kagglecatsanddogs_5340/PetImages_validation'
test_dir = 'I:/数据集/kagglecatsanddogs_5340/PetImages_test'
batch_size = 32
image_size = 224

train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(image_size, image_size),
    batch_size=batch_size,
    class_mode='categorical')

val_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)


validation_generator = val_datagen.flow_from_directory(
    val_dir,
    target_size=(image_size, image_size),
    batch_size=batch_size,
    class_mode='categorical')

        这段代码主要用于数据预处理和数据生成器的创建,用于训练和验证深度学习模型。

        首先,指定了训练、验证和测试数据集所在的文件夹路径。batch_size变量表示每个批次的图像数,image_size变量表示将图像调整为的统一大小。

        然后,创建一个ImageDataGenerator对象,train_datagen,来对输入图像进行数据增强(rescale、shear_range、zoom_range和horizontal_flip),从而扩充我们的训练数据集。

        接着,使用train_datagen.flow_from_directory函数来创建一个训练集的数据生成器train_generator,它将自动从train_dir目录中读取图像,并实时地将其转换为张量批次,以便于训练模型。class_mode参数设为'categorical'表示使用分类标签。

        同样的,也创建了一个ImageDataGenerator对象val_datagen,只应用了图像缩放操作。接着使用val_datagen.flow_from_directory函数创建了一个验证集的数据生成器validation_generator,class_mode参数也设置为'categorical',以便于评估模型的精度。

3.3 模型设计

input_shape = (image_size, image_size, 3)
num_classes = 2

# 加载预训练模型
base_model = tf.keras.applications.MobileNetV2(input_shape=input_shape, include_top=False, weights='imagenet')

# 冻结前面的层
for layer in base_model.layers:
    layer.trainable = False

# 添加新的全连接层
x = base_model.output
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.5)(x)
predictions = layers.Dense(num_classes, activation='sigmoid')(x)

# 构造完整模型
model = tf.keras.models.Model(inputs=base_model.input, outputs=predictions)
    
# 编译模型
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.0001), loss='binary_crossentropy', metrics=['accuracy'])

        首先,指定了输入图像的形状(input_shape)和分类数目(num_classes)。

        然后,使用tf.keras.applications.MobileNetV2函数加载了一个已经预先训练好的MobileNetV2模型,该模型作为神经网络的基础架构。include_top参数被设置为False,表示只需要模型的卷积部分,而不需要其全连接层。weights参数被设置为'imagenet',表示使用在ImageNet上预训练的权重值。

        接着,将MobileNetV2模型中的所有层都冻结起来,即将它们的trainable属性设为False,这样在训练过程中它们的权重将不会被更新。

        添加新的全连接层,其中 x = base_model.output 表示将MobileNetV2的输出作为新模型的输入;layers.GlobalAveragePooling2D() 将每个特征图中所有位置的值取平均值,得到一个固定长度的向量;Dense(128, activation='relu') 表示添加一个包含128个神经元的全连接层,并使用ReLU激活函数进行非线性变换;layers.Dropout(0.5) 表示在全连接层后面加上一个dropout层,以减少过拟合风险;最后一层layer.Dense(num_classes, activation='sigmoid')则根据我们的分类任务,使用sigmoid激活函数输出概率值。

        最后,使用tf.keras.models.Model将MobileNetV2模型和新添加的全连接层拼接在一起,生成完整的深度学习模型。对生成的模型进行编译,设置优化器(Adam)和损失函数(binary_crossentropy),并指定评估指标(accuracy)。

3.4 模型训练

# 训练模型
epochs = 5
steps_per_epoch = train_generator.n // batch_size
validation_steps = validation_generator.n // batch_size

history = model.fit(
    train_generator,
    steps_per_epoch=steps_per_epoch,
    epochs=epochs,
    validation_data=validation_generator,
    validation_steps=validation_steps)

        首先,指定了epochs参数表示要遍历整个训练集的次数。steps_per_epoch表示每个epoch中要进行的步骤数,这里通过train_generator.n // batch_size来确定。同样地,validation_steps也是通过validation_generator.n // batch_size来确定。

        接着,使用model.fit函数开始训练模型。train_generator和validation_generator分别是训练集和验证集的数据生成器;steps_per_epoch、epochs和validation_steps则是前面定义好的参数;validation_data表示使用哪个数据集用来做验证;history = model.fit返回一个History对象,包含了训练过程中的损失值和评估指标等信息。

        在模型训练期间,每个epoch都会将训练集中的所有样本送入模型进行训练,并将验证集的结果返回给我们,以便于我们查看模型的性能表现。最后,我们可以使用得到的History对象来分析模型在训练和验证阶段的表现情况,并据此进行调优。

代码输出:
Epoch 1/10 546/546 [=======] - 135s 239ms/step - loss: 0.1517 - accuracy: 0.9488 - val_loss: 0.0612 - val_accuracy: 0.9797

Epoch 2/10 546/546 [=======] - 125s 230ms/step - loss: 0.0720 - accuracy: 0.9749 - val_loss: 0.0544 - val_accuracy: 0.9826

Epoch 3/10 546/546 [=======] - 128s 234ms/step - loss: 0.0615 - accuracy: 0.9788 - val_loss: 0.0531 - val_accuracy: 0.9818

Epoch 4/10 546/546 [=======] - 124s 228ms/step - loss: 0.0557 - accuracy: 0.9805 - val_loss: 0.0525 - val_accuracy: 0.9810

Epoch 5/10 546/546 [=======] - 126s 231ms/step - loss: 0.0510 - accuracy: 0.9802 - val_loss: 0.0499 - val_accuracy: 0.9829

说明:

  • 从上面的输出来看,在训练的第一个 epoch 结束的时候,模型的 accuracy 就已经达到了 94%,作为对比,在上一篇文章中,完整从头训练的情况下,15个 epoch 才能达到 90% 的 accuracy。
  • 这说明了使用 MobileNet 作为基础模型,可以有效地提高准确率,缩短训练时间。

        由于使用了预训练的 MobileNet 模型作为基础,那么它的权重已经经过了充分的训练和调整,具有较强的特征提取能力,因此在猫狗图像分类任务中表现良好是可以预料的。而如果从零开始训练模型,则需要更多的样本和更长的训练时间才能达到类似的性能。

3.5 评估模型

# 评估模型在测试集上的性能
test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(image_size, image_size),
    batch_size=batch_size,
    class_mode='categorical')

# 评估模型在测试集上的性能
loss, acc = model.evaluate(test_generator)
print(f'Test loss: {loss}, Test accuracy: {acc}')

代码输出:

Found 3752 images belonging to 2 classes.

49/118 [=====>..................] - ETA: 2s - loss: 0.0645 - accuracy: 0.9758

118/118 [==============] - 5s 44ms/step - loss: 0.0590 - accuracy: 0.9784

Test loss: 0.05904101952910423, Test accuracy: 0.9784114956855774

说明:

  • 从输出信息来看,无论是训练时的准确率、还是测试的准备率,都已经达到了97%以上,可以说是非常高的了。

3.6 模型保存

        保存为 tf 
 

# 保存模型参数
model.save('cat_dog_classfier_v2.tf', overwrite=True, include_optimizer=True)

        保存为 tflite 格式,以便在移动端上使用:

# 导出TensorFlow Lite模型
covertor = tf.lite.TFLiteConverter.from_keras_model(model)
covertor.optimizations = [tf.lite.Optimize.DEFAULT]
tflife_model = covertor.convert()

with open('cat_dog_classfier_v2.tflite', 'wb') as f:
  f.write(tflife_model)

        导出的tflite文件 如下:

       从 上图对比得出:使用 MobileNet 训练的 v2 模型,只有 2.54 MB,而不使用 MobileNet 的 v1 版本,有 10 MB 。

四、总结

        本文介绍了采用 MobileNet 作为神经网络的基础架构来训练猫狗图像分类器的方法,该方法十分适合移动端。它不但减少了训练时间、提高了准确率,同时还减少了模型文件大小。

        .

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/485826.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

服务(第十六篇)mysql①基础

什么是数据库? 数据: ①描述事物的符号记录称为数据(Data),数字、文字、图形、图像、声音、档案记录等; ②数据是以“记录”的形式按照统一的格式进行存储的,而不是杂乱无章的。 行&#xf…

35岁以10亿美元身价登上《财富》杂志亿万富豪榜的电商传奇谢家华

Zappos的介绍 Zappos可谓是电商的传奇,国内同类电商是乐淘。Zappos是一家在线卖鞋和服装的公司,1999年创立,2009年被亚马逊以12亿元收购, 多次入选财富杂志最佳雇主公司top100。 Zappos的创始人及CEO 提到Zappos就不得不介绍下…

SQL知识汇总

什么时候用存储过程合适 当一个事务涉及到多个SQL语句时或者涉及到对多个表的操作时就要考虑用存储过程;当在一个事务的完成需要很复杂的商业逻辑时(比如,对多个数据的操作,对多个状态的判断更改等)要考虑&#xff1b…

05.rabbitMQ之搭建的各种坑

1.持久化需要重新设置队列 2.异步发布确认的坑, 生产者发消息太快只会确认最大的编号 1.消费者还是要确认消息 channel.basicAck(message.getEnvelope().getDeliveryTag(), false); 因为你发送的太快了,只会返回成功接收的最大的编号 3.消费者消息堆积(开启了消息手…

InnoDB 磁盘结构之数据字典和双写缓冲区

数据字典(InnoDB Data Dictionary) MySQL中,数据字典包括了: 表结构、数据库名或表名、字段的数据类型、视图、索引、表字段信息、MySQL版本信息、存储过程、触发器等内容 InnoDB数据字典由内部系统表组成,这些表包含用于查找表…

7万字水利数字孪生工程解决方案(word可编辑)

本资料来源公开网络,仅供个人学习,请勿商用,如有侵权请联系删除。 1.1 系统开发方案 1.1.1 系统设计开发思路 (1)基于层次分解的设计 xxx水利数字孪生工程将采用基于层次分解的系统模型,系统采用这种方式进行层次划…

【P5】JMeter CSV Data Set Config(CSV 数据文件设置)

文章目录 一、测试计划演示二、CSV Data Set Config(CSV 数据文件设置)主要参数说明2.1、忽略首行:True2.2、是否允许带引号?:False2.3、遇到文件结束符再次循环?:False2.4、遇到文件结束符停止…

Apache 可能会出手接盘 Google Wave

尽管Google计划在明年终止Google Wave项目,但他们提供Wave in a Box开源项目允许你在自己的服务器上跑一个Google Wave服务玩。据The Register报道,Apache Software Group正在试图将Wave in a Box移植到目前的管理系统里。尽管目前还处于早期孵化阶段&am…

AI模型部署概述

心口如一,犹不失为光明磊落丈夫之行也。——梁启超 文章目录 :smirk:1. AI模型部署方法:blush:2. AI模型部署框架ONNXNCNNOpenVINOTensorRTMediapipe如何选择 :satisfied:3. AI模型部署平台 😏1. AI模型部署方法 在AI深度学习模型的训练中,…

链游“风暴之年”已来 一文解读Web3游戏的前生今世

链上世界进入游戏市场,让越来越多游戏厂商不由得感叹区块链游戏(简称“链游”)的风暴之年正在加速到来。如今,游戏活动转变了单一的休闲娱乐理念,逐渐走向Web3发展个性化、可定义的未来。 前不久,阿里云作为…

S3C6410 中的 irqdomain 之 gpio

文章目录 VIC domain 与 gpio domain 的硬件拓扑图描述linux cascaded irq domainlinux irq domain 实例VIC domain 与 gpio domain 的硬件拓扑语言描述VIC 与 INT_EINTx 的关系INT_EINTx 与 GPIO的关系INT_EINT0INT_EINT1INT_EINT2INT_EINT3INT_EINT4INT_EINT4 与 External in…

【Elasticsearch】DSL操作相关

文章目录 DSL操作索引操作新建索引查询索引查看所有索引删除索引 映射操作创建映射查看映射索引映射关联(同创建映射类似) 文档操作创建文档查询指定ID文档查询所有文档全局修改文档局部修改文档删除文档条件删除 数据搜索数据准备条件查询(match)多字段条件查询(multi_match)关…

VMware 虚拟机中 Linux 系统Centos7磁盘空间扩容(亲测)

1.修改虚拟机磁盘容量 例如之前虚拟机磁盘空间为30G,现要将磁盘容量设置为50G 打开虚拟机(必须处于关机状态),点击【编辑虚拟机设置】,然后点击【磁盘】,接着点击【扩展】,输入修改后的最大磁盘…

LangChain入门(二)-通过 Google 搜索并返回答案

GitHub - liaokongVFX/LangChain-Chinese-Getting-Started-Guide: LangChain 的中文入门教程LangChain 的中文入门教程. Contribute to liaokongVFX/LangChain-Chinese-Getting-Started-Guide development by creating an account on GitHub.https://github.com/liaokongVFX/La…

js实现继承属性和方法

js实现继承属性和方法 1 使用extends实现继承2 原型链继承3 组合继承4 寄生组合继承5 实例继承6 拷贝继承7 扩展7.1 函数中方法定义在函数内部、函数外、prototype上的区别7.2 class创建实例与构造函数创建实例 首先定义一个父类 function Animal (name, age) {this.name nam…

Java Web案例:实现用户登录功能

文章目录 零、本节学习目标一、纯JSP方式实现用户登录功能(一)实现思路(二)实现步骤1、创建Web项目2、创建登录页面3、创建登录处理页面4、创建登录成功页面5、创建登录失败页面6、编辑项目首页 (三)测试结…

【JavaEE】CSS基础知识

文章目录 1.CSS概念1.1CSS是干啥的?1.2基础语法规范1.2基础语法规范1.3引入格式✨内部样式表✨行内样式表✨外部样式(最常用的样式) 1.4代码风格✨样式格式✨样式大小写 2.选择器2.1选择器的功能2.2基础选择器有哪些?&#x1f6e0…

GPT 告诉你请求到达 Tomcat 是怎么处理的

tomcat如何监听请求到达 没有SpringMVC,tomcat 如何处理请求 Tomcat 线程池的作用是什么 如何配置tomcat 线程池 tomcat 线程池的主要任务是处理连接请求 tomcat线程池是怎么实现的 到这里可以看出来,tomcat线程池的实现方式也是通过ThreadPoolExecutor 实现 如何根…

Git 使用教程整理

一、配置Git 编码为utf-8 设置登陆账号 使用Git GUI操作 二、获取远程仓库代码 推荐使用使用 git bash 命令:git clone xxx git clone https://github.com/jeromeetienne/jquery-qrcode.git 其他参考:使用Git获取最新版本到本地_gitgui 获取新版本_天…

【Java开发】Spring Cloud 11 :Gateway 配置 ssl 证书(https、http 访问)

最近研究给微服务项目配置 ssl 证书,如此才可以对接微信小程序(需要使用 https 请求)。传统单体项目来说,首先往项目中添加证书文件,然后在配置文件中配置 ssl 证书路径、密码等相关信息;那么微服务这么多项…