TensorFlow2.0实战:Cats vs Dogs

news2024/11/21 23:17:08

数据集准备

在本文中,我们使用“Cats vs Dogs”的数据集。这个数据集包含了23,262张猫和狗的图像

在这里插入图片描述
你可能注意到了,这些照片没有归一化,它们的大小是不一样的

但是非常棒的一点是,你可以在Tensorflow Datasets中获取这个数据集

所以,确保你的环境里安装了Tensorflow Dataset

pip install tensorflow-dataset

和这个库中的其他数据集不同,这个数据集没有划分成训练集和测试集

所以我们需要自己对这两类数据集做个区分

关于数据集的更多信息:https://www.tensorflow.org/datasets/catalog/cats_vs_dogs


实现

这个实现分成了几个部分

首先,我们实现了一个类,其负责载入数据和准备数据。

然后,我们导入预训练模型,构建一个类用于修改最顶端的几层网络。

最后,我们把训练过程运行起来,并进行评估。

当然,在这之前,我们必须导入一些代码库,定义一些全局常量:

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds

IMG_SIZE = 160
BATCH_SIZE = 32
SHUFFLE_SIZE = 1000
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)

数据载入器

这个类负责载入数据和准备数据,用于后续的数据处理。以下是这个类的实现:

class DataLoader(object):
    def __init__(self, image_size, batch_size):
        
        self.image_size = image_size
        self.batch_size = batch_size
        
        # 80% train data, 10% validation data, 10% test data
        split_weights = (8, 1, 1)
        splits = tfds.Split.TRAIN.subsplit(weighted=split_weights)
        
        (self.train_data_raw, self.validation_data_raw, self.test_data_raw), self.metadata = tfds.load(
            'cats_vs_dogs', split=list(splits),
            with_info=True, as_supervised=True)
        
        # Get the number of train examples
        self.num_train_examples = self.metadata.splits['train'].num_examples*80/100
        self.get_label_name = self.metadata.features['label'].int2str
        
        # Pre-process data
        self._prepare_data()
        self._prepare_batches()
        
    # Resize all images to image_size x image_size
    def _prepare_data(self):
        self.train_data = self.train_data_raw.map(self._resize_sample)
        self.validation_data = self.validation_data_raw.map(self._resize_sample)
        self.test_data = self.test_data_raw.map(self._resize_sample)
    
    # Resize one image to image_size x image_size
    def _resize_sample(self, image, label):
        image = tf.cast(image, tf.float32)
        image = (image/127.5) - 1
        image = tf.image.resize(image, (self.image_size, self.image_size))
        return image, label
    
    def _prepare_batches(self):
        self.train_batches = self.train_data.shuffle(1000).batch(self.batch_size)
        self.validation_batches = self.validation_data.batch(self.batch_size)
        self.test_batches = self.test_data.batch(self.batch_size)
   
    # Get defined number of  not processed images
    def get_random_raw_images(self, num_of_images):
        random_train_raw_data = self.train_data_raw.shuffle(1000)
        return random_train_raw_data.take(num_of_images)

这个类实现了很多功能,它实现了很多public方法:

  • _prepare_data:内部方法,用于缩放和归一化数据集里的图像。构造函数需要用到该函数。
  • _resize_sample:内部方法,用于缩放单张图像。
  • _prepare_batches:内部方法,用于将图像打包创建为batches。创建train_batchesvalidation_batchestest_batches,分别用于训练、评估过程。
  • get_random_raw_images:这个方法用于从原始的、没有经过处理的数据中随机获取固定数量的图像。

但是,这个类的主要功能还是在构造函数中完成的。让我们仔细看看这个类的构造函数。

def __init__(self, image_size, batch_size):

    self.image_size = image_size
    self.batch_size = batch_size

    # 80% train data, 10% validation data, 10% test data
    split_weights = (8, 1, 1)
    splits = tfds.Split.TRAIN.subsplit(weighted=split_weights)

    (self.train_data_raw, self.validation_data_raw, self.test_data_raw), self.metadata = tfds.load(
        'cats_vs_dogs', split=list(splits),
        with_info=True, as_supervised=True)

    # Get the number of train examples
    self.num_train_examples = self.metadata.splits['train'].num_examples*80/100
    self.get_label_name = self.metadata.features['label'].int2str

    # Pre-process data
    self._prepare_data()
    self._prepare_batches()

首先我们通过传入参数定义了图像大小和batch大小

然后,由于该数据集本身没有区分训练集和测试集,我们通过划分权值对数据进行划分

一旦我们执行了数据划分,我们就开始计算训练样本数量,然后调用辅助函数来为训练准备数据

在这之后,我们需要做的仅仅是实例化这个类的对象,然后载入数据即可。

data_loader = DataLoader(IMG_SIZE, BATCH_SIZE)

plt.figure(figsize=(10, 8))
i = 0
for img, label in data_loader.get_random_raw_images(20):
    plt.subplot(4, 5, i+1)
    plt.imshow(img)
    plt.title("{} - {}".format(data_loader.get_label_name(label), img.shape))
    plt.xticks([])
    plt.yticks([])
    i += 1
plt.tight_layout()
plt.show()

输出结果
在这里插入图片描述
基础模型 & Wrapper

下一个步骤就是载入预训练模型了

这些模型位于tensorflow.kearas.applications

我们可以用下面的语句直接载入它们

vgg16_base = tf.keras.applications.VGG16(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
googlenet_base = tf.keras.applications.InceptionV3(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
resnet_base = tf.keras.applications.ResNet101V2(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')

这段代码就是我们创建上述三种网络结构基础模型的方式

注意,每个模型构造函数的include_top参数传入的是false

这意味着这些模型是用于提取特征的

我们一旦创建了这些模型,我们就需要修改这些模型顶部的网络层,使之适用于我们的具体问题

我们使用Wrapper类来完成这个步骤

这个类接收预训练模型,然后添加一个Global Average Polling Layer和一个Dense Layer

本质上,这最后的Dense Layer会用于我们的二分类问题(猫或狗)

Wrapper类把所有这些元素都放到了一起,放在了同一个模型中

class Wrapper(tf.keras.Model):
    def __init__(self, base_model):
        super(Wrapper, self).__init__()
        
        self.base_model = base_model
        self.average_pooling_layer = tf.keras.layers.GlobalAveragePooling2D()
        self.output_layer = tf.keras.layers.Dense(1)
        
    def call(self, inputs):
        x = self.base_model(inputs)
        x = self.average_pooling_layer(x)
        output = self.output_layer(x)
        return output

然后我们就可以创建Cats vs Dogs分类问题的模型了,并且编译这个模型。

base_learning_rate = 0.0001

vgg16_base.trainable = False
vgg16 = Wrapper(vgg16_base)
vgg16.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),
              loss='binary_crossentropy',
              metrics=['accuracy'])

googlenet_base.trainable = False
googlenet = Wrapper(googlenet_base)
googlenet.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),
              loss='binary_crossentropy',
              metrics=['accuracy'])

resnet_base.trainable = False
resnet = Wrapper(resnet_base)
resnet.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),
              loss='binary_crossentropy',
              metrics=['accuracy'])

注意,我们标记了基础模型是不参与训练的

这意味着在训练过程中,我们只会训练新添加到顶部的网络层,而在网络底部的权重值不会发生变化。

训练

在我们开始整个训练过程之前,让我们思考一下,这些模型的大部头其实已经被训练过了

所以,我们可以执行评估过程来看看评估结果如何:

steps_per_epoch = round(data_loader.num_train_examples)//BATCH_SIZE
validation_steps = 20

loss1, accuracy1 = vgg16.evaluate(data_loader.validation_batches, steps = 20)
loss2, accuracy2 = googlenet.evaluate(data_loader.validation_batches, steps = 20)
loss3, accuracy3 = resnet.evaluate(data_loader.validation_batches, steps = 20)

print("--------VGG16---------")
print("Initial loss: {:.2f}".format(loss1))
print("Initial accuracy: {:.2f}".format(accuracy1))
print("---------------------------")

print("--------GoogLeNet---------")
print("Initial loss: {:.2f}".format(loss2))
print("Initial accuracy: {:.2f}".format(accuracy2))
print("---------------------------")

print("--------ResNet---------")
print("Initial loss: {:.2f}".format(loss3))
print("Initial accuracy: {:.2f}".format(accuracy3))
print("---------------------------")

有意思的是,这些模型在没有预先训练的情况下,我们得到的结果也还过得去(50%的精确度):

———VGG16———
Initial loss: 5.30
Initial accuracy: 0.51
—————————-

——GoogLeNet—–
Initial loss: 7.21
Initial accuracy: 0.51
—————————-

——–ResNet———
Initial loss: 6.01
Initial accuracy: 0.51
—————————-

把50%作为训练的起点已经挺好的了

所以,就让我们把训练过程跑起来吧,看看我们是否能得到更好的结果

首先,我们训练VGG16:

history = vgg16.fit(data_loader.train_batches, epochs=10, validation_data=data_loader.validation_batches)

训练过程历史数据显示大致如下:

在这里插入图片描述
然后我们可以训练GoogLeNet

history = googlenet.fit(data_loader.train_batches,
                    epochs=10,
                    validation_data=data_loader.validation_batches)

在这里插入图片描述
最后是ResNet的训练

history = resnet.fit(data_loader.train_batches,
                    epochs=10,
                    validation_data=data_loader.validation_batches)

在这里插入图片描述
由于我们只训练了顶部的几层网络,而不是整个网络,所以训练这三个模型只用了几个小时


评估

我们看到在训练开始前,我们已经有了50%左右的精确度。让我们来看下训练后是什么情况:

loss1, accuracy1 = vgg16.evaluate(data_loader.test_batches, steps = 20)
loss2, accuracy2 = googlenet.evaluate(data_loader.test_batches, steps = 20)
loss3, accuracy3 = resnet.evaluate(data_loader.test_batches, steps = 20)

print("--------VGG16---------")
print("Loss: {:.2f}".format(loss1))
print("Accuracy: {:.2f}".format(accuracy1))
print("---------------------------")

print("--------GoogLeNet---------")
print("Loss: {:.2f}".format(loss2))
print("Accuracy: {:.2f}".format(accuracy2))
print("---------------------------")

print("--------ResNet---------")
print("Loss: {:.2f}".format(loss3))
print("Accuracy: {:.2f}".format(accuracy3))
print("---------------------------")

结果如下:

——–VGG16———
Loss: 0.25
Accuracy: 0.93
—————————

——–GoogLeNet———
Loss: 0.54
Accuracy: 0.95
—————————
——–ResNet———
Loss: 0.40
Accuracy: 0.97
—————————

我们可以看到这三个模型的结果都相当好,其中ResNet效果最好,精确度高达97%。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/138905.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

梦在远方路在脚下,社科院与杜兰大学金融管理硕士项目与你一路相伴

梦想是指引我们飞翔的翅膀,梦想是远方的灯塔指引着我们前进的方向。梦想距离我们很远,但路在脚下,只要朝着梦想前进,终有一天梦想会照进现实。就像拥有读研梦想的我们,在社科院杜兰金融管理硕士项目汲取能量&#xff0…

【Android OpenGL开发】OpenGL ES与EGL介绍

什么是OpenGL ES OpenGL(Open Graphics Library)是一个跨编程语言、跨平台的编程图形程序接口,主要用于图像的渲染。 Android提供了简化版的OpenGL接口,即OpenGL ES。 早先定义 OpenGL ES 是 OpenGL 的嵌入式设备版本&#xff…

Mac上超实用的6款软件,老用户都知道!

今天为大家带来的是6款超实用的Mac软件,让你不再走弯路。第一款:Amphetamine 防休眠的利器Amphetamine for mac是应用在Mac上的一款防休眠工具,可以自定义哪些程序运行时不休眠,做到自定义Mac睡眠时间,可以通过超级简单…

【数据结构】链式存储:链表(无头双向链表实现)

目录 🥇一:无头双向链表 🎒二、无头双向链表的实现 📘1.创建节点类 📒2.创建链表 📗3.打印链表 📕4.查找是否包含关键字key是否在单链表当中 📙5.得到单链表的长度 &#x1…

PCL中常用的高级采样方法

0. 简介 我们在使用PCL时候,常常不满足于常用的降采样方法,这个时候我们就想要借鉴一些比较经典的高级采样方法。这一讲我们将对常用的高级采样方法进行汇总,并进行整理,来方便读者完成使用 1. 基础下采样 1.1 点云随机下采样 …

代码随想录拓展day6 N皇后

代码随想录拓展day6 N皇后 只有这一个内容。一刷的时候也没弄太明白,二刷的时候补上。还有部分内容来自牛客网左老师的算法课程。 总体思路不容易想明白,优化也有很大难度。这要是面试能碰上基本就是故意不给过了吧。 思路 首先来看一下皇后们的约束…

Flink 容错恢复 2.0 2022 最新进展

摘要:本文整理自阿里云 Flink 存储引擎团队负责人,Apache Flink 引擎架构师 & PMC 梅源在 FFA 核心技术专场的分享。主要介绍在 2022 年度,Flink 容错 2.0 这个项目在社区和阿里云产品的进展,内容包括:Flink 容错恢…

基于ssm的个人健康管理系统

项目描述 临近学期结束,还是毕业设计,你还在做java程序网络编程,期末作业,老师的作业要求觉得大了吗?不知道毕业设计该怎么办?网页功能的数量是否太多?没有合适的类型或系统?等等。这里根据疫情当下,你想解决的问…

简单理解光会产生折射的原因及折射定律的推导

已知 1、光是一种波; 2、光在不同介质中传播速度不同。 构建模型 如下图所示,光是中电磁波,以余弦波为例,取余弦波的极大值点为参考,建立一个平面波(波前为一个平面)。能明显的看出光的传播方…

树形结构——二叉树

前言 前面的章节我们介绍了两种重要的数据结构,数组和链表,由于他们各自的特性使得他们的优缺点非常分明,在查询速度和插入速度上顾此失彼,不能兼顾,那么有没有一种数据结构可以同时高效的完成插入和查询操作呢&#x…

专访 | 刘嘉松:开源,互惠且共赢

本文整理自对 2022 开源之夏 OpenMLDB 社区贡献者刘嘉松同学的采访,欢迎大家关注~ OpenMLDB:可以先请你介绍一下你自己吗? 刘嘉松:我叫刘嘉松,是中南大学计科专业的一名本科生,目前大四,未来将继续在中南…

Hello 2023 D. Boris and His Amazing Haircut

原题链接:Problem - D - Codeforces 题意: 给定长度为 n 的数组 A ,代表 Boris 现在的头发长度,和一个长度为 n 的数组 B ,代表他希望的发型的头发长度。理发师手里有 m 把剪刀,每个都只能用一次&#xff…

计算机硕士论文,盲审的老师都很严吗? - 易智编译EaseEditing

首先,学位论文必须论证严谨 对于一个结果的解读,先老老实实的把得到什么结果讲一遍,基于这个所得出的结论说一说,最后,来个非谓语的短句吹一吹重要性或提示意义。 这其实是有套路的一个句子写下来。但是,在…

AC7811-ACMP模拟比较器

在无感的BLDC方波控制中,AC7811没办法再直接通过PWDT模块检测霍尔信号了。 所以需要先进行ACMP模块的初始化配置,使能ACMP模块正常工作后,ACMP会对输入的三相反电动势与电机中电电压进行轮询模拟,得到各相反电动势过零点&#xf…

分享5款有趣但或许不那么实用的软件

今天我想分享几个有趣但或许不那么实用的软件,各位喜欢的朋友可以自行下载呢。 1.软件音量设置——EarTrumpet 听音乐、看视频、玩游戏,在各应用切换过程中,你可能会频繁调整系统音量大小,以适应自己的耳朵。而 EarTrumpet 则可…

AIGC:BAT、抖快的新掘金口?

配图来自Canva可画 AI辅助绘画估值超十万? 12月28日,山东人民出版社看中一位4岁女孩用百家号AI作画创作的AI绘本《超能外星战队》,认为该画价值超十万元且有出版意向。与此同时,“AI作画”像病毒般在各大短视频平台蔓延&#xf…

年度盘点丨2022年,7大关键词彰显用友成长型企业数智力量

导读:这一年,他们用卓越成绩证明自己,用产品创新回馈客户,用实力开启了中国成长型企业数智化产业的逆袭之路! 2022年,企业级数智化产业经历了翻天覆地的变化。 曾经万家追捧的“舶来品”在成长型企业的主…

项目管理中,如何对各种文件进行统一版本管理?

不知道你在工作中是否也遇到过这样的问题:1、文件先存一个位置,等晚点再整理,结果过了一段时间,就变成了这样:2、想从电脑中找一份重要材料,要花费很长时间,有时查找一通,却一无所获…

Docker中MySQL主从复制

MySQL主从复制 主从搭建步骤 新建主服务器容器实例3307 docker run -p 3307:3306 --name mysql-master \ -v /mydata/mysql-master/log:/var/log/mysql \ -v /mydata/mysql-master/data:/var/lib/mysql \ -v /mydata/mysql-master/conf:/etc/mysql \ -e MYSQL_ROOT_PASSWORD…

SAP年结账务调整过程中的业务改错处理心得

昨天遇到一个问题在这里分享一下。大家知道刚刚跨年,财务还在做一些后续调整。在业财一体化的系统中,这种调整往往涉及到两个方向,财务端去调,还是业务端去调。但大家要了解一个前提。即然需要做调整,肯定是业务端出错…