Vision Transformer(VIT 网络架构)

news2024/11/20 20:26:40

论文下载链接:https://arxiv.org/abs/2010.11929

文章目录

  • 引言
    • 1. VIT与传统CNN的比较
    • 2. 为什么需要Transformer在图像任务中?
  • 1. 深入Transformer
    • 1.1 Transformer的起源:NLP领域的突破
    • 1.2 Transformer的基本组成
      • 1.2.1 自注意机制 (Self-Attention Mechanism)
      • 1.2.2 前馈神经网络 (Feed-forward Neural Networks)
      • 1.2.3 残差连接 (Residual Connections)
      • 1.2.4 层标准化 (Layer Normalization)
  • 2. 从CNN到Vision Transformer
    • 2.1 CNN的局限性
    • 2.2 Vision Transformer的出现与动机
  • 3. Vision Transformer的工作原理
    • 3.1 输入:将图像分割成patches
    • 3.2 嵌入:linear embedding和位置嵌入
    • 3.3 Transformer编码器
    • 3.4 输出头:分类任务
  • 4. ViT的变种和相关工作
    • 4.1 DeiT (Data-efficient Image Transformer)
      • 4.1.1 概述
      • 4.1.2 知识蒸馏
      • 4.1.3 利用知识蒸馏进行优化的Transformer模型
    • 4.2 Hybrid models (ViT + CNN)
      • 4.2.1 为什么使用混合模型?
      • 4.2.2 基础架构
      • 4.2.3 示例
    • 4.3 Swin Transformer
      • 4.3.1 主要特点
      • 4.3.2 基础架构
      • 4.3.3 代码示例
  • 5. ViT的优点与缺点
    • 5.1 与CNN相比的优点
    • 5.2 ViT的挑战和限制

引言

1. VIT与传统CNN的比较

ViT(Vision Transformer)与传统的卷积神经网络(CNN)在图像处理方面有几个关键的不同点:

1. 模型结构:

  • ViT:主要基于Transformer结构,没有使用卷积层。
  • CNN:使用卷积层、池化层和全连接层。

2. 输入处理:

  • ViT:将图像分为多个固定大小的块并一次性处理。
  • CNN:通过卷积窗口逐渐扫描整个图像。

3. 计算复杂性:

  • ViT:由于自注意力机制,计算复杂性可能更高。
  • CNN:通常更易于优化,计算复杂性相对较低。

4. 数据依赖性:

  • ViT:通常需要更多的数据和计算资源来进行有效的训练。
  • CNN:相对更容易在小数据集上进行训练。

2. 为什么需要Transformer在图像任务中?

在深度学习的历史中,卷积神经网络(Convolutional Neural Networks, CNNs)长期以来一直是处理图像任务的主流架构。然而,随着Transformer的成功应用于自然语言处理(NLP)任务,研究人员开始考虑其在计算机视觉中的潜力。

灵活的全局注意机制

  • 全局上下文: 与局部感受野的CNN不同,Transformer具有全局的感受野,这使其可以在整个图像上进行信息融合。这种全局上下文可能在某些任务中非常有用,如图像分割、物体检测和多物体交互等。

可解释性和注意可视化

  • 更好的可解释性: 由于自注意机制,我们可以很容易地可视化模型在做决策时关注的区域,这增加了模型的可解释性。

序列到序列任务

  • 更容易处理序列输出: 在像图像字幕这样的任务中,同时考虑图像和文本信息变得更为直接,因为两者都可以用相似的Transformer架构来处理。

适应性

  • 更容易适应不同尺度和形状: Transformer不依赖于固定尺寸的滤波器,因此理论上更容易适应各种各样的输入。

1. 深入Transformer

1.1 Transformer的起源:NLP领域的突破

Transformer模型最初是由Google的研究人员在2017年的论文《Attention Is All You Need》中提出的。这个模型引入了一种全新的架构,主要以自注意(Self-Attention)机制为基础,并成功地解决了当时自然语言处理(NLP)中的一系列任务。这里列举一些Transformer在NLP领域的重要突破和影响:

1. 序列建模问题的新视角
传统的RNN(循环神经网络)和LSTM(长短时记忆)网络因为其递归的特性,在处理长序列时会遇到梯度消失或梯度爆炸的问题。Transformer通过自注意机制成功地捕获了序列内部的依赖关系,并且能够并行处理整个序列,从而在很多方面超过了RNN和LSTM。

2. 自注意机制
Transformer模型中的自注意机制允许模型在不同位置的输入之间建立直接的依赖关系,这让模型能更容易地理解句子或文档内部的上下文关系。这种机制特别适用于诸如机器翻译、文本摘要、问答系统等需要捕获长距离依赖的任务。

3. 可扩展性
由于其并行性和相对较少的时间复杂性,Transformer架构能更有效地利用现代硬件。这使得研究人员能够训练更大、更强大的模型,从而取得更好的性能。

4. 多模态和多任务学习
Transformer的架构具有高度的灵活性,可以容易地扩展到其他类型的数据和任务,包括图像、音频和多模态输入。这一点在后续的研究和应用中得到了广泛的证实。

5. 预训练和微调
Transformer架构适用于预训练和微调的工作流程。大型的预训练模型如BERT、GPT和T5都是基于Transformer构建的,并在多种NLP任务上设立了新的性能基准。

1.2 Transformer的基本组成

1.2.1 自注意机制 (Self-Attention Mechanism)

从心理学上来讲

  • 动物需要在复杂环境下有效关注值得注意的点
  • 心理学框架:人类根据随意(volitional)线索和不随意线索选择注意点(注意:这里的随意不是随便的意思,因为是翻译过来的,这里的随意应当为主动观察和不主动观察的意思,也可以理解为刻意无意

想象一下,假如我们面前有五个物品: 一份报纸、一篇研究论文、一杯咖啡、一本笔记本和一本书。所有纸制品都是黑白印刷的,但咖啡杯是红色的。 换句话说,这个咖啡杯在这种视觉环境中是突出和显眼的, 不由自主地引起人们的注意。 所以我们会把视力最敏锐的地方放到咖啡上
在这里插入图片描述

而想读书就成了随意线索
在这里插入图片描述

注意力机制

  • 传统的CNN架构中。卷积,池化,全连接层都只考虑不随意线索
  • 注意力机制则显示的考虑随意线索
    • 随意线索被称之为查询(query)
    • 每个输入是一个值(value)和不随意线索(key)的对这里可以把输入理解为环境
    • 通过注意力池化层来有偏向性的选择某些输入,因为我们加入了一些随意线索,我们可以在这里面有偏向性地选择某些输入。

计算过程

  1. 点积计算: 对于给定的查询,与每一个键进行点积,用以衡量查询和各个键之间的相似度。
  2. 缩放: 将点积的结果缩放(通常是除以键向量维度的平方根)。
  3. 激活函数: 应用Softmax激活函数,使权重和为1且介于0和1之间。
  4. 加权和: 使用得到的权重对值向量进行加权求和。
  5. 输出: 将加权和通过一个可选的全连接(Linear)层进行转换,生成该位置的输出。

多头注意力(Multi-Head Attention)
为了更丰富地捕捉不同的依赖关系,通常会使用多头注意力。在多头注意力中,模型维护多组独立的查询、键和值的权重矩阵,并进行并行计算。各个头的输出会被拼接并通过一个全连接层进行整合。

1.2.2 前馈神经网络 (Feed-forward Neural Networks)

前馈神经网络(Feed-forward Neural Networks, FFNNs)是最早的、最简单的神经网络架构。这种网络的特点是数据在网络中只有一个方向进行传播:从输入层,经过隐藏层,最终到输出层。这种单向的数据流动是“前馈”名字的由来。

结构和组件

  1. 输入层 (Input Layer): 这一层接收原始的输入数据,并将其传递给下一层。
  2. 隐藏层 (Hidden Layers): 网络可以包含一个或多个隐藏层,每个层由多个神经元组成。这些层捕获输入数据的复杂模式。
  3. 输出层 (Output Layer): 根据任务的需求(如分类、回归等),输出层生成网络的最终输出。

激活函数
为了引入非线性特性,每个神经元通常会有一个激活函数。常用的激活函数有:

  • ReLU (Rectified Linear Unit)
  • Sigmoid
  • Tanh (Hyperbolic Tangent)
  • Leaky ReLU, Parametric ReLU, etc.

训练
前馈神经网络通常使用反向传播(Backpropagation)算法进行训练,这涉及到:

  1. 前向传播 (Forward Propagation): 从输入层开始,数据通过网络流动,生成预测输出。
  2. 损失计算 (Loss Calculation): 根据预测输出和实际目标计算损失。
  3. 反向传播 (Backward Propagation): 计算损失关于每个权重的梯度,并更新网络中的权重。

在Transformer中的应用
虽然Transformer架构主要着重于自注意机制,但它在每个注意力模块之后都有一个前馈神经网络(通常是两层的网络)。这为模型引入了额外的计算能力,并帮助捕获数据的不同特征。

1.2.3 残差连接 (Residual Connections)

在Transformer架构中,残差连接起到了非常关键的作用。它们出现在自注意力(Self-Attention)层和前馈神经网络(Feed-forward Neural Networks)层的后面,通常与层归一化(Layer Normalization)一起使用。

结构与功能
在Transformer中,每一个子层(如多头自注意力或前馈神经网络)的输出都会与该子层的输入相加,形成一个残差连接。这种连接结构可以表示为:

Output=Sublayer(x)+x
或者更一般地:
Output=LayerNorm(Sublayer(x)+x)

这里的Sublayer(x)是子层(例如多头自注意力或前馈神经网络)的输出,而LayerNorm是层归一化。

1.2.4 层标准化 (Layer Normalization)

基本原理
层标准化的核心思想是对每一层的每一个样本独立进行标准化,以便每一层的输出具有大致相同的尺度。在全连接层或者卷积层之后,但通常在激活函数之前应用层标准化。
数学表示为:
在这里插入图片描述

在Transformer中的应用
在Transformer架构中,层标准化通常与残差连接(Residual Connections)结合使用。每个残差连接后面都会跟一个层标准化步骤,以稳定模型训练。这种组合有助于模型在训练期间保持数值稳定性,尤其是对于非常深的模型。

class AddNorm(nn.Module):
    """残差连接后进行层规范化"""
    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)

优点

  1. 数值稳定性: 层标准化有助于防止梯度消失或梯度爆炸问题,从而使模型更容易训练。
  2. 加速收敛: 通过调整各层的尺度,层标准化可以加速模型的收敛速度。
  3. 可适应性: 层标准化适用于不同类型和深度的网络架构,包括循环神经网络(RNNs)。

缺点

  1. 序列长度依赖: 在处理可变长度序列时,层标准化可能不如批标准化(Batch Normalization)有效。
  2. 模型复杂性: 引入了额外的可学习参数,这可能会增加模型的复杂性。

2. 从CNN到Vision Transformer

卷积神经网络(CNN)和Vision Transformer(ViT)都是用于处理图像任务的流行模型,但它们有着不同的设计哲学和应用范围。下面简要介绍这两者之间的演进。

2.1 CNN的局限性

1. 局部感受野
CNN通过局部感受野(receptive fields)来处理图像,这在某些任务中是一个局限性。虽然这种设计有助于识别图像中的局部结构,但它可能不适合捕捉远距离的依赖关系。

2. 计算成本
当处理高分辨率图像时,卷积操作的计算成本可能会非常高。

3. 空间结构假设
CNN假设输入数据具有某种固有的空间或时间结构。这使得CNN不容易适用于没有明确空间结构的数据。

4. 参数效率
在参数效率方面,即使使用了各种技巧(如批标准化、残差连接等),CNN仍然可能不如Transformer模型。

2.2 Vision Transformer的出现与动机

Vision Transformer是由Google Research在2020年首次提出的,它的设计灵感来自于用于自然语言处理的Transformer模型。

1. 全局注意力
与CNN不同,ViT使用全局自注意力机制,可以更好地处理图像中的远距离依赖关系

2. 计算效率
ViT通过自注意力前馈神经网络来实现计算效率,特别是在处理高分辨率图像时。

3. 模块化和可扩展性
ViT具有很好的模块化和可扩展性,可以容易地调整模型大小和复杂性。

4. 参数效率
在大量数据集上进行预训练后,ViT通常表现出高度的参数效率,即在相同数量的参数下,性能比CNN更好。

5. 跨模态应用
由于ViT没有硬编码的空间假设,它也更容易应用于其他类型的数据和任务。

3. Vision Transformer的工作原理

3.1 输入:将图像分割成patches

输入:将图像分割成patches

  1. 图像分割: Vision Transformer(ViT)首先将输入图像分割成多个固定大小的小块(patches)。这些小块通常是方形的,例如16x16像素。
  2. 一维化: 每个小块都被拉平成一个一维向量
  3. 合并: 所有这些一维向量然后被串联成一个序列,作为Transformer编码器的输入。

3.2 嵌入:linear embedding和位置嵌入

  1. Linear Embedding: 小块通过一个线性层(通常是一个全连接层)进行嵌入,以将它们转换成合适维度的向量。这相当于通过一个很浅的CNN层进行特征提取。
  2. 位置嵌入: 由于小块的原始位置信息在一维化过程中丢失了,因此需要添加位置嵌入以帮助模型识别这些小块的相对或绝对位置。
  3. 合并: 线性嵌入和位置嵌入通常会被加在一起,以生成一个包含位置信息的嵌入序列。

3.3 Transformer编码器

  1. 自注意力层: 这一层使用自注意力机制来分析输入序列中的每个元素(即每个小块和其对应的位置嵌入),以便更好地表示各个小块之间的关系。
  2. 前馈神经网络: 自注意力层的输出会被送入一个前馈神经网络(Feed-forward Neural Network)。
  3. 残差连接与层标准化: 在自注意力层和前馈神经网络之后,都会有残差连接和层标准化操作,以促进模型训练的稳定性和效率。
  4. 堆叠编码器: 上述所有组件会被堆叠多次(例如,12次或24次等),以形成完整的Transformer编码器。
  5. 分类头: 对于分类任务,通常会取编码器输出序列的第一个元素(通常对应于一个特殊的“[CLS]”标记)并通过一个全连接层进行分类。
class EncoderBlock(nn.Module):
    """Transformer编码器块"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                 dropout, use_bias=False, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.attention = d2l.MultiHeadAttention(
            key_size, query_size, value_size, num_hiddens, num_heads, dropout,
            use_bias)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(
            ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))

Transformer编码器中的任何层都不会改变其输入的形状。

3.4 输出头:分类任务

在Vision Transformer(ViT)模型中,用于分类任务的输出头通常是一个全连接(线性)层,该层将Transformer编码器的输出映射到类别标签的数量。在多数实现中,通常会使用Transformer编码器输出的第一个位置(通常与添加的特殊 [CLS] 标记对应)的特征

4. ViT的变种和相关工作

随着Vision Transformer(ViT)在图像分类任务中的成功,很多研究者开始探索其变种和改进方案。这里选择一些值得关注的变种和相关工作进行概述解析:

4.1 DeiT (Data-efficient Image Transformer)

4.1.1 概述

  • 概念: DeiT关注于如何更有效地使用数据。标准的ViT需要大量的数据和计算资源来进行预训练,但DeiT通过更高效的训练策略,尤其是数据增强知识蒸馏,来改善这一点。
  • 主要特点: 使用知识蒸馏和不同的训练技巧,如学习率调度和数据增强,以减少对大量标签数据的依赖。
import torch
import torch.nn as nn
import torch.nn.functional as F

# 分割图像到patch
class PatchEmbedding(nn.Module):
    def __init__(self, patch_size, in_channels, embed_dim):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # [B, C, H, W]
        x = x.flatten(2).transpose(1, 2)  # [B, num_patches, embed_dim]
        return x

# DeiT 模型主体
class DeiT(nn.Module):
    def __init__(self, patch_size, in_channels, embed_dim, num_heads, num_layers, num_classes):
        super().__init__()

        # 分割图像到patch并嵌入
        self.patch_embed = PatchEmbedding(patch_size, in_channels, embed_dim)

        # 特殊的 [CLS] token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # 位置嵌入
        num_patches = (224 // patch_size) ** 2
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

        # Transformer 编码器
        encoder_layer = nn.TransformerEncoderLayer(embed_dim, num_heads)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)

        # 分类器头
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.size(0)

        # 分割图像到patch并嵌入
        x = self.patch_embed(x)

        # 添加 [CLS] token
        cls_token = self.cls_token.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)

        # 添加位置嵌入
        x += self.pos_embed

        # 通过 Transformer
        x = self.transformer(x)

        # 只取 [CLS] 对应的输出用于分类任务
        x = x[:, 0]

        # 分类器
        x = self.fc(x)

        return x

# 参数
patch_size = 16
in_channels = 3
embed_dim = 768
num_heads = 12
num_layers = 12
num_classes = 1000  # 假设是一个1000分类问题

# 初始化模型
model = DeiT(patch_size, in_channels, embed_dim, num_heads, num_layers, num_classes)

# 假数据
x = torch.randn(32, 3, 224, 224)  # 32张3通道224x224大小的图片

# 模型前向推断
logits = model(x)

4.1.2 知识蒸馏

知识蒸馏(Knowledge Distillation, KD)是一种模型压缩技术,用于将一个大型、复杂模型(通常称为“教师模型”)的知识转移到一个更小、更简单的模型(通常称为“学生模型”)中。这样做的目的是在保持与大型模型相近的性能的同时,降低模型大小和推断时间。

工作原理

  • 教师模型: 通常是一个预先训练好的大型模型,用于生成软标签(soft labels),即类别概率分布。
  • 学生模型: 通常是一个相对较小的模型,需要被训练来模仿教师模型
  • 蒸馏损失: 在最基础的知识蒸馏中,学生模型的训练不仅要最小化与真实标签之间的损失(如交叉熵损失),还要最小化与教师模型预测的软标签之间的损失

简单的知识蒸馏代码示例
假设我们有一个教师模型(teacher_model)和一个学生模型(student_model),下面是一个使用PyTorch进行知识蒸馏的简单示例:

import torch
import torch.nn.functional as F

# 假定 teacher_model 和 student_model 已经定义并初始化
# teacher_model = ...
# student_model = ...

# 数据加载器
# data_loader = ...

# 优化器
optimizer = torch.optim.Adam(student_model.parameters(), lr=0.001)

# 温度参数和软标签权重
temperature = 2.0
alpha = 0.9

# 训练循环
for data, labels in data_loader:
    optimizer.zero_grad()

    # 正向传播:教师和学生模型
    teacher_output = teacher_model(data).detach()  # 注意:通常不会计算教师模型的梯度
    student_output = student_model(data)

    # 计算损失
    hard_loss = F.cross_entropy(student_output, labels)  # 与真实标签的损失
    soft_loss = F.kl_div(F.log_softmax(student_output/temperature, dim=1),
                         F.softmax(teacher_output/temperature, dim=1))  # 与软标签的损失

    loss = alpha * soft_loss + (1 - alpha) * hard_loss

    # 反向传播和优化
    loss.backward()
    optimizer.step()

应用场景
知识蒸馏不仅适用于模型压缩,在一些特定应用中也能用于提高小型模型的性能,例如在DeiT(Data-efficient Image Transformer)中用于提高数据效率。

4.1.3 利用知识蒸馏进行优化的Transformer模型

以下我们假设有一个已经训练好的大型 Transformer 模型(教师模型),以及一个更小的 Transformer 模型(学生模型)。

注意:这里为了简单,我们使用 nn.Transformer 模块作为 Transformer 的简单实现。你也可以根据需要替换为更复杂的模型。

损失函数包含两部分:一部分是学生模型和实际标签之间的损失,另一部分是学生和教师模型输出之间的 Kullback-Leibler 散度。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# 定义简单的 Transformer 模型
class SimpleTransformer(nn.Module):
    def __init__(self, d_model, nhead, num_layers, num_classes):
        super(SimpleTransformer, self).__init__()
        self.encoder = nn.Transformer(d_model, nhead, num_layers)
        self.classifier = nn.Linear(d_model, num_classes)
    
    def forward(self, x):
        x = self.encoder(x)
        x = x.mean(dim=1)
        x = self.classifier(x)
        return x

# 定义损失函数
def distillation_loss(y, labels, teacher_output, T=2.0, alpha=0.5):
    return nn.CrossEntropyLoss()(y, labels) * (1. - alpha) + (alpha * T * T) * nn.KLDivLoss()(F.log_softmax(y/T, dim=1),
                                                     F.softmax(teacher_output/T, dim=1))

# 假设我们有一些数据
# 注意:这里使用随机数据仅作为示例
N = 100  # 数据点数量
d_model = 32  # 嵌入维度
nhead = 2  # 多头注意力的头数
num_layers = 2  # Transformer 层的数量
num_classes = 10  # 分类数
T = 2.0  # 温度参数
alpha = 0.5  # 蒸馏损失的权重因子

x = torch.randn(N, 10, d_model)
labels = torch.randint(0, num_classes, (N,))

# 初始化教师和学生模型
teacher_model = SimpleTransformer(d_model, nhead, num_layers, num_classes)
student_model = SimpleTransformer(d_model, nhead, num_layers, num_classes)

# 设置优化器
optimizer = optim.Adam(student_model.parameters(), lr=0.001)

# 模拟训练过程
for epoch in range(10):
    # 前向传播
    teacher_output = teacher_model(x).detach()  # 通常来说,教师模型是预先训练好的,因此不需要计算梯度
    student_output = student_model(x)
    
    # 计算损失
    loss = distillation_loss(student_output, labels, teacher_output, T, alpha)
    
    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

4.2 Hybrid models (ViT + CNN)

混合模型(Hybrid models)结合了 Vision Transformer(ViT)和卷积神经网络(CNN)的优点,以实现更强大的图像识别能力。这类模型通常使用 CNN 作为特征提取器,将其输出用作 ViT 的输入。

4.2.1 为什么使用混合模型?

  1. 局部与全局特性: CNN 非常擅长捕获局部特性,而 Transformer 能够处理全局依赖关系。将两者结合可以更全面地理解图像。
  2. 计算效率: CNN 在处理图像数据方面通常更加高效。通过在模型前端使用 CNN,可以降低 Transformer 的计算复杂性。
  3. 数据效率: 使用 CNN 的预训练特征可以提高模型的数据效率,这对于训练数据较少的任务特别有用。

4.2.2 基础架构

在一个典型的混合模型中,CNN 通常用作特征提取器,而 ViT 用作特征编码和分类。

  1. 特征提取: 使用 CNN 层(可能是一个预训练的网络,比如 ResNet 或 VGG)从输入图像中提取特征。
  2. 图像分块与嵌入: 将 CNN 的输出分块,并通过线性嵌入层(或其他方法)转换为适用于 Transformer 的序列。
  3. Transformer 编码: 使用 ViT 进行特征的进一步编码。
  4. 分类头: 最后,使用全连接层进行分类。

4.2.3 示例

import torch
import torch.nn as nn

# 假设使用 ResNet 的某个版本作为特征提取器
class FeatureExtractor(nn.Module):
    def __init__(self, ...):
        super().__init__()
        # 定义 CNN 结构,例如一个简化的 ResNet
        ...

    def forward(self, x):
        # 通过 CNN 提取特征
        return x

# ViT 作为编码器
class ViTEncoder(nn.Module):
    def __init__(self, ...):
        super().__init__()
        # 定义 Transformer 结构
        ...

    def forward(self, x):
        # 通过 Transformer 编码特征
        return x

# 混合模型
class HybridModel(nn.Module):
    def __init__(self, ...):
        super().__init__()
        self.feature_extractor = FeatureExtractor(...)
        self.vit_encoder = ViTEncoder(...)
        self.classifier = nn.Linear(...)

    def forward(self, x):
        x = self.feature_extractor(x)  # CNN 特征提取
        x = self.vit_encoder(x)  # Transformer 编码
        x = self.classifier(x)  # 分类头
        return x

4.3 Swin Transformer

Swin Transformer 是一种用于计算机视觉任务的 Transformer 架构,提出了一种基于滑窗(sliding window)的自注意机制。这种方法结合了卷积神经网络(CNN)和 Transformer 的优点,旨在实现更高的模型效率和性能。

4.3.1 主要特点

  1. 分层特征提取: 与 CNN 类似,Swin Transformer 进行多层特征提取,每一层都会降采样,但是这里是通过 Transformer 实现的。
  2. 滑窗自注意: Swin Transformer 使用了滑窗自注意机制,该机制只考虑局部的上下文信息,而不是传统 Transformer 中的全局上下文信息。这减少了计算复杂性
  3. 分块与合并: 在多个层级中,Swin Transformer 通过分块和合并的方式,逐步减少序列的长度,并增加特征维度,以达到更高级别的特征提取。
  4. 灵活性: Swin Transformer 可以被用于多种计算机视觉任务,如图像分类、目标检测和语义分割等。

4.3.2 基础架构

  1. Patch Embedding: 将图像分割成多个小块(patches),然后用线性嵌入层进行嵌入。
  2. Swin Transformer Blocks: 包括多个 Swin Transformer 层,每一层都有一个或多个滑窗自注意机制和前馈神经网络。
  3. Head: 根据具体任务(如分类、检测等),在 Swin Transformer 的最后一层添加不同的头部结构。

4.3.3 代码示例

  • PatchEmbedding: 这部分负责将输入图像切割成小块并进行嵌入。
  • WindowAttention: 这是 Swin Transformer 特有的,用于在局部窗口内进行自注意力计算。
  • SwinBlock: 包括一个窗口注意力层和一个多层感知机(MLP)。
  • SwinTransformer: 最终的模型架构。
import torch
import torch.nn as nn
import torch.nn.functional as F

# 切分图像为patches
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, out_dim, patch_size):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.conv(x)
        x = x.flatten(2).transpose(1, 2)
        return x

# 滑窗注意力
class WindowAttention(nn.Module):
    def __init__(self, dim, heads, window_size):
        super().__init__()
        self.dim = dim
        self.heads = heads
        self.window_size = window_size

        self.query = nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)

    def forward(self, x):
        # 假设 x 的形状为 [batch_size, num_patches, dim]
        # 分割为多个窗口
        windows = x.view(x.size(0), self.window_size, self.window_size, self.dim)

        # 计算 q, k, v
        q = self.query(windows)
        k = self.key(windows)
        v = self.value(windows)

        # 注意力计算
        attn = torch.einsum('bqhd,bkhd->bhqk', q, k)
        attn = F.softmax(attn, dim=-1)

        # 输出
        out = torch.einsum('bhqk,bkhd->bqhd', attn, v)
        out = out.contiguous().view(x.size(0), self.window_size * self.window_size, self.dim)

        return out

# Swin Transformer Block
class SwinBlock(nn.Module):
    def __init__(self, dim, heads, window_size):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, heads, window_size)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Linear(dim, dim)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

# Swin Transformer 模型
class SwinTransformer(nn.Module):
    def __init__(self, in_channels, out_dim, patch_size, num_classes):
        super().__init__()
        self.patch_embedding = PatchEmbedding(in_channels, out_dim, patch_size)

        # 假设我们有 4 个 Swin Blocks 和窗口大小为 8
        self.blocks = nn.ModuleList([
            SwinBlock(out_dim, 8, 8) for _ in range(4)
        ])

        self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(out_dim, num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        for block in self.blocks:
            x = block(x)
        x = self.global_avg_pool(x.mean(dim=1))
        x = self.fc(x.squeeze(-1))
        return x

# 测试模型
if __name__ == '__main__':
    model = SwinTransformer(3, 128, 4, 10)
    x = torch.randn(16, 3, 32, 32)  # 假设有 16 张 32x32 的图像
    y = model(x)
    print(y.shape)  # 应该输出 torch.Size([16, 10])

5. ViT的优点与缺点

5.1 与CNN相比的优点

  1. 更好的长距离依赖处理: Transformer 架构设计初衷就是用来捕捉长距离依赖,这在某些复杂的图像识别任务中是非常有用的。
  2. 参数效率: ViT 有潜力以较少的参数量达到与 CNN 相同的性能
  3. 可解释性: 自注意力机制的输出可用于分析模型对于图像各部分的关注程度,有助于模型解释。
  4. 灵活性和泛化: Transformer 不依赖于固定大小的滤波器或局部区域,因此有潜力更好地泛化到不同类型和结构的视觉数据。
  5. 端到端训练: 与某些需要特别设计的 CNN 架构相比,ViT 可以从头到尾用一个统一的架构进行训练。

5.2 ViT的挑战和限制

  1. 计算复杂性: 对于大型图像,全局自注意力机制的计算复杂性可能非常高。这也是为什么一开始 ViT 主要用在 NLP 领域的原因之一。
  2. 数据依赖: ViT 通常需要大量的标注数据来进行有效训练。这在没有大量带标签数据的场景下可能是一个问题。
  3. 训练不稳定: Transformer 架构通常比 CNN 更难训练,尤其是在没有充足计算资源和数据的情况下。
  4. 局部特征处理不如 CNN: 由于没有内置的卷积操作,ViT 可能在某些依赖于局部特征的任务(例如纹理识别)中不如专门设计的 CNN。
  5. 内存消耗: 尤其是在大图像或长序列上,Transformer 模型(包括 ViT)通常需要更多的内存
  6. 过拟合风险: 由于模型复杂性和参数量通常较大,ViT 更容易过拟合,尤其是在数据量较少的情况下。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/971264.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++将派生类赋值给基类

在 C/C++ 中经常会发生数据类型的转换,例如将 int 类型的数据赋值给 float 类型的变量时,编译器会先把 int 类型的数据转换为 float 类型再赋值;反过来,float 类型的数据在经过类型转换后也可以赋值给 int 类型的变量。 数据类型转换的前提是,编译器知道如何对数据进行取舍…

星际争霸之小霸王之小蜜蜂(十)--鼠道

系列文章目录 星际争霸之小霸王之小蜜蜂(九)--狂鼠之灾 星际争霸之小霸王之小蜜蜂(八)--蓝皮鼠和大脸猫 星际争霸之小霸王之小蜜蜂(七)--消失的子弹 星际争霸之小霸王之小蜜蜂(六)-…

大数据课程K20——Spark的SparkSql概述

文章作者邮箱:yugongshiye@sina.cn 地址:广东惠州 ▲ 本章节目的 ⚪ 了解Spark的SparkSQL由来; ⚪ 了解Spark的SparkSQL特点; ⚪ 了解Spark的SparkSQL优势; ⚪ 掌握Spark的SparkSQL入门; 一、SparkSQL概述 1. 概述 Spark为结构化数据处理引入了一个称…

SQLI-labs-第四关

知识点:get双引号(")和括号注入 思路: 1、判断注入点 首先,输入?id1 --,看看正常的回显状态 接着输入?id1 --,结果还是正常回显,说明这里不存在单引号问题 试试双引号,这里爆出了sql语…

深入了解GCC编译过程

关于Linux的编译过程,其实只需要使用gcc这个功能,gcc并非一个编译器,是一个驱动程序。其编译过程也很熟悉:预处理–编译–汇编–链接。在接触底层开发甚至操作系统开发时,我们都需要了解这么一个知识点,如何…

C# 如何读取dxf档案

需求来源: 工作中,客户提供一张CAD导出的dxf 档案,然后需要机器人将其转成点位,走到对应的位置。 下面介绍一下dxf档案到底是什么?以及语法规则。 dxf 格式介绍:DXF 格式 dxf LINE 格式。 其实上述文档…

软考:中级软件设计师:多媒体基础,音频,图像,颜色,多媒体技术的种类,图像音频视频的容量计算,常见的多媒体标准

软考:中级软件设计师:多媒体基础 提示:系列被面试官问的问题,我自己当时不会,所以下来自己复盘一下,认真学习和总结,以应对未来更多的可能性 关于互联网大厂的笔试面试,都是需要细心准备的 &am…

在公网上使用SSH远程连接安卓手机Termux:将Android手机变身为远程服务器

文章目录 前言1.安装ssh2.安装cpolar内网穿透3.远程ssh连接配置4.公网远程连接5.固定远程连接地址 前言 使用安卓机跑东西的时候,屏幕太小,有时候操作不习惯。不过我们可以开启ssh,使用电脑PC端SSH远程连接手机termux。 本次教程主要实现在…

介绍OpenCV

OpenCV是一个开源计算机视觉库,可用于各种任务,如物体识别、人脸识别、运动跟踪、图像处理和视频处理等。它最初由英特尔公司开发,目前由跨学科开发人员社区维护和支持。OpenCV可以在多个平台上运行,包括Windows、Linux、Android和…

高等数学刷题

分段函数主要看在临界点处的左右极限是否相等,若相等则整段函数即为连续 (反之若是连续函数,在某一点为间断点,则可推导出一定为可去间断点) 无定义的点一定为间断点 如果该点有极限则为可去间断点 由于x的不确定导…

RK3568-spi-适配1.8寸TFT彩屏驱动芯片st7735s

RK3568-spi-适配1.8寸TFT彩屏 驱动芯片st7735s 显示分辨率128x160硬件连接 VCC -- 3.3V GND -- GND BL -- 背光控制 CS -- 片选引脚 DC -- 数据/命令控制 RES -- 屏幕复位 SCL -- i2c时钟引脚 SDA -- i2c数据引脚设备树编写 &spi0 {pinctrl-names = "default"…

docker安装Apache NIFI

说明 系统:CentOS7.9 nifi版本:1.23.2 下载镜像 nifi的镜像比较大,大概有2G左右,下载时间根据个人网速而定 docker pull apache/nifi:1.23.2 查看下载好的镜像 docker images 复制容器数据 创建挂载目录 创建挂载目录的目…

【Yolov5+Deepsort】训练自己的数据集(3)| 目标检测追踪 | 轨迹绘制 | 报错分析解决

📢前言:本篇是关于如何使用YoloV5Deepsort训练自己的数据集,从而实现目标检测与目标追踪,并绘制出物体的运动轨迹。本章讲解的为第三部分内容:数据集的制作、Deepsort模型的训练以及动物运动轨迹的绘制。本文中用到的数…

qt day 6

登录界面 #include "window.h" #include<QDebug> #include<QIcon> Window::Window(QWidget *parent) //构造函数的定义: QWidget(parent) //显性调用父类的构造函数 {//判断数据库对象是否包含了自己使用的数据库Student.dbif(!db.contains(&…

20230904 QT客户端服务器搭建聊天室

Ser cpp#include "app.h" #include "ui_app.h"APP::APP(QWidget *parent):QWidget(parent),ui(new Ui::APP) {ui->setupUi(this);this->resize(550,400);ui->Line->setAlignment(Qt::AlignCenter);//标签文本对齐方式 居中ui->Line->se…

爬虫源码---爬取自己想要看的小说

前言&#xff1a; 小说作为在自己空闲时间下的消遣工具&#xff0c;对我们打发空闲时间很有帮助&#xff0c;而我们在网站上面浏览小说时会被广告和其他一些东西影响我们的观看体验&#xff0c;而这时我们就可以利用爬虫将我们想要观看的小说下载下来&#xff0c;这样就不会担…

突破瓶颈:如何应对高级职位的面试

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

K8S自动化运维容器Docker集群

K8S&#xff1a;K8S自动化运维容器化(Docker)集群 一.k8s概述 1.k8s是什么 &#xff08;1&#xff09;K8S全程为Kubernetes&#xff0c;由于K到S直接有8个字母简称为K8S。 &#xff08;2&#xff09;版本&#xff1a;目前一般是1.18~1.2.0&#xff0c;后续可能会到1.24-1.2…

Qt 开发 CMake工程

Qt 入门实战教程&#xff08;目录&#xff09; 为何要写这篇文章 目前CMake作为C/C工程的构建方式在开源社区已经成为主流。 企业中也是能用CMake的尽量在用。 Windows 环境下的VC工程都是能不用就不用。 但是&#xff0c;这个过程是非常缓慢的&#xff0c;所以&#xff0…

Linux 学习笔记(2)—— 关于文件和目录

目录 1、切换目录 2、查看系统信息 3、文本的创建和编辑 3-1&#xff09;创建文件 3-2&#xff09;查看文件 3-3&#xff09;输出重定向和追加重定向 3-4&#xff09;使用 vi 编辑器编辑文件 4、文件和文件夹的处理 4-1&#xff09;对文件的处理 4-2&#xff09;查看…