Yolov5-detect.py代码简化(便于移植)
import argparse
import os
import sys
import time
from pathlib import Path
import cv2
import numpy as np
import torch
torch.cuda.current_device()
import torch.backends.cudnn as cudnn
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))
from models.experimental import attempt_load
from utils.datasets import LoadImages, LoadStreams
from utils.general import apply_classifier, check_img_size, check_imshow, check_requirements, check_suffix, colorstr, \
increment_path, non_max_suppression, print_args, save_one_box, scale_coords, set_logging, \
strip_optimizer, xyxy2xywh
from utils.plots import Annotator, colors, plot_one_box
from utils.torch_utils import load_classifier, select_device, time_sync
from utils.augmentations import letterbox
@torch.no_grad()
def run():
weights = './yolov5s.pt'
device = 'cuda:0'
save_conf = False
imgsz = 640
line_thickness = 3
hide_labels = False
hide_conf = False
half = False
device = select_device(device)
half &= device.type != 'cpu'
model = attempt_load(weights, map_location=device)
stride = 32
names = model.module.names if hasattr(model, 'module') else model.names
if half:
model.half()
img = cv2.imread("./data/images/image2.jpg")
im0 = img.copy()
img = letterbox(img, new_shape=(imgsz, imgsz), stride=stride)[0]
img = img.transpose((2, 0, 1))[::-1]
img = np.ascontiguousarray(img)
statistic_dic = {name: 0 for name in names}
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float()
img = img / 255.0
if len(img.shape) == 3:
img = img[None]
pred = model(img, augment=False, visualize=False)[0]
pred = non_max_suppression(pred, conf_thres=0.6, iou_thres=0.45, classes=None, max_det=1000)
for i, det in enumerate(pred):
if len(det):
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
for *xyxy, conf, cls in reversed(det):
c = int(cls)
statistic_dic[names[c]] += 1
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4))).view(-1).tolist()
line = (cls, *xywh, conf) if save_conf else (cls, *xywh)
label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
plot_one_box(xyxy, im0, label=label, color=colors(c, True), line_thickness=line_thickness)
print(statistic_dic)
cv2.imshow("img", im0)
cv2.waitKey(0)
cv2.destroyAllWindows()
if __name__ == "__main__":
run()