Tensorflow入门图像分类-猫狗分类-安卓

news2024/10/2 12:20:22

        最近在温习 Tensorflow,写了一篇笔记,记录了使用 Tensorflow 训练一个猫狗图像分类器的模型并在安卓应用上使用的全过程。

一、数据集准备

1.1 数据集来源

        我采用的是微软的猫狗数据集,链接:Download Kaggle Cats and Dogs Dataset from Official Microsoft Download CenterDownload Kaggle Cats and Dogs Dataset from Official Microsoft Download Center

 

1.2 数据集查看

        下载数据集后,解压如下:

         Cat 数据如下:

        Dog数据集如下:

 

 1.3 删除脏数据        

        这个数据集有一些文件是有问题的,需要删除掉:

import os
from PIL import Image

# Set the directory to search for corrupted files
directory = 'path/to/directory'

# Loop through all files in the directory
for filename in os.listdir(directory):

    # Check if file is an image
    if filename.endswith('.jpg') or filename.endswith('.png'):
        
        # Attempt to open image with PIL
        try:
            img = Image.open(os.path.join(directory, filename))
            img.verify()
            img.close()
        except (IOError, SyntaxError) as e:
            print(f"Deleting {filename} due to error: {e}")
            os.remove(os.path.join(directory, filename))

说明:上面的python代码会把目录下的文件给删除掉

二、训练模型

        下面是完整代码,包括数据集划分、数据增强、模型训练、模型评估、模型保存等。

2.1 数据准备

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# 定义数据集路径
dataset_dir = 'I:/数据集/kagglecatsanddogs_5340/PetImages'
img_height, img_width = 224, 224
batch_size = 32

2.2 数据集划分

        将数据集划分为训练集、测试集、验证集:

# 将数据集分为训练集、测试集、验证集
import os
import shutil
import numpy as np

train_dir = os.path.join(os.path.dirname(dataset_dir), 'PetImages_train')
val_dir = os.path.join(os.path.dirname(dataset_dir), 'PetImages_validation')
test_dir = os.path.join(os.path.dirname(dataset_dir), 'PetImages_test')

if not os.path.exists(train_dir):
    train_ratio = 0.7  # 训练集比例
    val_ratio = 0.15   # 验证集比例
    test_ratio = 0.15  # 测试集比例

    classfiers = ['Cat', 'Dog']
    for cls in classfiers:
        # 获取数据集中所有文件名
        filenames = os.listdir(os.path.join(dataset_dir, cls))
        
        # 计算拆分后的数据集大小
        num_samples = len(filenames)
        num_train = int(num_samples * train_ratio)
        num_val = int(num_samples * val_ratio)
        num_test = num_samples - num_train - num_val
        
        # 将文件名打乱顺序
        shuffle_indices = np.random.permutation(num_samples)
        filenames = [filenames[i] for i in shuffle_indices]
        
        os.makedirs(os.path.join(train_dir, cls), exist_ok=True)
        os.makedirs(os.path.join(val_dir, cls), exist_ok=True)
        os.makedirs(os.path.join(test_dir, cls), exist_ok=True)
    
        # 拆分数据集并复制文件到相应目录
        for i in range(num_train):
            src_path = os.path.join(dataset_dir, cls, filenames[i])
            dst_path = os.path.join(train_dir, cls, filenames[i])
            shutil.copy(src_path, dst_path)

        for i in range(num_train, num_train+num_val):
            src_path = os.path.join(dataset_dir, cls, filenames[i])
            dst_path = os.path.join(val_dir, cls, filenames[i])
            shutil.copy(src_path, dst_path)

        for i in range(num_train+num_val, num_samples):
            src_path = os.path.join(dataset_dir, cls, filenames[i])
            dst_path = os.path.join(test_dir, cls, filenames[i])
            shutil.copy(src_path, dst_path)

2.3 数据加载和增强处理

# 数据增强处理
train_datagen = ImageDataGenerator(rescale=1./255,
                                   shear_range=0.2,
                                   zoom_range=0.2,
                                   horizontal_flip=True)

# 加载数据集
train_generator = train_datagen.flow_from_directory(train_dir,
                                                    target_size=(img_height, img_width),
                                                    batch_size=batch_size,
                                                    class_mode='binary')

         这段代码是用于图像分类模型训练的数据预处理部分。其中,ImageDataGenerator类是Keras中用于数据增强的工具,可以对图像进行一系列随机变换来扩充训练集,以提高模型泛化能力。

        在这里,train_datagen对象将使用以下三个参数进行数据增强处理:

  • rescale:将图像像素值缩放到0-1之间,加速模型收敛;
  • shear_range:随机剪切变换范围;
  • zoom_range:随机缩放变换范围;
  • horizontal_flip:随机水平翻转。

        然后通过调用 flow_from_directory 方法来生成一个 train_generator 对象,它会从指定的 train_dir 文件夹中动态生成包含批次大小为 batch_size 的训练数据样本,并将其转成二进制标签格式(class_mode='binary')。同时,所有图像都将被调整为指定的 img_height img_width 大小,以保证输入神经网络的数据形状相同。

代码输出:Found 17498 images belonging to 2 classes.

        

2.4 模型定义

# 定义模型结构
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(img_height, img_width, 3)),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

        这段代码是用于构建卷积神经网络模型的部分,使用了Keras中的Sequential模型,该模型按照一定顺序堆叠不同类型的层。

        具体而言,这个模型包含了如下一系列层:

  • Conv2D层:使用32个3x3的滤波器进行二维卷积,激活函数为ReLU;
  • MaxPooling2D层:使用大小为2x2的窗口进行最大池化操作;
  • 再次添加一个 Conv2D 层:使用64个3x3的滤波器进行卷积,激活函数为ReLU;
  • 再次添加一个 MaxPooling2D 层;
  • 再次添加一个 Conv2D 层:使用128个3x3的滤波器进行卷积,激活函数为ReLU;
  • 再次添加一个 MaxPooling2D 层;
  • Flatten 层:将卷积部分输出的特征图展平成一维向量,以便后续连接全连接层;
  • 全连接 Dense 层:包含128个神经元,激活函数为ReLU;
  • 输出层 Dense:只有一个神经元,用Sigmoid激活函数输出二分类结果。

        这个模型的输入形状是(img_height, img_width, 3),其中3表示RGB图像的三个通道。第一层 Conv2D 和池化层之后,每个卷积层和池化层都会减小特征图的空间大小,并增加滤波器数量来提取更高层次的特征。最后使用全连接层对特征进行分类。输出层仅有一个神经元,用于二分类任务的概率输出。

2.4 编译模型与训练

        编译模型:

# 编译模型
model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])

        在这里,使用了Adam优化器(optimizer='adam'),它是一种自适应学习率的优化方法,能够有效解决梯度消失和梯度爆炸问题,加快收敛速度。同时,使用了二元交叉熵作为损失函数(loss='binary_crossentropy'),也是常见的用于二分类问题的损失函数。最后,使用准确率(accuracy)作为评价指标(metrics=['accuracy']),用于评估模型预测的准确性。

        训练模型:

# 训练模型
model.fit(train_generator,
          epochs=15)

         我电脑用的是魔改版 2080Ti 22GB,训练的时候显存占用为 21GB,每一个 epoch 约 2分钟,全部跑完大约30分钟。

代码输出:

Epoch 1/15
547/547 [================] - 121s 220ms/step - loss: 0.6190 - accuracy: 0.6480
Epoch 2/15
547/547 [================] - 120s 220ms/step - loss: 0.5057 - accuracy: 0.7540
Epoch 3/15
547/547 [================] - 120s 220ms/step - loss: 0.4398 - accuracy: 0.7949
Epoch 4/15
547/547 [================] - 120s 220ms/step - loss: 0.4021 - accuracy: 0.8180
Epoch 5/15
547/547 [================] - 120s 219ms/step - loss: 0.3753 - accuracy: 0.8302
Epoch 6/15
547/547 [================] - 120s 220ms/step - loss: 0.3512 - accuracy: 0.8435
Epoch 7/15
547/547 [================] - 120s 220ms/step - loss: 0.3256 - accuracy: 0.8580
Epoch 8/15
547/547 [================] - 155s 283ms/step - loss: 0.3046 - accuracy: 0.8677
Epoch 9/15
547/547 [================] - 127s 233ms/step - loss: 0.2897 - accuracy: 0.8764
Epoch 10/15
547/547 [================] - 122s 224ms/step - loss: 0.2705 - accuracy: 0.8861
Epoch 11/15
547/547 [================] - 120s 220ms/step - loss: 0.2519 - accuracy: 0.8930
Epoch 12/15
547/547 [================] - 122s 223ms/step - loss: 0.2363 - accuracy: 0.9011
Epoch 13/15
...
Epoch 14/15
547/547 [================] - 121s 221ms/step - loss: 0.2154 - accuracy: 0.9144
Epoch 15/15
547/547 [================] - 121s 221ms/step - loss: 0.1986 - accuracy: 0.9200

2.5 模型准确率测试

# 测试模型准确率
test_datagen = ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_directory(test_dir,
                                                  target_size=(img_height, img_width),
                                                  batch_size=batch_size,
                                                  class_mode='binary')
test_loss, test_acc = model.evaluate(test_generator)
print('Test accuracy:', test_acc)

代码输出:

Found 3752 images belonging to 2 classes.

40/118 [=========>....................] - ETA: 3s - loss: 0.3541 - accuracy: 0.8672

118/118 [==============================] - 6s 47ms/step - loss: 0.3433 - accuracy: 0.8681 Test accuracy: 0.8680703639984131

2.6 模型验证

# 验证模型准确率、损失
val_datagen = ImageDataGenerator(rescale=1./255)
val_generator = val_datagen.flow_from_directory(val_dir,
                                                  target_size=(img_height, img_width),
                                                  batch_size=batch_size,
                                                  class_mode='binary')
val_loss, val_acc = model.evaluate(val_generator)
print('Validation accuracy:', val_acc)
print('Validation loss:', val_loss)

代码输出:

Found 3748 images belonging to 2 classes. 118/118 [==============================] - 6s 46ms/step - loss: 0.3384 - accuracy: 0.8714 Validation accuracy: 0.8713980913162231 Validation loss: 0.338356614112854

2.6 模型保存

2.6.1 保存为 tf 格式

# save model
model.save('cat_dog_classfier_v1.tf', overwrite=True, include_optimizer=True)

        tf格式是一个文件夹,该文件夹里面包含了计算图结构、模型元数据信息、模型参数等。 

        

2.6.2 保存为 tflite 格式

        这种格式可以在移动端上运行:

# 导出TensorFlow Lite模型
covertor = tf.lite.TFLiteConverter.from_keras_model(model)
covertor.optimizations = [tf.lite.Optimize.DEFAULT]
tflife_model = covertor.convert()

with open('cat_dog_classfier_v1.tflite', 'wb') as f:
  f.write(tflife_model)

        tflite 格式大小约为 10.6MB; 

        

 

三、将tflite模型应用在安卓App上

        启动 Android Studio,打开一个工程。

        然后右击 app 模块,在弹出的右键菜单中选择 “New” - “Other” - "TensorFlow Lite Model",如下图所示:

        

         接着,选中模型:

         导入之后,如下图所示:

         Android Studio 会自动生成调用该模型的相关代码了。

        接下来我们就调用就可以了。

        先准备一些猫狗图片,

        放到 drawable-nodpi 目录下:

        

 

        接着开始写代码,以下代码写在 MainActivity 中:

    
override fun onCreate(savedInstanceState: Bundle?) {
    super.onCreate(savedInstanceState)
    setContentView(R.layout.activity_main)

    val map = mapOf(
        R.drawable.cat_0 to "R.drawable.cat_0",
        R.drawable.cat_1 to "R.drawable.cat_1",
        R.drawable.cat_2 to "R.drawable.cat_2",
        R.drawable.cat_3 to "R.drawable.cat_3",
        R.drawable.cat_4 to "R.drawable.cat_4",
        R.drawable.cat_5 to "R.drawable.cat_5",
        R.drawable.dog_0 to "R.drawable.dog_0",
        R.drawable.dog_1 to "R.drawable.dog_1",
        R.drawable.dog_2 to "R.drawable.dog_2",
        R.drawable.dog_3 to "R.drawable.dog_3",
        R.drawable.dog_4 to "R.drawable.dog_4",
        R.drawable.dog_5 to "R.drawable.dog_5"
    )
    val model = CatDogClassfierV1.newInstance(this)
    map.forEach {
        doit(model, it.key, it.value)
    }
    model.close()
}



private fun doit(model: CatDogClassfierV1, id: Int, name: String) {
        // Load the input image from resources
        val bmp = BitmapFactory.decodeResource(resources, id)

        // Resize the input image to 224x224
        val resizedBmp = Bitmap.createScaledBitmap(bmp, 224, 224, true)

        // Convert the resized bitmap to a ByteBuffer
        val byteBuffer = ByteBuffer.allocateDirect(4 * resizedBmp.width * resizedBmp.height * 3)
        byteBuffer.order(ByteOrder.nativeOrder())
        val pixels = IntArray(resizedBmp.width * resizedBmp.height)
        resizedBmp.getPixels(pixels, 0, resizedBmp.width, 0, 0, resizedBmp.width, resizedBmp.height)
        var pixel = 0
        for (i in 0 until 224) {
            for (j in 0 until 224) {
                val pixelValue = pixels[pixel++]
                byteBuffer.putFloat(((pixelValue shr 16) and 0xFF) / 255.0f)
                byteBuffer.putFloat(((pixelValue shr 8) and 0xFF) / 255.0f)
                byteBuffer.putFloat((pixelValue and 0xFF) / 255.0f)
            }
        }

        // Creates inputs for reference.
        val inputFeature0 = TensorBuffer.createFixedSize(intArrayOf(1, 224, 224, 3), DataType.FLOAT32)
        inputFeature0.loadBuffer(byteBuffer)

        // Runs model inference and gets result.
        val outputs = model.process(inputFeature0)
        val outputFeature0 = outputs.outputFeature0AsTensorBuffer

        // Get the index of the predicted class
        val probabilities = outputFeature0.floatArray
        val predictedClassIndex = probabilities.indices.maxByOrNull { probabilities[it] } ?: -1

        // Map the predicted class index to "Cat" or "Dog"
        val predictedClass = if (predictedClassIndex == 0) "Cat" else if (predictedClassIndex == 1) "Dog" else "Unknown"

        Log.i("MainActivity", "$name, probabilities=${probabilities.joinToString()}")
        Log.i("MainActivity", "$name, ${if(probabilities[0] < 0.5) '猫' else '狗'}")
        // Releases model resources if no longer used.
    }

         运行起来看日志输出:

 

 四、总结

        以上是如何使用 tensorflow 训练一个猫狗分类器、并在移动端上部署的完整过程。

        tensorflow 已经很完备了。在项目上可以多思考一下怎么应用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/484664.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2023华中杯数学建模C题完整模型代码

已完成全部模型代码&#xff0c;文末获取。 摘要 随着工业化和城市化的快速发展&#xff0c;空气污染已经成为全球性的环境问题。细颗粒物&#xff08;PM2.5&#xff09;等污染物对人类健康、生态环境和社会经济造成了严重影响。本研究旨在深入探究影响PM2.5浓度的主要因素&a…

【Android入门到项目实战-- 8.4】—— 如何解析JSON格式数据

目录 一、准备工作 二、使用JSONObject 三、使用GSON 比起XML&#xff0c;JSON的主要优势在于它的体积更小&#xff0c;在网络上传输的时候可以更省流量&#xff0c;但缺点是语义性较差&#xff0c;看起来不直观。 一、准备工作 还是使用前面文章的方法&#xff0c;在服务器…

【C++】STL标准库之vector

STL标准库之vector vector类的简介常用的vector类的接口构造容量遍历及访问增删查改迭代器迭代器失效问题 vector类的简介 vector是大小可变数组的序列容器&#xff0c;与string相比&#xff0c;vector中可以存任何类型的数据&#xff0c;而string中存储的只能是字符类型。 因为…

第二十九章 使用消息订阅发布实现组件通信

PubSubJS库介绍 如果你想在React中使用第三方库来实现Pub/Sub机制&#xff0c;PubSubJS是一个不错的选择。它是一个轻量级的库&#xff0c;可以在浏览器和Node.js环境中使用。 PubSubJS提供了一个简单的API&#xff0c;可以让你在应用程序中订阅和发布消息。你可以使用npm来安…

大数据Doris(十):Doris基础介绍

文章目录 Doris基础介绍 一、基本概念 二、建表语法及参数解释 1、column_definition_list 2、index_definition_list 3、engine_type 4、key_type 5、table_comment 6、partition_desc 7、distribution_desc 8、rollup_list 9、properites 三、数据类型 Doris基础…

java合并数组的方法

在 Java中&#xff0c;数组是一种重要的数据结构&#xff0c;在 Java中数组的操作方式有两种&#xff0c;一种是直接使用数组来操作&#xff0c;另一种是通过引用计数或者双指针对数组进行操作。对于直接使用数组来操作的方式&#xff0c;我们可以通过两个方法来实现。 一种是将…

C++(多态中)

目录&#xff1a; 1.多态实现原理&#xff08;再剖析&#xff09; 2.析构函数加virtual 3.C11新增两个关键字 override 和 final 4.重载、覆盖&#xff08;重写&#xff09;、隐藏的对比 5.抽象类 1.多态实现原理&#xff08;再剖析&#xff09; 实现出多态的效果&#xff0c;我…

Docker之Docker Compose技术

目录 一、什么是docker compose? 二、安装docker compose 三、使用案例&#xff1a;部署一个简单的fastapi服务 (以下教程是基于环境已将安装了docker服务) 一、什么是docker compose? Compose是一个将多个docker容器组合部署的技术&#xff0c;能通过编写yaml配置文件…

IJCAI2023 | A Systematic Survey of Chemical Pre-trained Models(化学小分子预训练模型综述)

IJCAI_A Systematic Survey of Chemical Pre-trained Models 综述资料汇总(更新中&#xff0c;原文提供)&#xff1a;GitHub - junxia97/awesome-pretrain-on-molecules: [IJCAI 2023 survey track]A curated list of resources for chemical pre-trained models 参考资料&…

『python爬虫』09. bs4实战之下载精美壁纸(保姆级图文)

目录 爬取思路代码思路1.拿到主页面的源代码. 然后提取到子页面的链接地址, href2.通过href拿到子页面的内容. 从子页面中找到图片的下载地址 img -> src3.下载图片 3. 完整实现代码总结 欢迎关注 『python爬虫』 专栏&#xff0c;持续更新中 欢迎关注 『python爬虫』 专栏&…

docker 非持久化存储 tmpfs mounts

docker 非持久化存储 tmpfs mounts 简介tmpfs mounts 限制--tmpfs 和 --mount 之间的差异在容器中使用 tmpfs mounts指定 tmpfs 选项 简介 官方文档&#xff1a;https://docs.docker.com/storage/tmpfs/ 与 volume 和 bind mounts 不同&#xff0c;tmpfs mounts 是临时的&…

jdk中juc多线程编程工具

jdk线程池实现原理分析 目录 CompletionService CompletableFuture 基本原理 CompletableFuture的接口 静态方法 handle() vs whenComplete() xxxEither() 异常处理exceptionally() 获取任务结果 结束任务 Semaphore CyclicBarrier CountDownLatch jdk线程池实…

《斯坦福数据挖掘教程·第三版》读书笔记(英文版)Chapter 4 Mining Data Streams

来源&#xff1a;《斯坦福数据挖掘教程第三版》对应的公开英文书和PPT Chapter 4 Mining Data Streams &#x1f4a1; Skip this chapter due to its difficulty and for me, it is hard to understand. Summary of Chapter 4 The Stream Data Model: This model assumes da…

【微机原理】半导体存储器

目录 一.半导体存储器的分类 二、半导体存储器性能指标 三、半导体存储器的结构 一.半导体存储器的分类 半导体存储器的分类方法有很多种。 1.按器件原理来分&#xff1a;有双极型存储器和MOS型存储器。 双极型&#xff1a;速度快、集成度低、功耗大MOS型&#xff1a;速度慢、集…

“ 探索迷局:解密广度寻路算法 “

专栏文章&#xff0c;自下而上 数据结构与算法——二叉搜索树 数据结构与算法——深度寻路算法 数据结构与算法——二叉树实现表达式树 数据结构与算法——树(三指针描述一棵树&#xff09; 数据结构与算法——栈和队列&#xff1c;也不过如此&#xff1e; 数据结构与算法——八…

C++的智能指针

文章目录 1. 内存泄漏1.1 什么是内存泄漏1.2 内存泄漏分类 2. 为什么需要智能指针3. 智能指针的使用及原理3.1 RAII3.2 使用RAII思想设计的SmartPtr类3.3 让SmartPtr像指针一样3.3 SmartPtr的拷贝3.4 auto_ptr3.5 unique_ptr3.6 shared_ptr3.6.1 shared_ptr的循环引用3.6.2 wea…

MYSQL-数据库管理(上)

一、数据库概述 一、数据库基本概念 1.1 数据 1&#xff09; 描述事物的符号记录称为数据&#xff08;Data&#xff09;。数字、文字、图形、图像、声音、档案记录等 都是数据。 2&#xff09;数据是以“记录”的形式按照统一的格式进行存储的&#xff0c;而不是杂乱无章的。…

机器学习之分类决策树与回归决策树—基于python实现

大家好&#xff0c;我是带我去滑雪&#xff01; 本期为大家介绍决策树算法&#xff0c;它一种基学习器&#xff0c;广泛应用于集成学习&#xff0c;用于大幅度提高模型的预测准确率。决策树在分区域时&#xff0c;会考虑特征向量对响应变量的影响&#xff0c;且每次仅使用一个分…

vs编译生成动态库

说明 windows版本&#xff0c;vs2019 创建一个动态库 新建一c项目&#xff0c;创建一个dll类型项目。 在头文件中添加一个mylib.h文件&#xff1a; #pragma once#ifndef MYLIB_H #define MYLIB_Hextern "C" __declspec(dllexport) void Hello(); extern "C…

UG NX二次开发(C++)-建模-修改NXObject或者Feature的颜色(一)

文章目录 1、前言2、在UG NX中修改Feature的颜色操作3、采用NXOpen(C)实现3.1 创建修改特征的方法3.2 调用ModifyFeatureColor方法3.3 测试结果 1、前言 在UG NX中&#xff0c;改变NXObject和Feature的操作是不相同的&#xff0c;所以其二次开发的代码也不一样&#xff0c;我们…