基于华为atlas的unet分割模型探索

news2024/9/23 11:20:59

Unet模型使用官方基于kaggle Carvana Image Masking Challenge数据集训练的模型。

模型输入为572*572*3,输出为572*572*2。分割目标分别为,0:背景,1:汽车。

Pytorch的pth模型转化onnx模型:

import torch

from unet import UNet

model = UNet(n_channels=3, n_classes=2, bilinear=False)
model = model.to(memory_format=torch.channels_last)

state_dict = torch.load("unet_carvana_scale1.0_epoch2.pth", map_location="cpu")
#del state_dict['mask_values']
model.load_state_dict(state_dict)

dummy_input = torch.randn(1, 3, 572, 572)

torch.onnx.export(model, dummy_input, "unet.onnx", verbose=True)

模型输入输出节点分析:

使用工具Netron查看模型结构,确定模型输入节点名称为input.1,输出节点名称为/outc/conv/Conv

onnx模型转化atlas模型:

atc --model=./unet.onnx --framework=5 --output=unet --soc_version=Ascend310P3  --input_shape="input.1:1,3,572,572" --output_type="/outc/conv/Conv:0:FP32" --out_nodes="/outc/conv/Conv:0"

推理代码实现:

import base64
import json
import os
import time

import numpy as np
import cv2

import MxpiDataType_pb2 as mxpi_data
from StreamManagerApi import InProtobufVector
from StreamManagerApi import MxProtobufIn
from StreamManagerApi import StreamManagerApi


def check_dir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir, exist_ok=True)


class SDKInferWrapper:
    def __init__(self): # 完成初始化
        self._stream_name = None
        self._stream_mgr_api = StreamManagerApi()

        if self._stream_mgr_api.InitManager() != 0:
            raise RuntimeError("Failed to init stream manager.")

        pipeline_name = './nested_unet.pipeline'

        self.load_pipeline(pipeline_name)

        self.width = 572
        self.height = 572

    def load_pipeline(self, pipeline_path):
        with open(pipeline_path, 'r') as f:
            pipeline = json.load(f)

        self._stream_name = list(pipeline.keys())[0].encode() # 'unet_pytorch'
        if self._stream_mgr_api.CreateMultipleStreams(
                json.dumps(pipeline).encode()) != 0:
            raise RuntimeError("Failed to create stream.")

    def do_infer(self, img_bgr):

        # preprocess
        image = cv2.resize(img_bgr, (self.width, self.height))
        image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
        image = image.astype('float32') / 255.0
        image = image.transpose(2, 0, 1)


        tensor_pkg_list = mxpi_data.MxpiTensorPackageList()
        tensor_pkg = tensor_pkg_list.tensorPackageVec.add()
        tensor_vec = tensor_pkg.tensorVec.add()
        tensor_vec.deviceId = 0
        tensor_vec.memType = 0

        for dim in [1, *image.shape]:
            tensor_vec.tensorShape.append(dim) # tensorshape属性为[1,3,572,572]

        input_data = image.tobytes()
        tensor_vec.dataStr = input_data
        tensor_vec.tensorDataSize = len(input_data)

        protobuf_vec = InProtobufVector()
        protobuf = MxProtobufIn()
        protobuf.key = b'appsrc0'
        protobuf.type = b'MxTools.MxpiTensorPackageList'
        protobuf.protobuf = tensor_pkg_list.SerializeToString()
        protobuf_vec.push_back(protobuf)

        unique_id = self._stream_mgr_api.SendProtobuf(
            self._stream_name, 0, protobuf_vec)

        if unique_id < 0:
            raise RuntimeError("Failed to send data to stream.")

        infer_result = self._stream_mgr_api.GetResult(
            self._stream_name, unique_id)

        if infer_result.errorCode != 0:
            raise RuntimeError(
                f"GetResult error. errorCode={infer_result.errorCode}, "
                f"errorMsg={infer_result.data.decode()}")
        
        output_tensor = self._parse_output_data(infer_result)
        output_tensor = np.squeeze(output_tensor)
        output_tensor = softmax(output_tensor)

        mask = np.argmax(output_tensor, axis =0)
        score = np.max(output_tensor, axis = 0)


        mask = cv2.resize(mask, [img_bgr.shape[1], img_bgr.shape[0]], interpolation=cv2.INTER_NEAREST)
        score = cv2.resize(score, [img_bgr.shape[1], img_bgr.shape[0]], interpolation=cv2.INTER_NEAREST)

        return mask, score



    def _parse_output_data(self, output_data):
        infer_result_data = json.loads(output_data.data.decode())
        content = json.loads(infer_result_data['metaData'][0]['content'])
        tensor_vec = content['tensorPackageVec'][0]['tensorVec'][0]
        data_str = tensor_vec['dataStr']
        tensor_shape = tensor_vec['tensorShape']
        infer_array = np.frombuffer(base64.b64decode(data_str), dtype=np.float32)
        return infer_array.reshape(tensor_shape)



    def draw(self, mask):
        color_lists = [(255, 0, 0), (0, 255, 0), (0, 0, 255)]

        drawed_img = np.stack([mask, mask, mask], axis = 2)
        for i in np.unique(mask):
            drawed_img[:,:,0][drawed_img[:,:,0]==i] = color_lists[i][0]
            drawed_img[:,:,1][drawed_img[:,:,1]==i] = color_lists[i][1]
            drawed_img[:,:,2][drawed_img[:,:,2]==i] = color_lists[i][2]

        return drawed_img

def softmax(x):
    exps = np.exp(x - np.max(x))
    return exps/np.sum(exps)



def sigmoid(x):
    y = x.copy()
    y[x >= 0] = 1.0 / (1 + np.exp(-x[x >= 0]))
    y[x < 0] = np.exp(x[x < 0]) / (1 + np.exp(x[x < 0]))
    return y



def check_dir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir, exist_ok=True)


def test():
    dataset_dir = './sample_data'
    output_folder = "./infer_result"   
    os.makedirs(output_folder, exist_ok=True)

    sdk_infer = SDKInferWrapper()


    # read img
    image_name = "./sample_data/images/111.jpg"
    img_bgr = cv2.imread(image_name)
    
    # infer
    t1 = time.time()
    mask, score = sdk_infer.do_infer(img_bgr)
    t2 = time.time()
    print(t2-t1, mask, score)
    drawed_img = sdk_infer.draw(mask)
    cv2.imwrite("infer_result/draw.png", drawed_img)
    

if __name__ == "__main__":
    test()

运行代码:

set -e
. /usr/local/Ascend/ascend-toolkit/set_env.sh
# Simple log helper functions
info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; }
warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; }

#export MX_SDK_HOME=/home/work/mxVision
export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH}
export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner
export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins

#to set PYTHONPATH, import the StreamManagerApi.py
export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python

python3 unet.py
exit 0

运行效果:

个人思考:

华为atlas的参考案例细节不到位,步骤缺失较多,摸索困难,代码写法较差,信创化道路任重而道远。

参考资料:

GitHub - milesial/Pytorch-UNet: PyTorch implementation of the U-Net for image semantic segmentation with high quality images

https://gitee.com/ascend/samples/tree/master/python/level2_simple_inference/3_segmentation/unet++

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1502815.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【蓝桥杯】单词分析 (BF)

一.题目描述 二.问题分析 //单词分析 #include <iostream> using namespace std;const int N1e42; char s[N]; int c[26]{0};int main(int argc, const char * argv[]) {ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);cin>>s;int max0,i0;char xa;while(s[i]){…

【Python使用】python高级进阶知识md总结第1篇:My Awesome Book【附代码文档】

python高级进阶全知识知识笔记总结完整教程&#xff08;附代码资料&#xff09;主要内容讲述&#xff1a;My Awesome Book&#xff0c;My Awesome Book。My Awesome Book&#xff0c;MySQL数据库。My Awesome Book&#xff0c;聚合函数。My Awesome Book&#xff0c;创建表并给…

Mysql深入学习 基础篇 Ss.06 事务

青青子衿&#xff0c;悠悠我心 纵我不往&#xff0c;子宁不嗣音 —— 24.3.9 事务&#xff1a; 事务简介 事务操作 事务四大特性 并发事务问题 事务隔离级别 一、事务简介 事务是一组操作的集合&#xff0c;它是一个不可分割的工作单位&#xff0c;事务会把所有的操作作为一个整…

开源的Java图片处理库介绍

在 Java 生态系统中&#xff0c;有几个流行的开源库可以用于图片处理。这些库提供了丰富的功能&#xff0c;如图像缩放、裁剪、颜色调整、格式转换等。以下是几个常用的 Java 图片处理库的介绍&#xff0c;包括它们的核心类、主要作用和应用场景&#xff0c;以及一些简单的例子…

MATLAB2020a安装编译器mingw-64(6.3.0)

MATLAB2020a指定安装mingw-64&#xff08;6.3.0&#xff09;版本编译器 记录一下几个要点 mingw-64&#xff08;6.3.0&#xff09; 找到对应的mingw-64安装包 设置mingw的bin文件路径到环境变量 变量名&#xff1a;MW_MINGW64_LOC MATLAB设置路径

java 哨兵线性搜索

顾名思义&#xff0c;哨兵线性搜索是线性搜索的一种&#xff0c;与传统线性搜索相比&#xff0c;比较次数减少了。在传统的线性搜索中&#xff0c;仅进行N次比较&#xff0c;而在哨兵线性搜索中&#xff0c;哨兵值用于避免任何越界比较&#xff0c;但没有专门针对正在搜索的元素…

快速瓦斯封孔器请满载希望出发

不论昨天如何&#xff0c;今天请满载希望出发&#xff01;每一个微笑、每一次服务&#xff0c;都是我们通往成功巅峰的阶梯。 一、 用途&#xff1a; CKF&#xff0d;I型快速瓦斯封孔器用以快速封闭采面卸压抽放钻孔&#xff0c;具有重量轻、速度快、操作简便的特点&#xff1…

Python中的装饰器详解及实际应用【第120篇—装饰器详解】

Python中的装饰器详解及实际应用 在Python编程中&#xff0c;装饰器&#xff08;Decorator&#xff09;是一种强大而灵活的工具&#xff0c;用于修改函数或方法的行为。它们广泛应用于许多Python框架和库&#xff0c;如Flask、Django等。本文将深入探讨装饰器的概念、使用方法…

C/C++实现代码雨效果

C/C实现代码雨效果 目录 C/C实现代码雨效果 说明使用的库说明测试代码效果图 说明 最近整理电脑资料&#xff0c;翻出了以前写的代码&#xff0c;顺便整理一下到博客上&#xff0c;当做一次备份记录 先看看静态效果 需要分为以下步骤实现 生成代码串把代码串绘制到窗口中使…

NUMA架构

UMA架构 在单cpu的时代&#xff0c;cpu与内存的交互需要通过北桥芯片来完成。cpu通过前端总线(FSB&#xff0c; front Side Bus)连接到北桥芯片&#xff0c;由北桥芯片连接到内存&#xff08;内存控制器是集成在北桥芯片里的&#xff09;。为了提升性能&#xff0c;cpu的频率不…

Web APIs 4 日期对象、节点操作

Web APIs 4 一、日期对象实例化日期对象方法案例&#xff1a;页面显示时间 时间戳 二、节点操作查找结点①父节点查找②子节点查找③兄弟节点查找 增加节点克隆节点删除节点 三、M端事件四、JS插件 一、日期对象 学习路径&#xff1a;实例化、日期对象方法、时间戳 实例化 …

论文学习——一种新的具有分层响应系统的动态多目标优化算法

论文题目&#xff1a;A Novel Dynamic Multiobjective Optimization Algorithm With Hierarchical Response System 一种新的具有分层响应系统的动态多目标优化算法&#xff08;Han Li , Zidong Wang , Fellow, IEEE, Chengbo Lan, Peishu Wu , and Nianyin Zeng , Member, IE…

c++ 常用的STL

前言 写这篇博客目的是为了记录在刷算法题中使用过的STL&#xff0c;因为有些不太常用的会遗忘。这篇博客只是作为笔记&#xff0c;不是详细的STL&#xff0c;因此只会对常用方法说明&#xff0c;不会详细介绍。此外在后面用到新的STL内容时会再补充。 列队 基础列队 基本列…

YOLOv8-Seg改进:特征融合篇 | GELAN(广义高效层聚合网络)结构来自YOLOv9

🚀🚀🚀本文改进:使用GELAN改进架构引入到YOLOv8 🚀🚀🚀YOLOv8-seg创新专栏:http://t.csdnimg.cn/KLSdv 学姐带你学习YOLOv8,从入门到创新,轻轻松松搞定科研; 1)手把手教你如何训练YOLOv8-seg; 2)模型创新,提升分割性能; 3)独家自研模块助力分割; 1.YO…

申请公众号上限是多少

一般可以申请多少个公众号&#xff1f;公众号申请限额在过去几年内的经历了很多变化。对公众号申请限额进行调整是出于多种原因&#xff0c;确保公众号内容的质量和合规性。企业公众号的申请数量从50个到5个最后到2个&#xff0c;对于新媒体公司来说&#xff0c;这导致做不了公…

基于深度视觉实现机械臂对目标的识别与定位

机械臂手眼标定 根据相机和机械臂的安装方式不同&#xff0c;手眼标定分为眼在手上和眼在手外两种方式&#xff0c;双臂机器人的相机和机械臂基座的相对位置固定&#xff0c;所以应该采用眼在手外的手眼标定方式。 后续的视觉引导机械臂抓取测试实验基于本实验实现&#xf…

CentOS 7 devtoolset编译addressSanitizer版本失败的问题解决

在我的一个Cent OS7开发环境中&#xff0c;按https://yeyongjin.blog.csdn.net/article/details/134178420的方法升级GCC版本到8.3.1。 这两天&#xff0c;要用Google的addressSanitizer检验内存问题&#xff0c;加上编译参数后&#xff0c;却发现编译不通过。configure时直接退…

微服务韧性工程:利用Sentinel实施有效服务容错与限流降级

目录 一、雪崩效应 二、Sentinel 服务容错 2.1 Sentinel容错思路 2.2 内部异常兼容 2.3 外部流量控制 三、Sentinel 项目搭建 四、Sentinel 工作原理 服务容错是微服务设计中一项重要原则和技术手段&#xff0c;主要目标是在服务出现故障、网络波动或其他不可预见的异常情况…

5G 网络切片VLAN ID配置错误导致业务不可用

【摘要】随着电联5G共建共享工作的开展&#xff0c;无法及时有效观测到单逻辑站点的相关指标&#xff0c;导致单运营商用户业务出现异常。本案例中着重对单运营商用户无法使用网络进行相关参数排查&#xff0c;从KPI性能指标结合故障告警发生时间&#xff0c;从而分析由于网络切…

Web APIs 5 Window对象、本地存储

Web APIs 5 一、Window对象1、BOM2、定时器-延时函数3、JS执行机制4、location对象案例&#xff1a;5秒钟之后跳转的页面 5、navigator对象6、histroy对象 二、本地存储本地存储 localStorage本地存储 sessionStorage存储复杂数据类型案例&#xff1a;学生就业统计表字符串拼接…