基于PyTorch搭建Mask-RCNN实现实例分割

news2024/11/26 19:31:29

基于PyTorch搭建Mask-RCNN实现实例分割

在这篇文章中,我们将讨论 Mask RCNN Pytorch 背后的理论以及如何在 PyTorch 中使用预训练的 Mask R-CNN 模型。

1. 语义分割、目标检测和实例分割

在之前的博客文章里介绍了语义分割和目标检测(如果感兴趣可以参考以下文章):

  1. 图像语义分割概述
  2. Pytorch实现图像语义分割(初体验)
  3. 基于PyTorch搭建FasterRCNN实现目标检测
  • 语义分割:为图像中的每个像素分配一个类标签(例如狗、猫、人、背景等)。
  • 目标检测:在对象检测中,我们为包含对象的边界框分配一个类标签。
  • 实例分割:将图像中的每个物体分割成独立的实例。

2. Mask R-CNN 架构

Mask R-CNN 的架构是 Faster R-CNN 的扩展。Faster R-CNN 架构具有以下组件

  • 卷积层:输入图像经过多个卷积层以创建特征图。
  • 区域生成网络(RPN:Region Proposal Network)。卷积层的输出用于训练网络,该网络提出包围对象的区域。
  • 分类器:相同的特征图也用于训练分类器,该分类器为框内的对象分配标签。

Faster R-CNN 比 Fast R-CNN 更快,因为特征图计算一次并由 RPN 和分类器重复使用。

Mask R-CNN 将这一想法更进一步。除了将特征图提供给 RPN 和分类器之外,它还用它来预测边界框内对象的二进制掩码。

看待 Mask R-CNN 的掩模预测部分的一种方式是,它是一个用于语义分割的全卷积网络(FCN)。唯一的区别是 FCN 应用于边界框,并且它与 RPN 和分类器共享卷积层。

3. 基于PyTorch搭建Mask R-CNN

3.1 输入和输出

该模型期望输入是形状为 (n, c , h, w) 的张量图像列表,其值范围为 0-1。图像的大小不需要固定。

  • n 是图像数
  • c 是通道数,对于 RGB 图像,为 3
  • h 是图像的高度
  • w 是图像的宽度

该模型返回边界框的坐标、模型预测将出现在输入图像中的类标签、标签的分数、标签中存在的每个类的掩码。

3.2 预训练模型

model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model.eval()

3.3 模型预测

实例分割的标签列表与对象检测任务相同。

COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]


def get_prediction(img_path, threshold):
    img = Image.open(img_path)
    transform = T.Compose([T.ToTensor()])
    img = transform(img)
    pred = model([img])
    pred_score = list(pred[0]['scores'].detach().numpy())
    pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1]
    masks = (pred[0]['masks'] > 0.5).squeeze().detach().cpu().numpy()
    pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
    pred_boxes = [[(int(i[0]), int(i[1])), (int(i[2]), int(i[3]))] for i in list(pred[0]['boxes'].detach().numpy())]
    masks = masks[:pred_t + 1]
    pred_boxes = pred_boxes[:pred_t + 1]
    pred_class = pred_class[:pred_t + 1]
    return masks, pred_boxes, pred_class
  • 从图像路径获得图像
  • 使用 PyTorch 的变换将图像转换为图像张量
  • 图像通过模型来获取预测
  • 从模型中获取掩模、预测类和边界框坐标,并将软掩模制成二进制(0或1)。示例:猫的部分设为 1,图像的其余部分设为 0。

每个预测对象的蒙版都会从一组 11 种预定义颜色中随机获得颜色,以便在输入图像上可视化蒙版。

def random_colour_masks(image):
    colours = [[0, 255, 0],[0, 0, 255],[255, 0, 0],[0, 255, 255],[255, 255, 0],[255, 0, 255],[80, 70, 180],[250, 80, 190],[245, 145, 50],[70, 150, 250],[50, 190, 190]]
    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)
    r[image == 1], g[image == 1], b[image == 1] = colours[random.randrange(0, 10)]
    coloured_mask = np.stack([r, g, b], axis=2)
    return coloured_mask

3.4 实例分割

def instance_segmentation_api(img_path, threshold=0.5, rect_th=3, text_size=3, text_th=3):
    masks, boxes, pred_cls = get_prediction(img_path, threshold)
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    for i in range(len(masks)):
        rgb_mask = random_colour_masks(masks[i])
        img = cv2.addWeighted(img, 1, rgb_mask, 0.5, 0)
        cv2.rectangle(img, boxes[i][0], boxes[i][1],color=(0, 255, 0), thickness=rect_th)
        cv2.putText(img,pred_cls[i], boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, (0, 255, 0), thickness=text_th)
        plt.figure(figsize=(20,30))
        plt.imshow(img)
        plt.xticks([])
        plt.yticks([])
        plt.show()
  • 掩码、预测类和边界框通过 get_prediction 获得。
  • 每个蒙版都从 11 种颜色中随机选择一种颜色。
  • 使用 OpenCV 将每个掩模以 1:0.5 的比例添加到图像中。
  • 使用 cv2.rectangle 绘制边界框,并将类名注释为文本。
    显示最终输出

3.5 运行测试

3.6 完整代码

import random
import torchvision
from PIL import Image
from torchvision import transforms as T
import numpy as np
import cv2
from matplotlib import pyplot as plt

model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model.eval()

COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]


def get_prediction(img_path, threshold):
    img = Image.open(img_path)
    transform = T.Compose([T.ToTensor()])
    img = transform(img)
    pred = model([img])
    pred_score = list(pred[0]['scores'].detach().numpy())
    pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1]
    masks = (pred[0]['masks'] > 0.5).squeeze().detach().cpu().numpy()
    pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
    pred_boxes = [[(int(i[0]), int(i[1])), (int(i[2]), int(i[3]))] for i in list(pred[0]['boxes'].detach().numpy())]
    masks = masks[:pred_t + 1]
    pred_boxes = pred_boxes[:pred_t + 1]
    pred_class = pred_class[:pred_t + 1]
    return masks, pred_boxes, pred_class


def random_colour_masks(image):
    colours = [[0, 255, 0],[0, 0, 255],[255, 0, 0],[0, 255, 255],[255, 255, 0],[255, 0, 255],[80, 70, 180],[250, 80, 190],[245, 145, 50],[70, 150, 250],[50, 190, 190]]
    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)
    r[image == 1], g[image == 1], b[image == 1] = colours[random.randrange(0, 10)]
    coloured_mask = np.stack([r, g, b], axis=2)
    return coloured_mask


def instance_segmentation_api(img_path, threshold=0.5, rect_th=3, text_size=3, text_th=3):
    masks, boxes, pred_cls = get_prediction(img_path, threshold)
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    for i in range(len(masks)):
        rgb_mask = random_colour_masks(masks[i])
        img = cv2.addWeighted(img, 1, rgb_mask, 0.5, 0)
        cv2.rectangle(img, boxes[i][0], boxes[i][1],color=(0, 255, 0), thickness=rect_th)
        cv2.putText(img,pred_cls[i], boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, (0, 255, 0), thickness=text_th)
        plt.figure(figsize=(20,30))
        plt.imshow(img)
        plt.xticks([])
        plt.yticks([])
        plt.show()


instance_segmentation_api('../img/cars.jpg')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1023594.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【golang】调度系列之P

调度系列 调度系列之goroutine 调度系列之m 在前面两篇中,分别介绍了G和M,当然介绍的不够全面(在写后面的文章时我也在不断地完善前面的文章,后面可能也会有更加汇总的文章来统筹介绍GMP)。但是,抛开技术细…

华为云云耀云服务器L实例使用教学 | 访问控制-安全组配置规则 实例教学

文章目录 访问控制-安全组什么叫安全组安全组配置默认安全组配置安全组配置实例安全组创建安全组模板配置安全组模板:通用Web服务器 配置安全组规则安全组配置规则功能介绍修改允许特定IP地址访问Web 80端口服务建立仅允许访问特定目的地址的安全规则配置网络ACL对实…

开源数字孪生基础设施

开源数字孪生基础设施 开源数字基础设施 开源数字基础设施 开源软件是基础设施发展的一种模式,这是在2007年美国科学基金会发布的《认识基础设施:动力机制、冲突和设计》中得出的结论。在这份55页的报告中三次集中谈到了开源软件(Open Sourc…

1999-2018年地级市经济增长数据

1999-2018年地级市经济增长数据 1、时间:1999-2018年 2、指标: 行政区划代码、城市、年份、地区生产总值_当年价格_全市_万元、地区生产总值_当年价格_市辖区_万元、人均地区生产总值_全市_元、人均地区生产总值_市辖区_元、地区生产总值增长率_全市_…

MySQL使用C语言链接

MySQL使用C语言链接 MySQL connect接口介绍mysql_initmysql_real_connectmysql_querymysql_store_result\mysql_use_result()mysql_num_rowsmysql_num_fieldsmysql_fetch_fieldsmysql_fetch_rowmysql_close MySQL connect 使用C语言来连接数据库,本质上就是利用一些…

「聊设计模式」之命令模式(Command)

🏆本文收录于《聊设计模式》专栏,专门攻坚指数级提升,助你一臂之力,带你早日登顶🚀,欢迎持续关注&&收藏&&订阅! 前言 在面向对象设计中,设计模式是重要的一环。设计…

c:Bubble Sort

/*****************************************************************//*** \file SortAlgorithm.h* \brief 业务操作方法* VSCODE c11* \author geovindu,Geovin Du* \date 2023-09-19 ***********************************************************************/ #if…

前端知识以及组件学习总结

JS 常用方法 js中字符串常用方法总结_15种常见js字符串用法_<a href"#">leo</a>的博客-CSDN博客 <script>var str"heool"console.log(str.length);console.log(str.concat(" lyt"));console.log(str.includes("he&quo…

WebPack5基础使用总结(一)

WebPack5基础使用总结 1、WebPack1.1、开始使用1.2、基本配置 2、处理样式资源2.1、处理Css资源2.2、处理Less资源2.3、处理Sass和Scss资源2.4、处理Styl资源 3、处理图片资源3.1、输出资源情况3.2、对图片资源进行优化 4、修改输出资源的名称和路径4.1、自动清空上次打包资源 …

想了解期权分仓交易和开户?这里告诉你。

期想了解期权分仓交易和开户&#xff1f;这里告诉你。权就是合约交易&#xff0c;通过买卖认购和认沽期权合约实现未来是否能赚钱&#xff0c;具备做多和做空T0双向交易机制&#xff0c;期权分仓开户就是零门槛开通期权账户&#xff0c;下文介绍想了解期权分仓交易和开户&#…

经验分享|作为程序员之后了解到的算法知识

欢迎关注博主 六月暴雪飞梨花 或加入【六月暴雪飞梨花】一起学习和分享Linux、C、C、Python、Matlab&#xff0c;机器人运动控制、多机器人协作&#xff0c;智能优化算法&#xff0c;滤波估计、多传感器信息融合&#xff0c;机器学习&#xff0c;人工智能等相关领域的知识和技术…

Java————栈

一 、栈 Stack继承了Vector&#xff0c;Vector和ArrayList类似&#xff0c;都是动态的顺序表&#xff0c;不同的是Vector是线程安全的。 是一种特殊的线性表&#xff0c; 其只允许在固定的一端进行插入和删除元素操作。 进行数据插入和删除操作的一端称为栈顶&#xff0c;另…

《计算机视觉中的多视图几何》笔记(4)

4 Estimation – 2D Projective Transformations 本章主要估计这么几种2D投影矩阵&#xff1a; 2D齐次矩阵&#xff0c;就是从一个图像中的点到另外一个图像中的点的转换&#xff0c;由于点的表示都是齐次的&#xff0c;所以叫齐次矩阵3D到2D的摄像机矩阵基本矩阵三视图之间的…

基于conda的相关命令

conda 查看python版本环境 打开Anaconda Prompt的命令输入框 查看自己的python版本 conda env list激活相应的python版本(环境&#xff09; conda avtivate python_3.9 若输入以下命令可查看python版本 python -V #注意V是大写安装相应的包 pip install 包名5.查看已安装…

智能井盖:提升城市井盖安全管理效率

窨井盖作为城市基础设施的重要组成部分&#xff0c;其安全管理与城市的有序运行和群众的生产生活安全息息相关&#xff0c;体现城市管理和社会治理水平。当前&#xff0c;一些城市已经将智能化的窨井盖升级改造作为新城建的重要内容&#xff0c;推动窨井盖等“城市部件”配套建…

工控机通过Profinet转Modbus RTU网关连接变频器与电机通讯案例

在工业自动化系统中&#xff0c;工控机扮演着重要的角色&#xff0c;它是数据采集、处理和控制的中心。工控机通过Profinet转Modbus RTU网关连接变频器与电机通讯&#xff0c;为工业自动化系统中的设备之间的通信提供了解决方案。工控机通过Profinet转Modbus RTU网关的方式&…

(leetcode)单值二叉树

个人主页&#xff1a;Lei宝啊 愿所有美好如期而遇 目录 题目&#xff1a; 思路&#xff1a; 代码&#xff1a; 画图与分析&#xff1a; 题目&#xff1a; 如果二叉树每个节点都具有相同的值&#xff0c;那么该二叉树就是单值二叉树。 只有给定的树是单值二叉树时&…

2023年以就业为目的学习Java还有必要吗?(文末送书)

目录 一、活力四射的 Java二、从零开始学会 Java三、准备工作四、基础知识五、进阶知识六、高级知识七、结语参与方式 大家好&#xff0c;我是哪吒。 文末送5本《Java编程动手学》 今天来探讨一个问题&#xff0c;现在学 Java 找工作还有优势吗&#xff1f; 在某乎上可以看到…

MS1861 视频处理与显示控制器 HDMI转MIPI LVDS转MIPI带旋转功能 图像带缩放,旋转,锐化

1. 基本介绍 MS1861 单颗芯片集成了 HDMI 、 LVDS 和数字视频信号输入&#xff1b;输出端可以驱动 MIPI(DSI-2) 、 LVDS 、 Mini-LVDS 以及 TTL 类型 TFT-LCD 液晶显示。可支持对输入视频信号进行滤波&#xff0c;图 像增强&#xff0c;锐化&#xff0c;对比度调节&am…

ai虚拟主播看车线上虚拟三维展示节约成本和资源

线上车展汽车3D展厅突破了前期虚拟和现实的障碍&#xff0c;使用户无论身在哪个城市&#xff0c;都可以随时随地在线3D看车&#xff0c;极大的方便了消费者的看车的线上体验。因此对企业来说&#xff0c;有购车意愿的顾客必然是会提高成交的可能性&#xff0c;那么如何满足顾客…