TensorFlow2实战-系列教程3:猫狗识别1

news2025/1/9 23:43:53

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

1、项目介绍

  • 数据预处理:图像数据处理,准备训练和验证数据集
  • 卷积网络模型:构建网络架构
  • 过拟合问题:观察训练和验证效果,针对过拟合问题提出解决方法
  • 数据增强:图像数据增强方法与效果
  • 迁移学习:深度学习必备训练策略

2、数据读取

import os
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# 数据所在文件夹
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

# 训练集
train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')

# 验证集
validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
  1. 导包
  2. 指定数据路径
  3. 训练数据路径
  4. 验证数据路径
  5. 训练数据猫类别路径
  6. 训练数据狗类别路径
  7. 验证数据猫类别路径
  8. 训练数据狗类别路径

3、构建卷积神经网络

model = tf.keras.models.Sequential([
    #如果训练慢,可以把数据设置的更小一些
    tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(64, 64, 3)),
    tf.keras.layers.MaxPooling2D(2, 2),

    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),

    tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    
    #为全连接层准备
    tf.keras.layers.Flatten(),
    
    tf.keras.layers.Dense(512, activation='relu'),
    # 二分类sigmoid就够了
    tf.keras.layers.Dense(1, activation='sigmoid')
])

3个3x3卷积,穿插3个2x2池化,拉平操作,两个全连接层

model.summary()

打印一下模型架构:

配置训练器:

model.compile(loss='binary_crossentropy', optimizer=Adam(lr=1e-4), metrics=['acc'])

4、数据预处理

  • 读进来的数据会被自动转换成tensor(float32)格式,分别准备训练和验证
  • 图像数据归一化(0-1)区间
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
        train_dir,  # 文件夹路径
        target_size=(64, 64),  # 指定resize成的大小
        batch_size=20,
        # 如果one-hot就是categorical,二分类用binary就可以
        class_mode='binary')

validation_generator = test_datagen.flow_from_directory(
        validation_dir,
        target_size=(64, 64),
        batch_size=20,
        class_mode='binary')

打印结果:
Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.

5、模型训练

  • 直接fit也可以,但是通常咱们不能把所有数据全部放入内存,fit_generator相当于一个生成器,动态产生所需的batch数据
  • steps_per_epoch相当给定一个停止条件,因为生成器会不断产生batch数据,说白了就是它不知道一个epoch里需要执行多少个step
history = model.fit_generator(
      train_generator,
      steps_per_epoch=100,  # 2000 images = batch_size * steps
      epochs=20,
      validation_data=validation_generator,
      validation_steps=50,  # 1000 images = batch_size * steps
      verbose=2)

部分打印结果:
Epoch 1/20 100/100 - 9s - loss: 0.6909 - acc: 0.5240 - val_loss: 0.6952 - val_acc: 0.5000
Epoch 2/20 100/100 - 9s - loss: 0.6645 - acc: 0.5960 - val_loss: 0.6906 - val_acc: 0.5360

Epoch 19/20 100/100 - 9s - loss: 0.1750 - acc: 0.9460 - val_loss: 0.6277 - val_acc: 0.7390
Epoch 20/20 100/100 - 9s - loss: 0.1593 - acc: 0.9505 - val_loss: 0.5901 - val_acc: 0.7490

6、预测效果展示

import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'bo', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')

plt.figure()

plt.plot(epochs, loss, 'bo', label='Training Loss')
plt.plot(epochs, val_loss, 'b', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()

plt.show()

在这里插入图片描述
在这里插入图片描述
将训练损失、准确率和对应的epoch分别画图展示

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1412236.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

无人值守变电所运维在海南市某住宅区的应用

1 前言 随着国家电网改革政策的逐步推进和落实,AcrelCloud-1000变电所运维云平台运用互联网和大数据技术,为电力运维公司提供变电所运维云平台。该平台作为连接运维单位和用电企业的纽带,监视用户配电系统的运行状态和电量数据,为…

HCIA学习第二天OSI七层协议与网络协议_操纵网络设备

第一天总结: 对等网——网络变大——无限的传输距离 无冲突 单播 为满足以上问题,出现了--网桥--紧接着出现了交换机——介质访问控制层(二层设备)——识别MAC地址 (认识有记录-单播 不认识无记录-泛洪(泛…

【java面试】Spring

目录 1. Spring 介绍1.1 Spring 的优点1.2 Spring 的缺点1.3 详细讲解一下核心容器(spring context应用上下文) 模块 2. Spring俩大核心概念IOC,Inversion of Control,控制反转AOP(Aspect-OrientedProgramming),面向切面编程Sprin…

如何实现高效一键群发1000人?

对于职场人来说,微信里的客户越多,成交的概率越大。但客户太多,群发是个难题,因为微信本身的群发是有数量限制的。 我们都知道,群发消息可以帮助我们: 1. 高效传递信息:群发短信的方式能够迅速…

Vue<圆形旋转菜单栏效果>

效果图: 大家不一定非要制成菜单栏,可以看下人家的华丽效果😝,参考地址 https://travelshift.com/ 大佬写的效果可比我的强多了,但是无从下手,所以就自己琢磨怎么写了,只能说效果勉强差不多 可以通过更改data值和注释我标注的css样式处部分,就可以实现全圆的效果😄…

vue3 + antd 封装动态表单组件(二)

传送带: vue3 antd 封装动态表单组件(一) 前置条件: vue版本 v3.3.11 ant-design-vue版本 v4.1.1 vue3 antd 封装动态表单组件(一)是基础版本,但是并不好用, 因为需要配置很多表…

echarts + gauge + 半圆效果

请注意以下配置需要echarts 5.0.0以上版本,4是不行的 option {series: [{center: [50%, 80%],type: gauge,startAngle: 180,endAngle: 0,min: 0,max: 150,axisTick: {show: false},splitLine: {show: false},detail: {color: #3096fe,offsetCenter: [0, -10],form…

Docker部署Stable-Diffusion-webui

前排提示:如果不想折腾,可直接跳到最后获取封装好的容器,一键运行 :D 前言 乘上AI生成的快车,一同看看沿途的风景。 启一个miniconda容器 docker run -itd -v 宿主机内SD项目路径:/tmp --gpus all --ipc host -p 7860:7860 con…

echarts柱状图添加白色柱状图背景+滚动+柱状顶部添加横线

echarts柱状图添加白色柱状图背景滚动 <template><div class"stream-water-wrapper"><div id"biologybgchart" class"bck-chart"></div><div id"biologychart" class"biology-chart"></…

【JS基础】定时器的使用、事件监听

文章目录 前言一、定时器1.1定时器是什么1.2 setInterval函数1.3 关闭定时器clearInterval 二、事件监听2.1 事件监听是什么2.2 事件监听的使用基本语法点击事件鼠标事件焦点事件键盘事件 2.3 事件对象event 总结 前言 JavaScript 中的定时器和事件监听是 Web 开发中至关重要的…

如何在今日头条广告中轻松唤起微信?这个方法你一定不能错过

要在今日头条的广告中调起微信&#xff0c;实现加好友的功能&#xff0c;可以参考以下步骤&#xff1a; 首先&#xff0c;通过搜索引擎找到“数灵通”外链工具的官网&#xff0c;并进入其后台。在后台填写相关参数&#xff0c;生成一条能够跳转到微信的链接。这个链接的作用是…

云手机哪一款好用?

随着海外市场的不断发展&#xff0c;云手机市场也呈现蓬勃的态势&#xff0c;众多云设备软件纷纷涌现。企业在选择云手机软件时&#xff0c;如何找到性能卓越的软件成为一项关键任务。在众多选择中&#xff0c;OgPhone云手机凭借其卓越的性能和独特功能脱颖而出。以下是OgPhone…

机器学习分类模型评价指标总结(准确率、精确率、召回率、Fmax、TPR、FPR、ROC曲线、PR曲线,AUC,AUPR)

为了看懂论文&#xff0c;不得不先学一些预备知识&#xff08;&#xff08;55555 主要概念 解释见图 TP、FP、TN、FN 准确率、精确率&#xff08;查准率&#xff09;、召回率&#xff08;查全率&#xff09; 真阳性率TPR、伪阳性率FPR F1-score2TP/(2*TPFPFN) 最大响应分…

【python】爬取豆瓣影评保存到Excel文件中【附源码】

欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 【往期相关文章】 爬取豆瓣电影排行榜Top250存储到Excel文件中 爬取豆瓣电影排行榜TOP250存储到CSV文件中 爬取知乎热榜Top50保存到Excel文件中 爬取百度热搜排行榜Top50可视化 爬取斗鱼直播照片保存到本地目录 爬…

司铭宇老师:汽车销售培训:汽车销售员培训:汽车销售技巧培训:汽车销售技巧和话术

汽车销售培训&#xff1a;汽车销售员培训&#xff1a;汽车销售技巧培训&#xff1a;汽车销售技巧和话术 汽车销售是一项充满挑战性的工作&#xff0c;它需要销售人员具备良好的沟通技巧、谈判技巧以及产品讲解能力。在这篇文章中&#xff0c;我们将详细探讨汽车销售中的技巧和话…

无状态应用管理Deployment

无状态应用管理Deployment 1、Deployment介绍 Deployment一般用于部署公司的无状态服务。 格式&#xff1a; apiVersion: apps/v1 kind: Deployment metadata: name: nginx-deployment labels: app: nginx spec: replicas: 3 selector: matchLabels: app: nginx …

openlayers+vue实现缓冲区

文章目录 前言一、准备二、初始化地图1、创建一个地图容器2、引入必须的类库3、地图初始化4、给地图增加底图 三、创建缓冲区1、引入需要的工具类库2、绘制方法 四、完整代码总结 前言 缓冲区是地理空间目标的一种影响范围或服务范围,是对选中的一组或一类地图要素(点、线或面…

2024年最新MacBook苹果电脑安装JDK8、JDK11教程,配置环境变量 + 快速切换JDK版本

本帖发布日期&#xff1a;2024年01月26日&#xff0c;全网最新教程整理。 1、概述 本文主要为在MacBook苹果电脑系统下安装JDK及环境变量配置。 教程并非原创&#xff0c;摘抄自互联网&#xff0c;本人作为更新整理亲测。&#xff08;也算给自己记录一贴&#xff09; 本帖分…

python之异常的捕获、模块、包

目录 1.了解异常 2.异常的捕获 3.异常的传递性 4.模块的概念和导入 5.自定义模块并导入 6.自定义python包 7.安装第三方包 1.了解异常 2.异常的捕获 直接报错了&#xff0c;说明我们捕获的就是名字的异常而没有捕获除0的异常。 这样就可以打印出异常 捕获全部的异常可以使…

HTML-表单

表单 概念&#xff1a;一个包含交互的区域&#xff0c;用于收集用户提供的数据。 1.基本结构 示例代码&#xff1a; <form action"https://www.baidu.com/s" target"_blank" method"get"><input type"text" name"wd&q…