TensorFlow2实战-系列教程5:猫狗识别2------数据增强

news2024/11/28 19:01:38

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

猫狗识别1
数据增强
猫狗识别2------数据增强
猫狗识别3------迁移学习

1、猫狗识别任务

import os
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')

validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(64, 64, 3)),
    tf.keras.layers.MaxPooling2D(2, 2),

    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),

    tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),

    tf.keras.layers.Flatten(),

    tf.keras.layers.Dense(512, activation='relu'),

    tf.keras.layers.Dense(1, activation='sigmoid')
])

model.compile(loss='binary_crossentropy',
              optimizer=Adam(lr=1e-4),
              metrics=['acc'])

依次是导包、指定数据路径、构建模型、配置训练器等,这些都与前面TensorFlow2实战-系列教程3:猫狗识别1完全一致

2、数据增强

train_datagen = ImageDataGenerator(
      rescale=1./255,
      rotation_range=40,
      width_shift_range=0.2,
      height_shift_range=0.2,
      shear_range=0.2,
      zoom_range=0.2,
      horizontal_flip=True,
      fill_mode='nearest')

test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
        train_dir,  
        target_size=(64, 64),  
        batch_size=20,
        class_mode='binary')

validation_generator = test_datagen.flow_from_directory(
        validation_dir,
        target_size=(64, 64),
        batch_size=20,
        class_mode='binary')

history = model.fit_generator(
      train_generator,
      steps_per_epoch=100,  # 2000 images = batch_size * steps
      epochs=100,
      validation_data=validation_generator,
      validation_steps=50,  # 1000 images = batch_size * steps
      verbose=2)

train_datagen:

  1. 这里将rescale重新缩放、旋转、平移变换、剪切变换、缩放、水平翻转、以临近方式填充等多种方式对训练数据进行数据增强
  2. shear_range=0.2 表示图像将在 -0.2 到 +0.2 弧度的范围内随机剪切

test_datagen:

  1. 验证数据,没有进行数据增强,这里只进行了归一化操作

train_generator:

  1. train_dir 目录加载训练图像,并应用前面定义的数据增强
  2. target_size=(64, 64):调整图像大小为 64x64 像素
  3. batch_size=20:每批次处理 20 张图像
  4. class_mode='binary':因为是二分类任务。

validation_generator:

  1. validation_dir 目录加载验证图像,只应用缩放

history:

  1. fit_generator 方法在 TensorFlow 2.2 之后已经被弃用,建议使用 fit 方法替代)
  2. 开始训练
  3. validation_data=validation_generator:指定验证数据生成器
  4. verbose=2:用于控制训练过程中输出的详细程度

3、预测效果展示

import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'b', label='Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()

plt.figure()

plt.plot(epochs, loss, 'b', label='Training Loss')
plt.plot(epochs, val_loss, 'r', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()

plt.show()

在这里插入图片描述

在这里插入图片描述

很显然经过数据增强后的模型表现对比原本效果有显著提升

3、加入Dropout

Dropout就是指定比例,对这一层随机杀死一下神经元,这里我们只需要在构建网络的时候在全连接层加上一层Dropout就可以了:

import os
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')

validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(64, 64, 3)),
    tf.keras.layers.MaxPooling2D(2, 2),

    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),

    tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),

    tf.keras.layers.Flatten(),

    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

model.compile(loss='binary_crossentropy',
              optimizer=Adam(lr=1e-4),
              metrics=['acc'])
train_datagen = ImageDataGenerator(
      rescale=1./255,
      rotation_range=40,
      width_shift_range=0.2,
      height_shift_range=0.2,
      shear_range=0.2,
      zoom_range=0.2,
      horizontal_flip=True,
      fill_mode='nearest')

test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
        train_dir,  
        target_size=(64, 64),  
        batch_size=20,
        class_mode='binary')

validation_generator = test_datagen.flow_from_directory(
        validation_dir,
        target_size=(64, 64),
        batch_size=20,
        class_mode='binary')

history = model.fit_generator(
      train_generator,
      steps_per_epoch=100,  # 2000 images = batch_size * steps
      epochs=100,
      validation_data=validation_generator,
      validation_steps=50,  # 1000 images = batch_size * steps
      verbose=2)

Epoch 100/100
100/100 - 3s - loss: 0.4145 - acc: 0.8145 - val_loss: 0.4269 - val_acc: 0.7830 - 3s/epoch - 33ms/step

在这里插入图片描述
在这里插入图片描述
这效果又提升了一点

猫狗识别1
数据增强
猫狗识别2------数据增强
猫狗识别3------迁移学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1425357.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

国家级专精特新“小巨人”第一至五批名单

国家级专精特新“小巨人”第一至五批名单 1、来源:工信部 2、样本量:1.29W第一批企业共248家,A股上市35家;第二批企业共1744家,A股上市157家;第三批企业共2930家,A股上市119家;第四…

【C++干货铺】哈希结构在C++中的应用

目录 unordered系列关联式容器 unordered_map unordered_map的接口说明 1.unordered_map的构造 2. unordered_map的容量 3. unordered_map的迭代器 4. unordered_map的元素访问 5. unordered_map的查询 6. unordered_map的修改操作 7. unordered_map的桶操作 底层结构 …

【知识点】设计模式

创建型 单例模式 Singleton:确保一个类只有一个实例,并提供该实例的全局访问点 使用一个私有构造方法、一个私有静态变量以及一个公有静态方法来实现。私有构造方法确保了不能通过构造方法来创建对象实例,只能通过公有静态方法返回唯一的私…

Qt实现窗口吸附屏幕边缘 自动收缩

先看效果: N年前的QQ就可以吸附到屏幕边缘,聊天时候非常方便,不用点击状态栏图标即可呼出QQ界面 自己尝试做了一个糙版的屏幕吸附效果。 关键代码: void Widget::mouseMoveEvent(QMouseEvent *e) {int dx e->globalX() - l…

C语言基础:写一个函数,输入一行字符,将此字符串最长的单词输出

方法一&#xff1a; #include<string.h> int find_longest(char line[])//把数组传过来 {int is_alphabetic(char word);int i 0;int length 0;//统计每个字符串的长度int max 0;//比max长就把值赋值给maxint place 0;//最长单词的起始位置int point;//每个字符串第…

暴搜,回溯,剪枝

力扣77.组合 class Solution {List<List<Integer>>retnew ArrayList<>();List<Integer>pathnew ArrayList<>();int n; int k;public List<List<Integer>> combine(int _n, int _k) {n_n;k_k;dfs(1);return ret;}public void dfs(int…

2024斋月大促跨境卖家准备指南

市场覆盖西欧、中东、东南亚、北非地区的跨境电商卖家注意了&#xff0c;2024年的斋月即将开启&#xff0c;较往年日期&#xff0c;今年提前了10天左右&#xff0c;斋月的第一天预测在3月11日星期一到来。 根据Google搜索数据可知&#xff0c;目前已经进入高频“斋月”搜索期&…

小米商城服务治理之客户端熔断器(Google SRE客户端熔断器)

目录 前言 一、什么是Google SRE熔断器 二、Google SRE 熔断器的工作流程&#xff1a; 三、客户端熔断器 (google SRE 熔断器) golang GRPC 实现 四、客户端熔断器 (google SRE 熔断器) golang GRPC单元测试 大家可以关注个人博客&#xff1a;xingxing – Web Developer …

K8S网络

一、介绍 k8s不提供网络通信&#xff0c;提供了CNI接口(Container Network Interface&#xff0c;容器网络接口)&#xff0c;由CNI插件实现完成。 1.1 Pod通信 1.1.1 同一节点Pod通信 Pod通过虚拟Ethernet接口对&#xff08;Veth Pair&#xff09;与外部通信&#xff0c;Veth…

银河麒麟v10服务器版,specvirt测试

1 两台服务器&#xff0c;一台为SUT&#xff0c;一台为Phyclient。 1.1 两台服务器均编译安装gcc和qemu 按银河麒麟v10服务器arm版&#xff0c;qemugcc&#xff0c;跨架构安装虚拟机中步骤&#xff0c;编译安装gcc-9.3.0和qemu-7.0.0。 2 SUT服务器操作 2.1 mount数据盘到/…

如何发布自己的npm包:

1.创建一个打包组件或者库&#xff1a; 安装weback&#xff1a; 打开项目&#xff1a; 创建webpack.config.js,创建src目录 打包好了后发现两个js文件都被压缩了&#xff0c;我们想开发使用未压缩&#xff0c;生产使用压缩文件。 erserPlugin&#xff1a;&#xff08;推荐使用…

搭建 idea 插件仓库私服

正常情况下&#xff0c;我们开发的 idea 插件会发布到 idea 官方商城中&#xff0c;这样用户就可以在 idea 的 Marketplace 中搜索安装。 但是在企业内部&#xff0c;有可能我们开发了很多内部插件&#xff0c;而不能发布到公共市场中&#xff0c;这种情况下我们就需要搭建一个…

css多行文本擦拭效果

<!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, initial-scale1.0" /><title>多行文本擦拭效果</title><style>* …

black--一键格式化Python代码

black black是一个Python代码格式化程序&#xff0c;使用它可以免于在调整代码格式上花费时间。black被许多大大小小的项目成功使用&#xff0c;包括pytest, tox, Pyramid, Django等。 格式化效果&#xff1a; 可以在线查看格式化效果&#xff1a;https://black.vercel.app/…

ERP系统助力车间生产:班组、设备、工序一网打尽!实现生产全流程可视化!

​随着企业生产规模的扩大和业务复杂性的增加&#xff0c;车间管理在企业运营中的地位日益突出。ERP系统作为企业资源管理的核心平台&#xff0c;为车间管理提供了全面的解决方案。通过合理配置和使用ERP系统的功能模块&#xff0c;企业可以优化生产流程、提高生产效率、确保产…

【SparkML系列3】特征提取器TF-IDF、Word2Vec和CountVectorizer

本节介绍了用于处理特征的算法&#xff0c;大致可以分为以下几组&#xff1a; 提取&#xff08;Extraction&#xff09;&#xff1a;从“原始”数据中提取特征。转换&#xff08;Transformation&#xff09;&#xff1a;缩放、转换或修改特征。选择&#xff08;Selection&…

一文看懂动态住宅代理IP,附常见使用问题解答

动态住宅代理IP在保护在线隐私和个人数据安全方面发挥着重要作用。通过隐藏用户的真实IP地址和地理位置&#xff0c;它为网络用户提供了一个更安全、更私密的网络环境。这对于希望保护自己免受网络监控和个人信息泄露的用户来说&#xff0c;是一项不可或缺的网络工具。 一、动态…

RT-Thread:STM32的PB3,PB4 复用IO配置为GPIO

说明&#xff1a;在使用 STM32F103CBT6 配置了 PB3 为IO&#xff0c;测试时发现读取这个IO的电平时钟是0&#xff0c;即便单管脚上的电平是1&#xff0c;读取的数据任然是0,查规格书后发现PB3,PB4是JTAG复用口&#xff0c;要当普通IO用需要配置。 配置工具&#xff1a;STM32Cu…

React中封装大屏自适应(拉伸)仿照 vue2-scale-box

0、前言 仿照 vue2-scale-box 1、调用示例 <ScreenAutoBox width{1920} height{1080} flat{true}>{/* xxx代码 */}</ScreenAutoBox> 2、组件代码 import { CSSProperties, ReactNode, RefObject, useEffect, useRef, useState } from react//数据大屏自适应函数…

36万的售价,蔚来理想卖得,小米卖不得?

文 | AUTO芯球 作者 | 雷歌 Are you OK&#xff1f;雷军被网友们叫“小雷”&#xff01; 被网友一猜再猜的小米SU7的价格&#xff0c;因为一份保险上牌价格单的曝光被网友吵得热热闹闹&#xff0c;曝出的小米汽车顶配上牌保险价格为36.14万。 20万以下&#xff0c;人们愿称…