【实验】SegViT: Semantic Segmentation with Plain Vision Transformers

news2024/10/5 18:32:56

在这里插入图片描述
想要借鉴SegViT官方模型源码部署到本地自己代码文件中

1. 环境配置

官网要求安装mmcv-full=1.4.4和mmsegmentation=0.24.0
在这之前记得把mmcv和mmsegmentation原来版本卸载

pip uninstall mmcv
pip uninstall mmcv-full
pip uninstall mmsegmentation

安装mmcv

其中,mmcv包含两个版本:一个是完整版mmcv(原来叫mmcv-full),一个是精简版mmcv-lite(原来叫mmcv),2.0.0版本之后更名了,具体的区别可以看mmcv官网手册和博客
安装mmcv-full(也就是mmcv完整版)主要参考mmcv官网手册。
如果你要安装mmcv>=2.0.0直接根据官网手册安装即可,不再赘述。
如果你要安装历史版本,例如我安装mmcv-full==1.4.4,可以参考我的记录。
在安装mmcv前,首先要知道自己的pytorch和cuda对应版本。
查看pytorch版本:

python -c 'import torch;print(torch.__version__)'

如果输出版本信息则已经安装pytorch
查看cuda版本:
注意要查你这个环境下pytorch对应的cuda版本
例如
这是我使用nvidia-smi命令查看的cuda版本:
在这里插入图片描述
这是我使用查看pytorch对应cuda版本命令:

python -c 'import torch;print(torch.version.cuda)'

也可以写成:

参考博客:https://blog.csdn.net/qq_49821869/article/details/127700187

python

>>>import torch
>>>torch.version.cuda

在这里插入图片描述
在这里我的pytorch版本应该是1.11.0,对应cuda版本是11.3

参考博客:https://blog.csdn.net/qq_41661809/article/details/125345690

于是,我输入命令:

pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html

不成功,于是我访问了这个网址查看,发现我能用的最低版本也就是1.4.7
在这里插入图片描述
于是我把命令换成了:

pip install mmcv-full==1.4.7 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html

mmcv-full安装结束

安装mmsegmentation

mmsegmentation原本我是按照官网指导安装的,
但是要求mmcv>=2.0.0,而且安装的版本是mmsegmentation==1.0.0,这和我的要求冲突了。
注意mmsegmentation要和mmcv版本匹配:

参考博客:https://blog.csdn.net/CharilePuth/article/details/122909620

在这里插入图片描述

于是我直接:

pip install mmsegmentation==0.24.0

安装成功。
“pip安装包像喝水一样简单”——曾经一位大佬如是说道。

2. 搞代码!

找模型配置文件

进入官网,在Training中找到模型对应的config文件:
在这里插入图片描述
Highlights中我知道了本文的一大亮点就是收缩结构,可以减小计算成本,因此接下来我会选择收缩结构:
在这里插入图片描述

由于我要跑的图片大小为512,因此我在这个代码的Results中找到同样512*512的COCO数据集对应模型:
在这里插入图片描述

返回configs文件夹找到这个数据集对应网络模型:
在这里插入图片描述
在这里插入图片描述
观察其代码得知所用backbone为vit_shrink,解码头为TPNATMHead:
在这里插入图片描述
注意其中的参数设置,同时还要关注__base__的配置文件,其中的参数在模型声明的时候要输入进去。

找模型代码

进入backbone文件夹下找到vit_shrink网络:
在这里插入图片描述
复制粘贴到自己的py文件中
在decode_heads文件夹下找到解码头代码:
在这里插入图片描述
复制粘贴到自己的py文件中

对代码缝缝补补

  1. 补充库文件
    库文件缺什么补什么,例如在tpn_atm_head解码器代码中需要引用另外两个解码器代码中的内容,直接把另外两个解码器的代码ctrl C+V进来,将需要使用的模块留下来即可:
    在这里插入图片描述
    在这里插入图片描述
  2. 检查输入输出
    backbone的输入和输出:
    在这里插入图片描述
    解码器部分的输入输出如图:
    在这里插入图片描述
    写一个SegViT来测试输入输出,注意参考配置文件将对应配置提前声明好:
class SegViT(nn.Module):
    def __init__(self, num_class):
        super(SegViT, self).__init__()
        out_indices = [7,23]
        in_channels = 1024
        img_size = 512
        # checkpoint = './pretrained/vit_large_p16_384_20220308-d4efb41d.pth'
        checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segmenter/vit_large_p16_384_20220308-d4efb41d.pth'

        # self.backbone = get_vit_shrink()
        self.backbone = vit_shrink(
            img_size=(img_size,img_size),embed_dims=in_channels,num_layers=24,drop_path_rate=0.3,num_heads=16,out_indices=out_indices)
        self.decoder = TPNATMHead(
            img_size=img_size,in_channels=in_channels,channels=in_channels,embed_dims=in_channels//2,num_heads=16,num_classes=num_class,num_layers=3, use_stages=len(out_indices))

    def forward(self, _x):
        x = self.backbone(_x)
        out = self.decoder(x)
        # if self.training:
            # return out['pred'], out['ce_aux']
        # else:
            # return out
        return out
 

运行检查out的类型

if __name__ == "__main__":
    x = torch.randn(4, 3, 512, 512)
    net = SegViT(6)
    # flops, params = profile(net, (x,))
    # print('flops: %.2f G, params: %.2f M' % (flops / 1000000000.0, params / 1000000.0))
    # res, aux = net(x)
    res = net(x)
    print(res)

然后发现输出是一个字典类型,prediction是其中键名为pred对应的值,该值为tensor类型,shape大小为(4,6,512,512),输出正确。
接下来要找辅助分支的输出。
在解码器头的forward中发现:
在这里插入图片描述
将注释去掉,得到辅助分支的输出(会将辅助分支的输出结果以字典元素形式加入到atm_out中,可以调试看看),记得要把对应的初始化函数的注释也去掉:
在这里插入图片描述
其中,由于我是单卡运行,于是把SyncBN改成了BN,否则报错。
另外,训练阶段和测试阶段的输出是不一样的,可以调试检查:

    def forward(self, _x):
        x = self.backbone(_x)
        out = self.decoder(x)
        if self.training:
            return out['pred'], out['ce_aux']
        else:
            return out
  1. 加载权重文件
    权重文件注意可以提前下载好
def get_vit_shrink(pretrained=True, img_size=512, in_channels=1024, out_indices=[7,23]):
    model = vit_shrink(
            img_size=(img_size,img_size),embed_dims=in_channels,num_layers=24,drop_path_rate=0.3,num_heads=16,out_indices=out_indices)
    if pretrained:
        checkpoint = '权重文件所在路径'
        # if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict']
        # else: state_dict = checkpoint
        model.load_state_dict(checkpoint, strict=False)
    return model

最终的模型是:

class SegViT(nn.Module):
    def __init__(self, num_class):
        super(SegViT, self).__init__()
        out_indices = [7,23]
        in_channels = 1024
        img_size = 512
        # checkpoint = './pretrained/vit_large_p16_384_20220308-d4efb41d.pth'
        # checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segmenter/vit_large_p16_384_20220308-d4efb41d.pth'

        self.backbone = get_vit_shrink()
        self.decoder = TPNATMHead(
            img_size=img_size,in_channels=in_channels,channels=in_channels,embed_dims=in_channels//2,num_heads=16,num_classes=num_class,num_layers=3, use_stages=len(out_indices))

    def forward(self, _x):
        x = self.backbone(_x)
        out = self.decoder(x)
        if self.training:
            return out['pred'], out['ce_aux']
        else:
            return out
 
  1. 检查最终的输入输出
    结束。

3. 运行模型

在自己的框架里,配置参数,然后运行即可。

结束。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/563815.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

你若在患难之日胆怯,你的力量便微小

如果你在做一件事情之前就过分放大它的困难,这就会逐渐降低自己去做它的动机和动力,还没开始你就已经削弱了自己的行动能力,在气势上就已经输了。 不要害怕困难,勇敢的去面对问题,解决问题,你就会在气势上更…

RK平台烧录固件的几种模式

在RK平台开发过程中,我们在使用烧录工具烧写固件的时候经常可以看到烧录工具会显示当前PC识别到的设备类型,一般有:MASKROM,LOADER,ADB,MSC等等。能烧录固件的模式有MASKROM模式和LOADER模式,下…

Python基础教程:第八章_Python文件操作

文件的编码 学习目标 掌握文件编码的概念和常见编码 思考:计算机只能识别:0和1,那么我们丰富的文本文件是如何被计算机识别,并存储在硬盘中呢? 答案:使用编码技术(密码本)将内…

STM32WB55_NUCLEO开发(12)----FUS 更新

概述 在 STM32WB 微控制器中,FUS(Firmware Upgrade Services)是用于固件升级的一种服务。这项服务可以让你更新设备上的无线栈固件(如蓝牙、Zigbee或 Thread 栈),以及无线 MCU (microcontroller unit) 的系…

day5 - 利用阈值勾勒

阈值处理在计算机视觉技术中占有十分重要的位置,他是很多高级算法的底层逻辑之一。本实验将练习使用图像阈值处理技术来处理不同的情况的图像,并获得图像轮廓。 完成本期内容,你可以: 了解图像阈值处理技术的定义和作用 掌握各阈…

PyQt5 使用 pyinstaller打包文件(speed)

编写界面 import sys,math from PyQt5.QtWidgets import * from PyQt5.QtCore import Qt from PyQt5.QtGui import *class RightBottomButton(QWidget) :def __init__(self):super(RightBottomButton,self).__init__()self.setWindowTitle("界面One")self.resize(400…

1.8G专网工业路由器促进4G智能电力建设,赋能配电站远程监控管理

TD-LTE已是当下成熟的4G通信技术,应用无线专网的场景也越来越多,4G技术在电力物联网中也得到了广泛应用。依托传统的人工监管方式,效率低、成本高、维护难,为促进4G智能电力建设迫切需要方便快捷的在线监控方式来及时发现电力配网…

xss跨站,订单,shell箱子反杀记

打开一个常见的订单靶场,老师自己搭建的 这个是可以进行xss漏洞的测试,凡是有这种数据交互的地方,前端有一个数据的接受,后端是数据的显示,这个过程就符合漏洞产生的前提条件,将输入的数据进行个显示&#…

BUUCTF-Basic部分(4道)

目录 Linux Labs BUU LFI COURSE 1 BUU BRUTE 1 BUU SQL COURSE 1 Linux Labs 第一个界面,给出了SSH ssh 用户名:root 密码:123456 地址和端口为动态分配的 以及映射地址和端口(这个地址端口是随机的) node4.buuoj.c…

R语言实践——使用 rWCVP 生成自定义清单

使用 rWCVP 生成自定义清单 介绍1. 特有物种清单2. 近特有物种清单2.1 在塞拉利昂和另一地区出现的物种2.2 在塞拉利昂和相邻地区出现的物种 3. 生成自定义报告 介绍 除了允许用户从世界维管植物名录(WCVP)创建清单外,rWCVP还提供了修改清单…

在小公司“混”了2年,我只认真做了5件事,如今顺利拿到字节 Offer

前言 是的,我一家小公司工作了整整两年时间,在入职这家公司前,也就是两年前,我就开始规划了我自己的人生,所以在两年时间里,我并未懈怠。 现如今,我已经跳槽到了字节,入职字节测试…

傅里叶级数 傅里叶变换 及应用

傅里叶级数和傅立叶变换是傅里叶分析的两个主要工具,它们之间有密切的关系。 什么是傅里叶级数 傅里叶级数是将一个周期函数分解为一系列正弦和余弦函数的和。它适用于周期性信号,可以将周期函数表示为一组振幅和相位不同的谐波分量的和。傅里叶级数展…

Netty编解码机制(二)

1.Netty入站和出站机制 1.1.基本介绍 1>.netty的组件设计: Netty的主要组件有Channel、EventLoop、ChannelFuture、ChannelHandler、ChannelPipe等; 2>.ChannelHandler充当了处理入站和出站数据的应用程序逻辑的容器.例如,实现ChannelInboundHandler接口(或ChannelInb…

Unity之如何接入google cardboard-xr-plugin实现android手机VR

前言 我们提到VR,总是会想到Oculus,HTC Vive,Pico等头戴VR设备,但是别忘了,最早Google就通过再手机端实现VR了,而且还推出过Cardboard手机盒子,让我们可以用最低的成本体验到VR效果。 插件下载 先说明一下,Unity在1028版本之前,支持过GoogleVR,但是后来因为统一…

Chapter8 :Physical Constraints(ug903)

8.1About Physical Constraints(关于物理约束) XilinxVivado集成设计环境(IDE)允许通过设置对象属性值对设计对象进行物理约束。示例包括: •I/O约束,如位置和I/O标准 •布局约束&…

惨败字节,苦心备战两个月斩获阿里offer,这份“258页软件测试面试宝典”也太顶了

测试三年有余,很多新学到的技术不能再项目中得到实践,同时薪资的涨幅很低,于是萌生了跳槽大厂的想法。 但大厂不是那么容易进的,前面惨败字节,为此我辛苦准备了两个月,又从小公司开始面试了半个月有余&…

k8s pv pvc的介绍|动态存储|静态存储

k8s pv pvc的介绍|动态存储|静态存储 1 emptyDir存储卷2 hostPath存储卷3 nfs共享存储卷4 PVC 和 PVNFS使用PV和PVC 4 搭建 StorageClass NFS,实现 NFS 的动态 PV 创建 1 emptyDir存储卷 当Pod被分配给节点时,首先创建emptyDir卷,并且只要该…

FPGA—可乐机拓展训练题(状态机)

题目:以可乐机为背景,一瓶可乐的价格还是 2.5 元。用按键控制投币(加入按键消抖功能),可以投 0.5 元硬币和 1 元硬币,投入 0.5 元后亮一个灯,投入 1 元后亮 2 个灯,投入 1.5 元后亮 …

【统计模型】学生课程类型选择影响因素分析

目录 学生课程类型选择影响因素分析 一、研究目的 二、数据来源和相关说明 三、描述性分析 3.1 样本描述 3.2 样本可视化 3.2.1 直方图 3.2.2 列联表 3.2.3 箱线图与折线图 3.2.4 相关性热力图 四、数学建模 4.1 无序多分类logistic回归模型 4.1.1 无序多分类logist…

STM32F030C8T6最小系统板和流水灯(原理图和PCB)

STM32F030C8T6最小系统板和流水灯。 嵌入式课的课程设计,要做个流水灯,我就顺便画个最小系统板,开源出来了,各位大佬指点指点,有哪里需要优化改进的。 那个WS2812的RGB灯用错引脚了,所以没法用PWM来控制&…