UNET-RKNN分割眼底血管

news2025/1/22 23:01:42

前言

        最近找到一个比较好玩的Unet分割项目,Unet的出现就是为了在医学上进行分割(比如细胞或者血管),这里进行眼底血管的分割,用的backbone是VGG16,结构如下如所示(项目里面的图片,借用的!借用标记出处,尊重别人的知识产权),模型比较小,但是效果感觉还不错的。

         相关的算法发介绍就不写了接下来从PYTORCH、ONNX、rknn三个方面看看效果

全部代码地址: https://pan.baidu.com/s/1QkOz5tvRSF-UkJhmpI__lA 提取码: 8twv 

检测原图

1. Pytroch推理代码

        gpu_test文件夹

├── predict.py:推理代码
├── test_result_cuda.png: 检测结果
├── save_weights:模型文件夹
├── images:图片文件夹
├── src:相关库文件夹
└── mask:mask图片文件夹 

import os
import time

import torch
from torchvision import transforms
import numpy as np
from PIL import Image

from src import UNet


def time_synchronized():
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    return time.time()


def main():
    classes = 1  # exclude background
    # 模型路径
    weights_path = "./save_weights/best_model.pth"
    # 检测图片路径
    img_path = "./images/01_test.tif"
    # mask图片路径
    roi_mask_path = "./mask/01_test_mask.gif"
    assert os.path.exists(weights_path), f"weights {weights_path} not found."
    assert os.path.exists(img_path), f"image {img_path} not found."
    assert os.path.exists(roi_mask_path), f"image {roi_mask_path} not found."

    mean = (0.709, 0.381, 0.224)
    std = (0.127, 0.079, 0.043)

    # get devices
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # 用cpu推理
    # device = "cpu"
    print("using {} device.".format(device))

    # create model
    model = UNet(in_channels=3, num_classes=classes+1, base_c=32)

    # load weights
    model.load_state_dict(torch.load(weights_path, map_location='cpu')['model'])
    model.to(device)

    # dummy_input = torch.randn(1, 3, 584, 565)
    # torch.onnx.export(model, dummy_input, 'eyes_unet.onnx', verbose=True, opset_version=11)
    # load roi mask
    roi_img = Image.open(roi_mask_path).convert('L')
    roi_img = np.array(roi_img)

    # load image
    original_img = Image.open(img_path).convert('RGB')

    # from pil image to tensor and normalize
    data_transform = transforms.Compose([transforms.ToTensor(),
                                         transforms.Normalize(mean=mean, std=std)])
    img = data_transform(original_img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    model.eval()  # 进入验证模式
    with torch.no_grad():
        # init model
        img_height, img_width = img.shape[-2:]
        init_img = torch.zeros((1, 3, img_height, img_width), device=device)
        model(init_img)

        t_start = time_synchronized()
        output = model(img.to(device))
        print(output["out"].shape)
        t_end = time_synchronized()
        print("inference time: {}".format(t_end - t_start))

        prediction = output['out'].argmax(1).squeeze(0)
        prediction = prediction.to("cpu").numpy().astype(np.uint8)
        # np.save("cuda_unet.npy", prediction)
        print(prediction.shape)
        # 将前景对应的像素值改成255(白色)
        prediction[prediction == 1] = 255
        # 将不敢兴趣的区域像素设置成0(黑色)
        prediction[roi_img == 0] = 0
        mask = Image.fromarray(prediction)
        mask.save("test_result_cuda.png")


if __name__ == '__main__':
    main()

        检测结果

2. ONNX代码推理

        onnx_test文件夹

├── images : 检测图片文件夹
├── test_result_onnx.png: 检测结果
├── predict_onnx.py:推理代码
├── mask:mask图片文件夹
└── eyes_unet-sim.onnx :模型文件

import os
import time
from torchvision import transforms
import numpy as np
from PIL import Image
import onnxruntime as rt


def main():
    # classes = 1  # exclude background
    img_path = "./images/01_test.tif"
    roi_mask_path = "./mask/01_test_mask.gif"
    assert os.path.exists(img_path), f"image {img_path} not found."
    assert os.path.exists(roi_mask_path), f"image {roi_mask_path} not found."

    mean = (0.709, 0.381, 0.224)
    std = (0.127, 0.079, 0.043)

    # load roi mask
    roi_img = Image.open(roi_mask_path).convert('L')
    roi_img = np.array(roi_img)
    # load image
    original_img = Image.open(img_path).convert('RGB')
    # from pil image to tensor and normalize
    data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
    img = data_transform(original_img)
    # expand batch dimension
    img = img.numpy()
    img = img[np.newaxis, :]
    t_start = time.time()
    sess = rt.InferenceSession('./eyes_unet-sim.onnx')
    # 模型的输入和输出节点名,可以通过netron查看
    input_name = 'input.1'
    outputs_name = ['437']
    # 模型推理:模型输出节点名,模型输入节点名,输入数据(注意节点名的格式!!!!!)
    output = sess.run(outputs_name, {input_name: img})
    output = np.array(output).reshape(1, 2, 584, 565)
    t_end = time.time()
    print("inference time: {}".format(t_end - t_start))
    prediction = np.squeeze(np.argmax(output, axis=1))
    print(prediction.shape)
    prediction = prediction.astype(np.uint8)
    # 将前景对应的像素值改成255(白色)
    prediction[prediction == 1] = 255
    # 将不敢兴趣的区域像素设置成0(黑色)
    prediction[roi_img == 0] = 0
    mask = Image.fromarray(prediction)
    mask.save("test_result_onnx.png")


if __name__ == '__main__':
    main()

        检测结果 

 3. RKNN模型转化

        rknn_trans_1808_3588文件夹

├── dataset.txt: 量化数据集路径 
├── images :量化数据集
├── trans_1808.py :适用1808的rknn模型
├── trans_3588.py :适用3588的rknn模型
├── mask:没用到 
└── eyes_unet-sim.onnx:原始onnx模型
        这个没什么好说的,装好环境,直接在相应的环境里面转就好啦,大家应该都会的(不会就拉出去,或者收藏留言,嘿嘿,我看看,出不出教程呢)

 4. RKNN模型推理

        4.1 rk1808_test文件夹

├── 01_test_mask.gif:mask图片
├── eyes_unet-sim-1808.rknn:rk1808适用模型
├── predict_rknn_1808.py:推理代码
├── test_result_1808.png :检测结果
└── 01_test.tif:检测图片

import os
import time
import numpy as np
from PIL import Image
from rknn.api import RKNN


def main():
    # classes = 1  # exclude background
    RKNN_MODEL = "./eyes_unet-sim-1808.rknn"
    img_path = "./01_test.tif"
    roi_mask_path = "./01_test_mask.gif"
    assert os.path.exists(img_path), f"image {img_path} not found."
    assert os.path.exists(roi_mask_path), f"image {roi_mask_path} not found."


    # load roi mask
    roi_img = Image.open(roi_mask_path).convert('L')
    roi_img = np.array(roi_img)


    # load image
    original_img = Image.open(img_path).convert('RGB')
    # from pil image to tensor and normalize
    # data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
    # img = data_transform(original_img)
    # expand batch dimension
    img = np.array(original_img)
    img = img[np.newaxis, :]


    # Create RKNN object
    rknn = RKNN(verbose=False)
    ret = rknn.load_rknn(RKNN_MODEL)


    # Init runtime environment
    print('--> Init runtime environment')

    ret = rknn.init_runtime(target='rk1808')
    if ret != 0:
        print('Init runtime environment failed!')
        exit(ret)
    print('done')
    
    t_start = time.time()
    output = rknn.inference(inputs=[img])
    t_end = time.time()
    print("inference time: {}".format(t_end - t_start))
    output = np.array(output).reshape(1, 2, 584, 565)
    prediction = np.squeeze(np.argmax(output, axis=1))
    print(prediction.shape)


    prediction = prediction.astype(np.uint8)
    # 将前景对应的像素值改成255(白色)
    prediction[prediction == 1] = 255
    # 将不敢兴趣的区域像素设置成0(黑色)
    prediction[roi_img == 0] = 0
    mask = Image.fromarray(prediction)
    mask.save("test_result_1808.png")

    rknn.release()


if __name__ == '__main__':
    main()

        检测结果

      4.2 rk3588_test文件夹

├── 01_test_mask.gif:mask图片
├── test_result_3588.png:检测结果
├── eyes_unet-sim-3588.rknn:rk1808适用模型
├── 01_test.tif:检测图片
└── predict_3588.py:推理代码

import os
import time
import numpy as np
from PIL import Image
from rknnlite.api import RKNNLite


def main():
    # classes = 1  # exclude background
    RKNN_MODEL = "./eyes_unet-sim.rknn"
    img_path = "./01_test.tif"
    roi_mask_path = "./01_test_mask.gif"
    assert os.path.exists(img_path), f"image {img_path} not found."
    assert os.path.exists(roi_mask_path), f"image {roi_mask_path} not found."


    # load roi mask
    roi_img = Image.open(roi_mask_path).convert('L')
    roi_img = np.array(roi_img)


    # load image
    original_img = Image.open(img_path).convert('RGB')
    # from pil image to tensor and normalize
    # data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
    # img = data_transform(original_img)
    # expand batch dimension
    img = np.array(original_img)
    img = img[np.newaxis, :]


    # Create RKNN object
    rknn_lite = RKNNLite(verbose=False)
    ret = rknn_lite.load_rknn(RKNN_MODEL)


    # Init runtime environment
    print('--> Init runtime environment')

    ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_AUTO)
    if ret != 0:
        print('Init runtime environment failed!')
        exit(ret)
    print('done')
    
    t_start = time.time()
    output = rknn_lite.inference(inputs=[img])
    t_end = time.time()
    print("inference time: {}".format(t_end - t_start))
    output = np.array(output).reshape(1, 2, 584, 565)
    prediction = np.squeeze(np.argmax(output, axis=1))
    print(prediction.shape)


    prediction = prediction.astype(np.uint8)
    np.save("int8_unet.npy", prediction)
    # 将前景对应的像素值改成255(白色)
    prediction[prediction == 1] = 255
    # 将不敢兴趣的区域像素设置成0(黑色)
    prediction[roi_img == 0] = 0
    mask = Image.fromarray(prediction)
    mask.save("test_result_3588.png")

    rknn_lite.release()


if __name__ == '__main__':
    main()

检测结果

5. 所有结果对比 

原图

GPU/ONNX

RK1808

RK3588

         其实比对了一下数据,量化的效果还不错,精度在99.5%左右,还是蛮好的!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/417472.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C语言函数大全--h开头的函数

C语言函数大全 本篇介绍C语言函数大全–h开头的函数或宏 1. hypot,hypotf,hypotl 1.1 函数说明 函数声明函数功能double hypot(double x, double y);计算直角三角形的斜边长(double)float hypotf (float x, float y);计算直角…

UPA/URA双极化天线的协方差矩阵结构

文章目录UPA的阵列响应向量(暂不考虑双极化天线)UPA阵列响应:从单极化天线到双极化天线UPA双极化天线的协方差矩阵结构参考文献UPA的阵列响应向量(暂不考虑双极化天线) 下图形象描述了UPA阵列的接收信号 UPA阵列的水平…

【springcloud 微服务】Spring Cloud 微服务网关Gateway使用详解

目录 一、微服务网关简介 1.1 网关的作用 1.2 常用网关 1.2.1 传统网关 1.2.2 云原生网关 二、gateway网关介绍 2.1 问题起源 2.2 引发的问题 2.2.1 重复造轮子 2.2.2 调用低效 2.2.3 重构复杂 2.3 gateway改进 三、Spring Cloud Gateway 介绍 3.1 Gateway 概述 …

【JSON学习笔记】3.JSON.parse()及JSON.stringify()

前言 本章介绍JSON.parse()及JSON.stringify()。 JSON.parse() JSON 通常用于与服务端交换数据。 在接收服务器数据时一般是字符串。 我们可以使用 JSON.parse() 方法将数据转换为 JavaScript 对象。 语法 JSON.parse(text[, reviver])参数说明: text:必需&…

Angular可视化指南 - 用Kendo UI图表组件创建数据可视化

Kendo UI for Angular是专业级的Angular UI组件库,不仅是将其他供应商提供的现有组件封装起来,telerik致力于提供纯粹高性能的Angular UI组件,而无需任何jQuery依赖关系。无论您是使用TypeScript还是JavaScript开发Angular应用程序&#xff0…

【机器学习(二)】线性回归之梯度下降法

文章目录专栏导读1、梯度下降法原理2、梯度下降法原理代码实现3、sklearn内置模块实现专栏导读 ✍ 作者简介:i阿极,CSDN Python领域新星创作者,专注于分享python领域知识。 ✍ 本文录入于《数据分析之术》,本专栏精选了经典的机器…

1漏洞发现

漏洞发现-操作系统之漏洞探针类型利用修复 一、操作系统漏洞思维导图 相关名词解释: CVSS(Common Vulnerability Scoring System,即“通用漏洞评分系统”) CVSS是安全内容自动化协议(SCAP)的一部分通常C…

rockchip rk3588添加uvc及uvc,adb的复合设备

软硬件环境: 软件基础:我目前拿到的rk3588 sdk :gitwww.rockchip.com.cn:2222/Android_S/rk3588- manifests.git硬件基础:RK3588 LP4X EVB uvc_app: 从rv1126 sdk中rv1126_sdk/rv1126/external/uvc_app 目录移植而来。移植后&…

能翻译大量文字的软件-正规的翻译软件

复制自动翻译软件是一种能够复制并自动翻译文本的工具。当您阅读某一种语言的文本时,这种软件可以快速识别并翻译出来,以方便您更好地理解内容。与其他翻译软件不同的是,复制自动翻译软件可以直接在游览网站的过程中,直接对用户正…

【C++】命名空间,缺省参数,函数重载,引用,内联函数

目录1. 命名空间2. 输入输出3. 缺省参数4. 函数重载为什么C支持函数重载?5. 引用5.1 引用作函数参数(输出型参数)5.2 作函数的返回值关于函数的返回值:5.3 引用权限关于类型转换:5.4 引用和指针6. 内联函数6.1 C推荐的…

【千题案例】TypeScript获取两点之间的距离 | 中点 | 补点 | 向量 | 角度

我们在编写一些瞄准、绘制、擦除等功能函数时,经常会遇到计算两点之间的一些参数,那本篇文章就来讲一下两点之间的一系列参数计算。 目录 1️⃣ 两点之间的距离 ①实现原理 ②代码实现及结果 2️⃣两点之间的中点 ①实现原理 ②代码实现及结果 3…

JUC结构

JUC是java.util.concurrent包的简称在Java5.0添加,目的就是为了更好的支持高并发任务。让开发者进行多线程编程时减少竞争条件和死锁的问题!进程与线程的区别:进程 : 一个运行中的程序的集合; 一个进程往往可以包含多个线程,至少包含一个线程…

count、sum、avg、max、min函数MySQL数据库 - 使用聚合函数查询(头歌实践教学平台)

文章目的初衷是希望学习笔记分享给更多的伙伴,并无盈利目的,尊重版权,如有侵犯,请官方工作人员联系博主谢谢。 目录 第1关:COUNT( )函数 任务描述 相关知识 COUNT()函数基本使用 编程要求 第2关:SUM(…

3.Java运算符

Java运算符 运算符基本分为六类:算数运算符、赋值运算符、关系运算符、逻辑运算符、位运算符、三元(条件)运算符。 一、算术运算符 算数运算符,是指在Java运算中,计算数值类型的计算符号,既然是操作数值…

ubuntu下安装与配置samba

参考文章: https://blog.csdn.net/xurongxin2006/article/details/127740629 https://blog.csdn.net/weixin_42758707/article/details/129855529 https://www.linuxidc.com/Linux/2018-11/155466.htm https://blog.csdn.net/flyingcys/article/details/50673167 1、…

SGD,Adam,AdamW,LAMB优化器

一. SGD,Adam,AdamW,LAMB优化器 优化器是用来更新和计算影响模型训练和模型输出的网络参数,使其逼近或达到最优值,从而最小化(或最大化)损失函数。 1. SGD 随机梯度下降是最简单的优化器,它采用了简单的…

Qt音视频开发37-识别鼠标按下像素坐标

一、前言 在和视频交互过程中,用户一般需要在显示视频的通道上点击对应的区域,弹出对应的操作按钮,将当前点击的区域或者绘制的多边形区域坐标或者坐标点集合,发送出去,通知其他设备进行处理。比如识别到很多人脸&…

使用 gzip 压缩数据

gzip 是GNU/Linux平台下常用的压缩软件,处理后缀名.gz的文件。 gzip 、 gunzip 和 zcat 都可以处理这种格式的。但这些工具只能压缩/解压缩单个文件或数据流,无法直接归档目录和多个文件。但是, gzip 可以同tar 和 cpio 这类归档工具配合使用…

JavaWeb——网络的基本概念

目录 一、IP地址 1、定义 2、格式 (1)、A类地址 (2)、B类地址 (3)、C类地址 (4)、特殊地址 二、端口号 三、协议 四、协议分层 1、定义 2、分类 (1&#xf…

pytorch进阶学习(六):如何对训练好的模型进行优化、验证并且对训练过程进行准确率、损失值等的可视化,新手友好超详细记录

课程资源: 7、模型验证与训练过程可视化【小学生都会的Pytorch】【提供源码】_哔哩哔哩_bilibili 推荐与上一节笔记搭配食用~: pytorch进阶学习(五):神经网络迁移学习应用的保姆级详细介绍,如何将训练好…