基于Vision Transformer的mini_ImageNet图片分类实战

news2024/9/20 7:54:51

【图书推荐】《PyTorch深度学习与计算机视觉实践》-CSDN博客

PyTorch计算机视觉之Vision Transformer 整体结构-CSDN博客

mini_ImageNet数据集简介与下载

mini_ImageNet数据集节选自ImageNet数据集。ImageNet是一个非常有名的大型视觉数据集,它的建立旨在促进视觉识别研究。ImageNet为超过1400万幅图像进行了注释,而且给至少100万幅图像提供了边框。同时,ImageNet包含2万多个类别,比如“气球”“轮胎”和“狗”等类别,ImageNet的每个类别均不少于500幅图像。

训练这么多图像需要消耗大量的资源,为了节约资源,后续的研究者在全ImageNet的基础上提取出了mini_ImageNet数据集。Mini_ImageNet包含100类共60000幅彩色图片,其中每类有600个样本,每幅图片的规格为84×84。通常而言,这个数据集的训练集和测试集的类别划分为80:20。相比于CIFAR-10数据集,mini_ImageNet数据集更加复杂,但更适合进行原型设计和实验研究。

mini_ImageNet的下载也很容易,读者可以使用提供的库包完成对应的下载操作,安装命令如下:

pip install MLclf

Vision Transformer模型设计

下面就是对训练过程的Vision Transformer进行模型设计,在11.1.4节完成的Vision Transformer模型的设计,针对的是224维度大小的图片,而此时使用的是mini版本的ImageNet,因此在维度上会有所变换。本例Vision Transformer模型的完整代码如下:

import torch
from vit import PatchEmbed,Block

class VisionTransformer(torch.nn.Module):
    def __init__(self,num_patches = 1,image_size = 84,patch_size = 14,embed_dim = 588,num_heads = 6,
                 qkv_bias = True,depth = 3,num_class = 64):
        super().__init__()

        #初始化PatchEmbed层
        self.patch_embed  = PatchEmbed(img_size = image_size,patch_size=patch_size,embed_dim=embed_dim)
        #增加一个作为标志物的参数
        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, embed_dim))

        #建立位置向量,计算embedding的长度
        self.num_tokens = (image_size * image_size) // (patch_size * patch_size)
        self.pos_embed = torch.nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))

        #这里在使用block模块时采用了指针的方式,注意*号
        self.blocks = torch.nn.Sequential(
            *[Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=4.0, qkv_bias=qkv_bias) for _ in range(depth)]
        )
        #最终的logits推断层
        self.logits_layer = torch.nn.Sequential(torch.nn.Linear(embed_dim, 512),torch.nn.GELU(),torch.nn.Linear(512, num_class))

    def forward(self,x):

        embedding = self.patch_embed(x)

        #添加标志物
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        embedding = torch.cat((cls_token, embedding), dim=1)  #[B, 197, 768]
        embedding += self.pos_embed

        embedding = self.blocks(embedding)

        embedding = embedding[:,0]
        embedding = torch.nn.Dropout(0.1)(embedding)
        logits = self.logits_layer(embedding)
        return logits

if __name__ == '__main__':
    image = torch.randn(size=(2,3,84,84))
    VisionTransformer()(image)

《PyTorch深度学习与计算机视觉实践(人工智能技术丛书)》(王晓华)【摘要 书评 试读】- 京东图书 (jd.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1948688.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

旗晟机器人仪器仪表识别AI智慧算法

在当今迅猛发展的工业4.0时代,智能制造和自动化运维已然成为工业发展至关重要的核心驱动力。其中智能巡检运维系统扮演着举足轻重的角色。工业场景上不仅要对人员行为监督进行监督,对仪器仪表识别分析更是不可缺少的一个环节。那么我们说说旗晟仪器仪表识…

AI模型大比拼:Claude 3系列 vs GPT-4系列最新模型综合评测

AI模型大比拼:Claude 3系列 vs GPT-4系列最新模型综合评测 引言 人工智能技术的迅猛发展带来了多款强大的语言模型。本文将对六款领先的AI模型进行全面比较:Claude 3.5 Sonnet、Claude 3 Opus、Claude 3 Haiku、GPT-4、GPT-4o和GPT-4o Mini。我们将从性能…

【Gin】精准应用:Gin框架中工厂模式的现代软件开发策略与实施技巧(下)

【Gin】精准应用:Gin框架中工厂模式的现代软件开发策略与实施技巧(下) 大家好 我是寸铁👊 【Gin】精准应用:Gin框架中工厂模式的现代软件开发策略与实施技巧(下)✨ 喜欢的小伙伴可以点点关注 💝 前言 本次文章分为上下两部分&…

智能家居全在手机端进行控制,未来已来!

未来触手可及:智能家居,手机端的全控时代 艾斯视觉的观点是:在不远的将来,家,这个温馨的港湾,将不再只是我们休憩的场所,而是科技与智慧的结晶。想象一下,只需轻触手机屏幕&#xf…

如何实现CPU最大处理效率

如何实现CPU最大处理效率 CPU,或称为中央处理器,是计算机中负责执行指令和处理数据的核心部件。它的工作原理可简单概括为"取指、译码、执行、存储"四个步骤,也称为计算机的指令周期。 取指(Fetch):在取指阶段,CPU从内存中获取下一条要执行的指令,并存放在指…

回顾网络路,心率就过速

笔者上网写作已满16年,其间加盟过国内互联网的知名网站自媒体至少在40至50家之多,但由于有的被已被勒令停刊了(如《天涯论坛》),有的则因其改版而只保留了极少数擅于唱颂的写手(如《强国论坛》)…

【SpringCloud】企业认证、分布式事务,分布式锁方案落地-1

目录 HR企业入驻 HR企业入驻 - 认证流程解析 HR企业入驻 - 查询企业是否存在 HR企业入驻 - 上传企业logo与营业执照 HR企业入驻 - 新企业(数据字典与行业tree结构解析) 行业tree 行业tree - 创建节点 行业tree - 查询一级分类 行业tree - 查询子分…

计算存储背景与发展

随着云计算、企业级应用以及物联网领域的飞速发展,当前的数据处理需求正以前所未有的规模增长,以满足存储行业不断变化的需求。这种增长导致网络带宽压力增大,并对主机计算资源(如内存和CPU)造成极大负担,进…

Redis的使用场景——热点数据缓存

热点数据缓存 Redis的使用场景——热点数据的缓存 1.1 什么是缓存 为了把一些经常访问的数据,放入缓存中以减少对数据库的访问效率,从而减少数据库的压力,提高程序的性能。【在内存中存储】 1.2 缓存的原理 查询缓存中是否存在对应的数据如…

05 capture软件创建元器件库(以STM32为例)

05 创建元器件库_以STM32为例 一、新建原理图库文件二、新建器件三、开始创建元器件 一些IC类元件,需要自己创建元器件库。 先看视频,然后自己创建STM32F103C8T6的LQFP48的元器件。 STM32F103C8T6是目前为止,自己用的最多的芯片。 先要有数据…

nodejs安装及环境配置建材商城管理系统App

✌网站介绍:✌10年项目辅导经验、专注于计算机技术领域学生项目实战辅导。 ✌服务范围:Java(SpringBoo/SSM)、Python、PHP、Nodejs、爬虫、数据可视化、小程序、安卓app、大数据等设计与开发。 ✌服务内容:免费功能设计、免费提供开题答辩P…

文件包涵条件竞争(ctfshow82)

Web82 利用 session.upload_progress 包含文件漏洞 <!DOCTYPE html> <html> <body> <form action"https://09558c1b-9569-4abd-bf78-86c4a6cb6608.challenge.ctf.show//" method"POST" enctype"multipart/form-data"> …

C语言的发展过程介绍

引言 C语言&#xff0c;由丹尼斯里奇&#xff08;Dennis Ritchie&#xff09;在20世纪70年代初期于贝尔实验室开发&#xff0c;是计算机科学史上最具影响力的编程语言之一。本文将概述C语言的发展历程&#xff0c;并提供一些代码示例来展示其演变。 起源&#xff1a;UNIX和C语言…

自动化测试--WebDriver API

1. 元素定位方法 通过 ID 定位&#xff1a;如果元素具有唯一的 ID 属性&#xff0c;可以使用 findElement(By.id("elementId")) 方法来定位元素。通过 Name 定位&#xff1a;使用 findElement(By.name("elementName")) 来查找具有指定名称的元素。通过 Cl…

重生之“我打数据结构,真的假的?”--5.堆(无习题)

1.堆的概念与结构 如果有⼀个关键码的集合 &#xff0c;把它的所有元素按完全⼆叉树的顺序存储⽅ 式存储&#xff0c;在⼀个⼀维数组中&#xff0c;并满⾜&#xff1a; &#xff08; 且 &#xff09;&#xff0c; i 0、1、2... &#xff0c;则称为⼩堆(或⼤堆)。将根结点最⼤的…

逻辑处理模块:FPGA复旦微JFM7VX690T36+网络加速器:雄立XC13080-500C

逻辑处理模块通常是指在计算机系统、软件应用或电子设备中负责执行逻辑运算和决策过程的组件。 在不同的领域和技术中&#xff0c;逻辑处理模块可能有不同的实现方式和名称&#xff0c;但它们的核心功能都是基于输入数据进行逻辑判断和处理&#xff0c;并产生相应的输出结果。下…

GO-学习-03-基本数据类型

数据类型&#xff1a;基本数据类型和复合数据类型 基本数据类型&#xff1a;整型、浮点型、布尔型、字符串 复合数据类型&#xff1a;数组、切片、结构体、函数、map、通道&#xff08;channel&#xff09;、接口 整型&#xff1a; package main import "fmt" im…

react-native从入门到实战系列教程一环境安装篇

充分阅读官网的环境配置指南&#xff0c;严格按照他的指导作业&#xff0c;不然你一直只能在web或沙箱环境下玩玩 极快的网络和科学上网&#xff0c;必备其中的一个较好的心理忍受能力&#xff0c;因为上面一点就可以让你放弃坚持不懈&#xff0c;努力尝试 成功效果 三大件 …

「Unity3D」场景中的距离单位Unit与相关设置PixelsToUnits、PixelsPerUnit

GameObject在场景的位置Position&#xff0c;并没有明确是什么具体单位——如&#xff1a;Transform的x、y、z&#xff0c;或RectTransform的PosX、PosY、PosZ。而RectTransform在面板上显示的Width和Height&#xff0c;也没有具体单位&#xff0c;其实并不是像素。 事实上&am…

python+vue3+onlyoffice在线文档系统实战20240725笔记,首页开发

解决遗留问题 内容区域的高度没有生效&#xff0c;会随着菜单的高度自动变化。 解决方案&#xff1a;给侧边加上一个最小高度。 首页设计 另一种设计&#xff1a; 进来以后&#xff0c;是所有的文件夹和最近的文件。 有一张表格&#xff0c;类似于Windows目录详情&…