Vision Transfomer系列第一节---从0到1的源码实现

news2025/1/16 13:36:26

本专栏主要是深度学习/自动驾驶相关的源码实现,获取全套代码请参考

这里写目录标题

  • 准备
  • 逐步源码实现
    • 数据集读取
    • VIt模型搭建
      • hand
      • 类别和位置编码
          • 类别编码
          • 位置编码
      • blocks
      • head
      • VIT整体
    • Runner(参考mmlab)
    • 可视化
  • 总结

准备

在这里插入图片描述
本博客完成Vision Transfomer(VIT)模型的搭建和flowers数据集的训练测试.整个源码包括如下几个任务:
1.读取flowers数据集的dataset类,对应文件dataset.py
2.VIT模型搭建,主要依赖于上几篇博客,对应model.py

1.transfomer中Multi-Head Attention的源码实现的MultiheadAttention类,用于搭建BaseTransformerLayer类,实现encoder和decoder功能
2.transfomer中Decoder和Encoder的base_layer的源码实现的BaseTransformerLayer类,帮助我们丝滑地搭建各类transformer网络
3.transfomer中正余弦位置编码的源码实现[可选]

3.设置优化器学习率和训练/验证模型,对应runner.py和train.py
4.可视化测试单个图片的预测结果,对应demo.py

逐步源码实现

源码结构如下
在这里插入图片描述

数据集读取

主要原理:根据dataset的路径,存储各个图片对应的路径,label隐藏在路径中.
在getitem函数中完成指定index图片和label的读取和数据增强功能

class Flowers(Dataset):
    # 用于读取flower数据集
    def __init__(self, dataset_path: str, transforms=None):
        '''
        存储所有数据 data路径和label
        :param dataset_path:
        '''
        super(Dataset, self).__init__()
        flowers = os.listdir(dataset_path)
        flowers = sorted(flowers) # 必须排序,否在每一次顺序不一样训练测试类别就会乱
        self.flower_paths = []
        self.class2label = {}  # 类别str 转 label
        label = 0
        for _, flower in enumerate(flowers):
            flowers_path = os.path.join(dataset_path, flower)
            if os.path.isdir(flowers_path):
                self.class2label[flower] = label
                label +=1
                sub_flowers = os.listdir(flowers_path)
                for sub_flower in sub_flowers:
                    self.flower_paths.append(os.path.join(flowers_path, sub_flower))
        self.label2class = label2class(self.class2label)  # label 转 类别str
        self.transforms = transforms
    ''''''
    def __getitem__(self, item):
        # 读取数据和label
        img = Image.open(self.flower_paths[item])
        label = self.class2label[self.flower_paths[item].split('/')[-2]]
        if self.transforms is not None:
            img = self.transforms(img)  # 数据增强
        return img, label

VIt模型搭建

将整个深度学习模型按照人体分为hand+backbone+neck+head 4个部分,Vit模型不同CNN模型,它的backbone+neck为多个MultiHeadAttention堆叠组成,称之为blocks.

hand

hand主用完成预处理,将数据用"手"揉捏成想要的类型.本处主要完成图片的patch操作,将图片分割成一个个小块,使用大核的卷积完成.然后把w和h拉平后shape就和NLP(b,n,d)一样了.

class PatchLayer(nn.Module):
    def __init__(self, img_size, patch_size=20, embeding_dim=64):
        super(PatchLayer, self).__init__()
        self.grid_size = (img_size[0] // patch_size, img_size[1] // patch_size)
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.proj = nn.Conv2d(in_channels=3,
                              out_channels=embeding_dim,
                              kernel_size=(patch_size, patch_size),
                              stride=patch_size,
                              padding=0)
        self.norm = nn.LayerNorm(normalized_shape=embeding_dim)
    def forward(self, img):
        img = self.proj(img)  # 图片分割
        img = img.flatten(start_dim=2)  # wh拉平
        img = img.permute(0, 2, 1)  # [b wh c]
        img = self.norm(img)
        return img

类别和位置编码

类别编码

直接cat到input上面,那么最后也取出对应的那一列作为类别输出.这是transformer类型网络的常用手段.
个人解释:训练出类别的访问者,这个访问者可以从特征信息(原input)中提取类别信息.训练访问者方法就是类别loss回归,训练时候先推出,推理时推出

位置编码

add到input上,可以使用可学习式的位置编码也可以使用正余弦位置编码.这是transformer类型网络的常用手段,还要特征层编码等
个人解释:训练出位置的标记者

        # 类别编码
        self.cls_token = nn.Parameter(torch.zeros(size=[1, 1, embed_dim]))
        # 固定位置编码和可学习位置编码
        # self.pos_embed = posemb_sincos_1d(len=num_patches + 1, dim=embed_dim,temperature=1000).unsqueeze(0)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

blocks

blocks使用注意力机制完成特征提取,
个人解释:
input线性映射为[query,key,value],需求侧(query)从供给侧(value)中取值,取值的根据是qurey@key转置生成的注意力矩阵(需求侧和供给侧每个像素之间的相似度),最后输出与输入shape相同.所以我们重复depth次,多次特征提取.
源码直接调用:transfomer中Decoder和Encoder的base_layer的源码实现的BaseTransformerLayer类

head

主要对transfomer输出的类别特征进行映射,embed维度映射为num_class维度

self.head = nn.Linear(embed_dim, num_classes)

VIT整体

主要是上述几个模块的集合及其正向传播过程:
完成二维图片变一维特征,一维特征transfomer特征提取,分类头输出.

class Vit(nn.Module):
    def __init__(self, img_size=[224, 224], patch_size=16, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12):
        super(Vit, self).__init__()

        self.patch_embed = PatchLayer(img_size, patch_size, embed_dim)
        num_patches = self.patch_embed.num_patches
        self.blocks = nn.Sequential(*[
            BaseTransformerLayer(attn_cfgs=[dict(embed_dim=embed_dim, num_heads=num_heads)],
                                 fnn_cfg=dict(embed_dim=embed_dim, feedforward_channels=4 * embed_dim, act_cfg='ReLU',
                                              ffn_drop=0.),
                                 operation_order=('self_attn', 'norm', 'ffn', 'norm'))
            for _ in range(depth)
        ])

        # 类别编码
        self.cls_token = nn.Parameter(torch.zeros(size=[1, 1, embed_dim]))
        # 固定位置编码和可学习位置编码
        # self.pos_embed = posemb_sincos_1d(len=num_patches + 1, dim=embed_dim,temperature=1000).unsqueeze(0)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

        # 分类头
        self.head = nn.Linear(embed_dim, num_classes)

        self.loss_class = nn.CrossEntropyLoss()  # 内置softmax
        self.init_weights()
   ''''''
   def forward(self, img):
        query = self.hand(img)
        query = self.extract_feature(query)
        cls_fea = query[:, -1, :]  # 刚刚class_token被cat到了dim1的最后一个数
        x = self.head(cls_fea)
        return x

Runner(参考mmlab)

建立优化前,设置学习率,根据指定的work_flow顺序进行训练的测试,并保留最优权重

class Runner:
    def __init__(self, arg, model, device):
        self.arg = arg
        # 建立优化器
        params = [p for p in model.parameters() if p.requires_grad]
        self.optimizer = torch.optim.SGD(params=params, lr=arg.lr, momentum=0.9, weight_decay=5E-5)
        lf = lambda x: ((1 + math.cos(x * math.pi / arg.epochs)) / 2) * (1 - arg.lrf) + arg.lrf  # cosine
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lf)
        self.model = model.to(device)
        self.device = device
        if arg.load_from is not None and arg.load_from != '':
            weight_dict = torch.load(arg.load_from, map_location=device)
            model.load_state_dict(weight_dict)

    def run(self, dataloaders: dict):
        # 开始训练和验证
        assert 'train' in self.arg.work_flow.keys(), '必须要用训练任务'
        epoch_start = 0
        best_accuracy = 0.0
        while epoch_start < self.arg.epochs:
            for task, times in self.arg.work_flow.items():
                if task == 'train':  # 开始训练
                    for _ in range(times):
                        epoch_start += 1  # epoch只记录训练轮
                        self.model.train()
                        loss_sum = 0.0
                        data_loader = tqdm(dataloaders['train'], file=sys.stdout)
                        for step, data_dict in enumerate(data_loader):
                            img, label = data_dict
                            instance = {
                                'data': img.to(self.device),
                                'label': label.to(self.device)
                            }
                            loss = self.model.loss(**instance)
                            loss_sum += loss.detach()  # 要十分注意 避免往计算图中引入新的东西
                            loss.backward()
                            self.optimizer.step()
                            self.optimizer.zero_grad()
                            data_loader.desc = "[train epoch {}] loss: {:.3f}".\
                                format(epoch_start,loss_sum.item() / (step + 1))
                        self.scheduler.step()
                        print('train: epoch={}, loss={}'.format(epoch_start, loss_sum / (step + 1.0)))

                elif task == 'val':  # 开始验证
                    ''''''
                else:
                    raise ValueError('task must be in [train, val, test]')

可视化

读取单张图片,转换格式输入模型,输出的label,转化为class名和置信度,显示图像,class名和置信度.

if __name__ == '__main__':
    # 建立数据集
    data_transform = transforms.Compose([transforms.Resize(256),
                                         transforms.CenterCrop(224),
                                         transforms.ToTensor(),
                                         transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    img = Image.open('*****daisy/21652746_cc379e0eea_m.jpg')
    input = data_transform(img).unsqueeze(0)
    label2class = Flowers(dataset_path='../datasets/flower_photos-mini').label2class

    device = torch.device('cuda:0')

    # 建立模型
    model = Vit(img_size=[224, 224],
                patch_size=16,
                embed_dim=768,
                depth=12,
                num_heads=12,
                num_classes=5).to(device)
    weight_dict = torch.load('weights/vit.pth', map_location=device)
    model.load_state_dict(weight_dict)

    model.eval()
    with torch.no_grad():
        output = model(input.to(device))
        output = output.detach().cpu()

        label = output[0].numpy().argmax()
        cnf = torch.softmax(output[0],dim=0).numpy().max()*100.0
        cnf = np.around(cnf, decimals=2) #保留2位小数
    plt.imshow(img)
    plt.title('{} : {}%'.format(label2class[label],cnf))
    plt.show()

总结

vit是视觉transfomer最经典的模型,复现一次代码十分有必要,中间会产生很多思考和问题.
后面章节将会更有价值,我将会:

1.利用本次的代码进行很多思考和trick的验证
2.总结本次代码的BUG们,及其产生的原理和解决方法

如需获取全套代码请参考

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1432561.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2024机械工程师面试题

1.常用的机械画图软件有哪些 SolidWorks、Pro/e、CATIA、UG、Creo、CAD、inventor。CAXA电子图板. 2.第一视角是___&#xff0c;第三视角是___&#xff1b; 只要区别是&#xff1a;物体所处的位置不同。一般中国都使用第一视角的。 3.气缸属于_____执行元件&#xff0c;电磁…

Multisim14.0仿真(五十一)基于LM555定时器的分频器设计

一、1KHz脉冲设置&#xff1a; 二、555脉冲电路&#xff1a; 三、仿真电路&#xff1a; 四、运行仿真&#xff1a;

【Linux笔记】缓冲区的概念到标准库的模拟实现

一、缓冲区 “缓冲区”这个概念相信大家或多或少都听说过&#xff0c;大家其实在C语言阶段就已经接触到“缓冲区”这个东西&#xff0c;但是相信大家在C语言阶段并没有真正弄懂缓冲区到底是个什么东西&#xff0c;也相信大家在C语言阶段也因为缓冲区的问题写出过各种bug。 其…

【计算机视觉】万字长文详解:卷积神经网络

以下部分文字资料整合于网络&#xff0c;本文仅供自己学习用&#xff01; 一、计算机视觉概述 如果输入层和隐藏层和之前一样都是采用全连接网络&#xff0c;参数过多会导致过拟合问题&#xff0c;其次这么多的参数存储下来对计算机的内存要求也是很高的 解决这一问题&#x…

2024.2.4

双向链表的头插 头删 尾插 尾删 //头插插入 Doublelink insert_head(Doublelink head,datatype element) {Doublelink screat_Node();s->dataelement;//判断是否有空链表if(NULLhead){heads;}else{s->nexthead;head->priors;heads;}return head; } //头删 Doublelink…

sql相关子查询

1.什么是相关子查询 相关子查询是一个嵌套在外部查询中的查询&#xff0c;它使用了外部查询的某些值。每当外部查询处理一行数据时&#xff0c;相关子查询就会针对那行数据执行一次&#xff0c;因此它的结果可以依赖于外部查询中正在处理的行。 2.为什么要使用相关子…

微信小程序之本地生活案例的实现

学习的最大理由是想摆脱平庸&#xff0c;早一天就多一份人生的精彩&#xff1b;迟一天就多一天平庸的困扰。各位小伙伴&#xff0c;如果您&#xff1a; 想系统/深入学习某技术知识点… 一个人摸索学习很难坚持&#xff0c;想组团高效学习… 想写博客但无从下手&#xff0c;急需…

图论练习3

内容&#xff1a;过程中视条件改变边权&#xff0c;利用树状数组区间加处理 卯酉东海道 题目链接 题目大意 个点&#xff0c;条有向边&#xff0c;每条边有颜色和费用总共有种颜色若当前颜色与要走的边颜色相同&#xff0c;则花费为若当前颜色与要走的边颜色不同&#xff0c;…

Android学习之路(27) ProGuard,混淆,R8优化

前言 使用java编写的源代码编译后生成了对于的class文件&#xff0c;但是class文件是一个非常标准的文件&#xff0c;市面上很多软件都可以对class文件进行反编译&#xff0c;为了我们app的安全性&#xff0c;就需要使用到Android代码混淆这一功能。 针对 Java 的混淆&#x…

【快速上手QT】01-QWidgetQMainWindow QT中的窗口

总所周知&#xff0c;QT是一个跨平台的C图形用户界面应用程序开发框架。它既可以开发GUI程序&#xff0c;也可用于开发非GUI程序&#xff0c;当然我们用到QT就是要做GUI的&#xff0c;所以我们快速上手QT的第一篇博文就讲QT的界面窗口。 我用的IDE是VS2019&#xff0c;使用QTc…

Leetcode高频题:213打家劫舍II

题目链接&#xff1a;力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 题目描述 你是一个专业的小偷&#xff0c;计划偷窃沿街的房屋&#xff0c;每间房内都藏有一定的现金。这个地方所有的房屋都 围成一圈 &#xff0c;这意味着第一个房屋和最后一个…

MySQL知识点总结(三)——事务

MySQL知识点总结&#xff08;三&#xff09;——事务 事务事务的四大特性ACID原子性一致性隔离性持久性 脏读、幻读、不可重复读脏读不可重复读幻读 隔离级别读未提交读已提交可重复读串行化 事务的原理InnoDB如何实现事务的ACID事务的两阶段提交redo log与binlog的区别事务两阶…

基于SpringBoot+Vue的在线教育平台设计与实现

目录 项目介绍 技术栈 项目介绍 项目截图 搭建 代码截取 代码获取 项目介绍 近年由于疫情影响&#xff0c;线下教育行业受到较大冲击&#xff0c;因此线上教育培训有较好的发展势头&#xff0c;其中建筑行业考证培训是一个前景良好的发展方向&#xff0c;该行业不仅需要…

权威认可|亚数强势入围FreeBuf《CCSIP 2023中国网络安全产业全景图》10大细分领域

近日&#xff0c;国内安全行业门户FreeBuf旗下FreeBuf咨询正式发布《CCSIP&#xff08;China Cyber Security Industry Panorama&#xff09;2023中国网络安全行业全景册&#xff08;第六版&#xff09;》。 凭借卓越的技术产品能力、市场影响力及领先的综合实力&#xff0c;亚…

C++泛编程(3)

类模板基础 1.类模板的基本概念2.类模板的分文件编写3.类模板的嵌套 &#xff08;未完待续...&#xff09; 在往节内容中&#xff0c;我们详细介绍了函数模板&#xff0c;这节开始我们就来聊一聊类模板。C中&#xff0c;类的细节远比函数多&#xff0c;所以这个专题也会更复杂。…

【Crypto | CTF】BUUCTF 萌萌哒的八戒

天命&#xff1a;这年头连猪都有密码&#xff0c;真是奇葩&#xff0c;怪不得我一点头绪都没有 拿到软件&#xff0c;发现是.zip的压缩包&#xff0c;打不开&#xff0c;改成7z后缀名&#xff0c;打开了 发现是一张图片 也只有下面这行东西是感觉是密码了&#xff0c;又不可能…

[leetcode] 22. 括号生成

文章目录 题目描述解题方法方法一&#xff1a;dfs遍历java代码 方法二&#xff1a;按照卡特兰数的思路递归求出有效括号组合java代码 相似题目 题目描述 数字 n 代表生成括号的对数&#xff0c;请你设计一个函数&#xff0c;用于能够生成所有可能的并且 有效的 括号组合。 示…

计算机编码:原码、反码、补码的思想、原理和实例(详细版)

​ 目录 收起 一、原码、反码、补码的意义 意义&#xff1a; 三、原码 原码的特点&#xff1a; 原码存在的问题&#xff1a; 四、反码 反码的特点&#xff1a; 存在的问题&#xff1a; 五、补码 六、补码的思想&#xff08;模&&同余数&#xff09; 模 && 同余数…

exF2FS: Transaction Support in Log-Structured Filesystem——泛读笔记

FAST 2022 Paper 分布式元数据论文汇总 问题 现代应用程序努力以崩溃一致的方式保护其数据&#xff0c;这通常分布在多个文件抽象之上。在底层文件系统缺乏事务支持的情况下&#xff0c;应用程序使用复杂的协议来确保跨多个文件的事务性更新&#xff0c;产生长序列的写操作和…

小林Coding_操作系统_读书笔记

一、硬件结构 1. CPU是如何执行的 冯诺依曼模型&#xff1a;中央处理器&#xff08;CPU&#xff09;、内存、输入设备、输出设备、总线 CPU中&#xff1a;寄存器&#xff08;程序计数器、通用暂存器、指令暂存器&#xff09;&#xff0c;控制单元&#xff08;控制CPU工作&am…