图像抠图DIS——自然图像中高精度二分图像抠图的方法(C++/python模型推理)

news2024/11/25 5:36:18

概述

DIS(Dichotomous Image Segmentation)是一种新的图像分割任务,旨在从自然图像中分割出高精度的物体。与传统的图像分割任务相比,DIS更侧重于具有单个或几个目标的图像,因此可以提供更丰富准确的细节。

为了研究DIS任务,研究人员创建了一个名为DIS5K的大规模、可扩展的数据集。DIS5K数据集包含了5,470张高分辨率图像,每张图像都配有高精度的二值分割掩码。这个数据集的建立有助于推动多个应用方向的发展,如图像去背景、艺术设计、模拟视图运动、基于图像的增强现实(AR)应用、基于视频的AR应用、3D视频制作等。

通过研究DIS任务和使用DIS5K数据集,研究人员可以探索新的图像分割方法,并为各种应用领域提供更精确、更可靠的图像分割技术,从而推动分割技术在更广泛的领域中的应用。

官网:https://xuebinqin.github.io/dis/index.html
Github:https://github.com/xuebinqin/DIS

数据集

图像二类分割是将图像分割成两个主要区域:前景和背景。在这种情况下,前景代表图像中的某个类别的物体,而背景则是除了该物体之外的所有内容。
官方公布了算所使用的数据集DIS5K, DIS5K数据集中的每张图像都经过了像素级别的手工标注,标注的真值掩码非常精确,每张图像的标记时间相当长。这种高精度的标注使得数据集中的每个像素都与其相应的类别关联起来,从而为模型提供了可靠的训练数据。这种高精度的标注是实现图像二类分割的关键,因为模型需要能够准确地识别和分割出前景物体。

在DIS5K数据集中,标注对象的类型多样,包括透明和半透明的物体,标注使用单个像素的二值掩码进行。这种精确的标注确保了模型训练的有效性和准确性,并且使得模型能够预测出高精度的物体分割结果。

DIS5K数据集网盘地址:https://pan.baidu.com/s/1umNk2AeBG5aB5kXlHTHdIg
提取码:7qfs

模型训练

模型训练可参考git上的官方的文档

模型推理

模型C++使用onnxruntime进行推理

#include <opencv2/opencv.hpp>
#include <onnxruntime_cxx_api.h>


class DIS
{
public:
	DIS(std::string model_path);
	void inference(cv::Mat& cv_src, cv::Mat& cv_mask);
private:
	std::vector<float> input_image_;
	int inpWidth;
	int inpHeight;
	int outWidth;
	int outHeight;
	const float score_th = 0;

	Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "DIS");
	Ort::Session* ort_session = nullptr;
	Ort::SessionOptions sessionOptions = Ort::SessionOptions();
	std::vector<char*> input_names;
	std::vector<char*> output_names;
	std::vector<std::vector<int64_t>> input_node_dims; // >=1 outputs
	std::vector<std::vector<int64_t>> output_node_dims; // >=1 outputs
};



DIS::DIS(std::string model_path)
{
	std::wstring widestr = std::wstring(model_path.begin(), model_path.end());
	//OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_CUDA(sessionOptions, 0);
	sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
	ort_session = new Ort::Session(env, widestr.c_str(), sessionOptions);
	size_t numInputNodes = ort_session->GetInputCount();
	size_t numOutputNodes = ort_session->GetOutputCount();
	Ort::AllocatorWithDefaultOptions allocator;
	for (int i = 0; i < numInputNodes; i++)
	{
		input_names.push_back(ort_session->GetInputName(i, allocator));
		Ort::TypeInfo input_type_info = ort_session->GetInputTypeInfo(i);
		auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo();
		auto input_dims = input_tensor_info.GetShape();
		input_node_dims.push_back(input_dims);
	}
	for (int i = 0; i < numOutputNodes; i++)
	{
		output_names.push_back(ort_session->GetOutputName(i, allocator));
		Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo(i);
		auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo();
		auto output_dims = output_tensor_info.GetShape();
		output_node_dims.push_back(output_dims);
	}
	this->inpHeight = input_node_dims[0][2];
	this->inpWidth = input_node_dims[0][3];
	this->outHeight = output_node_dims[0][2];
	this->outWidth = output_node_dims[0][3];
}


void DIS::inference(cv::Mat& cv_src, cv::Mat& cv_mask)
{
	cv::Mat cv_dst;
	cv::resize(cv_src, cv_dst, cv::Size(this->inpWidth, this->inpHeight));
	this->input_image_.resize(this->inpWidth * this->inpHeight * cv_dst.channels());
	for (int c = 0; c < 3; c++)
	{
		for (int i = 0; i < this->inpHeight; i++)
		{
			for (int j = 0; j < this->inpWidth; j++)
			{
				float pix = cv_dst.ptr<uchar>(i)[j * 3 + 2 - c];
				this->input_image_[c * this->inpHeight * this->inpWidth + i * this->inpWidth + j] = pix / 255.0 - 0.5;
			}
		}
	}
	std::array<int64_t, 4> input_shape_{ 1, 3, this->inpHeight, this->inpWidth };

	auto allocator_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
	Ort::Value input_tensor_ = Ort::Value::CreateTensor<float>(allocator_info,
		input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size());

	std::vector<Ort::Value> ort_outputs = ort_session->Run(Ort::RunOptions{ nullptr }, &input_names[0],
		&input_tensor_, 1, output_names.data(), output_names.size());   // 开始推理
	float* pred = ort_outputs[0].GetTensorMutableData<float>();
	cv::Mat mask(outHeight, outWidth, CV_32FC1, pred);
	double min_value, max_value;
	minMaxLoc(mask, &min_value, &max_value, 0, 0);
	mask = (mask - min_value) / (max_value - min_value);
	cv::resize(mask, cv_mask, cv::Size(cv_src.cols, cv_src.rows));
}

void show_img(std::string name, const cv::Mat& img)
{
	cv::namedWindow(name, 0);
	int max_rows = 500;
	int max_cols = 600;
	if (img.rows >= img.cols && img.rows > max_rows) {
		cv::resizeWindow(name, cv::Size(img.cols * max_rows / img.rows, max_rows));
	}
	else if (img.cols >= img.rows && img.cols > max_cols) {
		cv::resizeWindow(name, cv::Size(max_cols, img.rows * max_cols / img.cols));
	}
	cv::imshow(name, img);
}

cv::Mat replaceBG(const cv::Mat cv_src, cv::Mat& alpha, std::vector<int>& bg_color)
{
	int width = cv_src.cols;
	int height = cv_src.rows;

	cv::Mat cv_matting = cv::Mat::zeros(cv::Size(width, height), CV_8UC3);

	float* alpha_data = (float*)alpha.data;
	for (int i = 0; i < height; i++)
	{
		for (int j = 0; j < width; j++)
		{
			float alpha_ = alpha_data[i * width + j];
			cv_matting.at < cv::Vec3b>(i, j)[0] = cv_src.at < cv::Vec3b>(i, j)[0] * alpha_ + (1 - alpha_) * bg_color[0];
			cv_matting.at < cv::Vec3b>(i, j)[1] = cv_src.at < cv::Vec3b>(i, j)[1] * alpha_ + (1 - alpha_) * bg_color[1];
			cv_matting.at < cv::Vec3b>(i, j)[2] = cv_src.at < cv::Vec3b>(i, j)[2] * alpha_ + (1 - alpha_) * bg_color[2];
		}
	}

	return cv_matting;
}

int main()
{
	DIS dis_net("isnet_general_use_720x1280.onnx");

	std::string path = "images";
	std::vector<std::string> filenames;
	cv::glob(path, filenames, false);

	for (auto file_name : filenames)
	{
		cv::Mat cv_src = cv::imread(file_name);
		//std::vector<cv::Mat> cv_dsts;
		cv::Mat cv_dst, cv_mask;
		dis_net.inference(cv_src, cv_mask);
		std::vector<int> color{255, 0, 0};
		cv_dst=replaceBG(cv_src, cv_mask, color);

		show_img("src", cv_src);
		show_img("mask", cv_mask);
		show_img("dst", cv_dst);

		cv::waitKey(0);
	}
}

python推理代码也依赖onnxruntime

import argparse
import cv2
import numpy as np
import onnxruntime
### onnxruntime load ['isnet_general_use_HxW.onnx', 'isnet_HxW.onnx', 'isnet_Nx3xHxW.onnx']  inference failed
class DIS():
    def __init__(self, modelpath, score_th=None):
        so = onnxruntime.SessionOptions()
        so.log_severity_level = 3
        self.net = onnxruntime.InferenceSession(modelpath, so)
        self.input_height = self.net.get_inputs()[0].shape[2]
        self.input_width = self.net.get_inputs()[0].shape[3]
        self.input_name = self.net.get_inputs()[0].name
        self.output_name = self.net.get_outputs()[0].name
        self.score_th = score_th

    def detect(self, srcimg):
        img = cv2.resize(srcimg, dsize=(self.input_width, self.input_height))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = img.astype(np.float32) / 255.0 - 0.5
        blob = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0).astype(np.float32)
        outs = self.net.run([self.output_name], {self.input_name: blob})
        
        mask = np.array(outs[0]).squeeze()
        min_value = np.min(mask)
        max_value = np.max(mask)
        mask = (mask - min_value) / (max_value - min_value)
        if self.score_th is not None:
            mask = np.where(mask < self.score_th, 0, 1)
        mask *= 255
        mask = mask.astype('uint8')

        mask = cv2.resize(mask, dsize=(srcimg.shape[1], srcimg.shape[0]), interpolation=cv2.INTER_LINEAR)
        return mask

def generate_overlay_image(srcimg, mask):
    overlay_image = np.zeros(srcimg.shape, dtype=np.uint8)
    overlay_image[:] = (255, 255, 255)
    mask = np.stack((mask,) * 3, axis=-1).astype('uint8') 
    mask_image = np.where(mask, srcimg, overlay_image)
    return mask, mask_image

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--imgpath", type=str, default='images/cam_image47.jpg')
    parser.add_argument("--modelpath", type=str, default='weights/isnet_general_use_480x640.onnx')
    args = parser.parse_args()
    
    mynet = DIS(args.modelpath)
    srcimg = cv2.imread(args.imgpath)
    mask = mynet.detect(srcimg)
    mask, overlay_image = generate_overlay_image(srcimg, mask)

    winName = 'Deep learning object detection in onnxruntime'
    cv2.namedWindow(winName, cv2.WINDOW_NORMAL)
    cv2.imshow(winName, np.hstack((srcimg, mask)))
    cv2.waitKey(0)
    cv2.destroyAllWindows()

推理结果
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
资源和模型下载地址:https://download.csdn.net/download/matt45m/89024664

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1545738.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

cuda安装和下载for windows

cuda下载 英伟达cuda官方下载地址 https://developer.nvidia.com/cuda-downloads?target_osWindows&target_archx86_64&target_version11&target_typeexe_local 安装 直接一直点下一步即可&#xff0c;注意要注册账号&#xff0c;用微信扫码直接登录即可 win…

一篇文章给你讲清楚正常卷积与深度可分离卷积

文章目录 正常卷积深度可分离卷积深度卷积逐点卷积 对比代码实现查看&#xff08;torch实现&#xff09;结果 正常卷积 也就是我们平常用的比较普遍的卷积&#xff1a; 它的参数量是&#xff1a;112&#xff0c;即&#xff1a; ( 卷积核大小&#xff09; ∗ 输入通道 ∗ 输出…

【随笔】Git -- 常用命令(四)

&#x1f48c; 所属专栏&#xff1a;【Git】 &#x1f600; 作  者&#xff1a;我是夜阑的狗&#x1f436; &#x1f680; 个人简介&#xff1a;一个正在努力学技术的CV工程师&#xff0c;专注基础和实战分享 &#xff0c;欢迎咨询&#xff01; &#x1f496; 欢迎大…

【python】flask模板渲染引擎Jinja2,使得前后端交互更加便捷

✨✨ 欢迎大家来到景天科技苑✨✨ &#x1f388;&#x1f388; 养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; &#x1f3c6; 作者简介&#xff1a;景天科技苑 &#x1f3c6;《头衔》&#xff1a;大厂架构师&#xff0c;华为云开发者社区专家博主&#xff0c;…

2010年之前电脑ubuntu安装nvidia驱动黑屏处理

装好驱动 仿真fps直接到60Hz 陈旧设备 都是非常老旧的电脑&#xff0c;没钱换新电脑&#xff0c;就这么穷…… 电脑详细配置&#xff1a; 冲动 想装显卡驱动提升一下性能&#xff0c;结果……黑了 黑习惯了也无所谓&#xff0c;几分钟就能解决&#xff0c;关键还是太穷&…

【C】盛最多水的容器(双指针)

盛最多水的容器 原题目链接:点击跳转 给定一个长度为 n 的整数数组 height 。有 n 条垂线&#xff0c;第 i 条线的两个端点是 (i, 0) 和(i, height[i]) 。 找出其中的两条线&#xff0c;使得它们与 x 轴共同构成的容器可以容纳最多的水。 返回容器可以储存的最大水量。 说…

数据结构-树-006

1二叉树 1.1目标二叉树 前序遍历&#xff1a;ABDHIEJCFKG 中序遍历&#xff1a;HDIBEJAFKCG 后序遍历&#xff1a;HIDJEBKFGCA 层序遍历&#xff1a;ABCDEFGHIJK运行结果&#xff1a; 运行结果符合目标二叉树的深度优先&#xff08;前序遍历&#xff0c;中序遍历&#xff0c;…

【c++】【STL】stack类、queue类、deque类详解及模拟

&#x1fa90;&#x1fa90;&#x1fa90;欢迎来到程序员餐厅&#x1f4ab;&#x1f4ab;&#x1f4ab; 今日主菜&#xff1a;stack和queue&#xff0c;deque类 主厨&#xff1a;邪王真眼 所属专栏&#xff1a;c专栏 主厨的主页&#xff1a;Chef‘s blog 这可是…

Endnote(作者,年份)文中引用显示‘and etal‘与‘和 等‘

软件版本&#xff1a;Endnote X9.1&#xff0c;样式&#xff1a;Harvard&#xff0c;其余使用(作者&#xff0c;年份)的样式均可&#xff0c;GBT7714就有作者年份类型 本教程适用于X系列~ Endnote20及以上版本请移步另一条博文&#xff0c;指路&#xff1a;&#xff08;我还没…

“双碳”目标下资源环境中的可计算一般均衡(CGE)模型教程

原文链接&#xff1a;“双碳”目标下资源环境中的可计算一般均衡&#xff08;CGE&#xff09;模型https://mp.weixin.qq.com/s?__bizMzUzNTczMDMxMg&mid2247599079&idx4&sn82ea6c6f506cd20d1e0cd590faaa0611&chksmfa820200cdf58b16dc5b79746901cc9a4048b46db5…

《自动机理论、语言和计算导论》阅读笔记:p5-p27

《自动机理论、语言和计算导论》学习第2天&#xff0c;p5-p27总结&#xff0c;总计23页。 一、技术总结 1.集合 (1)commutative law of union. (2)distribute law of union. 2.归纳法(induction) & 演绎法(deduction) (1)归纳法&#xff1a;从许多个别的事实或原理中…

【zlm】问题记录:chrome更新引起的拉不出webrtc; 证书校验引起的放几秒中断

目录 chrome更新引起的拉不出webrtc 证书校验引起的放几秒中断 chrome更新引起的拉不出webrtc 【zlm】最新的chrome版本中的报错&#xff1a; 我有个问题event.js:8 [RTCPusherPlayer] DOMException: Failed to execute setRemoteDescription on RTCPeerConnection: Failed …

LabVIEW焓差试验室流量计现场自动校准系统

LabVIEW焓差试验室流量计现场自动校准系统 在现代工业和科研领域&#xff0c;流量计的准确性对于保证生产过程的质量和效率非常重要。开发了一种基于LabVIEW的焓差试验室流量计现场自动校准系统&#xff0c;通过提高流量计校准的准确性和效率。 在空调器空气焓值法能效测量装…

hololens 2 投屏 报错

使用Microsoft HoloLens投屏时&#xff0c;ip地址填对了&#xff0c;但是仍然报错&#xff0c;说hololens 2没有打开&#xff0c; 首先检查 开发人员选项 都打开&#xff0c;设备门户也打开 然后检查系统–体验共享&#xff0c;把共享都打开就可以了

【k8s】kubeasz 3.6.3 + virtualbox 搭建本地虚拟机openeuler 22.03 三节点集群 离线方案

kubeasz项目源码地址 GitHub - easzlab/kubeasz: 使用Ansible脚本安装K8S集群&#xff0c;介绍组件交互原理&#xff0c;方便直接&#xff0c;不受国内网络环境影响 拉取代码&#xff0c;并切换到最近发布的分支 git clone https://github.com/easzlab/kubeasz cd kubeasz gi…

C++细节

背景知识&#xff1a; 面向对象的编程中&#xff0c;类&#xff08;Class&#xff09;是创建对象的蓝图或模板&#xff0c;它包含了数据&#xff08;通常称为属性或变量&#xff09;和行为&#xff08;通常称为方法或函数&#xff09;。将数据封装为私有&#xff08;private&am…

【Java】:类和对象

1.面向对象的初步认知 1.1 什么是面向对象 Java是一门面向对象的语言&#xff0c;在面向对象的世界里&#xff0c;一切皆为对象。面向对象是解决问题的一种思想&#xff0c;主要依靠对象之间的交互完成一件事情。用面向对象的思想来涉及程序&#xff0c;更符合人们对事物的认知…

CDNS PCIe VIP debug info

1. TLP payload的顺序是反向的&#xff0c;即大小端反的&#xff0c;比如下面的denalirc打印的信息看&#xff0c;pl是我们发TLP时的配置&#xff0c;Cfg才是真正的data顺序。 而seq写的时候如下&#xff1a;可以看到payload[2]时第三个8bit payload&#xff0c;但是我们是想配…

JS加密解密之应用如何保存到桌面书签

前言 事情起因是这样的&#xff0c;有个客户解密了一个js&#xff0c;然后又看不懂里边的一些逻辑&#xff0c;想知道它是如何自动拉起谷歌浏览器和如何保存应用到书签的&#xff0c;以及如何下载应用的。继而诞生了这篇文章&#xff0c;讲解一下他的基本原理。 渐进式Web应用…

UE5学习日记——蓝图节点前缀关键字整理

一、起因 节点如海&#xff0c;中英文翻译的时候还是有差别的&#xff0c;比如&#xff1a; 同一个中文&#xff0c;可能在英文里完全不同&#xff0c;连出现位置可能都不一样 附加 Attach Actor To Component&#xff08;将Actor附加到组件&#xff09;Append Array&#xf…