RegionCLIP网络结构解析 Region-based Language-Image Pretraining

news2024/10/6 14:35:08

1、简单介绍

主要是关注目标检测方面的工作,现在纯CV已经前景黯淡,即使前段时间的YOLOv9发布也是关注一般。
现在大模型已成热点,而大模型要求的数据量和算力和算法复杂度,显然让很多人却步。但是具有大模型特点的多模态算法也算是研究的趋势,所以目前主要是关注多模态方面的目标检测工作。

其中目标检测领域,目前和多模态相关的主要是 开集、开放词汇、描述性目标检测以及情景理解等。相关的研究工作已经越来越多,这里权当学习记录。

RegionCLIP作为OVD检测算法,也是具有一定的代表性。

RegionCLIP的官方网址:https://github.com/microsoft/RegionCLIP
RegionCLIP的论文网址:https://arxiv.org/pdf/2112.09106.pdf

在这里插入图片描述

文章概述(摘自GitHub):

我们提出了 RegionCLIP,它显着扩展了 CLIP 以学习区域级的视觉表示。RegionCLIP支持图像区域和文本概念之间的细粒度对齐,从而支持基于区域的推理任务,包括零样本对象检测和开放词汇对象检测。

①预训练:我们利用 CLIP 模型将图像区域与模板标题进行匹配,然后预训练模型以对齐这些区域-文本对。

②零样本推理:预训练后,学习区域表示支持用于对象检测的零样本推理。

③迁移学习:学习的 RegionCLIP 模型可以通过额外的对象检测注释进一步微调,从而允许我们的模型用于完全监督或开放词汇的对象检测。

④结果:我们的方法展示了零样本目标检测和开放词汇目标检测的最新结果。

在这里插入图片描述

概括一下:核心思想就是把之前 图像特征和文本特征匹配的方式 聚焦到了 图像的局部区域特征 和文本特征的匹配

2、网络结构

大致看了代码,RegionCLIP是基于detectron2写的,包括预训练模型的训练和Fast RCNN结构的网络

2.1 预训练配置:

在这里插入图片描述
可以看到,这个预训练模型的结构是 PretrainFastRCNN
代码在
在这里插入图片描述
可以看到他的forward函数:

	def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
       
        if not self.training:
            return self.inference(batched_inputs)
        gt_instances = None
        losses = {}
        
        # localization branch: offline modules to get the region proposals
        proposals = self.get_region_proposals(batched_inputs)
        global_proposals = self.create_global_proposals(batched_inputs)

        # recognition branch: get 2D feature maps using the backbone of recognition branch and extract region features
        images = self.preprocess_image(batched_inputs)
        features = self.backbone(images.tensor)
        region_feats = self.get_region_features(images, features, proposals, gt_instances)
        global_feats = self.get_region_features(images, features, global_proposals, gt_instances)

        # image-text level matching
        if self.img_txt_level:
            self.image_text_matching(batched_inputs, proposals, region_feats, losses, global_feats=global_feats)

        # region-concept level matching
        if self.concept_emb is not None:
            self.region_concept_matching(images, proposals, gt_instances, region_feats, losses)

        return losses

从上可以看到区域选取是通过 self.get_region_proposals(batched_inputs) 实现的 ,

self.get_region_features(images, features, proposals, gt_instances) 这个是获取区域图像的特征

self.region_concept_matching(images, proposals, gt_instances, region_feats, losses) 是 区域图像特征和文本特征匹配的

2.2 CLIPFastRCNN 结构

如下配置文件,可以看到整体的网络配置
在这里插入图片描述

class CLIPFastRCNN(nn.Module):
    """
    Fast R-CNN style where the cropping is conducted on feature maps instead of raw images.
    It contains the following two components: 
    1. Localization branch: pretrained backbone+RPN or equivalent modules, and is able to output object proposals
    2. Recognition branch: is able to recognize zero-shot regions
    """
    @configurable
    def __init__(
        self,
        *,
        offline_backbone: Backbone,
        backbone: Backbone,
        offline_proposal_generator: nn.Module,
        language_encoder: nn.Module, 
        roi_heads: nn.Module,
        pixel_mean: Tuple[float],
        pixel_std: Tuple[float],
        input_format: Optional[str] = None,
        vis_period: int = 0,
        clip_crop_region_type: str = 'GT',
        use_clip_c4: False,
        use_clip_attpool: False,
        offline_input_format: Optional[str] = None,
        offline_pixel_mean: Tuple[float],
        offline_pixel_std: Tuple[float],
    ):

这是定义的 CLIPFastRCNN 的初始内容,包含要传递的参数模块
其中backbone 是 build_clip_resnet_backbone,这个可以在如下找到

def build_backbone(cfg, input_shape=None):
    """
    Build a backbone from `cfg.MODEL.BACKBONE.NAME`.

    Returns:
        an instance of :class:`Backbone`
    """
    if input_shape is None:
        input_shape = ShapeSpec(channels=len(cfg.MODEL.PIXEL_MEAN))

    backbone_name = cfg.MODEL.BACKBONE.NAME
    backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg, input_shape)
    assert isinstance(backbone, Backbone)
    return backbone

也就是通过 cfg.MODEL.BACKBONE.NAME 来定位到定义的backbone,如下:
在这里插入图片描述

可以看到,最终返回一个 ModifiedResNet

其中用了配置文件中的 MODEL.BACKBONE.FREEZE_ATMODEL.RESNETS.OUT_FEATURESMODEL.RESNETS.DEPTH具体如下:

def build_clip_resnet_backbone(cfg, input_shape):
    """
    Create a CLIP-version ResNet instance from config.

    Returns:
        ModifiedResNet: a :class:`ModifiedResNet` instance.
    """
    # port standard ResNet config to CLIP ModifiedResNet
    freeze_at           = cfg.MODEL.BACKBONE.FREEZE_AT
    out_features        = cfg.MODEL.RESNETS.OUT_FEATURES
    depth               = cfg.MODEL.RESNETS.DEPTH
    # num_groups          = cfg.MODEL.RESNETS.NUM_GROUPS
    # width_per_group     = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
    # bottleneck_channels = num_groups * width_per_group
    # in_channels         = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
    # out_channels        = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
    # stride_in_1x1       = cfg.MODEL.RESNETS.STRIDE_IN_1X1
    # res5_dilation       = cfg.MODEL.RESNETS.RES5_DILATION
    # deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE
    # deform_modulated    = cfg.MODEL.RESNETS.DEFORM_MODULATED
    # deform_num_groups   = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS
    
    num_blocks_per_stage = {
        18: [2, 2, 2, 2],
        34: [3, 4, 6, 3],
        50: [3, 4, 6, 3],
        101: [3, 4, 23, 3],
        152: [3, 8, 36, 3],
        200: [4, 6, 10, 6], # flag for ResNet50x4
    }[depth]
    vision_layers = num_blocks_per_stage
    vision_width = {
        50: 64,
        101: 64,
        200: 80, # flag for ResNet50x4
    }[depth]  # cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
    
    # default configs of CLIP ModifiedResNet, but not used if only building ModifiedResNet as backbone
    embed_dim = {
        50: 1024,
        101: 512,
        200: 640, # flag for ResNet50x4
    }[depth] 
    vision_heads = vision_width * 32 // 64
    image_resolution = {
        50: 224,
        101: 224,
        200: 288, # flag for ResNet50x4
    }[depth] 

    # if combine {ModifiedResNet of CLIP, C4, text emb as classifier}, then has to use att_pool to match dimension
    create_att_pool = True if (cfg.MODEL.ROI_HEADS.NAME in ['CLIPRes5ROIHeads', 'CLIPStandardROIHeads'] and cfg.MODEL.CLIP.USE_TEXT_EMB_CLASSIFIER)\
                           or cfg.MODEL.ROI_HEADS.NAME == 'PretrainRes5ROIHeads' else False

    return ModifiedResNet(layers=vision_layers, 
                output_dim=embed_dim,
                heads=vision_heads,
                input_resolution=image_resolution,
                width=vision_width,
                out_features=out_features, 
                freeze_at=freeze_at,
                depth=depth,
                pool_vec=False,
                create_att_pool=create_att_pool,
                )

继续看 CLIPFastRCNN ,其中 @classmethod ----- 类方法让类模板具有记忆力,用@classmethod描述类方法,然后用"cls"代表本类。

    @classmethod
    def from_config(cls, cfg):
        # create independent backbone & RPN
        if cfg.MODEL.CLIP.CROP_REGION_TYPE == "RPN": 
            # create offline cfg for the pretrained backbone & RPN
            from detectron2.config import get_cfg
            offline_cfg = get_cfg()
            offline_cfg.merge_from_file(cfg.MODEL.CLIP.OFFLINE_RPN_CONFIG)
            if cfg.MODEL.CLIP.OFFLINE_RPN_LSJ_PRETRAINED: # large-scale jittering (LSJ) pretrained RPN
                offline_cfg.MODEL.BACKBONE.FREEZE_AT = 0 # make all fronzon layers to "SyncBN"
                offline_cfg.MODEL.RESNETS.NORM = "SyncBN" # 5 resnet layers
                offline_cfg.MODEL.FPN.NORM = "SyncBN" # fpn layers
                offline_cfg.MODEL.RPN.CONV_DIMS = [-1, -1] # rpn layers
            if cfg.MODEL.CLIP.OFFLINE_RPN_NMS_THRESH:
                offline_cfg.MODEL.RPN.NMS_THRESH = cfg.MODEL.CLIP.OFFLINE_RPN_NMS_THRESH  # 0.9
            if cfg.MODEL.CLIP.OFFLINE_RPN_POST_NMS_TOPK_TEST:
                offline_cfg.MODEL.RPN.POST_NMS_TOPK_TEST = cfg.MODEL.CLIP.OFFLINE_RPN_POST_NMS_TOPK_TEST # 1000

            # create offline backbone and RPN
            offline_backbone = build_backbone(offline_cfg)
            offline_rpn = build_proposal_generator(offline_cfg, offline_backbone.output_shape())

            # convert to evaluation mode
            for p in offline_backbone.parameters(): p.requires_grad = False
            for p in offline_rpn.parameters(): p.requires_grad = False
            offline_backbone.eval()
            offline_rpn.eval()
        # region proposals are ground-truth boxes
        elif cfg.MODEL.CLIP.CROP_REGION_TYPE == "GT":
            offline_backbone = None
            offline_rpn = None
            offline_cfg = None
        
        backbone = build_backbone(cfg)
        # build language encoder
        if cfg.MODEL.CLIP.GET_CONCEPT_EMB: # extract concept embeddings
            language_encoder = build_clip_language_encoder(cfg)
        else:
            language_encoder = None
        roi_heads = build_roi_heads(cfg, backbone.output_shape())

        return {
            "offline_backbone": offline_backbone,
            "offline_proposal_generator": offline_rpn, 
            "backbone": backbone,
            "language_encoder": language_encoder, 
            "roi_heads": roi_heads, 
            "input_format": cfg.INPUT.FORMAT,
            "vis_period": cfg.VIS_PERIOD,
            "pixel_mean": cfg.MODEL.PIXEL_MEAN,
            "pixel_std": cfg.MODEL.PIXEL_STD,
            "clip_crop_region_type" : cfg.MODEL.CLIP.CROP_REGION_TYPE,
            "use_clip_c4": cfg.MODEL.BACKBONE.NAME == "build_clip_resnet_backbone",
            "use_clip_attpool": cfg.MODEL.ROI_HEADS.NAME in ['CLIPRes5ROIHeads', 'CLIPStandardROIHeads'] and cfg.MODEL.CLIP.USE_TEXT_EMB_CLASSIFIER,
            "offline_input_format": offline_cfg.INPUT.FORMAT if offline_cfg else None,
            "offline_pixel_mean": offline_cfg.MODEL.PIXEL_MEAN if offline_cfg else None,
            "offline_pixel_std": offline_cfg.MODEL.PIXEL_STD if offline_cfg else None,
        }

从上面可以看到 backbone ,language_encoder,roi_heads 构建相应的模块,基本上CLIPFastRCNN 的模块都在里面了。不过里面的 offline_backbone 让我疑惑,不知道这个是如何起作用的,发挥什么功能?我判断是加载离线模型 就是做过预训练的模型,用来生成proposals的,感觉这段代码不太好看,而且后面也不清晰怎么处理的。

还可以进一步看forward函数,直观了解数据处理:

    def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
        """
        Args:
            batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
                Each item in the list contains the inputs for one image.
                For now, each item in the list is a dict that contains:

                * image: Tensor, image in (C, H, W) format.
                * instances (optional): groundtruth :class:`Instances`
                * proposals (optional): :class:`Instances`, precomputed proposals.

                Other information that's included in the original dicts, such as:

                * "height", "width" (int): the output resolution of the model, used in inference.
                  See :meth:`postprocess` for details.

        Returns:
            list[dict]:
                Each dict is the output for one input image.
                The dict contains one key "instances" whose value is a :class:`Instances`.
                The :class:`Instances` object has the following keys:
                "pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints"
        """
        if not self.training:
            return self.inference(batched_inputs)
        if "instances" in batched_inputs[0]:
            gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
        else:
            gt_instances = None
        
        # localization branch: offline modules to get the region proposals
        with torch.no_grad():  
            if self.clip_crop_region_type == "GT":  # from ground-truth
                proposals = []
                for r_i, b_input in enumerate(batched_inputs): 
                    this_gt = copy.deepcopy(b_input["instances"])  # Instance
                    gt_boxes = this_gt._fields['gt_boxes'].to(self.device)
                    this_gt._fields = {'proposal_boxes': gt_boxes, 'objectness_logits': torch.ones(gt_boxes.tensor.size(0)).to(self.device)}
                    proposals.append(this_gt)                
            elif self.clip_crop_region_type == "RPN": # from the backbone & RPN of standard Mask-RCNN, trained on base classes
                if self.offline_backbone.training or self.offline_proposal_generator.training:  #  was set to True in training script
                    self.offline_backbone.eval() 
                    self.offline_proposal_generator.eval()  
                images = self.offline_preprocess_image(batched_inputs)
                features = self.offline_backbone(images.tensor)
                if self.offline_proposal_generator is not None:
                    proposals, _ = self.offline_proposal_generator(images, features, None)     

        # recognition branch: get 2D feature maps using the backbone of recognition branch
        images = self.preprocess_image(batched_inputs)
        features = self.backbone(images.tensor)

        # Given the proposals, crop region features from 2D image features and classify the regions
        if self.use_clip_c4: # use C4 + resnet weights from CLIP
            if self.use_clip_attpool: # use att_pool from CLIP to match dimension
                _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, res5=self.backbone.layer4, attnpool=self.backbone.attnpool)
            else: # use mean pool
                _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, res5=self.backbone.layer4)
        else:  # regular detector setting
            if self.use_clip_attpool: # use att_pool from CLIP to match dimension
                _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, attnpool=self.backbone.bottom_up.attnpool)
            else: # use mean pool
                _, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
        if self.vis_period > 0:
            storage = get_event_storage()
            if storage.iter % self.vis_period == 0:
                self.visualize_training(batched_inputs, proposals)
        #visualize_proposals(batched_inputs, proposals, self.input_format)

        losses = {}
        losses.update(detector_losses)
        return losses

可以看到数据输入 batched_inputs 的处理,features = self.backbone(images.tensor) 这一步完成特征提取,features里包含了文本特征,后续进入 roi_heads 进行损失计算。

以上就是RegionCLIP的 CLIPFastRCNN 的网络结构对应代码解析。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1566461.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue3从入门到实战:路由的query和params参数

在Vue 3中,我们可以通过路由的查询参数来传递数据。这意味着我们可以在不同的页面之间传递一些信息,以便页面可以根据这些信息来显示不同的内容或执行不同的操作。 查询参数的使用方式类似于在URL中添加附加信息,以便页面之间可以根据这些信息…

AI技术创业:把握机遇,提升能力,迎接未来挑战

文章目录 人工智能三次浪潮人工智能时代机遇提升核心能力AI时代的长期赛道和早期优势实践应用:让AI工具为你所用学会变通:适应AI领域的快速变化提升核心能力:想象力、创造力和提问能力结语 人工智能三次浪潮 第一次黄金期:1956年…

基于SSM的网络视频播放器

目录 背景 技术简介 系统简介 界面预览 背景 互联网的迅猛发展彻底转变了全球各类组织的管理策略。自20世纪90年代起,中国政府和企业便开始探索利用网络系统进行信息管理。然而,早期的网络覆盖不广泛、用户接受度不高、相关法律法规不完善以及技术开…

WPF文本框TextEdit不以科学计数法显示

WPF文本框TextEdit不以科学计数法显示 一个float或者double类型的数值,如果小数点后0的个数≥4,在界面上就会自动以科学计数法显示, 比如:0.00003会显示成这样 但是很多时候我并不希望它这样显示,因为这样不方便编辑…

成都欣丰洪泰文化传媒有限公司引领电商新风向

在当今数字化时代,电子商务行业日新月异,竞争激烈。然而,在这股浪潮中,成都欣丰洪泰文化传媒有限公司凭借其独特的战略眼光和创新精神,正引领着电商领域的新浪潮。本文将探讨成都欣丰洪泰文化传媒有限公司如何在激烈的…

实战webSocket压测(二)jmeter配置webSocket连接

背景 我们可以通过Jmeter添加插件实现webSocket脚本编写。WebSocket的插件较多,我选择以WebSocket Samplers by Peter Doornbosch为例来进行配置。 步骤1、WebSocket Samplers插件安装 下载地址:JMeter WebSocket Samplers,建议下载最新版本…

阿里巴巴25届实习生内推

#阿里巴巴 #春招实习 阿里国际春季2025届实习生招聘4月1日已正式启动!学生网申投递、师兄师姐内推通道均已开放 整体介绍(含在招岗位) 内推投递方式 方式一:内推码自行投递 方式二:通过简历投递 简历发邮箱&#xf…

Lumos学习王佩丰Excel第一讲:认识Excel

最近发现自己在操作excel的一些特殊功能时会有些不顺手,所以索性找了一个比较全的教程(王佩丰excel24讲)拿来学习,刚好形成文档笔记,分享给有需要但没有时间看视频的朋友们。整体笔记以王老师授课的知识点去记录&#…

蓝桥杯备考

目录 P8823 [传智杯 #3 初赛] 期末考试成绩 题目描述 输入格式 输出格式 输入输出样例 说明/提示 代码 P8828 [传智杯 #3 练习赛] 直角三角形 题目描述 输入格式 输出格式 输入输出样例 代码 P8833 [传智杯 #3 决赛] 课程 题目背景 题目描述 输入格式 输出格式…

【热门话题】文言一心与ChatGPT-4:一场跨时代智能对话系统的深度比较

🌈个人主页: 鑫宝Code 🔥热门专栏: 闲话杂谈| 炫酷HTML | JavaScript基础 ​💫个人格言: "如无必要,勿增实体" 文章目录 文言一心与ChatGPT-4:一场跨时代智能对话系统的深度比较一、技术背景…

西电计科大三下SOC微体系结构设计作业合集

目录 一.VHDL设计作业 1.基于硬件描述语言的3-8译码器逻辑电路设计 2.8位双向移位寄存器设计 3.基于有限状态机的自助售票系统设计 4.按键消抖电路设计 5.同步环形FIFO设计 6.线上实验——时钟模块设计 7.线上实验——原码二位乘法器设计 8.线上实验——布斯乘法器设…

新产品机会的两大来源:分析当前产品组合与创意生成工具或创造性思维技术

一、引言 在快速变化的商业世界中,企业/组织若想保持竞争力并持续繁荣,就必须不断寻找新产品机会。这些机会并非凭空而来,而是源于:1. 分析当前产品组合,找出可以进行产品改进或产品线延伸的领域。2.创意生成工具或创…

【JAVA】基础学习03变量和关键字

文章目录 JAVA变量与运算符1.关键字(keyword)2.标识符( identifier)2.1命名规则2.2命名规范2.3变量作用和类型2.3.1整型变量2.3.2补充:计算机存储单位2.3.3浮点类型:float、double2.3.4 关于浮点型精度的说明2.3.5 字符类型&#…

docker部署python

1.部署python 1.1安装docker(按这个操作就可以) http://t.csdnimg.cn/cezmt 1.2拉取python镜像,一般拉取收藏量最高的那个 sudo docker search python 1.3拉取python镜像 #可以设置版本号,也可以不设置版本号,不…

外卖配送时间预测项目

注意:本文引用自专业人工智能社区Venus AI 更多AI知识请参考原站 ([www.aideeplearning.cn]) 项目背景 外卖服务的兴起: 随着互联网技术和移动应用的发展,外卖成为一种日益普及的餐饮服务方式。顾客通过餐厅、杂货店的网站或移…

全球范围内2nm晶圆厂建设加速

随着人工智能浪潮席卷而来,先进制程芯片的重要性日益凸显。当前,3nm工艺节点是行业内最先进的节点。与此同时,台积电、三星、英特尔、Rapidus等厂商正积极布局建设2nm晶圆厂。台积电与三星此前计划于2025年量产2nm芯片,而Rapidus则…

嵌入式中常见的面试题分享

1.关键字static的作用是什么?为什么static变量只初始化一次? 1)修饰局部变量:使得变量变成静态变量,存储在静态区,存储在静态区的数据周期和程序相同, 在main函数开始前初始化,在退…

java:6 数组(3)

文章目录 14. 二维数组14.1 定义14.2 二维数组的使用14.3 练习 【老韩视频p175-】 14. 二维数组 14.1 定义 多维数组我们只介绍二维数组: 二维数组的应用场景:比如我们开发一个五子棋游戏,棋盘就是需要二维数组来表示。请用二维数组输出如下…

【漏洞复现】某科技X2Modbus网关多个漏洞

漏洞描述 最近某科技X2Modbus网关出了一个GetUser的信息泄露的漏洞,但是经过审计发现该系统80%以上的接口均是未授权的,没有添加相应的鉴权机制,以下列举多个未授权接口以及获取相关敏感信息的接口。 免责声明 技术文章仅供参考,任何个人和组织使用网络应当遵守宪法法律…

[C++初阶]初识C++(一)—————命名空间和缺省函数

声明: 本篇文献内容选自百度文库、比特就业课 代码内容部分选自比特就业课 一、命名空间 1.什么是命名空间 在编程语言中,命名空间是一种特殊的作用域,它包含了处于该作用域中的所有标示符,而且其本身也是由标示符表示的。命名空间的使用目…