基于CNN+ViT的蔬果图像分类实验

news2025/4/19 15:12:57

本文只是做一个简单融合的实验,没有任何新颖,大家看看就行了。

1.数据集

本文所采用的数据集为Fruit-360 果蔬图像数据集,该数据集由 Horea Mureșan 等人整理并发布于 GitHub(项目地址:Horea94/Fruit-Images-Dataset),广泛应用于图像分类和目标识别等计算机视觉任务。该数据集共包含141 类水果和蔬菜图像,总计 94,110 张图像,每张图像的尺寸统一为 100×100 像素,且背景已统一处理为白色背景,以减少背景噪声对模型训练的影响。

数据集中涵盖了大量常见和不常见的果蔬品类,主要包括:

  1. 苹果(多个品种:如深雪、金苹果、金红、青奶奶、粉红女士、红苹果、红美味等)
  2. 香蕉(黄色、红色、淑女手指等)
  3. 葡萄(蓝色、粉红色、白色多个品种)
  4. 柑橘类(橙子、柠檬、酸橙、葡萄柚、柑橘等)
  5. 热带水果(芒果、木瓜、红毛丹、百香果、番石榴、荔枝、菠萝、火龙果等)
  6. 浆果类(蓝莓、覆盆子、草莓、黑加仑、红醋栗、桑葚等)
  7. 核果类与坚果类(桃子、李子、杏、椰子、榛子、核桃、栗子、山核桃等)
  8. 蔬菜类(黄瓜、茄子、胡椒、番茄、洋葱、花椰菜、甜菜根、玉米、土豆等)
  9. 其他类如:仙人掌果实、杨布拉、姜根、格兰纳迪拉、Physalis(灯笼果)、油桃、佩皮诺、罗望子、大头菜等。

在数据划分方面,本研究按照如下比例进行数据集划分:

(1)训练集:70,491 张图像
          其中按照 8:2 的比例划分出验证集,得到最终:

                        训练子集:56,432 张

                        验证集:14,059 张

(2)测试集:23,619 张图像

2.模型简述

在图像分类任务中,深度学习方法已经取得了显著的进展,如残差神经网络(ResNet),Vision Transformer展现了较强的性能。ResNet作为CNN下的网络架构,在局部特征提取方面具有优势,能够有效地捕捉图像中的空间结构信息。而Vision Transformer作为Transformer的变种,在捕捉全局依赖关系和建模长程依赖性方面的具有更好的优势。

由于CNN的卷积操作本质上能够生成具有空间局部关联性的特征图,实际上可以视为一种变相的patch操作。因此,在将CNN与Transformer相结合时,可以避免传统ViT中对输入图像进行切分patch的操作,只需对图像进行位置编码,从而使得Transformer能够有效处理这些具有空间结构的特征图。这种设计不仅减少了计算开销,还使得整个模型在处理图像时更具效率与准确性。

同时,与原始ViT框架中描述的技术不同,原始框架通常会将一个可学习的位置嵌入向量预先添加到编码后的patch序列中,作为图像的位置信息进行表示。然而,为了简化模型的实现并提高计算效率,本文在架构设计上有所调整,省略了额外的位置编码步骤。具体来说,本文的模型通过直接输入编码后的patch序列到Transformer块中,跳过了对每个patch进行独立位置编码的操作。

基于这一思路,结合了残差神经网络(ResNet)和Vision Transformer(ViT)两种网络架构,将它们以串行连接的方式进行融合。具体模型架构图如下图所示

3.实验

模型代码(基于tensorflow2.X)

import glob
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers,models
import warnings
warnings.filterwarnings('ignore')
import os

Train = r"D:\archive (1)\fruits-360_dataset_100x100\fruits-360\Training"
Test = r"D:\archive (1)\fruits-360_dataset_100x100\fruits-360\Test"

IMAGE_SIZE = 100
NUM_CLASSES = 141
BATCH_SIZE = 32

imagegenerator = ImageDataGenerator(rescale=1.0 / 255.0, validation_split=0.2, rotation_range=10, horizontal_flip=True)

# Training and validation data generators
Train_Data = imagegenerator.flow_from_directory(
    Train,
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='training'
)
Validation_Data = imagegenerator.flow_from_directory(
    Train,
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='validation'
)

# Test data generator (no augmentation)
test_imagegenerator = ImageDataGenerator(rescale=1.0 / 255.0)
Test_Data = test_imagegenerator.flow_from_directory(
    Test,
    target_size=(IMAGE_SIZE,IMAGE_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    # subset='test'
)
class ResidualBlock(layers.Layer):
    def __init__(self, filters, kernel_size=(3, 3), strides=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = layers.Conv2D(filters, kernel_size, strides=strides, padding="same", activation='relu')
        self.conv2 = layers.Conv2D(filters, kernel_size, strides=1, padding='same', activation='relu')
        self.shortcut = layers.Conv2D(filters, (1, 1), strides=strides, padding='same', activation='relu')
        self.bn1 = layers.BatchNormalization()
        self.bn2 = layers.BatchNormalization()
        self.relu = layers.ReLU()

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)

        shortcut = self.shortcut(inputs)
        x = layers.add([x, shortcut])
        x = self.relu(x)
        return x


# ResNet Model definition
class ResNetModel(layers.Layer):
    def __init__(self):
        super(ResNetModel, self).__init__()
        self.conv1 = layers.Conv2D(32, (5, 5), activation='relu', input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),padding='same')
        self.maxpool1 = layers.MaxPooling2D((2, 2))

        # Residual Blocks
        self.resblock1 = ResidualBlock(32,strides=1)
        self.resblock2 = ResidualBlock(64,strides=2)
        self.resblock3 = ResidualBlock(128,strides=2)
        self.resblock4 = ResidualBlock(256, strides=2)
        # self.global_avg_pool = layers.GlobalAveragePooling2D()

    def call(self, inputs):
        print(inputs.shape)
        x = self.conv1(inputs)
        print(x.shape)
        x = self.maxpool1(x)
        print(x.shape)
        # Apply Residual Blocks
        x = self.resblock1(x)
        print(x.shape)
        x = self.resblock2(x)
        print(x.shape)
        # x = self.resblock3(x)
        # print(x.shape)
        # x = self.resblock4(x)
        # x = self.global_avg_pool(x)
        # print(x.shape)
        return x
class TransformerEncoder(layers.Layer):
    def __init__(self, num_heads=8, key_dim=64, ff_dim=256, dropout_rate=0.1):
        super(TransformerEncoder, self).__init__()
        self.attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)
        self.dropout1 = layers.Dropout(dropout_rate)
        self.norm1 = layers.LayerNormalization()

        self.ff = layers.Dense(ff_dim, activation='relu')
        self.ff_output = layers.Dense(key_dim*num_heads)
        self.dropout2 = layers.Dropout(dropout_rate)
        self.norm2 = layers.LayerNormalization()

    def call(self, x):
        # Multi-head self-attention
        attention_output = self.attention(x, x)
        attention_output = self.dropout1(attention_output)
        x = self.norm1(attention_output + x)  # Residual connection

        # Feed Forward Network
        ff_output = self.ff(x)
        ff_output = self.ff_output(ff_output)
        ff_output = self.dropout2(ff_output)
        x = self.norm2(ff_output + x)  # Residual connection

        return x


# Vision Transformer (ViT) 模型
class VisionTransformer(models.Model):
    def __init__(self, input_shape=(100, 100, 3), num_classes=141, num_encoders=3, patch_size=8, num_heads=16,
                 key_dim=4, ff_dim=256, dropout_rate=0.2):
        super(VisionTransformer, self).__init__()
        self.patch_size = patch_size

        #Resnet
        self.resnet=ResNetModel()
        # Patch Embedding
        self.conv = layers.Conv2D(64, (patch_size, patch_size), strides=(patch_size, patch_size), padding='valid')
        self.reshape = layers.Reshape((-1, 64))
        self.norm = layers.LayerNormalization()
        # 位置编码层
        self.position_encoding = self.add_weight("position_encoding", shape=(1, 625, 64))
        # Stack multiple Transformer Encoder layers
        self.encoders = [
            TransformerEncoder(num_heads=num_heads, key_dim=key_dim, ff_dim=ff_dim, dropout_rate=dropout_rate) for _ in
            range(num_encoders)]

        # Global Average Pooling
        self.global_avg_pooling = layers.GlobalAveragePooling1D()

        # Fully connected layer
        self.fc1 = layers.Dense(256, activation='relu')
        self.dropout = layers.Dropout(0.2)
        self.fc2 = layers.Dense(num_classes, activation='softmax')

    def call(self, inputs):
        #resnet
        x = self.resnet(inputs)
        # print("===========================")
        # print(x.shape)
        # Patch Embedding
        x = self.reshape(x)
        # 添加位置编码
        x = x + self.position_encoding  # 将位置编码加到Patch嵌入向量中
        # print(x.shape)
        # x = self.norm(x)

        # Apply multiple Transformer encoders
        for encoder in self.encoders:
            x = encoder(x)

        # Global Average Pooling
        x = self.global_avg_pooling(x)

        # Fully connected layers
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.fc2(x)

        return x
# 构建 Vision Transformer 模型

vit_model = VisionTransformer(input_shape=(100, 100, 3), num_classes=141, num_encoders=3)
vit_model.build(input_shape=(None, IMAGE_SIZE, IMAGE_SIZE, 3))  # 手动构建模型
# 打印模型摘要
vit_model.summary()
# 编译模型
vit_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)
checkpoint_path = "training_checkpoints_1/vit_model_checkpoint_epoch_{epoch:02d}.h5"

# 创建ModelCheckpoint回调
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path,
    monitor='val_accuracy',  # 你可以选择监控验证集的损失或准确度
    save_best_only=True,  # 只保存验证集损失最小的模型
    save_weights_only=True,  # 只保存权重(而不是整个模型)
    verbose=1  # 打印日志
)
# 检查是否有保存的模型权重文件
checkpoint_dir = "training_checkpoints_1/"
# 查找所有的 .h5 文件
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "vit_model_checkpoint_epoch_*.h5"))
# print(latest_checkpoint)
if checkpoint_files:
    # 使用 os.path.getctime() 获取文件创建时间(或者使用 getmtime() 获取修改时间)
    latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
    print(f"Loading model from checkpoint: {latest_checkpoint}")
    # 加载模型权重
    vit_model.load_weights(latest_checkpoint)
else:
    print("No checkpoint found, starting from scratch.")

# 训练模型
history = vit_model.fit(
    Train_Data,
    epochs=20,
    validation_data=Validation_Data,
    shuffle=True,
    callbacks=[checkpoint_callback]
)

# 评估模型
test_loss, test_acc = vit_model.evaluate(Test_Data)
print(f"Test Loss: {test_loss}")
print(f"Test Accuracy: {test_acc}")
# 训练和验证的准确率和损失历史记录
def plot_training_history(history):
    # 创建子图
    plt.figure(figsize=(14, 6))

    # 准备训练准确率和验证准确率的图
    plt.subplot(1, 2, 1)
    plt.title('Accuracy History')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.plot(history.history['accuracy'], label='Training Accuracy', marker='o')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy', color='green', marker='o')
    plt.legend()

    # 准备训练损失和验证损失的图
    plt.subplot(1, 2, 2)
    plt.title('Loss History')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.plot(history.history['loss'], label='Training Loss', marker='o')
    plt.plot(history.history['val_loss'], label='Validation Loss', color='green', marker='o')
    plt.legend()

    # 显示图形
    plt.tight_layout()
    plt.show()


# 绘制训练过程
plot_training_history(history)
for i in range(16):
    # 获取测试数据的下一个批次
    img_batch, labels_batch = Test_Data.next()
    img = img_batch[0]  # 获取当前批次的第一张图像
    true_label_idx = np.argmax(labels_batch[0])  # 获取真实标签的索引

    # 获取真实标签的名称
    true_label = [key for key, value in Train_Data.class_indices.items() if value == true_label_idx]

    # 扩展维度以匹配模型输入
    EachImage = np.expand_dims(img, axis=0)

    # 进行预测
    prediction = vit_model.predict(EachImage)

    # 获取预测标签
    predicted_label = [key for key, value in Train_Data.class_indices.items() if
                       value == np.argmax(prediction, axis=1)[0]]

    # 获取预测的概率
    predicted_prob = np.max(prediction, axis=1)[0]

    # 绘制图像
    plt.subplot(4, 4, i + 1)
    plt.imshow(img)
    plt.title(f"True: {true_label[0]} \nPred: {predicted_label[0]} \nProb: {predicted_prob:.2f}")
    plt.axis('off')

plt.tight_layout()
plt.show()

做了如下参数实验

ResNet层数

Encoder层数

num_heads

test_accuracy

2(32,64)

3

4

92.14%

3(32,64,128)

3

4

94.53%

2(32,64)

3

8

96.19%

3(32,64,128)

3

8

97.46%

2(32,64)

3

16

93.32%

3(32,64,128)

3

16

93.17%

 分类效果图

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2337105.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【虚幻C++笔记】接口

目录 概述创建接口 概述 简单的说,接口提供一组公共的方法,不同的对象中继承这些方法后可以有不同的具体实现。任何使用接口的类都必须实现这些接口。实现解耦解决多继承的问题 创建接口 // Fill out your copyright notice in the Description page o…

【MCP】第一篇:MCP协议深度解析——大模型时代的“神经连接层“架构揭秘

【MCP】第一篇:MCP协议深度解析——大模型时代的"神经连接层"架构揭秘 一、什么是MCP?二、为什么需要MCP?三、MCP的架构四、MCP与AI交互的原理4.1 ReAct(Reasoning Acting)模式4.2 Function Calling 模式 五…

实时模式下 libaom 与 x264 编码对比实验

前沿 理论基础:在相同视频质量下,AV1的压缩率比H.264高约30%-50%。实时模式:视频编码中的实时模式,其核心目标是平衡编码效率与延迟要求,尤其在视频会议、直播、实时通信等场景中至关重要。 低延迟要求:编…

学习海康VisionMaster之矩形检测

这几天太忙了,好几天没有学习了,今天终于空下来了,继续学习之路吧。 一:进一步学习了 今天学习下VisionMaster中的矩形检测,这个一开始我以为是形态学方面的检测,实际操作下来其实还是边缘直线的衍生应用&…

C Primer Plus 第6版 编程练习——第3章

1、通过试验(即编写带有此类问题的程序)观察系统如何处理整数上道、浮占数上溢和浮点数下溢的 int main(int argc, char** argv) {int intMax 2147483647;float floatMax 3.402823466e38f;float floatMin -3.402823466e38f;printf("intMax:%d, …

十倍开发效率 - IDEA插件之 Mybatis Log Free

提高效率不是为了完成更多任务,而是为了有充足的时间摸鱼 快速体验 MyBatis Log Free 支持打印执行的 SQL(完整的SQL,没有占位符的)。 没有使用 MyBatis Log Free 的开启日志打印是这样的: 用了 MyBatis Log Free 后…

手动安装 VMware Tools 并设置虚拟机共享 Windows 文件夹

前言:在当今数字化的工作环境中,虚拟机技术为我们提供了强大的灵活性和便利性。VMware 作为虚拟化领域的佼佼者,其虚拟机软件被广泛应用于开发、测试和日常工作中。然而,许多用户在使用 VMware 虚拟机时,会遇到一个常见…

(免费)flask调用讯飞星火AI,实现websocket

本文章可借鉴学习,不可直接盗用 接入ai要获取ID,Secret,Key,和接口地址,由于我们服务接口类型是websocket,所以要获取相应的接口地址。(千万不要复制粘贴到http的了) 还要获取doma…

2021-11-09 C++三位数平方含有该数

缘由求解&#xff0c;运算函数&#xff0c;哪位大神教一下-编程语言-CSDN问答 void 三位数平方含有该数() {//缘由https://ask.csdn.net/questions/7560152?spm1005.2025.3001.5141int a 100, aa 1000, f 0;while (a < aa){f a*a;while (f > a)if ((f - a) % aa)f …

StarRocks Community Monthly Newsletter (Mar)

版本动态 3.4.1 版本更新 核心功能升级 数据安全与权限管控 支持「安全视图」功能&#xff0c;严格管控视图查询权限 MySQL协议连接支持SSL认证&#xff0c;保障数据传输安全 存算分离架构增强 支持自动创建Snapshot&#xff08;集群恢复更便捷&#xff09; Storage Volu…

Github 2FA(Two-Factor Authentication/两因素认证)

Github 2FA认证 多因素用户认证(Multi-Factor Authentication)&#xff0c;基本上各个大互联网平台&#xff0c;尤其是云平台厂商&#xff08;如&#xff1a;阿里云的MFA、华为云、腾讯云/QQ安全中心等&#xff09;都有启用了&#xff0c;Github算是搞得比较晚些了。 双因素身…

动态规划 -- 简单多状态dp,打家劫舍问题

1 按摩师 面试题 17.16. 按摩师 - 力扣&#xff08;LeetCode&#xff09; 本题的意思简单理解就是&#xff0c;如果我们接受了第 i 个预约&#xff0c;那么第 i -1 个预约和第 i1 个预约我们都是无法接受的&#xff0c;只能至少间隔一个选择。 按照以前的经验&#xff0c;我们…

C++学习之游戏服务器开发⑤AOI业务逻辑

目录 1.项目进度回顾 2.完善整体架构 3.AOI网格思路 4.网络AOI数据结构 5.游戏世界类添加&#xff0c;删除和构造 6.AOI查找实现 7.GAMEROLE类结合AOI 8.登陆时发送ID和姓名 9.登陆时发送周围玩家位置 10.玩家上线完成 11.玩家下线处理 1.项目进度回顾 时间轮调度处理…

Python 实现日志备份守护进程

实训背景 假设你是一名运维工程师&#xff0c;需要为公司的监控系统开发一个简单的日志备份守护进程。该进程需满足以下需求&#xff1a; 后台运行&#xff1a;脱离终端&#xff0c;长期监控指定目录&#xff08;如 /var/log/app/&#xff09;中的日志文件。自动备份&#xf…

Electricity Market Optimization 探索系列(VII)- 直流潮流方程的推导及例题

本文参考书籍&#xff1a;电力经济与电力市场&#xff0c;甘德强&#xff0c;杨莉&#xff0c;冯冬涵 著 link \hspace{1.6em} 文章的结构如下&#xff1a;围绕电力传输系统中短线路的等值等效模型&#xff0c;从节点注入功率的角度和线路功率的角度分析电网中的潮流&#xff0…

路由过滤实验

实验拓扑以及要求 此实验总结 1.ip-prefix 拒绝192.168.4.1 32,这样写的话并不会匹配192.168.4.1 32,需要加上范围less-eq 32,也就是说,192.168.4.1 32只是规则的范围,匹配还是得写范围 2.router-policy适合用在边界路由器引入 filter-policy都可以用 配置IP 配置ospf,rip …

Idea连接远程云服务器上的MySQL,开放云服务器端口

1.开放云服务器的3306端口 &#xff08;1&#xff09;进入到云服务器的控制台 &#xff08;2&#xff09;点击使用的云服务器 &#xff08;3&#xff09;点击 配置安全组规则 &#xff08;4&#xff09;添加规则 &#xff08;5&#xff09;开放端口 2.创建可以远程访问…

Oracle查询大表的全部数据

2000w的大表 表结构如下&#xff0c;其中id是索引 查询处理慢的写法 List<String> queryLoidForPage(Integer startNum,Integer endNum){try {Connection oracleConnection initBean.oracleConnection;Statement stmt oracleConnection.createStatement();// 4.执行查…

PyTorch生成式人工智能实战(1)——神经网络与模型训练过程详解

PyTorch生成式人工智能实战&#xff08;1&#xff09;——神经网络与模型训练过程详解 0. 前言1. 传统机器学习与人工智能2. 人工神经网络基础2.1 人工神经网络组成2.2 神经网络的训练 3. 前向传播3.1 计算隐藏层值3.2 执行非线性激活3.3 计算输出层值3.4 计算损失值3.5 实现前…

基于X86/Nvidia+FPGA大模型具身智能机器人控制器解决方案,同时拥有算力与实时的便利

2025年成为人形机器人产业化元年&#xff0c;行业已突破早期实验室研发阶段&#xff0c;进入"场景验证量产爬坡"新周期&#xff0c;预计2031年具身智能市场规模有望突破万亿元。这一进程的背后&#xff0c;是硬件算力、实时控制、环境适应等底层技术的系统性突破——…