论文复现|Panoptic Deeplab(全景分割PyTorch)

news2024/9/20 16:47:09
摘要:这是发表于CVPR 2020的一篇论文的复现模型。

本文分享自华为云社区《Panoptic Deeplab(全景分割PyTorch)》,作者:HWCloudAI 。

这是发表于CVPR 2020的一篇论文的复现模型,B. Cheng et al, “Panoptic-DeepLab: A Simple, Strong, and Fast Baseline for Bottom-Up Panoptic Segmentation”, CVPR 2020,此模型在原论文的基础上,使用HRNet作为backbone,得到了高于原论文的精度,PQ达到了63.7%,mIoU达到了80.3%,AP达到了37.3%。该算法会载入Cityscapes上的预训练模型(HRNet),我们提供了训练代码和可用于训练的模型,用于实际场景的微调训练。训练后生成的模型可直接在ModelArts平台部署成在线服务。

具体算法介绍:AI Gallery_算法_模型_云市场-华为云

注意事项:

1.本案例使用框架:PyTorch1.4.0

2.本案例使用硬件:GPU: 1*NVIDIA-V100NV32(32GB) | CPU: 8 核 64GB

3.运行代码方法: 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码

4.JupyterLab的详细用法: 请参考《ModelAtrs JupyterLab使用指导》

5.碰到问题的解决办法: 请参考《ModelAtrs JupyterLab常见问题解决办法》

1.下载数据和代码

运行下面代码,进行数据和代码的下载

本案例使用cityscapes数据集。

import os
import moxing as mox
# 数据代码下载
mox.file.copy_parallel('s3://obs-aigallery-zc/algorithm/panoptic-deeplab','./panoptic-deeplab')

2.模型训练

2.1依赖库加载

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import print_function
import os
root_path = './panoptic-deeplab/'
os.chdir(root_path)
# 获取当前目录结构信息,以便进行代码调试
print('os.getcwd():', os.getcwd())
import time
import argparse
import time
import datetime
import math
import sys
import shutil
import moxing as mox # ModelArts上专用的moxing模块,可用于与OBS的数据交互,API文档请查看:https://github.com/huaweicloud/ModelArts-Lab/tree/master/docs/moxing_api_doc
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

2.2训练参数设置

parser = argparse.ArgumentParser(description='Panoptic Deeplab')
parser.add_argument('--training_dataset', default='/home/ma-user/work/panoptic-deeplab/', help='Training dataset directory') # 在ModelArts中创建算法时,必须进行输入路径映射配置,输入映射路径的前缀必须是/home/work/modelarts/inputs/,作用是在启动训练时,将OBS的数据拷贝到这个本地路径中供本地代码使用。
parser.add_argument('--train_url', default='./output', help='the path to save training outputs') # 在ModelArts中创建训练作业时,必须指定OBS上的一个训练输出位置,训练结束时,会将输出映射路径拷贝到该位置
parser.add_argument('--num_gpus',  default=1, type=int, help='num of GPUs to train')
parser.add_argument('--eval', default='False', help='whether to eval')
parser.add_argument('--load_weight', default='trained_model/model/model_final.pth',type=str) # obs路径 断点模型 pth文件 如果是评估 则是相对于src的路径
parser.add_argument('--iteration', default=100, type=int)
parser.add_argument('--learning_rate', default=0.001, type=float)
parser.add_argument('--ims_per_batch', default=8, type=int)
args, unknown = parser.parse_known_args() # 必须将parse_args改成parse_known_args,因为在ModelArts训练作业中运行时平台会传入一个额外的init_method的参数
# dir
fname = os.getcwd()
project_dir = os.path.join(fname, "panoptic-deeplab")
detectron2_dir = os.path.join(fname, "detectron2-0.3+cu102-cp36-cp36m-linux_x86_64.whl")
panopticapi_dir = os.path.join(fname, "panopticapi-0.1-py3-none-any.whl")
cityscapesscripts_dir = os.path.join(fname, "cityscapesScripts-2.1.7-py3-none-any.whl")
requirements_dir = os.path.join(project_dir, "requirements.txt") 
output_dir = "/home/work/modelarts/outputs/train_output" 
# config strings
evalpath = ''
MAX_ITER = 'SOLVER.MAX_ITER ' + str(args.iteration+90000)
BASE_LR = 'SOLVER.BASE_LR ' + str(args.learning_rate)
IMS_PER_BATCH = 'SOLVER.IMS_PER_BATCH ' + str(args.ims_per_batch)
SCRIPT_PATH = os.path.join(project_dir, "tools_d2/train_panoptic_deeplab.py") 
CONFIG_PATH = os.path.join(fname, "configs/config.yaml")
CONFIG_CMD = '--config-file ' + CONFIG_PATH
EVAL_CMD = ''
GPU_CMD = ''
OPTS_CMD = MAX_ITER + ' ' + BASE_LR + ' ' + IMS_PER_BATCH
RESUME_CMD = ''
#functions
def merge_cmd(scirpt_path, config_cmd, gpu_cmd, eval_cmd, resume_cmd, opts_cmd):
 return "python " + scirpt_path + " "+ config_cmd + " " + gpu_cmd + " " + eval_cmd + " " + resume_cmd + " " + OPTS_CMD
if args.eval == 'True':
 assert args.load_weight, 'load_weight empty when trying to evaluate' # 如果评估时为空,则报错
 if args.load_weight != 'trained_model/model/model_final.pth':
 #将model拷贝到本地,并获取模型路径
 modelpath, modelname = os.path.split(args.load_weight)
 mox.file.copy_parallel(args.load_weight, os.path.join(fname, modelname))
 evalpath = os.path.join(fname,modelname)
 else:
 evalpath = os.path.join(fname,'trained_model/model/model_final.pth')
    EVAL_CMD = '--eval-only MODEL.WEIGHTS ' + evalpath
else:
    GPU_CMD = '--num-gpus ' + str(args.num_gpus)
 if args.load_weight:
        RESUME_CMD = '--resume'
 if args.load_weight != 'trained_model/model/model_final.pth':
 modelpath, modelname = os.path.split(args.load_weight)
 mox.file.copy_parallel(args.load_weight, os.path.join('/cache',modelname))
 with open('/cache/last_checkpoint','w') as f: #创建last_checkpoint文件
 f.write(modelname)
 f.close()
 else:
 os.system('cp ' + os.path.join(fname, 'trained_model/model/model_final.pth') + ' /cache/model_final.pth')
 with open('/cache/last_checkpoint','w') as f: #创建last_checkpoint文件
 f.write('model_final.pth')
 f.close()
os.environ['DETECTRON2_DATASETS'] = args.training_dataset #添加数据库路径环境变量
cmd = merge_cmd(SCRIPT_PATH, CONFIG_CMD, GPU_CMD, EVAL_CMD, RESUME_CMD, OPTS_CMD)
# os.system('mkdir -p ' + args.train_url)
print('*********Train Information*********')
print('Run Command: ' + cmd)
print('Num of GPUs: ' + str(args.num_gpus))
print('Evaluation: ' + args.eval)
if args.load_weight:
 print('Load Weight: ' + args.load_weight)
else:
 print('Load Weight: None (train from scratch)')
print('Iteration: ' + str(args.iteration))
print('Learning Rate: ' + str(args.learning_rate))
print('Images Per Batch: ' + str(args.ims_per_batch))

2.3安装依赖库

安装依赖库需要几分钟,请耐心等待

def install_dependecies(r,d, p, c):
 os.system('pip uninstall pytorch> out1.txt')
 os.system('pip install  torch==1.7.0> out2.txt')
 os.system('pip install --upgrade pip')
 os.system('pip install --upgrade numpy')
 os.system('pip install torchvision==1.7.0> out3.txt')
 os.system('pip install pydot')
 os.system('pip install --upgrade pycocotools')
 os.system('pip install tensorboard')
 os.system('pip install -r ' + r + ' --ignore-installed PyYAML') 
 os.system('pip install ' + d) 
 os.system('pip install ' + p)
 os.system('pip install ' + c)
 os.system('pip install pyyaml ==5.1.0')
# 安装依赖
print('*********Installing Dependencies*********')
install_dependecies(requirements_dir,detectron2_dir, panopticapi_dir, cityscapesscripts_dir)
*********Installing Dependencies*********

2.4开始训练

print('*********Training Begin*********')
print(cmd)
start = time.time()
ret = os.system(cmd+ " >out.txt")
if ret == 0:
 print("success")
else:
 print('fail')
end_time=time.time()
print('done')
print(end_time-start)
if args.eval == 'False':
 os.system('mv /cache/model_final.pth ' + os.path.join(fname, 'output/model_final.pth')) #/cache模型移动到输出文件夹
if os.path.exists(os.path.join(fname, 'pred_results')):
 os.system('mv ' + os.path.join(fname, 'pred_results') + ' ' + args.train_url)

训练完成之后,可以在out.txt中看运行日志
在./panoptic-deeplab/output/pred_results/文件目录下,有该模型全景分割,实例分割,语义分割的评估结果

3.模型测试

3.1加载测试函数

from test import *

3.2开始预测

if __name__ == '__main__':
 img_path = r'/home/ma-user/work/panoptic-deeplab/cityscapes/leftImg8bit/val/frankfurt/frankfurt_000000_003920_leftImg8bit.png' # TODO 修改测试图片路径
 model_path = r'/home/ma-user/work/panoptic-deeplab/output/model_final.pth' # TODO 修改模型路径
 my_model = ModelClass(model_path)
    result = my_model.predict(img_path)
 print(result)

点击关注,第一时间了解华为云新鲜技术~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/31162.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

63. 不同路径 II

题目 一个机器人位于一个 m x n 网格的左上角 (起始点在下图中标记为 “Start” )。 机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角(在下图中标记为 “Finish”)。 现在考虑网格中有障碍物。那么从左上角到…

MySQL主/从-主/主集群安装部署

MySQL集群架构的介绍 我们在使用到MySQL数据库的时候,只是一个单机的数据库服务。在实际的生产环境中,数据量可能会非常庞大,这样单机服务的MySQL在使用的时候,性能会受到影响影响。并且单机服务的MySQL的数据安全性也会受到影响…

数字信号处理-09-串行FIR滤波器MATLAB与FPGA实现

前言 本文介绍了设计滤波器的FPGA实现步骤,并结合杜勇老师的书籍中的串行FIR滤波器部分进行一步步实现硬件设计,对书中的架构做了简单的优化,并进行了仿真验证。 FIR滤波器的FPGA实现步骤 从工程角度分析FIR滤波器的FPGA实现步骤如下&…

Vim简洁教程

Vim简洁教程Vim简介使用方法命令模式输入模式底线命令模式模式转换使用流程Vim键盘图Vim简介 在Linux系统中,Vim是一款自带的文本编辑器,因此Vim常用于Linux系统中。Vim是从 vi 发展出来的,包含代码补全、编译及错误跳转等方便编程的功能&am…

【LeetCode每日一题】——78.子集

文章目录一【题目类别】二【题目难度】三【题目编号】四【题目描述】五【题目示例】六【解题思路】七【题目提示】八【时间频度】九【代码实现】十【提交结果】一【题目类别】 数组 二【题目难度】 中等 三【题目编号】 78.子集 四【题目描述】 给你一个整数数组 nums &…

亚太C题详细版思路修改版(精)

今年的亚太A、B题的感觉难度不低,其难度已经可以与电工妈杯这种比赛的赛题难度相提并论了。因此,这次预计选C题的人数可能不少,这对于大家来说也是个好消息。塞翁失马焉知非福,难对于大家来说都难,只要自己放平心态&am…

计算机组成原理习题课第一章-1(唐朔飞)

计算机组成原理习题课第一章-1(唐朔飞) ✨欢迎关注🖱点赞🎀收藏⭐留言✒ 🔮本文由京与旧铺原创,csdn首发! 😘系列专栏:java学习 💻首发时间:&…

【Pygame实战】这游戏有毒,刷爆朋友圈:小编已与病毒版贪吃蛇大战了三百回合,最高分339?

导语 Hello,大家好呀!我是木木子吖~ 一个集美貌幽默风趣善良可爱并努力码代码的程序媛一枚。 听说关注我的人会一夜暴富发大财哦~ (哇哇哇 这真的爱😍😍) 所有文章完整的素材源码都在&#…

Android中JVM七大垃圾收集器【解析】

概述 GC垃圾收集器的种类 新生代:年轻代用来存放最近创建的对象老年代:主要存放应用程序中生命周期长的内存对象永久代:内存的永久保存区域(类和元数据),GC不参与回收Serial收集器:串行收集器…

web网页设计—— 中国餐饮协会(HTML+CSS)

🎀 精彩专栏推荐👇🏻👇🏻👇🏻 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 💂 作者主页: 【主页——🚀获取更多优质源码】 🎓 web前端期末大作业…

CentOS8使用阿里云yum源异常问题及解决方法

镜像下载、域名解析、时间同步请点击 阿里云开源镜像站 Linux安装git时发生如下错误 [rootraoyuuuu maven]# dnf install git Repository extras is listed more than once in the configuration Repository epel is listed more than once in the configuration CentOS-8 - B…

关于windows的文件监控管理系统(Java)

目 录 摘 要 I Abstract II 1.绪论 1 1.1课题背景 1 1.2系统开发的目的和意义 2 1.3国内外概况 3 1.4研究主要内容 3 2.windows文件监控管理系统相关技术介绍 4 2.1 API 4 2.2 API HOOK 5 2.3 Java 5 2.4 DLL 6 2.4 Windows系统的Socket编程 6 2.4.1使用WinSock API 6 2.4.2 使…

【 C++ 】IO流

目录 1、C语言的输入输出 2、流是什么 3、CIO流 3.1、C标准IO流 3.2、C文件IO流 文件操作步骤 以二进制的形式操作文件 以文本的形式操作文件 4、stringstream的介绍 1、C语言的输入输出 C语言中我们用到的最频繁的输入输出方式就是scanf()和printf()。 scanf()&#xff1a…

[前端基础] JavaScript 基础篇(下)

DOM 和 BOM DOM 指的是文档对象模型,它指的是把文档当做一个对象来对待,这个对象主要定义了处理网页内容的方法和接口。BOM 指的是浏览器对象模型,它指的是把浏览器当做一个对象来对待,这个对象主要定义了与浏览器进行交互的法和…

Node核心模块之Stream

Node.js诞生之初就是为了提高IO性能,文件操作系统和网络模块实现了流接口,Node.js中流就是处理流式数据的抽象接口。 那么应用程序为什么使用流来处理数据? 常见问题 同步读取资源文件,用户需要等待数据读取完成资源文件最终一次…

【Windows】windows10时间显示秒数

一般情况下windows10的电脑时间只显示小时和分钟,但是有的用户想要时间显示更加精细,那么windows10时间怎么显示秒呢?大家可以通过修改注册表的方式进行设置:打开注册表编辑器,定位到Advanced,右键新建DWOR…

【第十四篇】Camunda系列-多人会签【多实例】

多人会签 Multiple Instance 也叫多实例任务。 1.会签说明 多实例活动是为业务流程中的某个步骤定义重复的一种方式。在编程概念中,多实例与 for each 结构相匹配:它允许对给定集合中的每个项目按顺序或并行地执行某个步骤或甚至一个完整的子流程。 多实例是一个有额外属性…

注解(Annotation)

注解 注解也被称为元数据(MateDate),用于修饰或解释包,类,方法,属性,构造器,局部变量等数据信息和注释一样,注解不会影响程序逻辑,但是注解可以被编译或者运行&#xff…

如何定义需求优先级?

本文将围绕以下问题展开:1、什么是需求优先级排序,目的是什么?2、优先级排序的8大依据;3、需求优先级排序面临的挑战;4、一些优秀的需求优先级排序工具。 一、什么是需求优先级排序,目的是什么?…

Mybatis-plus 用法

本文主要介绍 mybatis-plus 这款插件,针对 springboot 用户。包括引入,配置,使用,以及扩展等常用的方面做一个汇总整理,尽量包含大家常用的场景内容。 关于 mybatis-plus 是什么,不多做介绍了,看…