实践案例丨CenterNet-Hourglass论文复现

news2025/1/11 3:00:59
摘要:本案例是CenterNet-Hourglass论文复现的体验案例,此模型是对Objects as Points 中提出的CenterNet进行结果复现。

本文分享自华为云社区《CenterNet-Hourglass (物体检测/Pytorch)》,作者:HWCloudAI。

目标检测常采用Anchor的方法来获取物体可能存在的位置,再对该位置进行分类,这样的做法耗时、低效,同时需要后处理(比如NMS)。CenterNet将目标看成一个点,即目标bounding box的中心点,整个问题转变成了关键点估计问题,其他目标属性,比如尺寸、3D位置、方向和姿态等都以估计的中心点为基准进行参数回归。

本案例是CenterNet-Hourglass论文复现的体验案例,此模型是对Objects as Points 中提出的CenterNet进行结果复现(原论文Table 2 最后一行)。本模型是以Hourglass网络架构作为backbone,以ExtremNet 作为预训练模型,在COCO数据集上进行50epochs的训练后得到的。本项目是基于原论文的官方代码进行针对ModelArts平台的修改来实现ModelArts上的训练与部署。

具体算法介绍:AI Gallery_算法_模型_云市场-华为云

注意事项:

1.本案例使用框架:PyTorch1.4.0

2.本案例使用硬件:GPU: 1*NVIDIA-V100NV32(32GB) | CPU: 8 核 64GB

3.运行代码方法: 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码

4.JupyterLab的详细用法: 请参考《ModelAtrs JupyterLab使用指导》

5.碰到问题的解决办法: 请参考《ModelAtrs JupyterLab常见问题解决办法》

1.下载数据和代码

运行下面代码,进行数据和代码的下载和解压

本案例使用COCO数据集。

import os
#数据代码下载
!wget https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/algorithm/CenterNet.zip
# 解压缩
os.system('unzip  CenterNet.zip -d ./')

--2021-06-25 17:50:11--  https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/algorithm/CenterNet.zip
Resolving proxy-notebook.modelarts.com (proxy-notebook.modelarts.com)... 192.168.6.62
Connecting to proxy-notebook.modelarts.com (proxy-notebook.modelarts.com)|192.168.6.62|:8083... connected.
Proxy request sent, awaiting response... 200 OK
Length: 1529663572 (1.4G) [application/zip]
Saving to: ‘CenterNet.zip’
CenterNet.zip       100%[===================>] 1.42G   279MB/s    in 5.6s
2021-06-25 17:50:16 (261 MB/s) - ‘CenterNet.zip’ saved [1529663572/1529663572]
0

2.训练

2.1依赖库加载和安装

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
root_path = './CenterNet/'
os.chdir(root_path)
os.system('pip install pycocotools')
import _init_paths
import torch
import torch.utils.data
from opts import opts
from models.model import create_model, load_model, save_model
from models.data_parallel import DataParallel
from logger import Logger
from datasets.dataset_factory import get_dataset
from trains.train_factory import train_factory
from evaluation import test, prefetch_test, image_infer
USE_MODELARTS = True
INFO:root:Using MoXing-v2.0.0.rc0-19e4d3ab
INFO:root:Using OBS-Python-SDK-3.20.9.1
NMS not imported! If you need it, do 
 cd $CenterNet_ROOT/src/lib/external 
 make

2.2训练函数

def main(opt):
 torch.manual_seed(opt.seed)
 torch.backends.cudnn.benchmark = not opt.not_cuda_benchmark and not opt.test
  Dataset = get_dataset(opt.dataset, opt.task)
  opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
  logger = Logger(opt)
 os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str
 opt.device = torch.device('cuda' if opt.gpus[0] >= 0 else 'cpu') 
 print('Creating model...')
  model = create_model(opt.arch, opt.heads, opt.head_conv)
  optimizer = torch.optim.Adam(model.parameters(), opt.lr)
 start_epoch = 0
 if opt.load_model != '':
    model, optimizer, start_epoch = load_model(
      model, opt.load_model, optimizer, opt.resume, opt.lr, opt.lr_step)
  Trainer = train_factory[opt.task]
  trainer = Trainer(opt, model, optimizer)
 trainer.set_device(opt.gpus, opt.chunk_sizes, opt.device)
 print('Setting up data...')
 train_loader = torch.utils.data.DataLoader(
 Dataset(opt, 'train'), 
 batch_size=opt.batch_size, 
      shuffle=True,
 num_workers=opt.num_workers,
 pin_memory=True,
 drop_last=True
 )
 print('Starting training...')
  best = 1e10
 for epoch in range(start_epoch + 1, opt.num_epochs + 1):
    mark = epoch if opt.save_all else 'last'
 log_dict_train, _ = trainer.train(epoch, train_loader)
 logger.write('epoch: {} |'.format(epoch))
 for k, v in log_dict_train.items():
 logger.scalar_summary('train_{}'.format(k), v, epoch)
 logger.write('{} {:8f} | '.format(k, v))
 save_model(os.path.join(opt.save_dir, 'model_last.pth'), 
                 epoch, model)
 logger.write('\n')
 if epoch in opt.lr_step:
 save_model(os.path.join(opt.save_dir, 'model_{}.pth'.format(epoch)), 
                 epoch, model, optimizer)
 lr = opt.lr * (0.1 ** (opt.lr_step.index(epoch) + 1))
 print('Drop LR to', lr)
 for param_group in optimizer.param_groups:
 param_group['lr'] = lr
 logger.close()

2.3开始训练

训练需要一点时间,请耐心等待

if __name__ == '__main__':
  opt = opts().parse()
 if USE_MODELARTS:
 pwd = os.getcwd()
 print('Copying dataset to work space...')
 print('Listing directory: ')
 print(os.listdir())
 if not os.path.exists(opt.save_dir):
 os.makedirs(opt.save_dir)
  main(opt)
 if USE_MODELARTS:
 print("Processing model checkpoints & service config for deployment...")
 if not opt.eval:
 infer_dir = os.path.join(opt.save_dir, 'model')
 os.makedirs(infer_dir)
 os.system(f'mv ./trained_model/* {infer_dir}')
 pretrained_pth = os.path.join(infer_dir, '*.pth')
 ckpt_dir = os.path.join(opt.save_dir, 'checkpoints')
 os.makedirs(ckpt_dir)
 os.system(f'mv {pretrained_pth} {ckpt_dir}')
 pth_files = os.path.join(opt.save_dir, '*.pth')
 infer_pth = os.path.join(ckpt_dir, f'{opt.model_deploy}.pth')
 os.system(f'mv {pth_files} {ckpt_dir}')
 os.system(f'mv {infer_pth} {infer_dir}')
 print(os.listdir(opt.save_dir))
 print("ModelArts post-training work is done!")
Fix size testing.
training chunk_sizes: [8]
The output will be saved to  ./output/exp/ctdet/default
Copying dataset to work space...
Listing directory: 
['pre-trained_weights', '.ipynb_checkpoints', 'coco_eval.py', 'train.py', 'coco', 'output', 'training_logs', 'trained_model', '_init_paths.py', '__pycache__', 'coco_classes.py', 'lib', 'evaluation.py']
heads {'hm': 80, 'wh': 2, 'reg': 2}
Creating model...
loaded ./trained_model/epoch_50_mAP_42.7.pth, epoch 50
Setting up data...
==> initializing coco 2017 train data.
loading annotations into memory...
Done (t=0.54s)
creating index...
index created!
Loaded train 5000 samples
Starting training...
/home/ma-user/anaconda3/envs/Pytorch-1.4.0/lib/python3.6/site-packages/torch/nn/_reduction.py:43: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
 warnings.warn(warning.format(ret))
ctdet/default| train: [1][0/625] |loss 1.7568 |hm_loss 1.3771 |wh_loss 1.9394 |off_loss 0.1857 |Data 0.384s (0.384s) |Net 5.019s (5.019s)
ctdet/default| train: [1][200/625] |loss 1.9275 |hm_loss 1.4429 |wh_loss 2.7269 |off_loss 0.2119 |Data 0.001s (0.003s) |Net 0.759s (0.779s)
ctdet/default| train: [1][400/625] |loss 1.9290 |hm_loss 1.4430 |wh_loss 2.7423 |off_loss 0.2118 |Data 0.001s (0.002s) |Net 0.760s (0.770s)
ctdet/default| train: [1][600/625] |loss 1.9276 |hm_loss 1.4397 |wh_loss 2.7623 |off_loss 0.2117 |Data 0.001s (0.002s) |Net 0.765s (0.767s)
Processing model checkpoints & service config for deployment...
['model', 'logs_2021-06-25-17-51', 'opt.txt', 'checkpoints']
ModelArts post-training work is done!

3.模型测试

3.1推理函数

# -*- coding: utf-8 -*-
# TODO 添加模型运行需要导入的模块
import os
import torch
import numpy as np
from PIL import Image
from io import BytesIO
from collections import OrderedDict
import cv2
import sys
sys.path.insert(0, './lib')
from opts import opts
from coco_classes import coco_class_map
from detectors.detector_factory import detector_factory
class ModelClass():
 def __init__(self, model_path):
 self.model_path = model_path # 本行代码必须保留,且无需修改
 self.opt = opts().parse()
 self.opt.num_classes = 80
 self.opt.resume = True
 self.opt.keep_res = True
 self.opt.fix_res = False
 self.opt.heads = {'hm': 80, 'wh': 2, 'reg': 2}
 self.opt.load_model = model_path
 self.opt.mean = np.array([0.40789654, 0.44719302, 0.47026115],
 dtype=np.float32).reshape(1, 1, 3)
 self.opt.std = np.array([0.28863828, 0.27408164, 0.27809835],
 dtype=np.float32).reshape(1, 1, 3)
 self.opt.batch_infer = False
 # configurable varibales:
 if 'BATCH_INFER' in os.environ:
 print('Batch inference mode!')
 self.opt.batch_infer = True
 if 'FLIP_TEST' in os.environ:
 print('Flip test!')
 self.opt.flip_test = True
 if 'MULTI_SCALE' in os.environ:
 print('Multi scale!')
 self.opt.test_scales = [0.5,0.75,1,1.25,1.5]
 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 if not torch.cuda.is_available():
 self.opt.gpus = [-1]
 self.class_map = coco_class_map()
 torch.set_grad_enabled(False)
        Detector = detector_factory[self.opt.task]
 self.detector = Detector(self.opt)
 print('load model success')
 def predict(self, file_name):
        image = Image.open(file_name).convert('RGB')
 img = np.array(image)
 img = img[:, :, ::-1]
        results = self.detector.run(img)['results']
        image = cv2.cvtColor(np.asarray(image),cv2.COLOR_RGB2BGR)
 if not self.opt.batch_infer:
 for c_id, dets in results.items():
 for det in dets:
 if det[4] > self.opt.vis_thresh:
                        scores = str(round(float(det[4]), 4))
                        classes = self.class_map[c_id]
                        image = cv2.rectangle(image,(int(det[0]),int(det[1])),(int(det[2]),int(det[3])),(0,255,0),2)
                        image = cv2.putText(image,classes+':'+scores,(int(det[0]),int(det[1])),cv2.FONT_HERSHEY_SIMPLEX,0.7,(0,0,255),2)
 else:
 for c_id, dets in results.items():
 for det in dets:
                    scores = str(round(float(det[4]), 4))
                    classes = self.class_map[c_id]
                    image = cv2.rectangle(image,(int(det[0]),int(det[1])),(int(det[2]),int(det[3])),(0,255,0),2)
                    image = cv2.putText(image,classes+':'+scores,(int(det[0]),int(det[1])),cv2.FONT_HERSHEY_SIMPLEX,0.5,(0,0,255),2)
 return image

3.2开始推理

可以自行修改预测的图像路径

if __name__ == '__main__':
 import matplotlib.pyplot as plt
 img_path = './coco/train/000000021903.jpg' 
 model_path = './output/exp/ctdet/default/model/model_last.pth' #模型的保存路径,你可以自己找一下
 # 以下代码无需修改
 my_model = ModelClass(model_path)
    result = my_model.predict(img_path)
    result = Image.fromarray(cv2.cvtColor(result,cv2.COLOR_BGR2RGB))
 plt.figure(figsize=(10,10)) #设置窗口大小
 plt.imshow(result)
 plt.show()

Fix size testing.
training chunk_sizes: [8]
The output will be saved to  ./output/exp/ctdet/default
Creating model...
loaded ./output/exp/ctdet/default/model/model_last.pth, epoch 1
load model success

点击关注,第一时间了解华为云新鲜技术~

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/63372.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【正点原子FPGA连载】第二十七章 MDIO接口读写测试实验 摘自【正点原子】DFZU2EG/4EV MPSoC 之FPGA开发指南V1.0

1)实验平台:正点原子MPSoC开发板 2)平台购买地址:https://detail.tmall.com/item.htm?id692450874670 3)全套实验源码手册视频下载地址: http://www.openedv.com/thread-340252-1-1.html 第二十七章 MDIO…

字典类型和字典函数、字典方法

字典类型 (无序&#xff0c;不能重复) 通过任意键信息查找一组数据中值信息的过程叫映射&#xff0c; Python语言中通过字典实现映射。 Python语言中的字典可以通过大括号({})建立&#xff0c;建立模式如下&#xff1a; {<键1>:<值1>,<键2>:<值2>,...,…

[附源码]Python计算机毕业设计SSM健身房管理系统(程序+LW)

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

p15~p22基本链表容器和高级链表容器迭代器

STL一、自制链表容器/基本链表容器1.1 首/尾部增删节点1.2 获取首/尾部的元素1.3 清空链表7 / 判空链表 / 链表大小81.4 缺省构造0/拷贝构造10/析构函数91.5 输出流操作符重载二、迭代器原理2.1 迭代器概念2.2 迭代器的分类三、迭代器实现3.1 正向非常迭代类3.2 正向非常迭代器…

html旅游网站设计与实现——绿色古典旅游景区 HTML+CSS+JavaScript

&#x1f468;‍&#x1f393;学生HTML静态网页基础水平制作&#x1f469;‍&#x1f393;&#xff0c;页面排版干净简洁。使用HTMLCSS页面布局设计,web大学生网页设计作业源码&#xff0c;这是一个不错的旅游网页制作&#xff0c;画面精明&#xff0c;排版整洁&#xff0c;内容…

解析仓库管理系统对于企业的重要性

仓储管理的职责是有效的保存和管理仓库内的物资&#xff0c;这些物资是指仓库内所有的有形物品以及无形的资产。以前很多企业都是依靠人工方式对库房的管理&#xff0c;难免会造成一些难以解决的问题&#xff1a; 仓库种类太多&#xff0c;查看困难&#xff1b;仓库信息记录不…

Java应用程序安全框架

《从零打造项目》系列文章 工具 比MyBatis Generator更强大的代码生成器 ORM框架选型 SpringBoot项目基础设施搭建SpringBoot集成Mybatis项目实操SpringBoot集成MybatisPlus项目实操SpringBoot集成Spring Data JPA项目实操 数据库变更管理 数据库变更管理&#xff1a;Liquibase…

Word控件Spire.Doc 【图像形状】教程(11): 如何在 C# 中为 Word 中的图像设置 Transeperant 颜色

Spire.Doc for .NET是一款专门对 Word 文档进行操作的 .NET 类库。在于帮助开发人员无需安装 Microsoft Word情况下&#xff0c;轻松快捷高效地创建、编辑、转换和打印 Microsoft Word 文档。拥有近10年专业开发经验Spire系列办公文档开发工具&#xff0c;专注于创建、编辑、转…

A-Level经济题解析及练习Policy options for Common Resources

今日知识点&#xff1a;Policy options for Common Resources 例题 There is a medieval town where sheep graze on common land. As the population grows, the number of sheep grows. However, the amount of land is fixed, the grass begins to disappear from overgra…

SwiftUI 中为什么应该经常用子视图替换父视图中的大段内容?

概览 在 SwiftUI 官方教程中&#xff0c;Apple 时常提出“化整为零”的界面布局思想。简单来说&#xff0c;Apple 推荐 SwiftUI 视图的构建方式是&#xff1a;用若干自定义小视图来构成上层的功能视图。 这是为什么呢&#xff1f; 在本篇博文中&#xff0c;我们将用一个通俗…

[Java反序列化]—CommonsCollections6

先贴个图 0x01: CC 6 应该是CC1 和 URLDNS 的综合&#xff0c;有一定联系&#xff0c;审一下吧 JDK版本需低于 8u71 AnnotationInvocationHandler类的readObject()方法在8u71以后逻辑就发生了改变&#xff0c;不能再利用了&#xff0c;所以就需要找一个绕过高版本的利用链…

Cadence Virtuoso Layout 版图绘制的使用技巧及其相关快捷键

1.版图前准备操作 画好原理图&#xff0c;打好pin脚&#xff08;pin最好以全大写的形式书写&#xff0c;以防后续操作中可能出现Bug&#xff09; 查看所使用工艺库的design rule文件&#xff0c;确定栅格单位设置大小 在准备绘制的原理图界面启动layout XL/GXL 在layout界面…

JS 正则表达式常用方法

1. JS 正则表达式 2. 使用字符串方法 3. 使用 RegExp 方法 1. JS 正则表达式 JS 正则表达式语法: # JS 的正则表达式不需要使用引号包裹&#xff0c;PHP 需要使用引号包裹。修饰符是可选的&#xff0c;可写可不写/正则表达式主体/修饰符JS 中使用正则表达式的方法比较多&am…

【强化学习论文合集】九.2018AAAI人工智能大会论文(AAAI2018)

强化学习(Reinforcement Learning, RL),又称再励学习、评价学习或增强学习,是机器学习的范式和方法论之一,用于描述和解决智能体(agent)在与环境的交互过程中通过学习策略以达成回报最大化或实现特定目标的问题。 本专栏整理了近几年国际顶级会议中,涉及强化学习(Rein…

Python中的Apriori关联算法-市场购物篮分析

数据科学Apriori算法是一种数据挖掘技术&#xff0c;用于挖掘频繁项集和相关的关联规则。本模块重点介绍什么是关联规则挖掘和Apriori算法&#xff0c;以及Apriori算法的用法。 去年&#xff0c;我们为一家公司进行了短暂的咨询工作&#xff0c;该公司正在构建一个主要基于Apr…

使用DIV+CSS技术设计的非遗文化网页与实现制作(web前端网页制作课作业)

&#x1f389;精彩专栏推荐 &#x1f4ad;文末获取联系 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 &#x1f482; 作者主页: 【主页——&#x1f680;获取更多优质源码】 &#x1f393; web前端期末大作业&#xff1a; 【&#x1f4da;毕设项目精品实战案例 (10…

m基于自适应遗传优化的IEEE-6建设费用和网络损耗费用最小化电网规划算法matlab仿真

目录 1.算法描述 2.仿真效果预览 3.MATLAB核心程序 4.完整MATLAB 1.算法描述 电力工业是当今世界各国经济的重要组成部分&#xff0c;随着世界经济的不断发展&#xff0c;电网的建设和中长期规划和经济发展之间的矛盾变得越来越突出&#xff0c;对电力系统的需求也变得越来…

微服务框架 SpringCloud微服务架构 16 SpringAMQP 16.7 DirectExchange

微服务框架 【SpringCloudRabbitMQDockerRedis搜索分布式&#xff0c;系统详解springcloud微服务技术栈课程|黑马程序员Java微服务】 SpringCloud微服务架构 文章目录微服务框架SpringCloud微服务架构16 SpringAMQP16.7 DirectExchange16.7.1 发布订阅 - DirectExchange16.7.…

基于遗传优化算法的小车障碍物避障路线规划matlab仿真

目录 1.算法描述 2.仿真效果预览 3.MATLAB核心程序 4.完整MATLAB 1.算法描述 一种通过模拟自然进化过程搜索最优解的方法&#xff0c;对于一个最优化问题&#xff0c;该算法通过一定数量的候选解种群迭代地执行选择、交叉、变异、评价等操作使得种群向更好的解进化。 遗传算…

MyBatisPlus简述

文章目录一、MyBatisPlus入门案例与简介1.入门案例2.springboot整合mybatis的方式3.springboot整合mybatisplus步骤1.创建环境&#xff0c;上面我们已经创建过了步骤2.创建数据库及表步骤2.pom.xml补全依赖步骤3.添加MP的相关配置信息步骤4.根据数据库表创建实体类步骤5.创建Da…