【深度学习实战(6)】搭建通用的语义分割推理流程

news2024/9/21 11:12:12

一、代码

#---------------------------------------------------#
#   检测图片
#---------------------------------------------------#
def detect_image(self, image, count=False, name_classes=None):
    #---------------------------------------------------------#
    #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
    #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
    #---------------------------------------------------------#
    image       = cvtColor(image)
    #---------------------------------------------------#
    #   对输入图像进行一个备份,后面用于绘图
    #---------------------------------------------------#
    old_img     = copy.deepcopy(image)
    orininal_h  = np.array(image).shape[0]
    orininal_w  = np.array(image).shape[1]
    #---------------------------------------------------------#
    #   给图像增加灰条,实现不失真的resize
    #   也可以直接resize进行识别
    #---------------------------------------------------------#
    image_data, nw, nh  = resize_image(image, (self.input_shape[1],self.input_shape[0]))
    #---------------------------------------------------------#
    #   添加上batch_size维度
    #---------------------------------------------------------#
    image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)

    with torch.no_grad():
        images = torch.from_numpy(image_data)
        if self.cuda:
            images = images.cuda()
            
        #---------------------------------------------------#
        #   图片传入网络进行预测
        #---------------------------------------------------#
        pr = self.net(images)[0]
        #---------------------------------------------------#
        #   取出每一个像素点的种类
        #---------------------------------------------------#
        pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
        #--------------------------------------#
        #   将灰条部分截取掉
        #--------------------------------------#
        pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
                int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
        #---------------------------------------------------#
        #   进行图片的resize
        #---------------------------------------------------#
        pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
        #---------------------------------------------------#
        #   取出每一个像素点的种类
        #---------------------------------------------------#
        pr = pr.argmax(axis=-1)

        seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
        #------------------------------------------------#
        #   将新图片转换成Image的形式
        #------------------------------------------------#
        image   = Image.fromarray(np.uint8(seg_img))
        #------------------------------------------------#
        #   将新图与原图及进行混合
        #------------------------------------------------#
        image   = Image.blend(old_img, image, 0.7)

二、代码逐步debug调试

(1)读图

#---------------------------------------------------------#
#   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
#   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image       = cvtColor(image)

在这里插入图片描述

(2) Letterbox

在这里插入图片描述
在这里插入图片描述
无论输入的图片尺寸多大,都会经过letter_box后,变为512x512尺寸

(3) 归一化、HWC 转 CHW,并expand维度到NCHW,转tensor

def preprocess_input(image):
    image /= 255.0
    return image
    
#---------------------------------------------------------#
#   添加上batch_size维度
#---------------------------------------------------------#
image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)

在这里插入图片描述

(4) 前向传播

#---------------------------------------------------#
#   图片传入网络进行预测
#---------------------------------------------------#
pr = self.net(images)[0]

在这里插入图片描述
21个channel代表(20+1)个类别,512x512为模型输入及输入尺寸

(5) softmax 计算像素类别概率

#---------------------------------------------------#
#   取出每一个像素点的种类
#---------------------------------------------------#
pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()

在这里插入图片描述

经过softmax后,512x512的mask图中,每个位置(x,y)对应的21个channel的值和为1。

(6) 截取灰条部分,并resize到原图尺寸(逆letter_box)

            #--------------------------------------#
            #   将灰条部分截取掉
            #--------------------------------------#
            pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
                    int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
            #---------------------------------------------------#
            #   进行图片的resize
            #---------------------------------------------------#
            pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)

pr类型是np,array,所以可以通过这种方式进行逆letter_box操作,将mask的宽高,还原到原始输入图片的宽高。

(7) 利用argmax,计算每个像素属于的类别

#---------------------------------------------------#
#   取出每一个像素点的种类
#---------------------------------------------------#
pr = pr.argmax(axis=-1)

返回最后一个维度(channel)中,最大值所对应的索引,即类别。例如,像素点(x1,y1)所对应的21个channel中,第5个channel的值最大,则像素点(x1,y1)对应类别则是class=5。

(8) 可视化

seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
#------------------------------------------------#
#   将新图片转换成Image的形式
#------------------------------------------------#
image   = Image.fromarray(np.uint8(seg_img))
#------------------------------------------------#
#   将新图与原图及进行混合
#------------------------------------------------#
image   = Image.blend(old_img, image, 0.7)

在这里插入图片描述
将预测的结果与原图进行混合。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1597150.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【题目】【信息安全管理与评估】2022年国赛高职组“信息安全管理与评估”赛项样题6

【题目】【信息安全管理与评估】2022年国赛高职组“信息安全管理与评估”赛项样题5 信息安全管理与评估 网络系统管理 网络搭建与应用 云计算 软件测试 移动应用开发 任务书,赛题,解析等资料,知识点培训服务 添加博主wx:liuliu548…

Decorator 装饰

意图 动态的给一个对象添加一些额外的职责。就增加功能而言,Decorator模式比生成子类更加灵活 结构 其中: Component定义一个对象接口,可以给这些对象动态的添加职责。ConcreteComponent定义一个对象,可以给这个对象添加一些职…

C++修炼之路之list模拟实现--C++中的双向循环链表

目录 引言 一:STL源代码中关于list的成员变量的介绍 二:模拟实现list 1.基本结构 2.普通迭代器 const迭代器的结合 3.构造拷贝构造析构赋值重载 清空 4.inserterase头尾插入删除 5.打印不同数据类型的数据《使用模板加容器来完成》 三&#xf…

水库之大坝安全监测系统解决方案

一、系统介绍 水库之大坝安全监测系统主要包括渗流监测系统、流量监测系统、雨量监测系统、沉降监测系统组成。每一个监测系统由监测仪器及自动化数据采集装置(内置通信装置、防雷设备)、附件(电缆、通信线路、电源线路)等组成&a…

YOLO算法改进Backbone系列之:HAT-Net

本文旨在解决ViT中与多头自我关注(MHSA)相关的高计算/空间复杂性问题。为此,我们提出了分层多头自注意(H-MHSA),这是一种以分层方式计算自注意的新方法。具体来说,我们首先按照通常的方法将输入…

llama-factory SFT系列教程 (二),大模型在自定义数据集 lora 训练与部署

文章目录 简介支持的模型列表2. 添加自定义数据集3. lora 微调4. 大模型 lora 权重,部署问题 参考资料 简介 文章列表: llama-factory SFT系列教程 (一),大模型 API 部署与使用llama-factory SFT系列教程 (二),大模型在自定义数…

ClickHouse--18--argMin() 和argMax()函数

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 argMin() 和argMax()函数业务场景使用案例1.准备表和数据:业务场景一:查看salary 最高和最小的user业务场景二:根据更新时间获取…

一种基于OpenCV的图片倾斜矫正方法

需求描述: 对倾斜的图片进行矫正,返回倾斜角度和矫正后的图片。 解决方法: 1、各种角度点被投影到一个累加器阵列中,其中倾斜角度可以定义为在最大化对齐的搜索间隔内的投影角度。 2、以不同的角度旋转图像,并为每…

Chatgpt掘金之旅—有爱AI商业实战篇|编写代码业务|(十九)

演示站点: https://ai.uaai.cn 对话模块 官方论坛: www.jingyuai.com 京娱AI 一、程序员使用 ChatGPT 进行编码搞副业 程序员不仅拥有将抽象概念转化为实际应用的能力,还通常具备强大的逻辑思维和问题解决能力。然而,许多程序员并…

宝塔面板安装软件 提示需要[xxxMB]内存 强制不能安装

解决方法: 第一步: 编辑修改/www/server/panel/class/下的文件panelPlugin.py vi /www/server/panel/class/panelPlugin.py注释以下判断的内容: ## 第二步: 重启宝塔面板,然后安装即可 bash bt 1

ROS 2边学边练(25)-- 将多个节点组合到一个进程

前言 在ROS 2中,将多个节点(Nodes)组合到一个单独的进程(Process)中通常指的是使用“Composable Nodes”的特性。这个特性允许你定义可复用的组件(Components),然后将这些组件加…

如何在MobaXterm上使用rz命令

1、首先输入命令和想下载的文件,如下图: 2、按住ctrl鼠标右键,选择如下选项: 上传命令是rz,选择Receive...... 下载命令是sz,选择Send...... 3、我这里是要把Linux上的文件下载到我的本地window磁盘&…

Django之rest_framework(三)

一、GenericAPIView的使用 rest_framework.generics.GenericAPIView 继承自APIVIew,主要增加了操作序列化器和数据库查询的方法,作用是为下面Mixin扩展类的执行提供方法支持。通常在使用时,可搭配一个或多个Mixin扩展类 1.1、属性 serializer_class 指明视图使用的序列化器…

记录一下买了腾讯云服务器后如何第一次连MobaXterm

首先是你要用SwitchHost把hosts的映射地址改成你新买的服务器的(如果你没这个软件,可以直接在etc/hosts里改 ) 再连MobaXterm 然后,关键的来了 成功!

2024/4/15 网络编程day3

一、TCP机械臂测试 通过w(红色臂角度增大)s(红色臂角度减小)d(蓝色臂角度增大)a(蓝色臂角度减小)按键控制机械臂 注意:关闭计算机的杀毒软件,电脑管家,防火墙 1&#…

openGauss学习笔记-261 openGauss性能调优-使用Plan Hint进行调优-将部分Error降级为Warning的Hint

文章目录 openGauss学习笔记-261 openGauss性能调优-使用Plan Hint进行调优-将部分Error降级为Warning的Hint261.1 功能描述261.2 语法格式261.3 示例261.3.1 忽略非空约束261.3.2 忽略唯一约束261.3.3 忽略分区表无法匹配到合法分区261.3.4 更新/插入值向目标列类型转换失败 o…

3.MMD快捷键操作及人物绑定配饰

快捷键 1. 模型界面切换 按一下TAB键,就从人物模型切换到照明模型 再按一下TAB键,就能从照明模型切换回人物模型 2. 选中全部模型 当模型界面是人物模型时 而且电脑输入法时英文时 按一下A键,可以把人物骨骼全部选中,方便旋转…

互联网轻量级框架整合之MyBatis配置详解

MyBatis核心配置文件mybatis-config.xml里有诸多配置项&#xff0c;但常用的就无非就如下这么多 <?xml version"1.0" encoding"UTF-8" ?> <!DOCTYPE configuration PUBLIC "-//mybatis.org//DTDConfig3.0//EN" "https://mybati…

【爬虫开发】爬虫从0到1全知识md笔记第5篇:Selenium课程概要,selenium的其它使用方法【附代码文档】

爬虫开发从0到1全知识教程完整教程&#xff08;附代码资料&#xff09;主要内容讲述&#xff1a;爬虫课程概要&#xff0c;爬虫基础爬虫概述,,http协议复习。requests模块&#xff0c;requests模块1. requests模块介绍,2. response响应对象,3. requests模块发送请求,4. request…

“成像光谱遥感技术中的AI革命:ChatGPT在遥感领域中的应用“

遥感技术主要通过卫星和飞机从远处观察和测量我们的环境&#xff0c;是理解和监测地球物理、化学和生物系统的基石。ChatGPT是由OpenAI开发的最先进的语言模型&#xff0c;在理解和生成人类语言方面表现出了非凡的能力。本文重点介绍ChatGPT在遥感中的应用&#xff0c;人工智能…