昇思25天学习打卡营第16天|应用实践之Vision Transformer图像分类

news2024/9/27 19:15:38

基本介绍

        今天同样是图像分类任务,也更换了模型,使用的时候计算机视觉版的Transformer,即Vision Transformer,简称ViT。Transformer本是应用于自然语言处理领域的模型,用于处理语言序列,而要将其应用于图像,关键在于如何把图像转化为序列数据。ViT对图像进行分割,分割成一个一个的patch,然后再加上空间编码,由此便把图像转化为序列数据,使得Transformer也可以应用于计算机视觉领域。下面会对ViT进行简单介绍,然后使用ImageNet数据集进行训练,简单训练10轮,并进行推理,以进一步了解ViT。

Vision Transformer简介

        ViT模型的主体结构是基于Transformer模型的Encoder部分,结构如下图所示

ViT由Transformer变换而成,而Transformer的核心是Self-Attention,要学习ViT就得搞懂Self-Attention。Self-Attention的核心内容是为输入向量的每个单词学习一个权重。通过给定一个任务相关的查询向量Query向量,计算Query和各个Key的相似性或者相关性得到注意力分布,即得到每个Key对应Value的权重系数,然后对Value进行加权求和得到最终的Attention数值。具体如下(以下的Self-Attention计算过程来自MindSpore官方教程,并非本人原创):

1. 最初的输入向量首先会经过Embedding层映射成Q(Query),K(Key),V(Value)三个向量,由于是并行操作,所以代码中是映射成为dim x 3的向量然后进行分割,换言之,如果你的输入向量为一个向量序列(𝑥1,𝑥2,𝑥3),其中的𝑥1,𝑥2,𝑥3都是一维向量,那么每一个一维向量都会经过Embedding层映射出Q,K,V三个向量,只是Embedding矩阵不同,矩阵参数也是通过学习得到的。这里大家可以认为,Q,K,V三个矩阵是发现向量之间关联信息的一种手段,需要经过学习得到,至于为什么是Q,K,V三个,主要是因为需要两个向量点乘以获得权重,又需要另一个向量来承载权重向加的结果,所以,最少需要3个矩阵。

2. 自注意力机制的自注意主要体现在它的Q,K,V都来源于其自身,也就是该过程是在提取输入的不同顺序的向量的联系与特征,最终通过不同顺序向量之间的联系紧密性(Q与K乘积经过Softmax的结果)来表现出来。Q,K,V得到后就需要获取向量间权重,需要对Q和K进行点乘并除以维度的平方根,对所有向量的结果进行Softmax处理,通过公式(2)的操作,我们获得了向量之间的关系权重

3. 其最终输出则是通过V这个映射后的向量与Q,K经过Softmax结果进行weight sum获得,这个过程可以理解为在全局上进行自注意表示。每一组Q,K,V最后都有一个V输出,这是Self-Attention得到的最终结果,是当前向量在结合了它与其他向量关联权重后得到的结果。

有了Self-Attention结构之后,通过与Feed Forward,Residual Connection等结构的拼接就可以形成Transformer的基础结构,如下图所示

ViT就是由上述的结构搭建而成。ViT的完整使用流程如下:

对ViT有了基本了解后,我们上手代码,加深理解。ViT(MindSpore版)的代码如下:

class ViT(nn.Cell):
    def __init__(self,
                 image_size: int = 224,
                 input_channels: int = 3,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 num_layers: int = 12,
                 num_heads: int = 12,
                 mlp_dim: int = 3072,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: Optional[nn.Cell] = nn.LayerNorm,
                 pool: str = 'cls') -> None:
        super(ViT, self).__init__()

        self.patch_embedding = PatchEmbedding(image_size=image_size,
                                              patch_size=patch_size,
                                              embed_dim=embed_dim,
                                              input_channels=input_channels)
        num_patches = self.patch_embedding.num_patches

        self.cls_token = init(init_type=Normal(sigma=1.0),
                              shape=(1, 1, embed_dim),
                              dtype=ms.float32,
                              name='cls',
                              requires_grad=True)

        self.pos_embedding = init(init_type=Normal(sigma=1.0),
                                  shape=(1, num_patches + 1, embed_dim),
                                  dtype=ms.float32,
                                  name='pos_embedding',
                                  requires_grad=True)

        self.pool = pool
        self.pos_dropout = nn.Dropout(p=1.0-keep_prob)
        self.norm = norm((embed_dim,))
        self.transformer = TransformerEncoder(dim=embed_dim,
                                              num_layers=num_layers,
                                              num_heads=num_heads,
                                              mlp_dim=mlp_dim,
                                              keep_prob=keep_prob,
                                              attention_keep_prob=attention_keep_prob,
                                              drop_path_keep_prob=drop_path_keep_prob,
                                              activation=activation,
                                              norm=norm)
        self.dropout = nn.Dropout(p=1.0-keep_prob)
        self.dense = nn.Dense(embed_dim, num_classes)

    def construct(self, x):
        """ViT construct."""
        x = self.patch_embedding(x)
        cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))
        x = ops.concat((cls_tokens, x), axis=1)
        x += self.pos_embedding

        x = self.pos_dropout(x)
        x = self.transformer(x)
        x = self.norm(x)
        x = x[:, 0]
        if self.training:
            x = self.dropout(x)
        x = self.dense(x)

        return x

模型训练

       由于数据集准备并不难,所以不做展示,直接使用模型进行训练,训练代码如下:

# define super parameter
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()

# construct model
network = ViT()

# load ckpt
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"

vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)

# define learning rate
lr = nn.cosine_decay_lr(min_lr=float(0),
                        max_lr=0.00005,
                        total_step=epoch_size * step_size,
                        step_per_epoch=step_size,
                        decay_epoch=10)

# define optimizer
network_opt = nn.Adam(network.trainable_params(), lr, momentum)


# define loss function
class CrossEntropySmooth(LossBase):
    """CrossEntropy."""

    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
        super(CrossEntropySmooth, self).__init__()
        self.onehot = ops.OneHot()
        self.sparse = sparse
        self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
        self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)

    def construct(self, logit, label):
        if self.sparse:
            label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
        loss = self.ce(logit, label)
        return loss


network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)

# set checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)

# initialize model
# "Ascend + mixed precision" can improve performance
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")

# train model
model.train(epoch_size,
            dataset_train,
            callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],
            dataset_sink_mode=False,)

完整训练的话起码有80个轮次,时间太长,再加上我们使用了预训练参数,所以我们只训练10轮

模型验证

        与训练过程相似,首先进行数据增强,然后定义ViT网络结构,加载预训练模型参数。随后设置损失函数,评价指标等,编译模型后进行验证。本案例采用了业界通用的评价标准Top_1_Accuracy和Top_5_Accuracy评价指标来评价模型表现。模型表现如下:

因为预训练参数的原因,效果还是不错的

模型推理

        使用一张杜宾犬的图片进行预测,结果如下,是准确的。

总结

        今日学习使用ViT,若之前对Attention完全没有了解,直接上手难度很大的,不过官方文档写的很好,加上本人有些Transformer的基础,所以认真花费一些时间,结合代码,对ViT的结构和流程有了一个基本了解。ViT可以应用的任务很多,希望下次可以尝试将其应用到目标检测。

Jupyter在线运行情况

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1914402.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

百度网盘青春版网页版上线

不知道还有多少小伙伴记得百度网盘曾经出过一个青春版,原因是21年相关部门发布通知《工业和信息化部关于开展信息通信服务感知提升行动的通知》其中就有一条: 明确指出网盘向免费用户提供的上传和下载最低速率应满足基本的下载需求 正所谓上有政策下有对…

常用的设计模式和使用案例汇总

常用的设计模式和使用案例汇总 【一】常用的设计模式介绍【1】设计模式分类【2】软件设计七大原则(OOP原则) 【二】单例模式【1】介绍【2】饿汉式单例【3】懒汉式单例【4】静态内部类单例【5】枚举(懒汉式) 【三】工厂方法模式【1】简单工厂模式&#xf…

AI绘画小白必备!Stable Diffusion常用插件合集,好用推荐!(附插件下载)

前言 宝子们,早上好啊~Stable Diffusion 常用插件,月月已经给大家整理好了,自取就好。 拥有这些SD常用插件,让您的图像生成和编辑过程更加强大、直观、多样化。以下插件集成了一系列增强功能,覆盖从自动补全提示词到…

设置DepthBufferBits和设置DepthStencilFormat的区别

1)设置DepthBufferBits和设置DepthStencilFormat的区别 2)Unity打包exe后,游戏内拉不起Steam的内购 3)Unity 2022以上Profiler.FlushMemoryCounters耗时要怎么关掉 4)用GoodSky资产包如何实现昼夜播发不同音乐功能 这是…

东旭蓝天被控股股东占用78亿:近七年业绩奇差,或面临退市

《港湾商业观察》施子夫 张楠 在7月5日一口气发了超过30份公告后,终于让投资者对于东旭蓝天2023年和今年一季度经营业绩有了更清晰的观察。 与此同时,东旭蓝天(下称)也收到了深交所的关注函。种种不利因素之下,上市…

【竞技宝 】欧洲杯:赛事水货盘点

本届欧洲杯接近尾声,有些球员抓住机会趁势崛起,踢出了身价。可惜还有一些球员的表现无法让球迷和媒体满意,下面我们就来盘点下本届欧洲杯的水货球员,看看哪些人因为糟糕的表现上榜? 格瓦迪奥尔(克罗地亚) 本届欧洲杯是克罗地亚黄金一代球员的谢幕之战,原本格瓦迪奥尔作为球队…

凌凯科技前五大客户依赖症加剧:研发费用率骤降,应收账款大增

《港湾商业观察》黄懿 6月13日,上海凌凯科技股份有限公司(下称“凌凯科技”)在港交所提交上市申请,拟于主板上市,华泰国际为其独家保荐人。 凌凯科技致力于提供小分子化合物技术和产品解决方案,专注于制药…

探索东芝 TCD1304DG 线性图像传感器的功能

主要特性 高灵敏度和低暗电流 TCD1304DG 具有高灵敏度和低暗电流,非常适合需要精确和可靠图像捕捉的应用。传感器包含 3648 个光敏元件,每个元件尺寸为 8 m x 200 m,确保了出色的光灵敏度和分辨率。 电子快门功能 内置的电子快门功能是 T…

【onnx】onnxruntime-gpu无法使用问题

every blog every motto: You can do more than you think. https://blog.csdn.net/weixin_39190382?typeblog 0. 前言 onnxruntime-gpu无法使用 1. 正文 CUDA版本:12.1 nvcc -VCUDNN的版本 cat /usr/include/cudnn_version.h |grep CUDNN_MAJOR -A 2说明: 可…

免费的SSL证书能使用吗

SSL证书为网站提供数据安全加密,保护数据传输,提升用户信任。 现在免费的SSL证书还能使用吗?答案是肯定的。个人博客、个人的网站目前使用免费SSL证书的居多,另外一些单位在网站上线前,也会使用免费SSL证书对网站进行…

品牌策划学习资源全攻略:从入门到精通的推荐清单!

这里再分享一些网站书籍和杂志给大家。 TOPYS创意内容平台: 专注于创意内容分享,涵盖广告、设计、艺术等多个领域,是广告设计人寻找创意灵感的好去处。 Dribbble: 设计师社区,用户可以浏览到全球设计师的优秀作品&…

低代码技术革新:高效构建现代人事管理系统

引言 在快速变化的商业环境中,企业必须不断提升其内部管理效率,以保持竞争力和灵活性。人事管理系统作为企业核心业务系统之一,承担着招聘、培训、绩效管理等重要功能,直接影响着企业的人才管理和运营效率。传统的人事管理系统通常…

Vue核心 — Vue2响应式原理和核心源码解析(核心中的核心)

一、前置知识 1、Vue 核心概念 Vue 是什么? Vue 是一款用于构建用户界面的 JavaScript 框架。它基于标准 HTML、CSS 和 JavaScript 构建,并提供了一套声明式的、组件化的编程模型,帮助你高效地开发用户界面。 Vue 核心特点是什么? 响应式数据绑定:…

Springboot助农农产品销售系统-计算机毕业设计源码16718

摘要 SpringBoot助农农产品销售系统旨在通过利用SpringBoot框架开发一个便捷高效的农产品销售平台。该系统包括用户注册登录、商品浏览、购物车管理、订单生成、支付功能等模块。通过整合支付接口、地图定位、推荐系统等技术,提供给用户更好的购物体验。本文介绍了…

考完软考之后,如何评职称?是否有有效期?

一、软考和职称之间的关系 软考和职称之间的关系可以这样理解:拿到软考证书并不意味着就能获得职称。软考证书是技术等级证书,而职称则是一种资格。如果单位聘用你做工程师,那么你的软考证书就可以发挥作用,相当于获得了职称证。…

私域运营从0到1冷启动

私域社群的冷启动是一个从无到有的过程,需要策略和耐心来吸引并维护用户。以下是一些步骤和策略,可以帮助你的私域社群实现从0到1的冷启动: 1. **明确目标和定位**: - 确定社群的目标用户和他们的需求。 - 明确社群的主题和…

【全面的LangChain入门指南】

🌈个人主页: 程序员不想敲代码啊 🏆CSDN优质创作者,CSDN实力新星,CSDN博客专家 👍点赞⭐评论⭐收藏 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共…

简过网:工程专业最吃香的6个证书,你考了几个了?

工程专业最吃香的6个证书,你考了几个了?我们一起来看看吧! 1、二级建造师 报考条件:工程类大专及以上学历/从事相关职业 考试时间:3月报名、6月考试 就业前景:建筑设计院、房产开发公司、施工单位 2、一…

叠纸游戏被“偷跑”的一生

已经数不清叠纸是第几次被偷跑了。 刚刚经历了一次大规模拆包偷跑的叠纸,在7月4日,又遭遇了如出一辙的恶性事件,叠纸旗下的乙女游戏《恋与深空》新男主秦彻再次被偷跑,#秦彻偷跑#、#秦彻建模#等多个话题登上热搜。 同时被偷跑的…

交流调压电路和交流调功电路的区别

交流调压电路和交流调功电路的区别 一、指代不同 1、交来流调压:对单相交流电的电压进行调节的电路。 2、交流调功:是一种以晶闸管(电力电子功率器件)为基础,以智能数字控制电路为核心的电源功率控制电路。 二、原…