tensorflow-卷积神经网络-图像分类入门demo

news2024/10/5 2:26:19

猫狗识别

  • 数据预处理:图像数据处理,准备训练和验证数据集
  • 卷积网络模型:构建网络架构
  • 过拟合问题:观察训练和验证效果,针对过拟合问题提出解决方法
  • 数据增强:图像数据增强方法与效果
  • 迁移学习:深度学习必备训练策略

导入工具包

import os
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator

指定好数据路径(训练和验证)

# 数据所在文件夹
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

# 训练集
train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')

# 验证集
validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')

构建卷积神经网络模型

  • 几层都可以,大家可以随意玩
  • 如果用CPU训练,可以把输入设置的更小一些,一般输入大小更主要的决定了训练速度
  • model = tf.keras.models.Sequential([
        #如果训练慢,可以把数据设置的更小一些
        tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(64, 64, 3)),
        tf.keras.layers.MaxPooling2D(2, 2),
    
        tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
        tf.keras.layers.MaxPooling2D(2,2),
    
        tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
        tf.keras.layers.MaxPooling2D(2,2),
        
        #为全连接层准备
        tf.keras.layers.Flatten(),
        
        tf.keras.layers.Dense(512, activation='relu'),
        # 二分类sigmoid就够了
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])
    model.summary()

  • 配置训练器

    model.compile(loss='binary_crossentropy',
                  optimizer=Adam(lr=1e-4),
                  metrics=['acc'])

    数据预处理

  • 读进来的数据会被自动转换成tensor(float32)格式,分别准备训练和验证
  • 图像数据归一化(0-1)区间
    train_datagen = ImageDataGenerator(rescale=1./255)
    test_datagen = ImageDataGenerator(rescale=1./255)
    train_generator = train_datagen.flow_from_directory(
            train_dir,  # 文件夹路径
            target_size=(64, 64),  # 指定resize成的大小
            batch_size=20,
            # 如果one-hot就是categorical,二分类用binary就可以
            class_mode='binary')
    
    validation_generator = test_datagen.flow_from_directory(
            validation_dir,
            target_size=(64, 64),
            batch_size=20,
            class_mode='binary')

    训练网络模型

  • 直接fit也可以,但是通常咱们不能把所有数据全部放入内存,fit_generator相当于一个生成器,动态产生所需的batch数据
  • steps_per_epoch相当给定一个停止条件,因为生成器会不断产生batch数据,说白了就是它不知道一个epoch里需要执行多少个step
    history = model.fit_generator(
          train_generator,
          steps_per_epoch=100,  # 2000 images = batch_size * steps
          epochs=20,
          validation_data=validation_generator,
          validation_steps=50,  # 1000 images = batch_size * steps
          verbose=2)
    Epoch 1/20
    100/100 - 7s - loss: 0.6892 - acc: 0.5325 - val_loss: 0.6705 - val_acc: 0.5970
    Epoch 2/20
    100/100 - 6s - loss: 0.6595 - acc: 0.6055 - val_loss: 0.6346 - val_acc: 0.6470
    Epoch 3/20
    100/100 - 6s - loss: 0.6350 - acc: 0.6515 - val_loss: 0.6358 - val_acc: 0.6320
    Epoch 4/20
    100/100 - 7s - loss: 0.5936 - acc: 0.6865 - val_loss: 0.5906 - val_acc: 0.6780
    Epoch 5/20
    100/100 - 7s - loss: 0.5530 - acc: 0.7170 - val_loss: 0.5978 - val_acc: 0.6670
    Epoch 6/20
    100/100 - 8s - loss: 0.5179 - acc: 0.7490 - val_loss: 0.5484 - val_acc: 0.7140
    Epoch 7/20
    100/100 - 8s - loss: 0.4854 - acc: 0.7725 - val_loss: 0.5686 - val_acc: 0.7080
    Epoch 8/20
    100/100 - 8s - loss: 0.4595 - acc: 0.7905 - val_loss: 0.5452 - val_acc: 0.7150
    Epoch 9/20
    100/100 - 8s - loss: 0.4406 - acc: 0.7885 - val_loss: 0.5453 - val_acc: 0.7210
    Epoch 10/20
    100/100 - 7s - loss: 0.4109 - acc: 0.8170 - val_loss: 0.5317 - val_acc: 0.7270
    Epoch 11/20
    100/100 - 8s - loss: 0.3892 - acc: 0.8285 - val_loss: 0.5384 - val_acc: 0.7220
    Epoch 12/20
    100/100 - 8s - loss: 0.3542 - acc: 0.8570 - val_loss: 0.5480 - val_acc: 0.7180
    Epoch 13/20
    100/100 - 8s - loss: 0.3421 - acc: 0.8580 - val_loss: 0.5355 - val_acc: 0.7420
    Epoch 14/20
    100/100 - 8s - loss: 0.3217 - acc: 0.8665 - val_loss: 0.5572 - val_acc: 0.7340
    Epoch 15/20
    100/100 - 8s - loss: 0.2931 - acc: 0.8805 - val_loss: 0.5545 - val_acc: 0.7400
    Epoch 16/20
    100/100 - 8s - loss: 0.2739 - acc: 0.8870 - val_loss: 0.5540 - val_acc: 0.7360
    Epoch 17/20
    100/100 - 8s - loss: 0.2535 - acc: 0.9040 - val_loss: 0.5564 - val_acc: 0.7380
    Epoch 18/20
    100/100 - 8s - loss: 0.2257 - acc: 0.9245 - val_loss: 0.5710 - val_acc: 0.7420
    Epoch 19/20
    100/100 - 8s - loss: 0.2084 - acc: 0.9350 - val_loss: 0.5734 - val_acc: 0.7460
    Epoch 20/20
    100/100 - 8s - loss: 0.2258 - acc: 0.9130 - val_loss: 0.5897 - val_acc: 0.7300
    

    效果展示

    import matplotlib.pyplot as plt
    acc = history.history['acc']
    val_acc = history.history['val_acc']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    
    epochs = range(len(acc))
    
    plt.plot(epochs, acc, 'bo', label='Training accuracy')
    plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
    plt.title('Training and validation accuracy')
    
    plt.figure()
    
    plt.plot(epochs, loss, 'bo', label='Training Loss')
    plt.plot(epochs, val_loss, 'b', label='Validation Loss')
    plt.title('Training and validation loss')
    plt.legend()
    
    plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1033054.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网页版的 Redis 可视化工具来了,已开源?

轻量级Redis缓存图形化管理工具,包含redis的5种数据类型的CRUD操作 软件架构 后端 springboot 2.2.2.RELEASEJDK 1.8jedis 3.2.0commons-lang3 3.5hutool-core 5.1.1fastjson 1.2.62h2database 1.4.200 前端 vue-admin 1.0.5axios 0.15.3element-ui 2.13.0font-…

海外媒体发稿:海外汽车媒体推广9个方式解析

根据下列9个国外汽车媒体推广方式,企业能够在国际范围内突破边界,获得领域关心。这将帮助企业完成国际化发展发展战略,扩展市场占有率和提升盈利空间。【华媒舍】国外全媒体发表文章将会成为企业完成这一目标的重要方式,为企业带来…

Caton Media Xstream: 重新定义实时内容交付服务

// 编者按:随着公共互联网愈加复杂,best effort的基本原型已无法满足越来越多的有QoS保障需求的实时内容交付服务。而专线、卫星等传统解决方案存在部署成本高、周期长等问题,无法快速响应各类需求。LiveVideoStackCon邀请到了科腾科技的魏…

解决:Typora上传图片后本地显示不出来

在配置好PicGo、github以及Typora后,为了更好部署博客,将图片的偏好设置改为上传图片,会出现一个问题: github上图片已上传成功,但是本地Typora的图片不显示,这里进行配置: 文件——>偏好设…

Oracle 11g RAC部署笔记

搭了三次才搭好,要记录一下。 1. Oracle 11g RAC部署的相关步骤以及需要的包,可以参考这里。 Oracle 11g RAC部署_12006142的技术博客_51CTO博客Oracle 11g RAC部署,Oracle11gRAC部署操作环境:CentOS7.4Oracle11.2.0.4一、主机网…

安卓玩机搞机----不用刷第三方官改固件即可享受“高级设置”的操作 ChiMi安装使用步骤

很多玩友特别喜欢第三方作者修改的带有高级设置的官改包。因为他可以随意修改系统里面的有关设置选项。包括但不限于修改状态栏 显示日期 秒等等的操作。 第三方带高级设置的官改 一般官改带高级设置的类似与 今天给大家分享下不用刷这些官改包即可享受高级设置的操作。 红米…

Android面试题汇总(三)

Android 四大组件相关 1、Activity与Fragment之间常见的通讯方式 对于Activity与Fragment直接的相互调用: 1、Activity调用Fragment直接调用就好了,Activity一般是持有Fragment实例的。或者通过Fragment的id或者tag获取Fragment的实例 2、Fragment调用A…

CRM系统主要包括哪些功能?

CRM系统应该要包括的功能总结为3大方向—— 核心必须要具备的功能常见尽量要有的功能可选有了自然更好的功能 以我们公司用的简道云CRM系统模板为例:https://www.jiandaoyun.com 01 核心必须要具备的功能 核心功能决定了系统是否能够被纳入CRM类别,这些…

升级pip

升级pip 报错提示: WARNING: You are using pip version 19.1.1, however version 20.0.2 is available. You should consider upgrading via the ‘python -m pip install --upgrade pip’ command. 解决办法 py -m pip install --upgrade pip记得关掉梯子 如果…

全面重构存储系统,释放AI时代全新数据价值

《失控》作者凯文凯利做出判断: “不管你从事任何商业,你的生意现在就是数据生意。”一言蔽之,数据已是当今商业社会和经济发展的基石。企业的生产研发、经营管理和销售服务等所有环节愈发离不开数据,数实融合的加速为产业发展带来无限可能。…

视频编解码器H.264和H265有什么区别?

对于大型视频文件来说,视频编解码器至关重要,它可以将文件压缩为较小的尺寸,从而可以更轻松地存储和加快传输速度。而两种最常用的编解码器是H.264和H.265,那么它们两者之间有什么区别,哪一个更好呢? 1. 什…

https SSL证书使用 git bash 解密

申请域名证书后,有些证书下载时强制加密。 在使用时,比如在AWS ACM中使用时,不能用加密的证书。所以这里讲下怎么解密。 首先,加密一般加密的是公私钥中的私钥,即private.key。 填写密码,下载证书&#x…

玩转蓝牙墨水屏电子标签(一)点灯

对于垃圾佬的生活来说,每天逛海鲜市场是必不可少的生活片段,这不,手抖一下又刷到了一个东付的电子标价签。 价格合理,2块钱一个不包邮,直接买了N个。。。算了一下一个3.5,然后拿到群去炫耀了下,…

Xilinx FPGA 7系列 GTX/GTH Transceivers (4) Aurora 8b10b 递增数收发验证

第一节:Xilinx FPGA 7系列 GTX/GTH Transceivers (1)–了解了GTX硬件的基础知识 第二节:IBERT GTX --通过Ibert IP测试链路通信 第三节:aurora 8b10b single lane 4byte–学习官方历程 递增数验证 自行编写data_gen和data_check 验证aurora 8b10b SFP 1.25G 收发正确。 组…

用过lsof命令的,都竖起了大拇指!!!

lsof(list open files)是一个列出当前系统打开文件的工具。在linux环境下,任何事物都以文件的形式存在,通过文件不仅仅可以访问常规数据,还可以访问网络连接和硬件。所以如传输控制协议 (TCP) 和用户数据报协议 (UDP) …

webpack常用配置与性能优化插件

webpack是一个流行的前端项目构建工具(打包工具),可以解决当前web 开发中所面临的困境。 提供了友好的模块化支持,以及代码压缩混淆、处理js兼容问题、性能优化等强大的功能,从而让程序员把工作的重心放到具体的功能实…

Bigemap如何添加mapbox图源?

会使用到的工具 bigemap gis office,下载链接:BIGEMAP GIS Office-全能版 打开软件,要提示需要授权和添加地图,然后去点击选择地图这个按钮,列表中有个添加按钮点进去选择添加地图的方式。 第一种方式:通…

k8s-2 集群升级

首先导入镜像到本地 然后上传镜像到仓库 在所有集群节点 部署cri-docker k8s从1.24版本开始移除了dockershim,所以需要安装cri-docker插件才能使用docker 配置cri-docker 升级master 节点 升级kubeadm 执行升级计划 修改节点套接字 腾空节点 升级kubelet 配置k…

MySQL备份及恢复

目录 MySQL备份 MySQL备份方法 备份策略 mysql的完全备份 mysql的增量备份 MySQL恢复 mysql完全恢复 mysql增量备份的恢复 MySQL备份 MySQL备份是基于对MySQL的日志进行备份,且恢复也是通过日志进行数据恢复。 MySQL备份方法 物理备份:直接对…

如何在一个月内通过PMP考试?

新版本考纲,有一个最大的难点。那就是知识点来自《PMBOK指南》,以及《敏捷实践指南》;但是考试大纲给了3个域、35个任务,这些只给了条目性的提纲,对应着考试的实践要求。 考试题目全部是基于情境的选择题,…