【Keras+计算机视觉+Tensorflow】实现基于YOLO和Deep Sort的目标检测与跟踪实战(附源码和数据集)

news2024/12/23 12:17:38

需要源码和数据集请点赞关注收藏后评论区留言私信~~~

一、YOLO目标检测算法

        YOLO是端到端的物体检测深度卷积神经网络,YOLO可以一次性预测多个候选框,并直接在输出层回归物体位置区域和区域内物体所属类别,而Faster R-CNN仍然是采用R-CNN那种将物体位置区域框与物体分开训练的思想,只是利用RPN网络,将提取候选框的步骤放在深度卷积神经网络内部实现,YOLO最大的优势就是速度快,可满足端到端训练和实时检测要求

二、Deep Sort多目标跟踪算法 

算法原理如下图所示,在目标检测算法得到检测结果后,利用目标框来初始化卡尔曼滤波器,使用一个八维空间去刻画轨迹在某时刻的状态分别表示目标框的中心位置,纵横比,高度以及在图像坐标中对应的速度信息,计算卡尔曼滤波器提供的预测框与目标检测框之间的位置关系和外观特征关系,利用两个信息综合判断目标检测与跟踪框之间的关联程度,完成多目标的跟踪匹配

三、实战项目算法流程 

实现流程为:首先从视频中分解出图像帧,将图像输入目标检测模块,将检测到的动态目标,输入到目标跟踪模块,而将检测到的静态目标直接输出检测结果,目标跟踪模块为同一动态目标编上同样的编号并显示在目标框的左上角,连接多帧中出现的相同的动态目标,从而画出该动态目标的运动轨迹

 效果展示

目标检测与跟踪的结果如下图

 三、代码

项目结构如下 

代码中主要的模块及步骤如下

1:导入第三方库

2:主函数

3:目标检测部分YOLO

4:目标跟踪部分Deep Sort 

部分代码如下 需要全部代码请点赞关注收藏后评论区留言私信~~~

YOLO算法代码 

import os
import numpy as np
import copy
import colorsys
from timeit import default_timer as timer
from keras import backend as K
from keras.models import load_model
from keras.layers import Input
from PIL import Image, ImageFont, ImageDraw
from nets.yolo4 import yolo_body,yolo_eval
from utils.utils import letterbox_image
#--------------------------------------------#
#   使用自己训练好的模型预测需要修改2个参数
#   model_path和classes_path都需要修改!
#--------------------------------------------#
class YOLO(object):
    _defaults = {
        "model_path"        : 'model_data/yolo4_weight.h5',
        "anchors_path"      : 'model_data/yolo_anchors.txt',
        "classes_path"      : 'model_data/coco_classes.txt',
        "score"             : 0.5,
        "iou"               : 0.3,
        "max_boxes"         : 100,
        # 显存比较小可以使用416x416
        # 显存比较大可以使用608x608
        "model_image_size"  : (416, 416)
    }

    @classmethod
    def get_defaults(cls, n):
        if n in cls._defaults:
            return cls._defaults[n]
        else:
            return "Unrecognized attribute name '" + n + "'"

    #---------------------------------------------------#
    #   初始化yolo
    #---------------------------------------------------#
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        self.class_names = self._get_class()
        self.anchors = self._get_anchors()
        self.sess = K.get_session()
        self.boxes, self.scores, self.classes = self.generate()

    #---------------------------------------------------#
    #   获得所有的分类
    #---------------------------------------------------#
    def _get_class(self):
        classes_path = os.path.expanduser(self.classes_path)
        with open(classes_path) as f:
            class_names = f.readlines()
        class_names = [c.strip() for c in class_names]
        return class_names

    #---------------------------------------------------#
    #   获得所有的先验框
    #---------------------------------------------------#
    def _get_anchors(self):
        anchors_path = os.path.expanduser(self.anchors_path)
        with open(anchors_path) as f:
            anchors = f.readline()
        anchors = [float(x) for x in anchors.split(',')]
        return np.array(anchors).reshape(-1, 2)

    #---------------------------------------------------#
    #   获得所有的分类
    #---------------------------------------------------#
    def generate(self):
        model_path = os.path.expanduser(self.model_path)
        assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
        
        # 计算anchor数量
        num_anchors = len(self.anchors)
        num_classes = len(self.class_names)

        # 载入模型,如果原来的模型里已经包括了模型结构则直接载入。
        # 否则先构建模型再载入
        try:
            self.yolo_model = load_model(model_path, compile=False)
        except:
            self.yolo_model = yolo_body(Input(shape=(None,None,3)), num_anchors//3, num_classes)
            self.yolo_model.load_weights(self.model_path)
        else:
            assert self.yolo_model.layers[-1].output_shape[-1] == \
                num_anchors/len(self.yolo_model.output) * (num_classes + 5), \
                'Mismatch between model and given anchor and class sizes'

        print('{} model, anchors, and classes loaded.'.format(model_path))

        # 画框设置不同的颜色
        hsv_tuples = [(x / len(self.class_names), 1., 1.)
                      for x in range(len(self.class_names))]
        self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        self.colors = list(
            map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
                self.colors))

        # 打乱颜色
        np.random.seed(10101)
        np.random.shuffle(self.colors)
        np.random.seed(None)

        self.input_image_shape = K.placeholder(shape=(2, ))

        boxes, scores, classes = yolo_eval(self.yolo_model.output, self.anchors,
                num_classes, self.input_image_shape, max_boxes = self.max_boxes,
                score_threshold = self.score, iou_threshold = self.iou)
        return boxes, scores, classes

    '''      
        函数名称:detect_image
        函数作用:目标跟踪程序(YOLO V4)
        函数输入:frame: 图像
        函数输出:
            boxs_person:行人检测框【x1, y1, w, h】
            boxs_others:其他类别检测框【x1, y1, x2, y2】
            labels_others: 其他框的类别
    
    '''
    def detect_image(self, image):
        new_image_size = (self.model_image_size[1],self.model_image_size[0])
        boxed_image = letterbox_image(image, new_image_size)
        image_data = np.array(boxed_image, dtype='float32')
        image_data /= 255.
        image_data = np.expand_dims(image_data, 0)  # Add batch dimension.
        boxs_person = []
        boxs_others = []
        labels_others = []
        # 预测结果
        out_boxes, out_scores, out_classes = self.sess.run(
            [self.boxes, self.scores, self.classes],
            feed_dict={
                self.yolo_model.input: image_data,
                self.input_image_shape: [image.size[1], image.size[0]],
                K.learning_phase(): 0
            })


        for i, c in list(enumerate(out_classes)):
            predicted_class = self.class_names[c]
            box = out_boxes[i]
            score = out_scores[i]
            top, left, bottom, right = box
            ###输入deepSort的格式如下###
            box_deepsort = [left,top,right-left,bottom-top]
            box_other = [left,top,right,bottom]

            if predicted_class == 'person':
                boxs_person.append(box_deepsort)
            else:
                boxs_others.append(box_other)
                labels_others.append(predicted_class)


        return boxs_person,boxs_others,labels_others

    def close_session(self):
        self.sess.close()

 Deep Sort算法代码

#!python3
#--coding:utf8--
from yolo import YOLO
from PIL import Image
import os
import sys
import time
import logging
import random
from random import randint
import math
import statistics
import getopt
from ctypes import *
import numpy as np
import cv2
from deep_sort import nn_matching
from deep_sort.detection import Detection
from deep_sort.tracker import Tracker
from tools import generate_detections as gdet
from deep_sort.detection import Detection as ddet
from collections import deque
from deep_sort import preprocessing


'''
函数名称:track_deepsort
函数作用:目标跟踪程序
函数输入:frame:图像
         boxs_person:行人检测框【x1,y1,w,h】
         boxs_others:其他类别检测框【x1,y1,x2,y2】
         labels_others:其他框的类别
         encoder:跟踪器的编码器
         tracker: 跟踪器
         pts: 运动点初始化值
         show_results:是否显示结果

函数输出:tracker 跟踪器
         pts 运动轨迹

'''

def track_deepsort(frame,boxs_person,boxs_others,labels_others,encoder,tracker,pts,show_results=True):
    nms_max_overlap = 1.0
    features = encoder(frame, boxs_person)
    detections = [Detection(bbox, 1.0, feature) for bbox, feature in zip(boxs_person, features)]
    boxes = np.array([d.tlwh for d in detections])
    scores = np.array([d.confidence for d in detections])
    indices = preprocessing.non_max_suppression(boxes, nms_max_overlap, scores)
    detections = [detections[i] for i in indices]
    # 跟踪
    tracker.predict()
    tracker.update(detections)
    i = int(0)
    indexIDs = []
    ##########结果显示###########
    if show_results:
        for det in detections:
            bbox = det.to_tlbr()
            cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (255, 255, 255), 2)
        for ii in range(len(boxs_others)):
            bbox = boxs_others[ii]
            label = labels_others[ii]
            cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 255, 255), 2)
            cv2.putText(frame, str(label), (int(bbox[0]), int(bbox[1])), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 255, 255), 2)

        for track in tracker.tracks:
            if not track.is_confirmed() or track.time_since_update > 1:
                continue
            # boxes.append([track[0], track[1], track[2], track[3]])
            indexIDs.append(int(track.track_id))
            bbox = track.to_tlbr()

            cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 255, 0), 3)
            cv2.putText(frame, str(track.track_id), (int(bbox[0]), int(bbox[1] - 50)), 0, 5e-3 * 150, (0, 255, 0), 2)


            i = i + 1
            center = (int(((bbox[0]) + (bbox[2])) / 2), int(((bbox[1]) + (bbox[3])) / 2))
            pts[track.track_id].append(center)
            # draw motion path
            for j in range(1, len(pts[track.track_id])):
                if pts[track.track_id][j - 1] is None or pts[track.track_id][j] is None:
                    continue
                thickness = int(np.sqrt(64 / float(j + 1)) * 2)
                cv2.line(frame, (pts[track.track_id][j - 1]), (pts[track.track_id][j]), (0, 255, 255), thickness)

    return tracker,pts



if __name__ == "__main__":

    yolo = YOLO()
    ####设置跟踪参数###
    max_cosine_distance = 0.5
    nn_budget = 20
    metric = nn_matching.NearestNeighborDistanceMetric("cosine", max_cosine_distance,
                                                       nn_budget)  # 最近邻距离度量,对于每个目标,返回到目前为止已观察到的任何样本的最近距离(欧式或余弦)。
    tracker = Tracker(metric)  # 由距离度量方法构造一个 Tracker。
    writeVideo_flag = False
    ###轨迹点定义##
    pts = [deque(maxlen=30) for _ in range(9999)]
    model_filename = './model_data/mars-small128.pb'  ###DeepSort 模型位置##
    encoder = gdet.create_box_encoder(model_filename, batch_size=1)
    Obj_centre = [[] for i in range(200)]
    Obj_pre_direction = [[] for i in range(200)]
    ShowFlag = True ##是否显示结果
    ####打开摄像机###
    # 创建VideoCapture,传入0即打开系统默认摄像头
    # cap = cv2.VideoCapture(0)
    #######读取视频######################################
    video_path = 'structure.mp4'
    video_capture = cv2.VideoCapture(video_path)
    key = ''
    count = 0
    save_path = './saveimg/'

    while key != 113:  # for 'q' key
       ###读取图像###
        ret, frame = video_capture.read()
        #######目标检测########################
        frame2 = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGRA2RGBA))
        boxs_person,boxs_others,labels_others = yolo.detect_image(frame2)

        #######目标跟踪########################
        tracker, pts = track_deepsort(frame, boxs_person, boxs_others, labels_others, encoder, tracker, pts)
        #######显示检测及跟踪结果####
        cv2.namedWindow("YOLO3_Deep_SORT", 0)
        cv2.resizeWindow('YOLO3_Deep_SORT', 1024, 768)
        cv2.imshow('YOLO3_Deep_SORT', frame)
        cv2.waitKey(3)
        count  += 1
        jpg_name = os.path.join(save_path,str(count).zfill(6)+'.jpg')
        cv2.imwrite(jpg_name,frame)





四、实战效果评价 

结果显示,在目标检测环节,当人群交叉 光照突变时可能出现漏检的现象,这将导致目标跟踪环节出现跟踪错误,应该进一步地调整目标跟踪策略,使目标跟踪算法具有鲁棒性,尤其是解决人员聚集情况下的目标跟踪问题

创作不易  觉得有帮助请点赞关注收藏~~~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/78382.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Windows 下Zookeeper 配置参数解读 和查看注册了哪些服务

zookeeper 配置文件解读 本地配置文件奉上: # The number of milliseconds of each tick tickTime2000 # The number of ticks that the initial # synchronization phase can take initLimit10 # The number of ticks that can pass between # sending a request and gett…

图像处理学习笔记-10-图像分割与边缘检测

图像分割的三大类方法:根据区域间灰度不连续搜寻区域之间的边界,在奇异性检测、边缘连接和边界检测介绍;以像素性质的分布进行阈值处理,在阈值处理介绍;直接搜寻区域进行分割,在基于区域的分割中介绍 奇异…

数据库、计算机网络,操作系统刷题笔记8

数据库、计算机网络,操作系统刷题笔记8 2022找工作是学历、能力和运气的超强结合体,遇到寒冬,大厂不招人,可能很多算法学生都得去找开发,测开 测开的话,你就得学数据库,sql,oracle&…

网站各个功能基本实现

1.前面已经介绍前后端的交互 2.今天实现网站功能的基本实现 也就是查询数据库。 网站类型为展示型网站。 页面如下: 点击政府公告显示: 点击机构设置显示: 后面不一一展示,主要实现六大功能的展示功能。 后续就实现管理员维…

PostgREST的安装部署(Windows和Linux环境)

下载地址:https://github.com/PostgREST/postgrest/releases 官方文档地址:Overview of Role System — PostgREST 9.0.0 documentation Windows 先下载对应系统的安装包: 下载之后解压会得到一个postgrest.exe可执行文件 创建配置文件&a…

解决 Android 开发过程中 出现 Duplicate class(包冲突)

1、现在大部分的项目都是支持 Androidx 的,所以出现 Duplicate 的时候 先把 gradle.properties 文件中添加参数,支持使用AndroidX android.useAndroidXtrue android.enableJetifiertrue 2、有些 *.jar/*.aar 不支持 AndroidX 的时候,将上面…

抽取_内插_半带滤波器_多相滤波器

文章目录半带滤波器多相抽取滤波器多相内插滤波器半带抽取器和半带内插器参考资料:Xilinx FIR Compiler v7.2 LogiCORE IP Product Guide PG149半带滤波器 半带滤波器的阶数为偶数,系数长度为奇数,且除了中间系数为0.5外,其余偶数…

mybatisplus 使用mybatis中的配置、mapper配置文件

1、在application.properties中配置mybatis的配置文件路径,例如: #指定mybatis-config.xml的位置 mybatis-plus.config-location classpath:mybatis/mybatis-config.xml 即在和application.properties同级目录下的mybatis目录中创建mybatis的配置文件m…

数制编码详解:二进制八进制十六进制的转换,原码、补码、反码、移码的定义

参考资料:《深入理解计算机网络(王达)》 文章目录一,数制1.1 基本数制1.2 不同数制之间的相互转换二,编码一,数制 1.1 基本数制 “数制”是“数据进制”的简称,也就是表示数据逢几进位的意思&a…

chatGPT的体验,是不是真智能?

目录 🏆一、前言 🏆二、安装 🏆三、普通对话 🚩1、chatGPT的ikun性 🚩2、chatGPT的日常对话 🏆四、实用能力 🏆五、代码改正 🏆六、写代码 🏆七、讲解代码 🏆…

ESXI精简thin磁盘迁移存储位置保留磁盘类型不变-无vc方式

运行2年了ESXI的SSD存储上很多VM并带多层快照,最近磁盘速度异常,迁移到新存储,都是thin磁盘;如有vCenter条件,采用“迁移”即可完美解决,既使磁盘类型thin不变又保留快照结构。如无vCenter条件的操作方式细…

Unity-iOS工程导出Xcode自动构建方法

Unity-iOS发布基本流程首先在Unity中导出Xcode工程,然后在Xcode工程中设置IOS打包的一些流程,诸如引入lib、framework或其他资源、设置签名及其他编译设置、加入编译脚本等等操作。 这些操作如果每次都是在导出Xcode后手动操作,一来浪费时间…

【C++ STL】-- 红黑树的插入实现

目录 红黑树的概念 二叉树搜索树的应用 红黑树节点的定义 红黑树结构 insert 需调整的多情况的核心思维: 需调整的多情况分类讲解: 情况一: 情况二: 情况三: 总结: 代码实现: 对于红黑树是否建立成功的检查 升序打印…

C++--类型转换--1128

1.C语言中的类型转换 分为隐式类型转化、显示强制类型转化。 隐式类型转化用于意义相近的类型,比如int,double,short都是表示数值的类型 int i1; double di; //编译、结果无问题 这里是隐式类型转换。 显示强制类型转换 显示强制类型用于意义不相近的类型&…

Redis Sentinel

高可用架构-Redis Sentinel Replication 缺点 接着之前的Redis Replication 主从复制架构,看似解决了主节点并发过大时,master节点处理繁忙的问题。将一部分读数据的请求交给从节点处理,从而将请求进行分散处理。但是该架构却存在很明显的缺…

基于LEACH的随机网络生成无线传感器网络路由协议的仿真比较(Matlab代码实现)

👨‍🎓个人主页:研学社的博客 💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜…

【深度学习】超详细的 PyTorch 学习笔记(上)

文章目录一、PyTorch环境检查二、查看张量类型三、查看张量尺寸和所占内存大小四、创建张量4.1 创建值全为1的张量4.2 创建值全为0的张量4.3 创建值全为指定值的张量4.4 通过 list 创建张量4.5 通过 ndarray 创建张量4.6 创建指定范围和间距的有序张量4.7 创建单位矩阵&#xf…

【力扣算法简单五十题】23.环形链表

给你一个链表的头节点 head ,判断链表中是否有环。 如果链表中有某个节点,可以通过连续跟踪 next 指针再次到达,则链表中存在环。 为了表示给定链表中的环,评测系统内部使用整数 pos 来表示链表尾连接到链表中的位置(索…

基于多种优化算法及神经网络的光伏系统控制(Matlab代码实现)

💥💥💥💞💞💞欢迎来到本博客❤️❤️❤️💥💥💥 🎉作者研究:🏅🏅🏅本科计算机专业,研究生电气学硕…

NNDL 实验八 网络优化与正则化(3)不同优化算法比较

文章目录7.3 不同优化算法的比较分析7.3.1 优化算法的实验设定7.3.1.1 2D可视化实验7.3.1.2 简单拟合实验7.3.1.3 与Torch API对比,验证正确性7.3.2 学习率调整7.3.2.1 AdaGrad算法7.3.2.2 RMSprop算法7.3.3 梯度估计修正7.3.3.1 动量法7.3.3.2 Adam算法7.3.4 不同优…