ViT的极简pytorch实现及其即插即用

news2024/12/23 4:47:17

先放一张ViT的网络图
在这里插入图片描述
可以看到是把图像分割成小块,像NLP的句子那样按顺序进入transformer,经过MLP后,输出类别。每个小块是16x16,进入Linear Projection of Flattened Patches, 在每个的开头加上cls token和位置信息,也就是position embedding。
去掉数据读取部分,直接上一个极简的ViT代码:

import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)## 对tensor张量分块 x :1 197 1024   qkv 最后是一个元祖,tuple,长度是3,每个元素形状:1 197 1024
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)   # 224*224
        patch_height, patch_width = pair(patch_size)   # 16 * 16

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            # (b,3,224,224) -> (b,196,768)    14*14=196  16*16*3=768
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),    # (b,196,1024)
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)        # img 1 3 224 224  输出形状x : 1 196 1024
        b, n, _ = x.shape                       # 1 196

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)    # (1,1,1024)
        x = torch.cat((cls_tokens, x), dim=1)   # (1,197,1024)
        x += self.pos_embedding[:, :(n + 1)]    # (1,197,1024)
        x = self.dropout(x)                     # (1,197,1024)

        x = self.transformer(x)                 # (1,197,1024)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]     # (1,1024)

        x = self.to_latent(x)      # (1,1024)
        return self.mlp_head(x)    # (1,1000)


if __name__ == '__main__':
    v = ViT(
        image_size = 224,
        patch_size = 16,
        num_classes = 1000,
        dim = 1024,
        depth = 6,
        heads = 16,
        mlp_dim = 2048,
        dropout = 0.1,
        emb_dropout = 0.1
    )

    img = torch.randn(1, 3, 224, 224)

    preds = v(img)        # (1, 1000)

    print(preds.shape)

去掉cls和最后的全连接分类头,变成即插即用的模块:

import torch
from torch import nn

from einops import rearrange
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)## 对tensor张量分块 x :1 197 1024   qkv 最后是一个元祖,tuple,长度是3,每个元素形状:1 197 1024
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, dim = 1024, depth = 3, heads = 16, mlp_dim = 2048, dim_head = 64, dropout = 0.1, emb_dropout = 0.1):
        super().__init__()
        channels, image_height, image_width = image_size   # 256,64,80
        patch_height, patch_width = pair(patch_size)       # 4*4

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)     # 16*20
        patch_dim = 64 * patch_height * patch_width    # 64*8*10

        self.conv1 = nn.Conv2d(256, 64, 1)

        self.to_patch_embedding = nn.Sequential(
            # (b,64,64,80) -> (b,320,1024)    16*20=320  4*4*64=1024
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),    # (b,320,1024)
        )

        self.to_img = nn.Sequential(
            # b c (h p1) (w p2) -> (b,64,64,80)      16*20=320  4*4*64=1024
            Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', \
                      p1 = patch_height, p2 = patch_width, h = image_height // patch_height, w = image_width // patch_width),
            nn.Conv2d(64, 256, 1),      # (b,64,64,80) -> (b,256,64,80)
        )
        # 位置编码
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

    def forward(self, img):
        x = self.conv1(img)                     # img 1 256 64 80 -> 1 64 64 80
        x = self.to_patch_embedding(x)          # 1 320 1024
        b, n, _ = x.shape                       # 1 320

        x += self.pos_embedding[:, :(n + 1)]    # (1,320,1024)
        x = self.dropout(x)                     # (1,320,1024)

        x = self.transformer(x)                 # (1,320,1024)

        x = self.to_img(x)

        return x                                # (1 256 64 80)


if __name__ == '__main__':

    v = ViT(image_size = (256,64,80), patch_size = 4)

    img = torch.randn(1, 256, 64, 80)

    preds = v(img)         # (1, 256, 64, 80)

    print(preds.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1343915.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

云原生|对象存储|minio分布式集群的搭建和初步使用(可用于生产)

前言: minio作为轻量级的对象存储服务安装还是比较简单的,但分布式集群可以大大提高存储的安全性,可靠性。分布式集群是在单实例的基础上扩展而来的 minio的分布式集群有如下要求: 所有运行分布式 MinIO 的节点需要具有相同的访…

【Java 进阶篇】Redis 缓存优化:提升应用性能的不二选择

在现代的软件开发中,性能一直是开发者们追求的目标之一。对于数据库访问频繁、数据读取较慢的场景,使用缓存是提升性能的有效手段之一。而 Redis 作为一款高性能的内存数据库,被广泛用作缓存工具。本文将围绕 Redis 缓存优化进行详解&#xf…

《深入理解JAVA虚拟机笔记》并发与线程安全原理

除了增加高速缓存之外,为了使处理器内部的运算单元能尽量被充分利用,处理器可能对输入代码进行乱序执行(Out-Of-Order Execution)优化。处理器会在计算之后将乱序执行的结果重组,保证该结果与顺序执行的结果一致&#…

三台CentOS7.6虚拟机搭建Hadoop完全分布式集群(三)

这个是笔者大学时期的大数据课程使用三台CentOS7.6虚拟机搭建完全分布式集群的案例,已成功搭建完全分布式集群,并测试跑实例。 9 安装hbase 温馨提示:安装hbase先在master主节点上配置,然后远程复制到slave01或slave02 &#xf…

远程网络唤醒家庭主机(openwrt设置)

远程网络唤醒家庭主机(openwrt设置) 前提: 1.配置好主板bios的网络唤醒功能(网络教程自己百度一下找) 2.电脑开启网络唤醒功能(网络教程自己百度一下找) 3.路由器通过ddns实现域名和动态IP绑定内网穿透方法汇总_不修改光猫进行内网穿透-C…

【办公软件】Excel双坐标轴图表

在工作中整理测试数据,往往需要一个图表展示两个差异较大的指标。比如共有三个数据,其中两个是要进行对比的温度值,另一个指标是两个温度的差值,这个差值可能很小。 举个实际的例子:数据如下所示,NTC检测温…

JavaScript中实现页面跳转的几种常用方法

Hi i,m JinXiang ⭐ 前言 ⭐ 本篇文章主要介绍在JavaScript中实现页面跳转的几种常用方法以及部分理论知识 🍉欢迎点赞 👍 收藏 ⭐留言评论 📝私信必回哟😁 🍉博主收将持续更新学习记录获,友友们有任何问题…

2024收入最高的编程语言

我的新书《Android App开发入门与实战》已于2020年8月由人民邮电出版社出版,欢迎购买。点击进入详情 1.Python Python 是最流行、用途最广泛的语言之一。它通常用于网络开发、数据科学、机器学习等。 以下是 Python 编程语言的一些主要用途: Web 开发&…

【AIGC科技展望】预测AIGC2025年的机会与挑战

2025年,AIGC的机会与挑战 在未来的五年里,AIGC(AI Generated Content)将会成为一个越来越重要的领域。但是,伴随着机会而来的是挑战。在这篇文章中,我们将一起探讨AIGC的机会与挑战,并预测2025…

WinForm开发 - C# RadioButton(单选框) 设置默认选中或取消默认选中

WinForm开发中RadioButton组件使用过程中的小技巧。 1、属性界面操作 如果有多个组件,希望不显示默认选中单选框只需要将其Checked属性全部设置为False即可, 如果希望默认多个组件中显示默认选中,将其Checked属性设置为True。 2、代码实…

GEE错误——‘xxx‘ did not match any bands.

这里我们在进行影像展示的时候会出现下面的错误,主要的原因是我们虽然进行了波段运算,但是依旧无法加载,主要原因是我们没有将计算过后的波段信息进行添加到我们的一个多波段影像,这里我们首先来看看代码出现的错误提示。当然这里只是给出了主要的问题,其实在进行波段运算…

Java——猫猫图鉴微信小程序(前后端分离版)

目录 一、开源项目 二、项目来源 三、使用框架 四、小程序功能 1、用户功能 2、管理员功能 五、使用docker快速部署 六、更新信息 审核说明 一、开源项目 猫咪信息点-ruoyi-cat: 1、一直想做点项目进行学习与练手,所以做了一个对自己来说可以完成的…

ES6的默认参数和rest参数

✨ 专栏介绍 在现代Web开发中,JavaScript已经成为了不可或缺的一部分。它不仅可以为网页增加交互性和动态性,还可以在后端开发中使用Node.js构建高效的服务器端应用程序。作为一种灵活且易学的脚本语言,JavaScript具有广泛的应用场景&#x…

PowerShell Instal 一键部署TeamCity

前言 TeamCity 是一个通用的 CI/CD 软件平台,可实现灵活的工作流程、协作和开发实践。允许在您的 DevOps 流程中成功实现持续集成、持续交付和持续部署。 系统支持 Centos7,8,9/Redhat7,8,9及复刻系列系统支持 Windows 10,11,2012,2016,2019,2022高版本建议使用9系列系统…

权重函数设计

1e4/(0.5*s^21*s1) 1.4125 抑制驱动器饱和 1e-4*(s1)/(0.001*s1) 10*(s10)/(s1000) 1e-6*(s1)/(0.001*s1) [0.01,0.1],[0.001,1] 0.5*(0.05*s1)/(0.0001*s0.01)

Android笔记(二十三):Paging3分页加载库结合Compose的实现分层数据源访问

在Android笔记(二十二):Paging3分页加载库结合Compose的实现网络单一数据源访问一文中,实现了单一数据源的访问。在实际运行中,往往希望不是单纯地访问网络数据,更希望将访问的网络数据保存到移动终端的SQL…

windows高效的文件夹管理工具QTTabBar

1、介绍: 功能介绍视频https://www.bilibili.com/video/BV13y4y1V782/?spm_id_from333.880.my_history.page.click&vd_sourceef5003729aa0776908befb4d20108803 这个视频很坑,介绍了功能,但没介绍安装方法 2、下载: 下载链…

自检服务器,无需服务器、不用编程。

自检服务器,无需服务器、不用编程。 大家好,我是JavaPub. 这几年自媒体原来热,很多人都知道了个人 IP 的重要性。连一个搞中医的朋友都要要做一个自己的网站,而且不想学编程、还不想花 RMB 租云服务。 老读者都知道&#xff0c…

Spring Boot 基于Redisson实现注解式分布式锁

依赖版本 JDK 17 Spring Boot 3.2.0 Redisson 3.25.0 源码地址&#xff1a;Gitee 导入依赖 <properties><redisson.version>3.25.0</redisson.version> </properties><dependencies><dependency><groupId>org.projectlombok</…

OpenHarmony南向之Camera简述

Camera驱动框架 该驱动框架模型内部分为三层&#xff0c;依次为HDI实现层、框架层和设备适配层&#xff1a; HDI实现层&#xff1a;实现OHOS&#xff08;OpenHarmony Operation System&#xff09;相机标准南向接口。框架层&#xff1a;对接HDI实现层的控制、流的转发&#x…