模型实战一之YOLOv7实例分割、模型训练自己数据集

news2024/11/23 20:59:12

模型实战一之YOLOv7实例分割、模型训练自己数据集

1.环境准备

  • 下载yolov7实例分割模型:
git clone https://github.com/WongKinYiu/yolov7.git -b mask yolov7-mask

cd yolov7-mask
  • 安装环境
#查看已安装环境
conda info --envs
#查看安装了哪些包
conda list

#创建环境 
conda create -n yolov7 python=3.8
#激活
conda activate yolov7

# 安装 torch 1.8.2+cu11.1
pip install torch==1.8.2 torchvision==0.9.2 torchaudio===0.8.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cu111

#其他版本:torch+cuda10.2
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio===0.8.1 -f https://download.pytorch.org/whl/torch_stable.html


# 修改requirements.txt,将其中的torch和torchvision注释掉
pip install -r requirements.txt
  • 安装detectron2
    detectron是facebook发布的开源机器视觉库,安装教程参考:https://blog.csdn.net/qq_45770232/article/details/126471738
# 安装detectron2
#先安装ninja
pip install ninja

git clone https://github.com/facebookresearch/detectron2
cd detectron2
python setup.py install
cd ..

2.测试实例分割demo

  • 测试:
下载权重放在detect.py路径下:yolov7.pt ... yolov7-mask.pt

测试yolov7目标检测:

 python detect.py --weights yolov7.pt --conf 0.25 --img-size 640 --source inference/images/horses.jpg 

在这里插入图片描述
在这里插入图片描述

  • 测试实例分割 - python
import matplotlib.pyplot as plt
import torch
import cv2
import yaml
from torchvision import transforms
import numpy as np

from utils.datasets import letterbox
from utils.general import non_max_suppression_mask_conf

from detectron2.modeling.poolers import ROIPooler
from detectron2.structures import Boxes
from detectron2.utils.memory import retry_if_cuda_oom
from detectron2.layers import paste_masks_in_image

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
with open('data/hyp.scratch.mask.yaml') as f:
    hyp = yaml.load(f, Loader=yaml.FullLoader)
weigths = torch.load('./weights/yolov7-mask.pt')
model = weigths['model']
model = model.half().to(device)
_ = model.eval()

image = cv2.imread('inference/images/horses.jpg')  # 504x378 image
image = letterbox(image, 640, stride=64, auto=True)[0]
image_ = image.copy()
image = transforms.ToTensor()(image)
image = torch.tensor(np.array([image.numpy()]))
image = image.to(device)
image = image.half()

output = model(image)

inf_out, train_out, attn, mask_iou, bases, sem_output = output['test'], output['bbox_and_cls'], output['attn'], output['mask_iou'], output['bases'], output['sem']

bases = torch.cat([bases, sem_output], dim=1)
nb, _, height, width = image.shape
names = model.names
pooler_scale = model.pooler_scale
pooler = ROIPooler(output_size=hyp['mask_resolution'], scales=(pooler_scale,), sampling_ratio=1, pooler_type='ROIAlignV2', canonical_level=2)

output, output_mask, output_mask_score, output_ac, output_ab = non_max_suppression_mask_conf(inf_out, attn, bases, pooler, hyp, conf_thres=0.25, iou_thres=0.65, merge=False, mask_iou=None)

pred, pred_masks = output[0], output_mask[0]
base = bases[0]
bboxes = Boxes(pred[:, :4])
original_pred_masks = pred_masks.view(-1, hyp['mask_resolution'], hyp['mask_resolution'])
pred_masks = retry_if_cuda_oom(paste_masks_in_image)( original_pred_masks, bboxes, (height, width), threshold=0.5)
pred_masks_np = pred_masks.detach().cpu().numpy()
pred_cls = pred[:, 5].detach().cpu().numpy()
pred_conf = pred[:, 4].detach().cpu().numpy()
nimg = image[0].permute(1, 2, 0) * 255
nimg = nimg.cpu().numpy().astype(np.uint8)
nimg = cv2.cvtColor(nimg, cv2.COLOR_RGB2BGR)
nbboxes = bboxes.tensor.detach().cpu().numpy().astype(np.int32)
pnimg = nimg.copy()

for one_mask, bbox, cls, conf in zip(pred_masks_np, nbboxes, pred_cls, pred_conf):
    if conf < 0.25:
        continue
    color = [np.random.randint(255), np.random.randint(255), np.random.randint(255)]

    pnimg[one_mask] = pnimg[one_mask] * 0.5 + np.array(color, dtype=np.uint8) * 0.5
    pnimg = cv2.rectangle(pnimg, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2)
    #label = '%s %.3f' % (names[int(cls)], conf)
    #t_size = cv2.getTextSize(label, 0, fontScale=0.5, thickness=1)[0]
    #c2 = bbox[0] + t_size[0], bbox[1] - t_size[1] - 3
    #pnimg = cv2.rectangle(pnimg, (bbox[0], bbox[1]), c2, color, -1, cv2.LINE_AA)  # filled
    #pnimg = cv2.putText(pnimg, label, (bbox[0], bbox[1] - 2), 0, 0.5, [255, 255, 255], thickness=1, lineType=cv2.LINE_AA)


# coco example
# %matplotlib inline
cv2.imshow("instance_result.jpg", pnimg)
cv2.waitKey(0)
# cv2.imwrite("instance_result.jpg", pnimg)

在这里插入图片描述

3.训练自己的数据集

  • 实例分割时目标检测语义分割的结合,所以其标注文件初始为通过labelme标注的json格式,要用yolo模型进行训练,需要将其转换为yolo所需要的txt格式:
    在这里插入图片描述
  • 转换demo如下:
    参考:https://blog.csdn.net/qq_57329395/article/details/128079776
# 处理labelme多边形矩阵的标注  json转化txt
import json
import os

name2id = {'peanuthull': 0, 'kernel': 1}


def convert(img_size, box):
    dw = 1. / (img_size[0])
    dh = 1. / (img_size[1])
    x = (box[0] + box[2]) / 2.0
    y = (box[1] + box[3]) / 2.0
    w = abs(box[2] - box[0])
    h = abs(box[3] - box[1])
    x = x * dw
    w = w * dw
    y = y * dh
    h = h * dh
    return (x, y, w, h)


def decode_json(json_floder_path, txt_outer_path, json_name):
    #  json_floder_path='E:\\Python_package\\itesjson\\'
    # json_name='V1125.json'
    txt_name = txt_outer_path + json_name[:-5] + '.txt'
    with open(txt_name, 'w') as f:
        json_path = os.path.join(json_floder_path, json_name)  # os路径融合
        data = json.load(open(json_path, 'r', encoding='gb2312', errors='ignore'))
        img_w = data['imageWidth']  # 图片的高
        img_h = data['imageHeight']  # 图片的宽
        isshape_type = data['shapes'][0]['shape_type']
        print(isshape_type)
        # print(isshape_type)
        # print('下方判断根据这里的值可以设置为你自己的类型,我这里是polygon'多边形)
        # len(data['shapes'])
        for i in data['shapes']:
            label_name = i['label']  # 得到json中你标记的类名
            if (i['shape_type'] == 'polygon'):  # 数据类型为多边形 需要转化为矩形
                x_max = 0
                y_max = 0
                x_min = 100000
                y_min = 100000
                for lk in range(len(i['points'])):
                    x1 = float(i['points'][lk][0])
                    y1 = float(i['points'][lk][1])
                    # print(x1)
                    if x_max < x1:
                        x_max = x1
                    if y_max < y1:
                        y_max = y1
                    if y_min > y1:
                        y_min = y1
                    if x_min > x1:
                        x_min = x1
                bb = (x_min, y_max, x_max, y_min)
            if (i['shape_type'] == 'rectangle'):  # 为矩形不需要转换
                x1 = float(i['points'][0][0])
                y1 = float(i['points'][0][1])
                x2 = float(i['points'][1][0])
                y2 = float(i['points'][1][1])
                bb = (x1, y1, x2, y2)
            bbox = convert((img_w, img_h), bb)
            try:
                f.write(str(name2id[label_name]) + " " + " ".join([str(a) for a in bbox]) + '\n')
            except:
                pass


if __name__ == "__main__":
    json_floder_path = 'data_\\jsons\\'  # 存放json的文件夹的绝对路径
    txt_outer_path = 'data_\\txts\\'  # 存放txt的文件夹绝对路径
    json_names = os.listdir(json_floder_path)
    print("共有:{}个文件待转化".format(len(json_names)))
    flagcount = 0
    for json_name in json_names:
        decode_json(json_floder_path, txt_outer_path, json_name)
        flagcount += 1
        print("还剩下{}个文件未转化".format(len(json_names) - flagcount))

    # break
    print('转化全部完毕')
  • 数据集存放格式:

  • datasets:

    • images:

      • train: .jpg
      • val: .jpg
    • labels:

      • train: .txt
      • val: .txt
    • train_list.txt

    • val_list.txt

  • train_listval_list存放绝对路径,如下:
    在这里插入图片描述
    在这里插入图片描述

参考:https://blog.csdn.net/matt45m/article/details/127416919?spm=1001.2014.3001.5502

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/115097.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

IMX6ULL学习笔记(15)——GPIO输出接口使用【官方SDK方式】

一、GPIO简介 i.MX6ULL 芯片的 GPIO 被分成 5 组,并且每组 GPIO 的数量不尽相同&#xff0c;例如 GPIO1 拥有 32 个引脚&#xff0c; GPIO2 拥有 22 个引脚&#xff0c; 其他 GPIO 分组的数量以及每个 GPIO 的功能请参考 《i.MX 6UltraLite Applications Processor Reference M…

市级数字政府电子政务大数据中心项目建设和运营方案

【版权声明】本资料来源网络&#xff0c;仅用于行业知识分享&#xff0c;供个人学习参考&#xff0c;不得作商业用途。【侵删致歉】如有侵权请联系小编&#xff0c;将在收到信息后第一时间进行删除&#xff01; 完整资料领取见文末&#xff0c;部分资料内容&#xff1a; 1.1 大…

【QTimeEdit | QDateEdit | QDateTimeEdit | QCalendarWidget | QLCDNumber】

【QTimeEdit | QDateEdit | QDateTimeEdit | QCalendarWidget | QLCDNumber】【1】UI界面设计【2】相关头文件【3】构造函数初始化【4】setDate | setTime | setDateTime | currentDate | currentTime | currentDateTime【5】maximumDate | maximumTime | minimumDate | minimu…

基于Java+SpringBoot+vue等疫情期间网课管理系统详细设计和实现

博主介绍&#xff1a;✌全网粉丝20W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取联系&#x1f345;精彩专栏推荐订阅收藏&#x1f447;&…

使用Prometheus和Grafana监控Apache Doris

目录1. 介绍2. Prometheus和Grafana的简单使用3. 配置Prometheus的prometheus.yml4. 下载Doris Dashboard模板1. 介绍 Doris的FE和BE通过http端口metrics路径&#xff0c;将监控数据暴露出来&#xff0c;以key-value的文本形式对外展现&#xff0c;每个key还可能有不同的Label…

Nacos注册中心

【Spring Cloud Alibaba】 1. Spring Cloud Alibaba Spring Cloud Alibaba 致力于提供微服务开发的一站式解决方案。此项目包含开发分布式应用微服务的必需组件&#xff0c;方便开发者通过 Spring Cloud 编程模型轻松使用这些组件来开发分布式应用服务。 依托 Spring Cloud …

微信小程序开发—入门到跑路(五)

文章目录1. 今日目标2. 使用 npm2.1 小程序对 npm 的支持和限制问题2.2 了解什么是 vant Weapp2.3 安装 Vant 组件库问题2.4 使用 Vant 组件问题2.5 定义和使用 CSS 变量问题2.6 使用 CSS 变量定制 Vant 的主题样式问题2.7 什么是小程序 API 的 Promise 化2.8 安装并构建 minip…

运输层协议概述(计算机网络-运输层)

目录 运输层协议的位置 运输层为相互通信的应用进程提供了逻辑通信 应用进程之间的通信 客户-服务器通信模式 互联网的运输层协议 UDP 与 TCP 运输层的复用与分用 运输层端口的概念 端口在进程之间的通信中所起的作用 端口号 运输层协议的位置 从通信和信息处理的角度…

2022年终总结(脚踏实地,仰望星空)

2022年终总结 回忆录 2022年焦虑和快乐是这一年中最大的两种情绪了。焦虑主要是因为心里的三块石头&#xff0c;从年初就开始悬着。第一块石头&#xff0c;科研论文录用&#xff0c;第二个石头&#xff0c;拿到国奖&#xff0c;第三个石头是拿到满意的offer。目前只剩下最后一…

网络实验之EtherChannel技术实践

一、EtherChannel简介 EtherChannel简单来说就是将多个物理端口绑定为一个逻辑端口&#xff0c;通过多个端口绑定&#xff0c;能充分利用现有端口来增加带宽。构成etherchannel的端口必须配置成相同的特性&#xff0c;如双工模式、速度、同为FE或GE端口、native VLAN,、VLAN ra…

C++11标准模板(STL)- 算法(std::inner_product)

定义于头文件 <algorithm> 算法库提供大量用途的函数&#xff08;例如查找、排序、计数、操作&#xff09;&#xff0c;它们在元素范围上操作。注意范围定义为 [first, last) &#xff0c;其中 last 指代要查询或修改的最后元素的后一个元素。 计算两个范围的元素的内积…

十七、Docker Compose容器编排第二篇

在上一篇中我们讲解了容器编排是什么、能干什么、怎么安装、使用步骤&#xff0c;如果没有看的大家可以先看下&#xff1a;https://blog.csdn.net/u011837804/article/details/128335166&#xff0c;然后继续看这一篇&#xff0c;好了&#xff0c;我们继续。 1、Docker Compons…

gl-Camera

我的服务原文访问&#xff1a;Camera 1.创建摄像机的坐标系&#xff0c;&#xff08;创建原理&#xff0c;两条直线求其法向量&#xff09; Z轴:在世界坐标中指向摄像机的向量&#xff08;D&#xff09; X轴&#xff1a;随便找一个向上量和Z向量求出的法向量就是X轴&#xf…

PostgreSQL数据库TableAM——Table scan callbacks

TableAM Table scan TableAM提供了如下4个接口用于实现表数据的扫描功能。scan_begin函数的形参nkeys不为零&#xff0c;则扫描结果需要根据scan keys先进行过滤&#xff1b;pscan如果不为null&#xff0c;说明该结构体已经由parallelscan_initialize初始化过了(仅仅在table_b…

初识Docker:(5)Docker自定义镜像

初识Docker&#xff1a;&#xff08;5&#xff09;Docker自定义镜像镜像结构Dockerfile语法什么是Dockerfile构建Java项目案例1&#xff1a;基于ubuntu镜像构建一个新镜像&#xff0c;运行一个java项目案例2&#xff1a;基于java:8-alpine镜像&#xff0c;将一个java项目构建为…

Java+JSP机房课表管理系统(含源码+论文+答辩PPT等)

项目功能简介: 该项目采用技术CSSJavaScriptMySQLServlet、MySQL数据库、项目含有源码、配套开发软件、软件安装教程、项目发布教程等 项目功能介绍&#xff1a; 系统管理&#xff1a;包含用户的注册&#xff0c;管理&#xff0c;信息修改 课程管理&#xff1a;包含课程录入、维…

IT大侦“碳”:VxRail的可持续法宝

环境Environmental      社会责任Social Responsibility      企业治理Corporate Governance      随着碳达峰、碳中和的逐步推进,越来越多的“大厂”或各行业的明星企业都开始重视自己的ESG报告,已然成为了商界新风尚。      可持续发展战略也与前沿技术密切相…

matlab神经网络求解最优化,matlab神经网络训练数据

1、神经网络的准确率是怎么计算的&#xff1f; 其实神经网络的准确率的标准是自己定义的。 我把你的例子赋予某种意义讲解&#xff1a; 1&#xff0c;期望输出[1 0 0 1]&#xff0c;每个元素代表一个属性是否存在。像着4个元素分别表示&#xff1a;是否肺炎&#xff0c;是否肝…

哈希知识点

目录对比map/set1. unordered系列关联式容器1.1 unordered_map2. 底层结构2.1 哈希概念2.2 哈希冲突2.3 哈希函数2.4 哈希冲突解决2.4.1 闭散列线性探测和二次探测扩容&#xff08;负载因子&#xff09;闭散列实现的hash2.4.2 开散列概念开散列思考实现模拟实现模板参数列表的改…

Java项目:springboot农业物资管理系统

作者主页&#xff1a;源码空间站2022 简介&#xff1a;Java领域优质创作者、Java项目、学习资料、技术互助 文末获取源码 项目介绍 农业物资管理系统&#xff0c;管理员可以对角色进行配置&#xff0c;分配用户角色&#xff1b; 主要功能包含&#xff1a;登录、注册、修改密码…