CLRerNet推理详解及部署实现(上)

news2025/1/11 2:54:31

目录

    • 前言
    • 1. 概述
    • 2. 环境配置
    • 3. Demo测试
    • 4. ONNX导出初探
    • 5. ONNX导出优化
    • 6. ONNX导出总结
    • 结语
    • 下载链接
    • 参考

前言

继续我们的车道线检测任务,之前我们分享了基于 anchor 的 LaneATT 模型以及 CVPR2022 的 SOTA 方案 CLRNet,这里我们分享 WACV2024 中的一个方案 CLRerNet,这篇文章主要分析 CLRerNet 模型的 ONNX 导出以及解决导出过程中遇到的各种问题。若有问题欢迎各位看官批评指正😄

paper:CLRerNet: Improving Confidence of Lane Detection with LaneIoU

repo:https://github.com/hirotomusiker/CLRerNet

1. 概述

CLRerNet 引入了被称为 LaneIoU 的新型 IoU,不同于 CLRNet 中的 LineIoU,LaneIoU 引入了一种可微的局部角度感知 IoU 定义,这种方法在计算 IoU 时考虑了车道线的局部角度变化,从而更准确地反映车道线之间的相似性。目前 CLRerNet 在 CULane 数据集上是 SOTA 方案,但它的模型结构与 CLRNet 相比并没有多大的变化,所以对于部署来说 CLRerNet 其实和 CLRNet 没有什么区别,这里让博主再水两篇文章吧😂

CLRerNet 整体框架如下图所示:

在这里插入图片描述

LineIoU 和 LaneIoU 对比图如下所示:

在这里插入图片描述

2. 环境配置

在开始之前我们有必要配置下环境,CLRerNet 的环境可以通过 CLRerNet/README.md 文档中安装,这里有个点需要大家注意,那就是 CLRerNet 官方和 CLRNet 一样将后处理的 NMS 部分放在了 CUDA 上实现,因此需要编译,这个在 Windows 上面折腾可能比较麻烦,博主直接在 Linux 上操作的

博主这里准备了一个可以运行 demo 和导出 ONNX 的环境,大家可以按照这个环境来,也可以自己参考文档进行相关环境配置

博主的环境安装指令如下所示:

conda create -n clrernet python=3.8
conda activate clrernet
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2
pip install -U openmim==0.3.3
mim install mmcv-full==1.7.0
pip install albumentations==0.4.6 p_tqdm==1.3.3 yapf==0.40.1 mmdet==2.28.0
pip install pytest pytest-cov tensorboard
pip install onnx onnx-simplifier onnxruntime

可能大家有所困惑,为什么需要的 torch 版本比较高,这个其实取决于你的 CUDA 版本,博主 Linux 主机的 CUDA 版本是 11.6,如果安装的 torch 版本过低,会导致编译的 NMS 插件无法通过,这个大家根据自己的实际情况来就行。

Note:这个环境目前博主只用于 demo 测试和 ONNX 导出,并不包含训练

为了不必要的错误,博主将虚拟环境中各个软件的版本都罗列出来,方便大家查看,环境如下:

Package                  Version
------------------------ -----------
absl-py                  2.1.0
addict                   2.4.0
albumentations           0.4.6
cachetools               5.4.0
certifi                  2024.7.4
charset-normalizer       3.3.2
click                    8.1.7
cmake                    3.30.2
colorama                 0.4.6
coloredlogs              15.0.1
contourpy                1.1.1
coverage                 7.6.1
cycler                   0.12.1
dill                     0.3.8
exceptiongroup           1.2.2
filelock                 3.15.4
flatbuffers              23.3.3
fonttools                4.53.1
google-auth              2.33.0
google-auth-oauthlib     1.0.0
grpcio                   1.65.4
humanfriendly            10.0
idna                     3.7
imageio                  2.35.0
imgaug                   0.4.0
importlib_metadata       8.2.0
importlib_resources      6.4.0
iniconfig                2.0.0
Jinja2                   3.1.4
kiwisolver               1.4.5
lazy_loader              0.4
lit                      18.1.8
Markdown                 3.6
markdown-it-py           3.0.0
MarkupSafe               2.1.5
matplotlib               3.7.5
mdurl                    0.1.2
mmcv-full                1.7.0
mmdet                    2.28.0
model-index              0.1.11
mpmath                   1.3.0
multiprocess             0.70.16
networkx                 3.1
nms                      0.0.0
numpy                    1.24.4
nvidia-cublas-cu11       11.10.3.66
nvidia-cuda-cupti-cu11   11.7.101
nvidia-cuda-nvrtc-cu11   11.7.99
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cudnn-cu11        8.5.0.96
nvidia-cufft-cu11        10.9.0.58
nvidia-curand-cu11       10.2.10.91
nvidia-cusolver-cu11     11.4.0.1
nvidia-cusparse-cu11     11.7.4.91
nvidia-nccl-cu11         2.14.3
nvidia-nvtx-cu11         11.7.91
oauthlib                 3.2.2
onnx                     1.16.2
onnx-simplifier          0.4.36
onnxruntime              1.14.1
opencv-python            4.10.0.84
openmim                  0.3.3
ordered-set              4.1.0
p_tqdm                   1.3.3
packaging                24.1
pandas                   2.0.3
pathos                   0.3.2
pillow                   10.4.0
pip                      24.2
platformdirs             4.2.2
pluggy                   1.5.0
pox                      0.3.4
ppft                     1.7.6.8
protobuf                 5.27.3
pyasn1                   0.6.0
pyasn1_modules           0.4.0
pycocotools              2.0.7
Pygments                 2.18.0
pyparsing                3.1.2
pytest                   8.3.2
pytest-cov               5.0.0
python-dateutil          2.9.0.post0
pytz                     2024.1
PyWavelets               1.4.1
PyYAML                   6.0.2
requests                 2.32.3
requests-oauthlib        2.0.0
rich                     13.7.1
rsa                      4.9
scikit-image             0.21.0
scipy                    1.10.1
setuptools               72.1.0
shapely                  2.0.5
six                      1.16.0
sympy                    1.11.1
tabulate                 0.9.0
tensorboard              2.14.0
tensorboard-data-server  0.7.2
terminaltables           3.1.10
tifffile                 2023.7.10
tomli                    2.0.1
torch                    2.0.1
torchaudio               2.0.2
torchvision              0.15.2
tqdm                     4.66.5
triton                   2.0.0
typing_extensions        4.12.2
tzdata                   2024.1
urllib3                  2.2.2
Werkzeug                 3.0.3
wheel                    0.43.0
yapf                     0.40.1
zipp                     3.20.0

3. Demo测试

OK,环境准备好后我们就可以执行 demo,具体流程可以参考:https://github.com/hirotomusiker/CLRerNet/speed-test

我们一个个来,首先是推理验证测试,推理脚本如下所示:

python demo/image_demo.py demo/demo.jpg configs/clrernet/culane/clrernet_culane_dla34.py clrernet_culane_dla34.pth --out-file=clrernet_result.png

在这之前我们需要把 CLRerNet 这个项目给 clone 下来,执行如下指令:

git clone https://github.com/hirotomusiker/CLRerNet.git

也可手动点击下载,点击右上角的 Code 按键,将代码下载下来。至此整个项目就已经准备好了。

接着我们需要把 NMS 插件编译下,方便后续 demo 的运行,指令如下:

cd CLRerNet
conda activate clrernet
cd libs/models/layers/nms/
python setup.py install

输出如下所示:

在这里插入图片描述

大家如果看到上述内容则说明 NMS 插件编译成功了

同时也可以通过 pip 查看编译的 NMS 插件,如下图所示:

在这里插入图片描述

此外还需要下载预训练权重用于 Demo 测试和 ONNX 导出

预训练权重的下载可以通过 README 提供的链接获取:

在这里插入图片描述

值得注意的是官方只提供了 backbone 为 DLA34 训练 CULane 数据集的权重,博主这里也准备了 Demo 测试使用的权重,大家可以点击 here 下载,下载好后将预训练权重放在 CLRerNet 主目录下即可

源码和模型都准备好后,执行如下指令即可进行推理:

python demo/image_demo.py demo/demo.jpg configs/clrernet/culane/clrernet_culane_dla34.py clrernet_culane_dla34.pth --out-file=clrernet_result.png

你可能会遇到如下问题:

在这里插入图片描述

错误显示 No module named libs.api,找不到 libs.api 模块,这个主要是我们没有添加环境变量,在终端执行如下指令:

export PYTHONPATH=$PYTHONPATH:/home/jarvis/Learn/project/CLRerNet

注意将路径修改为你自己的 CLRerNet 路径,接着再次执行上述脚本,输出如下:

在这里插入图片描述

同时在当前目录下还会生成 clrernet_result.png 推理后的图片,如下所示:

在这里插入图片描述

可以看到成功推理了,下面我们来分析 ONNX 模型的导出

4. ONNX导出初探

博主这里采用的是 vscode 进行代码的调试,其中的 launch.json 文件内容如下:

{
    // 使用 IntelliSense 了解相关属性。 
    // 悬停以查看现有属性的描述。
    // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
    "version": "0.2.0",
    "configurations": [
        {
            "name": "Python 调试程序: 当前文件",
            "type": "python",
            "request": "launch",
            "program": "${file}",
            "console": "integratedTerminal",
            "args": [
                "demo/demo.jpg", 
                "configs/clrernet/culane/clrernet_culane_dla34.py",
                "clrernet_culane_dla34.pth",
                "--out-file", "clrernet_result.png"
            ],
            "env": {
                "PYTHONPATH": "/home/jarvis/Learn/project/CLRerNet"
            }
        }
    ]
}

要调试的文件是 demo/image_demo.py,在 main 函数中打个断点我们来开始调试:

在这里插入图片描述

在 main 函数中我们可以非常清晰的找到 build model 的地方,从调试信息来看 model 就是一个正常的 pytorch 模型,因此我们其实可以直接在这里尝试导出下 ONNX

在 main 函数中新增如下导出代码:

# demo/image_demo.py 29 行

model = init_detector(args.config, args.checkpoint, device=args.device)

# =====================================================================
import torch
model = model.to("cpu")
dummy_input = torch.randn(1, 3 ,320, 800)
torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    opset_version=16
)
print(f"finished export onnx model")

import onnx
model_onnx = onnx.load("model.onnx")
onnx.checker.check_model(model_onnx)    # check onnx model

# Simplify
try:
    import onnxsim

    print(f"simplifying with onnxsim {onnxsim.__version__}...")
    model_onnx, check = onnxsim.simplify(model_onnx)
    assert check, "Simplified ONNX model could not be validated"
except Exception as e:
    print(f"simplifier failure: {e}")

onnx.save(model_onnx, "model.sim.onnx")
print(f"simplify done. onnx model save in model.sim.onnx")
return
# =====================================================================

再来执行如下指令:

python demo/image_demo.py demo/demo.jpg configs/clrernet/culane/clrernet_culane_dla34.py clrernet_culane_dla34.pth --out-file=clrernet_result.png

输出如下所示:

在这里插入图片描述

可以看到导出失败了,在 forward 的过程中出现了一个断言错误即 assert len(img_metas) == 1,我们来看下它的 forward 到底做了些什么:

在这里插入图片描述

经过我们调试发现在 forward_test 函数中它其实需要提供两个参数,一个是 img 另一个是 img_metas,其中 img_metas 中包含了 img 的一些信息,那这个就比较头疼了,还要按照它的格式去准备一个 img_metas

不过我们从 forward_test 函数中也能明显看到其实 forward 过程没有用到 img_metas,它通过 self.extract_feat 提取特征,接着送入到 self.bbox_head 拿到输出结果,因此我们完全可以自己来构建模型导出嘛,没有必要用他提供的 forward_test 函数

在 CLRerNet 目录下新建 export.py 文件,内容如下:

import torch
from mmcv import Config
from mmdet.models import build_detector
from mmcv.runner import load_checkpoint

class CLRerNetONNX(torch.nn.Module):
    def __init__(self, model):
        super(CLRerNetONNX, self).__init__()
        self.model = model

    def forward(self, x):
        x = self.model.backbone(x)
        x = self.model.neck(x)
        output = self.model.bbox_head(x)
        return output
    
if __name__ == "__main__":

    cfg = Config.fromfile("configs/clrernet/culane/clrernet_culane_dla34.py")
    model = build_detector(cfg.model, test_cfg=cfg.get("test_cfg"))
    load_checkpoint(model, "clrernet_culane_dla34.pth", map_location="cpu")
        
    model.eval()
    model = model.to("cpu")
    
    # Export to ONNX
    onnx_model = CLRerNetONNX(model)

    dummy_input = torch.randn(1, 3, 320, 800)

    torch.onnx.export(
        onnx_model, 
        dummy_input,
        "model.onnx",
        opset_version=16
    )
    print(f"finished export onnx model")

    import onnx
    model_onnx = onnx.load("model.onnx")
    onnx.checker.check_model(model_onnx)    # check onnx model

    # Simplify
    try:
        import onnxsim

        print(f"simplifying with onnxsim {onnxsim.__version__}...")
        model_onnx, check = onnxsim.simplify(model_onnx)
        assert check, "Simplified ONNX model could not be validated"
    except Exception as e:
        print(f"simplifier failure: {e}")

    onnx.save(model_onnx, "model.sim.onnx")
    print(f"simplify done. onnx model save in model.sim.onnx")

执行下该脚本输出如下所示:

在这里插入图片描述

这里有个点需要大家注意,那就是 opset_version 的设置必须大于等于 16,这是因为如果设置小于 16 会出现如下的问题:

在这里插入图片描述

提示说 grid_sampler 算子在 opset version 14 不支持,请尝试下 opset version 16,我们来看下 ONNX 官网算子的支持:

在这里插入图片描述

从官网上我们可以看到 GridSample 这个节点只有在 opset16 版本之后才支持导出,因此我们这里将 opset 设置为 16 就是这个原因,具体大家可以参考:https://github.com/onnx/onnx/blob/main/docs/Operators.md

还有一个点需要大家注意,那就是 TensorRT 只有在 8.5 版本之后才开始支持 GridSample 算子,因此如果你导出的 ONNX 中包含该算子,则需要你保证 TensorRT 在 8.5 版本以上,不然在生成 engine 的时候会出现算子节点无法解析的错误,具体可以参考:https://github.com/onnx/onnx-tensorrt/blob/release/8.5-GA/docs/Changelog.md

在这里插入图片描述

接着我们一起来看下刚导出的模型文件:

在这里插入图片描述

在这里插入图片描述

可以看到这个模型文件总体还是比较干净的,DLA34 加上 ROI pooling 以及 ROI gather,一路到底,输入没什么问题,输出存在多个,我们需要分析哪些部分是我们不需要的给它干掉,下面我们一起来优化下

5. ONNX导出优化

经过我们的调试分析(省略…😄)可以知道最终只需要 head 输出的最后一个维度,因此我们修改下 export.py 导出代码,如下所示:

import torch
from mmcv import Config
from mmdet.models import build_detector
from mmcv.runner import load_checkpoint

class CLRerNetONNX(torch.nn.Module):
    def __init__(self, model):
        super(CLRerNetONNX, self).__init__()
        self.model = model
        self.bakcbone = model.backbone
        self.neck     = model.neck
        self.head     = model.bbox_head

    def forward(self, x):
        x = self.bakcbone(x)
        x = self.neck(x)
        x = self.head(x)

        pred_dict     = x[-1]
        cls_logits    = pred_dict["cls_logits"]
        anchor_params = pred_dict["anchor_params"]
        lengths       = pred_dict["lengths"]
        xs            = pred_dict["xs"]
        
        output = torch.concat([cls_logits, anchor_params, lengths, xs], dim=2)

        return output
    
if __name__ == "__main__":

    cfg = Config.fromfile("configs/clrernet/culane/clrernet_culane_dla34.py")
    model = build_detector(cfg.model, test_cfg=cfg.get("test_cfg"))
    load_checkpoint(model, "clrernet_culane_dla34.pth", map_location="cpu")
        
    model.eval()
    model = model.to("cpu")
    
    # Export to ONNX
    onnx_model = CLRerNetONNX(model)

    dummy_input = torch.randn(1, 3, 320, 800)

    torch.onnx.export(
        onnx_model, 
        dummy_input,
        "model.onnx",
        input_names=["images"],
        output_names=["output"],
        opset_version=16
    )
    print(f"finished export onnx model")

    import onnx
    model_onnx = onnx.load("model.onnx")
    onnx.checker.check_model(model_onnx)    # check onnx model

    # Simplify
    try:
        import onnxsim

        print(f"simplifying with onnxsim {onnxsim.__version__}...")
        model_onnx, check = onnxsim.simplify(model_onnx)
        assert check, "Simplified ONNX model could not be validated"
    except Exception as e:
        print(f"simplifier failure: {e}")

    onnx.save(model_onnx, "model.sim.onnx")
    print(f"simplify done. onnx model save in model.sim.onnx")

再次执行下导出代码,查看下新导出的 ONNX 模型结构的变化:

在这里插入图片描述

在这里插入图片描述

可以看到导出的网络结构更加清晰了,而且输出只有一个符合我们的预期

我们再来看下动态 batch 模型的导出,简单增加下动态维度:

dynamic_batch = {'images': {0: 'batch'}, 'output': {0: 'batch'}}
torch.onnx.export(
    onnx_model,
    dummy_input,
    "model.onnx",
    input_names=["images"],
    output_names=["output"],
    opset_version=16,
    dynamic_axes=dynamic_batch
)

再次执行后生成的 ONNX 就是 batch 维度动态,如下所示:

在这里插入图片描述

可以看到输入输出都保证了 batch 维度动态,似乎没有什么问题,但是大家往后看会发现这个模型的复杂度还是比较高的:

在这里插入图片描述

这主要是因为 ROI pooling 以及 ROI gather 操作中一些 shape 节点的 trace 导致导出的动态 batch 模型复杂度非常高,下面我们来看看如何优化这个 ONNX 模型让它尽量简洁一些

我们学习之前 LaneATT 导出方法,重写下 head 的 forward 部分让它导出的 ONNX 尽可能的满足我们的需求,经过我们的调试分析(省略…😄)我们需要做以下几件事情:

  • 1. -1 尽量出现在 batch 维度
  • 2. cls_logits 添加 softmax
  • 3. length 维度乘以 n_strips
  • 4. 设置 opset version 17 导出完整的 LayerNormalization

修改后的 export.py 代码如下:

import torch
from mmcv import Config
from mmdet.models import build_detector
from mmcv.runner import load_checkpoint

class CLRerNetONNX(torch.nn.Module):
    def __init__(self, model):
        super(CLRerNetONNX, self).__init__()
        self.model = model
        self.bakcbone = model.backbone
        self.neck     = model.neck
        self.head     = model.bbox_head

    def forward(self, x):
        x = self.bakcbone(x)
        x = self.neck(x)
        
        batch = x[0].shape[0]
        feature_pyramid = list(x[len(x) - self.head.refine_layers:])
        # 1x64x10x25+1x64x20x50+1x64x40x100
        feature_pyramid.reverse()
        
        _, sampled_xs = self.head.anchor_generator.generate_anchors(
            self.head.anchor_generator.prior_embeddings.weight,
            self.head.prior_ys,
            self.head.sample_x_indices,
            self.head.img_w,
            self.head.img_h
        )

        anchor_params = self.head.anchor_generator.prior_embeddings.weight.clone().repeat(batch, 1, 1)
        priors_on_featmap = sampled_xs.repeat(batch, 1, 1)

        predictions_list = []
        pooled_features_stages = []
        for stage in range(self.head.refine_layers):
            # 1. anchor ROI pooling
            prior_xs = priors_on_featmap
            pooled_features = self.head.pool_prior_features(feature_pyramid[stage], prior_xs)
            pooled_features_stages.append(pooled_features)

            # 2. ROI gather
            fc_features = self.head.attention(pooled_features_stages, feature_pyramid, stage)
            # fc_features = fc_features.view(self.head.num_priors, batch, -1).reshape(batch * self.head.num_priors, self.head.fc_hidden_dim)
            fc_features = fc_features.view(self.head.num_priors, -1, 64).reshape(-1, self.head.fc_hidden_dim)

            # 3. cls and reg head
            cls_features = fc_features.clone()
            reg_features = fc_features.clone()
            for cls_layer in self.head.cls_modules:
                cls_features = cls_layer(cls_features)
            for reg_layer in self.head.reg_modules:
                reg_features = reg_layer(reg_features)
            
            cls_logits = self.head.cls_layers(cls_features)
            # cls_logits = cls_logits.reshape(batch, -1, cls_logits.shape[1])
            cls_logits = cls_logits.reshape(-1, 192, 2)

            reg = self.head.reg_layers(reg_features)
            # reg = reg.reshape(batch, -1, reg.shape[1])
            reg = reg.reshape(-1, 192, 76)

            # 4. reg processing
            anchor_params += reg[:, :, :3]
            updated_anchor_xs, _ = self.head.anchor_generator.generate_anchors(
                anchor_params.view(-1, 3),
                self.head.prior_ys,
                self.head.sample_x_indices,
                self.head.img_w,
                self.head.img_h
            )
            # updated_anchor_xs = updated_anchor_xs.view(batch, self.head.num_priors, -1)
            updated_anchor_xs = updated_anchor_xs.view(-1, 192, 72)
            reg_xs = updated_anchor_xs + reg[..., 4:]

            # start_y, start_x, theta
            # some problem.
            # anchor_params[:, :, 0] = 1.0 - anchor_params[:, :, 0]
            # anchor_params_ = anchor_params.clone()
            # anchor_params_[:, :, 0] = 1.0 - anchor_params_[:, :, 0]
            # print(f"anchor_params.shape = {anchor_params_.shape}")

            softmax = torch.nn.Softmax(dim=2)
            cls_logits = softmax(cls_logits)
            reg[:, :, 3:4] = reg[:, :, 3:4] * self.head.n_strips
            predictions = torch.concat([cls_logits, anchor_params, reg[:, :, 3:4], reg_xs], dim=2)
            # predictions = torch.concat([cls_logits, anchor_params_, reg[:, :, 3:4], reg_xs], dim=2)

            predictions_list.append(predictions)

            if stage != self.head.refine_layers - 1:
                anchor_params = anchor_params.detach().clone()
                priors_on_featmap = updated_anchor_xs.detach().clone()[
                    ..., self.head.sample_x_indices
                ]
        
        return predictions_list[-1]

    
if __name__ == "__main__":

    cfg = Config.fromfile("configs/clrernet/culane/clrernet_culane_dla34.py")
    model = build_detector(cfg.model, test_cfg=cfg.get("test_cfg"))
    load_checkpoint(model, "clrernet_culane_dla34.pth", map_location="cpu")
        
    model.eval()
    model = model.to("cpu")
    
    # Export to ONNX
    onnx_model = CLRerNetONNX(model)

    dummy_input = torch.randn(1, 3, 320, 800)

    dynamic_batch = {'images': {0: 'batch'}, 'output': {0: 'batch'}}
    torch.onnx.export(
        onnx_model, 
        dummy_input,
        "model.onnx",
        input_names=["images"],
        output_names=["output"],
        opset_version=17,
        dynamic_axes=dynamic_batch
    )
    print(f"finished export onnx model")

    import onnx
    model_onnx = onnx.load("model.onnx")
    onnx.checker.check_model(model_onnx)    # check onnx model

    # Simplify
    try:
        import onnxsim

        print(f"simplifying with onnxsim {onnxsim.__version__}...")
        model_onnx, check = onnxsim.simplify(model_onnx)
        assert check, "Simplified ONNX model could not be validated"
    except Exception as e:
        print(f"simplifier failure: {e}")

    onnx.save(model_onnx, "clrernet.sim.onnx")
    print(f"simplify done. onnx model save in clrernet.sim.onnx")

执行下上述导出脚本,会在当前目录下生成 clrernet.sim.onnx 模型文件,我们一起来看下导出的模型结构的变化:

在这里插入图片描述

在这里插入图片描述

可以看到 ONNX 模型的变化都符合我们的预期,不过动态 batch 的 ONNX 模型整体结构并没有啥变化,还是一样的复杂:

在这里插入图片描述

修改后还是这么复杂主要原因是博主也就是简单改了改 head 部分,并没有完全重写,大家感兴趣的可以重写下 ROI pooling 和 ROI gather 部分让它尽可能更简单

这里还有一个点需要大家注意,那就是我们这里得到的输出中的 start_y 维度并不是我们期望的起始点的数据,1-start_y 才是我们期望的起始点数据,因此我们可以考虑直接在这里把这个操作给做了,部分代码修改如下:

# start_y start_x theta
anchor_params[:, :, 0] = 1.0 - anchor_params[:, :, 0]

那博主在后续测试中发现这么做会存在一个问题,那就是 forward 中的后续操作会使用到 anchor_params 这个变量,因此你不能简单的直接去修改这个变量,这会导致后续的推理结果数据发生错误,因此我们正确的做法是先 clone,代码如下所示:

# start_y start_x theta
anchor_params_ = anchor_params.clone()
anchor_params_[:, :, 0] = 1.0 - anchor_params_[:, :, 0]

softmax = torch.nn.Softmax(dim=2)
cls_logits = softmax(cls_logits)
reg[:, :, 3:4] = reg[:, :, 3:4] * self.head.n_strips
predictions = torch.concat([cls_logits, anchor_params_, reg[:, :, 3:4], reg_xs], dim=2)

这样做推理似乎没有问题,但是随之而来的另外一个问题就是 ONNX 模型的复杂度提高了,如下所示:

在这里插入图片描述

上面的这些操作都是由于 anchor_params_ 而新增的节点,这显然不是我们期望看到的,所以博主这里还是把它放在后处理中去做吧

6. ONNX导出总结

经过上面的分析,我们来看下 CLRerNet 模型的 ONNX 到底该如何导出呢?我们在 CLRerNet 项目目录下新建一个 export.py 文件,其内容如下:

import torch
from mmcv import Config
from mmdet.models import build_detector
from mmcv.runner import load_checkpoint

class CLRerNetONNX(torch.nn.Module):
    def __init__(self, model):
        super(CLRerNetONNX, self).__init__()
        self.model = model
        self.bakcbone = model.backbone
        self.neck     = model.neck
        self.head     = model.bbox_head

    def forward(self, x):
        x = self.bakcbone(x)
        x = self.neck(x)
        
        batch = x[0].shape[0]
        feature_pyramid = list(x[len(x) - self.head.refine_layers:])
        # 1x64x10x25+1x64x20x50+1x64x40x100
        feature_pyramid.reverse()
        
        _, sampled_xs = self.head.anchor_generator.generate_anchors(
            self.head.anchor_generator.prior_embeddings.weight,
            self.head.prior_ys,
            self.head.sample_x_indices,
            self.head.img_w,
            self.head.img_h
        )

        anchor_params = self.head.anchor_generator.prior_embeddings.weight.clone().repeat(batch, 1, 1)
        priors_on_featmap = sampled_xs.repeat(batch, 1, 1)

        predictions_list = []
        pooled_features_stages = []
        for stage in range(self.head.refine_layers):
            # 1. anchor ROI pooling
            prior_xs = priors_on_featmap
            pooled_features = self.head.pool_prior_features(feature_pyramid[stage], prior_xs)
            pooled_features_stages.append(pooled_features)

            # 2. ROI gather
            fc_features = self.head.attention(pooled_features_stages, feature_pyramid, stage)
            # fc_features = fc_features.view(self.head.num_priors, batch, -1).reshape(batch * self.head.num_priors, self.head.fc_hidden_dim)
            fc_features = fc_features.view(self.head.num_priors, -1, 64).reshape(-1, self.head.fc_hidden_dim)

            # 3. cls and reg head
            cls_features = fc_features.clone()
            reg_features = fc_features.clone()
            for cls_layer in self.head.cls_modules:
                cls_features = cls_layer(cls_features)
            for reg_layer in self.head.reg_modules:
                reg_features = reg_layer(reg_features)
            
            cls_logits = self.head.cls_layers(cls_features)
            # cls_logits = cls_logits.reshape(batch, -1, cls_logits.shape[1])
            cls_logits = cls_logits.reshape(-1, 192, 2)

            reg = self.head.reg_layers(reg_features)
            # reg = reg.reshape(batch, -1, reg.shape[1])
            reg = reg.reshape(-1, 192, 76)

            # 4. reg processing
            anchor_params += reg[:, :, :3]
            updated_anchor_xs, _ = self.head.anchor_generator.generate_anchors(
                anchor_params.view(-1, 3),
                self.head.prior_ys,
                self.head.sample_x_indices,
                self.head.img_w,
                self.head.img_h
            )
            # updated_anchor_xs = updated_anchor_xs.view(batch, self.head.num_priors, -1)
            updated_anchor_xs = updated_anchor_xs.view(-1, 192, 72)
            reg_xs = updated_anchor_xs + reg[..., 4:]

            # start_y, start_x, theta
            # some problem.
            # anchor_params[:, :, 0] = 1.0 - anchor_params[:, :, 0]
            # anchor_params_ = anchor_params.clone()
            # anchor_params_[:, :, 0] = 1.0 - anchor_params_[:, :, 0]
            # print(f"anchor_params.shape = {anchor_params_.shape}")

            softmax = torch.nn.Softmax(dim=2)
            cls_logits = softmax(cls_logits)
            reg[:, :, 3:4] = reg[:, :, 3:4] * self.head.n_strips
            predictions = torch.concat([cls_logits, anchor_params, reg[:, :, 3:4], reg_xs], dim=2)
            # predictions = torch.concat([cls_logits, anchor_params_, reg[:, :, 3:4], reg_xs], dim=2)

            predictions_list.append(predictions)

            if stage != self.head.refine_layers - 1:
                anchor_params = anchor_params.detach().clone()
                priors_on_featmap = updated_anchor_xs.detach().clone()[
                    ..., self.head.sample_x_indices
                ]
        
        return predictions_list[-1]

    
if __name__ == "__main__":

    cfg = Config.fromfile("configs/clrernet/culane/clrernet_culane_dla34.py")
    model = build_detector(cfg.model, test_cfg=cfg.get("test_cfg"))
    load_checkpoint(model, "clrernet_culane_dla34.pth", map_location="cpu")
        
    model.eval()
    model = model.to("cpu")
    
    # Export to ONNX
    onnx_model = CLRerNetONNX(model)

    dummy_input = torch.randn(1, 3, 320, 800)

    dynamic_batch = {'images': {0: 'batch'}, 'output': {0: 'batch'}}
    torch.onnx.export(
        onnx_model, 
        dummy_input,
        "model.onnx",
        input_names=["images"],
        output_names=["output"],
        opset_version=17,
        dynamic_axes=dynamic_batch
    )
    print(f"finished export onnx model")

    import onnx
    model_onnx = onnx.load("model.onnx")
    onnx.checker.check_model(model_onnx)    # check onnx model

    # Simplify
    try:
        import onnxsim

        print(f"simplifying with onnxsim {onnxsim.__version__}...")
        model_onnx, check = onnxsim.simplify(model_onnx)
        assert check, "Simplified ONNX model could not be validated"
    except Exception as e:
        print(f"simplifier failure: {e}")

    onnx.save(model_onnx, "clrernet.sim.onnx")
    print(f"simplify done. onnx model save in clrernet.sim.onnx")

然后在终端执行该脚本即可在当前目录生成 clrernet.sim.onnx 模型文件

这里有几点需要额外补充说明:

  • 1. 如果只需要导出静态 batch 的 ONNX 模型,将 dynamic_axes 设置为 None 即可,导出的 ONNX 模型会更加简洁
  • 2. 导出的 ONNX 模型中的 start_y 维度不再是起始点坐标,1-start_y 才是,我们在后处理的时候需要特别注意
  • 3. opset_version 必须大于等于 16,如果设置的 16,则 LayerNormalization 算子会被拆分为如下结构

在这里插入图片描述

我们在韩君老师的课程中有讲过这个就是一个典型的 LayerNormalization 算子,大家感兴趣的可以看下:三. TensorRT基础入门-快速分析开源代码并导出onnx

那我们知道 ONNX 在 opset17 版本之后就开始支持 LayerNormalization 整个算子的导出了,具体可以参考:https://github.com/onnx/onnx/blob/main/docs/Operators.md

在这里插入图片描述

这里还有一个点需要大家注意,那就是 TensorRT 只有在 8.6 版本之后才开始支持 LayerNormalization 算子,因此如果你导出的 ONNX 中包含该算子,则需要你保证 TensorRT 在 8.6 版本以上,不然会出现算子节点无法解析的错误,具体可以参考:https://github.com/onnx/onnx-tensorrt/blob/release/8.6-EA/docs/Changelog.md
在这里插入图片描述

结语

博主在这里对 CLRerNet 模型进行了 ONNX 导出,主要是学习重写 head 的 forward 某些部分使得导出的 ONNX 模型尽可能的符合我们的要求,总的来说还是比较简单的

OK,以上就是 CLRerNet 模型导出 ONNX 的全部内容了,下节我们来学习如何利用 tensorRT 推理 CLRerNet,敬请期待😄

下载链接

  • 源代码、权重下载链接【提取码:lane】

参考

  • CLRerNet: Improving Confidence of Lane Detection with LaneIoU
  • https://github.com/hirotomusiker/CLRerNet
  • https://github.com/onnx/onnx/blob/main/docs/Operators.md
  • https://github.com/onnx/onnx-tensorrt/blob/release/8.5-GA/docs/Changelog.md
  • 三. TensorRT基础入门-快速分析开源代码并导出onnx
  • https://github.com/onnx/onnx-tensorrt/blob/release/8.6-EA/docs/Changelog.md

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2051626.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何在Linux系统上使用ONLYOFFICE文档编辑PDF文件

​对Linux用户来说,得益于各类免费PDF编辑器,编辑PDF文件从来都不是无解难题。 如果您需要为PDF文件添加注释、留下批注、编辑已有文本或添加新文本框、插入图片或形状、删除某些页面或永久删除页面,您始终可以找到合适的应用,轻…

有关JavaScript的函数定义和函数的两种声明方式

1.函数 函数是一段可重复执行的代码块,它可以接收参数,并返回结果。在JavaScript中,函数用于封装可重用的代码,提高代码的可读性和可维护性。 1.1 函数的定义 函数的定义包括以下几个部分: 函数名称:用…

Hive3:表性能优化-分区与分桶

一、分区 1、概念 我们知道,一个Hive表,对应的HDFS是一个文件夹。 那么,当数据非常多的时候,存放在一个文件夹中,后期进行查询操作会影响性能。 所以,Hive引入了分区管理的方式。 本质就是,在…

Datawhale X 魔搭 AI夏令营-AIGC方向-LoRA学习笔记

LoRA(Low-Rank Adaptation)是一种用于优化大规模预训练模型的微调技术,特别适用于在资源有限的情况下,对模型进行高效且低成本的微调。LoRA的核心思想是通过低秩分解方法,仅微调模型的少数参数,从而显著减少…

关于图像亮度相关的调试总结

1、问题背景 关于图像亮度的调试,是整个ISP中非常重要的一块,它决定了图像整体的亮度、对比度、细节、以及噪声,对人眼有非常直观的感受, 之前也就具体问题,整理过几篇图像亮度模块相关的调试总结: 关于图…

标题中有多少个字符(c语言)

1.//描述 //凯刚写了一篇美妙的作文,请问这篇作文的标题中有多少个字符? //注意:标题中可能包含大、小写英文字母、数字字符、空格和换行符。统计标题字 符数时,空格和换行符不计算在内。 //输入描述: //输入文件只有一…

【12】KMP和Manacher算法

目录 一.KMP算法解决的问题 二.Manacher算法解决的问题 基本概念 优化 一.KMP算法解决的问题 暴力求解复杂度O(N*M) next数组:next[i]表示arr[0...i-1]的前缀和后缀的最长公共长度。 Y位置失败,将前缀和后缀完全匹配,将前缀的部分和后缀对…

软件测试---接口自动化

一、pythonrequests模块 (1)requests全局观 安装:pip install requests 1.发送请求 ①requests.get() 发送get请求 ②requests.post() 发送post请求 data和json的区别:取绝于你需要传递的参数的类型。 files:文件上…

大学成长之路:如何从烧锅炉的逆袭成为FPGA大厂高管

如何从烧锅炉的逆袭成为FPGA大厂Sales Director 在即将到来的开学季,很多学子从高中生成为一个大学生,走入新的征程。大学生涯是人生的一个非常重要的阶段,如何度过大学4年的时光,并学有所成,是很多大学新生和家长思考…

Spring IoCDI(下)—DI的尾声

我们之前学习了控制反转IoC,接下来就开始学习依赖注入DI的细节。 依赖注入是一个过程,是指IoC容器在创建Bean时,去提供运行时所依赖的资源,而资源指的就是对象。我们使用 Autowired 注解,完成依赖注入的操作。简单来说…

AMBA-CHI协议详解(六)

AMBA-CHI协议详解(一) AMBA-CHI协议详解(二) AMBA-CHI协议详解(三) AMBA-CHI协议详解(四) AMBA-CHI协议详解(五) AMBA-CHI协议详解(六&#xff09…

JavaSocket编程+JDBC实战技术

一、JavaSocket编程 1.1HTTP协议 后端原理 2. 特点 同步:就是两个任务执行的过程中,其中一个任务要等另一个任务完成某各阶段性工作才能继续执行,如厨师A炒番茄,将葱花放入锅中,然后需要放入番茄,但是厨…

【自动驾驶】控制算法(二)三大坐标系与车辆运动学模型

写在前面: 🌟 欢迎光临 清流君 的博客小天地,这里是我分享技术与心得的温馨角落。📝 个人主页:清流君_CSDN博客,期待与您一同探索 移动机器人 领域的无限可能。 🔍 本文系 清流君 原创之作&…

Dubbo服务自动Web化之路

本文字数:6047字 预计阅读时间:40分钟 01 故障出现 事情起源于一次故障,2023年12月14日14点26分,大量Dubbo服务报出异常,无法链接zookeeper集群: Session 0x0 for server dubboZk.xxx.com/10.x.x.x:2181, C…

【高校科研前沿】南方科技大学冯炼教授等人在遥感顶刊RSE发文:全球人类改造的基塘系统制图

1.文章简介 论文名称:Global mapping of human-transformed dike-pond systems(全球人类改造的基塘系统制图) 第一作者及单位:Yang Xu(南方科技大学环境学院) 第一通讯作者及单位:冯炼&#x…

机器学习:线性回归算法(一元和多元回归代码)

1、线性回归 1、数据准备: 描述如何获取和准备数据。 2、图像预处理: 包括图像读取。 3、将数据划分为训练集和测试集。 4、计算数据的相关系数矩阵。 5、模型训练: 详细说明如何使用线性回归算法训练模型&…

京东2025届秋招 算法开发工程师 第2批笔试

目录 1. 第一题2. 第二题3. 第三题 ⏰ 时间:2024/08/17 🔄 输入输出:ACM格式 ⏳ 时长:2h 本试卷还有选择题部分,但这部分比较简单就不再展示。 1. 第一题 村子里有一些桩子,从左到右高度依次为 1 , 1 2…

达梦数据库的系统视图v$reserved_words

达梦数据库的系统视图v$reserved_words 达梦数据库(DM Database)提供了一系列系统视图以帮助管理员和开发人员了解数据库的状态和配置。V$RESERVED_WORDS 是其中一个系统视图,它显示了数据库系统中已保留的关键字。这些关键字在SQL语句中具有…

SpringBoot自动配置--原理探究

什么是自动配置? SpringBoot自动配置是指在SpringBoot应用启动时,可以把一些配置类自动注入到Spring的IOC容器中,项目运行时可以直接使用这些配置类的属性。简单来说就是用注解来对一些常规的配置做默认配置,简化xml配置内容&…

【三维目标检测】H3DNet(三)

【版权声明】本文为博主原创文章,未经博主允许严禁转载,我们会定期进行侵权检索。 参考书籍:《人工智能点云处理及深度学习算法》 H3DNet数据和源码配置调试过程以及主干网络介绍请参考上一篇博文:【三维目标检测】H3DNet&am…