mmdetection使用自己的voc数据集训练模型实战

news2025/1/17 8:45:33

一.自己数据集整理
将labelimg格式数据集进行整理
1.1. 更换图片后缀为jpg

import os
import shutil

root_path='/media/ai-developer/img'

file=os.listdir(root_path)

for img in file:
    if img.endswith('jpeg') or img.endswith('JPG') or img.endswith('png'):
        img_path=os.path.join(root_path,img)
        name=os.path.splitext(img)[0]
        new_name=name+'.jpg'
        os.rename(img_path,os.path.join(root_path,new_name))
        print(name+'.jpg','修改成功....')

2.删除xml和jpg名称不对应的图片

import os
import shutil
imgs=[]
labels=[]

xml_path='/media/ai-developer/277f00a0-3f2b-47a3-9870-b69d65db4d511/图像/20240130结果/ann'
jpg_path='/media/ai-developer/277f00a0-3f2b-47a3-9870-b69d65db4d511/图像/20240130结果/img'


def get_file_list(path, ex):

    file_list = []
    for dir, folder, file in os.walk(path):
        for i in file:
            if os.path.splitext(i)[1] in ex:
                file_list.append(os.path.join(dir, i))
    return file_list

file_jpg = get_file_list(jpg_path, ['.jpg','.JPG','jpeg','png'])
file_xml = get_file_list(xml_path, ['.xml'])

prefix_jpg_list=[]
prefix_xml_list=[]

for b in file_jpg:
    prefix_jpg=os.path.splitext(b)[0]
    jpg_suffix = os.path.basename(prefix_jpg)
    prefix_jpg_list.append(jpg_suffix)

for b in file_xml:
    prefix_xml=os.path.splitext(b)[0]
    xml_suffix = os.path.basename(prefix_xml)
    prefix_xml_list.append(xml_suffix)

for c in prefix_jpg_list:
    if c not in prefix_xml_list:
        os.remove(os.path.join(jpg_path,c)+'.jpg')
        print(c + '.jpg 已将删除')

for d in prefix_xml_list:
    if d not in prefix_jpg_list:
        os.remove(os.path.join(xml_path,d)+'.xml')
        print(d+'.xml 已将删除')
print('over')

1.3 查看class name

# -*- coding:utf-8 -*-

from xml.dom.minidom import parse
import xml.dom.minidom
import os
import xml.etree.ElementTree as ET

xml_path = '/home/ai-developer/桌面/VOCdevkit/VOC2007/Annotations'

classCount = dict()
jpg_name_set=set()

def load_predefine_class():
    predef = open('predefined_classes.txt', 'r', encoding='utf-8')
    for c in predef:
        c = c[:-1]
        classCount[c] = 0

def parse_files(path):

    root = ET.parse(path).getroot()  # 利用ET读取xml文件

    for obj in root.iter('object'):  # 遍历所有目标框
        # print('pic_name:', xml_name)
        name = obj.find('name').text  # 获取目标框名称,即label名

        v = classCount.get(name, 0)
        classCount[name] = v + 1

def traversal_dir(xml_path):
    for p,d,f in os.walk(xml_path):
        for t in f:

            if t.endswith(".xml"):
                path = os.path.join(p, t)

                parse_files(path)
                # print(path)


def output():
    for k in classCount:
        print('%s : %d' % (k, classCount[k]))

if __name__ == '__main__':

    traversal_dir(xml_path)
    output()

1.4 创建以下目录结构
在这里插入图片描述

在这里插入图片描述

其中JPEGImgs里面是所有图片
Annotations里面是所有xml文件
dataset.py文件代码为

import os
import random

trainval_percent =0.8 # 0.8
train_percent =0.8  #0.8
xmlfilepath = 'Annotations'
txtsavepath = 'ImageSets\Main'
total_xml = os.listdir(xmlfilepath)

num = len(total_xml)
list = range(num)
tv = int(num * trainval_percent)
tr = int(tv * train_percent)
trainval = random.sample(list, tv)
train = random.sample(trainval, tr)

ftrainval = open('ImageSets/Main/trainval.txt', 'w')
ftest = open('ImageSets/Main/test.txt', 'w')
ftrain = open('ImageSets/Main/train.txt', 'w')
fval = open('ImageSets/Main/val.txt', 'w')

for i in list:
    name = total_xml[i][:-4] + '\n'
    if i in trainval:
        ftrainval.write(name)
        if i in train:
            ftrain.write(name)
        else:
            fval.write(name)
            ftest.write(name)
    else:
        ftest.write(name)

ftrainval.close()
ftrain.close()
fval.close()
ftest.close()
print('数据集划分完成')

准备好一切后,python dataset.py自动划分数据集
由此,数据集已经准备完成

二.修改mmdetection配置文件

我的环境版本
``
torch 2.0.1
mmcv 2.1.0
mmdeploy 1.3.1
mmdeploy-runtime 1.3.1
mmdeploy-runtime-gpu 1.3.1
mmdet 3.2.0
mmengine 0.10.1


### 我使用的模型为cascade-rcnn-r101
## 1.0  修改voc0712.py
vi /mmdetection-main/configs/_base_/datasets/voc0712.py

```python
# dataset settings
dataset_type = 'VOCDataset'
data_root = 'data/VOCdevkit/'

# Example to use different file client
# Method 1: simply set the data root and let the file I/O module
# automatically Infer from prefix (not support LMDB and Memcache yet)

# data_root = 's3://openmmlab/datasets/detection/segmentation/VOCdevkit/'

# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6
# backend_args = dict(
#     backend='petrel',
#     path_mapping=dict({
#         './data/': 's3://openmmlab/datasets/segmentation/',
#         'data/': 's3://openmmlab/datasets/segmentation/'
#     }))
backend_args = None

train_pipeline = [
    dict(type='LoadImageFromFile', backend_args=backend_args),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Resize', scale=(1000, 600), keep_ratio=True),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PackDetInputs')
]
test_pipeline = [
    dict(type='LoadImageFromFile', backend_args=backend_args),
    dict(type='Resize', scale=(1000, 600), keep_ratio=True),
    # avoid bboxes being resized
    dict(type='LoadAnnotations', with_bbox=True),
    dict(
        type='PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
                   'scale_factor'))
]
train_dataloader = dict(
    batch_size=2,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    batch_sampler=dict(type='AspectRatioBatchSampler'),
    dataset=dict(
        type='RepeatDataset',
        times=3,
        dataset=dict(
            type='ConcatDataset',
            # VOCDataset will add different `dataset_type` in dataset.metainfo,
            # which will get error if using ConcatDataset. Adding
            # `ignore_keys` can avoid this error.
            ignore_keys=['dataset_type'],
            datasets=[
                dict(
                    type=dataset_type,
                    data_root=data_root,
                    ann_file='VOC2007/ImageSets/Main/trainval.txt',
                    data_prefix=dict(sub_data_root='VOC2007/'),
                    filter_cfg=dict(
                        filter_empty_gt=True, min_size=32, bbox_min_size=32),
                    pipeline=train_pipeline,
                    backend_args=backend_args),
                # dict(
                #     type=dataset_type,
                #     data_root=data_root,
                #     ann_file='VOC2012/ImageSets/Main/trainval.txt',
                #     data_prefix=dict(sub_data_root='VOC2012/'),
                #     filter_cfg=dict(
                #         filter_empty_gt=True, min_size=32, bbox_min_size=32),
                #     pipeline=train_pipeline,
                #     backend_args=backend_args)
            ])))

val_dataloader = dict(
    batch_size=2,
    num_workers=2,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file='VOC2007/ImageSets/Main/test.txt',
        data_prefix=dict(sub_data_root='VOC2007/'),
        test_mode=True,
        pipeline=test_pipeline,
        backend_args=backend_args))
test_dataloader = val_dataloader

val_evaluator = dict(type='VOCMetric', metric='mAP', eval_mode='11points')
test_evaluator = val_evaluator

2.0 修改cascade-rcnn_r50_fpn.py

vi mmdetection-main/configs/base/models/cascade-rcnn_r50_fpn.py
修改3出位置 ,num_classes=自己对应的类别数量,

3.0 cascade-rcnn_r50_fpn_1x_coco.py文件修改

vi /mmdetection-main/configs/cascade_rcnn/cascade-rcnn_r50_fpn_1x_coco.py

_base_ = [
    '../_base_/models/cascade-rcnn_r50_fpn.py',
    # '../_base_/datasets/coco_detection.py',
    '../_base_/datasets/voc0712.py',
    '../_base_/schedules/schedule_2x.py', '../_base_/default_runtime.py'
]

4.0 修改voc.py
vi /mmdetection-main/mmdet/datasets/voc.py
在这里插入图片描述

5.0 修改class_name.py
vi /mmdetection-main/mmdet/evaluation/functional/class_names.py
在这里插入图片描述

好了,配置文件修改完成,接下来就是开始训练

三.启动训练

单卡训练模型示例 
python3 ./tools/train.py ./configs/faster_rcnn_r50_fpn_1x.py
python tools/train.py configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py --work-dir work_dirs/cascade_rcnn_r50_fpn_1x_0603/

多gpu分布式训练示例

./tools/dist_train.sh configs/cascade_rcnn/cascade-rcnn_r101_fpn_1x_coco.py 2 --work-dir work_dirs/cascade_rcnn_r101_fpn_1x_0120



resume 训练
 ./tools/dist_train.sh configs/cascade_rcnn/cascade_rcnn_r101_fpn_20e_coco.py 2 --resume-from work_dirs/cascade_rcnn_r101_fpn_1x_coco0716/latest.pth --work-dir work_dirs/cascade_rcnn_r101_fpn_1x_coco0716

四.模型推理

from mmdet.apis import DetInferencer
import mmcv
import os
import time
import cv2
import matplotlib.pyplot as plt


config_file = '/home/ai-developer/data/mmdetection-main/work_dirs/cascade_rcnn_r101_fpn_1x25/cascade-rcnn_r101_fpn_1x_coco.py'
checkpoint_file = '/home/ai-developer/data/mmdetection-main/work_dirs/cascade_rcnn_r101_fpn_1x25/epoch_19.pth'

inferencer = DetInferencer(model=config_file,weights=checkpoint_file,device='cuda:0') # ,palette ='random'

imgdir = '/home/ai-developer/data/mmdetection-main/work_dirs/cascade_rcnn_r101_fpn_1x_0205/test_img'
imgs = os.listdir(imgdir)
i = 0
start = time.time()
for img in imgs:
    i = i + 1
    name=os.path.basename(img)
    # print('name:',name)
    imgpath = os.path.join(imgdir, img)  # or img = mmcv.imread(img), which will only load it once

    # print(imgpath)
    out_dir = os.path.join('./results/shebei0205', img)
    result = inferencer(imgpath,out_dir=out_dir,show=False)#,show=True

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1447772.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

关于显卡、显卡驱动、cuda、cuDNN等的区别

关于显卡、显卡驱动、cuda、cuDNN等的区别 刚接触AI或机器学习框架时,经常会被这几个概念搞混,尤其是显卡驱动、cuda、cuDNN这个三个软的东西;此外,NVCC、cudatoolkit又是什么呢? 1. 显卡(GPU) 显卡就是硬件&#xff…

C# winfrom中NPOI操作EXCEL

前言 1.整个Excel表格叫做工作表:WorkBook(工作薄),包含的叫页(工作表):Sheet;行:Row;单元格Cell。 2.忘了告诉大家npoi是做什么的了,npoi 能够读…

揭秘产品迭代计划制定:从0到1打造完美迭代策略

产品迭代计划是产品团队确保他们能够交付满足客户需求的产品以及实现其业务目标的重要工具。开发一个成功的产品迭代计划需要仔细考虑产品的目标、客户需求、市场趋势和可用资源。以下是帮助您创建产品迭代计划的一些步骤:建立产品目标、收集客户反馈、分析市场趋势…

Vue3快速上手(五)ref之对象类型的响应式数据

一、ref之对象类型的响应式数据 1.1 基本语法 import { ref } from vuelet x ref(初始值)console.log(xxx --> , x.value);x为一个RefImpl对象,该对象的value属性为实际值,在script里需要操作x.value来改变数据的值,在页面里则可以直接…

计算机网络——09Web-and-HTTP

Web and HTTP 一些术语 Web页:由一些对象组成对象可以是HTML文件、JPEG图像,JAVA小程序,声音剪辑文件等Web页含有一个基本的HTML文件,该基本HTML文件又包含若干对象的引用(链接)通过URL对每个对象进行引用…

C语言每日一题(56)平衡二叉树

力扣网 110 平衡二叉树 题目描述 给定一个二叉树,判断它是否是高度平衡的二叉树。 本题中,一棵高度平衡二叉树定义为: 一个二叉树每个节点 的左右两个子树的高度差的绝对值不超过 1 。 示例 1: 输入:root [3,9,20,…

LabVIEW智能温度监控系统

LabVIEW智能环境监测系统 介绍了一个基于LabVIEW的智能环境监测系统的开发过程。该系统在实时监测和分析环境参数,如温度、湿度、气体浓度等,以提供精确的数据支持,确保环境安全与健康。通过高效的数据处理和友好的用户界面,系统…

单链表基础知识点

单链表的读取 对于单链表实现获取第i个元素的数据的操作 GetElem&#xff0c;在算法上&#xff0c;相对要麻烦一些。 获得链表第i个数据的算法思路: 声明一个结点p指向链表第一个结点&#xff0c;初始化j从1开始;当j<i时&#xff0c;就遍历链表&#xff0c;让p的指针向后移…

算法沉淀——分治算法(leetcode真题剖析)

算法沉淀——分治算法 快排思想01.颜色分类02.排序数组03.数组中的第K个最大元素04.库存管理 III 归并思想01.排序数组02.交易逆序对的总数03.计算右侧小于当前元素的个数04.翻转对 分治算法是一种解决问题的算法范式&#xff0c;其核心思想是将一个大问题分解成若干个小问题&a…

计算机二级C语言的注意事项及相应真题-4-程序修改

目录&#xff1a; 31.逐个比较p、q所指两个字符串对应位置中的字符&#xff0c;把ASCII值大或相等的字符依次存放到c所指数组中&#xff0c;形成一个新的字符串32.求矩阵&#xff08;二维数组)a[N][N]中每行的最小值&#xff0c;结果存放到数组b中33.将一个十进制整数转换成r(二…

力扣刷题54-螺旋矩阵

给你一个 m 行 n 列的矩阵 matrix &#xff0c;请按照 顺时针螺旋顺序 &#xff0c;返回矩阵中的所有元素。 示例 1&#xff1a; 输入&#xff1a;matrix [[1,2,3],[4,5,6],[7,8,9]] 输出&#xff1a;[1,2,3,6,9,8,7,4,5]示例 2&#xff1a; 输入&#xff1a;matrix [[1,2,3,…

我让ChatGPT帮我钓妹子,它一口气撩了5000人

来自俄罗斯的一名AI开发者、社交平台TenChat的产品经理 AleksandrZhadan于1月30日在推特上发布了自己的婚讯&#xff0c;他将要与自己的女友Karina Imranovna在今年的8月结婚。令人震惊的是Aleksandr Zhadan介绍的认识女友的窍门-ChatGPT 帮他找到了另一半&#xff0c;并且通过…

springsecurity6使用

spring security 中的类 &#xff1a; AuthenticationManager : 实现类&#xff1a;ProviderManager 管理很多的 provider &#xff0c;&#xff0c;&#xff0c; 经常使用的&#xff0c;DaoAuthenticationProvider , 这个要设置一个 UserDetailService , 查找数据库&#xff…

【ES6】Promise

Promise 回调地狱 const fs require(fs);fs.readFile(./a.txt, utf-8, (err, data) > {if(err) throw err;console.log(data);fs.readFile(./b.txt, utf-8, (err, data) > {if(err) throw err;console.log(data);fs.readFile(./c.txt, utf-8, (err, data) > {if(er…

斯巴鲁Subaru EDI需求分析

斯巴鲁Subaru是日本运输集团斯巴鲁公司&#xff08;前身为富士重工&#xff09;的汽车制造部门&#xff0c;以性能而闻名&#xff0c;曾赢得 3 次世界拉力锦标赛和 10 次澳大利亚拉力锦标赛。 斯巴鲁Subaru EDI 需求分析 企业与斯巴鲁Subaru建立EDI连接&#xff0c;首先需要确…

MQTT的学习与应用

文章目录 一、什么是MQTT二、MQTT协议特点三、MQTT应用领域四、安装Mosquitto五、如何学习 MQTT 一、什么是MQTT MQTT&#xff08;Message Queuing Telemetry Transport&#xff09;是一种轻量级的消息传输协议&#xff0c;设计用于在低带宽、不稳定的网络环境中进行高效的通信…

ITK 图像分割(一):阈值ThresholdImageFilter

效果&#xff1a; 1、itkThresholdImageFilter 该类的主要功能是通过设置低阈值、高阈值或介于高低阈值之间&#xff0c;则将图像值输出为用户指定的值。 如果图像值低于、高于或介于设置的阈值之间&#xff0c;该类就将图像值设置为用户指定的“外部”值&#xff08;默认情况…

基于JAVA的新能源电池回收系统 开源项目

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 用户档案模块2.2 电池品类模块2.3 回收机构模块2.4 电池订单模块2.5 客服咨询模块 三、系统设计3.1 用例设计3.2 业务流程设计3.3 E-R 图设计 四、系统展示五、核心代码5.1 增改电池类型5.2 查询电池品类5.3 查询电池回…

联合体与枚举

联合体与枚举 联合体枚举问题 联合体 联合体也是由一个或多个成员构成的数据类型,它最大的特点是只为最大的一个成员开辟空间,其他成员共用这个空间,这个东西也叫共用体!!! union Un {char c;int i; };int main() {union Un un { 0 };un.c 0x01;//先为最大的成员开辟空间un.…

STM32单片机的基本原理与应用(七)

超声波测距实验 基本原理 超声波测距实验是STM32单片机通过控制HC-SR04超声波模块&#xff0c;使其发送超声波&#xff0c;遇到物体反射回超声波来实现距离测量&#xff0c;其原理就是在发射超声波到接收超声波会有一段时间&#xff0c;而超声波在空气中传播的速度为声速&…