基于tensorflow的咖啡豆识别

news2024/11/18 15:45:21
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

一、前期工作

1. 设置GPU

import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")
    print("GPU is available")

2. 导入数据

from tensorflow       import keras
from tensorflow.keras import layers,models
import numpy             as np
import matplotlib.pyplot as plt
import os,PIL,pathlib

data_dir = "F:/host/Data/咖啡豆识别数据/"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*.png')))

print("图片总数为:",image_count)

在这里插入图片描述

二、数据预处理

1. 加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

batch_size = 8
img_height = 224
img_width = 224
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.1,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)

在这里插入图片描述

2. 可视化数据

plt.figure(figsize=(10, 4))  # 图形的宽为10高为5

for images, labels in train_ds.take(1):
    for i in range(8):
        
        ax = plt.subplot(2, 4, i + 1)  

        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        
        plt.axis("off")

在这里插入图片描述

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

在这里插入图片描述

3. 配置数据集

  • shuffle() :打乱数据,关于此函数的详细介绍可以参考:https://zhuanlan.zhihu.com/p/42417456
  • prefetch() :预取数据,加速运行,其详细介绍可以参考我前两篇文章,里面都有讲解。
  • cache() :将数据集缓存到内存当中,加速运行
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)

train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
val_ds   = val_ds.map(lambda x, y: (normalization_layer(x), y)) 
image_batch, labels_batch = next(iter(val_ds))
first_image = image_batch[0]

# 查看归一化后的数据
print(np.min(first_image), np.max(first_image))

三、构建VGG-16网络

from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout

def VGG16(nb_classes, input_shape):
    # 输入层
    input_tensor = Input(shape=input_shape)
    # 卷积层1
    x = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv1')(input_tensor)
    x = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv2')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block1_pool')(x)
    # 卷积层2
    x = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv1')(x)
    x = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv2')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block2_pool')(x)
    # 卷积层3
    x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv1')(x)
    x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv2')(x)
    x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv3')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block3_pool')(x)
    # 卷积层4
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv1')(x)
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv2')(x)
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv3')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block4_pool')(x)
    # 卷积层5
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv1')(x)
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv2')(x)
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv3')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block5_pool')(x)
    # 展平层
    x = Flatten()(x)
    # 全连接层1
    x = Dense(4096, activation='relu',name='fc1')(x)
    # 全连接层2
    x = Dense(4096, activation='relu',name='fc2')(x)
    # 输出层
    output_tensor = Dense(nb_classes, activation='softmax',name='predictions')(x)
    # 创建模型
    model = Model(input_tensor, output_tensor)
    return model

# 创建模型
model=VGG16(len(class_names), (img_width, img_height, 3))

# 打印模型结构
model.summary()

在这里插入图片描述

3. 网络结构图

关于卷积的相关知识可以参考文章:https://mtyjkh.blog.csdn.net/article/details/114278995

结构说明:

  • 13个卷积层(Convolutional Layer),分别用blockX_convX表示
  • 3个全连接层(Fully connected Layer),分别用fcX与predictions表示
  • 5个池化层(Pool layer),分别用blockX_pool表示

VGG-16包含了16个隐藏层(13个卷积层和3个全连接层),故称为VGG-16

在这里插入图片描述

四、编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
# 设置初始学习率
initial_learning_rate = 1e-4

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate, 
        decay_steps=30,      # 敲黑板!!!这里是指 steps,不是指epochs
        decay_rate=0.92,     # lr经过一次衰减就会变成 decay_rate*lr
        staircase=True)

# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)

model.compile(optimizer=opt,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

五、训练模型

epochs = 20

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)

在这里插入图片描述

六、可视化结果

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

七、个人小结

在本次咖啡豆识别项目中,我们通过设置GPU、导入并预处理数据、构建深度学习模型,以及对模型进行训练和评估,实现了对咖啡豆图像的自动识别。整个过程涵盖了数据加载与可视化、数据集配置、模型构建与优化等关键步骤,最终显著提升了图像分类的准确性,同时也加深了我们对深度学习技术的实践理解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1710588.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

远程桌面连接--“发生身份验证错误。要求的函数不受支持”

出现身份验证错误 要求的函数不受支持的问题,可以通过以下几种方法尝试解决:12 对于Windows 10家庭版用户,需要修改注册表信息。具体步骤如下: 按下WIN R,输入regedit,点击确定,打开注册表编辑…

openresty(Nginx) 隐藏 软包名称及版本号 升级版本

1 访问错误或者异常的URL 2 修改配置,重新编译,升级 #修改版本等 vim ./bundle/nginx-1.13.6/src/core/nginx.h #define nginx_version 1013006 #define NGINX_VERSION "1.13.6" #define NGINX_VER "openresty/&q…

python中的-1是什么意思

python中的-1是什么意思? -1指的是索引,即列表的最后一个元素。 比如你输入一个列表: a = [1,2,3,4,5,6,7] a[-1]就代表索引该列表最后一个值,你可以 b a[-1] print(b) 结果如下: 7 索引从左往右是…

5.28学习总结

java复习总结 hashcode()和equals() hashcode():在Object里这个方法是通过返回地址的整数值来生成哈希值。 equals():在Object里这个方法是通过比较他们的内存地址来确定两个对象是否相同。 运行效率:hashcode的时间复杂度为O(1)(因为只要计算一次哈…

SpringCloud之SSO单点登录-基于Gateway和OAuth2的跨系统统一认证和鉴权详解

单点登录(SSO)是一种身份验证过程,允许用户通过一次登录访问多个系统。本文将深入解析单点登录的原理,并详细介绍如何在Spring Cloud环境中实现单点登录。通过具体的架构图和代码示例,我们将展示SSO的工作机制和优势&a…

mysql 8 [HY000][1114] The table ‘/tmp/#sql4c3_3e5a0_2‘ is full

分组有个比较大的表,出现了临时表空间满了的情况; 试用该sql 语句: SHOW GLOBAL VARIABLES LIKE internal_tmp_mem_storage_engine; 可以看到 默认临时结果是用临时表存的,在mysql的my.cnt可以改临时空间的大小 但是磁盘哪有内…

2、python环境的安装-mac系统下

打开官网,downloads下边有macOS,点击: 选择最新版本,点击,进入下边的页面,一直往下滑,看到files中有个macOS的版本,点击下载 点击下载后是pkg的安装包,点击安装。 一步步…

浙江大学数据结构MOOC-课后习题-第九讲-排序3 Insertion or Heap Sort

题目汇总 浙江大学数据结构MOOC-课后习题-拼题A-代码分享-2024 题目描述 测试点 思路分析 和上一题的思路一样&#xff0c;每进行一次迭代&#xff0c;来验证当前序列是否和给定的序列相同 代码展示 #include <cstdlib> #include <iostream> #define MAXSIZE 10…

代码随想录算法训练营第七天| 454.四数相加II 、383. 赎金信、 15. 三数之和、18. 四数之和

454.四数相加II 题目链接&#xff1a; 454.四数相加II 文档讲解&#xff1a;代码随想录 状态&#xff1a;没做出来&#xff0c;没想到考虑重复的情况&#xff01; 题解&#xff1a; public int fourSumCount(int[] nums1, int[] nums2, int[] nums3, int[] nums4) {// 结果计数…

100个 Unity小游戏系列三 -Unity 抽奖游戏专题一 转盘抽奖游戏

一 、效果展示 二、知识点 2.1 布局需要实现功能 1、转动的根目录为itemSpinRoot 2、创建对应的item 3、每个item转动的角度 2.2 代码 public class WheelDialog : UIBase{[SerializeField] Button btnClick;[SerializeField] Button btnClose;[SerializeField] Sprite[] ite…

【错误记录】HarmonyOS 运行报错 ( Failure INSTALL _PARSE _FAILED _USESDK _ERROR )

文章目录 一、报错信息二、问题分析三、解决方案 一、报错信息 在 DevEco Studio 中 , 使用 远程设备 , 向 P40 Failure[INSTALL_PARSE_FAILED_USESDK_ERROR] compileSdkVersion and releaseType of the app do not match the apiVersion and releaseType on the device. 二、…

【蓝桥杯】第十四届蓝桥杯大赛软件赛国赛C/C++ 大学 B 组

答题结果页 - 蓝桥云课 (lanqiao.cn) 0子2023 - 蓝桥云课 (lanqiao.cn)&#xff08;暴力枚举 #include<bits/stdc.h> using lllong long; using ullunsigned long long; #define fir first #define sec second //#define int llconst int N1e510; const int mod1e97;int…

使用prometheus监测MySQL主从同步状态方案

说明&#xff1a;本文介绍如何使用prometheus、alertmanager监测MySQL主从&#xff0c;当从节点中断同步时&#xff0c;发送邮箱报警&#xff0c;并使用grafana将数据视图化。 结构图如下&#xff1a; 安装 &#xff08;1&#xff09;安装应用 首先&#xff0c;来安装promet…

谷歌推出TransformerFAM架构,以更低的消耗处理长序列文本

Transformer对大模型界的影响力不言而喻&#xff0c;ChatGPT、Sora、Stable Difusion等知名模型皆使用了该架构。 但有一个很明显的缺点&#xff0c;其注意力复杂度的二次方增长在处理书籍、PDF等超长文档时会显著增加算力负担。 虽然会通过滑动窗口注意力和稀疏注意力等技术…

全栈实现图片验证码及知识补充 全栈开发之路——全栈篇(4)

全栈开发一条龙——前端篇 第一篇&#xff1a;框架确定、ide设置与项目创建 第二篇&#xff1a;介绍项目文件意义、组件结构与导入以及setup的引入。 第三篇&#xff1a;setup语法&#xff0c;设置响应式数据。 第四篇&#xff1a;数据绑定、计算属性和watch监视 第五篇 : 组件…

100个 Unity小游戏系列六 -Unity 抽奖游戏专题四 翻卡游戏

一、演示效果 二、知识点讲解 2.1 布局 void CreateItems(){reward_data_list reward_data_list ?? new List<RewardData>();reward_data_list.Clear();for (int i 0; i < ItemCount; i){GameObject item;if (i 1 < itemParent.childCount){item itemParent…

AI革命:生活无处不智能

AI革命&#xff1a;生活无处不智能 &#x1f604;生命不息&#xff0c;写作不止 &#x1f525; 继续踏上学习之路&#xff0c;学之分享笔记 &#x1f44a; 总有一天我也能像各位大佬一样 &#x1f3c6; 博客首页 怒放吧德德 To记录领地 &#x1f31d;分享学习心得&#xff0…

PyBullet 物理引擎

PyBullet是一个开源的物理仿真库&#xff0c;基于Bullet Physics SDK这一成熟的、广泛使用的开源物理引擎。它提供了Python接口&#xff0c;使开发者能够利用Bullet强大的物理仿真能力&#xff0c;同时享受Python的易用性。PyBullet支持多种物理学模型&#xff0c;如刚体、骨骼…

CTF流量分析之wireshark使用

01.基本介绍 在CTF比赛中&#xff0c;对于流量包的分析取证是一种十分重要的题型。通常这类题目都是会提供一个包含流量数据的pcap文件&#xff0c;参赛选手通过该文件筛选和过滤其中无关的流量信息&#xff0c;根据关键流量信息找出flag或者相关线索。 pcap流量包的分析通常…

在 GPT-4o 释放完整能力前,听听实时多模态 AI 创业者的一手经验 | 编码人声

「编码人声」是由「RTE开发者社区」策划的一档播客节目&#xff0c;关注行业发展变革、开发者职涯发展、技术突破以及创业创新&#xff0c;由开发者来分享开发者眼中的工作与生活。 5 月中旬 GPT-4o 的发布&#xff0c;让人与 AI 的交互&#xff0c;从对话框的文本交流加速推进…