MobileViT神经网络模型

news2024/11/16 19:24:26

  • 官方源码(Pytorch实现) : https://github.com/apple/ml-cvnets
  •  原文链接:https://blog.csdn.net/qq_37541097/article/details/126715733

  • 霹雳吧啦Wz从ml-evnets仓库中剥离的代码: deep-learning-for-image-processing/pytorch_classification/MobileViT at master · WZMIAOMIAO/deep-learning-for-image-processing · GitHub
  • MobileViT对应博文: MobileViT模型简介_太阳花的小绿豆的博客-CSDN博客

1. Transformer模型存在的问题

  • Transformer参数多,算力要求高,很难部署到移动端。
  • Transformer缺少空间偏置。计算某个token的attention时如果将其他token的顺序打乱对最终结果没有任何影响。但在图像数据中,空间信息是很重要且有意义的。为了解决这个问题,常见的方法是加上位置偏置(position bias)/位置编码,比如Vision Transformer中使用的绝对位置偏置,Swin Transformer中的相对位置偏置,加上位置偏置虽然在一定程度上解决了空间位置的信息丢失的问题,但又引入了一个新的问题。迁移到别的任务上时,位置偏执信息往往需要调整。
  • Transformer迁移到其他任务(输入图像分辨率发生改变)比较繁琐
  • Transformer模型很难训练(需要更多的训练数据,需要迭代更多的epoch,需要更大的正则项(L2正则),需要更多的数据增强(且对数据增强很敏感))

2. Vision Transformer结构

MobileViT论文中绘制的Standard visual Transformer。首先将输入的图片划分成一个个Patch,然后通过线性变化将每个Patch映射到一个一维向量中(视为一个个Token),接着加上位置偏置信息(可学习参数),再通过一系列Transformer Block,最后通过一个全连接层得到最终预测输出。

 2. MobileViT结构

如图 b 所示的 MobileViT 块的作用是使用包含较少参数的输入张量学习局部和全局信息。MobileViT 应用一个 n×n 标准卷积层,然后是逐点(1×1)卷积层来特征提取。n×n 卷积层编码局部空间信息,而逐点卷积通过学习输入通道的线性组合将张量投影到高维空间。

总而言之, MobileViT 使用标准卷积和 transformer 分别学习局部和全局表示,使得MobileViT 既具有类似卷积的属性,又同时允许全局处理。

MobileViT结构:普通卷积,MV2(MobiletNetV2中的Inverted Residual block),MobileViT block,全局池化以及全连接层共同组成。 

mobileNetV2

                                               

Expansion layer是使用1x1卷积将低维空间映射到高维空间(扩大通道数),这里Expansion有一个超参数是维度扩展几倍,可以根据实际情况来做调整的,默认值是6,也就是扩展6倍。

Projection layer也是使用1x1卷积,他的目的是希望把高维特征映射到低维空间去(减小通道数)。需要注意的是residual connection是在输入和输出的部分进行连接。另外,前面已经说过,因为从高维向低维转换,使用ReLU激活函数可能会造成信息丢失或破坏,所以在Projection convolution这一部分,我们不再使用ReLU激活函数而是使用线性激活函数。

3.模型代码

class MobileViTBlock(nn.Module):
    def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
        super().__init__()
        self.ph, self.pw = patch_size

        self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
        self.conv2 = conv_1x1_bn(channel, dim)

        self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)

        self.conv3 = conv_1x1_bn(dim, channel)
        self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
    
    def forward(self, x):
        y = x.clone()

        # Local representations
        x = self.conv1(x)
        x = self.conv2(x)
        
        # Global representations
        _, _, h, w = x.shape
        x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
        x = self.transformer(x)
        x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)

        # Fusion
        x = self.conv3(x)
        x = torch.cat((x, y), 1)
        x = self.conv4(x)
        return x
import torch
import torch.nn as nn

from einops import rearrange


def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.SiLU()
    )


def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
    return nn.Sequential(
        nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.SiLU()
    )


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b p h n d -> b p n (h d)')
        return self.to_out(out)


class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
            ]))
    
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x


class MV2Block(nn.Module):
    def __init__(self, inp, oup, stride=1, expansion=4):
        super().__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(inp * expansion)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expansion == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileViTBlock(nn.Module):
    def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
        super().__init__()
        self.ph, self.pw = patch_size

        self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
        self.conv2 = conv_1x1_bn(channel, dim)

        self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)

        self.conv3 = conv_1x1_bn(dim, channel)
        self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
    
    def forward(self, x):
        y = x.clone()

        # Local representations
        x = self.conv1(x)
        x = self.conv2(x)
        
        # Global representations
        _, _, h, w = x.shape
        x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
        x = self.transformer(x)
        x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)

        # Fusion
        x = self.conv3(x)
        x = torch.cat((x, y), 1)
        x = self.conv4(x)
        return x


class MobileViT(nn.Module):
    def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):
        super().__init__()
        ih, iw = image_size
        ph, pw = patch_size
        assert ih % ph == 0 and iw % pw == 0

        L = [2, 4, 3]

        self.conv1 = conv_nxn_bn(3, channels[0], stride=2)

        self.mv2 = nn.ModuleList([])
        self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
        self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
        self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
        self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))   # Repeat
        self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
        self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
        self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))
        
        self.mvit = nn.ModuleList([])
        self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0]*2)))
        self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1]*4)))
        self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2]*4)))

        self.conv2 = conv_1x1_bn(channels[-2], channels[-1])

        self.pool = nn.AvgPool2d(ih//32, 1)
        self.fc = nn.Linear(channels[-1], num_classes, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = self.mv2[0](x)

        x = self.mv2[1](x)
        x = self.mv2[2](x)
        x = self.mv2[3](x)      # Repeat

        x = self.mv2[4](x)
        x = self.mvit[0](x)

        x = self.mv2[5](x)
        x = self.mvit[1](x)

        x = self.mv2[6](x)
        x = self.mvit[2](x)
        x = self.conv2(x)

        x = self.pool(x).view(-1, x.shape[1])
        x = self.fc(x)
        return x


def mobilevit_xxs(img_size=(256, 256), num_classes=1000):
    dims = [64, 80, 96]
    channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320]
    return MobileViT((img_size[0], img_size[1]), dims, channels, num_classes=num_classes, expansion=2)


def mobilevit_xs(img_size=(256, 256), num_classes=1000):
    dims = [96, 120, 144]
    channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384]
    return MobileViT((img_size[0], img_size[1]), dims, channels, num_classes=num_classes)


def mobilevit_s(img_size=(256, 256), num_classes=1000):
    dims = [144, 192, 240]
    channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640]
    return MobileViT((img_size[0], img_size[1]), dims, channels, num_classes=num_classes)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


if __name__ == '__main__':
    img = torch.randn(5, 3, 256, 256)

    vit = mobilevit_xxs(img_size=(256, 256))
    out = vit(img)
    print(out.shape)
    print(count_parameters(vit))

    vit = mobilevit_xs()
    out = vit(img)
    print(out.shape)
    print(count_parameters(vit))

    vit = mobilevit_s()
    out = vit(img)
    print(out.shape)
    print(count_parameters(vit))

注意,输入图片不能是224x224,因为reshape时会报错,可以设置成256x256或其他。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/487126.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AWE 2023:科技与艺术的结晶 三星展台亮点回顾

2023年4月27~30日,AWE 2023中国家电及消费电子博览会在上海新国际博览中心盛大举行。 作为全球三大消费电子展之一,每一年的AWE都汇聚了全球家电及消费电子领域最前沿、最尖端的科技和产品,因而向来都被业界人士视为整个行业的风向标。本届AW…

【读书笔记】《深入浅出数据分析》

我最大的收获 试想你在经历一场英语考试,还有两分钟就要交卷了,而你还没有开始写作文。此时,你会怎么做? 利用2分钟时间写出的第一段,还是只写关键句子,搭出文章的开头、过程、结尾? 后者更加明…

【LeetCode】91. 解码方法

91. 解码方法(中等) 思路 这其实是一道字符串类的动态规划题,不难发现对于字符串s的某个位置i而言,我们只关心「位置 i 自己能否形成独立 item」和「位置 i 能够与上一位置(i-1)能否形成item 」&#xff0c…

无人机测试二维码降落。

一、首先要做的就是让NX板卡驱动usb摄像头: 1. 下载usb_cam软件包 sudo apt install ros-melodic-usb-cam2. 启动相机节点: A. 查找摄像头接口 ls /dev/video*B. 在该路径下修改launch文件,换成你自己的摄像头接口 我这里的摄像头接口为…

PySpark基础入门(3):RDD持久化

RDD的持久化 RDD 的数据是过程数据,因此需要持久化存储; RDD之间进行相互迭代的计算,新的RDD的生成代表着旧的RDD的消失;这样的特性可以最大化地利用资源,老旧地RDD可以及时地从内存中清理,从而给后续地计…

电脑发挥极致,畅游永恒之塔sf

随着22寸显示器的普及,玩永恒之塔势必会对显示卡造成了很大负担。不要说效果全开,就连简洁的玩,都成了问题,那是不是就要重金把才买的显示卡又要拿掉呢? 最出众的解决办法,是超频。 主要就具有以下条件最佳…

Azure DevOps Server 数据还原方式三:增量还原

Contents 1. 概述2. 操作方式 2.1 创建共享文件夹,并将备份文件复制到共享文件夹中2.2 还原数据3 验证还原的数据库 3.1 方式一:核对工作项所在的表的数据3.2 方式二:将数据库配置到应用层,在应用中验证数据4. 常见问题&#xff1…

FAST协议解析1 通过输入输出逆解析

一、前言 FAST协议可以支持金融机构间高吞吐量、低延迟的数据通讯,目前我知道的应用领域是沪深交易所的Level-2行情传输。网络上无论是FAST协议本身,还是使用相关工具(openfast、quickfast)对FAST行情进行解析,相关的…

MC9S12G128开发板—解决小车九宫格方位移动功能实现遇到的一些问题

接着我的上一篇文章:MC9S12G128开发板—实现按键发送CAN报文指示小车移动功能。本篇文章主要记录下在实现小车九宫格方位移动功能过程中,遇到的一些程序问题以及解决措施。 1. 上位机小车响应开发板按键CAN报文指令的响应出错问题 问题现象描述&#x…

自动驾驶行业观察之2023上海车展-----车企发展趋势(1)

新势力发展趋势 小鹏汽车:发布新车G6(中型SUV),将于2023年年中上市 发布新车G6:车展上,小鹏G6正式首发亮相,定位中型SUV,对标Tesla Model Y,将于2023年年中上市并开始交…

基于web的商场商城后台管理系统

该系统用户分为两类:普通员工和管理员。普通员工是指当前系统中的需要对商品和客户的信息进行查询的人。此类用户只能查看自己的信息,以及对商品和客户的信息进行查看。管理员用户可以对自己和他人的信息进行维护,包括对商品入库、销售、库存…

Redis缓存过期淘汰策略

文章目录 1、如何设置 Redis 最大运行内存?2、过期删除策略3、内存淘汰策略 1、如何设置 Redis 最大运行内存? 在配置文件 redis.conf 中,可以通过参数 maxmemory 来设定最大运行内存,只有在 Redis 的运行内存达到了我们设置的最…

代码命名规范

日常编码中,代码的命名是个大的学问。能快速的看懂开源软件的代码结构和意图,也是一项必备的能力。那它们有什么规律呢? Java项目的代码结构,能够体现它的设计理念。Java采用长命名的方式来规范类的命名,能够自己表达…

消息称苹果Type-C口充电未设MFi限制,iOS17将更新Find My服务

根据国外科技媒体 iMore 报道,基于消息源 analyst941 透露的信息,苹果公司目前并未开发 MFi 限制。 根据推文信息内容,两款 iPhone 15 机型的最高充电功率为 20W,而 iPhone 15 Pro 机型的最高支持 27W 充电。 此前古尔曼表示苹…

Python趋势外推预测模型实验完整版

趋势外推预测模型实验完整版 实验目的 通过趋势外推预测模型(佩尔预测模型),掌握预测模型的建立和应用方法,了解趋势外推预测模型(佩尔预测模型)的基本原理 实验内容 趋势外推预测模型 实验步骤和过程…

第4章-虚拟机栈(多使用到jclasslib工具查看字节码)

虚拟机栈 简介 虚拟机栈的出现背景 由于跨平台性的设计,Java的指令都是根据栈来设计的。不同平台CPU架构不同,所以不能设计为基于寄存器的【如果设计成基于寄存器的,耦合度高,性能会有所提升,因为可以对具体的CPU架…

警惕免杀版Gh0st木马!

https://github.com/SecurityNo1/Gh0st2023 经过调查发现,这款开源的高度免杀版Gh0st木马目前正在大范围传播,据称可免杀多种主流杀软:开发者不仅制作了新颖的下载页面,还设法增加了搜索引擎的收录权重,吸引了许多免…

Python基础合集 练习17(类与对象)

class Dog: pass papiDog() print(papi) print(type(papi)) 构建方法 创建类过后可以定义一个特殊的方法。在python中构建方法是__init__(),init()必须包含一个self参数 class pig(): #def__init__(self) -> None: print(‘你好’) pipgpig() 属性和方法 cl…

JDBC详解(六):数据库事务(超详解)

JDBC详解(六):数据库事务(超详解) 前言一、数据库事务介绍二、JDBC事务处理三、事务的ACID属性1、数据库的并发问题2、四种隔离级别3、在MySql中设置隔离级别 前言 本博主将用CSDN记录软件开发求学之路上亲身所得与所…

MySQL基础——数据模型·数据库操作

♥️作者:小刘在C站 ♥️个人主页:小刘主页 ♥️每天分享云计算网络运维课堂笔记,努力不一定有收获,但一定会有收获加油!一起努力,共赴美好人生! ♥️树高千尺,落叶归根人生不易&…