简洁易懂的Yolov7本地训练自己的数据并onnx推理

news2024/11/24 0:54:21

YOLOV7

官方代码Yolov7

测试官方案例

1、下载下来先按照github教程下载yolov7.pt权重
2、pycharm(或其他)打开detect文件,修改权重路径和推理图片的路径,分别是

parser.add_argument('--weights', nargs='+', type=str, default='yolov7.pt', help='model.pt path(s)')
parser.add_argument('--source', type=str, default='inference/images', help='source')  # file/folder, 0 for webcam

运行结果保存在runs/detect/exp文件夹下,运行结果如图:
在这里插入图片描述

创建自己的数据集

感受了yolov7之后,就是训练和运行自己的数据了。

划分数据集

其实很简单,只需要创建为如下txt文件即可!
此文件需要train.txt、val.txt和test.txt,每个文件都是数据绝对路径

注:没有根据官方那样搞文件夹的形式,有点麻烦,所以自己写了
在这里插入图片描述
以精子数据集为例,下载链接

划分代码如下:

import os
import shutil
import xml.etree.ElementTree as ET
classes = ["S", "Impurity"]

def convert(size, box):
    dw = 1. / size[0]
    dh = 1. / size[1]
    x = (box[0] + box[1]) / 2.0
    y = (box[2] + box[3]) / 2.0
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x * dw
    w = w * dw
    y = y * dh
    h = h * dh
    return (x, y, w, h)


def convert_annotation(in_file, out_file):
    tree = ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)

    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult) == 1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
             float(xmlbox.find('ymax').text))
        bb = convert((w, h), b)
        out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')


filepath = r'./data/Sperm/Sperm_ori/Frames from original videos/'
trainval_percent = 0.9  # 训练验证集占整个数据集的比重(划分训练集和测试验证集)
train_percent = 0.9  # 训练集占整个训练验证集的比重(划分训练集和验证集)
total_sample = os.listdir(filepath)
num = len(total_sample)
list = range(num)
tv = int(num * trainval_percent) + 1
tr = int(tv * train_percent)
tt = int(num * (1 - trainval_percent))
trainval = list[:tv]
train = list[:tr]
valid = list[tr:tv]
test = list[tv:]


print("train and val size", len(trainval))
print("train size", len(train))
print("val size", len(valid))
print("test size", len(test))

saveBasePath = r'./data/Sperm/main_txt/'  # 生成的txt文件的保存路径
ftrainval = open(os.path.join(saveBasePath, 'trainval.txt'), 'w')
ftest = open(os.path.join(saveBasePath, 'test.txt'), 'w')
ftrain = open(os.path.join(saveBasePath, 'train.txt'), 'w')
fval = open(os.path.join(saveBasePath, 'val.txt'), 'w')

xml_path = r'./data/Sperm/Sperm_ori/Subset-A/'
xmltxtBasePath = r'./data/Sperm/main_txt/label/'  # 生成的txt文件的保存路径
def mk(p_path):
    if not os.path.exists(p_path):
        os.makedirs(p_path)
train_xmltxt_path = xmltxtBasePath + 'train/'
valid_xmltxt_path = xmltxtBasePath + 'valid/'
test_xmltxt_path = xmltxtBasePath + 'test/'
all_xmltxt_path = xmltxtBasePath + 'all/'
mk(train_xmltxt_path)
mk(valid_xmltxt_path)
mk(test_xmltxt_path)
mk(all_xmltxt_path)

for i, one in enumerate(total_sample):
    one_samples = os.listdir(filepath + one)
    for name in one_samples:
        # 图像路径
        new_name = filepath + one + '/' + name
        basename = name.split('.')[0] + '.xml'
        in_file = open(xml_path + basename)

        if i in trainval:
            ftrainval.write(new_name + '\n')
            if i in train:
                ftrain.write(new_name + '\n')
                out_file = open(train_xmltxt_path + name.split('.')[0] + '.txt', 'w', encoding='utf-8')
                convert_annotation(in_file, out_file)
            else:
                fval.write(new_name + '\n')
                out_file = open(valid_xmltxt_path + name.split('.')[0] + '.txt', 'w', encoding='utf-8')
                convert_annotation(in_file, out_file)
        else:
            ftest.write(new_name + '\n')
            out_file = open(test_xmltxt_path + name.split('.')[0] + '.txt', 'w', encoding='utf-8')
            convert_annotation(in_file, out_file)
        in_file = open(xml_path + basename)
        out_file = open(all_xmltxt_path + name.split('.')[0] + '.txt', 'w', encoding='utf-8')
        convert_annotation(in_file, out_file)



ftrainval.close()
ftrain.close()
fval.close()
ftest.close()


生成标签

根据每一个图像的名称命名标签txt文件,可以都放在一起
在这里插入图片描述
每个图象包含的目标标签存储如下:
一行一个目标,五列,第1列为分类编号,第2-3列为目标中心xy,第4-5列为目标的wh,注意以上均为归一化结果
在这里插入图片描述

注意,一般都是用labelme获取的xml标注文件,需要转化为上述txt文件,转化代码集成在划分代码里,可根据需求拆分。

训练

修改配置和代码

由于自己修改了数据集文件格式,所以
dataset.py文件中img2label_paths更换如下:

def img2label_paths_my(p, img_paths):
    # Define label paths as a function of image paths
    ll = []
    for x in img_paths:
        basename = os.path.basename(x).replace('.png', '.txt')
        pp = os.path.dirname(p) + os.sep + 'label' + os.sep + 'all'  + os.sep + basename
        ll.append(pp)
        # print(pp)
    return ll

train.py修改如下:路径根据实际

parser.add_argument('--weights', type=str, default='weights/yolov7.pt', help='initial weights path')
parser.add_argument('--cfg', type=str, default='./cfg/training/yolov7.yaml', help='model.yaml path')
parser.add_argument('--data', type=str, default='data/Sperm/yaml/sperm.yaml', help='data.yaml path')
parser.add_argument('--hyp', type=str, default='data/hyp.scratch.p5.yaml', help='hyperparameters path')
parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--batch-size', type=int, default=32, help='total batch size for all GPUs')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')

之后便可以训练了!

导出onnx文件

核对export.py参数:

parser.add_argument('--weights', type=str, default='weights/best.pt', help='weights path') #pt路径
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size')  # height, width
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes')
parser.add_argument('--dynamic-batch', action='store_true', help='dynamic batch onnx for tensorrt and onnx-runtime')
parser.add_argument('--grid', action='store_true', default=True, help='export Detect() layer grid') #包含检测
parser.add_argument('--end2end', action='store_true', default=True, help='export end2end onnx')#
parser.add_argument('--max-wh', type=int, default=640,
                    help='None for tensorrt nms, int value for onnx-runtime nms')
parser.add_argument('--topk-all', type=int, default=100, help='topk objects for every images')
parser.add_argument('--iou-thres', type=float, default=0.45, help='iou threshold for NMS')
parser.add_argument('--conf-thres', type=float, default=0.25, help='conf threshold for NMS')
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--simplify', action='store_true', default=True, help='simplify onnx model') #
parser.add_argument('--include-nms', action='store_true', help='export end2end onnx')
parser.add_argument('--fp16', action='store_true', help='CoreML FP16 half-precision export')
parser.add_argument('--int8', action='store_true', help='CoreML INT8 quantization')

onnx文件推理

此部分集成了pt和onnx的训练方式
不仅包含自己数据集推理,也包含了coco的测试

import argparse
import time
from pathlib import Path
import cv2
import torch
import numpy as np
from numpy import random
from models.experimental import attempt_load
from utils.datasets import  LoadImages
from utils.general import check_img_size,  non_max_suppression,  \
    scale_coords, increment_path
from utils.plots import plot_one_box
from utils.torch_utils import select_device, time_synchronized

def detect():
    source, weights, imgsz = opt.source, opt.weights, opt.img_size
    # Directories
    save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))  # increment run
    save_dir.mkdir(parents=True, exist_ok=True)  # make dir

    # Initialize
    device = select_device(opt.device)
    half = device.type != 'cpu'  # half precision only supported on CUDA

    # Load model
    onnx_flag = 0
    if weights.split('.')[-1] == 'onnx':
        import onnxruntime
        ort_session = onnxruntime.InferenceSession(weights)
        def to_numpy(tensor):
            return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
        onnx_flag = 1
        stride = 32
    else:
        model = attempt_load(weights, map_location=device)  # load FP32 model
        stride = int(model.stride.max())  # model stride

        if half:
            model.half()  # to FP16
        # Set Dataloader
    vid_path, vid_writer = None, None

    imgsz = check_img_size(imgsz[0], s=stride)  # check img_size
    if onnx_flag:
        dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=False, onnx_flag=True)
    else:
        dataset = LoadImages(source, img_size=imgsz, stride=stride)

    # Get names and colors
    if onnx_flag:
        names = ["S", "Impurity"]
        colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
    else:
        names = model.module.names if hasattr(model, 'module') else model.names
        colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
    if 'yolov7' in weights:
        names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
                 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
                 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase',
                 'frisbee',
                 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
                 'surfboard',
                 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
                 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
                 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
                 'cell phone',
                 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
                 'teddy bear',
                 'hair drier', 'toothbrush']
        # colors = {name: [random.randint(0, 255) for _ in range(3)] for i, name in enumerate(names)}
        colors = [[random.randint(0, 255) for _ in range(3)] for i, name in enumerate(names)]

    t0 = time.time()
    for path, img, im0s, vid_cap in dataset:
        img = torch.from_numpy(img).to(device)
        img = img.half() if half else img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        if img.ndimension() == 3:
            img = img.unsqueeze(0)
        # print(img.shape)

        # Inference
        t1 = time_synchronized()
        if onnx_flag:
            ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
            ort_outs = ort_session.run(None, ort_inputs)
            pred = ort_outs[0]
            pred = torch.from_numpy(pred)
            pred = [pred[:,[1,2,3,4,6,5]]]

        else:
            pred = model(img, augment=opt.augment)[0]
            # Apply NMS
            pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
        t3 = time_synchronized()


        # Process detections
        for i, det in enumerate(pred):  # detections per image
            p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
            p = Path(p)  # to Path
            save_path = str(save_dir / p.name)  # img.jpg
            if len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
                # Print results
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string
                for *xyxy, conf, cls in reversed(det):
                    # Add bbox to image
                    label = f'{names[int(cls)]} {conf:.2f}'
                    plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=1)


            # Print time (inference + NMS)
            print(f'{s}Done. ({(1E3 * (t3 - t1)):.1f}ms) Inference')

            # Save results (image with detections)
            if dataset.mode == 'image':
                cv2.imwrite(save_path, im0)
                print(f" The image with the result is saved in: {save_path}")
            else:  # 'video' or 'stream'
                if vid_path != save_path:  # new video
                    vid_path = save_path
                    if isinstance(vid_writer, cv2.VideoWriter):
                        vid_writer.release()  # release previous video writer
                    if vid_cap:  # video
                        fps = vid_cap.get(cv2.CAP_PROP_FPS)
                        w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                        h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                    else:  # stream
                        fps, w, h = 30, im0.shape[1], im0.shape[0]
                        save_path += '.mp4'
                    vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
                vid_writer.write(im0)

    print(f'Done. ({time.time() - t0:.3f}s)')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # parser.add_argument('--weights', nargs='+', type=str, default='weights/yolov7.pt', help='model.pt path(s)')
    # parser.add_argument('--weights', nargs='+', type=str, default='weights/yolov7.onnx', help='model.pt path(s)')
    parser.add_argument('--weights', nargs='+', type=str, default='weights/best.onnx', help='model.pt path(s)')
    parser.add_argument('--source', type=str, default='inference/images/sperm', help='source')  # file/folder, 0 for webcam
    parser.add_argument('--img-size', type=int, default=[640, 640], help='inference size (pixels)')
    parser.add_argument('--conf-thres', type=float, default=0.35, help='object confidence threshold')
    parser.add_argument('--iou-thres', type=float, default=0.25, help='IOU threshold for NMS')


    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
    parser.add_argument('--augment', action='store_true', help='augmented inference')
    parser.add_argument('--project', default='runs/detect', help='save results to project/name')
    parser.add_argument('--name', default='exp', help='save results to project/name')
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
    opt = parser.parse_args()
    print(opt)
    with torch.no_grad():
        detect()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1320167.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[LLM]nanoGPT---训练一个写唐诗的GPT

karpathy/nanoGPT: The simplest, fastest repository for training/finetuning medium-sized GPTs. (github.com) 原有模型使用的莎士比亚的戏剧数据集, 如果需要一个写唐诗机器人,需要使用唐诗的文本数据, 一个不错的唐诗,宋词数据的下载…

东南亚Shopee:东南亚领先的电商平台

Shopee是东南亚地区最著名的电商平台之一,成立于2015年。作为新加坡互联网公司Sea Group(原名Garena)的一部分,Shopee在东南亚市场拥有广泛的业务覆盖范围,包括新加坡、马来西亚、泰国、印度尼西亚、越南和菲律宾等国家…

源码编译 METIS 以及 GKlib 在Linux ubuntu上

1. GKlib 构建 $ git clone --recursive gitgithub.com:Kleenelan/GKlib.git $ cd GKlib/ $ make config ccgcc openmpset $ make $ make install源码构建了 GKlib 的 openmp 版本,以便充分使用多核的算力; make config ccgcc openmpset 的效果图&#…

QT-可拖拉绘图工具

QT-可拖拉绘图工具 一、演示效果二、关键程序三、下载链接 一、演示效果 二、关键程序 #include "diagramscene.h" #include "arrow.h"#include <QTextCursor> #include <QGraphicsSceneMouseEvent> #include <QDebug>QPen const Diagr…

Windows本地搭建开源企业管理套件Odoo并实现公网访问

文章目录 前言1. 下载安装Odoo&#xff1a;2. 实现公网访问Odoo本地系统&#xff1a;3. 固定域名访问Odoo本地系统 前言 Odoo是全球流行的开源企业管理套件&#xff0c;是一个一站式全功能ERP及电商平台。 开源性质&#xff1a;Odoo是一个开源的ERP软件&#xff0c;这意味着企…

当OneNote不同步时,你需要做些什么让其恢复在线

OneNote笔记本无法同步的原因有很多。由于OneNote使用OneDrive将笔记本存储在云中,因此可能会出现互联网连接问题,与多人联机处理笔记本时会出现延迟,以及从不同设备处理同一笔记本时会发生延迟。以下是OneNote不同步时的操作。 注意:本文中的说明适用于OneNote for Windo…

什么是工业互联网平台?

1.什么是工业互联网平台&#xff1f; 1.1 工业互联网平台的定义 工业互联网平台是一个连接设备与服务、数据与人的跨行业、跨领域的全新工业平台。工业互联网平台利用了互联网、物联网、大数据、AI等技术&#xff0c;集成各类工业设备&#xff0c;不断采集和分析数据&#xff…

百度地图添加坐标点,并返回坐标信息

1、创建地图容器 在mounted中初始化地图、鼠标绘制工具和添加鼠标监听事件 vue data中添加地图和绘制工具对象 2、添加初始化化地图方法 initMap(longitude, latitude) {let that thisthat.map new BMapGL.Map("container");// 创建地图实例if (longitude null ||…

功放诊断测试

1.切换trace显示时间模式&#xff0c;Toggle time mode 2.测seedkey 需要加载seednkey.dll 3.功能寻址和物理寻址切换

idea恢复默认出厂设置

idea恢复默认出厂设置 1、IDEA 2021 之后&#xff0c; 在顶部工具栏&#xff0c;选择 File | Manage IDE Settings | Restore Default Settings. 2、或者双击shift搜索Restore Default settings然后点击restore and restart

MySQL安装——备赛笔记——2024全国职业院校技能大赛“大数据应用开发”赛项——任务2:离线数据处理

MySQLhttps://www.mysql.com/ 将下发的ds_db01.sql数据库文件放置mysql中 12、编写Scala代码&#xff0c;使用Spark将MySQL的ds_db01库中表user_info的全量数据抽取到Hive的ods库中表user_info。字段名称、类型不变&#xff0c;同时添加静态分区&#xff0c;分区字段为etl_da…

ubuntu18.04 64 位安装笔记——备赛笔记——2024全国职业院校技能大赛“大数据应用开发”赛项——任务2:离线数据处理

进入VirtuakBox官网&#xff0c;网址链接&#xff1a;Oracle VM VirtualBoxhttps://www.virtualbox.org/ 网页连接&#xff1a;Ubuntu Virtual Machine Images for VirtualBox and VMwarehttps://www.osboxes.org/ubuntu/ 将下发的ds_db01.sql数据库文件放置mysql中 12、编写S…

代码随想录算法训练营Day5 | 242.有效的字母异位词、349.两个数组的交集、202.快乐数、1. 两数之和

LeetCode 242 有效的字母异位词 本题思路&#xff1a;我们只需要分别统计&#xff0c;字符串 s ,字符串 t 中每个字符的出现次数&#xff0c;分别用两个数组来存储&#xff0c;然后再循环遍历对比两个数组中相同位置出现的次数&#xff0c;如果有不同的则返回 false。 统计完之…

双非大数据

双非本秋招上岸总结 个人简介 学历&#xff1a;双非&#xff1b; 专业&#xff1a;软件工程&#xff1b; 求职岗位&#xff1a;大数据开发工程师&#xff1b; 状态&#xff1a;已上岸 翻车经历 学校以Java后端开发为主流&#xff0c;我从大二开始学习Java&#xff0c;直到大四…

HarmonyOS(十四)——状态管理之@State装饰器(组件内状态)

前言 在初识状态管理我们了解了状态管理的基本概念&#xff0c;以及管理组件拥有的状态有哪几种装饰器&#xff0c;今天我们就来认识一下第一种装饰器&#xff1a;State装饰器&#xff08;组件内状态&#xff09;。 概述 State装饰的变量&#xff0c;或称为状态变量&#xf…

解决 Hbuilder打包 Apk pad 无法横屏 以及 H5 直接打包 成Apk

解决 Hbuilder打包 Apk pad 无法横屏 前言云打包配置 前言 利用VUE 写了一套H5 想着 做一个APP壳 然后把 H5 直接嵌进去 客户要求 在pad 端 能够操作 然后页面风格 也需要pad 横屏展示 云打包 配置 下面是manifest.json 配置文件 {"platforms": ["iPad"…

Vue--第十天

终极实战----大事件项目 1.简介&#xff1a; 2.创建项目&#xff1a; 1.创建&#xff08;159-163&#xff09;&#xff1a; 还是对着视频操作吧 2.路由&#xff1a; 3.element Plus: 导入element Plus 后不需要再导入插件配置&#xff0c;就连组件导入也不用 4.pinia构建用…

Ubuntu Desktop 22.04 设置 ssh 超时时间

Ubuntu Desktop 22.04 使用 ssh 连接服务器时&#xff0c;发现一段时间不操作就会自动断开连接&#xff0c;解决方法如下&#xff1a; 打开 /etc/ssh/ssh_config 文件&#xff1a; sudo vim /etc/ssh/ssh_config在文件最后添加&#xff1a; # ssh 客户端会每隔 30 秒发送一个…

网络(九)三层路由、DHCP以及VRRP协议介绍

目录 一、三层路由 1. 定义 2. 交换原理 3. 操作演示 3.1 图示 3.2 LSW1新建vlan10、20、30&#xff0c;分别对应123接口均为access类型&#xff0c;接口4为trunkl类型&#xff0c;允许所有vlan通过 3.3 LSW2新建vlan10、20、30&#xff0c;配置接口1为trunk类型&…

ElasticSearch单机或集群未授权访问漏洞

漏洞处理方法&#xff1a; 1、可以使用系统防火墙 来做限制只允许ES集群和Server节点的IP来访问漏洞节点的9200端口&#xff0c;其他的全部拒绝。 2、在ES节点上设置用户密码 漏洞现象&#xff1a;直接访问9200端口不需要密码验证 修复过程 2.1 生成认证文件 必须要生成…