ModNet抠图算法及摄像头实时抠图示例

news2025/1/11 12:34:26

目录

一、视频抠图采用绿幕的原因

1、摄像机成色原因

2、抠图效果原因

3、经济成本

二、抠图背景知识

1、Trimap

2、什么是抠图

3、抠图算法分类

三、Deep Image Matting算法

1、网络结构图

2、算法解读

(1)Encoder-Decoder阶段

(2)Refinement阶段

四、ModNet算法:Trimap-Free Portrait Matting in Real Time

1、网络结构图

2、算法解读

五、ModNet抠图实践


一、视频抠图采用绿幕的原因

1、摄像机成色原因

主流摄像机传感器为RGB三通道,所以为了抠图最精准最好采用三原色中原始颜色。此外,相机的CMOS传感器矩阵多数都是采用拜耳阵列,该阵列中绿色感光点是2个,高于红色和蓝色,所以信息更丰富更容易抠除。

2、抠图效果原因

视频中的人物和皮肤,多数都是绿色的补色,反差大,这样电脑在渲染处理时就更容易区分边缘和纹理毛发,从而减少抠图的工作量。

3、经济成本

绿背景亮度高,拍摄时光可以亮度调小点从而省电。

二、抠图背景知识

人像抠图:算法概述及工程实现(一)-云社区-华为云

1、Trimap

最常用的先验知识,它是一个三元图,每个像素取值为{0, 128, 255}其中之一,分别代表前景、未知与背景。

2、什么是抠图

对于一张图I,我们感兴趣的人像部分称为前景F,其余部分为背景B,则图像I可以视为F与B的加权融合:
I=alphaF+(1−alpha)BalphashapeI一致。

而抠图任务就是找到合适的权重alpha矩阵。

将按照上述公式前景图和背景图融合的过程举例如下:

假如一张图的中间圆圈部分为前景,其余部分为背景。则上述两张图按照公式结合后,中间圆圈都是前景相关的像素,而圆圈之外都是背景相关的像素。Alpha对应的是前景图的概率矩阵。

假如alpha训练完成后,若要完成一张图的抠图,只要alpha*原图 + (1-alpha)*白底图即可。

Alpha是介于[0, 1]之间的连续值,可以理解为像素属于前景的概率,这与人像分割是不同的。在人像分割任务中,alpha只能取0或1,本质上是分类任务,而抠图是回归任务。

抠图任务的ground truth,可以看到值分布在0~1之间。

语义分割的ground truth,可以看到值非0即1。

3、抠图算法分类

目前流行的抠图算法大致可以分为两类。

一种是需要先验信息的Trimap-based的方法,宽泛的先验信息包括Trimap、粗糙mask、无人的背景图像、pose信息等,网络使用先验信息与图片信息共同预测alpha

另一种则是Trimap-free的方法,仅根据图片信息预测alpha,对实际应用更友好,但效果普遍不如Trimap-based的方法。

目前主流是trimap-free算法。

三、Deep Image Matting算法

1、网络结构图

2、算法解读

网络包括Encoder-Decoder阶段和Refinement阶段

(1)Encoder-Decoder阶段

输入为RGB图像的patch和对应trimap的concat,所以包含4通道,经过编码和解码后输出单通道的raw alpha pred。该阶段的loss由两部分组成:

第一部分是预测的alpha和真实的alpha之间的绝对误差,考虑到L1 loss在0处不可微,使用Charbonnier Loss去近似:

第二部分是由预测的alpha、真实的前景和真实的背景组成的RGB图像与真实的RGB图像之间的绝对误差,其作用是对网络施加约束,同样使用Charbonnier Loss去近似:

最终的Loss是两部分的加权求和:

(2)Refinement阶段

它的输入为Encoder-Decoder阶段输出的raw alpha pred与原始RGB图像的concat,同样为4通道,原始RGB能够为refine提供边界细节信息。重点是使用了一个skip connection,将Encoder-Decoder阶段输出的raw alpha pred与Refinement阶段输出的refined alpha pred做一个add操作,然后输出最终的预测结果。其实Refinement阶段就是一个residual block,通过残差学习对边界信息进行建模,与去噪模型对噪声建模如出一辙。

Refinement阶段只有一个loss:refined alpha pred与GT alpha matte计算Charbonnier Loss。

四、ModNet算法:Trimap-Free Portrait Matting in Real Time

1、网络结构图

2、算法解读

网络结构由:语义估计分支、细节预测分支、语义-细节融合分支 组成。

五、ModNet抠图实践

参考文章:

【Matting】MODNet:实时人像抠图模型-onnx python部署_onnx模型下载_嘟嘟太菜了的博客-CSDN博客

原作者的onnix模型链接:https://download.csdn.net/download/qq_40035462/85046509

代码示例:

import cv2
import time
from tqdm import tqdm
import numpy as np
import onnxruntime as rt


class Matting:
    def __init__(self, model_path='onnx_model\modnet.onnx', input_size=(512, 512)):
        self.model_path = model_path
        self.sess = rt.InferenceSession(self.model_path, providers=['CUDAExecutionProvider'])
        # self.sess = rt.InferenceSession(self.model_path)  # 默认使用cpu
        self.input_name = self.sess.get_inputs()[0].name
        self.label_name = self.sess.get_outputs()[0].name
        self.input_size = input_size
        self.txt_font = cv2.FONT_HERSHEY_PLAIN

    def normalize(self, im, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
        im = im.astype(np.float32, copy=False) / 255.0
        im -= mean
        im /= std
        return im

    def resize(self, im, target_size=608, interp=cv2.INTER_LINEAR):
        if isinstance(target_size, list) or isinstance(target_size, tuple):
            w = target_size[0]
            h = target_size[1]
        else:
            w = target_size
            h = target_size
        im = cv2.resize(im, (w, h), interpolation=interp)
        return im

    def preprocess(self, image, target_size=(512, 512), interp=cv2.INTER_LINEAR):
        image = self.normalize(image)
        image = self.resize(image, target_size=target_size, interp=interp)
        image = np.transpose(image, [2, 0, 1])
        image = image[None, :, :, :]
        return image

    def predict_frame(self, bgr_image):
        assert len(bgr_image.shape) == 3, "Please input RGB image."
        raw_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
        h, w, c = raw_image.shape
        image = self.preprocess(raw_image, target_size=self.input_size)

        pred = self.sess.run(
            [self.label_name],
            {self.input_name: image.astype(np.float32)}
        )[0]
        pred = pred[0, 0]
        matte_np = self.resize(pred, target_size=(w, h), interp=cv2.INTER_NEAREST)
        matte_np = np.expand_dims(matte_np, axis=-1)
        return matte_np

    def predict_image(self, source_image_path, save_image_path):
        bgr_image = cv2.imread(source_image_path)
        assert len(bgr_image.shape) == 3, "Please input RGB image."
        matte_np = self.predict_frame(bgr_image)
        matting_frame = matte_np * bgr_image + (1 - matte_np) * np.full(bgr_image.shape, 255.0)
        matting_frame = matting_frame.astype('uint8')
        cv2.imwrite(save_image_path, matting_frame)

    def predict_camera(self):
        cap_video = cv2.VideoCapture(0)
        if not cap_video.isOpened():
            raise IOError("Error opening video stream or file.")
        beg = time.time()
        count = 0
        while cap_video.isOpened():
            ret, raw_frame = cap_video.read()
            if ret:
                count += 1
                matte_np = self.predict_frame(raw_frame)
                matting_frame = matte_np * raw_frame + (1 - matte_np) * np.full(raw_frame.shape, 255.0)
                matting_frame = matting_frame.astype('uint8')

                end = time.time()
                fps = round(count / (end - beg), 2)
                if count >= 50:
                    count = 0
                    beg = end

                cv2.putText(matting_frame, "fps: " + str(fps), (20, 20), self.txt_font, 2, (0, 0, 255), 1)

                cv2.imshow('Matting', matting_frame)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
            else:
                break
        cap_video.release()
        cv2.destroyWindow()

    def check_video(self, src_path, dst_path):
        cap1 = cv2.VideoCapture(src_path)
        fps1 = int(cap1.get(cv2.CAP_PROP_FPS))
        number_frames1 = cap1.get(cv2.CAP_PROP_FRAME_COUNT)
        cap2 = cv2.VideoCapture(dst_path)
        fps2 = int(cap2.get(cv2.CAP_PROP_FPS))
        number_frames2 = cap2.get(cv2.CAP_PROP_FRAME_COUNT)
        assert fps1 == fps2 and number_frames1 == number_frames2, "fps or number of frames not equal."

    def predict_video(self, video_path, save_path, threshold=2e-7):
        # 使用odf策略
        time_beg = time.time()
        pre_t2 = None  # 前2步matte
        pre_t1 = None  # 前1步matte

        cap = cv2.VideoCapture(video_path)
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
                int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
        number_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
        print("source video fps: {}, video resolution: {}, video frames: {}".format(fps, size, number_frames))
        videoWriter = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc('I', '4', '2', '0'), fps, size)

        ret, frame = cap.read()
        with tqdm(range(int(number_frames))) as t:
            for c in t:
                matte_np = self.predict_frame(frame)
                if pre_t2 is None:
                    pre_t2 = matte_np
                elif pre_t1 is None:
                    pre_t1 = matte_np
                    # 第一帧写入
                    matting_frame = pre_t2 * frame + (1 - pre_t2) * np.full(frame.shape, 255.0)
                    videoWriter.write(matting_frame.astype('uint8'))
                else:
                    # odf
                    error_interval = np.mean(np.abs(pre_t2 - matte_np))
                    error_neigh = np.mean(np.abs(pre_t1 - pre_t2))
                    if error_interval < threshold < error_neigh:
                        pre_t1 = pre_t2

                    matting_frame = pre_t1 * frame + (1 - pre_t1) * np.full(frame.shape, 255.0)
                    videoWriter.write(matting_frame.astype('uint8'))
                    pre_t2 = pre_t1
                    pre_t1 = matte_np

                ret, frame = cap.read()
            # 最后一帧写入
            matting_frame = pre_t1 * frame + (1 - pre_t1) * np.full(frame.shape, 255.0)
            videoWriter.write(matting_frame.astype('uint8'))
            cap.release()
        print("video matting over, time consume: {}, fps: {}".format(time.time() - time_beg, number_frames / (time.time() - time_beg)))


if __name__ == '__main__':
    model = Matting(model_path='onnx_model\modnet.onnx', input_size=(512, 512))
    model.predict_camera()
    # model.predict_image('images\\1.jpeg', 'output\\1.png')
    # model.predict_image('images\\2.jpeg', 'output\\2.png')
    # model.predict_image('images\\3.jpeg', 'output\\3.png')
    # model.predict_image('images\\4.jpeg', 'output\\4.png')
    # model.predict_video("video\dance.avi", "output\dance_matting.avi")

代码中涉及的modnet.onnx文件见最上面的附件。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/629852.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue3 -- lottie-web使用

Lottie简介 官方介绍:Lottie是一个库,可以解析使用AE制作的动画(需要用bodymovie导出为json格式),支持web、ios、android、flutter和react native。在web端,lottie-web库可以解析导出的动画json文件,并将其以svg或者canvas的方式将动画绘制在我们的页面上. Lottie的优点 …

chatgpt赋能python:Python版本切换教程

Python版本切换教程 Python是一种高级编程语言&#xff0c;用于多种编程任务。但是&#xff0c;由于Python版本之间的不兼容性&#xff0c;有时候需要切换Python版本以满足特定的需求。在本文中&#xff0c;我们将介绍Python版本切换的方法&#xff0c;包括安装和使用多个版本…

机器学习 | 集成算法 | Bagging | Boosting | 概念向

&#x1f4da;Bagging和Boosting的概念 集成学习&#xff08;Ensemble Learning&#xff09;就是通过某种策略将多个模型集成起来&#xff0c;通过群体决策来提高决策准确率。为什么集成学习会好于单个学习器呢&#xff1f;原因可能有三&#xff1a; 训练样本可能无法选择出最好…

【ARMv8 SIMD和浮点指令编程】NEON 通用算术指令——杂项也不少

算术通用指令杂项包括以下指令: UABA、UABAL/UABAL2、UABD 和 UABDL/UABDL2。无符号向量差值绝对值累加和差值绝对值。 ABS 和 NEG向量绝对值和求反。 UMAX、UMIN、UPMAX、UPMIN、UMAXV 和 UMINV。无符号向量最大值,无符号向量最小值,无符号向量按对最大值,无符号向量按对最…

chatgpt赋能python:Python怎么分行输出?教程来了!

Python怎么分行输出&#xff1f;教程来了&#xff01; Python是一种解释型、面向对象、动态数据类型的高级编程语言。在Python中&#xff0c;分行输出是非常常见的操作&#xff0c;本文将介绍Python分行输出的不同方式以及使用的情况。 一、使用换行符 使用换行符是Python分…

python文字转语音(pyttsx3+flask)

提示&#xff1a;文章结尾有全部代码 目录 前言一、Flaskpyttsx基本使用Flask导入Flask框架配置基础环境初始Flask代码 pyttsx3库基本使用导入pyttsx3初始化pyttsx3文字转语音运行 二、具体实现1.引入库 总结 前言 本文主要讲解如何用python的pyttsx3库flask框架&#xff0c;手…

chatgpt赋能python:Python切换指南:让你无缝转换到Python

Python切换指南&#xff1a;让你无缝转换到Python Python是一个高级的编程语言&#xff0c;它可以用来进行各种各样的应用开发和数据分析。 Python有很多优点&#xff0c;比如它易于学习和使用&#xff0c;它是一个开源的语言&#xff0c;它具有广泛的库和框架。 如果你是处于…

Mysql数据库(六):基本的SELECT语句

基本的SELECT语句 前言一、SELECT...二、SELECT ... FROM三、列的别名四、去除重复行五、空值参与运算六、着重号七、查询常数八、显示表结构九、过滤数据 前言 本博主将用CSDN记录软件开发求学之路上亲身所得与所学的心得与知识&#xff0c;有兴趣的小伙伴可以关注博主&#…

如何监控EMC VNX控制器的启动过程

这里我们要讨论的内容基本上适用于所有的EMC VNX中端存储系统&#xff0c;包含老的Clariion CX3&#xff0c;CX4&#xff0c;VNX1和VNX2&#xff0c;其实VNXe和Unity很多内容也是一样的。当然由于VNXe和Unity 操作系统的大的变化&#xff0c;差异也是比较大的。 导致EMC Clarr…

什么是M-LAG?为什么需要M-LAG?

M-LAG&#xff08;Multichassis Link Aggregation Group&#xff09;提供一种跨设备链路聚合的技术。M-LAG通过将两台接入交换机以同一个状态和用户侧设备或服务器进行跨设备的链路聚合&#xff0c;把链路的可靠性从单板级提升到设备级。同时&#xff0c;由于M-LAG设备可以单独…

如何设计一个完整的交互流程,提升产品用户体验

交互流程设计是一项关乎用户体验的重要工作。通过设计和规划用户与产品或服务的交互方式和流程&#xff0c;我们可以提高用户的满意度和使用效果。在本文中&#xff0c;我们将深入探讨交互流程设计的关键要素以及其对用户体验的重要性。 交互流程设计本质是通过设计和规划用户与…

【Python】Python进阶系列教程-- Python3 SMTP发送邮件(六)

文章目录 前言实例使用Python发送HTML格式的邮件Python 发送带附件的邮件在 HTML 文本中添加图片使用第三方 SMTP 服务发送 前言 往期回顾&#xff1a; Python进阶系列教程-- Python3 正则表达式&#xff08;一&#xff09;Python进阶系列教程-- Python3 CGI编程&#xff08;…

chatgpt赋能python:Python怎么倒序输出字符串

Python怎么倒序输出字符串 Python是一种高级编程语言&#xff0c;它可以让开发人员快速编写代码。在Python中&#xff0c;字符串是一种非常常见的数据类型&#xff0c;其支持各种字符串操作。在这篇文章中&#xff0c;我们将讨论如何在Python中倒序输出字符串。 倒序输出字符…

openGauss5.0企业版使用指南之企业版安装

文章目录 0. 前言1. 安装1.1 获取安装包1.1.1 操作步骤1.1.2 准备软硬件安装环境1.1.3 软硬件环境要求1.1.4 修改操作系统配置1.1.5 **关闭操作系统防火墙**1.1.6 **设置字符集参数**1.1.7 **设置时区和时间**1.1.8 **&#xff08;可选&#xff09;关闭swap交换内存**1.1.9 **关…

MOVEit Transfer 漏洞似乎被广泛利用

Progress Software 已在其文件传输软件 MOVEit Transfer 中发现一个漏洞&#xff0c;该漏洞可能导致权限提升和潜在的未经授权访问环境&#xff0c;该公司在一份安全公告中表示。 在 MOVEit Transfer Web 应用程序中发现了一个 SQL 注入漏洞&#xff0c;可能允许未经身份验证…

【ARMv8 SIMD和浮点指令编程】NEON 逻辑指令——与或非有多少?

NEON 逻辑指令主要包括与、或、异或、位清除、或非、为 False 时按位插入、为 True 时按位插入和按位选择指令,下面我们来详细学习这些指令。 一、逻辑指令 1.1 AND 按位与(向量),该指令将两个源 SIMD&FP 寄存器按位与,并将结果写入目标 SIMD&FP 寄存器。 AND …

基于JDBC的账务管理系统

一、项目介绍 1.1 项目目标 本项目为JAVAEE基础和数据库的综合项目&#xff0c;包含了若干个知识点&#xff0c;达到将从基础班到现在所学的知识综合使用&#xff0c;提高了我们对项目的理解与知识点的运用。熟练View层、Service层、Dao层之间的方法相互调用操作熟练使用工具类…

chatgpt赋能python:Python怎么入侵别人微信:一种黑客行为的技术探讨

Python怎么入侵别人微信&#xff1a;一种黑客行为的技术探讨 随着社交媒体微信的普及和使用程度的不断提高&#xff0c;对微信的攻击和入侵成为了目前互联网安全领域的热点问题之一。其中&#xff0c;Python编程语言的广泛应用和强大的功能使得其逐渐成为了微信黑客行为的利器…

robots.txt的作用是什么,看完了我默默加在了自己网站上

文章目录 背景robots.txt的主要作用使用示范User-agentDisallowAllowSitemap 总结 背景 最近在研究网站SEO相关的东西&#xff0c;第一次接触到robots.txt&#xff0c;才发现实际上很多网站都用到了它&#xff0c;尤其是对搜索引擎依赖特别高的C端系统或者网站&#xff0c;是一…

论文解读:SuperGlue: Learning Feature Matching with Graph Neural Networks

SuperGlue: Learning Feature Matching with Graph Neural Networks 发表时间&#xff1a;2020 论文地址&#xff1a;https://arxiv.org/abs/1911.11763 项目地址&#xff1a;http://github.com/magicleap/SuperGluePretrainedNetwork。 本文介绍了一种通过联合寻找对应和拒绝…