C# yolov8 TensorRT +ByteTrack Demo

news2024/11/17 13:53:24

C# yolov8 TensorRT +ByteTrack  Demo

目录

效果

说明 

项目

代码

Form2.cs

YoloV8.cs

ByteTracker.cs

下载

参考 


效果

说明 

环境

NVIDIA GeForce RTX 4060 Laptop GPU

cuda12.1+cudnn 8.8.1+TensorRT-8.6.1.6

版本和我不一致的需要重新编译TensorRtExtern.dll,TensorRtExtern源码地址:TensorRT-CSharp-API/src/TensorRtExtern at TensorRtSharp2.0 · guojin-yan/TensorRT-CSharp-API · GitHub

Windows版 CUDA安装参考:Windows版 CUDA安装_win cuda安装-CSDN博客

项目

代码

Form2.cs

using ByteTrack;
using OpenCvSharp;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Drawing;
using System.IO;
using System.Threading;
using System.Windows.Forms;
using TensorRtSharp.Custom;

namespace yolov8_TensorRT_Demo
{
    public partial class Form2 : Form
    {
        public Form2()
        {
            InitializeComponent();
        }

        string imgFilter = "*.*|*.bmp;*.jpg;*.jpeg;*.tiff;*.tiff;*.png";

        YoloV8 yoloV8;
        Mat image;

        string image_path = "";
        string model_path;

        string video_path = "";
        string videoFilter = "*.mp4|*.mp4;";
        VideoCapture vcapture;
        VideoWriter vwriter;
        bool saveDetVideo = false;
        ByteTracker tracker;


        /// <summary>
        /// 单图推理
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void button2_Click(object sender, EventArgs e)
        {

            if (image_path == "")
            {
                return;
            }

            button2.Enabled = false;
            pictureBox2.Image = null;
            textBox1.Text = "";

            Application.DoEvents();

            image = new Mat(image_path);

            List<DetectionResult> detResults = yoloV8.Detect(image);

            //绘制结果
            Mat result_image = image.Clone();
            foreach (DetectionResult r in detResults)
            {
                Cv2.PutText(result_image, $"{r.Class}:{r.Confidence:P0}", new OpenCvSharp.Point(r.Rect.TopLeft.X, r.Rect.TopLeft.Y - 10), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.Rectangle(result_image, r.Rect, Scalar.Red, thickness: 2);
            }

            if (pictureBox2.Image != null)
            {
                pictureBox2.Image.Dispose();
            }
            pictureBox2.Image = new Bitmap(result_image.ToMemoryStream());
            textBox1.Text = yoloV8.DetectTime();

            button2.Enabled = true;

        }

        /// <summary>
        /// 窗体加载,初始化
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void Form1_Load(object sender, EventArgs e)
        {
            image_path = "test/dog.jpg";
            pictureBox1.Image = new Bitmap(image_path);

            model_path = "model/yolov8n.engine";

            if (!File.Exists(model_path))
            {
                //有点耗时,需等待
                Nvinfer.OnnxToEngine("model/yolov8n.onnx", 20);
            }

            yoloV8 = new YoloV8(model_path, "model/lable.txt");

        }

        /// <summary>
        /// 选择图片
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void button1_Click_1(object sender, EventArgs e)
        {
            OpenFileDialog ofd = new OpenFileDialog();
            ofd.Filter = imgFilter;
            if (ofd.ShowDialog() != DialogResult.OK) return;

            pictureBox1.Image = null;

            image_path = ofd.FileName;
            pictureBox1.Image = new Bitmap(image_path);

            textBox1.Text = "";
            pictureBox2.Image = null;
        }

        /// <summary>
        /// 选择视频
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void button4_Click(object sender, EventArgs e)
        {
            OpenFileDialog ofd = new OpenFileDialog();
            ofd.Filter = videoFilter;
            ofd.InitialDirectory = Application.StartupPath + "\\test";
            if (ofd.ShowDialog() != DialogResult.OK) return;

            video_path = ofd.FileName;

            textBox1.Text = "";
            pictureBox1.Image = null;
            pictureBox2.Image = null;

            button3_Click(null, null);

        }

        /// <summary>
        /// 视频推理
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void button3_Click(object sender, EventArgs e)
        {
            if (video_path == "")
            {
                return;
            }

            textBox1.Text = "开始检测";

            Application.DoEvents();

            Thread thread = new Thread(new ThreadStart(VideoDetection));

            thread.Start();
            thread.Join();

            textBox1.Text = "检测完成!";
        }

        void VideoDetection()
        {
            vcapture = new VideoCapture(video_path);
            if (!vcapture.IsOpened())
            {
                MessageBox.Show("打开视频文件失败");
                return;
            }

            tracker = new ByteTracker((int)vcapture.Fps, 200);

            Mat frame = new Mat();
            List<DetectionResult> detResults;

            // 获取视频的fps
            double videoFps = vcapture.Get(VideoCaptureProperties.Fps);
            // 计算等待时间(毫秒)
            int delay = (int)(1000 / videoFps);
            Stopwatch _stopwatch = new Stopwatch();

            if (checkBox1.Checked)
            {
                vwriter = new VideoWriter("out.mp4", FourCC.X264, vcapture.Fps, new OpenCvSharp.Size(vcapture.FrameWidth, vcapture.FrameHeight));
                saveDetVideo = true;
            }
            else
            {
                saveDetVideo = false;
            }

            while (vcapture.Read(frame))
            {
                if (frame.Empty())
                {
                    MessageBox.Show("读取失败");
                    return;
                }

                _stopwatch.Restart();

                delay = (int)(1000 / videoFps);

                detResults = yoloV8.Detect(frame);

                //绘制结果
                //foreach (DetectionResult r in detResults)
                //{
                //    Cv2.PutText(frame, $"{r.Class}:{r.Confidence:P0}", new OpenCvSharp.Point(r.Rect.TopLeft.X, r.Rect.TopLeft.Y - 10), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                //    Cv2.Rectangle(frame, r.Rect, Scalar.Red, thickness: 2);
                //}

                Cv2.PutText(frame, "preprocessTime:" + yoloV8.preprocessTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 30), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.PutText(frame, "inferTime:" + yoloV8.inferTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 70), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.PutText(frame, "postprocessTime:" + yoloV8.postprocessTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 110), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.PutText(frame, "totalTime:" + yoloV8.totalTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 150), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.PutText(frame, "video fps:" + videoFps.ToString("F2"), new OpenCvSharp.Point(10, 190), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.PutText(frame, "det fps:" + yoloV8.detFps.ToString("F2"), new OpenCvSharp.Point(10, 230), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);

                List<Track> track = new List<Track>();
                Track temp;
                foreach (DetectionResult r in detResults)
                {
                    RectBox _box = new RectBox(r.Rect.X, r.Rect.Y, r.Rect.Width, r.Rect.Height);
                    temp = new Track(_box, r.Confidence, ("label", r.ClassId), ("name", r.Class));
                    track.Add(temp);
                }

                var trackOutputs = tracker.Update(track);

                foreach (var t in trackOutputs)
                {
                    Rect rect = new Rect((int)t.RectBox.X, (int)t.RectBox.Y, (int)t.RectBox.Width, (int)t.RectBox.Height);
                    //string txt = $"{t["name"]}-{t.TrackId}:{t.Score:P0}";
                    string txt = $"{t["name"]}-{t.TrackId}";
                    Cv2.PutText(frame, txt, new OpenCvSharp.Point(rect.TopLeft.X, rect.TopLeft.Y - 10), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                    Cv2.Rectangle(frame, rect, Scalar.Red, thickness: 2);
                }

                if (saveDetVideo)
                {
                    vwriter.Write(frame);
                }

                Cv2.ImShow("DetectionResult", frame);

                // for test
                // delay = 1;
                delay = (int)(delay - _stopwatch.ElapsedMilliseconds);
                if (delay <= 0)
                {
                    delay = 1;
                }
                //Console.WriteLine("delay:" + delay.ToString()) ;
                if (Cv2.WaitKey(delay) == 27)
                {
                    break; // 如果按下ESC,退出循环
                }
            }

            Cv2.DestroyAllWindows();
            vcapture.Release();
            if (saveDetVideo)
            {
                vwriter.Release();
            }

        }

    }

}

using ByteTrack;
using OpenCvSharp;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Drawing;
using System.IO;
using System.Threading;
using System.Windows.Forms;
using TensorRtSharp.Custom;

namespace yolov8_TensorRT_Demo
{
    public partial class Form2 : Form
    {
        public Form2()
        {
            InitializeComponent();
        }

        string imgFilter = "*.*|*.bmp;*.jpg;*.jpeg;*.tiff;*.tiff;*.png";

        YoloV8 yoloV8;
        Mat image;

        string image_path = "";
        string model_path;

        string video_path = "";
        string videoFilter = "*.mp4|*.mp4;";
        VideoCapture vcapture;
        VideoWriter vwriter;
        bool saveDetVideo = false;
        ByteTracker tracker;


        /// <summary>
        /// 单图推理
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void button2_Click(object sender, EventArgs e)
        {

            if (image_path == "")
            {
                return;
            }

            button2.Enabled = false;
            pictureBox2.Image = null;
            textBox1.Text = "";

            Application.DoEvents();

            image = new Mat(image_path);

            List<DetectionResult> detResults = yoloV8.Detect(image);

            //绘制结果
            Mat result_image = image.Clone();
            foreach (DetectionResult r in detResults)
            {
                Cv2.PutText(result_image, $"{r.Class}:{r.Confidence:P0}", new OpenCvSharp.Point(r.Rect.TopLeft.X, r.Rect.TopLeft.Y - 10), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.Rectangle(result_image, r.Rect, Scalar.Red, thickness: 2);
            }

            if (pictureBox2.Image != null)
            {
                pictureBox2.Image.Dispose();
            }
            pictureBox2.Image = new Bitmap(result_image.ToMemoryStream());
            textBox1.Text = yoloV8.DetectTime();

            button2.Enabled = true;

        }

        /// <summary>
        /// 窗体加载,初始化
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void Form1_Load(object sender, EventArgs e)
        {
            image_path = "test/dog.jpg";
            pictureBox1.Image = new Bitmap(image_path);

            model_path = "model/yolov8n.engine";

            if (!File.Exists(model_path))
            {
                //有点耗时,需等待
                Nvinfer.OnnxToEngine("model/yolov8n.onnx", 20);
            }

            yoloV8 = new YoloV8(model_path, "model/lable.txt");

        }

        /// <summary>
        /// 选择图片
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void button1_Click_1(object sender, EventArgs e)
        {
            OpenFileDialog ofd = new OpenFileDialog();
            ofd.Filter = imgFilter;
            if (ofd.ShowDialog() != DialogResult.OK) return;

            pictureBox1.Image = null;

            image_path = ofd.FileName;
            pictureBox1.Image = new Bitmap(image_path);

            textBox1.Text = "";
            pictureBox2.Image = null;
        }

        /// <summary>
        /// 选择视频
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void button4_Click(object sender, EventArgs e)
        {
            OpenFileDialog ofd = new OpenFileDialog();
            ofd.Filter = videoFilter;
            ofd.InitialDirectory = Application.StartupPath + "\\test";
            if (ofd.ShowDialog() != DialogResult.OK) return;

            video_path = ofd.FileName;

            textBox1.Text = "";
            pictureBox1.Image = null;
            pictureBox2.Image = null;

            button3_Click(null, null);

        }

        /// <summary>
        /// 视频推理
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void button3_Click(object sender, EventArgs e)
        {
            if (video_path == "")
            {
                return;
            }

            textBox1.Text = "开始检测";

            Application.DoEvents();

            Thread thread = new Thread(new ThreadStart(VideoDetection));

            thread.Start();
            thread.Join();

            textBox1.Text = "检测完成!";
        }

        void VideoDetection()
        {
            vcapture = new VideoCapture(video_path);
            if (!vcapture.IsOpened())
            {
                MessageBox.Show("打开视频文件失败");
                return;
            }

            tracker = new ByteTracker((int)vcapture.Fps, 200);

            Mat frame = new Mat();
            List<DetectionResult> detResults;

            // 获取视频的fps
            double videoFps = vcapture.Get(VideoCaptureProperties.Fps);
            // 计算等待时间(毫秒)
            int delay = (int)(1000 / videoFps);
            Stopwatch _stopwatch = new Stopwatch();

            if (checkBox1.Checked)
            {
                vwriter = new VideoWriter("out.mp4", FourCC.X264, vcapture.Fps, new OpenCvSharp.Size(vcapture.FrameWidth, vcapture.FrameHeight));
                saveDetVideo = true;
            }
            else
            {
                saveDetVideo = false;
            }

            while (vcapture.Read(frame))
            {
                if (frame.Empty())
                {
                    MessageBox.Show("读取失败");
                    return;
                }

                _stopwatch.Restart();

                delay = (int)(1000 / videoFps);

                detResults = yoloV8.Detect(frame);

                //绘制结果
                //foreach (DetectionResult r in detResults)
                //{
                //    Cv2.PutText(frame, $"{r.Class}:{r.Confidence:P0}", new OpenCvSharp.Point(r.Rect.TopLeft.X, r.Rect.TopLeft.Y - 10), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                //    Cv2.Rectangle(frame, r.Rect, Scalar.Red, thickness: 2);
                //}

                Cv2.PutText(frame, "preprocessTime:" + yoloV8.preprocessTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 30), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.PutText(frame, "inferTime:" + yoloV8.inferTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 70), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.PutText(frame, "postprocessTime:" + yoloV8.postprocessTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 110), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.PutText(frame, "totalTime:" + yoloV8.totalTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 150), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.PutText(frame, "video fps:" + videoFps.ToString("F2"), new OpenCvSharp.Point(10, 190), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.PutText(frame, "det fps:" + yoloV8.detFps.ToString("F2"), new OpenCvSharp.Point(10, 230), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);

                List<Track> track = new List<Track>();
                Track temp;
                foreach (DetectionResult r in detResults)
                {
                    RectBox _box = new RectBox(r.Rect.X, r.Rect.Y, r.Rect.Width, r.Rect.Height);
                    temp = new Track(_box, r.Confidence, ("label", r.ClassId), ("name", r.Class));
                    track.Add(temp);
                }

                var trackOutputs = tracker.Update(track);

                foreach (var t in trackOutputs)
                {
                    Rect rect = new Rect((int)t.RectBox.X, (int)t.RectBox.Y, (int)t.RectBox.Width, (int)t.RectBox.Height);
                    //string txt = $"{t["name"]}-{t.TrackId}:{t.Score:P0}";
                    string txt = $"{t["name"]}-{t.TrackId}";
                    Cv2.PutText(frame, txt, new OpenCvSharp.Point(rect.TopLeft.X, rect.TopLeft.Y - 10), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                    Cv2.Rectangle(frame, rect, Scalar.Red, thickness: 2);
                }

                if (saveDetVideo)
                {
                    vwriter.Write(frame);
                }

                Cv2.ImShow("DetectionResult", frame);

                // for test
                // delay = 1;
                delay = (int)(delay - _stopwatch.ElapsedMilliseconds);
                if (delay <= 0)
                {
                    delay = 1;
                }
                //Console.WriteLine("delay:" + delay.ToString()) ;
                if (Cv2.WaitKey(delay) == 27)
                {
                    break; // 如果按下ESC,退出循环
                }
            }

            Cv2.DestroyAllWindows();
            vcapture.Release();
            if (saveDetVideo)
            {
                vwriter.Release();
            }

        }

    }

}

YoloV8.cs

using OpenCvSharp;
using OpenCvSharp.Dnn;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using TensorRtSharp.Custom;

namespace yolov8_TensorRT_Demo
{
    public class YoloV8
    {

        float[] input_tensor_data;
        float[] outputData;
        List<DetectionResult> detectionResults;

        int input_height;
        int input_width;

        Nvinfer predictor;

        public string[] class_names;
        int class_num;
        int box_num;

        float conf_threshold;
        float nms_threshold;

        float ratio_height;
        float ratio_width;

        public double preprocessTime;
        public double inferTime;
        public double postprocessTime;
        public double totalTime;
        public double detFps;

        public String DetectTime()
        {
            StringBuilder stringBuilder = new StringBuilder();
            stringBuilder.AppendLine($"Preprocess: {preprocessTime:F2}ms");
            stringBuilder.AppendLine($"Infer: {inferTime:F2}ms");
            stringBuilder.AppendLine($"Postprocess: {postprocessTime:F2}ms");
            stringBuilder.AppendLine($"Total: {totalTime:F2}ms");

            return stringBuilder.ToString();
        }

        public YoloV8(string model_path, string classer_path)
        {
            predictor = new Nvinfer(model_path);

            class_names = File.ReadAllLines(classer_path, Encoding.UTF8);
            class_num = class_names.Length;

            input_height = 640;
            input_width = 640;

            box_num = 8400;

            conf_threshold = 0.25f;
            nms_threshold = 0.5f;

            detectionResults = new List<DetectionResult>();
        }

        void Preprocess(Mat image)
        {
            //图片缩放
            int height = image.Rows;
            int width = image.Cols;
            Mat temp_image = image.Clone();
            if (height > input_height || width > input_width)
            {
                float scale = Math.Min((float)input_height / height, (float)input_width / width);
                OpenCvSharp.Size new_size = new OpenCvSharp.Size((int)(width * scale), (int)(height * scale));
                Cv2.Resize(image, temp_image, new_size);
            }
            ratio_height = (float)height / temp_image.Rows;
            ratio_width = (float)width / temp_image.Cols;
            Mat input_img = new Mat();
            Cv2.CopyMakeBorder(temp_image, input_img, 0, input_height - temp_image.Rows, 0, input_width - temp_image.Cols, BorderTypes.Constant, 0);

            //归一化
            input_img.ConvertTo(input_img, MatType.CV_32FC3, 1.0 / 255);

            input_tensor_data = Common.ExtractMat(input_img);

            input_img.Dispose();
            temp_image.Dispose();
        }

        void Postprocess(float[] outputData)
        {
            detectionResults.Clear();

            float[] data = Common.Transpose(outputData, class_num + 4, box_num);

            float[] confidenceInfo = new float[class_num];
            float[] rectData = new float[4];

            List<DetectionResult> detResults = new List<DetectionResult>();

            for (int i = 0; i < box_num; i++)
            {
                Array.Copy(data, i * (class_num + 4), rectData, 0, 4);
                Array.Copy(data, i * (class_num + 4) + 4, confidenceInfo, 0, class_num);

                float score = confidenceInfo.Max(); // 获取最大值

                int maxIndex = Array.IndexOf(confidenceInfo, score); // 获取最大值的位置

                int _centerX = (int)(rectData[0] * ratio_width);
                int _centerY = (int)(rectData[1] * ratio_height);
                int _width = (int)(rectData[2] * ratio_width);
                int _height = (int)(rectData[3] * ratio_height);

                detResults.Add(new DetectionResult(
                   maxIndex,
                   class_names[maxIndex],
                   new Rect(_centerX - _width / 2, _centerY - _height / 2, _width, _height),
                   score));
            }

            //NMS
            CvDnn.NMSBoxes(detResults.Select(x => x.Rect), detResults.Select(x => x.Confidence), conf_threshold, nms_threshold, out int[] indices);
            detResults = detResults.Where((x, index) => indices.Contains(index)).ToList();

            detectionResults = detResults;
        }

        internal List<DetectionResult> Detect(Mat image)
        {

            var t1 = Cv2.GetTickCount();

            Stopwatch stopwatch = new Stopwatch();
            stopwatch.Start();

            Preprocess(image);

            preprocessTime = stopwatch.Elapsed.TotalMilliseconds;
            stopwatch.Restart();

            predictor.LoadInferenceData("images", input_tensor_data);

            predictor.infer();

            inferTime = stopwatch.Elapsed.TotalMilliseconds;
            stopwatch.Restart();

            outputData = predictor.GetInferenceResult("output0");

            Postprocess(outputData);

            postprocessTime = stopwatch.Elapsed.TotalMilliseconds;
            stopwatch.Stop();

            totalTime = preprocessTime + inferTime + postprocessTime;

            detFps = (double)stopwatch.Elapsed.TotalSeconds / (double)stopwatch.Elapsed.Ticks;

            var t2 = Cv2.GetTickCount();

            detFps = 1 / ((t2 - t1) / Cv2.GetTickFrequency());

            return detectionResults;

        }

    }
}

ByteTracker.cs

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace ByteTrack
{
    public class ByteTracker
    {
        readonly float _trackThresh;
        readonly float _highThresh;
        readonly float _matchThresh;
        readonly int _maxTimeLost;

        int _frameId = 0;
        int _trackIdCount = 0;

        readonly List<Track> _trackedTracks = new List<Track>(100);
        readonly List<Track> _lostTracks = new List<Track>(100);
        List<Track> _removedTracks = new List<Track>(100);

        public ByteTracker(int frameRate = 30, int trackBuffer = 30, float trackThresh = 0.5f, float highThresh = 0.6f, float matchThresh = 0.8f)
        {
            _trackThresh = trackThresh;
            _highThresh = highThresh;
            _matchThresh = matchThresh;
            _maxTimeLost = (int)(frameRate / 30.0 * trackBuffer);
        }

        /// <summary>
        /// 
        /// </summary>
        /// <param name="objects"></param>
        /// <returns></returns>
        public IList<Track> Update(List<Track> tracks)
        {
            #region Step 1: Get detections 
            _frameId++;

            // Create new Tracks using the result of object detection
            List<Track> detTracks = new List<Track>();
            List<Track> detLowTracks = new List<Track>();

            foreach (var obj in tracks)
            {
                if (obj.Score >= _trackThresh)
                {
                    detTracks.Add(obj);
                }
                else
                {
                    detLowTracks.Add(obj);
                }
            }

            // Create lists of existing STrack
            List<Track> activeTracks = new List<Track>();
            List<Track> nonActiveTracks = new List<Track>();

            foreach (var trackedTrack in _trackedTracks)
            {
                if (!trackedTrack.IsActivated)
                {
                    nonActiveTracks.Add(trackedTrack);
                }
                else
                {
                    activeTracks.Add(trackedTrack);
                }
            }

            var trackPool = activeTracks.Union(_lostTracks).ToArray();

            // Predict current pose by KF
            foreach (var track in trackPool)
            {
                track.Predict();
            }
            #endregion

            #region Step 2: First association, with IoU 
            List<Track> currentTrackedTracks = new List<Track>();
            Track[] remainTrackedTracks;
            Track[] remainDetTracks;
            List<Track> refindTracks = new List<Track>();
            {
                var dists = CalcIouDistance(trackPool, detTracks);
                LinearAssignment(dists, trackPool.Length, detTracks.Count, _matchThresh,
                    out var matchesIdx,
                    out var unmatchTrackIdx,
                    out var unmatchDetectionIdx);

                foreach (var matchIdx in matchesIdx)
                {
                    var track = trackPool[matchIdx[0]];
                    var det = detTracks[matchIdx[1]];
                    if (track.State == TrackState.Tracked)
                    {
                        track.Update(det, _frameId);
                        currentTrackedTracks.Add(track);
                    }
                    else
                    {
                        track.ReActivate(det, _frameId);
                        refindTracks.Add(track);
                    }
                }

                remainDetTracks = unmatchDetectionIdx.Select(unmatchIdx => detTracks[unmatchIdx]).ToArray();
                remainTrackedTracks = unmatchTrackIdx
                    .Where(unmatchIdx => trackPool[unmatchIdx].State == TrackState.Tracked)
                    .Select(unmatchIdx => trackPool[unmatchIdx])
                    .ToArray();
            }
            #endregion

            #region Step 3: Second association, using low score dets 
            List<Track> currentLostTracks = new List<Track>();
            {
                var dists = CalcIouDistance(remainTrackedTracks, detLowTracks);
                LinearAssignment(dists, remainTrackedTracks.Length, detLowTracks.Count, 0.5f,
                                 out var matchesIdx,
                                 out var unmatchTrackIdx,
                                 out var unmatchDetectionIdx);

                foreach (var matchIdx in matchesIdx)
                {
                    var track = remainTrackedTracks[matchIdx[0]];
                    var det = detLowTracks[matchIdx[1]];
                    if (track.State == TrackState.Tracked)
                    {
                        track.Update(det, _frameId);
                        currentTrackedTracks.Add(track);
                    }
                    else
                    {
                        track.ReActivate(det, _frameId);
                        refindTracks.Add(track);
                    }
                }

                foreach (var unmatchTrack in unmatchTrackIdx)
                {
                    var track = remainTrackedTracks[unmatchTrack];
                    if (track.State != TrackState.Lost)
                    {
                        track.MarkAsLost();
                        currentLostTracks.Add(track);
                    }
                }
            }
            #endregion

            #region Step 4: Init new tracks 
            List<Track> currentRemovedTracks = new List<Track>();
            {
                // Deal with unconfirmed tracks, usually tracks with only one beginning frame
                var dists = CalcIouDistance(nonActiveTracks, remainDetTracks);
                LinearAssignment(dists, nonActiveTracks.Count, remainDetTracks.Length, 0.7f,
                                 out var matchesIdx,
                                 out var unmatchUnconfirmedIdx,
                                 out var unmatchDetectionIdx);

                foreach (var matchIdx in matchesIdx)
                {
                    nonActiveTracks[matchIdx[0]].Update(remainDetTracks[matchIdx[1]], _frameId);
                    currentTrackedTracks.Add(nonActiveTracks[matchIdx[0]]);
                }

                foreach (var unmatchIdx in unmatchUnconfirmedIdx)
                {
                    var track = nonActiveTracks[unmatchIdx];
                    track.MarkAsRemoved();
                    currentRemovedTracks.Add(track);
                }

                // Add new stracks
                foreach (var unmatchIdx in unmatchDetectionIdx)
                {
                    var track = remainDetTracks[unmatchIdx];
                    if (track.Score < _highThresh)
                        continue;

                    _trackIdCount++;
                    track.Activate(_frameId, _trackIdCount);
                    currentTrackedTracks.Add(track);
                }
            }
            #endregion

            #region Step 5: Update state
            foreach (var lostTrack in _lostTracks)
            {
                if (_frameId - lostTrack.FrameId > _maxTimeLost)
                {
                    lostTrack.MarkAsRemoved();
                    currentRemovedTracks.Add(lostTrack);
                }
            }

            var trackedTracks = currentTrackedTracks.Union(refindTracks).ToArray();
            var lostTracks = _lostTracks.Except(trackedTracks).Union(currentLostTracks).Except(_removedTracks).ToArray();
            _removedTracks = _removedTracks.Union(currentRemovedTracks).ToList();
            RemoveDuplicateStracks(trackedTracks, lostTracks);
            #endregion

            return _trackedTracks.Where(track => track.IsActivated).ToArray();
        }

        /// <summary>
        /// 
        /// </summary>
        /// <param name="aTracks"></param>
        /// <param name="bTracks"></param>
        /// <param name="aResults"></param>
        /// <param name="bResults"></param>
        void RemoveDuplicateStracks(IList<Track> aTracks, IList<Track> bTracks)
        {
            _trackedTracks.Clear();
            _lostTracks.Clear();

            List<(int, int)> overlappingCombinations;
            var ious = CalcIouDistance(aTracks, bTracks);

            if (ious is null)
                overlappingCombinations = new List<(int, int)>();
            else
            {
                var rows = ious.GetLength(0);
                var cols = ious.GetLength(1);
                overlappingCombinations = new List<(int, int)>(rows * cols / 2);
                for (var i = 0; i < rows; i++)
                    for (var j = 0; j < cols; j++)
                        if (ious[i, j] < 0.15f)
                            overlappingCombinations.Add((i, j));
            }

            var aOverlapping = aTracks.Select(x => false).ToArray();
            var bOverlapping = bTracks.Select(x => false).ToArray();

            foreach (var (aIdx, bIdx) in overlappingCombinations)
            {
                var timep = aTracks[aIdx].FrameId - aTracks[aIdx].StartFrameId;
                var timeq = bTracks[bIdx].FrameId - bTracks[bIdx].StartFrameId;
                if (timep > timeq)
                    bOverlapping[bIdx] = true;
                else
                    aOverlapping[aIdx] = true;
            }

            for (var ai = 0; ai < aTracks.Count; ai++)
                if (!aOverlapping[ai])
                    _trackedTracks.Add(aTracks[ai]);

            for (var bi = 0; bi < bTracks.Count; bi++)
                if (!bOverlapping[bi])
                    _lostTracks.Add(bTracks[bi]);
        }

        /// <summary>
        /// 
        /// </summary>
        /// <param name="costMatrix"></param>
        /// <param name="costMatrixSize"></param>
        /// <param name="costMatrixSizeSize"></param>
        /// <param name="thresh"></param>
        /// <param name="matches"></param>
        /// <param name="aUnmatched"></param>
        /// <param name="bUnmatched"></param>
        void LinearAssignment(float[,] costMatrix, int costMatrixSize, int costMatrixSizeSize, float thresh, out IList<int[]> matches, out IList<int> aUnmatched, out IList<int> bUnmatched)
        {
            matches = new List<int[]>();
            if (costMatrix is null)
            {
                aUnmatched = Enumerable.Range(0, costMatrixSize).ToArray();
                bUnmatched = Enumerable.Range(0, costMatrixSizeSize).ToArray();
                return;
            }

            bUnmatched = new List<int>();
            aUnmatched = new List<int>();

            var (rowsol, colsol) = Lapjv.Exec(costMatrix, true, thresh);

            for (var i = 0; i < rowsol.Length; i++)
            {
                if (rowsol[i] >= 0)
                    matches.Add(new int[] { i, rowsol[i] });
                else
                    aUnmatched.Add(i);
            }

            for (var i = 0; i < colsol.Length; i++)
                if (colsol[i] < 0)
                    bUnmatched.Add(i);
        }

        /// <summary>
        /// 
        /// </summary>
        /// <param name="aRects"></param>
        /// <param name="bRects"></param>
        /// <returns></returns>
        static float[,] CalcIous(IList<RectBox> aRects, IList<RectBox> bRects)
        {
            if (aRects.Count * bRects.Count == 0) return null;

            var ious = new float[aRects.Count, bRects.Count];
            for (var bi = 0; bi < bRects.Count; bi++)
                for (var ai = 0; ai < aRects.Count; ai++)
                    ious[ai, bi] = bRects[bi].CalcIoU(aRects[ai]);

            return ious;
        }

        /// <summary>
        /// 
        /// </summary>
        /// <param name="aTtracks"></param>
        /// <param name="bTracks"></param>
        /// <returns></returns>
        static float[,] CalcIouDistance(IEnumerable<Track> aTtracks, IEnumerable<Track> bTracks)
        {
            var aRects = aTtracks.Select(x => x.RectBox).ToArray();
            var bRects = bTracks.Select(x => x.RectBox).ToArray();

            var ious = CalcIous(aRects, bRects);
            if (ious is null) return null;

            var rows = ious.GetLength(0);
            var cols = ious.GetLength(1);
            var matrix = new float[rows, cols];
            for (var i = 0; i < rows; i++)
                for (var j = 0; j < cols; j++)
                    matrix[i, j] = 1 - ious[i, j];

            return matrix;
        }
    }
}

下载

源码下载

参考 

https://github.com/devhxj/Yolo8-ByteTrack-CSharp

https://github.com/guojin-yan/TensorRT-CSharp-API

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1712405.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

保姆教程系列:小白也能看懂的 Linux 挂载磁盘实操

&#xff01;&#xff01;&#xff01;是的没错&#xff0c;胖友们&#xff0c;保姆教程系列又更新了&#xff01;&#xff01;&#xff01; 文章目录 前言简介一、磁盘分区二、文件系统三、实际操作1. 使用lsblk命令查看新加入的磁盘信息2. 使用fdisk或者cfdisk分区新磁盘&am…

工业制造企业为什么要进行数字化转型

人人都在谈数字化转型&#xff0c;政府谈数字化策略方针&#xff0c;企业谈数字化转型方案&#xff0c;员工谈数字化提效工具。互联网企业在谈&#xff0c;工业企业也在谈。 在这种大趋势下&#xff0c;作为一个从事TOB行业十年的老兵&#xff0c;今天就来给大家讲讲&#xff…

【RLHF个人笔记】RLHF:Reinforcement Learning from Human Feedback具体过程

【RLHF个人笔记】RLHF:Reinforcement Learning from Human Feedback具体过程 RLHF训练的三个步骤步骤1&#xff1a;收集数据与有监督训练策略步骤2&#xff1a;收集数据训练奖励模型步骤3&#xff1a;结合奖励模型利用强化学习算法如PPO算法来优化策略 参考内容 RLHF训练的三个…

236. 二叉树的最近公共祖先(C++)

文章目录 前言一、题目介绍二、解决方案三、优化总结 前言 在本篇文章中我们将会讲解二叉树中极为经典的题目236. 二叉树的最近公共祖先 一、题目介绍 给定一个二叉树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为&#xff1a;“对于有根树 T 的…

BLE蓝牙模块在虚拟车钥匙上的运用—开启无钥匙驾驶新时代

随着科技的不断发展&#xff0c;人们对汽车的智能化需求也日益增长。在这个背景下&#xff0c;BLE蓝牙模块在虚拟车钥匙上的运用应运而生&#xff0c;为消费者带来更加便捷、智能的出行体验。本文将从以下几个方面阐述BLE蓝牙模块在虚拟车钥匙上的应用。   一、什么是BLE蓝牙…

精酿啤酒:品质与口感在啤酒行业竞争中的竞争优势

在啤酒行业中&#xff0c;竞争激烈&#xff0c;品牌众多。要想在竞争中脱颖而出&#xff0c;需要具备与众不同的竞争优势。对于Fendi club啤酒而言&#xff0c;其卓着的品质和与众不同的口感成为了其在竞争中取胜的关键。 品质是啤酒行业竞争中的核心要素。Fendi club啤酒在原料…

Redis中的数据结构与内部编码

本篇文章主要是对 Redis 常见的数据结构进行讲解&#xff0c;同时还对其所对应的不同的内部编码进行讲解。希望本篇文章会对你有所帮助。 文章目录 一、五大数据结构 二、数据结构对应的编码方式 String hash list set zset &#x1f64b;‍♂️ 作者&#xff1a;Ggggggtm &…

node.js(express)+MongoDB快速搭建后端---新手教程

前言&#xff1a; Node.js是一个基于 Chrome V8引擎的JavaScript运行环境&#xff0c;是对于前端工程师来说学习成本最小的后端实现方法&#xff0c;本篇文章总结如何从0-1写一个后端的登录接口 一、检查node环境 先检查自己的node是否安装 一般来说前端工程师的电脑环境肯定…

长安链使用Golang编写智能合约教程(二)

本篇说的是长安链2.3.的版本的智能合约&#xff0c;虽然不知道两者有什么区别&#xff0c;但是编译器区分。 教程三会写一些&#xff0c;其他比较常用SDK方法的解释和使用方法 编写前的注意事项&#xff1a; 1、运行一条带有Doker_GoVM的链 2、建议直接用官方的在线IDE去写合…

【机器学习】Pandas中to_pickle()函数的介绍与机器学习中的应用

【机器学习】Pandas中to_pickle()函数的介绍和机器学习中的应用 &#x1f308; 欢迎莅临我的个人主页&#x1f448;这里是我深耕Python编程、机器学习和自然语言处理&#xff08;NLP&#xff09;领域&#xff0c;并乐于分享知识与经验的小天地&#xff01;&#x1f387; &#…

【Android】【netd】网络相关调试技巧

网络调试技巧总结 ifconfig ifconfig 查看网卡信息 ifconfig -S tcpdump tcpdump -i any -n icmp 查看流量出入ip addr 上面的log 以及ifcong -S 信息可以知道&#xff0c;当前是从wlan0 网卡请求数据。 iptable iptable 部分指令 //禁止www.baidu.com 网址流量进入&a…

网易面试:手撕定时器

概述&#xff1a; 本文使用STL容器-set以及Linux提供的timerfd来实现定时器组件 所谓定时器就是管理大量定时任务&#xff0c;使其能按照超时时间有序地被执行 需求分析&#xff1a; 1.数据结构的选择&#xff1a;存储定时任务 2.驱动方式&#xff1a;如何选择一个任务并执…

在HTML和CSS当中运用显示隐藏

1.显示与隐藏 盒子显示:display:block;盒子隐藏: display:none:隐藏该元素并且该元素所占的空间也不存在了。 visibility:hidden:隐藏该元素但是该元素所占的内存空间还存在&#xff0c;即“隐身效果”。 2.圆角边框 在CSS2中添加圆角&#xff0c;我们不得不使用背景图像&am…

redis面试知识点

Redis知识点 Redis的RDB和AOF机制各是什么&#xff1f;它们有什么区别&#xff1f; 答&#xff1a;Redis提供了RDB和AOF两种数据持久化机制&#xff0c;适用于不同的场景。 RDB是通过在特定的时刻对内存中的完整的数据复制快照进行持久化的。 RDB工作原理&#xff1a; 当执行…

Python 机器学习 基础 之 无监督学习 【聚类(clustering)/k均值聚类/凝聚聚类/DBSCAN】的简单说明

Python 机器学习 基础 之 无监督学习 【聚类&#xff08;clustering&#xff09;/k均值聚类/凝聚聚类/DBSCAN】的简单说明 目录 Python 机器学习 基础 之 无监督学习 【聚类&#xff08;clustering&#xff09;/k均值聚类/凝聚聚类/DBSCAN】的简单说明 一、简单介绍 二、聚类…

Vue3兼容低版本浏览器(ie11,chrome63)

1、插件安装 为了使你的项目兼容 Chrome 63&#xff0c;你需要确保包含适当的 polyfills 和插件配置。你已经在使用 legacy 插件&#xff0c;但在代码中可能缺少一些配置或插件顺序有问题。以下是几个可能的改进&#xff1a; 安装 vitejs/plugin-legacy 插件&#xff1a; 确保…

Midjourney保姆级教程(五):Midjourney图生图

Midjourney生成图片的方式除了使用文字描述生成图片外&#xff0c;还有“图生图”的方式&#xff0c;可以让生成的图片更接近参考的图片。 今天我们来聊聊“图生图”的方式。 一、模仿获取propmt 很多时候&#xff0c;我们不知道画什么内容的图片&#xff0c;大家可以关注内…

一款拥有15000+POC漏洞扫描工具

1 工具介绍 0x01 免责声明 请勿使用本文中所提供的任何技术信息或代码工具进行非法测试和违法行为。若使用者利用本文中技术信息或代码工具对任何计算机系统造成的任何直接或者间接的后果及损失&#xff0c;均由使用者本人负责。本文所提供的技术信息或代码工具仅供于学习&am…

vue3快速上手笔记(尚硅谷)

[TOC]# 1. Vue3简介 2020年9月18日&#xff0c;Vue.js发布版3.0版本&#xff0c;代号&#xff1a;One Piece&#xff08;n 经历了&#xff1a;4800次提交、40个RFC、600次PR、300贡献者 官方发版地址&#xff1a;Release v3.0.0 One Piece vuejs/core 截止2023年10月&#…

经典必读:智能制造数字化工厂建设方案

导语 大家好&#xff0c;我是社长&#xff0c;老K。专注分享智能制造和智能仓储物流等内容。 新书《智能物流系统构成与技术实践》 完整版文件和更多学习资料&#xff0c;请球友到知识星球【智能仓储物流技术研习社】自行下载 战略背景&#xff1a;响应《中国制造2025》及"…