CenterNet官方代码—目标检测模型推理部分解析与项目启动

news2024/11/15 1:33:43

CenterNet模型推理部分解析

CenterNet官方代码环境部署

CenterNet作为2019年CVPR推出的论文,论文中给出了官方代码所在的github仓库地址。https://github.com/xingyizhou/CenterNet。

在这里插入图片描述

整个代码的代码量并不是特别大,但整个项目的难点在于使用了老版本的pytorch<1.0与一些依赖的资源库从而导致了整个项目在启动和加载时或产生很多的错误。

GPU版本安装问题

在这里插入图片描述
在官方代码的Readme文件夹的下面给出了安装步骤的文件。首先文件中说明了当时使用的时python3.6版本。(不建议使用python3.6版本

自己使用的conda激活虚拟环境后使用的是3.8版本。

对出错的部分没有进行截图,因此只是简单的进行一定的陈述。

  1. 错误一:使用pip来安装CUDA的pytorch版本,我在安装的时候发现,如果使用pip来安装torch会安装的是torch+cu这种版本。而使用conda安装的话会安装一些CUDA相关的依赖

在这里插入图片描述
而在安装dcnv2的过程中需要CUDA相关库的支持。所以pytorch要使用conda来进行安装。 (最好之间安装Gpu的版本)

gpu.c相关文件也需要编译而且容易报错。

  1. DCNV2文件对torch版本的不支持,(自己当时使用git切换conda的虚拟环境安装.make.sh文件时候报错)需要使用最新的DCNV2库,卸载之后重新安装替代。

在这里插入图片描述

  1. 踩坑的地方在启动setup.py文件之后,可能会产生两个错误,第一个是因为缺少c++的编译环境,缺少支持库的错误。第二个是说c++库的版本过高不支持

网上查的原因是因为CUDA更新比较慢,而vistual Studio的更新过快。我自己将2022的版本降为2019年的启动成功。

  1. 一定要下载nvcc完全对应的版本,我自己的是12.1但当时安装习惯安装的是11.8nvcc会报错

在这里插入图片描述
总结:我个人使用的是pytorch2.3.0的版本来进行安装的,如果安装过程中出现错误,可以联系博主简单交流。

模型推理部分代码

检测部分的启动参数示例

ctdet
–demo
…/images/17790319373_bd19b24cfc_k.jpg
–load_model
…/models/ctdet_coco_dla_2x.pth

启动部分会产生一个模型下载失败的错误,将模型自己手动的下载到C盘的cache->torch…文件下面即可以启动模型成功。

自己断点调试,报错的一部分主要是网络错误,作用是加载预训练的一部分权重数据。

在这里插入图片描述

和Yolo一样也是可以检测视频文件的。

断点调试分析

在这里插入图片描述

推理过程整体的执行流程图与代码

在这里插入图片描述

def demo(opt):
  os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str #指定GPU对哪些环境变量可用
  opt.debug = max(opt.debug, 1) # 设置debug属性便于进行调试
  Detector = detector_factory[opt.task]
  detector = Detector(opt)

  if opt.demo == 'webcam' or \
    opt.demo[opt.demo.rfind('.') + 1:].lower() in video_ext:
    cam = cv2.VideoCapture(0 if opt.demo == 'webcam' else opt.demo)
    detector.pause = False
    while True:
        _, img = cam.read()
        cv2.imshow('input', img)
        ret = detector.run(img)
        time_str = ''
        for stat in time_stats:
          time_str = time_str + '{} {:.3f}s |'.format(stat, ret[stat])
        print(time_str)
        if cv2.waitKey(1) == 27:
            return  # esc to quit
  else:
    if os.path.isdir(opt.demo): # 如果参数中的待检测数据是文件夹
      image_names = []
      ls = os.listdir(opt.demo)
      for file_name in sorted(ls):
          ext = file_name[file_name.rfind('.') + 1:].lower()
          if ext in image_ext:
              image_names.append(os.path.join(opt.demo, file_name))
    else: # 普通文件
      image_names = [opt.demo] #获取文件名
    
    for (image_name) in image_names:
      ret = detector.run(image_name) # 通过检测网络执行推理(核心)
      time_str = ''
      for stat in time_stats:
        time_str = time_str + '{} {:.3f}s |'.format(stat, ret[stat])
      print(time_str)

核心推理函数的细化分析

在这里插入图片描述

run()函数整体的部分

  def run(self, image_or_path_or_tensor, meta=None):
    load_time, pre_time, net_time, dec_time, post_time = 0, 0, 0, 0, 0 #时间变量初始化
    merge_time, tot_time = 0, 0
    debugger = Debugger(dataset=self.opt.dataset, ipynb=(self.opt.debug==3), # Debugger便于调试与错误诊断
                        theme=self.opt.debugger_theme)
    start_time = time.time()
    pre_processed = False
    if isinstance(image_or_path_or_tensor, np.ndarray):
      image = image_or_path_or_tensor
    elif type(image_or_path_or_tensor) == type (''): 
      image = cv2.imread(image_or_path_or_tensor) # 通过opencv读取图片
    else:
      image = image_or_path_or_tensor['image'][0].numpy()
      pre_processed_images = image_or_path_or_tensor
      pre_processed = True
    
    loaded_time = time.time()
    load_time += (loaded_time - start_time)
    
    detections = [] # 检测列表
    for scale in self.scales:
      scale_start_time = time.time()
      if not pre_processed:
        images, meta = self.pre_process(image, scale, meta) # 执行图片的预处理
      else:
        # import pdb; pdb.set_trace()
        images = pre_processed_images['images'][scale][0]
        meta = pre_processed_images['meta'][scale]
        meta = {k: v.numpy()[0] for k, v in meta.items()}
      images = images.to(self.opt.device) # 将image张量移动到GPU上
      torch.cuda.synchronize() # 确保CUDA操作完成
      pre_process_time = time.time()
      pre_time += pre_process_time - scale_start_time
      
      output, dets, forward_time = self.process(images, return_time=True)

      torch.cuda.synchronize()
      net_time += forward_time - pre_process_time # 根据之前的函数计算出推理解码等一系列的时间
      decode_time = time.time()
      dec_time += decode_time - forward_time
      
      if self.opt.debug >= 2:
        self.debug(debugger, images, dets, output, scale)
      
      dets = self.post_process(dets, meta, scale) # 执行后处理的过程(得到80个类别的列表信息部分含有数据)
      torch.cuda.synchronize()
      post_process_time = time.time()
      post_time += post_process_time - decode_time # 计算后处理时间

      detections.append(dets)
    
    results = self.merge_outputs(detections)
    torch.cuda.synchronize()
    end_time = time.time()
    merge_time += end_time - post_process_time
    tot_time += end_time - start_time

    if self.opt.debug >= 1:
      self.show_results(debugger, image, results) # 调用opencv函数显示结果
    
    return {'results': results, 'tot': tot_time, 'load': load_time, # 放入所有的信息进行返回
            'pre': pre_time, 'net': net_time, 'dec': dec_time,
            'post': post_time, 'merge': merge_time}

预处理函数的执行部分

  def pre_process(self, image, scale, meta=None):
    height, width = image.shape[0:2]
    new_height = int(height * scale)
    new_width  = int(width * scale)
    if self.opt.fix_res: # 转为输入大小512 x 512
      inp_height, inp_width = self.opt.input_h, self.opt.input_w
      c = np.array([new_width / 2., new_height / 2.], dtype=np.float32)
      s = max(height, width) * 1.0
    else:
      inp_height = (new_height | self.opt.pad) + 1
      inp_width = (new_width | self.opt.pad) + 1
      c = np.array([new_width // 2, new_height // 2], dtype=np.float32)
      s = np.array([inp_width, inp_height], dtype=np.float32)

    trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height])# 获取仿射变换矩阵
    resized_image = cv2.resize(image, (new_width, new_height))# 读取图片裁剪
    inp_image = cv2.warpAffine(
      resized_image, trans_input, (inp_width, inp_height),
      flags=cv2.INTER_LINEAR) # 应用之前的仿射矩阵进行仿射变换
    inp_image = ((inp_image / 255. - self.mean) / self.std).astype(np.float32) # 归一化处理

    images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height, inp_width)
    if self.opt.flip_test:
      images = np.concatenate((images, images[:, :, :, ::-1]), axis=0)
    images = torch.from_numpy(images) # 转为张量的格式
    meta = {'c': c, 's': s, 
            'out_height': inp_height // self.opt.down_ratio, # opt.down_ratio:下采样倍率
            'out_width': inp_width // self.opt.down_ratio} # 储存c:中心点坐标 s:缩放因子
    return images, meta # 128x128(下采样4倍的特征图)

解码扩增部分对应的代码信息

def ctdet_decode(heat, wh, reg=None, cat_spec_wh=False, K=100):
    batch, cat, height, width = heat.size() # cat:表示类别的数目

    # heat = torch.sigmoid(heat)
    # perform nms on heatmaps
    heat = _nms(heat) # maxpooling(3x3)类似nms算法
    # scores:得分最高的 K 个点的得分。
    # inds:得分最高的 K 个点的索引。
    # clses:得分最高的 K 个点对应的类别。
    # ys:得分最高的 K 个点的 y 坐标。
    # xs:得分最高的 K 个点的 x 坐标。
    scores, inds, clses, ys, xs = _topk(heat, K=K) # 在热力图 heat 上执行 Top-K 操作提取最高 K 个得分的点提取最高 K 个得分的点
    if reg is not None: # 位置微调的过程
      reg = _transpose_and_gather_feat(reg, inds) # 将回归参数 reg 按照 inds(Top-K 操作返回的索引)重新排列,以便收集与 Top-K 操作选择的热力图位置相对应的回归参数
      reg = reg.view(batch, K, 2) # 收集到的回归参数 reg 调整为形状 [batch, K, 2] 的张量,其中 batch 是批次大小,K 是 Top-K 操作选择的点的数量,2 表示每个点的两个回归参数(通常是一个点的 x 和 y 方向的偏移量)
      xs = xs.view(batch, K, 1) + reg[:, :, 0:1] # x坐标与回归偏移量的坐标相加
      ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
    else:
      xs = xs.view(batch, K, 1) + 0.5
      ys = ys.view(batch, K, 1) + 0.5
    wh = _transpose_and_gather_feat(wh, inds) # 特征图的维度从 [batch, num_anchors, height, width] 转换为 [batch, K, num_anchors]
    if cat_spec_wh: # 是否每个类别有特定的宽度和高度预测
      wh = wh.view(batch, K, cat, 2) # wh 张量的形状调整为 [batch, K, num_classes, 2]
      clses_ind = clses.view(batch, K, 1, 1).expand(batch, K, 1, 2).long() # 创建一个索引张量,用于从 wh 中收集特定类别的宽度和高度
      wh = wh.gather(2, clses_ind).view(batch, K, 2) # 得到结果值
    else:
      wh = wh.view(batch, K, 2)
    clses  = clses.view(batch, K, 1).float()
    scores = scores.view(batch, K, 1)
    bboxes = torch.cat([xs - wh[..., 0:1] / 2,  # torch.cat 函数将四个坐标点(左上角和右下角)拼接成一个检测框的坐标张量
                        ys - wh[..., 1:2] / 2,
                        xs + wh[..., 0:1] / 2, # xs + wh[..., 0:1] / 2 和 ys + wh[..., 1:2] / 2:计算右下角的坐标。
                        ys + wh[..., 1:2] / 2], dim=2) # xs - wh[..., 0:1] / 2 和 ys - wh[..., 1:2] / 2:计算左上角的坐标。
    detections = torch.cat([bboxes, scores, clses], dim=2)
      
    return detections # 连接之后返回检测框结果(论文中提到的扩增的过程)

后处理函数的执行部分

  def post_process(self, dets, meta, scale=1):
    dets = dets.detach().cpu().numpy() # 转为numpy格式进入cpu进行计算
    dets = dets.reshape(1, -1, dets.shape[2])
    dets = ctdet_post_process( # 尺度变换、坐标转换、非极大值抑制(NMS) 传入的meta['c']中心点 meta['s']缩放尺度
        dets.copy(), [meta['c']], [meta['s']],
        meta['out_height'], meta['out_width'], self.opt.num_classes)
    for j in range(1, self.num_classes + 1): # 后处理转换类型
      dets[0][j] = np.array(dets[0][j], dtype=np.float32).reshape(-1, 5)
      dets[0][j][:, :4] /= scale
    return dets[0]

剩余的部分函数不在进行说明了,但模型推理代码的复现流程主要还是需要按照我自己绘制的流程图进行debug

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2140359.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

横向移动-WMI

什么是WMI? WMI是基于 Web 的企业管理 (WBEM) 的 Windows 实现&#xff0c;WBEM 是跨设备访问管理信息的企业标准。 WBEM&#xff08;Web-Based Enterprise Management&#xff09;是一个开放标准&#xff0c;用于跨平台和跨设备的管理信息访问。WMI&#xff08;Windows Mana…

VMware Fusion虚拟机Mac版 安装Win10系统教程

Mac分享吧 文章目录 Win10安装完成&#xff0c;软件打开效果一、VMware安装Windows10虚拟机1️⃣&#xff1a;准备镜像2️⃣&#xff1a;创建虚拟机3️⃣&#xff1a;虚拟机设置4️⃣&#xff1a;安装虚拟机&#xff08;步骤和Win11安装步骤类似&#xff0c;此处相同步骤处没换…

C++从入门到起飞之——继承下篇(万字详解) 全方位剖析!

&#x1f308;个人主页&#xff1a;秋风起&#xff0c;再归来~&#x1f525;系列专栏&#xff1a;C从入门到起飞 &#x1f516;克心守己&#xff0c;律己则安 目录 1、派⽣类的默认成员函数 1.1 四个常⻅默认成员函数 1.2 实现⼀个不能被继承的类 ​编辑 2. 继承与友…

词嵌入(二):基于上下文窗口的静态词嵌入(从NNLM、CW模型谈到基于层次Softmax、负采样的Word2Vec模型)

文章目录 一、经典神经语言模型&#xff08;A Neural Probabilistic Language Model&#xff09;二、C&W模型 (Collobert and Weston, 2008)2.1 文章背景2.2 模型架构&#xff08;词向量的表示&#xff09;2.2.1 Lookup-Table Layer&#xff08;查找表&#xff09;2.2.2 TD…

STM32关于keil使用过程中遇到的问题

1.设备管理器STlink驱动确认安装完成&#xff0c;但是keil里一直识别不到&#xff0c;换下载器也没用 &#xff08;1&#xff09;问题描述 我的问题是这样产生的&#xff1a;之前用标准库开发STM32的时候&#xff0c;STLink能够正常使用&#xff0c;然后使用HAL库开发的时候出…

仓储管理系统的设计与实现SSM框架

&#x1f497;博主介绍&#x1f497;&#xff1a;✌在职Java研发工程师、专注于程序设计、源码分享、技术交流、专注于Java技术领域和毕业设计✌ 温馨提示&#xff1a;文末有 CSDN 平台官方提供的老师 Wechat / QQ 名片 :) Java精品实战案例《700套》 2025最新毕业设计选题推荐…

electron react离线使用monaco-editor

目录 1.搭建一个 electron-vite 项目 2.安装monaco-editor/react和monaco-editor 3.引入并做monaco-editor离线配置 4.react中使用 5.完整代码示例 6.monaco-editor离线配置官方说明 7.测试 1.搭建一个 electron-vite 项目 pnpm create quick-start/electron 参考链接…

React学习day06-异步操作、ReactRouter的概念及简单使用

13、续 &#xff08;8&#xff09;异步状态操作 1&#xff09;在子仓库中 ①创建仓库 ②解构需要的方法 ③安装axios ④封装并导出请求 ⑤在reducer中为newsList赋值 ⑥获取并导出reducer函数 2&#xff09;在入口文件index.js中&#xff0c;注入 3&#xff09;在App.js中&a…

Vue.js入门系列(二十九):深入理解编程式路由导航、路由组件缓存与路由守卫

个人名片 &#x1f393;作者简介&#xff1a;java领域优质创作者 &#x1f310;个人主页&#xff1a;码农阿豪 &#x1f4de;工作室&#xff1a;新空间代码工作室&#xff08;提供各种软件服务&#xff09; &#x1f48c;个人邮箱&#xff1a;[2435024119qq.com] &#x1f4f1…

爬虫--翻页tips

免责声明&#xff1a;本文仅做分享&#xff01; 伪线程 from DrissionPage import ChromiumPage import timepage ChromiumPage() page.get("https://you.ctrip.com/sight/taian746.html") # 初始化 第0页 index_page 0# 翻页点击函数 sleep def page_turn():page…

【Linux修行路】网络套接字编程——UDP

目录 ⛳️推荐 前言 六、Udp Server 端代码 6.1 socket——创建套接字 6.2 bind——将套接字与一个 IP 和端口号进行绑定 6.3 recvfrom——从服务器的套接字里读取数据 6.4 sendto——向指定套接字中发送数据 6.5 绑定 ip 和端口号时的注意事项 6.5.1 云服务器禁止直接…

C++复习day12

IO流 一、C语言的输入和输出 C语言中我们用到的最频繁的输入输出方式就是scanf ()与printf()。 scanf(): 从标准输入设备(键 盘)读取数据&#xff0c;并将值存放在变量中。printf(): 将指定的文字/字符串输出到标准输出设备(屏幕)。 注意宽度输出和精度输出控制。C语言借助了…

【C++】多态and多态原理

目录 一、多态的概念 二、多态的定义及实现 &#x1f31f;多态的构成条件 &#x1f31f;虚函数 &#x1f31f;虚函数的重写 &#x1f320;小贴士&#xff1a; &#x1f31f;C11 override 和 final &#x1f31f;重载、重写&#xff08;覆盖&#xff09;、重定义&#xf…

POD内的容器之间的资源共享

概述 摘要&#xff1a;本文通过实践描述并验证了pod内容器如何实现网络、文件、PID、UTC、mount的共享。 pod实战之容器内资源共享与隔离 container容器之间的共享实战 从实际场景说起&#xff1a;有2个容器nginx与wordpress分别运行了紧密耦合且需要共享资源的应用程序。我…

英语学习交流平台|基于java的英语学习交流平台系统小程序(源码+数据库+文档)

英语学习交流平台系统小程序 目录 基于java的英语学习交流平台系统小程序 一、前言 二、系统设计 三、系统功能设计 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取&#xff1a; 博主介绍&#xff1a;✌️大厂码农|毕设布道师&…

基于SpringBoot的校园社团活动管理系统设计与实现

文未可获取一份本项目的java源码和数据库参考。 一、设计&#xff08;论文&#xff09;研究背景与意义 在当今的社会&#xff0c;可以说是信息技术的发展时代&#xff0c;在社会的方方面面无不涉及到各种信息的处理。[1]信息是人们对客观世界的具体描述&#xff0c;是人们进行…

性能优化一:oracle 锁的原则

文章目录 锁的原则查看具体会话阻塞过程 锁的原则 1、只有被修改时,行才会被锁定。 2、当条语句修改了一条记录,只有这条记录上被锁定,在Oracle数据库中不存在锁升 3、当某行被修改时 &#xff0c;它将阻塞别人对它的修改。 4、当一个事务修改一行时.将在这个行上加上行锁(TX…

测试开发基础——测试用例的设计

三、测试用例的设计 1. 什么是测试用例 测试用例(Test Case)是为了实施测试而向被测试的系统提供的一组集合&#xff0c;这组集合包含&#xff1a;测试环境、操作步骤、测试数据、预期结果等要素。 设计测试用例原则一&#xff1a;测试用例中一个必需部分是对预期输出或结果进…

带你如何使用CICD持续集成与持续交付

目录 一、CICD是什么 1.1 持续集成&#xff08;Continuous Integration&#xff09; 1.2 持续部署&#xff08;Continuous Deployment&#xff09; 1.3 持续交付&#xff08;Continuous Delivery&#xff09; 二、git工具使用 2.1 git简介 2.2 git的工作流程 2.3 部署g…

【MRI基础】Partial volume 伪影

基本概念 partial volume 伪影是 MRI 中的一种常见伪影&#xff0c;当图像中的体素包含不同组织类型或结构的混合时就会出现这种伪影。这种伪影是由于成像系统的空间分辨率有限而产生的&#xff0c;导致具有不同信号强度的相邻结构在一个体素内混合在一起。 抑制MRI 中的parti…