GRFB UNet——基于多尺度注意网络盲道检测算法实现与模型C++部署

news2024/9/22 9:27:44

1. 概述

盲道是视障人士安全出行的重要辅助设施。识别盲道的形状和位置,对于增强视障人士的自主移动能力至关重要,而视觉分割技术正是应对这一挑战的有效工具。为了显著提升盲道分割的精确度和稳定性,本文提出了一种创新的分割方法,该方法融合了UNet网络与多尺度特征提取技术。本方法在UNet架构中引入了组感受野块(GRFB)的设计,用以捕获盲道的多级视觉信息。通过应用组卷积,该方法有效降低了计算的复杂度。此外,在每个组卷积之后引入了小尺度卷积,以促进不同通道间的信息交流和融合,进而提取更为丰富和高层次的特征。

在本研究中,我们构建并标注了一个包含多种环境条件下盲道的数据集,用以进行实验评估。我们还对本方法与现有的典型网络结构和模块进行了详尽的比较分析。实验结果表明,我们提出的网络在盲道分割任务上的表现超越了其他对比网络,为盲道的检测提供了一个有力的参考,这不仅证明了本方法的有效性,也为视障人士的导航辅助技术的发展做出了贡献。
在这里插入图片描述
训练代码:https://github.com/Chon2020/GRFB-Unet

2. 网络架构

GRFB-UNet网络的结构是指在传统的UNet网络基础上增加了组感受野块(Group Receptive Field Block,简称GRFB)的改进型网络。GRFB的设计旨在通过组卷积来捕获图像中的多尺度特征,从而提高网络对不同尺度目标的识别能力。
(1). 输入层:网络接收输入图像,并开始进行特征提取。
(2). GRFB结构:在UNet的多个阶段中引入GRFB,每个GRFB由多个组卷积层组成,这些层可以并行地从不同尺度捕获图像特征。
(3). 组卷积:GRFB中的组卷积允许网络在每个组内独立地学习特征,这有助于网络专注于不同的空间尺度。
(4). 跨组卷积:在组卷积之后,使用小尺度的卷积核进行跨组卷积,以实现组间的信息整合。
(5). 特征融合:UNet网络的上采样和跳跃连接有助于将低层的高分辨率特征与高层的抽象特征进行融合,增强了特征的表达能力。
(6). 多尺度特征提取:通过GRFB结构,网络能够同时提取不同尺度的特征,这对于理解图像中的局部和全局上下文非常重要。
(7). 输出层:网络的最终输出是一个分割图,它将输入图像中的每个像素分类为属于或不属于目标类别(例如,盲道)。

GRFB-UNet网络的设计特别适合于需要精确定位和多尺度特征提取的图像分割任务,如触觉铺路的分割。通过这种结构,网络能够更好地理解和处理图像中的复杂结构,从而提高分割的准确性和鲁棒性。

在这里插入图片描述
在这里插入图片描述

3.环境安装

3.1 环境安装

conda create -n py python=3.7
conda activate py
conda install pytorch==1.5.1 torchvision==0.6.1 cudatoolkit=10.1 -c pytorch
pip install -r requirements.txt

3.2 模型转换

import torch
import numpy as np
import onnx
from models.unet import GRFBUNet

def to_onnx(model_unet):
    with torch.no_grad():
        img_unet = torch.randn(1,3,640,640)
        outputs_unet = model_unet(img_unet)

        # 导出ONNX文件
        torch.onnx.export(
            model_unet,
            img_unet,
            'grfb_unet.onnx',
            opset_version=11,
            input_names=['input'],
            output_names=['output']
        )

        # prediction = outputs_unet['out'].argmax(1)
        
    return 


def main():
    # 加载unet模型
    model_unet = GRFBUNet(in_channels=3, num_classes=2, base_c=32)
    model_unet.load_state_dict(torch.load('./weights/grfb-unet.pth', map_location='cpu')['model'])
    model_unet.eval() 
    to_onnx(model_unet)


if __name__ == "__main__":
    main()

4. 模型推理

#include <provider_options.h>
#include <onnxruntime_cxx_api.h>
#include <opencv2/opencv.hpp>


int main()
{
    int gpu_index = 0;
    int gpu_ram = 4;
    int num_thread = 4;
    std::string model_path = "grfb_unet.onnx";
    auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
    std::vector<std::string> available_providers = Ort::GetAvailableProviders();
    auto cuda_available = std::find(available_providers.begin(),
        available_providers.end(), "CUDAExecutionProvider");

    Ort::SessionOptions session_options = Ort::SessionOptions();

    session_options.SetInterOpNumThreads(num_thread);
    session_options.SetIntraOpNumThreads(num_thread);
    session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);

    if (gpu_index >= 0 && (cuda_available != available_providers.end()))
    {
        OrtCUDAProviderOptions cuda_options;
        cuda_options.device_id = gpu_index;
        cuda_options.arena_extend_strategy = 0;

        if (gpu_ram == -1)
        {
            cuda_options.gpu_mem_limit = ~0ULL;
        }
        else
        {
            cuda_options.gpu_mem_limit = gpu_ram * 1024 * 1024 * 1024;
        }

        cuda_options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchExhaustive;
        cuda_options.do_copy_in_default_stream = 1;

        session_options.AppendExecutionProvider_CUDA(cuda_options);
    }

    float mean[3] = { 0.709, 0.381, 0.224 };
    float std[3] = { 0.127, 0.079, 0.043 };
    try 
    {
        // 加载模型并创建环境空间
        std::wstring widestr = std::wstring(model_path.begin(), model_path.end());
        Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "grfb_unet");
        Ort::Session session(env, widestr.c_str(), session_options);
        Ort::AllocatorWithDefaultOptions allocator;

        char input[] = "input";
        char output[] = "output";

        std::vector<char*> input_node_names;
        std::vector<char*> output_node_names;

        input_node_names.push_back(input);
        output_node_names.push_back(output);

      
        std::string path = "images";
        std::vector<std::string> filenames;
        cv::glob(path, filenames, false);

        for (auto imgpath : filenames)
        {
            cv::Mat original_image = cv::imread(imgpath, cv::IMREAD_COLOR);
            cv::Mat resized_image;
            cv::resize(original_image, resized_image, cv::Size(640, 640));


            // 确定输入数据维度
            std::vector<int64_t> input_node_dims = { 1,3,640,640 };
            size_t input_tensor_size = 1 * 3 * 640 * 640;

            // 填充数据输入
            std::vector<float> input_tensor_values(input_tensor_size);
            for (int h = 0; h < 640; ++h)
            {
                for (int w = 0; w < 640; ++w)
                {
                    for (int c = 0; c < 3; ++c)
                    {
                        // 均一化像素值
                        float pix = resized_image.at<cv::Vec3b>(h, w)[c];
                        pix = pix / 255.0f;
                        pix = (pix - mean[c]) / std[c];
                        input_tensor_values[640 * 640 * c + h * 640 + w] = pix;
                    }
                }
            }

            Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
                memory_info,
                input_tensor_values.data(),
                input_tensor_size,
                input_node_dims.data(),
                input_node_dims.size()
                );

            assert(input_tensor.IsTensor());

            std::vector<Ort::Value> ort_inputs;
            ort_inputs.push_back(std::move(input_tensor));


            // 启动模型预测并获取输出张量
            auto output_tensors = session.Run(
                Ort::RunOptions{ nullptr },
                input_node_names.data(),
                ort_inputs.data(),
                ort_inputs.size(),
                output_node_names.data(),
                1
            );

            // 解析输出张量
            Ort::Value& output_tensor = output_tensors[0];
            const float* output_data = output_tensor.GetTensorData<float>();
            std::vector<int64_t> output_dims = output_tensor.GetTensorTypeAndShapeInfo().GetShape();

            // 存储输出图像
            cv::Mat result_image(640, 640, CV_8UC1);

            // 对输出的2通道图像进行二分类预测
            for (int h = 0; h < 640; ++h) {
                for (int w = 0; w < 640; ++w) {
                    int index_max = output_data[w + h * 640] > output_data[w + h * 640 + 640 * 640] ? 0 : 1;
                    result_image.at<uchar>(h, w) = 255 * index_max;
                }
            }

            cv::imshow("Resized Image", resized_image);
            // 显示结果
            cv::imshow("Result Image", result_image);
            cv::waitKey(0);
        }

    }
    catch (const Ort::Exception& e) 
    {
        // 打印异常
        std::cerr << "Caught Ort::Exception: " << std::string(e.what()) << std::endl;
        size_t pos = std::string(e.what()).find("ErrorCode: ");
        if (pos != std::string::npos) {
            std::string error_code_str = std::string(e.what()).substr(pos + 12);
            int error_code = std::stoi(error_code_str);
            std::cerr << "Error Code: " << error_code << std::endl;
        }
        return -1;
    }

    return 0;
}

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1974071.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OpenShift 4 - 用 oc-mirror 为离线 OpenShift 集群的 Mirror Registry 同步容器镜像

《OpenShift / RHEL / DevSecOps 汇总目录》 本文适合 OpenShift 4.11 及其以上版本。 文章目录 在离线环境中用 OpenShift 准备 Mirror Registry环境说明向隔离环境复制镜像准备节点环境bastion 节点操作support 节点操作 网络完全隔离环境-复制镜像bastion 节点操作support …

[图解]掉杠·above...duty -《分析模式》漫谈20

1 00:00:01,650 --> 00:00:05,120 今天我们来说一下《分析模式》和掉杠 1 00:00:00,480 --> 00:00:02,800 还是前言这里&#xff0c;有一句话 2 00:00:02,810 --> 00:00:04,850 I will mention 3 00:00:04,860 --> 00:00:05,250 that 4 00:00:05,680 --> 00…

【Golang 面试 - 进阶题】每日 3 题(十四)

✍个人博客&#xff1a;Pandaconda-CSDN博客 &#x1f4e3;专栏地址&#xff1a;http://t.csdnimg.cn/UWz06 &#x1f4da;专栏简介&#xff1a;在这个专栏中&#xff0c;我将会分享 Golang 面试中常见的面试题给大家~ ❤️如果有收获的话&#xff0c;欢迎点赞&#x1f44d;收藏…

实战:MySQL数据同步神器之Canal

1.概叙 场景一&#xff1a;数据增量实时同步 项目中业务数据量比较大&#xff0c;每类业务表都达到千万级别&#xff0c;虽然做了分库分表&#xff0c;每张表数据控制在300W以下&#xff0c;但是效率还是达不到要求&#xff0c;为了提高查询效率&#xff0c;打算使用ES进行数…

Java面试题--JVM大厂篇之破解Java性能瓶颈!深入理解Parallel GC并优化你的应用

目录 引言&#xff1a; 正文&#xff1a; 1. 理解Parallel GC的工作原理 2. 配置Parallel GC 3. 监控和分析GC日志 4. 常见调优技巧 5. 持续迭代和优化 结束语&#xff1a; 补充考虑 1. 综合考虑吞吐量与响应时间 2. 评估和优化垃圾回收频率 3. 动态调整与自适应策…

定期自动巡检,及时发现机房运维管理中的潜在问题

随着信息化技术的迅猛发展&#xff0c;机房作为企业数据处理与存储的核心场所&#xff0c;其运维管理的复杂性和挑战性也与日俱增。为确保机房设备的稳定运行和业务的连续性&#xff0c;运维团队必须定期进行全面的巡检。然而&#xff0c;传统的手工巡检方式不仅效率低下&#…

【卷积神经网络】基于CIFAR10数据集实现图像分类【构建、训练、预测】

文章目录 1、内容简介2、CIFAR10 数据集2.1、数据集概述2.2、代码使用2.2.1、查看数据集基本信息2.2.2、数据加载器2.2.3、完整代码 3、搭建图像分类网络&#x1f53a;3.1、网络结构⭐3.2、代码构建网络⭐ 4、编写训练函数4.1、多分类交叉熵损失函数&#x1f53a;4.2、Adam&…

泛微开发修炼之旅--41Ecology基于触发器实现增量数据同步(人员、部门、岗位、人员关系表、人岗关系表)

一、需求背景 我们在项目上遇到一个需求&#xff0c;需要将组织机构数据&#xff08;包含人员信息、部门信息、分部信息、人岗关系&#xff09;生成的增量数据&#xff0c;实时同步到三方的系统中&#xff0c;三方要求&#xff0c;只需要增量数据即可。 那么基于ecology系统&a…

【C++高阶】:C++11的深度解析上

✨ 心似白云常自在&#xff0c;意如流水任东西 &#x1f30f; &#x1f4c3;个人主页&#xff1a;island1314 &#x1f525;个人专栏&#xff1a;C学习 &#x1f680; 欢迎关注&#xff1a;&#x1f44d;点赞 &#x1f4…

数说故事|引爆社媒的森贝儿IP,品牌如何实现流量变现?

以可爱、雅痞、贱萌......的外表加魔性舞姿出圈的可爱小狗——森贝儿贵宾犬Milo&#xff0c;用“可爱微怒”的表情演绎着当代打工人的“疯态”&#xff0c;并迅速晋升成不少打工人高频使用的表情包。 最近几年&#xff0c;“萌系”爆款IP频出&#xff0c;用小动物的形象、可爱…

一键生成视频并批量上传视频抖音、bilibili、腾讯(已打包)

GenerateAndAutoupload Github地址&#xff1a;https://github.com/cmdch2017/GenerateAndAutoupload 如何下载&#xff08;找到最新的release&#xff09; https://github.com/cmdch2017/GenerateAndAutoupload/releases/download/v1.0.1/v1.0.1.zip 启动必知道 conf.py …

Redis学习[5] ——Redis过期删除和内存淘汰

六、Redis过期键值删除 6.1 Redis的过期键值删除策略 6.1.1 什么是过期键值删除&#xff1f; Redis中是可以对key设置过期时间的&#xff0c;所以需要有相应的机制将已过期的键值对删除&#xff0c;也就是**过期键值删除策略。Redis会用一个过期字典&#xff08;expires dic…

如何改网络的ip地址:实用方法与步骤解析

在数字化时代&#xff0c;网络IP地址作为设备在互联网上的唯一标识&#xff0c;其重要性不言而喻。然而&#xff0c;在某些特定场景下&#xff0c;如网络测试、隐私保护或突破地域限制等&#xff0c;我们可能需要更改网络IP地址。那么&#xff0c;如何安全、有效地实现这一操作…

学习日志:update 没加索引会锁全表

文章目录 前言一、为什么会发生这种的事故如何避免这种事故的发生&#xff1f;总结 前言 在线上执行一条 update 语句修改数据库数据的时候&#xff0c;where 条件没有带上索引&#xff0c;导致业务直接崩了 为什么会发生这种的事故&#xff1f; 又该如何避免这种事故的发生&a…

html+css練習:iconfont使用

1.網址地址&#xff1a;https://www.iconfont.cn/search/index 2.註冊登錄&#xff0c;將需要的圖標添加到購物車 3.下載代碼 4.下載后的代碼有一個html頁面&#xff0c;裡面有詳細的使用方式

Linux进程间通信学习2

文章目录 共享内存信号信号概述以及种类信号的处理信号相关函数&#xff08;简单&#xff09;运用小demo实现ctrlc无法终止进程使用kill函数在程序内部实现一个进程杀死另外一个进程 信号相关函数高级版运用函数小demo 信号量信号量相关函数运用小demo: 共享内存 相比于前三个…

基于微信小程序的宠物服务平台(系统源码+lw+部署文档+讲解等)

文章目录 目录 详细视频演示 系统详细设计截图 微信小程序系统的实现 1.1系统前台功能的实现 2.1微信小程序开发环境搭建 2.2微信开发者工具 2.3程序应用相关技术和知识 2.3.1小程序目录结构以及框架介绍 2.3.2 Java技术 2.3.3 MySQL数据库 2.3.4 SSM框架 源码获…

构建铁路安全防线:EasyCVR视频+AI智能分析赋能铁路上道作业高效监管

一、方案背景 随着我国铁路特别是高速铁路的快速发展&#xff0c;铁路运营里程不断增加&#xff0c;铁路沿线的安全环境对保障铁路运输的安全畅通及人民群众的生命财产安全具有至关重要的作用。铁路沿线安全环境复杂多变&#xff0c;涉及多种风险因素&#xff0c;如人员入侵、…

函数递归超详解!

目录 1.什么是递归调用&#xff1f; 直接调用 间接调用 2.什么是递归&#xff1f; 3.递归举例 3.1求n!的阶乘 3.1.1.非递归法 3.1.2.递归法 3.1.2.1分析和代码实现 3.2顺序打印一个整数的每一位 3.2.1分析和代码实现 4.递归与迭代 4.1举例&#xff1a;斐波那契数列 …

开放式耳机更适合运动的时候使用?开放式耳机推荐指南

开放式耳机确实非常适合运动时使用&#xff0c;原因主要有以下几点。 首先&#xff0c;保持对外界的感知是很重要的一点。在运动的时候&#xff0c;我们需要听到周围的环境声音&#xff0c;比如车辆的行驶声、行人的呼喊等&#xff0c;以便及时做出反应&#xff0c;保证自身安全…