Tensorflow实现深度学习案例7:咖啡豆识别

news2024/12/24 21:22:09

本文为为🔗365天深度学习训练营内部文章

原作者:K同学啊

一、前期工作

1. 导入数据

from tensorflow       import keras
from tensorflow.keras import layers,models
import numpy             as np
import matplotlib.pyplot as plt
import os,PIL,pathlib
import tensorflow as tf
import warnings as w
w.filterwarnings('ignore')

data_dir = "./coffee/"
data_dir = pathlib.Path(data_dir)

image_count = len(list(data_dir.glob('*/*.png')))

print("图片总数为:",image_count)
图片总数为: 1200

二、数据预处理

1. 加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

batch_size = 32
img_height = 224
img_width = 224
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

Found 1200 files belonging to 4 classes.
Using 960 files for training.

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

Found 1200 files belonging to 4 classes.
Using 240 files for validation.

class_names = train_ds.class_names
print(class_names)
['Dark', 'Green', 'Light', 'Medium']

2.数据可视化 

plt.figure(figsize=(10, 4))  # 图形的宽为10高为5

for images, labels in train_ds.take(1):
    for i in range(10):
        
        ax = plt.subplot(2, 5, i + 1)  

        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        
        plt.axis("off")
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

(32, 224, 224, 3)
(32,)

3. 配置数据集

  • shuffle() :打乱数据,关于此函数的详细介绍可以参考:https://zhuanlan.zhihu.com/p/42417456
  • prefetch() :预取数据,加速运行,其详细介绍可以参考我前两篇文章,里面都有讲解。
  • cache() :将数据集缓存到内存当中,加速运行
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

 并且将数据归一化

normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)

train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
val_ds   = val_ds.map(lambda x, y: (normalization_layer(x), y))

image_batch, labels_batch = next(iter(val_ds))
first_image = image_batch[0]

# 查看归一化后的数据
print(np.min(first_image), np.max(first_image))

 0.0 1.0

三、构建VGG-16网络

1.VGG优缺点分析:

  • VGG优点

VGG的结构非常简洁,整个网络都使用了同样大小的卷积核尺寸(3x3)和最大池化尺寸(2x2)

  • VGG缺点

1)训练时间过长,调参难度大。2)需要的存储容量大,不利于部署。例如存储VGG-16权重值文件的大小为500多MB,不利于安装到嵌入式系统中。

2.网络结构图

结构说明:

  • 13个卷积层(Convolutional Layer),分别用blockX_convX表示
  • 3个全连接层(Fully connected Layer),分别用fcXpredictions表示
  • 5个池化层(Pool layer),分别用blockX_pool表示

VGG-16包含了16个隐藏层(13个卷积层和3个全连接层),故称为VGG-16

 

model = tf.keras.applications.VGG16(weights='imagenet')
model.summary()
Model: "vgg16"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 block1_conv1 (Conv2D)       (None, 224, 224, 64)      1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 224, 224, 64)      36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 112, 112, 64)      0         
                                                                 
 block2_conv1 (Conv2D)       (None, 112, 112, 128)     73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 112, 112, 128)     147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, 56, 56, 128)       0         
                                                                 
 block3_conv1 (Conv2D)       (None, 56, 56, 256)       295168    
                                                                 
 block3_conv2 (Conv2D)       (None, 56, 56, 256)       590080    
                                                                 
 block3_conv3 (Conv2D)       (None, 56, 56, 256)       590080    
                                                                 
 block3_pool (MaxPooling2D)  (None, 28, 28, 256)       0         
                                                                 
 block4_conv1 (Conv2D)       (None, 28, 28, 512)       1180160   
                                                                 
 block4_conv2 (Conv2D)       (None, 28, 28, 512)       2359808   
                                                                 
 block4_conv3 (Conv2D)       (None, 28, 28, 512)       2359808   
                                                                 
 block4_pool (MaxPooling2D)  (None, 14, 14, 512)       0         
                                                                 
 block5_conv1 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_conv2 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_conv3 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_pool (MaxPooling2D)  (None, 7, 7, 512)         0         
                                                                 
 flatten (Flatten)           (None, 25088)             0         
                                                                 
 fc1 (Dense)                 (None, 4096)              102764544 
                                                                 
 fc2 (Dense)                 (None, 4096)              16781312  
                                                                 
 predictions (Dense)         (None, 1000)              4097000   
                                                                 
=================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
_________________________________________________________________

四、编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
# 设置初始学习率
initial_learning_rate = 1e-4

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate, 
        decay_steps=30,      # 敲黑板!!!这里是指 steps,不是指epochs
        decay_rate=0.92,     # lr经过一次衰减就会变成 decay_rate*lr
        staircase=True)

# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)

model.compile(optimizer=opt,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

五、训练模型 

epochs = 20

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)
Epoch 1/20
30/30 [==============================] - 346s 11s/step - loss: 1.7546 - accuracy: 0.2625 - val_loss: 1.4646 - val_accuracy: 0.2125
Epoch 2/20
30/30 [==============================] - 352s 12s/step - loss: 1.3637 - accuracy: 0.3104 - val_loss: 1.0428 - val_accuracy: 0.4583
Epoch 3/20
30/30 [==============================] - 338s 11s/step - loss: 0.7237 - accuracy: 0.6458 - val_loss: 0.4818 - val_accuracy: 0.7833
Epoch 4/20
30/30 [==============================] - 336s 11s/step - loss: 0.3633 - accuracy: 0.8479 - val_loss: 1.1034 - val_accuracy: 0.6167
Epoch 5/20
30/30 [==============================] - 340s 11s/step - loss: 0.2880 - accuracy: 0.8927 - val_loss: 0.1480 - val_accuracy: 0.9500
Epoch 6/20
30/30 [==============================] - 338s 11s/step - loss: 0.1802 - accuracy: 0.9333 - val_loss: 0.4709 - val_accuracy: 0.8458
Epoch 7/20
30/30 [==============================] - 334s 11s/step - loss: 0.1468 - accuracy: 0.9490 - val_loss: 0.0214 - val_accuracy: 1.0000
Epoch 8/20
30/30 [==============================] - 339s 11s/step - loss: 0.0174 - accuracy: 0.9969 - val_loss: 0.0196 - val_accuracy: 0.9875
Epoch 9/20
30/30 [==============================] - 329s 11s/step - loss: 0.0399 - accuracy: 0.9875 - val_loss: 0.2539 - val_accuracy: 0.9292
Epoch 10/20
30/30 [==============================] - 330s 11s/step - loss: 0.2606 - accuracy: 0.9073 - val_loss: 0.0737 - val_accuracy: 0.9917
Epoch 11/20
30/30 [==============================] - 334s 11s/step - loss: 0.0610 - accuracy: 0.9812 - val_loss: 0.0070 - val_accuracy: 1.0000
Epoch 12/20
30/30 [==============================] - 341s 11s/step - loss: 0.0296 - accuracy: 0.9917 - val_loss: 0.0256 - val_accuracy: 0.9875
Epoch 13/20
30/30 [==============================] - 335s 11s/step - loss: 0.0252 - accuracy: 0.9917 - val_loss: 0.0431 - val_accuracy: 0.9833
Epoch 14/20
30/30 [==============================] - 345s 12s/step - loss: 0.0058 - accuracy: 0.9979 - val_loss: 0.0088 - val_accuracy: 0.9958
Epoch 15/20
30/30 [==============================] - 557s 19s/step - loss: 0.0015 - accuracy: 1.0000 - val_loss: 0.0144 - val_accuracy: 0.9917
Epoch 16/20
30/30 [==============================] - 340s 11s/step - loss: 3.6823e-04 - accuracy: 1.0000 - val_loss: 0.0052 - val_accuracy: 0.9958
Epoch 17/20
30/30 [==============================] - 347s 12s/step - loss: 5.9116e-05 - accuracy: 1.0000 - val_loss: 0.0064 - val_accuracy: 0.9958
Epoch 18/20
30/30 [==============================] - 347s 12s/step - loss: 2.5309e-05 - accuracy: 1.0000 - val_loss: 0.0048 - val_accuracy: 0.9958
Epoch 19/20
30/30 [==============================] - 350s 12s/step - loss: 1.0864e-05 - accuracy: 1.0000 - val_loss: 0.0033 - val_accuracy: 1.0000
Epoch 20/20
30/30 [==============================] - 341s 11s/step - loss: 6.0013e-06 - accuracy: 1.0000 - val_loss: 0.0045 - val_accuracy: 0.9958

 六 可视化结果

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

预测图片 

import numpy as np

# 采用加载的模型(new_model)来看预测结果
plt.figure(figsize=(18, 3))  # 图形的宽为18高为5
plt.suptitle("预测结果展示")

for images, labels in val_ds.take(1):
    for i in range(8):
        ax = plt.subplot(1,8, i + 1)  
        
        # 显示图片
        plt.imshow(images[i].numpy())
        
        # 需要给图片增加一个维度
        img_array = tf.expand_dims(images[i], 0) 
        
        # 使用模型预测图片中的人物
        predictions = model.predict(img_array)
        plt.title(class_names[np.argmax(predictions)])

        plt.axis("off")
1/1 [==============================] - 0s 279ms/step
1/1 [==============================] - 0s 110ms/step
1/1 [==============================] - 0s 118ms/step
1/1 [==============================] - 0s 109ms/step
1/1 [==============================] - 0s 110ms/step
1/1 [==============================] - 0s 104ms/step
1/1 [==============================] - 0s 111ms/step
1/1 [==============================] - 0s 115ms/step

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2056031.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

地平线旭日X3开发板--图像获取时间戳问题

需求 需要获得图像接收完成后的帧时间戳。 sensor f37, MIPI 通信 问题 按我的了解,一般是在内核中产生MIPI数据接收完成中断并打印时间戳, 一般是CLOCK_MONOTONIC方式的时间 , X3无法获得MIPI数据接收完成的时间戳。 X3平台HB_VIN_GetC…

4 - Linux远程访问及控制

目录 一、SSH远程管理 1. SSH概述 2.SSH的优点 3.配置OpenSSH客户端 4.sshd服务支持的两种验证方式 5. 使用SSH客户端程序 5.1 ssh - 远程登录 5.2 scp - 远程复制 6.配置密钥对验证 二、TCP Wrappers访问控制 1.TCP Wrappers 概述 2. TCP Wrappers 机制的基本原则 …

MS SQL Server partition by 函数实战二 编排考场人员

目录 需求 输出效果 范例运行环境 表及视图样本设计 功能实现 生成考场数据 生成重复的SQL语句 封装为统计视图 编写存储过程实现统计 小结 需求 假设有若干已分配准考证号的考生,准考证号示例(01010001)共计8位,前4位…

ZeroEA阅读笔记

ZeroEA阅读笔记 摘要 实体对齐(EA)是知识图(KG)研究中的一项关键任务,旨在识别不同知识图谱中的等效实体,以支持知识图谱集成、文本到SQL和问答系统等下游任务。考虑到KG中丰富的语义信息,预训练语言模型(PLM)凭借其卓越的上下文感知编码功…

使用SSMS操作AdventureWorks 示例数据库

简介 AdventureWorks 示例数据库,官方文档:https://learn.microsoft.com/zh-cn/sql/samples/adventureworks-install-configure?viewsql-server-ver16&tabsssms 下载备份文件 OLTP 数据适用于大多数典型的联机事务处理工作负载。数据仓库 (DW) 数据…

网络设备监控工具 PIGOSS BSM 网络设备-Ruijie设备SNMP配置及监控

目录 1. 全局模式 2. 配置SNMP V2 3. 配置SNMP V3 4. 配置SNMP Trap 5. 保存配置 6. 查看配置结果 7. 锐捷设备监控 1. 全局模式 SNMP 的配置工作在网络设备的全局配置模式下完成,在进行SNMP 配置前,请先进入全局配置模式。 Ruijie>enable …

Excel“取消工作表保护”忘记密码并恢复原始密码

文章目录 1.前言2.破解步骤3. 最终效果4.参考文献 1.前言 有时候别人发来的Excel中有些表格不能编辑,提示如下,但是又不知道原始密码 2.破解步骤 1、打开您需要破解保护密码的Excel文件; 2、依次点击菜单栏上的视图—宏----录制宏&#xf…

Spring Boot内嵌Tomcat处理请求的链接数和线程数

Spring Boot内嵌Tomcat处理请求的连接数和线程数 处理请求的连接数和线程数配置 Spring Boot的配置项 #等待连接数 server.tomcat.accept-count100 #最大链连接数 server.tomcat.max-connections8192#最小备用线程数 server.tomcat.threads.min-spare10 #最大工作线程数 ser…

【git命令相关】git上传和删除文件步骤

(一)git登录 1. git bash窗口输入 git config --global user.name "你的Git账号" git config --global user. Email "你的Git邮箱"2. 生成密钥 ssh-keygen -t rsa -C "你的Git邮箱"在此命令执行的返回结果中找到key存放…

海康VisionMaster使用学习笔记11-VisionMaster基本操作

VisionMaster基本操作 VM示例方案 1. 工具拖拽及使用方式 分别从采集和定位栏里拖拽图像源,快速匹配,Blob分析工具 2. 模块连线 依次连线 3.如何订阅 点击快速匹配,可以看到输入源已订阅了图像1的图像,Blob分析类似 4. 方案操作及全局触发 点击快速匹配,创建特征模版,框选…

vue-cli搭建过程,elementUI搭建使用过程

vue-cli vue-cli 官方提供的一个脚手架,用于快速生成一个 vue 的项目模板;预先定义 好的目录结构及基础代码,就好比咱们在创建 Maven 项目时可以选择创建一个 骨架项目,这个骨架项目就是脚手架,我们的开发更加的快速。…

深兰科技荣获CFS第十三届财经峰会“2024杰出出海品牌引领奖”

近日,以“向新而行,新质生产力激发新活力”为主题的“CFS2024第十三届财经峰会暨Amazing 2024创新企业家节”在北京隆重开幕。峰会揭晓了第十三届“CFS 2024企业奖”的评选结果,深兰科技凭借自身在AI机器人出口和海外市场开拓等品牌全球化方面…

60KW~180KW一体式充电桩电路方案!

本次小编给大家带来了一款60KW~180KW的一体式充电桩电路方案,本方案包含接线图,电路原理图,PCB图,BOM,协议说明,产品标准等资料! 下载链接!https://t.1yb.co/KW1R 本方案采用STM32F…

std::wcout,std::cout控制台输出中文乱码,std::cerr字符串的字符无效

系列文章目录 文章目录 系列文章目录前言一、中文乱码原因二、解决方法1.如果是windos11下,使用英文语言,需要加以下代码2.如果是中文语言只需要一行关键代码3.如果在异常处理中显示宽字符中文4.完整代码如下:实现文件测试代码输出打印 前言 …

【图像特效系列】图像毛玻璃特效的实践 | 包含代码和效果图

目录 一 毛玻璃特效 1 代码 2 效果图 图像特效系列主要是对输入的图像进行处理,生成指定特效效果的图片。图像素描特效会将图像的边界都凸显出来;图像怀旧特效是指图像经历岁月的昏暗效果;图像光照特效是指图像存在一个类似于灯光的光晕特效,图像像素值围绕光照中心点呈…

极光推送(JPush)携手中大英才,打造智慧教育新模式

随着互联网技术的快速发展,在线教育行业蓬勃兴起,用户对学习体验的要求也越来越高。作为国内领先的职业技能知识培训服务商,中大英才(北京)网络教育科技有限公司(简称“中大英才”)始终致力于为多层次求知学习人士提供专业化、智能化和科学化…

实战演练:通过API获取商品详情并展示

实战演练:通过API获取商品详情并展示,通常涉及以下几个步骤:确定API接口、发送HTTP请求、处理响应数据、以及将数据展示给用户。这里我们以一个假想的商品详情API为例,使用Python语言和requests库来完成这个任务。 步骤 1: 确定A…

DMHS数据同步工具

DMHS数据同步工具 ​ 本章节主要介绍DM数据同步工具DMHS的使用,通过将oracle11g的数据同步到DM8的过程来理解DMHS的功能和作用。 安装前的准备 端口、服务信息 IP地址服务名称版本端口安装路径192.168.19.136OracleOracle11.0.21521/opt/oracle/DMHS源端dmhs_V3…

第100+22步 ChatGPT学习:概率校准 Platt Scaling

基于Python 3.9版本演示 一、写在前面 最近看了一篇在Lancet子刊《eClinicalMedicine》上发表的机器学习分类的文章:《Development of a novel dementia risk prediction model in the general population: A large, longitudinal, population-based machine-learn…

MapBox Android版开发 1 配置

MapBox Android版开发 1 配置 前言MapBox V9 配置创建工程配置地图配置私钥配置公钥配置仓库配置依赖配置权限地图初始化 显示地图布局文件地图Activity 运行效果 MapBox V11 配置创建工程配置地图配置私钥配置公钥配置仓库配置依赖配置权限 显示地图布局文件 运行效果 前言 本…