(a)Mask RCNN总体流程

news2024/11/26 18:53:00

(a)Mask RCNN总体流程

一.Mask RCNN 架构

自己整理了一份Mask RCNN架构图如下,其中绿色模块只有推理过程才会涉及。

Mask RCNN网络架构

核心模块包括:数据预处理,骨干网络,区域提议网络,FastRCNN分支,Mask分支,数据后处理等。

二.网络核心流程

class FasterRCNNBase(nn.Module):

   def __init__(self, backbone, rpn, roi_heads, transform):
       super(FasterRCNNBase, self).__init__()
       self.transform = transform
       self.backbone = backbone
       self.rpn = rpn
       self.roi_heads = roi_heads
       # used only on torchscript mode
       self._has_warned = False

   @torch.jit.unused
   def eager_outputs(self, losses, detections):
       # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
       if self.training:
           return losses

       return detections

   def forward(self, images, targets=None):
       # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
       if self.training and targets is None:
           raise ValueError("In training mode, targets should be passed")

       if self.training:
           assert targets is not None
           for target in targets:         # 进一步判断传入的target的boxes参数是否符合规定
               boxes = target["boxes"]
               if isinstance(boxes, torch.Tensor):
                   if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
                       raise ValueError("Expected target boxes to be a tensor"
                                        "of shape [N, 4], got {:}.".format(
                                         boxes.shape))
               else:
                   raise ValueError("Expected target boxes to be of type "
                                    "Tensor, got {:}.".format(type(boxes)))

       original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
       for img in images:
           val = img.shape[-2:]
           assert len(val) == 2  # 防止输入的是个一维向量
           original_image_sizes.append((val[0], val[1]))
       # original_image_sizes = [img.shape[-2:] for img in images]

       images, targets = self.transform(images, targets)  # 对图像进行预处理
       # print(images.tensors.shape)
       features = self.backbone(images.tensors)  # 将图像输入backbone得到特征图
       if isinstance(features, torch.Tensor):  # 若只在一层特征层上预测,将feature放入有序字典中,并编号为‘0’
           features = OrderedDict([('0', features)])  # 若在多层特征层上预测,传入的就是一个有序字典

       # 将特征层以及标注target信息传入rpn中
       # proposals: List[Tensor], Tensor_shape: [num_proposals, 4],
       # 每个proposals是绝对坐标,且为(x1, y1, x2, y2)格式
       proposals, proposal_losses = self.rpn(images, features, targets)

       # 将rpn生成的数据以及标注target信息传入fast rcnn后半部分
       detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)

       # 对网络的预测结果进行后处理(主要将bboxes还原到原图像尺度上)
       detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)

       losses = {}
       losses.update(detector_losses)
       losses.update(proposal_losses)

       if torch.jit.is_scripting():
           if not self._has_warned:
               warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
               self._has_warned = True
           return losses, detections
       else:
           return self.eager_outputs(losses, detections)

       # if self.training:
       #     return losses
       #
       # return detections

FasterRCNNBase是RCNN检测算法的基类,FasterRCNN类要继承FasterRCNNBase类,而MaskRCNN类又要继承FasterRCNN类,所以当实例化一个model并传入数据x时,会调用FasterRCNNBase的forward函数:

model = MaskRCNN(backbone,num_classes)
model(images,targets)

FasterRCNNBase的 init() 函数:

    def __init__(self, backbone, rpn, roi_heads, transform):
        super(FasterRCNNBase, self).__init__()
        self.transform = transform
        self.backbone = backbone
        self.rpn = rpn
        self.roi_heads = roi_heads
        # used only on torchscript mode
        self._has_warned = False

传入参数包括:
(1)backbone:
resnet50
resnet101
resnet50+fpn
resnet101+fpn
(2)rpn:
区域提议网络
(3)roi_haeds:
box roi pooling/align
two MLP head
box predictor
mask roi pool
mask head
mask predictor
(4)transforms:
GeneraRCNNtransforms类的实例,用于数据预处理

FasterRCNNBase的 forward() 函数:

    def forward(self, images, targets=None):
        # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
        if self.training and targets is None:
            raise ValueError("In training mode, targets should be passed")

        if self.training:
            assert targets is not None
            for target in targets:         # 进一步判断传入的target的boxes参数是否符合规定
                boxes = target["boxes"]
                if isinstance(boxes, torch.Tensor):
                    if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
                        raise ValueError("Expected target boxes to be a tensor"
                                         "of shape [N, 4], got {:}.".format(
                                          boxes.shape))
                else:
                    raise ValueError("Expected target boxes to be of type "
                                     "Tensor, got {:}.".format(type(boxes)))

        original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
        for img in images:
            val = img.shape[-2:]
            assert len(val) == 2  # 防止输入的是个一维向量
            original_image_sizes.append((val[0], val[1]))
        # original_image_sizes = [img.shape[-2:] for img in images]

        images, targets = self.transform(images, targets)  # 对图像进行预处理
        # print(images.tensors.shape)
        features = self.backbone(images.tensors)  # 将图像输入backbone得到特征图
        if isinstance(features, torch.Tensor):  # 若只在一层特征层上预测,将feature放入有序字典中,并编号为‘0’
            features = OrderedDict([('0', features)])  # 若在多层特征层上预测,传入的就是一个有序字典

        # 将特征层以及标注target信息传入rpn中
        # proposals: List[Tensor], Tensor_shape: [num_proposals, 4],
        # 每个proposals是绝对坐标,且为(x1, y1, x2, y2)格式
        proposals, proposal_losses = self.rpn(images, features, targets)

        # 将rpn生成的数据以及标注target信息传入fast rcnn后半部分
        detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)

        # 对网络的预测结果进行后处理(主要将bboxes还原到原图像尺度上)
        detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)

        losses = {}
        losses.update(detector_losses)
        losses.update(proposal_losses)

        if torch.jit.is_scripting():
            if not self._has_warned:
                warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
                self._has_warned = True
            return losses, detections
        else:
            return self.eager_outputs(losses, detections)

首先增加一些容错机制,保住输入数据格式符合模型要求,然后将images和targets输入transforms中进行数据格式的预处理;然后将images输入到backbone中,得到特征图features;将features,images,targets输入rpn网络中,得到proposals和proposals_loss;然后将proposals,images,features等输入到roi_heads得到detections和detector_loss;如果在训练模式下,则返回loss(proposals_loss和detection_loss),在推理模式下,则返回detections。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1188620.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【MATLAB】设置图形透明度

1 Scatter散点图 % 设置散点大小 s.SizeData 100;散点标记符号设置如下: 在绘制散点图时,想设置透明度: 更改散点透明度后,图形如下: 相关绘图代码如下: figure(1) set(gcf, Units, figureUnits, Po…

剪贴板管理软件 Paste Wizard mac中文版功能特色

Paste Wizard mac是一款剪贴板管理工具,它可以帮助用户更高效地管理剪贴板中的文本、图片、链接等内容。 Paste Wizard mac特色功能 提供了多种方式来保存和管理剪贴板中的内容。用户可以创建自定义的标签,将内容按照标签进行分类,方便快速查…

【RPC】前传

前传 本地程序用的go语言,想把main.go程序当中一些计算工作放到服务器上进行,而只需要把结果给我即可。由于平台上暂时不能运行Go代码,所以写的是python文件。 1、主要是使用ssh依赖进行连接,但是大概率是需要手动添加的&#xf…

android手机平板拓展电脑屏幕

有这么两个软件 spacedesk_driver_Win_10_64_v1065_BETA.msi 安装在电脑上 spacedeskv0.91.1_chinese.apk 安装在android设备上 同一个局域网投屏就好了。 局域网无限投屏是很吃带宽的。 建议usb共享网络,不占用带宽、延迟低。 下载地址: https:/…

OpenCV实现手势虚拟拖拽

前言: Hello大家好,我是Dream。 今天来学习一下如何使用OpenCV实现手势虚拟拖拽,欢迎大家一起前来探讨学习~ 一、主要步骤及库的功能介绍 1.主要步骤 要实现本次实验,主要步骤如下: 导入OpenCV库。通过OpenCV读取摄…

使用Drupal管理小型项目?试试Docker快速部署Drupal结合内网穿透实现远程访问

🎬 鸽芷咕:个人主页 🔥个人专栏:《Linux深造日志》《C干货基地》 ⛺️生活的理想,就是为了理想的生活! 文章目录 前言1. Docker安装Drupal2. 本地局域网访问3 . Linux 安装cpolar4. 配置Drupal公网访问地址5. 公网远程访问Drupal…

四川思维跳动商务信息咨询有限公司可信吗?

在今天的数字化时代,抖音带货已成为一种全新的商业模式。许多公司都在通过这种形式进行产品推广和销售,其中,四川思维跳动商务信息咨询有限公司以其专业的服务和良好的信誉,在抖音带货领域赢得了广泛赞誉。 四川思维跳动商务信息…

抖店怎么做才会快速起店?跟着这个思路来,一周搞定!

大家好,我是电商糖果 有不少朋友,自己开了一家抖店。 因为不懂运营,店铺一直没有流量,也不出单。 糖果做抖店三年多了,不敢自吹有多么优秀,但是做店还是有一套自己的方法的。 按照糖果这个思路做店&…

echarts 图表 地图实例

效果&#xff1a; 代码实现&#xff1a; draw(data) {var option {tooltip: {trigger: item,icon: query,// triggerOn: click,formatter: function (e, t, n) {let string ;string <div style"padding:10px"><span style"padding-right:10px"…

Microsoft SDKs 有文件重定义导致编译失败的处理

一个32位的mfc项目&#xff0c;之前采用vs2019编译&#xff0c;现在换了电脑(系统是win10)&#xff0c;采用vs2022编译时&#xff0c;提示如下错误&#xff1a; 1>------ 已启动生成: 项目: aAnsys, 配置: Debug Win32 ------ 1>cl : 命令行 warning D9035: “Gm”选项…

【ubuntu】ubuntu系统查看服务命令

查看正在运行的服务 sudo service --status-all [] 代表服务是在启动运行的状态 [-] 代表服务是在关闭停止的状态

使用Go语言抓取酒店价格数据的技术实现

目录 一、引言 二、准备工作 三、抓取数据 四、数据处理与存储 五、数据分析与可视化 六、结论与展望 一、引言 随着互联网的快速发展&#xff0c;酒店预订已经成为人们出行的重要环节。在选择酒店时&#xff0c;价格是消费者考虑的重要因素之一。因此&#xff0c;抓取酒…

opencv读取图片的方式影响图像绘制的颜色

圆圈的颜色设置不变&#xff0c;仅仅更改imread读取图片的方式 #frame cv2.imread(img_path,2)##flag2,单通道&#xff0c;原深度 **frame cv2.imread(img_path)##flag2,单通道&#xff0c;原深度** #cv2.circle(frame, (int(lmx), int(lmy)), 8, (0, 0, 125), 3) ### open…

优思学院|推行精益六西格玛困难重重?7大原因分析助你避坑

六西格玛&#xff0c;是一种让企业在绩效管理的舞台上跳得更高更远的方法。它不仅仅是一套原则和技术&#xff0c;更是一种对完美的执着追求。 在这个舞台上&#xff0c;企业的流程管理得以严格、集中&#xff0c;质量得以高效提升。优思学院总结出六西格玛的核心是&#xff1…

互联网金融风控常见知识点

1.怎么做互联网金融风控 首先风险不是都是坏的&#xff0c;风险是有价值的。也就是风险的VaR值(Value at Risk) 对于互联网信贷风控&#xff0c;是要把风险和收益做到更合理的平衡&#xff0c;在控制风险水平的情况下使得收益更高。 所以&#xff0c;做风控的不是一味地追求耕…

VS Code + VUE 代码自动格式化配置

插件列表 ESLintVetur setting.json { "[vue]": { "editor.defaultFormatter": "octref.vetur" }, "[javascript]": { "editor.defaultFormatter": "vscode.typescript-language-features" }, …

抖音双11进入决赛圈,爆款王炸单品竟是.....

今年&#xff0c;抖音将双11战线拉长&#xff0c;给足品牌和消费者时间备战&#xff0c;第一轮抢跑期战绩亮眼&#xff0c;多项双11销售增长记录被刷新&#xff0c;引爆全域流量。最后几天&#xff0c;抖音商城全面进入终局厮杀阶段&#xff0c;爆发期下半程对比抢跑期增速放缓…

Scala爬虫实战:采集网易云音乐热门歌单数据

导言 网易云音乐是一个备受欢迎的音乐平台&#xff0c;汇集了丰富的音乐资源和热门歌单。这些歌单涵盖了各种音乐风格和主题&#xff0c;为音乐爱好者提供了一个探索和分享音乐的平台。然而&#xff0c;有时我们可能需要从网易云音乐上获取歌单数据&#xff0c;以进行音乐推荐…

后端面试问题(学习版)

JAVA相关 JAVA语言概述 1. 一个".java"源文件中是否可以包含多个类&#xff1f;有什么限制&#xff1f; 可以。 一个源文件可以声明多个类&#xff0c;但是最多只能有一个类使用public进行声明 且要求声明public的类的类名与源文件相同。 2. Java的优势&#xff…

Python中的del用法

大家早好、午好、晚好吖 ❤ ~欢迎光临本文章 如果有什么疑惑/资料需要的可以点击文章末尾名片领取源码 python中的del用法比较特殊&#xff0c;新手学习往往产生误解&#xff0c;弄清del的用法&#xff0c;可以帮助深入理解python的内存方面的问题。 python的del不同于C的fre…