VIT 和Swin Transformer

news2025/1/12 9:47:49

VIT:https://blog.csdn.net/qq_37541097/article/details/118242600
Swin Transform:https://blog.csdn.net/qq_37541097/article/details/121119988
一、VIT
模型由三个模块组成:
Linear Projection of Flattened Patches(Embedding层)
Transformer Encoder(图右侧有给出更加详细的结构)
MLP Head(最终用于分类的层结构)
在这里插入图片描述
Embedding模块:
ViT-B/16为例,每个token向量长度为768。要求输入的token必须是二维的。需要把三维的图片信息转成二维。
以ViT-B/16为例,直接使用一个卷积核大小为16x16,步距为16,卷积核个数为768的卷积来实现。通过卷积[224, 224, 3] -> [14, 14, 768],然后把H以及W两个维度展平即可[14, 14, 768] -> [196, 768],此时正好变成了一个二维矩阵,正是Transformer想要的。
还要有一个用于分类的token,长度与其他token保持一致。与之前从图片中生成的tokens拼接在一起,Cat([1, 768], [196, 768]) -> [197, 768]。
Transformer Encoder模块:

vit使用
总结构

class VisionTransformer(nn.Module):
    """ Vision Transformer with support for patch or hybrid CNN input stage
    """

    def __init__(self, nattr=1, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, use_checkpoint=False):
        super().__init__()

        self.nattr = nattr
        self.use_checkpoint = use_checkpoint
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models

        if hybrid_backbone is not None:
            self.patch_embed = HybridEmbed(
                hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
        else:
            self.patch_embed = PatchEmbed(
                img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)    ###第一步
        num_patches = self.patch_embed.num_patches

        # modify
        # self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, self.nattr, embed_dim))     ##创建类别token
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.nattr, embed_dim))    ##总的token
        self.pos_drop = nn.Dropout(p=drop_rate)    ##使用Dropout

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule

        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])

        self.norm = norm_layer(embed_dim)

        # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
        # self.repr = nn.Linear(embed_dim, representation_size)
        # self.repr_act = nn.Tanh()

        # Classifier head
        # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        trunc_normal_(self.cls_token, std=.02)
        trunc_normal_(self.pos_embed, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)  # (bt, num_patches + nattr, embed_dim)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:

            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)

        x = self.norm(x)
        # return x[:, :self.nattr]
        return x[:, 1:]

第一步Embedding层,相当于一层卷积

class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.num_x = img_size[1] // patch_size[1]  # 28
        self.num_y = img_size[0] // patch_size[0]

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

第二步+第三步,Transformer Encoder+MLP Head

self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])   ##创建12个Block,每个Block都是:归一化+attention+dropout+归一化+mlp(2个fc层)。
            
class Mlp(nn.Module):  ##两个全连接层
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
        
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)   ##层归一化,LayerNorm
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)   ##注意力模块,需要设置头个数
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

最后一步,搭建分类器:

@CLASSIFIER.register("linear")
class LinearClassifier(BaseClassifier):
    def __init__(self, nattr, c_in, bn=False, pool='avg', scale=1):
        super().__init__()

        self.pool = pool
        if pool == 'avg':
            self.pool = nn.AdaptiveAvgPool2d(1)
        elif pool == 'max':
            self.pool = nn.AdaptiveMaxPool2d(1)

        self.logits = nn.Sequential(
            nn.Linear(c_in, nattr),
            nn.BatchNorm1d(nattr) if bn else nn.Identity()
        )


    def forward(self, feature, label=None):

        if len(feature.shape) == 3:  # for vit (bt, nattr, c)

            bt, hw, c = feature.shape
            # NOTE ONLY USED FOR INPUT SIZE (256, 192)
            h = 16
            w = 12
            feature = feature.reshape(bt, h, w, c).permute(0, 3, 1, 2)    ##(32,768,16,12)

        feat = self.pool(feature).view(feature.size(0), -1)    ##(32,768)
        x = self.logits(feat)    ##(32,num_class)

        #return [x],feature,feat
        return [x], feature
classifier = build_classifier(cfg.CLASSIFIER.NAME)(
    nattr=train_set.attr_num,
    c_in=c_output,
    bn=cfg.CLASSIFIER.BN,
    pool=cfg.CLASSIFIER.POOLING,
    scale =cfg.CLASSIFIER.SCALE
)

model = FeatClassifier(backbone, classifier, bn_wd=cfg.TRAIN.BN_WD)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/945309.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

中文情感分类

本文通过ChnSentiCorp数据集介绍了文本分类任务过程,主要使用预训练语言模型bert-base-chinese直接在测试集上进行测试,也简要介绍了模型训练流程,不过最后没有保存训练好的模型。 一.任务和数据集介绍 1.任务 中文情感分类本质还是一个文本…

会员管理系统实战开发教程05-会员开卡

上一篇我们讲解了如何点击按钮弹出层,已经罗列了会员管理的一些常见功能。本篇我们介绍一下会员开卡的业务。 1 创建变量 我们会员开卡的业务的话,也是要在本页面弹出,弹出其实只是让组件是否显示和隐藏,我们先定义一个布尔值类…

柠檬水找零【贪心算法-】

柠檬水找零 在柠檬水摊上,每一杯柠檬水的售价为 5 美元。顾客排队购买你的产品,(按账单 bills 支付的顺序)一次购买一杯。 每位顾客只买一杯柠檬水,然后向你付 5 美元、10 美元或 20 美元。你必须给每个顾客正确找零&…

读书笔记-《ON JAVA 中文版》-摘要23[第二十章 泛型-2]

文章目录 第二十章 泛型5. 泛型擦除5.1 泛型擦除5.2 迁移兼容性5.3 擦除的问题5.4 边界处的动作 6. 补偿擦除7. 边界8. 通配符8.1 通配符8.2 逆变 9. 问题10. 动态类型安全11. 泛型异常 第二十章 泛型 普通的类和方法只能使用特定的类型:基本数据类型或类类型。如果…

驱动代码验证

要求 代码 #include <stdio.h> #include <unistd.h> #include <sys/types.h> #include <sys/stat.h> #include <fcntl.h> #include <stdlib.h> #include <string.h> #include <sys/ioctl.h>int main(int argc,const char * a…

移动端h5项目的兼容和适配问题

解决兼容性问题的关键在于对移动端产品的生存环境进行梳理&#xff0c;在此基础之上制定应对策略。 所谓生存环境主要分为三个维度&#xff1a; 硬件环境&#xff0c;细分为品牌和机型&#xff0c;决定了屏幕大小、性能等硬件限制 操作系统&#xff0c;比如iOS6和iOS7&#xf…

Redis数据类型(list\set\zset)

"maybe its why" List类型 列表类型是⽤来存储多个有序的字符串&#xff0c;列表中的每个字符串称为元素&#xff08;element&#xff09;&#xff0c;⼀个列表最多可以存储个2^32 - 1个元素。在Redis中&#xff0c;可以对列表两端插⼊&#xff08;push&#xff09…

555定时器

一、定义 定时器是一种多用途的数字-模拟混合集成电路&#xff0c;可极方便的构成施密特触发器、单稳态触发器和多谐振荡器&#xff0c;其简化原理图及引脚定义如下所示 3个绿色电阻&#xff0c;电阻值为5K&#xff1b;2个黄色和粉色比较器&#xff1b;1个紫色SR触发器&#x…

WPF实战项目十三(API篇):备忘录功能api接口、优化待办事项api接口

1、新建MenoDto.cs /// <summary>/// 备忘录传输实体/// </summary>public class MenoDto : BaseDto{private string title;/// <summary>/// 标题/// </summary>public string Title{get { return title; }set { title value; }}private string con…

实验三十一、OCL 电路输出功率和效率的研究

一、题目 研究 OCL 功率放大电路的输出功率和效率。 二、仿真电路 OCL 功率放大电路如图1所示。 图 1 OCL 功率放大电路 图1\,\,\,\textrm{OCL}\,功率放大电路 图1OCL功率放大电路图中采用 NPN 型低频功率晶体管 2SC2001&#xff0c;其参数为&#xff1a; I C M 700 mA I_…

5G NR:RACH流程 -- Msg1之选择正确的PRACH时频资源

PRACH的时域资源是如何确定的 PRACH的时域资源主要由参数“prach-ConfigurationIndex”决定。拿着这个参数的取值去协议38211查表6.3.3.2-2/3/4&#xff0c;需要注意根据实际情况在这三张表中进行选择&#xff1a; FR1 FDD/SULFR1 TDDFR2 TDD Random access preambles can onl…

信号和槽的相关操作

目录 信号和槽 connect()函数 自定义信号槽 例子 自定义信号槽需要注意的事项 信号槽的更多用法 Lambda表达式 ① 函数对象参数 ② 操作符重载函数参数 ③ 可修改标示符 ④ 错误抛出标示符 ⑤ 函数返回值 ⑥ 是函数体 所谓信号槽&#xff0c;实际就是观察者模式。当…

AVS3变换:PBT、ST和SBT

前面的文章介绍了AVS3中的变换工具IST和ISTS&#xff0c;本文将介绍AVS3中剩余的几种变换工具&#xff1a;基于位置的变换&#xff08;PBT,Position Based Transform&#xff09;、二次变换&#xff08;ST, Secondary Transform&#xff09;和子块变换&#xff08;SBT, Sub-Blo…

SmartInspect Professional .Net Delphi Crack

SmartInspect Professional .Net & Delphi Crack SmartInspect Professional是一个用于调试和跟踪.NET、Java和Delphi软件的高级日志记录工具。它使您能够识别错误&#xff0c;找到客户问题的解决方案&#xff0c;并让您清楚地了解软件在不同环境和条件下的工作方式。可以轻…

给oracle逻辑导出clob大字段、大数据量表提提速

文章目录 前言一、大表数据附&#xff1a;查询大表 二、解题思路1.导出排除大表的数据2.rowid切片导出大表数据Linux代码如下&#xff08;示例&#xff09;&#xff1a;Windows代码如下&#xff08;示例&#xff09;&#xff1a;手工执行代码如下&#xff08;示例&#xff09;&…

java八股文面试[多线程]——Synchronized的底层实现原理

笔试&#xff1a;画出Synchronized 线程状态流转实现原理图 synchronized关键字解决的是多个线程之间访问资源的同步性&#xff0c;synchronized 翻译为中文的意思是同步&#xff0c;也称之为”同步锁“。 synchronized的作用是保证在同一时刻&#xff0c; 被修饰的代码块或方…

任意文件上传

文章目录 渗透测试漏洞原理任意文件上传1. 任意文件上传概述1.1 漏洞成因1.2 漏洞原理1.3 漏洞危害1.4 漏洞的利用方法1.5 漏洞的验证 2. WebShell解析2.1 Shell2.1.1 命令解释器 2.2 WebShell2.2.1 大马2.2.2 小马2.2.3 GetShell2.2.4 WebShell项目 3. 任意文件上传攻防3.1 毫…

注册字符设备

五、注册字符设备 struct cdev {struct kobject kobj;//表示该类型实体是一种内核对象struct module *owner;//填THIS_MODULE&#xff0c;表示该字符设备从属于哪个内核模块const struct file_operations *ops;//指向空间存放着针对该设备的各种操作函数地址struct list_head …

RAD Installer Crack,集成到RAD Studio IDE支持

RAD & Installer Crack,集成到RAD Studio IDE支持 用于创建NSIS和Inno Setup安装程序的RAD Studio扩展。它将NSIS(Nullsoft Scriptable Install System)和Inno Setup与Embarcadero RAD Studio IDE结合在一起。它允许您在RAD Studio中设计和构建NSIS和Inno Setup项目&#x…

错误的迷宫:探索开发中的异常管理之旅

引言&#xff1a;为什么我们需要谈论错误处理&#xff1f; 在软件开发的世界中&#xff0c;错误是不可避免的。它们是我们编程旅程中的挑战&#xff0c;但也是我们成长的机会。正确地处理错误不仅可以确保软件的稳定性和可靠性&#xff0c;还可以为开发者提供宝贵的反馈。本文…