利用torchvision库实现目标检测与语义分割

news2025/1/9 2:18:02

一、介绍

利用torchvision库实现目标检测与语义分割。

二、代码

1、目标检测

from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as T
import torchvision
import numpy as np
import cv2
import random


COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]


def get_prediction(img_path, threshold):
    # 加载 mask_r_cnn 模型进行目标检测
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    model.eval()
    img = Image.open(img_path)
    transform = T.Compose([T.ToTensor()])
    img = transform(img)
    pred = model([img])
    pred_score = list(pred[0]['scores'].detach().numpy())
    print(pred[0].keys())  # ['boxes', 'labels', 'scores', 'masks']
    pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1]  # num of boxes
    pred_masks = (pred[0]['masks'] > 0.5).squeeze().detach().cpu().numpy()
    pred_boxes = [[(int(i[0]), int(i[1])), (int(i[2]), int(i[3]))] for i in list(pred[0]['boxes'].detach().numpy())]
    pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
    pred_masks = pred_masks[:pred_t + 1]
    pred_boxes = pred_boxes[:pred_t + 1]
    pred_class = pred_class[:pred_t + 1]
    return pred_masks, pred_boxes, pred_class


def random_colour_masks(image):
    colours = [[0, 255, 0], [0, 0, 255], [255, 0, 0], [0, 255, 255], [255, 255, 0], [255, 0, 255], [80, 70, 180],
               [250, 80, 190], [245, 145, 50], [70, 150, 250], [50, 190, 190]]
    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)
    r[image == 1], g[image == 1], b[image == 1] = colours[random.randrange(0, 10)]
    coloured_mask = np.stack([r, g, b], axis=2)
    return coloured_mask


def instance_segmentation_api(img_path, threshold=0.5, rect_th=3, text_size=2, text_th=2):
    masks, boxes, cls = get_prediction(img_path, threshold)
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    for i in range(len(masks)):
        rgb_mask = random_colour_masks(masks[i])
        randcol = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
        img = cv2.addWeighted(img, 1, rgb_mask, 0.5, 0)
        cv2.rectangle(img, boxes[i][0], boxes[i][1], color=randcol, thickness=rect_th)
        cv2.putText(img, cls[i], boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, randcol, thickness=text_th)
    plt.figure(figsize=(20, 30))
    plt.imshow(img)
    plt.xticks([])
    plt.yticks([])
    plt.show()
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    cv2.imwrite('result_det.jpg', img)


if __name__ == '__main__':
    instance_segmentation_api('horse.jpg')

 

 

2、语义分割

import torch
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from torchvision import models
from torchvision import transforms


def pre_img(img):
    if img.mode == 'RGBA':
        a = np.asarray(img)[:, :, :3]
        img = Image.fromarray(a)
    return img


def decode_seg_map(image, nc=21):
    label_colors = np.array([(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0),
                             (0, 0, 128), (128, 0, 128), (0, 128, 128), (128, 128, 128),
                             (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0),
                             (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
                             (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)])
    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)

    for l in range(0, nc):
        idx = image == l
        r[idx] = label_colors[l, 0]
        g[idx] = label_colors[l, 1]
        b[idx] = label_colors[l, 2]

    return np.stack([r, g, b], axis=2)


if __name__ == '__main__':
    # 加载 deep_lab_v3 模型进行语义分割
    model = models.segmentation.deeplabv3_resnet101(pretrained=True)
    model = model.eval()

    img = Image.open('horse.jpg')
    print(img.size)  # (694, 922)
    plt.imshow(img)
    plt.axis('off')
    plt.show()

    im = pre_img(img)
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    input_img = transform(im).unsqueeze(0)  # resize
    tt = np.transpose(input_img.detach().numpy()[0], (1, 2, 0))  # transpose
    print(tt.shape)  # (224, 224, 3)
    plt.imshow(tt)
    plt.axis('off')
    plt.show()

    output = model(input_img)
    print(output.keys())  # odict_keys(['out', 'aux'])
    print(output['out'].shape)  # torch.Size([1, 21, 224, 224])
    output = torch.argmax(output['out'].squeeze(), dim=0).detach().cpu().numpy()
    result_class = set(list(output.flat))
    print(result_class)  # {0, 13, 15}

    rgb = decode_seg_map(output)
    print(rgb.shape)  # (224, 224, 3)
    img = Image.fromarray(rgb)
    img.save('result_seg.jpg')
    plt.axis('off')
    plt.imshow(img)
    plt.show()

 

 

三、参考

Pytorch预训练模型、内置模型实现图像分类、检测和分割

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/918286.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[保研/考研机试] KY188 哈夫曼树 北京邮电大学复试上机题 C++实现

题目链接: 哈夫曼树_牛客题霸_牛客网哈夫曼树,第一行输入一个数n,表示叶结点的个数。需要用这些叶结点生成哈夫曼树,根据哈夫曼树。题目来自【牛客题霸】https://www.nowcoder.com/share/jump/437195121692781391110 描述 哈夫…

Halcon错误 #2021: System clock has been set back.

修复"Halcon#2021 System clock has been set back."一键即可解决。

照片怎么换背景,换背景的简单方法

你是否曾经为了照片背景不合适而苦恼?是否曾经因为照片背景影响美观而错过了重要的纪念时刻?今天,来为你介绍以后照片抠图换背景的简单方法,让你不再需要担心照片背景的问题!一起来看看吧! 有时候照片背景…

Pycharm 控制台 输出 中文 乱码 黑方块

问题: 解决: 打开设置 》编辑器 》常规 》控制台》默认编码(由系统默认GBK改为UTF-8)

【回味“经典”】DFS练习题解(工作分配问题,最大平台)

这篇文章是一年前写的 走进“深度搜索基础训练“,踏入c算法殿堂(四)和 走进“深度搜索基础训练“,踏入c算法殿堂(二)的重编版。 希望以此,唤起对那位故人的回忆。 【搜索与回溯算法】工作分配问…

Authing 官网新升级,「客户第一」是我们的方法论

赶在立秋前,我们上线了全新一版官网。 官网链接:http://www.authing.com 如果你说,在几个月前我会怎么描述我们的官网,我会说:它很好,很标准。和其它绝大多数企业的官网一样,它作为展示信息的页…

MQ消息队列(主要介绍RabbitMQ)

消息队列概念&#xff1a;是在消息的传输过程中保存消息的容器。 作用&#xff1a;异步处理、应用解耦、流量控制..... RabbitMQ&#xff1a; SpringBoot继承RabbitMQ步骤&#xff1a; 1.加入依赖 <dependency><groupId>org.springframework.boot</groupId&g…

汽配企业MES管理系统如何追溯产品质量问题

随着汽车行业的快速发展&#xff0c;汽配行业也面临着越来越严格的质量要求。为了满足客户需求并提高产品质量&#xff0c;汽配企业需要实现生产过程的可追溯性。MES管理系统解决方案作为生产过程的核心管理系统&#xff0c;可以通过记录生产数据和流程&#xff0c;实现产品质量…

寡肽-54/Oligopeptide-54, CG-Nokkin---------一种新型的促进头发生长的多肽

功效与应用----寡肽-54 1. 头发色素沉积和逆转头发变白过程 2. 刺激头发生长 1. Hair pigment deposition and reversal of hair whitening process 2. Stimulate hair growth 作用机理----寡肽-54 寡肽-54&#xff0c;oligopeptide-54&#xff08;CG nokkin&#xff09;增…

基于量子粒子群算法(QPSO)优化LSTM的风电、负荷等时间序列预测算法(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

【C++杂货铺】探索string的底层实现

文章目录 一、成员变量二、成员函数2.1 默认构造函数2.2 拷贝构造函数2.3 operator2.4 c_str()2.5 size()2.6 operator[ ]2.7 iterator2.8 reserve2.9 resize2.10 push_back2.11 append2.12 operator2.13 insert2.14 erase2.15 find2.16 substr2.17 operator<<2.18 opera…

【数据结构】 LinkedList的模拟实现与使用

文章目录 &#x1f340;什么是LinkedList&#x1f334;LinkedList的模拟实现&#x1f6a9;创建双链表&#x1f6a9;头插法&#x1f6a9;尾插法&#x1f6a9;任意位置插入&#x1f6a9;查找关键字&#x1f6a9;链表长度&#x1f6a9;打印链表&#x1f6a9;删除第一次出现关键字为…

【技术】安防视频监控平台EasyNVR平台启用国标级联的操作步骤

安防视频监控汇聚EasyNVR视频集中存储平台&#xff0c;是基于RTSP/Onvif协议的安防视频平台&#xff0c;可支持将接入的视频流进行全平台、全终端分发&#xff0c;分发的视频流包括RTSP、RTMP、HTTP-FLV、WS-FLV、HLS、WebRTC等格式。 为提高用户体验&#xff0c;让用户更加便捷…

【Midjourney电商与平面设计实战】创作效率提升300%

不得不说&#xff0c;最近智能AI的话题火爆圈内外啦。这不&#xff0c;战火已经从IT行业燃烧到设计行业里了。 刚研究完ChatGPT&#xff0c;现在又出来一个AI作图Midjourney。 其视觉效果令不少网友感叹&#xff1a;“AI已经不逊于人类画师了!” 现如今&#xff0c;在AIGC 热…

CSS实现一个交互感不错的卡片列表

0、需求分析 横向滚动鼠标悬停时突出显示 默认堆叠展示鼠标悬停时&#xff0c;完整展示当前块适当旋出效果 移动端样式优化、磁吸效果美化滚动条 1、涉及的主要知识块 flex 布局css 简单变换过渡 transform、transition 渐变色函数 linear-gradient… 伪类、伪元素 滚动条、…

突破欧美技术垄断,国产磁悬浮人工心脏再闯关

“现在身体状态还不错&#xff0c;一些不太剧烈的运动也可以参加。”一年前&#xff0c;湖北武汉市东西湖区的李女士突发暴发性心肌炎&#xff0c;出现心力衰竭。植入国产全磁悬浮人工心脏治疗后&#xff0c;现在李女士能正常生活。 心力衰竭是全球医学的重大挑战。据统计&…

猫云域名防红系统源码

大致功能&#xff1a;支持会员充值功能&#xff0c;对接的易支付&#xff0c;本站可以自行搭建。支持添加广告信息&#xff0c;例如进入网站前&#xff0c;先跳转个广告支持设置访问流量限制等支持设置伪域名&#xff0c;长短后缀支持屏蔽ip支持添加多个入口与落地域名支持对接…

信息安全史:半个世纪以来飞跃发展的信息安全

从20世纪60年代开始信息技术稳步上升&#xff0c;信息安全现已成为一个重要的现代问题。在过去的十年中&#xff0c;美国的雅虎、微软和Equifax等大公司都曾遭到黑客攻击。尽管近年来网络安全得到极大提高&#xff0c;但2017年的WannaCry勒索蠕虫攻击证明&#xff0c;不仅仅是信…

多个微信号怎么定时发圈?

多个微信号怎么定时发圈&#xff1f;https://mp.weixin.qq.com/s?__bizMzg2Nzg4NjEzNg&mid2247487136&idx2&sn036e1d5f9d3790b12a103a90de474957&chksmceb5fbf7f9c272e1f8e9acf644ad3d4d97fb8fdce77ec5e2a2976527d4d180ad1c277b4336c8&token495803628&…

OpenGL —— 2.5、绘制第一个三角形(附源码,glfw+glad)(更新:纹理贴图)

源码效果 C++源码 纹理图片 需下载stb_image.h这个解码图片的库,该库只有一个头文件。 具体代码: vertexShader.glsl #version 330 corelayout(location = 0) in vec3 aPos; layout(location = 1) in vec3 aColor; layout(location = 2) in vec2 aUV;out vec4 outColor; ou…