5 Tensorflow图像识别(下)模型构建

news2025/1/24 2:11:29

上一篇:4 Tensorflow图像识别模型——数据预处理-CSDN博客

1、数据集标签

上一篇介绍了图像识别的数据预处理,下面是完整的代码:

import os
import tensorflow as tf

# 获取训练集和验证集目录
train_dir = os.path.join('cats_and_dogs_filtered/train')
validation_dir = os.path.join('cats_and_dogs_filtered/validation')


# 模型参数设置
BATCH_SIZE = 100

# 图片尺寸统一为150*150
IMG_SHAPE = 150

# 处理图像尺寸
img_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1. / 255, horizontal_flip=True, )

train_data_gen = img_generator.flow_from_directory(directory=train_dir,
                                                   shuffle=True,
                                                   batch_size=BATCH_SIZE,
                                                   target_size=(IMG_SHAPE, IMG_SHAPE),
                                                   class_mode='binary')
val_data_gen = img_generator.flow_from_directory(directory=validation_dir,
                                                 shuffle=True,
                                                 batch_size=BATCH_SIZE,
                                                 target_size=(IMG_SHAPE, IMG_SHAPE),
                                                 class_mode='binary')

上一篇提到系统的输入是“特征-标签”对,特征是输入的图片,标签就是标记该图片是猫还是狗。上面的代码如何知道输入的照片是猫还是狗?

这里用到了keras的一个函数flow_from_directory(),从目录中生成数据流,子目录会自动帮你生成标签。先看看train训练集的这两个子目录生成的标签是什么:

使用下面代码查看

print(train_data_gen.class_indices)

运行结果:

Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.
{'cats': 0, 'dogs': 1}

从运行结果可以看到,猫的照片系统自动打上了0的标签,狗的标签是1。

2、Relu激活函数

构建模型的完整代码如下:

import os
import tensorflow as tf
import numpy as np



# 获取训练集和验证集目录
train_dir = os.path.join('cats_and_dogs_filtered/train')
validation_dir = os.path.join('cats_and_dogs_filtered/validation')

# 模型参数设置
BATCH_SIZE = 100

# 图片尺寸统一为150*150
IMG_SHAPE = 150

# 处理图像尺寸
img_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1. / 255, horizontal_flip=True, )

# 训练数据
train_data_gen = img_generator.flow_from_directory(directory=train_dir,
                                                   shuffle=True,
                                                   batch_size=BATCH_SIZE,
                                                   target_size=(IMG_SHAPE, IMG_SHAPE),
                                                   class_mode='binary')

# 验证数据
val_data_gen = img_generator.flow_from_directory(directory=validation_dir,
                                                 shuffle=True,
                                                 batch_size=BATCH_SIZE,
                                                 target_size=(IMG_SHAPE, IMG_SHAPE),
                                                 class_mode='binary')


model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)),
    tf.keras.layers.MaxPooling2D(2, 2),

    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),

    tf.keras.layers.Conv2D(100, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dense(2)

])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy']
              )

EPOCHS = 20
history = model.fit_generator(
    train_data_gen,
    steps_per_epoch=int(np.ceil(2000 / float(BATCH_SIZE))),
    epochs=EPOCHS,
    validation_data=val_data_gen,
    validation_steps=int(np.ceil(1000 / float(BATCH_SIZE)))
)

model中加入了和之前不一样的代码:

tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)),

这里使用了卷积神经,主要是为了突出区分不同对象的特征。一张图片的信息很多的,但往往我们只需要一些特征进行训练就可以了,后续会详细介绍。

现在先介绍 activation='relu',激活函数Relu。

ReLU,全称是线性整流函数(Rectified Linear Unit),是人工神经网络中常用的激活函数。它的图像如下:

当x<=0时,f(x)=0;

当x>0时,f(x)=x;

可以运行代码看看:

例1:

import tensorflow as tf


x = -19
print(tf.nn.relu(x))

运行结果:
tf.Tensor(0, shape=(), dtype=int32)

输入-19,使用relu激活函数后的结果为0

例2:

import tensorflow as tf


x = 8
print(tf.nn.relu(x))

运行结果:

tf.Tensor(8, shape=(), dtype=int32)

3、损失函数

代码:


model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy']
              )
 

其中损失函数为SparseCategoricalCrossentropy,它是用于计算多分类问题的交叉熵,如果是两个或两个以上的分类问题可以始终这样设置。对其原理及计算过程的读者可以自行百度,此处不详细介绍。

4、训练过程详解
(1)训练准确率

运行上面的完整代码:

可以看到训练集和验证集的loss值在慢慢下降,准确率在提升。

划线部分是最后一个epoch的训练结果:

accuracy:0.8175,也就是说你的神经网络在分类训练数据方面的准确率约为82%;

val_accuracy:0.7160,在验证集的准确率约为72%

(2)batch_size批次大小

代码中batche_size设置的大小为100,意思是每批次生成的样本数量为100。

例如上述代码的train训练集一共有2000张图片,一个周期(epoch)分20个批次(2000/100=20)样本数据进行训练,每个批次训练完后利用优化器更新模型参数。

所以一个周期(epoch)的模型参数更新次数就是20:2000/batch_size=20

截图中红色部分,就是一个epoch分了20个批次用来更新模型参数。

训练结果会因为模型的参数的设置、训练集图片的数量等等原因结果大不相同,学习的时候可以自己动手去调整模型参数来看看训练结果。



 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1185130.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

有关python库

官方库 #1、导入某模块 import os #2、导入OS模块中的system方法 from os import system #3、导入某模块中的孙子模块中的xx方法&#xff0c;并重命名 from module.xx.xx import xx as rename #4、导入OS中的所有模块 #不用进行OS.method(),直接method&#xff08;&#xff0…

ST 任意内核 移植freertos系统

FREERTOS系统移植&#xff0c;先下载系统文件并解压如下 keil5移植后效果如图 注意事项 注意内核类型&#xff0c;ST的f1为M3,F4为m4&#xff0c;h7&#xff0c;f7为m7 再include包含.h路径即可 任务函数一定要写到while&#xff08;1&#xff09;否则无法运行 void lvgl_demo…

AVL树 c语言版本 插入部分

目录 引入平衡树 为什么要变平衡 怎么判断是否需要变平衡 怎么变平衡 LL型失衡 RR型失衡 LR型失衡 RL型失衡 补充 左旋补充 右旋补充 Code 开辟一个新节点 初始化 获取树的高度 左旋函数 更新树高 树高的求法 右旋转函数 插入 InsertNode() 更新树高 getbala…

ZZ308 物联网应用与服务赛题第G套

2023年全国职业院校技能大赛 中职组 物联网应用与服务 任 务 书 &#xff08;G卷&#xff09; 赛位号&#xff1a;______________ 竞赛须知 一、注意事项 1.检查硬件设备、电脑设备是否正常。检查竞赛所需的各项设备、软件和竞赛材料等&#xff1b; 2.竞赛任务中所使用…

星岛专栏|从Web3发展看金融与科技的融合之道

11月起&#xff0c;欧科云链与香港主流媒体星岛集团开设Web3.0安全技术专栏&#xff0c;该专栏主要面向香港从业者、交易机构、监管机构输出专业性的安全合规建议&#xff0c;旨在促进香港Web3.0行业向安全与合规发展。 出品&#xff5c;欧科云链研究院 自2016年首届香港金融…

P02项目诊断报警组件(学习操作日志记录、单元测试开发)

★ P02项目诊断报警组件 诊断报警组件的主要功能有&#xff1a; 接收、记录硬件设备上报的报警信息。从预先设定的错误码对照表中找到对应的声光报警和蜂鸣器报警策略&#xff0c;结合当前的报警情况对设备下发报警指示。将报警消息发送到消息队列&#xff0c;由其它组件发送…

Linux常用命令及项目部署

目录 Linux介绍 Linux环境 下载xshell 常见的Linux命令 搭建Java部署环境 1.jdk 2.tomcat 3.mysql 进行部署 Linux介绍 Linux操作系统是和Windows并列的关系&#xff0c;Linux主要通过命令行进行操作的。 Linux环境 1.使用虚拟机&#xff0c;电脑上安装虚拟机软件 …

快来看看如何拿下7+干湿结合生信思路。赶紧拿起笔记本

今天给同学们分享一篇生信文章“CENPF/CDK1 signaling pathway enhances the progression of adrenocortical carcinoma by regulating the G2/M-phase cell cycle”&#xff0c;这篇文章发表在J Transl Med期刊上&#xff0c;影响因子为7.4。 结果解读&#xff1a; CENPF在AC…

Python实现图片与PDF互相转换

目录 图片转PDF文件夹所有图片转为1个PDF文件夹指定图片转为1个PDF文件夹所有图片分别转为PDF举例 PDF转图片指定PDF转为图片文件夹所有PDF转为图片举例 图片转PDF 之前的一篇博客《Python合并拼接图片》&#xff0c;可对图片进行合并拼接 使用前需要安装PyMuPDF库&#xff0c…

lua脚本实现redis分布式锁(脚本解析)

文章目录 lua介绍lua基本语法redis执行lua脚本 - EVAL指令使用lua保证删除原子性 lua介绍 Lua 是一种轻量小巧的脚本语言&#xff0c;用标准C语言编写并以源代码形式开放&#xff0c; 其设计目的是为了嵌入应用程序中&#xff0c;从而为应用程序提供灵活的扩展和定制功能。 设…

最适合后端程序员要学的前端基本知识js版本

本文章适合有后端基础&#xff0c;也对前端知识有所了解的同学&#xff0c;简单而精准的了解后端程序员要了解的前端知识。让自己能看懂前端js&#xff0c;html和css版本后续写出。。。。 自定义对象 在对象中创建对象的行为时若是要调用对象属性&#xff0c;要加this 列如&a…

kubernetes 认证授权

目录 一、kubernetes API访问控制 二、pod绑定sa 三、认证 四、授权 一、kubernetes API访问控制 Authentication&#xff08;认证&#xff09;认证方式现共有8种&#xff0c;可以启用一种或多种认证方式&#xff0c;只要有一种认证方式通过&#xff0c;就不再进行其它方式…

【Axure高保真原型】树切换动态面板案例

今天和大家分享树切换动态面板的原型模板&#xff0c;点击树的箭头可以打开或者收起子节点&#xff0c;点击最后一级人物节点&#xff0c;可以切换右侧面板的状态到对应的页面&#xff0c;左侧的树是通过中继器制作的&#xff0c;使用简单&#xff0c;只需要按要求填写中继器表…

人大金仓三大兼容:SQL Server迁移无忧

SQL Server在数据库领域一直占据着重要地位。作为一款成熟稳定的关系型数据库管理系统&#xff0c;SQL Server在国内有着广泛的用户群体&#xff0c;医疗、海关、政务等行业的核心业务系统多采用SQL Server数据库。随着政策与市场的双重驱动&#xff0c;信息技术应用创新产业的…

【快速使用ShardingJDBC的哈希分片策略进行分表】

文章目录 &#x1f50a;博主介绍&#x1f964;本文内容&#x1f34a;1.引入maven依赖&#x1f34a;2.启动类上添加注解MapperScan&#x1f34a;3.添加application.properties配置&#x1f34a;4.普通的自定义实体类&#x1f34a;5.写个测试类验证一下&#x1f34a;6.控制台打印…

竞赛 目标检测-行人车辆检测流量计数

文章目录 前言1\. 目标检测概况1.1 什么是目标检测&#xff1f;1.2 发展阶段 2\. 行人检测2.1 行人检测简介2.2 行人检测技术难点2.3 行人检测实现效果2.4 关键代码-训练过程 最后 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 行人车辆目标检测计数系统 …

Linux之make/maakefile

access不是实时更新的。 printf打印并不是直接给屏幕而是先放到缓冲区。 可以通过fflush(stdout)强制刷新缓冲区。 换行是指直接到同一位置的下一行&#xff0c;回车是回到开头。

【MySQL数据库】 七

本文主要介绍了Java的JDBC编程的过程. 超详细 !!! 一.JDBC JDBC就是通过Java代码,来操作数据库 由于我们在实际开发中,绝大部分都是用代码来操作数据库的 , 因此一个成熟的数据库,都会提供一些API让程序员来使用. 常见的数据库比如mysql / oracl / sqlserver / SQLite 都会提…

Python类与对象:类的定义、实例化、方法、属性、构造函数

文章目录 类的定义类的实例化方法属性构造函数Python 类和对象是面向对象编程的基础。在 Python 中,几乎所有东西都是对象,拥有属性和方法。类是创建对象的蓝图或模板。让我们一步步来探索类的定义、实例化、方法、属性以及构造函数,并提供详细的代码示例。 类的定义 在 P…

SQL注入漏洞 其他注入

文章目录 宽字节注入案例 HTTP头部注入Cookie注入base64User-Agent注入Referer 注入 SQL注入读写文件条件1.是否拥有读写权限2.文件路径3.secure_file_priv 读取文件写入文件 SQLMap安装sqlmapkail 源安装仓库克隆 参数简介快速入门&#xff1b;SQLmap&#xff08;常规&#xf…