《Pytorch深度学习和图神经网络(卷 2)》学习笔记——第二章

news2024/11/25 6:59:26

基于图片内容的处理任务

主要包括目标检测、图片分割两大任务。
目标检测:精度相对较高,主要是以检测框的方式,找出图片中目标物体所在坐标。模型运算量相对较小,相对较快。
图片分割:精度相对较低,主要是以像素点的集合方式,找出图片中目标物体边缘的具体像素点。模型运算量相对较大,相对较慢。

目标检测

单阶段:也叫Region-free方法,直接从模型获得预测结果,有YOLO、SSD、RetinaNet等。
两阶段:先检测包含实物的区域,再对该区域内的实物进行分类识别。有R-CNN、Faster R-CNN、Mask R-CNN等。
两阶段检测模型在检测框方面表现的精度更高,单阶段在分类方面表现出的精度更高。

图片分割

语义分割:能将图片中具有不同语义的部分分开
实例分割:能描述出目标物体的轮廓(比检测框更为精细),比语义分割还能识别出单个的具体个体。

非极大值抑制算法(Non-Max Suppression)

目标检测中,会检测出很多个结果,可能会出现重复物体(中心和大小略有不同),要用NMS对检测结果进行去重。具体过程如下:
从所有的检测框中找到置信度较大(大于某个阈值)的检测框
逐一计算其与剩余检测框的区域面积重叠率(IOU)
如果IOU大于一定阈值则剔除
重复上述过程
IOU(Intersection Over Union)交并比

Mask R-CNN 模型

属于两阶段模型,具体步骤如下:
用NMS将一张图分成多个子框,称作锚点,不同尺寸存在重叠。
在图片中为具体实物标注坐标。
根据坐标和IOU计算那些锚点是前景(IOU高的),哪些是背景(IOU低的)。
计算前景的锚点坐标和实物标注的坐标,计算二者的相对位移和长宽的缩放比例。
最终检测区域会转换为一堆锚点的分类(前景和背景)和回归任务(偏移和缩放),每张图片会将其自身标注的信息转化为锚点对应的标签,让模型对已有的锚点进行训练或识别。
在模型中实现区域检测功能的网络被称作区域生成网络(Region Proposal Network),实际处理中会从RPN的输出结果中选取前景概率较高的一定数量的锚点作为感兴趣区域(Region Of Interest),送到第二阶段的网络中进行计算。

完整步骤:
提取主特征:又称作骨干网络,从图片中提取出一些不同尺寸的特征,通常用一些预训练好的模型(VGG、Inception、ResNet等),这些获得的特征数据被称作特征图。
特征融合:用特征融合金字塔(Feature Pyramin Network)整合骨干网络中的不同尺寸,最终的特征信息用于后面的RPN和最终的分类器网络的计算。
提取ROI:主要通过RPN来实现,在众多锚点计算前景背景的预测值,基于锚点的便宜,然后对前景概率较大的ROI用NMS去重,最终结果取出指定个数的ROI用于后续的计算。
ROI池化:用区域对齐(ROI Align)的方式实现,将特征融合的结果当做图片,按照ROI中的区域框位置从图中取出对应内容,将形状统一成指定大小,用于后面的计算。
最终检测:对上一步的结果一次进行分类,设置矩形坐标、实物像素分割处理。得到最终结果。

实例:使用Mask R-CNN模型进行目标检测与语义分割

cv2.error: OpenCV(4.8.0) 👎 error: (-5:Bad argument)

in function ‘putText’
Overload resolution failed:
Can’t parse ‘org’. Sequence item with index 0 has a wrong type

in function ‘rectangle’
Overload resolution failed:
Can’t parse ‘pt1’. Sequence item with index 0 has a wrong type
argument for rectangle() given by name (‘color’) and position (3)
上述两种报错是因为这两个函数坐标只能是int类型
cv2.rectangle()
cv2.putText()

可以用这种方式对元组里的数据强制类型转换

# initializing list
test_list = [(4, 5), (6, 7), (1, 4), (8, 10)]
  
# printing original list
print("The original list is : " + str(test_list))
  
# Change Datatype of Tuple Values
# Using enumerate() + loop
# converting to string using str()
for idx, (x, y) in enumerate(test_list):
    test_list[idx] = (x, str(y))
  
# printing result 
print("The converted records : " + str(test_list)) 

完整代码:

from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as T
import torchvision
import numpy as np
import cv2
import random
import torch

#加载模型
model = torchvision.models.detection.maskrcnn_resnet50_fpn()
model.load_state_dict(torch.load(r"pytorch\2-chapter1\some3\maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth")) 	#true 代表下载
model = model.eval()
model.eval()

COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' ]

len(COCO_INSTANCE_CATEGORY_NAMES) #91

def get_prediction(img_path, threshold):
  img = Image.open(img_path)
  transform = T.Compose([T.ToTensor()])
  img = transform(img)
  pred = model([img])
  print('pred')
  print(pred)
  pred_score = list(pred[0]['scores'].detach().numpy())
  pred_t = [pred_score.index(x) for x in pred_score if x>threshold][-1]
  print("masks>0.5")
  print(pred[0]['masks']>0.5)
  masks = (pred[0]['masks']>0.5).squeeze().detach().cpu().numpy()
  print("this is masks")
  print(masks)
  pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
  pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]['boxes'].detach().numpy())]
  masks = masks[:pred_t+1]
  pred_boxes = pred_boxes[:pred_t+1]
  pred_class = pred_class[:pred_t+1]
  return masks, pred_boxes, pred_class


def random_colour_masks(image):
  colours = [[0, 255, 0],[0, 0, 255],[255, 0, 0],[0, 255, 255],[255, 255, 0],[255, 0, 255],[80, 70, 180],[250, 80, 190],[245, 145, 50],[70, 150, 250],[50, 190, 190]]
  r = np.zeros_like(image).astype(np.uint8)
  g = np.zeros_like(image).astype(np.uint8)
  b = np.zeros_like(image).astype(np.uint8)
  randcol = colours[random.randrange(0,10)]
  r[image == 1] = randcol[0]
  g[image == 1] = randcol[1]
  b[image == 1] = randcol[2]
  coloured_mask = np.stack([r, g, b], axis=2)
  return coloured_mask,randcol

def instance_segmentation_api(img_path, threshold=0.5, rect_th=3, text_size=5, text_th=5):
  masks, boxes, pred_cls = get_prediction(img_path, threshold) #调用模型
  img = cv2.imread(img_path)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  for i in range(len(masks)):
    rgb_mask,randcol = random_colour_masks(masks[i])   #为掩码区填充随机值
    img = cv2.addWeighted(img, 1, rgb_mask, 0.5, 0)

    test_list = [boxes[i][0], boxes[i][1]]

    for idx, (x, y) in enumerate(test_list):
        test_list[idx] = (int(x), int(y))
    # print(test_list)

    cv2.rectangle(img, test_list[0], test_list[1],color=randcol, thickness=rect_th)

    test_list = [boxes[i][0]]

    for idx, (x, y) in enumerate(test_list):
        test_list[idx] = (int(x), int(y))
    print(test_list[0])

    # print(img)
    cv2.putText(img,pred_cls[i], test_list[0], cv2.FONT_HERSHEY_SIMPLEX, text_size, randcol,thickness=text_th)
    plt.figure(figsize=(20,30))
  plt.imshow(img)
  plt.xticks([])
  plt.yticks([])
  plt.show()

#显示模型结果
instance_segmentation_api(r'E:\desktop\Home_Code\pytorch\2-chapter1\some1\horse.jpg')

在这里插入图片描述
本书后面的内容与我目前的学习需求关系不大,所以在此结束!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/783832.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【工具-jmeter】jmeter 入门级 demo 练习

目录 前言: 1. Jmeter 准备 1.1 jmeter 安装包下载 1.2 jmeter 启动 1.3 jmeter 语言选择 2. Jmeter 运行 1 个 Web 请求的 demo 2.1 添加 1 个 Thread Group 线程组 2.2 添加 1 个 HTTP Request 请求 2.3 乱码问题 2.4 添加 1 个 HTTP Header 请求头 2.…

开发中遇到的 cookie 问题

1. cookie 无法跨域携带问题 尽管已经登录,但是请求接口返回状态码:202,msg: 未登录,如下图所示; 1.1 XMLHttpRequest.withCredentials未设置 如果需要跨域 AJAX 请求发送 Cookie,需要withCre…

【UE】虚幻网络同步

UE网络官方文档链接:https://docs.unrealengine.com/5.2/zh-CN/networking-overview-for-unreal-engine/ 虚幻的网络模式 服务器作为游戏主机,保留一个真实授权的游戏状态。换句话说,服务器是多人游戏实际发生的地方。客户端会远程控制其在服…

SpringBoot Redis 使用Lettuce和Jedis配置哨兵模式

Redis 从入门到精通【应用篇】之SpringBoot Redis 配置哨兵模式 Lettuce 和Jedis 文章目录 Redis 从入门到精通【应用篇】之SpringBoot Redis 配置哨兵模式 Lettuce 和Jedis前言Lettuce和Jedis区别1. 连接方式2. 线程安全性 教程如下1. Lettuce 方式配置1.1. 添加 Redis 和 Let…

Java项目里添加python解析器

java项目里配置了SDK为1.8,添加python文件时会无法解析。 提示让模块配置Python解析器,点击 配置python解析器 ,弹出如下: 应用即可。

【机器学习】异常检测

异常检测 假设你是一名飞机涡扇引擎工程师,你在每个引擎出厂之前都需要检测两个指标——启动震动幅度和温度,查看其是否正常。在此之前你已经积累了相当多合格的发动机的出厂检测数据,如下图所示 我们把上述的正常启动的数据集总结为 D a t…

【Linux命令200例】chattr改变文件的扩展属性

🏆作者简介,黑夜开发者,全栈领域新星创作者✌,2023年6月csdn上海赛道top4。 🏆本文已收录于专栏:Linux命令大全。 🏆本专栏我们会通过具体的系统的命令讲解加上鲜活的实操案例对各个命令进行深入…

【人工智能】博弈、极小极大值、α-β剪枝、截断测试

文章目录 博弈极小极大值α-β剪枝截断测试博弈 极小极大值 假设两个玩家都以最大化自身利用进行博弈举例: 计算机假设在它移动后,对手会选择最小化的行动计算机在考虑自己的行动和对手的最佳行动后选择最佳行动算法实现

【python】在matlab中调用python

参考 Matlab调用Python - 知乎 (zhihu.com) 说一下我犯的错误: 1、电脑上有没有python都可以,我以为anaconda里的python不行,又重新下了一个python3.8 实际上导入的时候可以用 pyversion(D:\myDownloads\anaconda\envs\pytorch38\pytho…

Docker 全栈体系(五)

Docker 体系(高级篇) 二、DockerFile解析 1. 是什么? Dockerfile是用来构建Docker镜像的文本文件,是由一条条构建镜像所需的指令和参数构成的脚本。 1.1 概述 1.2 官网 https://docs.docker.com/engine/reference/builder/ 1…

freeBSD:ssh登录root

/etc/inetd.conf ee /etc/inetd.conf 去掉# /etc/rc.conf ee /etc/rc.conf 添加一句 sshd_enable"YES" /etc/ssh/sshd_config vi /etc/ssh/sshd_config 22行可以修改端口号,非必要就默认22 36行 去掉# 后面修改成 yes 61 PasswordAuthentication…

Python处理Elasticsearch

简介:Elasticsearch 是一个分布式、高扩展、高实时的搜索与数据分析引擎。它能很方便的使大量数据具有搜索、分析和探索的能力。充分利用Elasticsearch的水平伸缩性,能使数据在生产环境变得更有价值。Elasticsearch 的实现原理主要分为以下几个步骤&…

Golang数据库连接池技术原理与实现

1 为什么需要连接池? 如果不用连接池,而是每次请求都创建一个连接是比较昂贵的,因此需要完成3次tcp握手。同时在高并发场景下,由于没有连接池的最大连接数限制,可以创建无数个连接,耗尽文件描述符。连接池…

【软件测试】什么是selenium

1.seleniumJava环境搭建 前置条件: Java最低版本要求为8,浏览器使用chrome浏览器 1.1下载chrome浏览器 https://www.google.cn/chrome/ 1.2查看浏览器版本 点击关于Google chrome. 记住版本的前三个数. 1.3下载浏览器驱动 http://chromedriver.chromium.org/downloads 下载…

JS案例:在浏览器实现自定义菜单

目录 前言 设计思路 BaseElem Menu CustomElement BaseDrag Drag Resize 最终效果 总结 相关代码 前言 分享一下之前公司实现自定义菜单的思路,禁用浏览器右键菜单,使用自定义的菜单将其代替,主要功能有:鼠标右键调出菜…

二、基本数据类型和表达式

2.1数据类型 数据类型占用字节数取值范围bool1true 或 falsechar1-128 到 127 或 0 到 255 (取决于是否带符号)unsigned char10 到 255short2-32,768 到 32,767unsigned short20 到 65,535int4-2,147,483,648 到 2,147,483,647unsigned int40 到 4,294,…

ES6基础知识二:ES6中数组新增了哪些扩展?

一、扩展运算符的应用 ES6通过扩展元素符…&#xff0c;好比 rest 参数的逆运算&#xff0c;将一个数组转为用逗号分隔的参数序列 console.log(...[1, 2, 3]) // 1 2 3console.log(1, ...[2, 3, 4], 5) // 1 2 3 4 5[...document.querySelectorAll(div)] // [<div>, &l…

12 扩展Spring MVC

12.1 实现页面跳转功能 页面跳转功能&#xff1a;访问localhost:8081/jiang会自动跳转到另一个页面。 首先&#xff0c;在config包下创建一个名为MyMvcConfig的配置类&#xff1a; 类上加入Configuration注解&#xff0c;类实现WebMvcConfiger接口&#xff0c;实现里面的视图跳…

Python入门【列表元素访问和计数 、切片操作、列表的遍历、复制列表所有的元素到新列表对象、多维列表、元组tuple】(五)

&#x1f44f;作者简介&#xff1a;大家好&#xff0c;我是爱敲代码的小王&#xff0c;CSDN博客博主,Python小白 &#x1f4d5;系列专栏&#xff1a;python入门到实战、Python爬虫开发、Python办公自动化、Python数据分析、Python前后端开发 &#x1f4e7;如果文章知识点有错误…

OSI 和 TCP/IP 网络分层模型详解(基础)

OSI模型: 即开放式通信系统互联参考模型&#xff08;Open System Interconnection Reference Model&#xff09;&#xff0c;是国际标准化组织&#xff08;ISO&#xff09;提出的一个试图使各种计算机在世界范围内互连为网络的标准框架&#xff0c;简称OSI。 OSI 七层模型 OS…