LaneATT推理详解及部署实现(上)

news2025/1/20 22:46:23

目录

    • 前言
    • 1. 概述
    • 2. 环境配置
    • 3. Demo测试
    • 4. ONNX导出初探
    • 5. ONNX导出优化
    • 6. ONNX导出总结
    • 结语
    • 下载链接
    • 参考

前言

最近想关注下车道线检测任务,在 GitHub 上找了一个模型 LaneATT,想通过调试分析 LaneATT 代码把 LaneATT 模型导出来,并在 tensorRT 上推理得到结果,这篇文章主要分析 LaneATT 模型的 ONNX 导出以及解决导出过程中遇到的各种问题。若有问题欢迎各位看官批评指正😄

paper:Keep your Eyes on the Lane: Real-time Attention-guided Lane Detection

repo:https://github.com/lucastabelini/LaneATT

1. 概述

车道线检测Lane Detection)是一项计算机视觉任务,涉及在道路场景的视频或图像中识别行车道的边界。 其目标是实时准确地定位和跟踪车道标记,即使在光线不足、眩光或道路布局复杂等恶劣条件下也不例外。

车道线检测是高级驾驶辅助系统(ADAS)和自动驾驶汽车的重要组成部分,因为它能提供有关道路布局和车辆在车道内位置的信息,这对导航和安全至关重要。 这些算法通常结合使用边缘检测、色彩过滤和霍夫变换等计算机视觉技术,来识别和跟踪道路场景中的车道标记。

车道线检测的数据集有很多,包括 CULane、TuSimple、CurveLanes、LLAMAS、OpenLane 等等,我们这里主要介绍下 LaneATT 模型中使用到的 CULane、TuSimple 以及 LLAMAS 数据集

CULane 是一个用于交通车道线检测学术研究的大型挑战性数据集,该数据集由安装在北京六辆不同司机驾驶的不同车辆上的摄像头收集的,收集的视频时长超过 55 小时,提取的帧数为 133,235 帧,数据集分为 88,880 张训练集图像、9,675 张验证集图像和 34,680 张测试集图像,测试集分为正常类别和 8 个挑战类别,获取地址:https://xingangpan.github.io/projects/CULane.html

在这里插入图片描述

TuSimple 数据集包含 6,408 张美国高速公路的道路图像,图像分辨率为 1280×720,数据集由 3,626 张训练图像、358 张验证图像和 2,782 张测试图像组成,其中的图像处于不同的天气条件下,获取地址:https://github.com/TuSimple/tusimple-benchmark

在这里插入图片描述

无监督标记车道线数据集(LLAMAS)是一个用于车道检测和分割的数据集,它包含 100,000 多张标注图像,标注距离超过 100 米,分辨率为 1276 x 717,获取地址:https://unsupervised-llamas.com/llamas

在这里插入图片描述

值得注意的是还有一些用于 3D Lane Detection 的车道线数据集,例如 OpenLane、OpenLane-V2、Apollo Synthetic 3D Lane、ONCE-3DLanes 等等

关于车道线检测数据集的更详细介绍大家可以参考:车道线检测数据集介绍

车道线检测任务主要的难点有:

  • 实时性
  • 非结构化与非标准化
  • 相比目标检测的 corner case 更多
  • 强依赖视觉传感器,光线敏感

车道线检测方法(2D)主要可以分为以下几类:

在这里插入图片描述

值得注意的是 down-top 分割方案的后处理部分可能有些复杂但是工程实用性比较强,表达能力更好,row-wise 分类方案学术界可能使用偏多,端到端多项式预测方案直接预测车道线参数方程一般是三次曲线,但是其局限性比较大,anchor-based 方案需要考虑先验信息。从上图中可知博主这里分享的 LaneATT 是一种基于 anchor-based 的车道线检测方法

LaneATT 使用 resnet 作为特征提取,生成一个特征映射,然后汇集起来提取每个 anchor 的特征。 这些特性与一组由注意力模块产生的全局特征相结合,通过结合局部和全局特征,这在遮挡或没有可见车道标记的情况下可以更容易地使用来自其他车道的信息。 最后,将组合的特征传递给全连接层,以预测最终的输出车道,整个框架如下图所示,还是比较清晰的

在这里插入图片描述

那现在主流的 2D Lane Detection 方法有哪些呢?我们来看下排行榜(CULane):

在这里插入图片描述

从排行榜中我们看到的最多的是 CLRerNet、CondLSTR 以及 CLRNet,那像我们前面列出来的几种方案都排在后面,比如 GANet 排在 14,LaneATT 排在 31,更多内容大家可以参考:https://paperswithcode.com/sota/lane-detection-on-culane

看完 2D 我们再来看看 3D Lane Detection(OpenLane):

在这里插入图片描述

3D 车道线检测排名靠前的方案都是最近提出来的,比如 PVALane、LATR、RFTR 等等,相比 2D 而言还是比较火的,而且大部分方案都是基于 BEV、transformer 这些东西

值得注意的是上述榜单可能并没有那么多的关注速度、部署的难度以及工程实用性,那这些是我们在实际工程应用中需要考虑的

2. 环境配置

在开始之前我们有必要配置下环境,LaneATT 的环境可以通过 LaneATT/README.md 文档中安装,这里有个点需要大家注意,那就是 LaneATT 官方后处理的 NMS 部分是在 CUDA 上实现的,因此需要编译,这个在 Windows 上面折腾可能比较麻烦,博主直接在 Linux 上操作的

博主这里准备了一个可以运行 demo 和导出 ONNX 的环境,大家可以按照这个环境来,也可以自己参考文档进行相关环境配置

博主的环境安装指令如下所示:

conda create -n laneatt python=3.10
conda activate laneatt
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2
pip install pyyaml opencv-python scipy imgaug numpy==1.26.4 tqdm p_tqdm ujson scikit-learn tensorboard
pip install onnx onnxruntime onnx-simplifier

可能大家有所困惑,为什么需要的 torch 版本比较高,这个其实取决于你的 CUDA 版本,博主 Linux 主机的 CUDA 版本是 11.6,如果安装的 torch 版本过低,会导致编译的 NMS 插件无法通过,这个大家根据自己的实际情况来就行。另外需要注意的是后续的 ONNX 导出其实并不需要这个环境,这里只是为了 demo 测试以及调试梳理 LaneATT 前后处理需要

Note:这个环境博主目前只用于 demo 测试和 ONNX 导出,并不包含训练

为了不必要的错误,博主将虚拟环境中各个软件的版本都罗列出来,方便大家查看,环境如下:

Package                  Version
------------------------ -----------
absl-py                  2.1.0
certifi                  2024.7.4
charset-normalizer       3.3.2
cmake                    3.30.1
coloredlogs              15.0.1
contourpy                1.2.1
cycler                   0.12.1
dill                     0.3.8
filelock                 3.15.4
flatbuffers              24.3.25
fonttools                4.53.1
fsspec                   2024.6.1
grpcio                   1.65.4
humanfriendly            10.0
idna                     3.7
imageio                  2.34.2
imgaug                   0.4.0
Jinja2                   3.1.4
joblib                   1.4.2
kiwisolver               1.4.5
lazy_loader              0.4
lit                      18.1.8
Markdown                 3.6
markdown-it-py           3.0.0
MarkupSafe               2.1.5
matplotlib               3.9.1
mdurl                    0.1.2
mpmath                   1.3.0
multiprocess             0.70.16
networkx                 3.3
nms                      0.0.0
numpy                    1.26.4
nvidia-cublas-cu11       11.10.3.66
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu11   11.7.101
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu11   11.7.99
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu11        8.5.0.96
nvidia-cudnn-cu12        8.9.2.26
nvidia-cufft-cu11        10.9.0.58
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu11       10.2.10.91
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu11     11.4.0.1
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu11     11.7.4.91
nvidia-cusparse-cu12     12.1.0.106
nvidia-nccl-cu11         2.14.3
nvidia-nccl-cu12         2.18.1
nvidia-nvjitlink-cu12    12.6.20
nvidia-nvtx-cu11         11.7.91
nvidia-nvtx-cu12         12.1.105
onnx                     1.16.2
onnx-simplifier          0.4.36
onnxruntime              1.18.1
opencv-python            4.10.0.84
p_tqdm                   1.4.0
packaging                24.1
pathos                   0.3.2
pillow                   10.4.0
pip                      24.0
pox                      0.3.4
ppft                     1.7.6.8
protobuf                 4.25.4
Pygments                 2.18.0
pyparsing                3.1.2
python-dateutil          2.9.0.post0
PyYAML                   6.0.1
requests                 2.32.3
rich                     13.7.1
scikit-image             0.24.0
scikit-learn             1.5.1
scipy                    1.14.0
setuptools               69.5.1
shapely                  2.0.5
six                      1.16.0
sympy                    1.13.1
tensorboard              2.17.0
tensorboard-data-server  0.7.2
threadpoolctl            3.5.0
tifffile                 2024.7.24
torch                    2.0.1
torchaudio               2.0.2
torchvision              0.15.2
tqdm                     4.66.4
triton                   2.0.0
typing_extensions        4.12.2
ujson                    5.10.0
urllib3                  2.2.2
Werkzeug                 3.0.3
wheel                    0.43.0

3. Demo测试

OK,环境准备好后我们就可以开始执行 demo,具体流程可以参照:https://github.com/lucastabelini/LaneATT/README.md#3-getting-started

我们一个个来,首先是推理验证测试,教程给的推理脚本如下所示:

python main.py test --exp_name example

在这之前我们需要把 LaneATT 这个项目给 clone 下来,执行如下指令:

git clone https://github.com/lucastabelini/LaneATT.git

也可手动点击下载,点击右上角的 Code 按键,将代码下载下来。至此整个项目就已经准备好了。

接着需要把 NMS 插件编译下,方便后续 demo 的运行,指令如下:

cd LaneATT/lib/nms
python setup.py install

输出如下所示:

在这里插入图片描述

大家如果看到上述输出内容则说明 NMS 插件编译成功了

此外还要下载相关的数据集和预训练权重用于 Demo 测试和 ONNX 导出

数据集的下载可以参考:LaneATT/DATASETS.md

预训练权重的下载可以通过如下指令获取:

gdown "https://drive.google.com/uc?id=1R638ou1AMncTCRvrkQY6I-11CPwZy23T" # main experiments on TuSimple, CULane and LLAMAS (1.3 GB)
unzip laneatt_experiments.zip

值得注意的是数据集和权重都比较大,官方提供了 CULane、TuSimple 以及 LLAMAS 数据集,并且提供了分别利用这三种数据集训练的 resnet18、resnet34 以及 resnet121 三种权重。博主这里准备了 Demo 测试使用的权重和数据集,其中权重是 r18_culane 和 r34_culane,数据集是 culane 部分测试数据集,大家可以点击 here 下载,下载好后在 LaneATT 目录下进行解压,解压后的整个目录如下所示:

在这里插入图片描述

源码、数据集和模型都准备好后,执行如下指令即可进行推理:

conda activate laneatt
python main.py test --exp_name laneatt_r34_culane

你可能会遇到如下的问题:

在这里插入图片描述

这主要是因为 numpy 版本导致的一些 API 变化,按照提示我们将 np.bool 修改为 np.bool_ 即可,修改内容如下:

# lib/models/laneatt.py 323 行

# mask = ~((((lane_xs[:start] >= 0.) &
#             (lane_xs[:start] <= 1.)).cpu().numpy()[::-1].cumprod()[::-1]).astype(np.bool))

mask = ~((((lane_xs[:start] >= 0.) &
            (lane_xs[:start] <= 1.)).cpu().numpy()[::-1].cumprod()[::-1]).astype(np.bool_))

修改后再次执行,输出如下:

在这里插入图片描述

可以看到测试数据集的各个精度,说明整个程序执行成功了,不过没有一些可视化的结果看着不直观,因此这里博主简单写了一个小 demo 来推理一张图片并进行可视化,代码如下:

import cv2
import torch
import numpy as np
from lib.models.laneatt import LaneATT

def preprocess(img, dst_width=640, dst_height=360):
    img_pre = cv2.resize(img, (dst_width, dst_height))
    img_pre = (img_pre / 255.0).astype(np.float32)
    img_pre = img_pre.transpose(2, 0, 1)[None]
    img_pre = torch.from_numpy(img_pre)
    return img_pre

if __name__ == "__main__":

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

    img = cv2.imread("datasets/culane/driver_37_30frame/05181432_0203.MP4/00210.jpg")
    img_pre = preprocess(img).to(device)

    model = LaneATT(anchors_freq_path="data/culane_anchors_freq.pt", topk_anchors=1000)
    state_dict = torch.load("experiments/laneatt_r34_culane/models/model_0015.pt")['model']
    model.load_state_dict(state_dict)
    model = model.to(device)

    model.eval()
    with torch.no_grad():
        output = model(img_pre, conf_threshold=0.5, nms_thres=50.0, nms_topk=4)
        pred = model.decode(output, as_lanes=True)[0]
        for line in pred:
            points = line.points
            points[:, 0] *= img.shape[1]
            points[:, 1] *= img.shape[0]
            points = points.round().astype(int)
            for point in points:
                cv2.circle(img, point, 3, color=(0, 255, 0), thickness=-1)
        cv2.imwrite("result.jpg", img)

执行该脚本后在当前目录下会生成 result.jpg 推理结果图片,如下图所示:

在这里插入图片描述

可以看到成功推理了,下面我们来分析 ONNX 模型的导出

4. ONNX导出初探

博主这里采用的是 vscode 进行代码的调试,其中的 launch.json 文件内容如下:

{
    // 使用 IntelliSense 了解相关属性。 
    // 悬停以查看现有属性的描述。
    // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
    "version": "0.2.0",
    "configurations": [
        {
            "name": "Python 调试程序: 当前文件",
            "type": "debugpy",
            "request": "launch",
            "program": "${file}",
            "console": "integratedTerminal",
            "args": [
                "test",
                "--exp_name", "laneatt_r34_culane"
            ]
        }
    ]
}

要调试的文件是 main.py,在 main 函数中打个断点我们来开始调试:

在这里插入图片描述

调试会发现我们调用 runner.eval 函数来进行验证推理,我们找下模型构建的地方:

在这里插入图片描述

在 eval 函数中我们可以非常清晰的找到 build model 的地方,从调试信息来看 model 就是一个正常的 pytorch 模型,因此我们其实可以直接在这里尝试导出下 ONNX

在导出之前其实还有个问题需要解决,我们先看下 LaneATT 模型的 forward 部分:

在这里插入图片描述

从上图中我们可以看到 forward 部分有把 nms 给添加进去,我们期望导出的 ONNX 并不需要这部分,因此我们修改下 forward 部分:

# lib/models/laneatt.py 108 行

# Apply nms
# proposals_list = self.nms(reg_proposals, attention_matrix, nms_thres, nms_topk, conf_threshold)

# return proposals_list

return reg_proposals

在 forward 中我们直接把 reg_proposals 结果返回即可,nms 部分我们放在模型后处理中去做

接着我们需要在 eval 函数中新增如下导出代码:

# lib/runner.py 79 行

model.load_state_dict(self.exp.get_epoch_model(epoch))

# =====================================================================
model = model.to("cpu")
dummy_input = torch.randn(1, 3 ,360, 640)
torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=["images"],
    output_names=["output"],
)
print(f"finished export onnx model")

import onnx
model_onnx = onnx.load("model.onnx")
onnx.checker.check_model(model_onnx)    # check onnx model

# Simplify
try:
    import onnxsim

    print(f"simplifying with onnxsim {onnxsim.__version__}...")
    model_onnx, check = onnxsim.simplify(model_onnx)
    assert check, "Simplified ONNX model could not be validated"
except Exception as e:
    print(f"simplifier failure: {e}")

onnx.save(model_onnx, "model.dynamic.sim.onnx")
print(f"simplify done. onnx model save in model.sim.onnx")
return
# =====================================================================

再来执行如下指令:

python main.py test --exp_name laneatt_r34_culane

输出如下所示:

在这里插入图片描述

执行成功后会在当前目录下生成 model.sim.onnx 模型文件,我们一起来看下这个模型文件:

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

可以看到这个模型文件总体还是比较干净的,resnet34 加上 attention,一路到底,输入输出也没有什么问题

我们再来看下动态 batch 模型的导出,简单增加下动态维度:

dynamic_batch = {'images': {0: 'batch'}, 'output': {0: 'batch'}}
torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=["images"],
    output_names=["output"],
    dynamic_axes=dynamic_batch
)

再次执行后生成的 ONNX 模型就是 batch 维度动态,如下所示:

在这里插入图片描述

可以看到输入输出都保证了 batch 维度动态,似乎没有什么问题,但是大家往后看这个模型的结构会发现一团糟:

在这里插入图片描述

这主要是因为 attention 中一些 shape 节点的 trace 以及 anchor 的处理导致导出的动态 batch 模型复杂度非常高,下面我们来看看如何优化这个 ONNX 模型让它尽量简洁一些

5. ONNX导出优化

这里有一个不错的 repo 供大家参考:https://github.com/Yibin122/TensorRT-LaneATT

这个 repo 重写了 LaneATT forward 部分让其更加简洁,并且提供了 TensorRT 部署代码,博主这里主要参考了该 repo,只不过进行了一些修改,下面我们一起来看下

该 repo 提供的 laneatt_to_onnx.py 导出代码如下:

import torch

from lib.models.laneatt import LaneATT


class LaneATTONNX(torch.nn.Module):
    def __init__(self, model):
        super(LaneATTONNX, self).__init__()
        # Params
        self.fmap_h = model.fmap_h  # 11
        self.fmap_w = model.fmap_w  # 20
        self.anchor_feat_channels = model.anchor_feat_channels  # 64
        self.anchors = model.anchors
        self.cut_xs = model.cut_xs
        self.cut_ys = model.cut_ys
        self.cut_zs = model.cut_zs
        self.invalid_mask = model.invalid_mask
        # Layers
        self.feature_extractor = model.feature_extractor
        self.conv1 = model.conv1
        self.cls_layer = model.cls_layer
        self.reg_layer = model.reg_layer
        self.attention_layer = model.attention_layer

        # Exporting the operator eye to ONNX opset version 11 is not supported
        attention_matrix = torch.eye(1000)
        self.non_diag_inds = torch.nonzero(attention_matrix == 0., as_tuple=False)
        self.non_diag_inds = self.non_diag_inds[:, 1] + 1000 * self.non_diag_inds[:, 0]  # 999000

    def forward(self, x):
        batch_features = self.feature_extractor(x)
        batch_features = self.conv1(batch_features)
        # batch_anchor_features = self.cut_anchor_features(batch_features)
        batch_anchor_features = batch_features[0].flatten()
        # h, w = batch_features.shape[2:4]  # 12, 20
        batch_anchor_features = batch_anchor_features[self.cut_xs + 20 * self.cut_ys + 12 * 20 * self.cut_zs].\
            view(1000, self.anchor_feat_channels, self.fmap_h, 1)
        # batch_anchor_features[self.invalid_mask] = 0
        batch_anchor_features = batch_anchor_features * torch.logical_not(self.invalid_mask)

        # Join proposals from all images into a single proposals features batch
        batch_anchor_features = batch_anchor_features.view(-1, self.anchor_feat_channels * self.fmap_h)

        # Add attention features
        softmax = torch.nn.Softmax(dim=1)
        scores = self.attention_layer(batch_anchor_features)
        attention = softmax(scores)
        attention_matrix = torch.zeros(1000 * 1000, device=x.device)
        attention_matrix[self.non_diag_inds] = attention.flatten()  # ScatterND
        attention_matrix = attention_matrix.view(1000, 1000)
        attention_features = torch.matmul(torch.transpose(batch_anchor_features, 0, 1),
                                          torch.transpose(attention_matrix, 0, 1)).transpose(0, 1)
        batch_anchor_features = torch.cat((attention_features, batch_anchor_features), dim=1)

        # Predict
        cls_logits = self.cls_layer(batch_anchor_features)
        reg = self.reg_layer(batch_anchor_features)

        # Add offsets to anchors (1000, 2+2+73)
        reg_proposals = torch.cat([softmax(cls_logits), self.anchors[:, 2:4], self.anchors[:, 4:] + reg], dim=1)

        return reg_proposals


def export_onnx(onnx_file_path):
    # e.g. laneatt_r18_culane
    backbone_name = 'resnet18'
    checkpoint_file_path = 'experiments/laneatt_r18_culane/models/model_0015.pt'
    anchors_freq_path = 'culane_anchors_freq.pt'

    # Load specified checkpoint
    model = LaneATT(backbone=backbone_name, anchors_freq_path=anchors_freq_path, topk_anchors=1000)
    checkpoint = torch.load(checkpoint_file_path)
    model.load_state_dict(checkpoint['model'])
    model.eval()

    # Export to ONNX
    onnx_model = LaneATTONNX(model)
    dummy_input = torch.randn(1, 3, 360, 640)
    torch.onnx.export(onnx_model, dummy_input, onnx_file_path, opset_version=11)


if __name__ == '__main__':
    export_onnx('./LaneATT_test.onnx')

我们修改一个地方即可:

# laneatt_to_onnx.py 69 行

# anchors_freq_path = 'culane_anchors_freq.pt'
anchors_freq_path = 'data/culane_anchors_freq.pt'

接着在终端执行下该脚本:

python laneatt_to_onnx.py

执行成功后会在当前目录下生成 LaneATT_test.onnx 模型,我们一起来看下这个模型结构:

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

从上图中我们可以看到导出的 ONNX 模型是静态 batch 模型,但是输出的 batch 维度似乎被作者给干掉了,此外模型后半部分似乎简洁了一些,这主要是因为作者对之前 forward 中的 self.cut_anchor_features 函数进行了部分重写,还有作者把分类分支的 softmatx 直接在 forward 中就做了,这个就非常好

Notelaneatt_to_onnx.py 中测试使用的是 resnet18_culane 模型

那么我们还需要做以下几件事情:

  • 1. 修改代码保证输出的 batch 维度
  • 2. 修改输入输出节点名
  • 2. 利用 onnx-simplifier 简化导出模型
  • 3. 导出动态 batch 模型看是否存在问题

我们先来做前面三件事,修改后的 laneatt_to_onnx.py 代码如下:

import torch
from lib.models.laneatt import LaneATT

class LaneATTONNX(torch.nn.Module):
    def __init__(self, model):
        super(LaneATTONNX, self).__init__()
        # Params
        self.fmap_h = model.fmap_h  # 11
        self.fmap_w = model.fmap_w  # 20
        self.anchor_feat_channels = model.anchor_feat_channels  # 64
        self.anchors = model.anchors
        self.cut_xs = model.cut_xs
        self.cut_ys = model.cut_ys
        self.cut_zs = model.cut_zs
        self.invalid_mask = model.invalid_mask
        # Layers
        self.feature_extractor = model.feature_extractor
        self.conv1 = model.conv1
        self.cls_layer = model.cls_layer
        self.reg_layer = model.reg_layer
        self.attention_layer = model.attention_layer

        # Exporting the operator eye to ONNX opset version 11 is not supported
        attention_matrix = torch.eye(1000)
        self.non_diag_inds = torch.nonzero(attention_matrix == 0., as_tuple=False)
        self.non_diag_inds = self.non_diag_inds[:, 1] + 1000 * self.non_diag_inds[:, 0]  # 999000

    def forward(self, x):
        batch_features = self.feature_extractor(x)
        batch_features = self.conv1(batch_features)
        # batch_anchor_features = self.cut_anchor_features(batch_features)
        # batchx15360
        batch_anchor_features = batch_features.reshape(-1, int(batch_features.numel()))
        # h, w = batch_features.shape[2:4]  # 12, 20
        indices = self.cut_xs + 20 * self.cut_ys + 12 * 20 * self.cut_zs        
        batch_anchor_features = batch_anchor_features[:, indices].\
            view(-1, 1000, self.anchor_feat_channels, self.fmap_h, 1)        
        # batch_anchor_features[self.invalid_mask] = 0
        batch_anchor_features = batch_anchor_features * torch.logical_not(self.invalid_mask)

        # Join proposals from all images into a single proposals features batch
        # batchx1000x704
        batch_anchor_features = batch_anchor_features.view(-1, 1000, self.anchor_feat_channels * self.fmap_h)

        # Add attention features
        softmax = torch.nn.Softmax(dim=2)
        # batchx1000x999
        scores = self.attention_layer(batch_anchor_features)
        attention = softmax(scores)
        bs, _, _ = scores.shape
        attention_matrix = torch.zeros(bs, 1000 * 1000, device=x.device)
        attention_matrix[:, self.non_diag_inds] = attention.reshape(-1, int(attention.numel()))
        attention_matrix = attention_matrix.view(-1, 1000, 1000)
        attention_features = torch.matmul(torch.transpose(batch_anchor_features, 1, 2),
                                          torch.transpose(attention_matrix, 1, 2)).transpose(1, 2)
        batch_anchor_features = torch.cat((attention_features, batch_anchor_features), dim=2)

        # Predict
        cls_logits = self.cls_layer(batch_anchor_features)
        reg = self.reg_layer(batch_anchor_features)

        xs, ys = map(int, self.anchors.shape)
        anchors = self.anchors[None].expand(bs, xs, ys)

        # Add offsets to anchors (1000, 2+2+73)
        reg_proposals = torch.cat([softmax(cls_logits), anchors[:, :, 2:4], anchors[:, :, 4:] + reg], dim=2)

        return reg_proposals

def export_onnx(onnx_file_path):
    # e.g. laneatt_r18_culane
    backbone_name = 'resnet18'
    checkpoint_file_path = 'experiments/laneatt_r18_culane/models/model_0015.pt'
    anchors_freq_path = 'data/culane_anchors_freq.pt'

    # Load specified checkpoint
    model = LaneATT(backbone=backbone_name, anchors_freq_path=anchors_freq_path, topk_anchors=1000)
    checkpoint = torch.load(checkpoint_file_path)
    model.load_state_dict(checkpoint['model'])
    model.eval()

    # Export to ONNX
    onnx_model = LaneATTONNX(model)
    dummy_input = torch.randn(1, 3, 360, 640)
    torch.onnx.export(
        onnx_model, 
        dummy_input, 
        onnx_file_path, 
        input_names=["images"], 
        output_names=["output"]
    )

    import onnx
    model_onnx = onnx.load(onnx_file_path)

    # Simplify
    try:
        import onnxsim

        print(f"simplifying with onnxsim {onnxsim.__version__}...")
        model_onnx, check = onnxsim.simplify(model_onnx)
        assert check, "Simplified ONNX model could not be validated"
    except Exception as e:
        print(f"simplifier failure: {e}")

    onnx.save(model_onnx, "LaneATT_test.sim.onnx")
    print(f"simplify done. onnx model save in LaneATT_test.sim.onnx")   

if __name__ == '__main__':
    export_onnx('./LaneATT_test.onnx')

我们再次执行下,执行成功后会在当前目录下生成 LaneATT_test.sim.onnx,我们一起来看下修改后的 ONNX 模型结构是否符合我们的预期:

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

从上图中可以看到模型结构都符合我们之前修改的期望,output 的 batch 维度添加了,输入输出节点名修改了,onnx-simplifier 简化也做了,整个模型已经足够简洁了

这里额外说一下如果大家对动态 batch 模型没有需求的话,可以直接使用这里的静态 batch 模型也能完成后续的推理部署工作

下面我们就来看看动态 batch 模型的导出看看是否存在问题,简单修改下 laneatt_to_onnx.py 代码:

dummy_input = torch.randn(1, 3, 360, 640)
dynamic_batch = {'images': {0: 'batch'}, 'output': {0: 'batch'}}
torch.onnx.export(
    onnx_model, 
    dummy_input, 
    onnx_file_path, 
    input_names=["images"], 
    output_names=["output"],
    dynamic_axes=dynamic_batch
)

再次执行下导出脚本代码,接着我们一起来看看导出的动态 batch 模型存在哪些问题:

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

看上去似乎没有什么问题,相比于前面动态 batch 模型要简洁不少,但是我们在其中还是可以看到诸如 Shape、Gather、ConstantOfShape 等节点,这个主要是因为 shape 节点的 trace 导致的,之前杜老师的课程中有讲过,大家感兴趣的可以看看:6.3.tensorRT高级(1)-yolov5模型导出、编译到推理(无封装)

那我们修改就非常简单了,修改代码如下:

# bs, _, _ = scores.shape
bs, _, _ = map(int, scores.shape)

执行下导出脚本,再看下导出的模型结构的变化:

在这里插入图片描述

在这里插入图片描述

可以看到后半部分非常简洁和静态 batch 模型基本上没区别,似乎没有啥问题了,但是我们仔细看就会发现其实还是有点猫腻的,首先 ScatterND 这个常量节点的维度是 1x1000000,这显然不是我们期望的,我们期望的是 batchx1000000,之所以会出现这个情况主要是因为我们断开了 shape 节点的 trace,其中的 batch 维度没有跟踪

但是我们发现后续跟另外一个 tensor 做矩阵相乘最后得到的结果又是动态的 batchx704x1000,大家可能觉得没啥问题,因为矩阵相乘会做广播会将 1x1000x1000 广播成 batchx1000x1000,但是博主在后续 tensorRT 解析动态 batch 模型时出现了如下的警告:

[08/03/2024-06:53:41] [W] [TRT] Profile kMAX values are not self-consistent. IShuffleLayer /Reshape_5: reshaping failed for tensor: /Softmax_output_0 reshape would change volume 7992000 to 999000 Instruction: RESHAPE_ZERO_IS_PLACEHOLDER{8 1000 999} {1 999000}.

错误信息提示说 TensorRT 在执行 reshape 操作时,要求原始张量和目标张量的总元素数必须保持一致,原始数据的总数是 7992000,而尝试 reshape 为 999000,那这里博主设置的 batch size 的 kMax value 是 8,而且出现问题的节点名是 /Softmax_output_0,很明显就是我们前面说的 ScatterND 这个节点前的 reshape 操作导致的,因为 batch 未跟踪固定为 1 导致在 kMax batch size 时 reshape 维度不一致

此外这边还有一个问题那就是最后的 Concat 节点前的两个 tensor shape 都是动态 batch 的,为什么最后 concat 出来的结果却是静态的 1x1000x77,这个主要是因为 Concat 节点还把 anchor 的信息作为输出了,而 anchor 的 batch 维度没有 trace 变成了 1,导致最终的 output 是静态 batch 的

这个我们可以看下 forward 的代码:

xs, ys = map(int, self.anchors.shape)
anchors = self.anchors[None].expand(bs, xs, ys)

# Add offsets to anchors (1000, 2+2+73)
reg_proposals = torch.cat([softmax(cls_logits), anchors[:, :, 2:4], anchors[:, :, 4:] + reg], dim=2)

也可以从 ONNX 模型中发现这个问题:

在这里插入图片描述

所以说我们还是不能断开 shape 节点的 trace,因为诸如 attention_matrix、anchors 这些常量没有办法 trace 到 batch,导致最终的模型是静态 batch 的,兜兜转转又回去了,属实是白干一场

那我们只能 trace batch 节点,不过我们还是可以做一些小优化的,在 anchor 处理的时候我们可以提前做下 slice,修改如下所示:

# __init__
self.anchor_parts_1 = self.anchors[:, 2:4]
self.anchor_parts_2 = self.anchors[:, 4:]

# forward
anchor_expanded_1 = self.anchor_parts_1.repeat(reg.shape[0], 1, 1)
anchor_expanded_2 = self.anchor_parts_2.repeat(reg.shape[0], 1, 1)  

# Add offsets to anchors (1000, 2+2+73)
reg_proposals = torch.cat([softmax(cls_logits), anchor_expanded_1, anchor_expanded_2 + reg], dim=2)

执行下导出脚本,再看下导出的模型结构的变化:

在这里插入图片描述

似乎也没咋优化,之前的两个 slice 节点干掉了,不过新增了一个 Tile 节点

6. ONNX导出总结

经过上面的分析,我们来看下 LaneATT 模型的 ONNX 到底该如何导出呢?我们在 LaneATT 项目目录下新建一个 export.py 文件,其内容如下:

import torch
from lib.models.laneatt import LaneATT

class LaneATTONNX(torch.nn.Module):
    def __init__(self, model):
        super(LaneATTONNX, self).__init__()
        # Params
        self.fmap_h = model.fmap_h  # 11
        self.fmap_w = model.fmap_w  # 20
        self.anchor_feat_channels = model.anchor_feat_channels  # 64
        self.anchors = model.anchors
        self.cut_xs = model.cut_xs
        self.cut_ys = model.cut_ys
        self.cut_zs = model.cut_zs
        self.invalid_mask = model.invalid_mask
        # Layers
        self.feature_extractor = model.feature_extractor
        self.conv1 = model.conv1
        self.cls_layer = model.cls_layer
        self.reg_layer = model.reg_layer
        self.attention_layer = model.attention_layer

        # Exporting the operator eye to ONNX opset version 11 is not supported
        attention_matrix = torch.eye(1000)
        self.non_diag_inds = torch.nonzero(attention_matrix == 0., as_tuple=False)
        self.non_diag_inds = self.non_diag_inds[:, 1] + 1000 * self.non_diag_inds[:, 0]  # 999000

        self.anchor_parts_1 = self.anchors[:, 2:4]
        self.anchor_parts_2 = self.anchors[:, 4:]

    def forward(self, x):
        batch_features = self.feature_extractor(x)
        batch_features = self.conv1(batch_features)
        # batch_anchor_features = self.cut_anchor_features(batch_features)
        # batchx15360
        batch_anchor_features = batch_features.reshape(-1, int(batch_features.numel()))
        # h, w = batch_features.shape[2:4]  # 12, 20
        indices = self.cut_xs + 20 * self.cut_ys + 12 * 20 * self.cut_zs        
        batch_anchor_features = batch_anchor_features[:, indices].\
            view(-1, 1000, self.anchor_feat_channels, self.fmap_h, 1)        
        # batch_anchor_features[self.invalid_mask] = 0
        batch_anchor_features = batch_anchor_features * torch.logical_not(self.invalid_mask)

        # Join proposals from all images into a single proposals features batch
        # batchx1000x704
        batch_anchor_features = batch_anchor_features.view(-1, 1000, self.anchor_feat_channels * self.fmap_h)

        # Add attention features
        softmax = torch.nn.Softmax(dim=2)
        # batchx1000x999
        scores = self.attention_layer(batch_anchor_features)
        attention = softmax(scores)
        # bs, _, _ = scores.shape
        bs, _, _ =scores.shape
        attention_matrix = torch.zeros(bs, 1000 * 1000, device=x.device)
        attention_matrix[:, self.non_diag_inds] = attention.reshape(-1, int(attention.numel()))
        attention_matrix = attention_matrix.view(-1, 1000, 1000)
        attention_features = torch.matmul(torch.transpose(batch_anchor_features, 1, 2),
                                          torch.transpose(attention_matrix, 1, 2)).transpose(1, 2)
        batch_anchor_features = torch.cat((attention_features, batch_anchor_features), dim=2)

        # Predict
        cls_logits = self.cls_layer(batch_anchor_features)
        reg = self.reg_layer(batch_anchor_features)

        anchor_expanded_1 = self.anchor_parts_1.repeat(reg.shape[0], 1, 1)
        anchor_expanded_2 = self.anchor_parts_2.repeat(reg.shape[0], 1, 1)  

        # Add offsets to anchors (1000, 2+2+73)
        reg_proposals = torch.cat([softmax(cls_logits), anchor_expanded_1, anchor_expanded_2 + reg], dim=2)

        return reg_proposals

def export_onnx(onnx_file_path):
    # e.g. laneatt_r18_culane
    backbone_name = 'resnet18'
    checkpoint_file_path = 'experiments/laneatt_r18_culane/models/model_0015.pt'
    anchors_freq_path = 'data/culane_anchors_freq.pt'

    # Load specified checkpoint
    model = LaneATT(backbone=backbone_name, anchors_freq_path=anchors_freq_path, topk_anchors=1000)
    checkpoint = torch.load(checkpoint_file_path)
    model.load_state_dict(checkpoint['model'])
    model.eval()

    # Export to ONNX
    onnx_model = LaneATTONNX(model)
    
    dummy_input = torch.randn(1, 3, 360, 640)
    dynamic_batch = {'images': {0: 'batch'}, 'output': {0: 'batch'}}
    torch.onnx.export(
        onnx_model, 
        dummy_input, 
        onnx_file_path, 
        input_names=["images"], 
        output_names=["output"],
        dynamic_axes=dynamic_batch
    )

    import onnx
    model_onnx = onnx.load(onnx_file_path)

    # Simplify
    try:
        import onnxsim

        print(f"simplifying with onnxsim {onnxsim.__version__}...")
        model_onnx, check = onnxsim.simplify(model_onnx)
        assert check, "Simplified ONNX model could not be validated"
    except Exception as e:
        print(f"simplifier failure: {e}")

    onnx.save(model_onnx, "laneatt.sim.onnx")
    print(f"simplify done. onnx model save in laneatt.sim.onnx")   

if __name__ == '__main__':
    export_onnx('./laneatt.onnx')

然后在终端执行该脚本即可在当前目录生成 laneatt.sim.onnx 模型文件

这里有几点需要额外补充说明:

  • 1. 如果只需要导出静态 batch 的 ONNX 模型,将 dynamic_axes 设置为 None 即可,导出的 ONNX 模型会更加简洁
  • 2. 导出代码案例使用的是 culane 数据集的 laneatt_r18 模型,如果想导出其他的 resnet 模型需要修改 backbone_name 和 checkpoint_file_path
  • 3. 如果想导出其它数据集的模型,除了修改 backbone_name 和 checkpoint_file_path 还需要修改下 anchors_freq_path

结语

博主在这里对 LaneATT 模型进行了 ONNX 导出,主要是学习重写 forward 部分使得导出的 ONNX 模型尽可能的简洁,总的来说还是比较简单的

OK,以上就是 LaneATT 模型导出 ONNX 的全部内容了,下节我们来学习如何利用 tensorRT 推理 LaneATT,敬请期待😄

下载链接

  • 源代码、权重下载链接【提取码:lane】

参考

  • Keep your Eyes on the Lane: Real-time Attention-guided Lane Detection
  • https://github.com/lucastabelini/LaneATT
  • https://github.com/Yibin122/TensorRT-LaneATT
  • https://paperswithcode.com/task/lane-detection/latest
  • 车道线检测数据集介绍
  • 6.3.tensorRT高级(1)-yolov5模型导出、编译到推理(无封装)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1976490.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java游戏源码:象棋网络对战版

学习java朋友们&#xff0c;福利来了&#xff0c;今天小编给大家带来了一款象棋网络对战版源码。 源码搭建和讲解 源码分为客户端和服务器&#xff0c;采用java原生 java.net.Socket 实现&#xff0c;服务器主循环代码&#xff1a; import java.net.ServerSocket; import jav…

二维码生成原理及解码原理

☝☝☝二维码配图 二维码 二维码&#xff08;Quick Response Code&#xff0c;简称QR码&#xff09;是一种广泛使用的二维条形码技术&#xff0c;由日本公司Denso Wave在1994年开发。二维码能有效地存储和传递信息&#xff0c;广泛应用于商品追溯、支付、广告等多个领域。二维…

Star-CCM+负体积网格检查与出现原因

要使网格可用于有限体积计算&#xff0c;每个网格单元必须具有正体积&#xff0c;否则初始化过程将失败&#xff0c;且模拟计算无法运行。 负体积网格单元可能会以多种不同的方式出现&#xff0c;但必须修复或从网格中移除&#xff0c;才能继续执行任何后续操作。 要检查体网…

<数据集>人员摔倒识别数据集<目标检测>

数据集格式&#xff1a;VOCYOLO格式 图片数量&#xff1a;8605张 标注数量(xml文件个数)&#xff1a;8605 标注数量(txt文件个数)&#xff1a;8605 标注类别数&#xff1a;1 标注类别名称&#xff1a;[fall] 序号类别名称图片数框数1fall860512275 使用标注工具&#xf…

当前生物信息学研究面临的四大机遇和挑战(特别是最后一个,一定要足够重视)...

生物信息学是应用计算方法分析生物数据&#xff0c;如 DNA&#xff0c;RNA&#xff0c;蛋白质和代谢物。生物信息学已成为促进我们对生命科学的理解以及开发新的诊断&#xff0c;治疗和生物技术产品的重要工具。本文我们将探讨生物信息学研究的一些当前趋势和发展&#xff0c;以…

如何快速入门 PyTorch ?

PyTorch是一个机器学习框架&#xff0c;主要依靠深度神经网络&#xff0c;目前已迅速成为机器学习领域中最可靠的框架之一。 PyTorch 的大部分基础代码源于 Ronan Collobert 等人 在 2007 年发起的 Torch7 项目&#xff0c;该项目源于 Yann LeCun 和 Leon Bottou 首创的编程语…

【C++题解】1249. 搬砖问题

欢迎关注本专栏《C从零基础到信奥赛入门级&#xff08;CSP-J&#xff09;》 问题&#xff1a;1249. 搬砖问题 类型&#xff1a;嵌套穷举 题目描述&#xff1a; 36 块砖&#xff0c; 36 人搬。男搬 4 &#xff0c;女搬 3 &#xff0c;两个小儿抬一砖。 要求一次全搬完。问需…

GitHub最全中文排行榜开源项目,助你轻松发现优质资源!

文章目录 GitHub-Chinese-Top-Charts&#xff1a;中文开发者的开源项目精选项目介绍项目特点核心功能1. 热门项目榜单2. 详细项目信息 如何使用覆盖范围软件类资料类 GitHub-Chinese-Top-Charts&#xff1a;中文开发者的开源项目精选 在全球范围内&#xff0c;GitHub已经成为了…

谷歌外链:提升网站权重的秘密武器!

谷歌外链之被称为提升网站权重的秘密武器&#xff0c;主要是因为它们对网站的搜索引擎排名有着直接且显著的影响 谷歌和其他搜索引擎使用外链作为衡量网站信任度和权威性的重要指标。当一个网站获得来自其他信誉良好的源的链接时&#xff0c;这被视为信任的投票。多个高质量链…

opencv-图像仿射变换

仿射变换就是将矩形变为平行四边形&#xff0c;而透视变换可以变成任意不规则四边形。实际上&#xff0c;仿射变换是透视变换的子集&#xff0c;仿射变换是线性变换&#xff0c;而透视变换不仅仅是线性变换。 仿射变换设计图像位置角度的变化&#xff0c;是深度学习预处理中常…

力扣SQL50 患某种疾病的患者 正则表达式

Problem: 1527. 患某种疾病的患者 在SQL查询中&#xff0c;REGEXP 是用于执行正则表达式匹配的操作符。正则表达式允许使用特殊字符和模式来匹配字符串中的特定文本。具体到你的查询&#xff0c;^DIAB1|\\sDIAB1 是一个正则表达式&#xff0c;它使用了一些特殊的通配符和符号。…

Vue:vue-router使用指南

一、简介 点击查看vue-router官网 Vue Router 是 Vue.js 的官方路由。它与 Vue.js 核心深度集成&#xff0c;让用 Vue.js 构建单页应用变得轻而易举。功能包括&#xff1a; 嵌套路由映射动态路由选择模块化、基于组件的路由配置-路由参数、查询、通配符-展示由 Vue.js 的过渡系…

DNS常见面试题

DNS是什么&#xff1f; 域名使用字符串来代替 IP 地址&#xff0c;方便用户记忆&#xff0c;本质上一个名字空间系统&#xff1b;DNS 是一个树状的分布式查询系统&#xff0c;但为了提高查询效率&#xff0c;外围有多级的缓存&#xff1b;DNS 就像是我们现实世界里的电话本、查…

电路板热仿真覆铜率,功率,结温,热阻率信息计算获取方法总结

🏡《电子元器件学习目录》 目录 1,概述2,覆铜率3,功率4,器件尺寸5,结温6,热阻1,概述 电路板热仿真操作是一个复杂且细致的过程,旨在评估和优化电路板内部的热分布及温度变化,以确保电子元件的可靠性和性能。本文简述在进行电路板的热仿真时,元器件热信息的计算方法…

59.DevecoStudio项目引入不同目录的文件进行函数调用

59.DevecoStudio ArkUI项目引入不同目录的文件进行函数调用 arkUi,ets,cj文件&#xff0c;ts文件的引用 import common from ohos.app.ability.common; import stringutils from ./uint8array2string; //index.ts的当前目录 import StringUtils2 from ../http2/uint8array2st…

python全栈开发《23.字符串的find与index函数》

1.补充说明上文 python全栈开发《22.字符串的startswith和endswith函数》 endswith和startswith也可以对完整&#xff08;整体&#xff09;的字符串进行判断。 info.endswith(this is a string example!!)或info.startswith(this is a string example!!)相当于bool(info this …

鸿蒙媒体开发【拼图】拍照和图片

拼图 介绍 该示例通过ohos.multimedia.image和ohos.file.photoAccessHelper接口实现获取图片&#xff0c;以及图片裁剪分割的功能。 效果预览 使用说明&#xff1a; 使用预置相机拍照后启动应用&#xff0c;应用首页会读取设备内的图片文件并展示获取到的第一个图片&#x…

Animate软件基础:关于补间动画中的图层

Animate 文档中的每一个场景都可以包含任意数量的时间轴图层。使用图层和图层文件夹可组织动画序列的内容和分隔动画对象。在图层和文件夹中组织它们可防止它们在重叠时相互擦除、连接或分段。若要创建一次包含多个元件或文本字段的补间移动的动画&#xff0c;请将每个对象放置…

go 中 string 并发写导致的 panic

类型的一点变化 在Go语言的演化过程中&#xff0c;引入了unsafe.String来取代之前的StringHeader结构体&#xff0c;这是为了提供更安全和简洁的字符串操作方式。 旧设计 (StringHeader 结构体) StringHeader注释发生了一点变动&#xff0c;被标注了 Deprecated&#xff0c;…

谷粒商城实战笔记-103~104-全文检索-ElasticSearch-Docker安装ES和Kibana

文章目录 一&#xff0c;103-全文检索-ElasticSearch-Docker安装ES1&#xff0c;下载镜像文件2&#xff0c;Elasticsearch安装3&#xff0c;验证 二&#xff0c;104-全文检索-ElasticSearch-Docker安装Kibana1&#xff0c;下载镜像文件2&#xff0c;kibana的安装3&#xff0c;验…