FCOS论文复现:通用物体检测算法

news2024/11/28 4:53:13
摘要:本案例代码是FCOS论文复现的体验案例,此模型为FCOS论文中所提出算法在ModelArts + PyTorch框架下的实现。本代码支持FCOS + ResNet-101在MS-COCO数据集上完整的训练和测试流程

本文分享自华为云社区《通用物体检测算法 FCOS(目标检测/Pytorch)》,作者: HWCloudAI 。

FCOS:Fully Convolutional One-Stage Object Detection

本案例代码是FCOS论文复现的体验案例

此模型为FCOS论文中所提出算法在ModelArts + PyTorch框架下的实现。该算法使用MS-COCO公共数据集进行训练和评估。本代码支持FCOS + ResNet-101在MS-COCO数据集上完整的训练和测试流程

具体的算法介绍:AI Gallery_算法_模型_云市场-华为云

注意事项:

1.本案例使用框架: PyTorch1.0.0

2.本案例使用硬件: GPU

3.运行代码方法: 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码

1.数据和代码下载

import os
import moxing as mox
# 数据代码下载
mox.file.copy_parallel('obs://obs-aigallery-zc/algorithm/FCOS.zip','FCOS.zip')
# 解压缩
os.system('unzip  FCOS.zip -d ./')

2.模型训练

2.1依赖库安装及加载

"""
Basic training script for PyTorch
"""
# Set up custom environment before nearly anything else is imported
# NOTE: this should be the first import (no not reorder)
import os
import argparse
import torch
import shutil
src_dir = './FCOS/'
os.chdir(src_dir)
os.system('pip install -r ./pip-requirements.txt')
os.system('python -m pip install ./trained_model/model/framework-2.0-cp36-cp36m-linux_x86_64.whl')
os.system('python setup.py build develop')
from framework.utils.env import setup_environment
from framework.config import cfg
from framework.data import make_data_loader
from framework.solver import make_lr_scheduler
from framework.solver import make_optimizer
from framework.engine.inference import inference
from framework.engine.trainer import do_train
from framework.modeling.detector import build_detection_model
from framework.utils.checkpoint import DetectronCheckpointer
from framework.utils.collect_env import collect_env_info
from framework.utils.comm import synchronize, \
 get_rank, is_pytorch_1_1_0_or_later
from framework.utils.logger import setup_logger
from framework.utils.miscellaneous import mkdir

2.2训练函数

def train(cfg, local_rank, distributed, new_iteration=False):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
 if cfg.MODEL.USE_SYNCBN:
 assert is_pytorch_1_1_0_or_later(), \
 "SyncBatchNorm is only available in pytorch >= 1.1.0"
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)
 if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
 # this should be removed if we update BatchNorm stats
 broadcast_buffers=False,
 )
    arguments = {}
    arguments["iteration"] = 0
 output_dir = cfg.OUTPUT_DIR
 save_to_disk = get_rank() == 0
 checkpointer = DetectronCheckpointer(
 cfg, model, optimizer, scheduler, output_dir, save_to_disk
 )
 print(cfg.MODEL.WEIGHT)
 extra_checkpoint_data = checkpointer.load_from_file(cfg.MODEL.WEIGHT)
 print(extra_checkpoint_data)
 arguments.update(extra_checkpoint_data)
 if new_iteration:
        arguments["iteration"] = 0
 data_loader = make_data_loader(
 cfg,
 is_train=True,
 is_distributed=distributed,
 start_iter=arguments["iteration"],
 )
 do_train(
        model,
 data_loader,
        optimizer,
        scheduler,
 checkpointer,
        device,
        arguments,
 )
 return model

2.3设置参数,开始训练

def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
 parser.add_argument(
 '--train_url',
        default='./outputs',
 type=str,
 help='the path to save training outputs'
 )
 parser.add_argument(
 "--config-file",
        default="./trained_model/model/fcos_resnet_101_fpn_2x.yaml",
 metavar="FILE",
 help="path to config file",
 type=str,
 )
 parser.add_argument("--local_rank", type=int, default=0)
 parser.add_argument('--train_iterations', default=0, type=int)
 parser.add_argument('--warmup_iterations', default=500, type=int)
 parser.add_argument('--train_batch_size', default=8, type=int)
 parser.add_argument('--solver_lr', default=0.01, type=float)
 parser.add_argument('--decay_steps', default='120000,160000', type=str)
 parser.add_argument('--new_iteration',default=False, action='store_true')
 args, unknown = parser.parse_known_args()
 cfg.merge_from_file(args.config_file)
 # load the model trained on MS-COCO
 if args.train_iterations > 0:
 cfg.SOLVER.MAX_ITER = args.train_iterations
 if args.warmup_iterations > 0:
 cfg.SOLVER.WARMUP_ITERS = args.warmup_iterations
 if args.train_batch_size > 0:
 cfg.SOLVER.IMS_PER_BATCH = args.train_batch_size
 if args.solver_lr > 0:
 cfg.SOLVER.BASE_LR = args.solver_lr
 if len(args.decay_steps) > 0:
        steps = args.decay_steps.replace(' ', ',')
        steps = steps.replace(';', ',')
        steps = steps.replace(';', ',')
        steps = steps.replace(',', ',')
        steps = steps.split(',')
        steps = tuple([int(x) for x in steps])
 cfg.SOLVER.STEPS = steps
 cfg.freeze()
 num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
 args.distributed = num_gpus > 1
 if args.distributed:
 torch.cuda.set_device(args.local_rank)
 torch.distributed.init_process_group(
            backend="nccl", init_method="env://"
 )
 synchronize()
 output_dir = args.train_url
 if output_dir:
 mkdir(output_dir)
    logger = setup_logger("framework", output_dir, get_rank())
 logger.info("Using {} GPUs".format(num_gpus))
    logger.info(args)
 logger.info("Loaded configuration file {}".format(args.config_file))
 train(cfg, args.local_rank, args.distributed, args.new_iteration)
if __name__ == "__main__":
 main()

3.模型测试

3.1预测函数

from framework.engine.predictor import Predictor
from PIL import Image,ImageDraw
import numpy as np
import matplotlib.pyplot as plt
def predict(img_path,model_path): 
 config_file = "./trained_model/model/fcos_resnet_101_fpn_2x.yaml"
 cfg.merge_from_file(config_file)
 cfg.defrost()
 cfg.MODEL.WEIGHT = model_path
 cfg.OUTPUT_DIR = None
 cfg.freeze()
    predictor = Predictor(cfg=cfg, min_image_size=800)
 src_img = Image.open(img_path)
 img = src_img.convert('RGB')
 img = np.array(img)
 img = img[:, :, ::-1]
    predictions = predictor.compute_prediction(img)
 top_predictions = predictor.select_top_predictions(predictions)
 bboxes = top_predictions.bbox.int().numpy().tolist()
 bboxes = [[x[1], x[0], x[3], x[2]] for x in bboxes]
    scores = top_predictions.get_field("scores").numpy().tolist()
    scores = [round(x, 4) for x in scores]
    labels = top_predictions.get_field("labels").numpy().tolist()
    labels = [predictor.CATEGORIES[x] for x in labels]
    draw = ImageDraw.Draw(src_img)
 for i,bbox in enumerate(bboxes):
 draw.text((bbox[1],bbox[0]),labels[i] + ':'+str(scores[i]),fill=(255,0,0))
 draw.rectangle([bbox[1],bbox[0],bbox[3],bbox[2]],fill=None,outline=(255,0,0))
 return src_img

3.2开始预测

if __name__ == "__main__":
 model_path = "./outputs/weights/fcos_resnet_101_fpn_2x/model_final.pth" # 训练得到的模型
 image_path = "./trained_model/model/demo_image.jpg" # 预测的图像
 img = predict(image_path,model_path)
 plt.figure(figsize=(10,10)) #设置窗口大小
 plt.imshow(img)
 plt.show()
2021-06-09 15:33:15,362 framework.utils.checkpoint INFO: Loading checkpoint from ./outputs/weights/fcos_resnet_101_fpn_2x/model_final.pth

点击关注,第一时间了解华为云新鲜技术~

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/41989.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

UML/SysML和流浪地球的地球发动机

Lucky 2022-11-24 14:33 最近收到的公众号消息有不少是sysml内容,请问老师sysml和uml是什么关系,以后的趋势是sysml取代uml吗? UMLChina潘加宇 SysML和UML不冲突,也不存在取代的关系。 UML是信息系统的建模语言。“信息系统”…

“Signal”背后的bug与解决

背景 熟悉我的老朋友可能都知道,之前为了应对crash与anr,开源过一个“民间偏方”的库Signal,用于解决在发生crash或者anr时进行应用的重启,从而最大程度减少其坏影响。 在维护的过程中,发生过这样一件趣事&#xff0…

python合集1

我的首个python的合集啊~~ 完全给自己看啊 不喜喷了也不里你 一、一维插值 对现有数据进行拟合或插值是数学分析中常见的方式。 通过分析现有数据,得到一个连续的函数(也就是曲线);或者更密集的离散方程与已知数据互相吻合&…

HTML+CSS详细知识点(下)

🔥上一篇🔥HTMLCSS详细知识点复习(上) 文章目录五、列表和超链接1、列表标签2、CSS控制列表样式3、超链接六、表格和表单1、表格2、表单七、浮动与定位1、元素的浮动2、清除浮动3、overflow属性4、元素的定位属性5、position属性五…

[附源码]计算机毕业设计springboot安防管理平台

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

【吴恩达机器学习笔记】五、逻辑回归

✍个人博客:https://blog.csdn.net/Newin2020?spm1011.2415.3001.5343 📣专栏定位:为学习吴恩达机器学习视频的同学提供的随堂笔记。 📚专栏简介:在这个专栏,我将整理吴恩达机器学习视频的所有内容的笔记&…

【Hack The Box】linux练习-- Horizontall

HTB 学习笔记 【Hack The Box】linux练习-- Horizontall 🔥系列专栏:Hack The Box 🎉欢迎关注🔎点赞👍收藏⭐️留言📝 📆首发时间:🌴2022年11月27日🌴 &…

Spring Cloud和Dubbo有哪些区别?

Spring Cloud Spring Cloud是⼀个微服务框架,提供了微服务领域中的很多功能组件,并且Spring Cloud是⼀个⼤⽽全的框架 Dubbo Dubbo⼀开始是⼀个RPC调⽤框架,核⼼是解决服务调⽤间的问题 对比: Dubbo则更侧重于服务调⽤&#x…

Nuxt 3.0.0正式发布,集成Element Plus、Ant Design Vue和Arco Design Vue脚手架

发布说明 Nuxt 是使用简便的 Web 框架,用于构建现代和高性能的 Web 应用,可以部署在任何运行 JavaScript 的平台上。 Nuxt 3.0 11天前正式发布了稳定版,3.0 基于 Vue 3,为 TypeScript 提供了 “一等公民” 支持,并进行…

java面试强基(13)

前文链接(61条消息) java面试强基(12)_一个风轻云淡的博客-CSDN博客https://blog.csdn.net/m0_62436868/article/details/128047427?spm1001.2014.3001.5501 何为反射?反射机制优缺点? ​ 它赋予了我们在运行时分析类以及执行类…

Jenkins部署与基础配置(1)

5 Jenkins 部署与基础配置 IP地址角色172.18.8.19jenkins-master172.18.8.29jenkins-node1172.18.8.39jenkins-node2 [rootjenkins-master ~]# tail -n1 .bashrc PS1\[\e[1;32m\][\[\e[0m\]\[\e[1;32m\]\[\e[1;33m\]\u\[\e[34m\]\h\[\e[1;31m\] \w\[\e[1;32m\]]\[\e[0m\]# [r…

ISCTF新生赛(引用传递简单社工)

猫和老鼠 反序列化题目&#xff1a; <?php //flag is in flag.php highlight_file(__FILE__); error_reporting(0);class mouse { public $v;public function __toString(){echo "Good. You caught the mouse:";include($this->v);}}class cat {public $a;p…

05 Pod:如何理解这个Kubernetes里最核心的概念?

文章目录1 为什么要有pod?2. 为什么Pod 是 Kubernetes 的核心对象&#xff1f;3. 如何用YAML描述Pod?3.1 Pod的基本组成部分3.1.1 最重要的 spec.containers 字段使用3.1.1.1为什么要定义容器启动时要执行的命令&#xff1f;4. 如何使用kubectl 操作Pod?4.1 创建pod4.2 删除…

数据结构与算法之查找算法

数据结构与算法——查找算法 本文将不断更新查找有关算法&#xff0c;由于精力有限&#xff0c;因此本博文将分多次更新&#xff0c;感谢您的关注 文章目录数据结构与算法——查找算法1. 二分法查找&#xff08;折半查找&#xff09;1.1 算法叙述1.2 实例说明2. 插值查找&#…

【ML特征工程】第 8 章 :自动化特征化器:图像特征提取和深度学习

&#x1f50e;大家好&#xff0c;我是Sonhhxg_柒&#xff0c;希望你看完之后&#xff0c;能对你有所帮助&#xff0c;不足请指正&#xff01;共同学习交流&#x1f50e; &#x1f4dd;个人主页&#xff0d;Sonhhxg_柒的博客_CSDN博客 &#x1f4c3; &#x1f381;欢迎各位→点赞…

[2022-11-26]神经网络与深度学习第5章 - 循环神经网络(part 2)

contents循环神经网络(part 2) - 梯度爆炸实验写在开头解决方式概览梯度爆炸实验梯度打印函数思考&#xff1a;什么是范数、L2范数、为什么要打印梯度范数复现梯度爆炸现象使用梯度截断解决梯度爆炸问题思考&#xff1a;梯度截断解决梯度爆炸问题的原理&#xff1f;写在最后循环…

搭建MinIO容器

文章目录1 问题背景2 资源准备3 安装Docker服务4 关闭防火墙5 以Docker方式安装MinIO6 访问MinIO1 问题背景 玩一个前后端的项目&#xff0c;需要用到对象存储器&#xff0c;于是使用开源的MinIO。期间以Docker方式搭建遇到某些坑&#xff0c;此处仅以博客的方式记录下来 2 资源…

【通信原理课设--基于MATLAB/Simulink的2ASK数字带通传输系统建模与仿真】Simulink的使用介绍以及在本实验中的使用

目录 Simulink的简要介绍 Simulink的使用流程 进入Simulink 进入模型编辑窗口 ​ 建立一个新的文件 根据求需建立模型 对选择的模块进行参数设置 本次课程设计需要使用Simulink做ASK的仿真处理&#xff0c;那么下面就一起学习了解一下Simulink吧&#xff01; Simuli…

全球经济自由度1995-2021最新版绿色金融指数2001-2020

&#xff08;1&#xff09;全球经济自由度指数 1995-2021 1、数据来源&#xff1a;美国传统基金会 2、时间跨度&#xff1a;1995-2021 3、区域范围&#xff1a;全球 4、指标说明&#xff1a; 经济自由度指数&#xff0c;是由《华尔街日报》和美国传统基金会发布的年度报告…

爱站网关键词挖掘工具-长尾关键词挖掘站长工具

长尾词挖掘免费工具&#xff0c;为什么我们要使用长尾词挖掘免费工具&#xff0c;我们只要找准关键词就等于掌握了流量。 关键词可应用于任何平台&#xff1a;不管是网站、短视频、自媒体等&#xff01; 比如说用户A经常看体育领域的内容&#xff0c;平台就会给A打上体育领域标…