昇思25天学习打卡营第16天 | Vision Transformer图像分类

news2024/9/22 17:22:58

昇思25天学习打卡营第16天 | Vision Transformer图像分类

文章目录

  • 昇思25天学习打卡营第16天 | Vision Transformer图像分类
    • Vision Transform(ViT)模型
      • Transformer
        • Attention模块
        • Encoder模块
      • ViT模型输入
    • 模型构建
      • Multi-Head Attention模块
      • Encoder模块
      • Patch Embedding模块
      • ViT网络
    • 总结
    • 打卡

Vision Transform(ViT)模型

ViT是NLP和CV领域的融合,可以在不依赖于卷积操作的情况下在图像分类任务上达到很好的效果。

ViT模型的主体结构是基于Transformer的Encoder部分。

Transformer

Transformer由很多Encoder和Decoder模块构成,包括多头注意力(Multi-Head Attention)层,Feed Forward层,Normalization层和残差连接(Residual Connection)。
encoder-decoder
多头注意力结构基于自注意力机制(Self-Attention),是多个Self-Attention的并行组成。

Attention模块

Attention的核心在于为输入向量的每个单词学习一个权重。

  1. 最初的输入向量首先经过Embedding层映射为Q(Query),K(Key),V(Value)三个向量。
  2. 通过将Q和所有K进行点乘初一维度平方根,得到向量间的相似度,通过softmax获取每词向量之间的关系权重。
  3. 利用关系权重对词向量的V加权求和,得到自注意力值。
    self-attention
    多头注意力机制只是对self-attention的并行化:
    multi-head-attention
Encoder模块

ViT中的Encoder相对于标准Transformer,主要在于将Normolization放在self-attention和Feed Forward之前,其他结构与标准Transformer相同。
vit-encoder

ViT模型输入

传统Transformer主要应用于自然语言处理的一维词向量,而图像时二维矩阵的堆叠。
在ViT中:

  1. 通过卷积将输入图像在每个channel上划分为 16 × 16 16\times 16 16×16个patch。如果输入 224 × 224 224\times224 224×224的图像,则每一个patch的大小为 14 × 14 14\times 14 14×14
  2. 将每一个patch拉伸为一个一维向量,得到近似词向量堆叠的效果。如将 14 × 14 14\times14 14×14展开为 196 196 196的向量。
    这一部分Patch Embedding用来替换Transformer中Word Embedding,用作网络中的图像输入。

模型构建

Multi-Head Attention模块

from mindspore import nn, ops


class Attention(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_heads: int = 8,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0):
        super(Attention, self).__init__()

        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = ms.Tensor(head_dim ** -0.5)

        self.qkv = nn.Dense(dim, dim * 3)
        self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)
        self.out = nn.Dense(dim, dim)
        self.out_drop = nn.Dropout(p=1.0-keep_prob)
        self.attn_matmul_v = ops.BatchMatMul()
        self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
        self.softmax = nn.Softmax(axis=-1)

    def construct(self, x):
        """Attention construct."""
        b, n, c = x.shape
        qkv = self.qkv(x)
        qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
        qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = ops.unstack(qkv, axis=0)
        attn = self.q_matmul_k(q, k)
        attn = ops.mul(attn, self.scale)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        out = self.attn_matmul_v(attn, v)
        out = ops.transpose(out, (0, 2, 1, 3))
        out = ops.reshape(out, (b, n, c))
        out = self.out(out)
        out = self.out_drop(out)

        return out

Encoder模块

from typing import Optional, Dict


class FeedForward(nn.Cell):
    def __init__(self,
                 in_features: int,
                 hidden_features: Optional[int] = None,
                 out_features: Optional[int] = None,
                 activation: nn.Cell = nn.GELU,
                 keep_prob: float = 1.0):
        super(FeedForward, self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dense1 = nn.Dense(in_features, hidden_features)
        self.activation = activation()
        self.dense2 = nn.Dense(hidden_features, out_features)
        self.dropout = nn.Dropout(p=1.0-keep_prob)

    def construct(self, x):
        """Feed Forward construct."""
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.dense2(x)
        x = self.dropout(x)

        return x


class ResidualCell(nn.Cell):
    def __init__(self, cell):
        super(ResidualCell, self).__init__()
        self.cell = cell

    def construct(self, x):
        """ResidualCell construct."""
        return self.cell(x) + x

class TransformerEncoder(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_layers: int,
                 num_heads: int,
                 mlp_dim: int,
                 keep_prob: float = 1.,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: nn.Cell = nn.LayerNorm):
        super(TransformerEncoder, self).__init__()
        layers = []

        for _ in range(num_layers):
            normalization1 = norm((dim,))
            normalization2 = norm((dim,))
            attention = Attention(dim=dim,
                                  num_heads=num_heads,
                                  keep_prob=keep_prob,
                                  attention_keep_prob=attention_keep_prob)

            feedforward = FeedForward(in_features=dim,
                                      hidden_features=mlp_dim,
                                      activation=activation,
                                      keep_prob=keep_prob)

            layers.append(
                nn.SequentialCell([
                    ResidualCell(nn.SequentialCell([normalization1, attention])),
                    ResidualCell(nn.SequentialCell([normalization2, feedforward]))
                ])
            )
        self.layers = nn.SequentialCell(layers)

    def construct(self, x):
        """Transformer construct."""
        return self.layers(x)

Patch Embedding模块

class PatchEmbedding(nn.Cell):
    MIN_NUM_PATCHES = 4

    def __init__(self,
                 image_size: int = 224,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 input_channels: int = 3):
        super(PatchEmbedding, self).__init__()

        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)

    def construct(self, x):
        """Path Embedding construct."""
        x = self.conv(x)
        b, c, h, w = x.shape
        x = ops.reshape(x, (b, c, h * w))
        x = ops.transpose(x, (0, 2, 1))

        return x

ViT网络

from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameter


def init(init_type, shape, dtype, name, requires_grad):
    """Init."""
    initial = initializer(init_type, shape, dtype).init_data()
    return Parameter(initial, name=name, requires_grad=requires_grad)


class ViT(nn.Cell):
    def __init__(self,
                 image_size: int = 224,
                 input_channels: int = 3,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 num_layers: int = 12,
                 num_heads: int = 12,
                 mlp_dim: int = 3072,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: Optional[nn.Cell] = nn.LayerNorm,
                 pool: str = 'cls') -> None:
        super(ViT, self).__init__()

        self.patch_embedding = PatchEmbedding(image_size=image_size,
                                              patch_size=patch_size,
                                              embed_dim=embed_dim,
                                              input_channels=input_channels)
        num_patches = self.patch_embedding.num_patches

        self.cls_token = init(init_type=Normal(sigma=1.0),
                              shape=(1, 1, embed_dim),
                              dtype=ms.float32,
                              name='cls',
                              requires_grad=True)

        self.pos_embedding = init(init_type=Normal(sigma=1.0),
                                  shape=(1, num_patches + 1, embed_dim),
                                  dtype=ms.float32,
                                  name='pos_embedding',
                                  requires_grad=True)

        self.pool = pool
        self.pos_dropout = nn.Dropout(p=1.0-keep_prob)
        self.norm = norm((embed_dim,))
        self.transformer = TransformerEncoder(dim=embed_dim,
                                              num_layers=num_layers,
                                              num_heads=num_heads,
                                              mlp_dim=mlp_dim,
                                              keep_prob=keep_prob,
                                              attention_keep_prob=attention_keep_prob,
                                              drop_path_keep_prob=drop_path_keep_prob,
                                              activation=activation,
                                              norm=norm)
        self.dropout = nn.Dropout(p=1.0-keep_prob)
        self.dense = nn.Dense(embed_dim, num_classes)

    def construct(self, x):
        """ViT construct."""
        x = self.patch_embedding(x)
        cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))
        x = ops.concat((cls_tokens, x), axis=1)
        x += self.pos_embedding

        x = self.pos_dropout(x)
        x = self.transformer(x)
        x = self.norm(x)
        x = x[:, 0]
        if self.training:
            x = self.dropout(x)
        x = self.dense(x)

        return x

总结

这一节对Transformer进行介绍,包括Attention机制、并行化的Attention以及Encoder模块。由于传统Transformer主要作用于一维的词向量,因此二维图像需要被转换为类似的一维词向量堆叠,在ViT中通过将Patch Embedding解决这一问题,并用来代替传统Transformer中的Word Embedding作为网络的输入。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1932818.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java对象转换为JSON字符串

0 写在前面 业务中有很多场景需要 把一个带有数据的 Java对象/Java集合转换为JSON 存入数据库中。 在需要的时候还需要吧和这个JSON字符串拿出来再次转换成Java对象/集合 1 Java对象与JSON字符串互转 引入依赖: <dependency><groupId>com.alibaba</groupId&…

解决VMware虚拟机在桥接模式下无法上网的问题

解决VMware虚拟机在桥接模式下无法上网的问题 windows11系统自动启动了热点功能&#xff0c;开启热点可能会干扰虚拟机的桥接设置。 方法一&#xff1a;windows11可以提供网络热点服务 方法二&#xff1a;手动指定桥接的物理网卡 方法一&#xff1a;关闭热点功能 优点&#xff…

【Java项目笔记】01项目介绍

一、技术框架 1.后端服务 Spring Boot为主体框架 Spring MVC为Web框架 MyBatis、MyBatis Plus为持久层框架&#xff0c;负责数据库的读写 阿里云短信服务 2.存储服务 MySql redis缓存数据 MinIO为对象存储&#xff0c;存储非结构化数据&#xff08;图片、视频、音频&a…

【开发指南】HTML和JS编写多用户VR应用程序的框架

1.概述 Networked-Aframe 的工作原理是将实体及其组件同步到连接的用户。要连接到房间&#xff0c;您需要将networked-scene组件添加到a-scene元素。对于要同步的实体&#xff0c;请向其添加networked组件。默认情况下&#xff0c;position和rotation组件是同步的&#xff0c;…

【Spring Cloud】掌握Gateway核心技术,实现高效路由与转发

目录 前言示例创建一个服务提供者创建网关 创建common子项目 前言 Spring Cloud Gateway 是一个基于 Spring Boot 的非阻塞 API 网关服务&#xff0c;它提供了动态路由、请求断言、过滤器等功能。 以下是关于 Spring Cloud Gateway 的示例&#xff1a; 示例 创建一个服务提…

什么是 std::ios::sync_with_stdio(false)

介绍 std::ios::sync_with_stdio(false) 是 C 中的一个配置设置&#xff0c;用于控制标准 I/O 流&#xff08;如 std::cin, std::cout&#xff09;的行为。这个设置主要用于优化输入输出操作的性能&#xff0c;尤其是在处理大量数据时。 在 C 中&#xff0c;标准流库&#xf…

PHP连接MySQL数据库

PHP本身不具备操作MySQL数据库的能力&#xff0c;需要借助MySQL扩展来实现。 1、PHP加载MySQL扩展&#xff1a;php.ini文件中。&#xff08;不要用记事本打开&#xff09; 2、PHP中所有扩展都是在ext的文件夹中&#xff0c;需要指定扩展所在路径&#xff1a;extension_dir。 3、…

3D问界—MAYA制作铁丝栅栏(透明贴图法)

当然&#xff0c;如果想通过建立模型法来实现铁丝栅栏的效果&#xff0c;也不是不行&#xff0c;可以找一下栅栏建模教程。本篇文章主要是记录一下如何使用透明贴图来实现创建铁丝栅栏&#xff0c;主要应用于场景建模&#xff0c;比如游戏场景、建筑场景等大环境&#xff0c;不…

Spring3(代理模式 Spring1案例补充 Aop 面试题)

目录 一、代理模式 介绍 意图 主要解决的问题 使用场景 实现方式 关键代码 应用实例 优点 缺点 使用建议 注意事项 结构 什么是代理模式&#xff1f; 为什么要用代理模式&#xff1f; 有哪几种代理模式&#xff1f; 1. 静态代理 实现 2. 基于接口的动态代理…

基于python旅游景点满意度分析设计与实现

1.1研究背景与意义 1.1.1研究背景 随着旅游业的快速发展&#xff0c;满意度分析成为评估旅游景点质量和提升游客体验的重要手段。海口市作为中国的旅游城市之一&#xff0c;其旅游景点吸引了大量游客。然而&#xff0c;如何科学评估和提升海口市旅游景点的满意度&#xff0c;…

Qt创建列表,通过外部按钮控制列表的选中下移、上移以及左侧图标的显现

引言 项目中需要使用列表QListWidget,但是不能直接拿来使用。需要创建一个列表,通过向上和向下的按钮来向上或者向下移动选中列表项,当当前项背选中再去点击确认按钮,会在列表项的前面出现一个图标。 实现效果 本实例实现的效果如下: 实现思路 思路一 直接采用QLis…

Spring Security之安全异常处理

前言 在我们的安全框架中&#xff0c;不管是什么框架&#xff08;包括通过过滤器自定义&#xff09;都需要处理涉及安全相关的异常&#xff0c;例如&#xff1a;登录失败要跳转到登录页&#xff0c;访问权限不足要返回页面亦或是json。接下来&#xff0c;我们就看看Spring Sec…

海外营销推广:快速创建维基百科(wiki)词条-大舍传媒

一、维基百科的永久留存问题 许多企业和个人关心维基百科是否能永久留存。实际上&#xff0c;只要企业和个人的行为没有引起维基百科管理方的反感&#xff0c;词条就可以长期保存。如果有恶意行为或被投诉&#xff0c;维基百科可能会对词条进行删除或修改。 二、创建维基百科…

为fooocus v2.5.0安装groundingdino

在win10下折就fooocus&#xff0c;使用git pull命令更新本地&#xff0c;然后…\python_embeded\python.exe -m pip install -r .\requirements_versions.txt更新依赖关系包。 卡在groundingdino的安装上&#xff0c;先在requirements_versions.txt中删除它&#xff0c;安装其他…

第十课:telnet(远程登入)

如何远程管理网络设备&#xff1f; 只要保证PC和路由器的ip是互通的&#xff0c;那么PC就可以远程管理路由器&#xff08;用telnet技术管理&#xff09;。 我们搭建一个下面这样的简单的拓扑图进行介绍 首先我们点击云&#xff0c;把云打开&#xff0c;点击增加 我们绑定vmn…

线程的中断和同步问题

1、自动终断【完成】&#xff1a;一个线程完成执行后&#xff08;即run方法执行完毕&#xff09;&#xff0c;不能再次运行 。 2、手动中断&#xff1a; stop( ) —— 已过时&#xff0c;基本不用。&#xff08;不安全&#xff0c;就像是突然停电&#xff09; interrupt( ) …

VTK----3D picking的原理、类型及实现

目录 3D picking概述 3D射线投射原理 VTK picking框架 vtkPicker(选Actor) vtkPointPicker(选点) vtkCellPicker(选单元) vtkAreaPicker(框选) 3D picking概述 3D picking 是一种在三维场景中确定用户点击或指向的对象的技术。这在3D应用程序和游戏中非常常见,…

CentOS 7 初始化环境配置详细

推荐使用xshell远程连接&#xff0c;如链接不上 请查看 CentOS 7 网络配置 修改主机名 hostname hostnamectl set-hostname xxx bash 关闭 SElinux 重启之后生效 配置yum源&#xff08;阿里&#xff09; 先备份CentOS-Base.repo&#xff0c;然后再下载 mv /etc/yum.repos…

MySQL学习记录 —— 이십이 MySQL服务器日志

文章目录 1、日志介绍2、一般、慢查询日志1、一般查询日志2、慢查询日志FILE格式TABLE格式 3、错误日志4、二进制日志5、日志维护 1、日志介绍 中继服务器的数据来源于集群中的主服务。每次做一些操作时&#xff0c;把操作保存到重做日志&#xff0c;这样崩溃时就可以从重做日志…

STM32(六):STM32指南者-定时器实验

目录 一、基本概念1、常规定时器2、内核定时器 二、基本定时器实验1、实验说明2、编程过程&#xff08;1&#xff09;配置LED&#xff08;2&#xff09;配置定时器&#xff08;3&#xff09;设定中断事件&#xff08;4&#xff09;主函数计数 3、工程代码 三、通用定时器实验实…