在Rust中使用torch------day1环境配置Yolov8推理

news2025/1/14 1:10:42

现在不管什么专业,貌似多多少少都得和深度学习扯上点关系才好写文章(水文章).其中最常用的深度学习框架应该是tensorflow,pytorch这些python的三方库,既然最近在学Rust那就来借机讲讲torch的Rust绑定—tch-rs

其实tch实际上是基于libtorch的封装,而不是pytorch.因此使用起来如果熟悉libtorch的api的话会更容易,不熟悉也没关系毕竟pytorch的api也差不多;况且我相信愿意弄这个的人应该都是对技术感兴趣的,而不是那些为了造学术垃圾抄点代码改改参数点个run.

废话不多说,今天的主题就是配置好环境然后写一个推理demo

环境配置

1. libtorch安装(Ubuntu)

上面说了tch是基于libtorch的封装,因此底层还是需要使用libtorch的相关api,而且最新版本的要求安装libtorch版本为2.0.0.Torch官网目前给出的libtorch版本是2.1,如果安装之后运行会显示

image-20230630101839557

具体可以看这个issue,同时也给出了解决办法设置一个LIBTORCH_BYPASS_VERSION_CHECK的环境变量来避免版本检查,但是作者也不确定不同版本的api是否会有差异,因此还是建议下载libtorch 2.0版本

根据tch官网的介绍,我们将libtorch下载解压之后在环境变量中配置相应的路径

export PATH="xxx/libtorch:$PATH"
export LIBTORCH=xxx/libtorch
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"xxx/libtorch/lib:$LD_LIBRARY_PATH"
export LIBTORCH_INCLUDE=xxx/libtorch
export LIBTORCH_LIB=xxx/libtorch

这样的话,就可以进行tch的引用.不过这里有个小插曲,由于我之前偷懒把平时需要的python库都装在Base中,而不是虚拟环境.这就造成直接设置为环境变量,在使用python导入torch的时候出现段错误

image-20230630102755375

为了找到错误原因,直接gdb python调试

r -c "import torch"&&bt

image-20230630102929529

引用的是我写在环境变量中的lib,既然是这样我们就不能随便将libtorch直接写入,而是在bashrc中写个函数封装一下

function set_libtorch(){
    export PATH="xxx/libtorch:$PATH"
    export LIBTORCH=xxx/libtorch
    export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"xxx/libtorch/lib:$LD_LIBRARY_PATH"
    export LIBTORCH_INCLUDE=xxx/libtorch
    export LIBTORCH_LIB=xxx/libtorch
    echo "LIBTORCH is set to $LIBTORCH"
}
​

这样,我们需要使用tch时运行一下这个函数就可以顺利编译,平时也能愉快的用python

image-20230630103311585

2. tch使用

我们首先在.toml中添加依赖

[dependencies]
tch="0.13.0"

然后写一个demo mod

use tch::Tensor;
​
pub fn demo() {
    let mut t = Tensor::from_slice(&[3, 1, 4, 1, 5]);
    t.print();
    t = t * 2;
    t.print()
}
​
​
pub fn cuda_is_available(){
    println!("Cuda available: {}", tch::Cuda::is_available());
    println!("Cudnn available: {}", tch::Cuda::cudnn_is_available());
    let device = tch::Device::cuda_if_available();
    println!("Device :{:?}",device);
    let t = Tensor::from_slice(&[1,2,3,4,5]).to(device);
    t.print();
}

在main中调用函数

mod demo1;
​
fn main() {
    demo1::demo();
    demo1::cuda_is_available();
}
​

image-20230630103818605

上面这是在terminal中用命令cargo run main运行,如果更喜欢点一下run运行需要在运行设置中添加环境变量

image-20230630103938067

3. Yolov8推理实战

既然是实战,那就必须得来点硬货.今天,我们就用tch实现一下Yolov8的推理.

大致思路:

  1. 从Yolov8中导出torchscript权重文件.
  1. 构建YOLO结构体,封装前处理,预测,后处理等方法.
  2. 对输出进行绘制并保存

3.1 torchscript导出

这里直接clone源码,然后导出的方法已经封装很好了

from ultralytics import YOLO
model = YOLO("yolov8s.pt") 
success = model.export(format="torchscript",imgsz=(640,640)) 

注意,这里导出不要使用optimize,否则torchscript加载模型会出错.

3.2 YOLO结构体设计

这部分集中了三个部分,模型加载导入与图像加载的前处理,模型推理,模型的后处理(NMS).

3.2.1 模型加载与前处理

模型加载与图像加载我们写在主函数中

let args: Vec<_> = std::env::args().collect();
    let (weights, img_path) = match args.as_slice() {
        [_, w, i] => (std::path::Path::new(w), i.to_owned()),
        _ => bail!("usage: main yolov8.torchscipt image.jpg"),
    };
    let device = tch::Device::cuda_if_available();
    println!("Run inference by device={:?}", device);
    let mut yolov8 = yolo::YOLO::new(weights, 640, 640, 0.25, 0.65, 100, device);
    let img = yolo::YOLO::preprocess(&mut yolov8, &img_path).to_device(yolov8.device);

其中,YOLO类的初始化以及前处理代码如下

pub struct YOLO{
    model:tch::CModule,
    pub device:tch::Device,
    h:i64,
    w:i64,
    conf_threshold:f64,
    iou_threshold:f64,
    top_k:i64,
}
​
​
​
impl YOLO{
    pub fn new(weights: &Path, h: i64, w: i64, conf_threshold:f64, iou_threshold:f64, top_k:i64, device: tch::Device) -> YOLO {
        let mut model = tch::CModule::load_on_device(weights, device).unwrap();
        model.set_eval();
        YOLO {
            model,
            h,
            w,
            device,
            conf_threshold,
            iou_threshold,
            top_k,
        }
    }
​
    pub fn preprocess(&mut self, image_path: &String) ->tch::Tensor{
        let origin_image=tch::vision::image::load(image_path).unwrap();
        let (_,ori_h,ori_w)=origin_image.size3().unwrap();
        self.w=ori_w;self.h=ori_h;
        let img = tch::vision::image::resize(&origin_image, 640,640).unwrap().unsqueeze(0).to_kind(tch::Kind::Float)/255.;
        return img;
    }
}

这里YOLO结构体中保留模型信息,图像大小,推理设备以及NMS的阈值,然后前处理部分只需要将图片resize到(640,640)写死.需要注意的是,我们可以去看python中predict的源码,模型推理的时候对输入的前处理仅仅是resize和规一化,并没有使用normalize.因此,这里前处理的时候千万不要加normalize,也不要使用tch::vision::imagenet::load_image_and_resize()函数,因为这个函数默认会进行normalize

image-20230702075427412

后果就是模型推理的输出无法对齐,导致后续处理结果全部出错.

3.2.2 模型推理

模型推理很简单,只需要调用forward_t()就可以进行正向传播推理.这里为了统计推理耗时,额外加了一些代码

pub fn predict(&self, image: &Tensor) -> Vec<Bbox> {
        let start_time=Instant::now();
        let pred = self
            .model
            .forward_t(image,false);
        let end_time=Instant::now();
        let elapsed_time=end_time.duration_since(start_time);
        println!("YOLOv8 inference time:{} ms",elapsed_time.as_millis());
​
        let pred=pred.to_device(tch::Device::Cpu);
        let start_time=Instant::now();
        let result = self.non_max_suppression(&pred);
        let end_time=Instant::now();
        let elapsed_time=end_time.duration_since(start_time);
        println!("YOLOv8 nms time:{} ms",elapsed_time.as_millis());
        result
    }

关于为什么要计时,这里有个很有趣的问题留到最后再分析.

3.2.3 NMS后处理

为了删除冗余的预测框,提高检测精度,通常都会进行NMS操作.原理就是把置信度低的预测框与邻近冗余的预测框全部删除,只保留筛选后的预测框.这里筛选的条件就是self.conf_thresholdself.iou_threshold.

这里稍微说明一下,Yolov8与之前的Yolov5等模型的输出不同之处在于,它的输出格式为(Batch_size,84,80*80+40*40+20*20=8400),也就是说输出只有回归框+80个类别对应的置信度,而且回归框总数放到了最后一个维度.通常为了遍历的时候内存对齐,还是交换这里的输出维度比较好.

第一步进行置信度筛选

let pred= & pred.transpose(2, 1).squeeze();
let (npreds,pred_size) = pred.size2().unwrap();
let nclasses=pred_size-4;
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
​
let class_index=pred.i((..,4..pred_size));
let (pred_conf,class_label)=class_index.max_dim(-1, false);
// pred_conf.save("pred_conf.pt").expect("pred_conf save err");
// class_label.save("class_label.pt").expect("class_labe; save err");
​
for index in 0..npreds {
    let pred = Vec::<f64>::try_from(pred.get(index)).unwrap();
    let conf = pred_conf.double_value(&[index]);
    if conf>self.conf_threshold{
        let label=class_label.int64_value(&[index]);
        if pred[(4 + label) as usize] > 0. {
            let bbox = Bbox {
                xmin: pred[0] - pred[2] / 2.,
                ymin: pred[1] - pred[3] / 2.,
                xmax: pred[0] + pred[2] / 2.,
                ymax: pred[1] + pred[3] / 2.,
                confidence: conf,
                cls: label,
            };
            bboxes[label as usize].push(bbox);
        }
    }
}

首先通过下标索引得到类别对应的列,然后取最大值作为置信度与类别标签.将置信度大于阈值的检测框进行保存.这里注释掉的两行就是因为一开始输出结果一直不对,为了看输出是否对齐然后进行保存的.然后再进行iou筛选

 for Bboxes_for_class in bboxes.iter_mut() {
     Bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
​
    let mut current_index = 0;
    for index in 0..Bboxes_for_class.len() {
        let mut drop = false;
        for prev_index in 0..current_index {
            let iou = self.iou(&Bboxes_for_class[prev_index], &Bboxes_for_class[index]);
            if iou > self.iou_threshold {
            drop = true;
            break;
            }
        }
        if !drop {
            Bboxes_for_class.swap(current_index, index);
            current_index += 1;
        }
    }
    Bboxes_for_class.truncate(current_index);
}
fn iou(&self, b1: &Bbox, b2: &Bbox) -> f64 {
        let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
        let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
        let i_xmin = b1.xmin.max(b2.xmin);
        let i_xmax = b1.xmax.min(b2.xmax);
        let i_ymin = b1.ymin.max(b2.ymin);
        let i_ymax = b1.ymax.min(b2.ymax);
        let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
        i_area / (b1_area + b2_area - i_area)
}

当然,有可能最后输出的结果还是很多,所以最终用top_k参数作为最终输出个数限制,输出的目标数不应该超过top_k的值.

let mut result = vec![];
let mut count=0;
for Bboxes_for_class in bboxes.iter() {
    for Bbox in Bboxes_for_class.iter() {
        if count>=self.top_k {
            break;
        }
        result.push(*Bbox);
        count+=1;
    }
}
​
return result;

3.3 保存结果

首先查看tch官网yolo的demo,把绘制矩形框部分的代码借用

fn draw_line(&self, t: &mut tch::Tensor, x1: i64, x2: i64, y1: i64, y2: i64) {
        let color = tch::Tensor::from_slice(&[255,0,0]).view([3, 1, 1]);
        t.narrow(2, x1, x2 - x1)
            .narrow(1, y1, y2 - y1)
            .copy_(&color)
    }
​
    pub fn show(&self, image: &mut Tensor, bboxes: &Vec<Bbox>) {
        let w_ratio = self.w as f64 / 640 as f64;
        let h_ratio = self.h as f64 / 640 as f64;
​
        for bbox in bboxes.iter() {
            let xmin= ((bbox.xmin * w_ratio) as i64).clamp(0, self.w);
            let ymin = ((bbox.ymin * h_ratio) as i64).clamp(0, self.h);
            let xmax = ((bbox.xmax * w_ratio) as i64).clamp(0, self.w);
            let ymax = ((bbox.ymax * h_ratio) as i64).clamp(0, self.h);
            self.draw_line(image, xmin, xmax, ymin, ymax.min(ymin + 2));
            self.draw_line(image, xmin, xmax, ymin.max(ymax - 2), ymax);
            self.draw_line(image, xmin, xmax.min(xmin + 2), ymin, ymax);
            self.draw_line(image, xmin.max(xmax - 2), xmax, ymin, ymax);
        }
        tch::vision::image::save(&image, "./result.jpg").unwrap();
    }

但是这样我们仅仅只能绘制出矩形检测框,我如果还想绘制出类别也就是在图上绘制文字,就需要引入别的依赖

image="0.24.6"
rusttype="0.9.3"
imageproc="0.23.0"

然后读取保存的结果图片进行二次加工绘制

fn text(bboxes: Vec<Bbox>) {
    let mut image = open("./result.jpg").unwrap().into_rgb8();
    let font = Vec::from(include_bytes!("./DejaVuSans.ttf") as &[u8]);
    let font = Font::try_from_vec(font).unwrap();
    let size = 20.;
    let scale = Scale {
        x: size * 1.5,
        y: size * 2.,
    };
    let w_ratio = image.width() as f64 / 640 as f64;
    let h_ratio = image.height() as f64 / 640 as f64;
    for bbox in bboxes.iter() {
        println!(
            "xmin={},ymin={},xmax={},ymax={},class_label={},confidence={}",
            bbox.xmin,
            bbox.ymin,
            bbox.xmax,
            bbox.ymax,
            coco_names::COCO_NAMES[bbox.cls as usize],
            bbox.confidence
        );
        let text = coco_names::COCO_NAMES[bbox.cls as usize];
        draw_text_mut(
            &mut image,
            Rgb([255u8, 0u8, 0u8]),
            (bbox.xmin * w_ratio) as i32 + 10,
            (bbox.ymin * h_ratio) as i32 + 10,
            scale,
            &font,
            text,
        );
        let _ = image.save("./result.jpg").unwrap();
    }
}

几乎同样的操作重复两次实属无奈,在tch中貌似没有找到相关的api实现在Tensor上绘制文字信息.

经过上面的努力,运行看看效果

cargo run main yolov8s.torchscript bus.jpg

总结

我通过tch用Rust实现了对Yolov8的推理,并且最终输出结果与实际推理结果一致,很好的做到了精度对齐.不过这里面还是存在一点点小问题的,模型推理讲究的是推理时间与实时性要求,来看看在不同device下我推理时间的结果.

image

image

在cpu比较正常的0.17s左右,而在GPU上推理居然花费了1s左右.这个完全不能理解,即是是一张图片而不是batch作为推理,按道理GPU和CPU推理速度的差异也不应该这么大.即使是由于数据upload到GPU中造成一些耗时,但是我明明在传入之前就已经保证数据在device中了,可推理耗时依旧在1s左右.详情可以参考官网的issue.目前我的猜想是某些机制导致数据的重复拷贝上传到device,不过还需要进一步验证.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/720328.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

云原生——什么是云原生数据库?

❄️作者介绍&#xff1a;奇妙的大歪❄️ &#x1f380;个人名言&#xff1a;但行前路&#xff0c;不负韶华&#xff01;&#x1f380; &#x1f43d;个人简介&#xff1a;云计算网络运维专业人员&#x1f43d; 前言 突然间&#xff0c;云原生数据库就火了。根据IDC《2021年下半…

WinForm中使用AnyCAD控件

一、添加DLL程序集 AnyCAD.Foundation.Net.dll AnyCAD.Presentation.Net.dll AnyCAD.Exchange.Net.dll 二、初始化控件 1.首先创建一个窗体 2.在窗体上放置一个Panel用来放置三维控件 3.初始化控件 完整代码如下&#xff1a; using AnyCAD.Presentation; using System; …

Generalized Category Discovery(论文翻译)

Generalized Category Discovery 摘要1.导言2.相关工作3.广义类发现3.1 我们的方法 图1.我们提出一个新的设置&#xff1a;“广义类别发现”及其解决方法。我们的设置可以简洁地描述为&#xff1a;给定一个子集具有类标签的数据集&#xff0c;对数据集中所有未标记的图像进行分…

【嵌入式Qt开发入门】如何使用Qt进行绘图——QPainter 绘图

绘图与图表在嵌入式里有的比较多&#xff0c;尤其是图表&#xff0c;我们常在股票里看到的“图表折线/曲线 图/饼状图等”都可以用 Qt 的图表来实现。绘图和图表的内容本章主要介绍绘图和图表的基本操作&#xff0c;以简单的例子呈现绘图与图表的用法&#xff0c;目的就是快速入…

抖音怎么私信发名片

抖音怎么私信发名片&#xff0c;抖音私信卡片制作教程来了&#xff0c;视频版教程#新媒体运营工具#软件#抖音消息卡片 hello&#xff0c;大家&#xff0c;我是百收网SEO&#xff0c;今天给大家说一下个人号自动回复卡片&#xff0c;相比企业号自动回复卡片&#xff0c;它的优势…

MATLAB | 终于找到了修改图例图标的方法(可以自己设计图例啦?)

讲一点扒MATLAB底裤的事情叭&#xff0c;就是之前写的有一些绘图函数&#xff0c;比如阴影柱状图&#xff0c;想要把图例里的图标进行修改让其也带着阴影&#xff0c;我采取的是直接绘制一些会检测图例框移动的阴影图标来冒充图例的图标&#xff0c;那么有没有办法真正的自定义…

如何轻松应对广泛存在开源“0Day”

跟不少安全人员讨论过一个很通俗的话题 作为网安从业者最讨厌的是什么&#xff1f; 不同的人给了很多不同的答案 有的人说&#xff1a; 但更多的人都在说&#xff1a; “零日漏洞”(zero-day)又叫零时差攻击&#xff0c;是指被发现后立即被恶意利用的安全漏洞。通俗地讲&…

QScintilla自制代码编辑器系列(1)编译库文件与运行测试例子

1.下载工程源码 我本人机器上的QT是6.4 可以下载最新的代码 https://www.riverbankcomputing.com/static/Downloads/QScintilla/2.14.0/QScintilla_src-2.14.0.zip 2. 编译生成文件 无需改动可以顺利生成库文件 3. 运行例子 1&#xff09;拷贝头文件 将整个Qsci文件夹拷…

面试官:你的项目有什么亮点?我:解决了JS脚本加载失败的问题!

前后端面试题库 &#xff08;面试必备&#xff09; 推荐&#xff1a;★★★★★ 地址&#xff1a;前端面试题库 web前端面试题库 VS java后端面试题库大全 面试官&#xff1a;你的项目有什么亮点&#xff1f;解决了什么问题&#xff1f; 你&#xff1a;嗯...... 面试官&#…

7.3 【Linux】磁盘的分区、格式化、检验与挂载

想在系统中新增一颗磁盘时&#xff0c;需要做&#xff1a; 1.对磁盘进行分区&#xff0c;以创建可用的partition&#xff1b; 2.对该partition进行格式化&#xff08;format),以创建系统可用的filesystem&#xff1b; 3.可对刚刚创建好的filesystem进行检验&#xff1b; 4.…

Acrel-5000重点用能单位能耗在线监测系统在湖南三立集团的案例分析

安科瑞 崔丽洁 摘要&#xff1a;根据《重点用能节能办法》&#xff08;国家发展改革委等第七部委2018年15号令&#xff09;、《重点用能单位能耗在线监测系统推广建设工作方案》&#xff08;发改环资[2017]1711号&#xff09;和《关于加速推进重点用能单位能耗在线监测系统建设…

介绍几种OPTIONS检测的方法

概述 日常的VOIP开发中&#xff0c;OPTIONS检测是常用的网络状态检测工具。 OPTIONS原本是作为获取对方能力的消息&#xff0c;也可以检测当前服务状态。正常情况下&#xff0c;UAS收到OPTIONS心跳&#xff0c;直接回复200即可。 与ping不同的是&#xff0c;OPTIONS检测不仅…

前后端分离开发

目录 前后合开发&#xff08;不推荐&#xff09; 前后端分离开发&#xff08;主流&#xff09; 项目开发的基本步骤 接口文档的管理平台--YApi 前后合开发&#xff08;不推荐&#xff09; 沟通成本高分工不明确不便于管理不便于扩展 前后端分离开发&#xff08;主流&…

让浮动元素在一行显示

&#x1f4dd;个人主页&#xff1a;爱吃炫迈 &#x1f48c;系列专栏&#xff1a;HTMLCSS &#x1f9d1;‍&#x1f4bb;座右铭&#xff1a;道阻且长&#xff0c;行则将至&#x1f497; <div class"wrap"><div class"item">1</div><di…

ubuntu实现自动挂载u盘

ubuntu实现自动挂载u盘 但是&#xff0c;有些设施可以在没有图形工具的情况下进行复制&#xff0c;并且在系统上占用的空间非常小。 例如&#xff0c;在我的设置中&#xff0c;我已经实现了USB自动挂载服务&#xff0c;而无需使用任何外部工具/服务&#xff0c;只有udev和syst…

Junit5相关技术

Selenium自动化测试框架 Junit针对Java的单元测试框架 拿一个技术写自动化测试用例&#xff08;Selenium3&#xff09; 拿一个技术管理已经编写好的测试用例(Junit5) 写代码前需要添加依赖&#xff1a;Junit5 一、注解 1.1 Test 表示当前这个方法是一个测试用例 1.2 Di…

DCN v2阅读笔记

Deformable ConvNets v2: More Deformable, Better Results 是 Deformable Convolutional Networks 研究的续作&#xff0c;发表在 CVPR 2019上。 作者对 DCNv1 的自适应行为进行研究&#xff0c;观察到虽然其神经特征的空间支持比常规的卷积神经网络更符合物体结构&#xff0…

2023年的无线蓝牙耳机哪些牌子好,真无线蓝牙耳机品牌排名

本文将为您详细介绍每款蓝牙耳机的设计特点、音质表现、续航能力和智能功能等关键信息。我们将提供客观、全面的分析&#xff0c;帮助您更好地了解每款产品的优势和适用场景&#xff1b;无论您是追求高保真音质的音乐发烧友&#xff0c;还是需要轻便舒适的耳机进行运动&#xf…

Kafka入门, 消费者工作流程(十八)

kafka消费方式 pull(拉)模式&#xff1a; consumer采用从broker中主动拉取数据。 Kafka采用这种方式。 push(推)模式&#xff1a; Kafka没有采用这种方式&#xff0c;因为由broker决定消息发送速率&#xff0c;很难适应所有消费者的速率。例如推送速度是50m/s&#xff0c;consu…

rocketmq客户端日志过大造成磁盘使用率占用过高

目录 问题现象 排查占用 自定义客户端日志配置未生效 总结 问题现象 收到项目报警&#xff1a;磁盘占用率超标通知 排查占用 从上述可以看出&#xff0c;实质是跟正常业务日志无关的&#xff0c;/home/work/log挂出来了&#xff0c;与/根目录下无关 查看根目录下日志占用…