【NLP】视觉变压器与卷积神经网络

news2024/11/24 9:44:18

一、说明

        本篇是 变压器因其计算效率和可扩展性而成为NLP的首选模型。在计算机视觉中,卷积神经网络(CNN)架构仍然占主导地位,但一些研究人员已经尝试将CNN与自我注意相结合。作者尝试将标准变压器直接应用于图像,发现在中型数据集上训练时,与类似ResNet的架构相比,这些模型的准确性适中。然而,当在更大的数据集上进行训练时,视觉转换器(ViT)取得了出色的结果,并在多个图像识别基准上接近或超过了最先进的技术。本文记录这种结论,等有时机去验证。

二、CNN卷积网络transformer起源

        这篇博文的灵感来自谷歌研究团队的一篇题为“图像价值16X16字:大规模图像识别的变形金刚”的论文。本文建议使用直接应用于图像补丁的纯转换器来完成图像分类任务。视觉转换器 (ViT) 在多个基准测试中优于最先进的卷积网络,同时在对大量数据进行预训练后,需要更少的计算资源进行训练。

        变压器因其计算效率和可扩展性而成为NLP的首选模型。在计算机视觉中,卷积神经网络(CNN)架构仍然占主导地位,但一些研究人员已经尝试将CNN与自我注意相结合。作者尝试将标准变压器直接应用于图像,发现在中型数据集上训练时,与类似ResNet的架构相比,这些模型的准确性适中。然而,当在更大的数据集上进行训练时,视觉转换器(ViT)取得了出色的结果,并在多个图像识别基准上接近或超过了最先进的技术。

        图 1(取自原始论文)描述了一个模型,该模型通过将 2D 图像转换为展平的 2D 补丁序列来处理 <>D 图像。然后将补丁映射到具有可训练线性投影的恒定潜在矢量大小。一个可学习的嵌入被附加到补丁序列之前,它在转换器编码器输出端的状态用作图像表示。然后将图像表示通过分类头进行预训练或微调。添加位置嵌入以保留位置信息,嵌入向量序列用作变压器编码器的输入,该编码器由多头自注意和 MLP 块的交替层组成。

        过去,CNN长期以来一直是图像处理任务的首选。它们擅长通过卷积层捕获局部空间模式,从而实现分层特征提取。CNN擅长从大量图像数据中学习,并在图像分类,对象检测和分割等任务中取得了显着的成功。

        虽然CNN在各种计算机视觉任务中拥有良好的记录,并且可以有效地处理大规模数据集,但视觉转换器在全局依赖关系和上下文理解至关重要的情况下具有优势。然而,视觉变压器通常需要大量的训练数据才能实现与CNN相当的性能。此外,CNN由于其可并行化的性质而具有计算效率,使其对于实时和资源受限的应用程序更加实用。

三、示例:CNN 与视觉转换器

        在本节中,我们将使用 CNN 和视觉转换器方法,在 Kaggle 中可用的猫和狗数据集上训练视觉分类器。首先,我们将从 Kaggle 下载包含 25000 张 RGB 图像的猫和狗数据集。如果您还没有,可以阅读此处的说明,了解如何设置 Kaggle API 凭据。以下 Python 代码会将数据集下载到当前工作目录中。


from kaggle.api.kaggle_api_extended import KaggleApi

api = KaggleApi()
api.authenticate()

# we write to the current directory with './'
api.dataset_download_files('karakaggle/kaggle-cat-vs-dog-dataset', path='./')

        下载文件后,您可以使用以下命令解压缩文件。

!unzip -qq kaggle-cat-vs-dog-dataset.zip
!rm -r kaggle-cat-vs-dog-dataset.zip

        使用以下命令克隆视觉转换器 GitHub 存储库。此存储库包含vision_tr目录下的视觉转换器所需的所有代码。

!git clone https://github.com/RustamyF/vision-transformer.git
!mv vision-transformer/vision_tr .

        下载的数据需要清理并准备训练我们的图像分类器。创建以下实用程序函数以 Pytorch 的 DataLoader 格式清理和加载数据。

import torch.nn as nn
import torch
import torch.optim as optim

from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from sklearn.model_selection import train_test_split

import os


class LoadData:
    def __init__(self):
        self.cat_path = 'kagglecatsanddogs_3367a/PetImages/Cat'
        self.dog_path = 'kagglecatsanddogs_3367a/PetImages/Dog'

    def delete_non_jpeg_files(self, directory):
        for filename in os.listdir(directory):
            if not filename.endswith('.jpg') and not filename.endswith('.jpeg'):
                file_path = os.path.join(directory, filename)
                try:
                    if os.path.isfile(file_path) or os.path.islink(file_path):
                        os.unlink(file_path)
                    elif os.path.isdir(file_path):
                        shutil.rmtree(file_path)
                    print('deleted', file_path)
                except Exception as e:
                    print('Failed to delete %s. Reason: %s' % (file_path, e))

    def data(self):
        self.delete_non_jpeg_files(self.dog_path)
        self.delete_non_jpeg_files(self.cat_path)

        dog_list = os.listdir(self.dog_path)
        dog_list = [(os.path.join(self.dog_path, i), 1) for i in dog_list]

        cat_list = os.listdir(self.cat_path)
        cat_list = [(os.path.join(self.cat_path, i), 0) for i in cat_list]

        total_list = cat_list + dog_list

        train_list, test_list = train_test_split(total_list, test_size=0.2)
        train_list, val_list = train_test_split(train_list, test_size=0.2)
        print('train list', len(train_list))
        print('test list', len(test_list))
        print('val list', len(val_list))
        return train_list, test_list, val_list


# data Augumentation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])


class dataset(torch.utils.data.Dataset):

    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    # dataset length
    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    # load an one of images
    def __getitem__(self, idx):
        img_path, label = self.file_list[idx]
        img = Image.open(img_path).convert('RGB')
        img_transformed = self.transform(img)
        return img_transformed, label

四、CNN方法

        此图像分类器的 CNN 模型由三层 2D 卷积组成,内核大小为 3,步幅为 2,最大池化层为 2。在卷积层之后,有两个全连接层,每个层由 10 个节点组成。下面是说明此结构的代码片段:

class Cnn(nn.Module):
    def __init__(self):
        super(Cnn, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=0, stride=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=0, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=0, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.fc1 = nn.Linear(3 * 3 * 64, 10)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(10, 2)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.view(out.size(0), -1)
        out = self.relu(self.fc1(out))
        out = self.fc2(out)
        return out

        训练是用特斯拉T4(g4dn-xlarge)GPU机器进行的,训练了10个训练周期。Jupyter Notebook 在项目的 GitHub 存储库中可用,其中包含训练循环的代码。以下是每个纪元的训练循环的结果。

五、视觉转换器方法

        视觉变压器架构设计有可定制的尺寸,可以根据特定要求进行调整。对于这种大小的图像数据集,此体系结构仍然很大。

from vision_tr.simple_vit import ViT
model = ViT(
    image_size=224,
    patch_size=32,
    num_classes=2,
    dim=128,
    depth=12,
    heads=8,
    mlp_dim=1024,
    dropout=0.1,
    emb_dropout=0.1,
).to(device)

        视觉转换器中的每个参数都起着关键作用,如下所述:

  • image_size=224:此参数指定模型输入图像的所需大小(宽度和高度)。在这种情况下,图像的大小应为 224x224 像素。
  • patch_size=32:图像被分成较小的补丁,此参数定义每个补丁的大小(宽度和高度)。在本例中,每个修补程序为 32x32 像素。
  • num_classes=2:此参数表示分类任务中的类数。在此示例中,模型旨在将输入分为两类(猫和狗)。
  • dim=128:它指定模型中嵌入向量的维数。嵌入捕获每个图像修补程序的表示形式。
  • depth=12:此参数定义视觉转换器模型(编码器模型)中的深度或层数。更高的深度允许更复杂的特征提取。
  • heads=8:此参数表示模型自注意机制中的注意力头数。
  • mlp_dim=1024:指定模型中多层感知器 (MLP) 隐藏层的维数。MLP 负责在自我注意后转换令牌表示。
  • dropout=0.1:此参数控制辍学率,这是一种用于防止过度拟合的正则化技术。它在训练期间将输入单位的一部分随机设置为 0。
  • emb_dropout=0.1:它定义了专门应用于令牌嵌入的辍学率。此丢弃有助于防止在训练期间过度依赖特定令牌。

使用Tesla T4(g4dn-xlarge)GPU机器对分类任务的视觉转换器进行了20个训练周期的训练。训练进行了20个epoch(而不是CNN使用的10个epoch),因为训练损失的收敛速度很慢。以下是每个纪元的训练循环的结果。

        CNN 方法在 75 个时期内达到了 10% 的准确率,而视觉转换器模型的准确率达到了 69%,训练时间要长得多。

六、结论

        总之,在比较CNN和Vision Transformer模型时,在模型大小,内存要求,准确性和性能方面存在显着差异。CNN 型号传统上以其紧凑的尺寸和高效的内存利用率而闻名,使其适用于资源受限的环境。事实证明,它们在图像处理任务中非常有效,并在各种计算机视觉应用中表现出出色的精度。另一方面,视觉变压器提供了一种强大的方法来捕获图像中的全局依赖关系和上下文理解,从而提高某些任务的性能。然而,与CNN相比,视觉变压器往往具有更大的模型尺寸和更高的内存要求。虽然它们可能会达到令人印象深刻的准确性,尤其是在处理较大的数据集时,但计算需求可能会限制它们在资源有限的场景中的实用性。最终,CNN 和 Vision Transformer 模型之间的选择取决于手头任务的特定要求,考虑可用资源、数据集大小以及模型复杂性、准确性和性能之间的权衡等因素。随着计算机视觉领域的不断发展,预计这两种架构将取得进一步进展,使研究人员和从业者能够根据他们的特定需求和限制做出更明智的选择。

变形金刚
计算机视觉

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/781346.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

4.2 Bootstrap HTML编码规范

文章目录 Bootstrap HTML编码规范语法HTML5 doctype语言属性IE 兼容模式字符编码引入 CSS 和 JavaScript 文件HTML5 spec links 实用为王属性顺序布尔&#xff08;boolean&#xff09;型属性减少标签的数量JavaScript 生成的标签 Bootstrap HTML编码规范 语法 用两个空格来代替…

【Verilog】乒乓操作

文章目录 乒乓操作乒乓操作简单介绍乒乓操作的处理流程代码参考功能代码testbench波形文件 乒乓操作应用场景何时考虑使用乒乓操作乒乓操作的三个优点具体实现分析不间断地处理数据&#xff0c;无缝缓冲与处理可以节约缓冲区空间用低速模块处理高速数据流 乒乓操作 乒乓操作简…

光电器件的种类、原理和应用

光电器件是指能够将光信号转换成电信号或者将电信号转换成光信号的器件。它们广泛应用于通信、计算机、医疗、能源和环保等领域。本文将从光电器件的种类、原理和应用三个方面进行论述。 一、光电器件的种类 根据其功能和结构特点&#xff0c;光电器件可以分为多种类型&#…

【基于CentOS 7 的iscsi服务】

目录 一、概述 1.简述 2.作用 3. iscsi 4.相关名称 二、使用步骤 - 构建iscsi服务 1.使用targetcli工具进入到iscsi服务器端管理界面 2.实现步骤 2.1 服务器端 2.2 客户端 2.2.1 安装软件 2.2.2 在认证文件中生成iqn编号 2.2.3 开启客户端服务 2.2.4 查找可用的i…

Spring Boot-3

学习笔记&#xff08;今天又读了好多篇的博客&#xff0c;做个今天的总结&#xff0c;加油&#xff01;&#xff01;&#xff01;&#xff09; PS&#xff1a;快到中伏了&#xff0c;今天还是好热 使用阿里巴巴 FastJson 的设置 1、jackson 和 fastJson 的对比 有很多人已经…

spmvc基本要求

Mvc第一天 今天目标就是将所有的接口相关的注解理解,并且所有的注解都举例写出代码(cdsn上查) 1 mvc的基本概念: mvc: model,view,Controller,简单解释就是模型,视图,控制器.面试会问,md文档这些描述很到位,你可以看看 2 接口注解: (1) Controller/RestController表示当前…

【100天精通python】Day12:面向对象编程_属性和继承的基本语法与示例

目录 1 属性&#xff08;Attributes&#xff09; 1.1 属性的基本语法 1.2 创建用于计算的属性 1.3 属性的安全保护机制 2 继承&#xff08;Inheritance&#xff09; 2.1 继承的基本语法 2.2 方法的重写 2.3 派生类中调用基类的_init_()方法 3 总结 属性是类的特征或数…

记录WordPress安装后我常用的插件

记录WordPress安装后我常用的插件 一、WordPress安装二、插件使用1.添加Astra主题2.Easy Updates Manager2.WP Githuber MD3.WP-Optimize – Cache, Clean, Compress4. WP-PostViews或Post Views Counter5. Easy Table of Contents5. UpdraftPlus Backup/Restore6.WP Super Cac…

【开源项目】低代码数据可视化开发平台-Datav

Datav 基本介绍 Datav是一个Vue3搭建的低代码数据可视化开发平台&#xff0c;将图表或页面元素封装为基础组件&#xff0c;无需编写代码即可完成业务需求。 它的技术栈为&#xff1a;Vue3 TypeScript4 Vite2 ECharts5 Axios Pinia2 在线预览 账号: admin 密码: 123123预…

4.1 Bootstrap UI 编辑器

文章目录 1. Bootstrap Magic2. BootSwatchr3. Bootstrap Live Editor4. Fancy Boot5. Style Bootstrap6. Lavish7. Bootstrap ThemeRoller8. LayoutIt!9. Pingendo10. Kickstrap11. Bootply12. X-editable13. Jetstrap14. DivShot15. PaintStrap 以下是 15 款最好的 Bootstrap…

Zabbix监控安装grafana并配置图形操作

第三阶段基础 时 间&#xff1a;2023年7月20日 参加人&#xff1a;全班人员 内 容&#xff1a; Zabbix监控安装grafana 目录 安装并配置grafana 一、安装Grafana 二、下载安装插件 三、配置grafana 四、Web访问并配置&#xff1a; 安装并配置grafana 一、安装Graf…

微服务之服务注册与发现原理

1. 前言 在传统的开发中&#xff0c;由于提供服务的地址是相对静态的&#xff0c;所以我们只需要找到对应服务的开发人员&#xff0c;然后了解到对应的服务接口地址就可以了。 而在微服务架构开发过程中&#xff0c;如果我们需要调用一个RESTFul风格的API接口&#xff0c;我们…

deep a wavelet 深度自适应小波网络

深度自适应小波网络 1.  原理说明 1.1 the Lifting scheme 提升方案&#xff0c;也称为第二代小波[25]&#xff0c;是定义与第一代小波[6]具有相同属性的小波的一种简单而强大的方法。 提升方案将信号x作为输入&#xff0c; 生成小波变换的近似分量C, 和细节分量d 这两类…

网络安全基础知识解析:了解常见的网络攻击类型、术语及其防范方法

目录 1、网络安全常识和术语 1.1资产 1.2网络安全 1.3漏洞 1.4 0day 1.5 1day 1.6后门 1.7exploit 1.8攻击 1.9安全策略 1.10安全机制 1.11社会工程学 2、为什么会出现网络安全问题&#xff1f; 2.1网络的脆弱性 2.4.1缓冲区溢出攻击原理&#xff1a; 2.4.2缓冲…

NLP(六十一)使用Baichuan-13B-Chat模型构建智能文档问答助手

在文章NLP&#xff08;六十&#xff09;Baichuan-13B-Chat模型使用体验中&#xff0c;我们介绍了Baichuan-13B-Chat模型及其在向量嵌入和文档阅读上的初步尝试。   本文将详细介绍如何使用Baichuan-13B-Chat模型来构建智能文档问答助手。 文档问答流程 智能文档问答助手的流…

手机+App=电脑静音无线鼠标 - WiFimouse初体验

应用情景 大晚上的别人在睡觉&#xff0c;自己又不得不使用电脑&#xff08;台式&#xff09;&#xff0c;鼠标点点点又吵。 专门买个静音鼠标又没钱&#xff0c;咋办~ 效果图 手机app 电脑无线触控板&#xff0c;零噪音&#xff01; 可以单击、移动鼠标光标、可以上下滚动…

什么是Redis?

什么是Redis 什么是Redis一、特性1. 支持多种数据结构2. 读/写速度快&#xff0c;性能高。3. 支持持久化。4. 实现高可用主从复制&#xff0c;主节点做数据副本。5. 实现分布式集群和高可用。 二、基本数据类型string&#xff08;字符串&#xff09;list(双向链表)set(集合)zse…

22matlab数据分析 拉格朗日插值(matlab程序)

1.简述 第一部分&#xff1a;问题分析 &#xff08;1&#xff09;实验题目&#xff1a;拉格朗日插值算法 具体实验要求&#xff1a;要求学生运用拉格朗日插值算法通过给定的平面上的n个数据点&#xff0c;计算拉格朗日多项式Pn(x)的值&#xff0c;并将其作为实际函数f(x)的估…

文心千帆为你而来

1. 前言 3月16号百度率先发布了国内第一个人工智能大语言模型—文心一言。文心一言的发布在业界引起了不小的震动。而文心一言的企业服务则由文心千帆大模型平台提供。文心千帆大模型平台是百度智能云打造出来的一站式大模型开发与应用平台&#xff0c;提供包括文心一言在内的…

文件被识别为病毒,被删除,如何解决

我们的文件有时候有用&#xff0c;但是电脑却识别为病毒&#xff0c;直接给我删除掉了&#xff0c;这让人是真的很XX&#xff0c;那该怎么办呐。 我最近用了这个方法很多次&#xff0c;蛮好用&#xff0c;分享给大家&#xff01; 1、先找到安全中心 2、找不到排除项 3、点击添…