CNNs和视觉Transformer:分析与比较

news2024/9/28 8:10:54

探索视觉Transformer和卷积神经网络(CNNs)在图像分类任务中的有效性。

图像分类是计算机视觉中的关键任务,在工业、医学影像和农业等各个领域得到广泛应用。卷积神经网络(CNNs)是该领域的一项重大突破,被广泛使用。然而,随着论文《Attention is all you need》的出现,行业开始转向Transformer。Transformer在人工智能和数据科学领域取得了显著进展。例如,ChatGPT的出色性能最近就展示了Transformer的有效性。类似地,《ViT》论文提供了Vision Transformer的概述。在本文中,我将尝试比较CNNs和ViTs(Vision Transformer)在Food-101数据集上进行图像分类的性能。需要注意的是,选择使用CNNs还是ViTs取决于多个因素,包括工作类型、训练时间和计算能力,并不能直接断言Transformer比CNNs更好。本分析旨在提供对它们在这个特定任务中性能的见解。

数据集

由于有限的计算能力,我将易于访问的Food-101数据集分成了10个类别,该数据集包含大约101,000张图像。该数据集可以直接从PyTorch和TensorFlow中使用:

    • https://pytorch.org/vision/main/generated/torchvision.datasets.Food101.html

    • https://www.tensorflow.org/datasets/catalog/food101

    • https://huggingface.co/datasets/food101

如果您想下载数据集,可以使用以下链接:

    • https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/

我将数据集分成了以下10个类别:

['samosa','pizza','red_velvet_cake', 'tacos', 'miso_soup', 'onion_rings', 'ramen', 'nachos', 'omelette', 'ice_cream']

注意:类别名称的顺序与上述列表不同。

图像经过转换和调整大小为256x256,并标准化为均值为0,方差为1。在数据集的子集之后,将数据集分为训练集和验证集,其中7500张图像用于训练,2500张图像用于测试。

以下是数据集中的示例图像:

dfec26d0283ae80e6850964ca5c000c7.png

为了比较CNNs和ViTs的性能,我使用了预训练的DenseNet121架构作为CNNs的模型,以及ViT-16作为Vision Transformers的模型。选择DenseNet121是基于其密集的架构,拥有121层,使其成为与ViTs在训练时间、层数以及硬件和内存要求方面进行比较的合适候选模型。对于ViTs,我使用了ViT-Base模型,它由12层和86M个参数组成。

DenseNet121

DenseNet-121是一个非常著名的CNN架构,用于图像分类,它是DenseNet模型系列的一部分,旨在解决深度神经网络中可能出现的梯度消失问题。它有121个层,使用了卷积层、池化层和全连接层的组合。其中有4个稠密块,每个稠密块由多个带有BatchNorm和ReLU激活的卷积层组成。在稠密块之间,有过渡层,使用池化操作来减小特征图的空间维度。以下是DenseNet的架构示意图:

fa2bfe790c44cf621b72df432eb23017.png

DenseNet架构

预训练模型使用了PyTorch提供的模型。模型经过了10个epochs的训练。

# Constants
NUM_CLASSES = 10
LEARNING_RATE = 0.001


# Model
densenet = torch.hub.load('pytorch/vision:v0.10.0', 'densenet121', pretrained=True)
for param in densenet.parameters():
  param.requires_grad = False


# Change classifier layer
densenet.classifier = nn.Linear(1024,NUM_CLASSES)


# Loss, Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(densenet.classifier.parameters(), lr=LEARNING_RATE)

准确率 vs Epochs和损失 vs Epochs的图表:

4a1b557cf417e06ee173b10674d6bbff.png

b2cc8e5d51ff5ee02b17649a86682b97.png

在最后一个epoch中,训练损失为0.3671,测试损失为0.3586,训练准确率为88.29%,测试准确率为87.72%。

分类别结果:

83875553a44a2dc4a3c97dde18db7fb2.png

ViT-16

ViT-16是Vision Transformer(ViT)的一种变体,由于在各种图像分类基准测试中能够达到最先进的结果,它在ViT论文之后变得非常受欢迎。ViT-16由一个Transformer编码器和一个用于分类的多层感知机(MLP)组成。Transformer编码器由16个相同的Transformer层组成,每个层包含一个自注意机制和一个前馈神经网络。网络的输入是扁平化的图像块序列,通过将输入图像分成不重叠的块,并将每个块扁平化为向量而获得。

每个Transformer层中的自注意机制允许网络在进行预测时专注于图像的不同部分。特别地,它计算输入序列中每对位置的注意权重,使得网络能够根据它们与当前分类任务的相关性来关注不同的图像块。每个Transformer层中的前馈神经网络然后对自注意机制的输出进行非线性变换。

在Transformer编码器之后,输出被传递到MLP分类器中,该分类器由两个具有ReLU激活的全连接层和一个用于分类的softmax输出层组成。MLP将最终Transformer层的输出作为输入,并将其映射到输出类别上的概率分布。

以下是ViT的架构示意图:

266373426c363780bdce02d4ee15d138.png

Vision Transformer架构

在将图像输入Transformer编码器模型之前,我们需要首先将输入图像分割成块,然后扁平化这些块。下面是图像被分割成块的示例:

1d466cea0b1ee14e15322c7f19d81487.png

将样本输入图像分割成块

我尝试了从头构建Transformer模型,但性能并不好。然后我尝试了迁移学习,使用了预训练的ViT-16模型和PyTorch提供的默认权重。我还对适用于ViT的图像应用了相应的转换操作。

# Default weights
pretrained_weights = torchvision.models.ViT_B_16_Weights.DEFAULT


# Model
vit = vit_b_16(weights=pretrained_weights).to(device)


for parameter in vit.parameters():
  parameter.requires_grad=False


# Change last layer
vit.heads = nn.Linear(in_features=768, out_features=10)


# Auto Transforms
vit_transforms = pretrained_weights.transforms()

准确率 vs Epochs和损失 vs Epochs的图表:

48d94a500965cce23d886bb1c250455f.png

在最后一个epoch中,训练损失为0.1203,测试损失为0.1893,训练准确率为96.89%,测试准确率为93.63%。

分类别的结果:

1bfff11b74d95db10ec0da7aab38bd57.png

预测结果:

以下是对于ViT-16模型的一些使用未见过数据的预测结果 — 

56cf5bef48398a5728fa15dc16b041d7.png

类别:5 名称:比萨

8d02fface249e3b094eb920c9a5d096d.png

类别:6 名称:拉面

7fe62e17f85597bc2e13073e66bca1fc.png

类别:8 名称:萨莫萨饼

注意:类别名称的顺序与上述列表不同

在大多数情况下,ViT-16能够正确分类未见过的数据。

结论

在这个特定任务中,ViT-16在图像分类方面的性能优于DenseNet121。准确率和图表曲线也显示了两者之间的显著差异。分类报告显示,ViT的f1-score相比DenseNet更好。

然而,需要注意的是,虽然Vision Transformer在某些情况下可能优于CNN,但不能一概而论地认为它们比CNN架构更好。每个架构的性能取决于各种因素,如使用情况、数据规模、训练时间、参数调整、硬件的内存和计算能力等。

·  END  ·

HAPPY LIFE

8f38411480cde2c57346b74f57808caf.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/570810.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软考网工计算题总结(一):总共33类题型,进来复习啦!

题型一: 1.地址编号从80000H到BFFFFH且按字节编址的内存容量为(5)KB,若用16KX4bit的存储芯片够成该内存,共需(6)片。 (5)A.128 B.256 C.512 D.1024 (6)A.8 B.16 C.32 D.64 【答案】B C 【解析】本题…

springboot入门简单使用

springboot入门简单使用 1、SpringBoot项目创建并配置mysql数据库创建项目编写Controller测试配置数据库 2、SpringBoot集成mybatis-plus初始化数据库安装mybatis-plus通过mybatis-plus将数据库数据通过接口显示 3、SpringBoot三层架构Controller、Service、Dao4、SpringBoot集…

【鲁棒、状态估计】用于电力系统动态状态估计的鲁棒迭代扩展卡尔曼滤波器研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

自动化测试技术相结合的测试方法

自动化测试技术相结合的测试方法 随着软件开发的不断进步和变革,测试也越来越重要。为了提高测试效率和质量,自动化测试技术相结合的测试方法得到了广泛应用。 自动化测试是一种利用工具和脚本自动执行测试任务的测试方法。通过自动化测试,可…

【产品设计】工具类产品,带一些社交元素

工具类产品要加入社交元素,关键在于找到工具与社交的结合点。 一、工具类的产品,可以这样加入社交元素 1、分开来看:工具类产品和社交类产品 工具类产品,顾名思义,以工具属性为主,核心突出的是一个“用”…

UE5.1.1C++从0开始(11.AI与行为树)

怕有些朋友不知道教程指的是哪一个,我在这里把教程的网址贴出来:https://www.bilibili.com/video/BV1nU4y1X7iQ?p1 这一章开始进入电脑玩家逻辑的编写,因为是第一次接触,所以老师也没有讲什么很难的问题,这里还是老样…

React学习笔记七-事件处理

此文章是本人在学习React的时候,写下的学习笔记,在此纪录和分享。此为第七篇,主要介绍react中的事件处理。 事件处理 (1)通过onXxx属性指定事件处理函数(注意大小写) 1.react使用的是自定义(合…

01_JVM快速入门

从面试开始: 请谈谈你对JVM 的理解?java8 的虚拟机有什么更新? 什么是OOM ?什么是StackOverflowError?有哪些方法分析? JVM 的常用参数调优你知道哪些? 内存快照抓取和MAT分析DUMP文件知道吗…

2023年第十五届电工杯选题浅析

本次电工杯作为2023年上半年度数学建模赛事的收官之战,报名队伍最后截止统计已经达到12000支队伍,同时免费的报名费也让这个收官之战,被很多建模小白当作第一次练手赛。为了帮助大家选题,下面为大家带来AB两题的思路浅析&#xff…

凌恩生物文献分享 | 癌症领域新曙光——肿瘤内微生物

上一期我们给大家介绍了肿瘤胞内菌在癌症转移中发挥的作用。2022年12月,蔡尚老师团队在Cell子刊-Trends in Cell Biology上又发表了一篇总结肿瘤内菌群在癌症转移中最新发现的综述,其中讨论了癌症治疗遇到的新挑战。 研究亮点 1)癌症转移是…

Leetcode 1679. K 和数对的最大数目 双指针法

https://leetcode.cn/problems/max-number-of-k-sum-pairs/ 给你一个整数数组 nums 和一个整数 k 。 每一步操作中,你需要从数组中选出和为 k 的两个整数,并将它们移出数组。 返回你可以对数组执行的最大操作数。 示例 1: 输入&#xff1…

【JS】1693- 重学 JavaScript API - Web Storage API

❝ 前期回顾: 1. Page Visibility API 2. Broadcast Channel API 3. Beacon API 4. Resize Observer API 5. Clipboard API 6. Fetch API 7. Performance API ❞ 在 Web 开发中经常需要在客户端保存和获取数据,Web Storage API 提供了一种在浏览器中存储…

【sop】含储能及sop的多时段配网优化模型

目录 1 主要内容 2 部分代码 3 程序结果 4 下载链接 1 主要内容 之前分享了含sop的配电网优化模型,链接含sop的配电网优化,很多同学在咨询如何增加储能约束,并进行多时段的优化,本次拓展该部分功能,在原代码的基础上增加储能模…

Paragon NTFS2023最新mac免费实用工具磁盘工具

mac虽然系统稳定,但在使用过程中也有一些瑕疵,如当mac连接到ntfs格式移动磁盘时,可能会出现移动磁盘无法在mac被正常读写的状况。遇到移动磁盘无法正常读写的状况,我们可以在mac中使用磁盘工具,以使mac获得对ntfs格式移…

Docker实战1-运行前端Vue项目

本次运行了两个项目,一个是开源的镜像,一个是自己的前端项目镜像 1 在docker中运行 keycloak docker run -p 8080:8080 -e KEYCLOAK_ADMINadmin -e KEYCLOAK_ADMIN_PASSWORDadmin quay.io/keycloak/keycloak:21.1.1 start-dev 这个最简单了&#xff0c…

版图设计IC617 virtuoso启动以及smic18mmrf加载库

一. 启动virtuoso 1.1 创建一个目录用于库管理 mkdir pro3 1.2 拷贝.bashrc到工程目录下,.bashrc存在~目录下,是一个隐藏文件,需要用ls -la查看 1.3 执行.bashrc文件 1.4 启动 virtuoso & 1.5 检查库中是否包含系统基本库,如…

【JavaSE】Java基础语法(十三):Java 中的集合(十分全面)

文章目录 List, Set, Queue, Map 四者的区别?集合框架底层数据结构总结ArrayList 和 Vector 的区别ArrayList 与 LinkedList 区别补充内容:RandomAccess 接⼝ArrayList 的扩容机制comparable 和 Comparator 的区别比较 HashSet、LinkedHashSet 和 TreeSet 三者的异同…

Java jdbcTemplate 获取数据表结构

表结构如图 代码 AutowiredJdbcTemplate jdbcTemplate;Testpublic void getColumnNames() throws Exception {String sql "select * from tb_test where 12 ";SqlRowSet sqlRowSet jdbcTemplate.queryForRowSet(sql);SqlRowSetMetaData sqlRsmd sqlRowSet.getMeta…

高手速成 | 过滤器、监听器的创建与配置

本节讲解过滤器、监听器的创建以及监听事件配置示例。 01、过滤器的创建与配置 【例1】创建过滤器及配置过滤规则。 (1) 在Eclipse中新建一个Web项目,取名为Chapt_09。在src目录下,新建一个名为com.test.filter的包。选中该包并按CtrlN组合键&#xf…

Linux之软件包管理

软件包管理 RPM RPM 概述 RPM(RedHat Package Manager), RedHat软件包管理工具, 类似windows里面的setup.exe,是Linux这系列操作系统里面的打包安装工具, 它虽然是RedHat的标志, 但理念是通用…