基于Tensorflow和Keras实现卷积神经网络CNN并进行猫狗识别

news2025/1/10 10:54:51

文章目录

  • 一、环境配置
    • 1、安装Anaconda
    • 2、配置TensorFlow、Keras
  • 二、猫狗数据集分类建模
    • 3.1 猫狗图像预处理
    • 3.2 猫狗分类的实例——基准模型
      • 3.1 构建神经网络
      • 3.2 配置优化器
      • 3.3 图片格式转化
      • 3.4 训练模型
      • 3.5 保存模型
      • 3.6 可视化
  • 三、数据增强
  • 四、dropout 层
  • 五、参考资料

一、环境配置

1、安装Anaconda

具体可参考此博文

2、配置TensorFlow、Keras

创建虚拟环境:
输入下面命令

conda create -n tf1 python=3.6
#tf1是自己为创建虚拟环境取的名字,后面python的版本可以根据自己需求进行选择
  • 如果命令无效,可参考此博文
  • 如果提示:CondaSSLError: OpenSSL appears to be unavailable on this machine.可参考此博文

运行结果如下:
在这里插入图片描述

激活环境:

activate
conda activate tf1

安装 tensorflow、keras 库:
在新建的虚拟环境 tf1 内,使用以下命令安装两个库:

pip install tensorflow==1.14.0 -i “https://pypi.doubanio.com/simple/”
pip install keras==2.2.5 -i “https://pypi.doubanio.com/simple/”

结果如下:

在这里插入图片描述

安装 nb_conda_kernels 包:

conda install nb_conda_kernels

安装 1.16.4 版本的 numpy:

pip install numpy==1.16.4 -i "https://pypi.doubanio.com/simple/"

结果如下:

在这里插入图片描述
安装 pillow 库:

pip install pillow -i “https://pypi.doubanio.com/simple/

在这里插入图片描述
安装matplotlib库:

pip install matplotlib -i “https://pypi.doubanio.com/simple/

在这里插入图片描述

打开 Jupyter Notebook((tf1)环境下的):
在这里插入图片描述

点击【New】→【Python[tf1环境下的]】创建 python 文件:

在这里插入图片描述

二、猫狗数据集分类建模

  • 猫狗图片数据集下载:https://pan.baidu.com/s/1f-MvZl7_J6DF7P9CGBY3SQ—提取码:ruyn
  • 数据集下载完毕后,解压缩,并放在一个没有中文路径下,如下图所示:
    在这里插入图片描述

3.1 猫狗图像预处理

  • 对猫狗图像进行分类,代码如下:
import os, shutil 
# 原始目录所在的路径
original_dataset_dir = 'D:\\Cat_And_Dog\\train\\'

# 数据集分类后的目录
base_dir = 'D:\\Cat_And_Dog\\train1'
os.mkdir(base_dir)

# # 训练、验证、测试数据集的目录
train_dir = os.path.join(base_dir, 'train')
os.mkdir(train_dir)
validation_dir = os.path.join(base_dir, 'validation')
os.mkdir(validation_dir)
test_dir = os.path.join(base_dir, 'test')
os.mkdir(test_dir)

# 猫训练图片所在目录
train_cats_dir = os.path.join(train_dir, 'cats')
os.mkdir(train_cats_dir)

# 狗训练图片所在目录
train_dogs_dir = os.path.join(train_dir, 'dogs')
os.mkdir(train_dogs_dir)

# 猫验证图片所在目录
validation_cats_dir = os.path.join(validation_dir, 'cats')
os.mkdir(validation_cats_dir)

# 狗验证数据集所在目录
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
os.mkdir(validation_dogs_dir)

# 猫测试数据集所在目录
test_cats_dir = os.path.join(test_dir, 'cats')
os.mkdir(test_cats_dir)

# 狗测试数据集所在目录
test_dogs_dir = os.path.join(test_dir, 'dogs')
os.mkdir(test_dogs_dir)

# 将前1000张猫图像复制到train_cats_dir
fnames = ['cat.{}.jpg'.format(i) for i in range(1000)]
for fname in fnames:
    src = os.path.join(original_dataset_dir, fname)
    dst = os.path.join(train_cats_dir, fname)
    shutil.copyfile(src, dst)

# 将下500张猫图像复制到validation_cats_dir
fnames = ['cat.{}.jpg'.format(i) for i in range(1000, 1500)]
for fname in fnames:
    src = os.path.join(original_dataset_dir, fname)
    dst = os.path.join(validation_cats_dir, fname)
    shutil.copyfile(src, dst)
    
# 将下500张猫图像复制到test_cats_dir
fnames = ['cat.{}.jpg'.format(i) for i in range(1500, 2000)]
for fname in fnames:
    src = os.path.join(original_dataset_dir, fname)
    dst = os.path.join(test_cats_dir, fname)
    shutil.copyfile(src, dst)
    
# 将前1000张狗图像复制到train_dogs_dir
fnames = ['dog.{}.jpg'.format(i) for i in range(1000)]
for fname in fnames:
    src = os.path.join(original_dataset_dir, fname)
    dst = os.path.join(train_dogs_dir, fname)
    shutil.copyfile(src, dst)
    
# 将下500张狗图像复制到validation_dogs_dir
fnames = ['dog.{}.jpg'.format(i) for i in range(1000, 1500)]
for fname in fnames:
    src = os.path.join(original_dataset_dir, fname)
    dst = os.path.join(validation_dogs_dir, fname)
    shutil.copyfile(src, dst)
    
# 将下500张狗图像复制到test_dogs_dir
fnames = ['dog.{}.jpg'.format(i) for i in range(1500, 2000)]
for fname in fnames:
    src = os.path.join(original_dataset_dir, fname)
    dst = os.path.join(test_dogs_dir, fname)
    shutil.copyfile(src, dst)

运行效果如下:

在这里插入图片描述

  • 查看分类后,对应目录下的图片数量:
#输出数据集对应目录下图片数量
print('total training cat images:', len(os.listdir(train_cats_dir)))
print('total training dog images:', len(os.listdir(train_dogs_dir)))
print('total validation cat images:', len(os.listdir(validation_cats_dir)))
print('total validation dog images:', len(os.listdir(validation_dogs_dir)))
print('total test cat images:', len(os.listdir(test_cats_dir)))
print('total test dog images:', len(os.listdir(test_dogs_dir)))

在这里插入图片描述
可从上图看出猫狗训练图片各 1000 张,验证图片各 500 张,测试图片各 500 张。

3.2 猫狗分类的实例——基准模型

3.1 构建神经网络

#网络模型构建
from keras import layers
from keras import models
#keras的序贯模型
model = models.Sequential()
#卷积层,卷积核是3*3,激活函数relu
model.add(layers.Conv2D(32, (3, 3), activation='relu',
                        input_shape=(150, 150, 3)))
#最大池化层
model.add(layers.MaxPooling2D((2, 2)))
#卷积层,卷积核2*2,激活函数relu
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
#最大池化层
model.add(layers.MaxPooling2D((2, 2)))
#卷积层,卷积核是3*3,激活函数relu
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
#最大池化层
model.add(layers.MaxPooling2D((2, 2)))
#卷积层,卷积核是3*3,激活函数relu
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
#最大池化层
model.add(layers.MaxPooling2D((2, 2)))
#flatten层,用于将多维的输入一维化,用于卷积层和全连接层的过渡
model.add(layers.Flatten())
#全连接,激活函数relu
model.add(layers.Dense(512, activation='relu'))
#全连接,激活函数sigmoid
model.add(layers.Dense(1, activation='sigmoid'))
  • 查看模型各层参数情况:
#输出模型各层的参数状况
model.summary()

在这里插入图片描述

3.2 配置优化器

  • loss:计算损失,这里用的是交叉熵损失
  • metrics:列表,包含评估模型在训练和测试时的性能的指标
from keras import optimizers

model.compile(loss='binary_crossentropy',
              optimizer=optimizers.RMSprop(lr=1e-4),
              metrics=['acc'])

3.3 图片格式转化

  • 所有图片(2000张)重设尺寸大小为 150x150 大小,并使用ImageDataGenerator 工具将本地图片 .jpg 格式转化成 RGB 像素网格,再转化成浮点张量上传到网络上
    from keras.preprocessing.image import ImageDataGenerator
# 所有图像将按1/255重新缩放
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
        # 这是目标目录
        train_dir,
        # 所有图像将调整为150x150
        target_size=(150, 150),
        batch_size=20,
        # 因为我们使用二元交叉熵损失,我们需要二元标签
        class_mode='binary')

validation_generator = test_datagen.flow_from_directory(
        validation_dir,
        target_size=(150, 150),
        batch_size=20,
        class_mode='binary')

在这里插入图片描述

  • 查看上述图像预处理生成其中的输出
#查看上面对于图片预处理的处理结果
for data_batch, labels_batch in train_generator:
    print('data batch shape:', data_batch.shape)
    print('labels batch shape:', labels_batch.shape)
    break

在这里插入图片描述

3.4 训练模型

#模型训练过程
history = model.fit_generator(
      train_generator,
      steps_per_epoch=100,
      epochs=30,
      validation_data=validation_generator,
      validation_steps=50)

在这里插入图片描述

3.5 保存模型

#保存训练得到的的模型
model.save('D:\\Cat_And_Dog\\kaggle\\cats_and_dogs_small_1.h5')

3.6 可视化

#对于模型进行评估,查看预测的准确性
import matplotlib.pyplot as plt

acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()

plt.figure()

plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()

plt.show()

在这里插入图片描述

  • 训练结果如上图所示,很明显模型上来就过拟合了,主要原因是数据不够,或者说相对于数据量,模型过复杂(训练损失在第30个epoch就降为0了),训练精度随着时间线性增长,直到接近100%,而我们的验证精度停留在70-72%。我们的验证损失在5个epoch后达到最小,然后停止,而训练损失继续线性下降,直到接近0。

  • 这里先解释下什么是过拟合?
    举个例子:我们设计了一个模型来判断 一件物品是否为树叶。喂养这个模型的数据集中含有几张带有尖刺边缘的树叶。模型的设计者希望模型能满足每一个训练数据,模型就将尖刺边缘也纳入了参数中。当我们测试这个模型的泛化性能时,就会发现效果很差,因为模型钻牛角尖,它认为树叶必须带有尖刺边缘,所以它排除了所有没有带有尖刺边缘的树叶,但事实上,我们知道树叶并不一定带有尖刺边缘。

  • 过拟合常见解决方法:
    (1)在神经网络模型中,可使用权值衰减的方法,即每次迭代过程中以某个小因子降低每个权值。
    (2)选取合适的停止训练标准,使对机器的训练在合适的程度;
    (3)保留验证数据集,对训练成果进行验证;
    (4)获取额外数据进行交叉验证;
    (5)正则化,即在进行目标函数或代价函数优化时,在目标函数或代价函数后面加上一个正则项,一般有L1正则与L2正则等。

  • 不过接下来将使用一种新的方法,专门针对计算机视觉,在深度学习模型处理图像时几乎普遍使用——数据增强。

三、数据增强

数据集增强主要是为了减少网络的过拟合现象,通过对训练图片进行变换可以得到泛化能力更强的网络,更好的适应应用场景。

常用的数据增强方法有:
在这里插入图片描述
重新构建模型:

  • 我们重新建一个 .ipynb 文件,重新开始建模
  • 首先猫狗图像预处理,只不过这里将分类好的数据集放在 train2 文件夹中,其它的都一样
    在这里插入图片描述
  • 然后配置网络模型、构建优化器,然后进行数据增强,代码如下:

图像数据生成器增强数据

from keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(
      rotation_range=40,
      width_shift_range=0.2,
      height_shift_range=0.2,
      shear_range=0.2,
      zoom_range=0.2,
      horizontal_flip=True,
      fill_mode='nearest')
  • 参数解释:
    在这里插入图片描述
  • 查看数据增强后的效果:
import matplotlib.pyplot as plt
# This is module with image preprocessing utilities
from keras.preprocessing import image
fnames = [os.path.join(train_cats_dir, fname) for fname in os.listdir(train_cats_dir)]
# We pick one image to "augment"
img_path = fnames[3]
# Read the image and resize it
img = image.load_img(img_path, target_size=(150, 150))
# Convert it to a Numpy array with shape (150, 150, 3)
x = image.img_to_array(img)
# Reshape it to (1, 150, 150, 3)
x = x.reshape((1,) + x.shape)
# The .flow() command below generates batches of randomly transformed images.
# It will loop indefinitely, so we need to `break` the loop at some point!
i = 0
for batch in datagen.flow(x, batch_size=1):
    plt.figure(i)
    imgplot = plt.imshow(image.array_to_img(batch[0]))
    i += 1
    if i % 4 == 0:
        break
plt.show()

在这里插入图片描述
在这里插入图片描述

  • 图片格式转换
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,)
# Note that the validation data should not be augmented!
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
        # This is the target directory
        train_dir,
        # All images will be resized to 150x150
        target_size=(150, 150),
        batch_size=32,
        # Since we use binary_crossentropy loss, we need binary labels
        class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
        validation_dir,
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary')
  • 训练模型
history = model.fit_generator(
      train_generator,
      steps_per_epoch=100,
      epochs=100,
      validation_data=validation_generator,
      validation_steps=50)
model.save('E:\\Cat_And_Dog\\kaggle\\cats_and_dogs_small_2.h5')

在这里插入图片描述

  • 可视化
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()
plt.figure()
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()

在这里插入图片描述

  • 由于数据量的增加,对比基准模型,可以很明显的观察到曲线没有过度拟合了,训练曲线紧密地跟踪验证曲线,这也就是数据增强带来的影响,但是可以发现它的波动幅度还是比较大的。
  • 下面在此数据增强的基础上,再增加一层 dropout 层,再来训练看看。

四、dropout 层

什么是dropout层?

Dropout层在神经网络层当中是用来干嘛的呢?它是一种可以用于减少神经网络过拟合的结构,那么它具体是怎么实现的呢?
假设下图是我们用来训练的原始神经网络:
在这里插入图片描述
在这里插入图片描述

实现:
在构建网络模型时额外加入以下代码:

#退出层
model.add(layers.Dropout(0.5))

在这里插入图片描述
运行结果如下:
在这里插入图片描述

五、参考资料

https://blog.csdn.net/ssj925319/article/details/117787737

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/709854.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Openresty原理概念篇(十五)Lua 规则和 NGINX 配置文件产生冲突怎么办?

一 Lua 规则和 NGINX 配置文件产生冲突怎么办? ① OpenResty 的名字和语言 说明: 了解openresty的发展史 ② 配置文件的规则优先级 1) 如何各司其职2) 都能满足功能,该如何取舍 理解: 1) rewrite ... break 到POST_WRITE阶段2) 而rewrite_by_lua*…

JAVA的DIFF算法

首先看一下我的文件结构 1.EnumType 类 public enum EnumType {ADD("ADD"),MODIFIED("MODIFIED"), DELETED("DELETED");//创建私有变量private String type;EnumType(String type) {this.type type;} }2.OperationType类 public class Operati…

vue封装svg组件来修改svg图片颜色

文章目录 1、引入依赖2、根目录的vue.config.js配置3、在组件文件夹(compontents)中创建svgIcon.vue4、在src目录下创建icons文件5、处理svg格式的图片6、在main.js文件中引入icons文件中的index.js文件7、使用8、效果图1、项目成功运行后的样子2、直接在html上添加样式&#x…

DEBUG系列三:使用 F9 和 watch point

首先是我随便找了个报错。 报销消息号信息: No pricing procedure could be determined Message No. V1212 1)首先可以直接SE91 来追溯这个消息号哪儿报出来的 可以看到下面两个地方可能会报这个消息,可以直接在这两个地方打断点,…

开发一个RISC-V上的操作系统(一)—— 环境搭建

在前面我们使用Verilog实现了一个简易的RISC-V处理器,并且能烧录到板子上跑一些简单C程序,传送门: RISC-V处理器的设计与实现(一)—— 基本指令集_risc_v处理器_Patarw_Li的博客-CSDN博客 RISC-V处理器的设计与实现&…

电子器件系列41:扁平高压电阻

这种电阻和其他的高压电阻不同,不是绕线电阻而是陶瓷电阻 找到一个大神,他的专栏也得很详细了,贴在这里 https://blog.csdn.net/wkezheng/category_12059870.html 阻容感基础03:电阻器分类(1)-片式电阻器…

如何快速判断是否在容器环境

在渗透测试过程中,我们的起始攻击点可能在一台虚拟机里或是一个Docker环境里,甚至可能是在K8s集群环境的一个pod里,我们应该如何快速判断当前是否在容器环境中运行呢? 当拿到shell权限,看到数字和字母随机生成的主机名…

软考A计划-系统集成项目管理工程师-项目范围管理(二)

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例点击跳转>软考全系列 👉关于作者 专注于Android/Unity和各种游戏开发技巧&#xff…

HTML、Markdown、Word、Excel等格式的文档转换为PDF

工具:gotenberg,docker部署 github:https://github.com/gotenberg/gotenberg 文档:https://gotenberg.dev/docs/about https://gotenberg.dev/docs/modules/libreoffice docker运行: docker run -d --rm -p 3000:30…

kubernete部署prometheus监控sring-boot程序

目录 1、kubernete集群环境以及prometheus基础环境 2、kubernetes监控集群内部的spring-boot程序 2.1、application.yml系统配置,endpoints相关设置 2.2、引入监控的相关依赖文件 pom.xml ,主要是spring-boot-starter-actuator和micrometer-registr…

ModaHub魔搭社区:向量数据库Milvus产品问题(二)

目录 为什么向量距离计算方式是内积时,搜索出来的 top1 不是目标向量本身? 对集合分区的查询是否会受到集合大小的影响,尤其在集合数据量高达一亿数据量时? 如果只是搜索集合中的部分分区,整个集合的数据会全部加载…

表单(form) post 方式提交时的编码与乱码(上)

在上一篇章中谈论了表单以 get 提交时的编码与乱码问题, 这一章中将讨论以 post 方式提交时的编码与乱码问题. 在前面也同时提到, 表单有一个叫 enctype 的属性, 它有两个值, application/x-www-form-urlencoded 和 multipart/form-data. 这一属性实际只对 post 方式起作用, …

@Configuration 和 @Component 的区别 ,别再瞎用了!

一句话概括就是 Configuration 中所有带 Bean 注解的方法都会被动态代理,因此调用该方法返回的都是同一个实例。 理解:调用Configuration类中的Bean注解的方法,返回的是同一个示例;而调用Component类中的Bean注解的方法&#xff…

List, Set, Ordered-SetHash

前言 本文小结Redis中List,Set,ZSet和Hash四种数据类型的,基本特点,使用场景和实现方式。 一、List 1. 基本特点 a. 作为数组,基于下标索引操作, 但支持正向索引和反向索引; b. 作为链表, 支持高效插入&#xff1b…

信息安全-应用安全-定制化白盒检测 | 越权漏洞治理分享

目录 一、背景 二、面临的挑战 三、治理目标 四、解决方案 4.1 系统架构 4.2 鉴权函数 4.3 告警识别 4.4 鉴权分 五、未来的白盒检测方向 六、越权治理 七、小结 一、背景 在漏洞扫描领域,主流的扫描方式分为黑盒扫描和白盒扫描,其中源代码安…

MYSQL-主键外键约束

主键语法: 在创建表指定列数据类型时在后面加(可以结合AUTO_INCREMENT) PRIMARY KEY 主键要短,可唯一标识记录,且永不改变。 外键语法: 第一个column_name是被指定外键的本表列名 table_name是主键的表名 第二个column_name是主键列名 FOREIGN KE…

使用DataX同步数据(小白步骤,一看就懂)

详细文档说明,及图文讲解 ​​​​​​datax的异构数据同步资源-CSDN文库 Datax简介 下载datax软件,从开源镜像下载

python接口自动化(七)--状态码详解对照表(详解)

简介 我们为啥要了解状态码,从它的作用,就不言而喻了。如果不了解,我们就会像个无头苍蝇,横冲直撞。遇到问题也不知道从何处入手,就是想找别人帮忙,也不知道是找前端还是后端的工程师。 状态码的作用是&…

串口接收不定长数据的实现

使用串口进行数据的收发在嵌入式产品中是很常用的一种通信方式,因为串口的简单使用,很容易就被选为产品中数据交互的通信手段。 基于串口进行开发的功能有很多,比如同类/不同类产品之间的通信,RS485通信,RS232通信方式…