从0到1制作单只鳌虾运动轨迹追踪软件

news2024/10/5 13:50:47

前言

需要准备windows10操作系统,python3.11.9,cuDNN8.9.2.26,CUDA11.8,paddleDetection2.7

流程:

  1. 准备数据集-澳洲鳌虾VOC数据集 
  2. 基于RT-DETR目标检测模型训练
  3. 导出onnx模型进行python部署
  4. 平滑滤波处理视频帧保留的物体质心坐标
  5. 基于pywebview为软件前端,falsk为软件后端制作UI
  6. 使用pyinstaller打包成exe
  7. 使用into setup生成安装包

本人代码禁止任何商业化用途,个人开发者随意。所有代码均开源

项目目录

XXX 项目总目录
    static 存放js静态文件
        plotly.js
    templates 存放html文件
        index.html
    temp 用户上传文件保存路径
    venv 虚拟环境
    main.py 主程序
    model.onnx 模型文件
    1.ico 打包的程序图标

准备数据集

点击下载澳洲鳌虾VOC数据集

下载后解压,文件目录为

data
    Annotations
        0.xml
        1.xml
        ...
    imgs
        0.jpg
        1.jpg
        ...
    lables.txt

然后使用如下的脚本把数据集划分为训练集和测试集

import os
import random
import shutil


def splitDatasets(images_dir,xmls_dir,train_dir,test_dir):

    if os.path.exists(train_dir):
        shutil.rmtree(train_dir)
        
    os.makedirs(train_dir)
    os.makedirs(train_dir+'/imgs')
    os.makedirs(train_dir+'/annotations')
        
    if os.path.exists(test_dir):
        shutil.rmtree(test_dir)
        
    os.makedirs(test_dir)
    os.makedirs(test_dir+'/imgs')
    os.makedirs(test_dir+'/annotations')
        
    images=os.listdir(images_dir)
    random.shuffle(images)

    split_index=int(0.9*len(images))

    train_images=images[:split_index]
    test_images=images[split_index:]

    with open(train_dir+'/train.txt','w') as file:
        for img in train_images:
            shutil.copy(os.path.join(images_dir,img),os.path.join(train_dir,'imgs',img))
            ann=img.replace('jpg','xml')
            shutil.copy(os.path.join(xmls_dir,ann),os.path.join(train_dir,'annotations',ann))
            line=os.path.join(train_dir,'imgs',img)+' '+os.path.join(train_dir,'annotations',ann)+'\n'
            file.write(line)

    with open(test_dir+'/test.txt','w') as file:
        for img in test_images:
            shutil.copy(os.path.join(images_dir,img),os.path.join(test_dir,'imgs',img))
            ann=img.replace('jpg','xml')
            shutil.copy(os.path.join(xmls_dir,ann),os.path.join(test_dir,'annotations',ann))
            line=os.path.join(test_dir,'imgs',img)+' '+os.path.join(test_dir,'annotations',ann)+'\n'
            file.write(line)
        
    shutil.rmtree(images_dir)
    shutil.rmtree(xmls_dir)
    
if __name__=='__main__':
    # 填写img文件夹所在绝对路径
    images_dir='/home/aistudio/work/voc/imgs'
    # 填写Annotations文件夹所在绝对路径
    xmls_dir='/home/aistudio/work/voc/Annotations'
    # 填写 训练集 的存放的绝对路径
    train_dir='/home/aistudio/work/voc/trains'
    # 填写 测试集 的存放的绝对路径
    test_dir='/home/aistudio/work/voc/tests'
    
    splitDatasets(images_dir,xmls_dir,train_dir,test_dir)

训练模型

可在aistudio云平台训练,我放好了所有的相关文件,点击进入,里面的说明很详细

也可在本地进行训练,下面来配置本地的训练环境

配置相关文件

下载paddleDetection2.7

原始目录如下

paddleDetection2.7
    .github
    .travis
    activity
    benchmark
    configs 模型配置文件
    dataset 里面有数据集下载的脚本文件
    demo
    deploy 推理的相关文件
    docs 说明文档
    industrial_tutorial
    ppdet 模型运行的核心文件
    scripts
    test_pic
    tools 模型训练入口,测试,验证,导出等脚本文件
    .gitignore
    .pre-commit-config.yaml
    .style.yapf
    .travis.yml
    LICENSE
    README_cn.md 说明文档中文版
    README_en.md 说明文档英文版
    requirements.txt 相关依赖库
    setup.py 模型编译的相关脚本

需要删除一些目录,把README_en.md改名为README.md,处理过的目录如下

paddleDetection2.7
    configs
    dataset
    deploy
    ppdet
    tools
    README.md
    requirements.txt
    setup.py

把dataset里所有东西都删除,再将划分好的数据集放到该文件下,处理好的目录如下

dataset
    voc
        trains
            annotations
            imgs
            train.txt
        tests
            annotations
            imgs
            test.txt
        labels.txt

进入tools目录,只保留如下文件,其余全删除,处理后的文件目录如下

tools
    train.py
    infer.py
    eval.py
    export_model.py

进入configs目录,只保留下面三个文件和目录,处理后的目录如下

configs
    datasets
    rtdetr
    runtime.yml

进入datasets目录,只保留voc.yml,其余文件全删除,处理后的目录如下

datasets
    voc.yml

并用如下内容覆盖voc.yml

metric: VOC
map_type: 11point
num_classes: 1

TrainDataset:
  name: VOCDataSet
  dataset_dir: dataset/voc
  anno_path: trains/train.txt
  label_list: labels.txt
  data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']

EvalDataset:
  name: VOCDataSet
  dataset_dir: dataset/voc
  anno_path: tests/test.txt
  label_list: labels.txt
  data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']

TestDataset:
  name: ImageFolder
  anno_path: dataset/labels.txt

进入rtdetr目录,只保留如下2个文件和目录,处理后的目录如下:

rtdetr
    _base_
    rtdetr_hgnetv2_x_6x_coco.yml

进入_base_目录,找到optimizer_6x.yml,修改第一行为epoch: 200,意思是训练200轮

找到rtdetr_reader.yml,根据自己的CPU和GPU调整相关参数,如果是4核CPU,worker_num可为8,batch_size根据显存调整,占用80%到90%的显存即可

安装依赖库

建议在虚拟环境中操作

!pip install -r requirements.txt
!pip install pycocotools
!pip install filterpy
!pip install flask
!pip install pyinstaller
!pip install pywebview
!pip install onnxruntime-gpu
!pip install onnxruntime
!pip install onnx
!pip install paddle2onnx
!python setup.py install

开始训练

建议命令行输入,先进入paddleDetection所在位置,再执行以下命令

python tools/train.py -c configs/rtdetr/rtdetr_hgnetv2_x_6x_coco.yml --eval --use_vdl True --vdl_log_dir vdl_log_dir/scalar

然后就是漫长的等待

导出模型

生成的模型在paddleDetection/output/best_model/model.pdparams

先进入paddleDetection所在位置,再执行以下命令

python tools/export_model.py -c configs/rtdetr/rtdetr_hgnetv2_x_6x_coco.yml -o weights=output/best_model/model.pdparams

转onnx

先进入paddleDetection所在位置,再执行以下命令,可以根据需要选择保存路径

paddle2onnx --model_dir=output_inference/rtdetr_hgnetv2_x_6x_coco/ \
            --model_filename model.pdmodel  \
            --params_filename model.pdiparams \
            --opset_version 16 \
            --save_file /home/work/infer/model.onnx

模型部署

导包

import webview
from flask import Flask, request, jsonify,render_template,stream_with_context,Response
import os
import time
import cv2
from onnxruntime import InferenceSession
import numpy as np
from werkzeug.utils import secure_filename

总览代码

class TrackShrimp():
    def __init__(self,video_path,model_path,onnx_threshold=0.7):
        # 获取帧数据
        self.cap=self.init_video(video_path)
        frame_width=int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height=int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        # 图形尺寸
        im_shape = np.array([[frame_height, frame_width]], dtype='float32')
        # y轴缩放量
        self.im_scale_y=640.0/frame_height
        # x轴缩放量
        self.im_scale_x=640.0/frame_width
        scale_factor = np.array([[self.im_scale_y,self.im_scale_x]]).astype('float32')
        # 定义模型输入
        self.inputs_dict = {
            'im_shape': im_shape,
            'image': None,
            'scale_factor': scale_factor
            }
        # 初始化模型
        self.sess=self.init_session(model_path)
        # 模型输出阈值
        self.onnx_threshold=onnx_threshold

    def init_video(self,input_path):
        cap=cv2.VideoCapture(input_path)
        if not cap.isOpened():
            raise ValueError(f'无法打开视频{input_path}')
        return cap

    def init_session(self,model_path):
        try:
            return InferenceSession(model_path, providers=['CUDAExecutionProvider']) 
        except:
            return InferenceSession(model_path, providers=['CPUExecutionProvider'])
        

    def precess_img(self,frame):
        img = cv2.resize(frame, None,None,fx=self.im_scale_x,fy=self.im_scale_y,interpolation=2)
        img = img.astype(np.float32) / 255.0
        img = np.transpose(img, [2, 0, 1])
        img = img[np.newaxis, :, :, :]
        return img
 
    def postcess(self,results:np.ndarray,all_centers:list[np.ndarray]):
        results=results[(results[:, 0] == 0) & (results[:, 1] > self.onnx_threshold)]
        x_centers = (results[:, 2] + results[:, 4]) / 2
        y_centers = (results[:, 3] + results[:, 5]) / 2
        centers = np.column_stack((x_centers, y_centers))
        all_centers.extend(centers)

    def by_smoothfilter(self,centers:list[np.ndarray],window_size=24):
        """
        :param centers: list[np.ndarray,np.ndarray,...]
        :param window_size: 平滑窗口大小
        :return: 平滑后的质心坐标NumPy数组
        """
        centers=np.stack(centers)
        # 计算滑动窗口的平均值,pad函数在序列前后补零以处理边界情况
        padded_centers = np.pad(centers, ((window_size//2, window_size//2), (0, 0)), mode='edge')
        window_sum = np.cumsum(padded_centers, axis=0)
        smoothed_centers = (window_sum[window_size:] - window_sum[:-window_size]) / window_size
        return smoothed_centers

    def calculate_distance(self,centers:np.ndarray):
        '''
        centers:np.ndarray n*2
        '''
        # 计算相邻点之间的差
        diffs = centers[1:] - centers[:-1]
        # 计算每个差值的欧几里得距离
        distances = np.linalg.norm(diffs, axis=1)
        # 计算总路程
        return int(np.sum(distances))

    def gain_position(self,centers:np.ndarray):
        position_list=centers.tolist()
        return position_list
    
    def run(self):
        global schedule
        global run_task
        # 帧数
        frame_count=int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_number=0
        center_list=[]
        for frame_number in range(frame_count):
            if not run_task:
                return
            success, frame = self.cap.read()
            if not success:
                break
            schedule=int(frame_number/frame_count*100)
            # 打印进度
            if frame_number%10==0:
                print('Process: ',schedule)
            # 图片预处理
            img=self.precess_img(frame)
            self.inputs_dict['image']=img
            results=self.sess.run(None,self.inputs_dict)[0]
            if results is not None:
                self.postcess(results,center_list)
        # 使用平滑滤波
        filtered_centers = self.by_smoothfilter(center_list)
        self.cap.release()

        # 返回路程,轨迹坐标
        return self.calculate_distance(filtered_centers),self.gain_position(filtered_centers)

由于是对视频进行推理,所以首先得初始化视频打开的方法

def init_video(self,input_path):
        cap=cv2.VideoCapture(input_path)
        if not cap.isOpened():
            raise ValueError(f'无法打开视频{input_path}')
        return cap

初始化onnx运行引擎,优先使用显卡,如果CUDA环境有问题,就使用CPU运行

    def init_session(self,model_path):
        try:
            return InferenceSession(model_path, providers=['CUDAExecutionProvider']) 
        except:
            return InferenceSession(model_path, providers=['CPUExecutionProvider'])

onnx引擎需要一定的输入格式,放到类的init里

    def __init__(self,video_path,model_path,onnx_threshold=0.7):
        # 获取帧数据
        self.cap=self.init_video(video_path)
        frame_width=int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height=int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        # 图形尺寸
        im_shape = np.array([[frame_height, frame_width]], dtype='float32')
        # y轴缩放量
        self.im_scale_y=640.0/frame_height
        # x轴缩放量
        self.im_scale_x=640.0/frame_width
        scale_factor = np.array([[self.im_scale_y,self.im_scale_x]]).astype('float32')
        # 定义模型输入
        self.inputs_dict = {
            'im_shape': im_shape,
            'image': None,
            'scale_factor': scale_factor
            }
        # 初始化模型
        self.sess=self.init_session(model_path)
        # 模型输出阈值
        self.onnx_threshold=onnx_threshold

在提取每一帧后需要进行图像处理,resize图片为模型输入的要求,归一化

    def precess_img(self,frame):
        img = cv2.resize(frame, None,None,fx=self.im_scale_x,fy=self.im_scale_y,interpolation=2)
        img = img.astype(np.float32) / 255.0
        img = np.transpose(img, [2, 0, 1])
        img = img[np.newaxis, :, :, :]
        return img

在提取到视频的每一帧中的鳌虾的质心坐标后,由于每一帧的图像都不一样,输入模型后再输出的结果就不一样,会抖动,也就是噪声,我们需要滤波去噪,这里使用平滑滤波,相比卡尔曼滤波简单使用快速出结果。

    def by_smoothfilter(self,centers:list[np.ndarray],window_size=24):
        """
        :param centers: list[np.ndarray,np.ndarray,...]
        :param window_size: 平滑窗口大小
        :return: 平滑后的质心坐标NumPy数组
        """
        centers=np.stack(centers)
        # 计算滑动窗口的平均值,pad函数在序列前后补零以处理边界情况
        padded_centers = np.pad(centers, ((window_size//2, window_size//2), (0, 0)), mode='edge')
        window_sum = np.cumsum(padded_centers, axis=0)
        smoothed_centers = (window_sum[window_size:] - window_sum[:-window_size]) / window_size
        return smoothed_centers

我们需要计算鳌虾的运动总路程,用滤波后的质心坐标计算

    def calculate_distance(self,centers:np.ndarray):
        '''
        centers:np.ndarray n*2
        '''
        # 计算相邻点之间的差
        diffs = centers[1:] - centers[:-1]
        # 计算每个差值的欧几里得距离
        distances = np.linalg.norm(diffs, axis=1)
        # 计算总路程
        return int(np.sum(distances))

滤波后的质心坐标是numpy数组,需要一定的转换再发送到前端进行渲染(matplotlib画的图太丑了,不如plotly.js)

    def gain_position(self,centers:np.ndarray):
        position_list=centers.tolist()
        return position_list

在获取每一帧图像后,送入模型。模型会输出一对numpy数组,需要进行一对的后处理,低于阈值的就抛弃,然后取阈值最高的,计算质心坐标并保存

    def postcess(self,results:np.ndarray,all_centers:list[np.ndarray]):
        results=results[(results[:, 0] == 0) & (results[:, 1] > self.onnx_threshold)]
        x_centers = (results[:, 2] + results[:, 4]) / 2
        y_centers = (results[:, 3] + results[:, 5]) / 2
        centers = np.column_stack((x_centers, y_centers))
        all_centers.extend(centers)

需要在一个主函数里将上述打开视频,图像预处理,送入模型,后处理连起来

    def run(self):
        global schedule
        global run_task
        # 帧数
        frame_count=int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_number=0
        center_list=[]
        for frame_number in range(frame_count):
            if not run_task:
                return
            success, frame = self.cap.read()
            if not success:
                break
            schedule=int(frame_number/frame_count*100)
            # 打印进度
            if frame_number%10==0:
                print('Process: ',schedule)
            # 图片预处理
            img=self.precess_img(frame)
            self.inputs_dict['image']=img
            results=self.sess.run(None,self.inputs_dict)[0]
            if results is not None:
                self.postcess(results,center_list)
        # 使用平滑滤波
        filtered_centers = self.by_smoothfilter(center_list)
        self.cap.release()

        # 返回路程,轨迹坐标
        return self.calculate_distance(filtered_centers),self.gain_position(filtered_centers)

前端的设计

以pywebview为平台,html和css设计前端

 

 

 

代码总览

index.html

<!DOCTYPE html>
<html>
<head>
    <title></title>
    <link rel="shortcut icon" href="#" />
    <script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
    <script src="../static/plotly.js"></script>
    <style>
        html,body{
            width: 100%;
            height: 100%;
            margin: 0 auto;

        }
        body{
            display: flex;
            align-items: center;
            justify-content: center;
            height: 100vh;
            background-color: rgb(6, 32, 80);
        }
        main{
            display: grid;
            grid-template-columns: 1fr 3fr;
            column-gap: 2%;
            width: 98%;
            height: 98%;
        }
        fieldset{
            border: 2px solid rgb(32, 139, 139);
            color: rgb(32, 139, 139);
            margin: 8% 0 8% 0;
        }
        #s2{
            text-align: center;
            display: flex;
            justify-content: center;
            align-items: center;
            background-color: rgba(32, 139, 139, 0.301);
            border: 2px solid rgb(32, 139, 139);
        }
        #progress-circle{
            border: 1em solid rgb(32, 139, 139);
            width: 40vh;
            height: 40vh;
            border-radius: 20vh;
            display: flex;  
            justify-content: center; 
            align-items: center;
        }
        #progress-num{
            font-size: 18vh;
            color: rgb(32, 139, 139);
        }

    </style>
</head>
<body>
    <main>
        <section id="s1">
            <form id="form" enctype="multipart/form-data">
                <fieldset>
                    <legend>选择你要检测的视频</legend>
                    <input type="file" accept=".mp4" id="video" name="vedio">
                </fieldset>
                <fieldset>
                    <legend>功能按键</legend>
                    <button onclick="submit_to()" id="submit">开始上传</button>
                    <button onclick="stopRun()">终止运行</button>
                </fieldset>
            </form>
            <script>
                async function stopRun(){
                    try{
                        const response=await fetch('/stopRun',{method:'POST'})
                        if (!response.ok) {  
                            throw new Error('Network response was not ok.');  
                        }
                        data=await response.json()
                        alert(data.data)
                    }catch(error){
                        console.log(error)
                    }
                }
                
                async function submit_to(){
                    // 防重复激发
                    const button = document.getElementById('submit');  
                    button.disabled = true;
                    try{
                        // 获取文件
                        const input=document.getElementById('video')
                        const file=input.files[0]
                        if (!file){
                            throw new Error('未选择文件')
                        }
                        if(file.type!=='video/mp4'){
                            throw new Error('请选择MP4文件')
                        }
                        // 刷新界面 
                        const s2=document.getElementById('s2')
                        Plotly.purge(s2)
                        // 初始化进度显示
                        const progressCircle=document.getElementById('progress-circle')
                        const progressNum=document.getElementById('progress-num')
                        progressCircle.style.display='flex'
                        progressNum.innerHTML='0%'
                        // 更新进度
                        let source = new EventSource("/progress")
                        source.onmessage = function(event) {
                        progressNum.innerHTML = event.data+'%'
                        }
                        // 发送请求
                        const formData=new FormData()
                        formData.append('video', file)
                        const response=await fetch('/shrimp',{method:'POST',body:formData})
                        if (!response.ok) {
                            throw new Error('Network response was not ok.');  
                        }
                        source.close()
                        const data=await response.json()
                        button.disabled=false
                        if(data.data==='任务被终止'){
                            alert(data.data)
                        }
                        else{
                            progressCircle.style.display='none'
                            $('#distance').text('总路程'+data.distance)
                            // 画图
                            var trace=[{
                                x: data.position_data.map(item=>item[0]),
                                y: data.position_data.map(item=>item[1]),
                                mode:"lines",
                                line:{
                                        color:'rgb(32, 139, 139)'
                                    }
                            }]
                            var layout = {
                                xaxis: {
                                    range: [0, 600],
                                    title: "x(像素)",
                                    titlefont: {  
                                        color: 'rgb(32, 139, 139)' // 轴标签颜色  
                                    },  
                                    linecolor: 'rgb(32, 139, 139)', // 轴线颜色  
                                    tickfont: {  
                                        color: 'lrgb(32, 139, 139)' // 轴刻度标签颜色  
                                    }
                                },
                                yaxis: {range: [0, 600],
                                    title: "y(像素)",
                                    titlefont: {  
                                        color: 'rgb(32, 139, 139)' // 轴标签颜色  
                                    },  
                                    linecolor: 'rgb(32, 139, 139)', // 轴线颜色  
                                    tickfont: {  
                                        color: 'lrgb(32, 139, 139)' // 轴刻度标签颜色  
                                    }
                                },  
                                title: "鳌虾运动轨迹",
                                titlefont:{
                                    color:'rgb(32, 139, 139)'
                                },
                                plot_bgcolor: 'rgba(0,0,0,0)',
                                paper_bgcolor:'rgba(0,0,0,0)'
                                }
                            Plotly.newPlot("s2", trace, layout,{scrollZoom: true,editable: true }) 
                        }
                    }catch(error){
                        button.disabled = false
                        if(error.message.startsWith('Failed to fetch')){}
                        else{alert(error)}
                    }
                }
            
            </script>

            <fieldset>
                <legend>输出结果</legend>
                <P id="distance">总路程:</P>
            </fieldset>
            <fieldset>
                <legend>注意事项</legend>
                <p>本程序运行将消耗大量算力和内存,最好使用高配电脑。不支持windows10以下的操作系统。在后台有任务在跑时,切勿重复上传视频,
                    等待后台跑完出图时再上传新的视频。如果选错视频并上传了,请点击'终止运行'再重新上传视频。有问题联系wx:m989783106</p>
            </fieldset>
        </section>

        <section id="s2">
            <div id="progress-circle"><p id="progress-num"></p></div>
        </section>
    </main>
</body>
</html>

plotly.js从官网下载

代码分览

总体设计是以<html>和<body>为底,<main>为主容器内使用grid2列布局,2个<section>作为内容器占据左右2个网格。

左边的<section>容纳文件上传表单,功能按钮,数据显示,使用说明

        <section id="s1">
            <form id="form" enctype="multipart/form-data">
                <fieldset>
                    <legend>选择你要检测的视频</legend>
                    <input type="file" accept=".mp4" id="video" name="vedio">
                </fieldset>
                <fieldset>
                    <legend>功能按键</legend>
                    <button onclick="submit_to()" id="submit">开始上传</button>
                    <button onclick="stopRun()">终止运行</button>
                </fieldset>
            </form>

            <fieldset>
                <legend>输出结果</legend>
                <P id="distance">总路程:</P>
            </fieldset>
            <fieldset>
                <legend>注意事项</legend>
                <p>本程序运行将消耗大量算力和内存,最好使用高配电脑。不支持windows10以下的操作系统。在后台有任务在跑时,切勿重复上传视频,
                    等待后台跑完出图时再上传新的视频。如果选错视频并上传了,请点击'终止运行'再重新上传视频。有问题联系wx:m989783106</p>
            </fieldset>
        </section>

之间用<fieldset>做了区域划分,简单又美观。

<button>均使用onclick属性进行触发

在上传前会检测用户是否选择文件,是否选择的是MP4文件

// 获取文件
const input=document.getElementById('video')
const file=input.files[0]
if (!file){
    throw new Error('未选择文件')
}
if(file.type!=='video/mp4'){
    throw new Error('请选择MP4文件')
}

 一共有3个请求:

  • 请求上传文件,将MP4上传给后端,然后后端运行模型发送质心坐标给前端渲染
  • 请求终止程序,当用户想终止后端运行模型,重新上传文件时
  • 请求获取模型处理进度,后端返回进度给前端,前端进行渲染展示

画轨迹图,前端用plotly.js将质心坐标进行渲染,同时轨迹图还有一定的交互能力。

// 画图
var trace=[{
    x: data.position_data.map(item=>item[0]),
    y: data.position_data.map(item=>item[1]),
    mode:"lines",
    line:{
            color:'rgb(32, 139, 139)'
        }
}]
var layout = {
    xaxis: {
        range: [0, 600],
        title: "x(像素)",
        titlefont: {  
            color: 'rgb(32, 139, 139)' // 轴标签颜色  
        },  
        linecolor: 'rgb(32, 139, 139)', // 轴线颜色  
        tickfont: {  
            color: 'lrgb(32, 139, 139)' // 轴刻度标签颜色  
        }
    },
    yaxis: {range: [0, 600],
        title: "y(像素)",
        titlefont: {  
            color: 'rgb(32, 139, 139)' // 轴标签颜色  
        },  
        linecolor: 'rgb(32, 139, 139)', // 轴线颜色  
        tickfont: {  
            color: 'lrgb(32, 139, 139)' // 轴刻度标签颜色  
        }
    },  
    title: "鳌虾运动轨迹",
    titlefont:{
        color:'rgb(32, 139, 139)'
    },
    plot_bgcolor: 'rgba(0,0,0,0)',
    paper_bgcolor:'rgba(0,0,0,0)'
    }
Plotly.newPlot("s2", trace, layout,{scrollZoom: true,editable: true })

其余的就是代码的排布顺序,异步执行调度,错误处理能力,系统稳定性,用户交互能力的提升,细节很多,均包含在代码中


右边的<section>容纳进度圈,轨迹图

        <section id="s2">
            <div id="progress-circle"><p id="progress-num"></p></div>
        </section>

在文件上传时,就初始化渲染进度条,然后异步请求获取进度,渲染到页面;当进度到达一定值,比如99%,就关闭获取进度的请求,同时设置进度条的display=none。当用户打断程序执行或者重新运行程序,就清理轨迹图,初始化进度条,循环往复。

后端设计

后端整体使用flask,jinjia模板,将flask与pywebview结合。把模型检测代码封装到一个类TrackShrimp,其余的就是各种请求函数。

代码总览

import webview
from flask import Flask, request, jsonify,render_template,stream_with_context,Response
import os
import time
import cv2
from onnxruntime import InferenceSession
import numpy as np
from werkzeug.utils import secure_filename

class TrackShrimp():
    def __init__(self,video_path,model_path,onnx_threshold=0.7):
        # 获取帧数据
        self.cap=self.init_video(video_path)
        frame_width=int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height=int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        # 图形尺寸
        im_shape = np.array([[frame_height, frame_width]], dtype='float32')
        # y轴缩放量
        self.im_scale_y=640.0/frame_height
        # x轴缩放量
        self.im_scale_x=640.0/frame_width
        scale_factor = np.array([[self.im_scale_y,self.im_scale_x]]).astype('float32')
        # 定义模型输入
        self.inputs_dict = {
            'im_shape': im_shape,
            'image': None,
            'scale_factor': scale_factor
            }
        # 初始化模型
        self.sess=self.init_session(model_path)
        # 模型输出阈值
        self.onnx_threshold=onnx_threshold

    def init_video(self,input_path):
        cap=cv2.VideoCapture(input_path)
        if not cap.isOpened():
            raise ValueError(f'无法打开视频{input_path}')
        return cap

    def init_session(self,model_path):
        try:
            return InferenceSession(model_path, providers=['CUDAExecutionProvider']) 
        except:
            return InferenceSession(model_path, providers=['CPUExecutionProvider'])
        

    def precess_img(self,frame):
        img = cv2.resize(frame, None,None,fx=self.im_scale_x,fy=self.im_scale_y,interpolation=2)
        img = img.astype(np.float32) / 255.0
        img = np.transpose(img, [2, 0, 1])
        img = img[np.newaxis, :, :, :]
        return img
 
    def postcess(self,results:np.ndarray,all_centers:list[np.ndarray]):
        results=results[(results[:, 0] == 0) & (results[:, 1] > self.onnx_threshold)]
        x_centers = (results[:, 2] + results[:, 4]) / 2
        y_centers = (results[:, 3] + results[:, 5]) / 2
        centers = np.column_stack((x_centers, y_centers))
        all_centers.extend(centers)

    def by_smoothfilter(self,centers:list[np.ndarray],window_size=24):
        """
        :param centers: list[np.ndarray,np.ndarray,...]
        :param window_size: 平滑窗口大小
        :return: 平滑后的质心坐标NumPy数组
        """
        centers=np.stack(centers)
        # 计算滑动窗口的平均值,pad函数在序列前后补零以处理边界情况
        padded_centers = np.pad(centers, ((window_size//2, window_size//2), (0, 0)), mode='edge')
        window_sum = np.cumsum(padded_centers, axis=0)
        smoothed_centers = (window_sum[window_size:] - window_sum[:-window_size]) / window_size
        return smoothed_centers

    def calculate_distance(self,centers:np.ndarray):
        '''
        centers:np.ndarray n*2
        '''
        # 计算相邻点之间的差
        diffs = centers[1:] - centers[:-1]
        # 计算每个差值的欧几里得距离
        distances = np.linalg.norm(diffs, axis=1)
        # 计算总路程
        return int(np.sum(distances))

    def gain_position(self,centers:np.ndarray):
        position_list=centers.tolist()
        return position_list
    
    def run(self):
        global schedule
        global run_task
        # 帧数
        frame_count=int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_number=0
        center_list=[]
        for frame_number in range(frame_count):
            if not run_task:
                return
            success, frame = self.cap.read()
            if not success:
                break
            schedule=int(frame_number/frame_count*100)
            # 打印进度
            if frame_number%10==0:
                print('Process: ',schedule)
            # 图片预处理
            img=self.precess_img(frame)
            self.inputs_dict['image']=img
            results=self.sess.run(None,self.inputs_dict)[0]
            if results is not None:
                self.postcess(results,center_list)
        # 使用平滑滤波
        filtered_centers = self.by_smoothfilter(center_list)
        self.cap.release()

        # 返回路程,轨迹坐标
        return self.calculate_distance(filtered_centers),self.gain_position(filtered_centers)


app = Flask(__name__)
UPLOAD_FOLDER = './temp'
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
schedule=0
run_task=False

def run_flask():
    app.run(debug=False, threaded=True,host='127.0.0.1',port=5000)

def video_process(video_path):
    return TrackShrimp(video_path,'./model.onnx').run()

# 主页面
@app.route('/',methods=['POST','GET'])
def return_main_page():
    return render_template('index.html')

# 检测视频页面
@app.route('/shrimp',methods=['POST'])
def shrimp_track():
    global run_task
    run_task=True
    file=request.files.get('video')
    filename = secure_filename(file.filename)
    video_path=os.path.join(app.config['UPLOAD_FOLDER'],filename)
    file.save(video_path)
    try:
        results=video_process(video_path)
        if results is not None:
            distance,position_data=results
            data = {
                'distance': distance,
                'position_data': position_data
            }
            run_task=False
            return jsonify(data)
        else:
            return jsonify({'data':'任务被终止'})
    except Exception as e:
        print('error:',e)
        return jsonify({'data':'任务被终止'})
    finally:
        if os.path.exists(video_path):
            os.remove(video_path)

@app.route('/stopRun',methods=['GET','POST'])
def stopRun():
    global run_task
    global schedule
    if run_task:
        run_task=False
        schedule=0
        return jsonify({'data':'正在停止任务'})
    else:
        return jsonify({'data':'当前没有任务运行'})

# 进度查询路由
@app.route('/progress',methods=['GET'])
def progress():
    @stream_with_context
    def generate():
        global run_task
        ratio = schedule
        while ratio < 95 and run_task:
            yield "data:" + str(ratio) + "\n\n"
            ratio = schedule
            time.sleep(5)
    return Response(generate(), mimetype='text/event-stream')

if __name__=='__main__':
    # 启动后端  
    # flask_thread = threading.Thread(target=run_flask)  
    # flask_thread.start()
    # time.sleep(1)
    # 启动前端
    webview.create_window('鳌虾轨迹侦测',url=app,width=900,height=600)
    # webview.create_window('鳌虾轨迹侦测',url=f'http://127.0.0.1:5000',width=900,height=600)
    webview.start()

代码分览

一个onnx部署的类TrackShrimp,详细见前面。

一些常量的定义

app = Flask(__name__)
UPLOAD_FOLDER = './temp' # 文件的上传路径,后端需要该路径保留用户上传的文件
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
schedule=0 # 实时进度,初始化进度为0
run_task=False # 一个onnx模型是否在运行的标志,用于接收用户中断信号从而终止模型运行

定义一个flask的·启动函数,用于web调试,浏览器F12启动调试窗口

def run_flask():
    app.run(debug=False, threaded=True,host='127.0.0.1',port=5000)

主页面的请求函数,该页面为主要的UI

# 主页面
@app.route('/',methods=['POST','GET'])
def return_main_page():
    return render_template('index.html')

用户请求中断的请求函数

首先通过标志位(run_task)检测模型是否在跑,如果检测到模型正在运行,就把标志位设为False,然后把进度归0

@app.route('/stopRun',methods=['GET','POST'])
def stopRun():
    global run_task
    global schedule
    if run_task:
        run_task=False
        schedule=0
        return jsonify({'data':'正在停止任务'})
    else:
        return jsonify({'data':'当前没有任务运行'})

进度查询

这里设置当进度为95%时,就停止查询。

@app.route('/progress',methods=['GET'])
def progress():
    @stream_with_context
    def generate():
        global run_task
        ratio = schedule
        while ratio < 95 and run_task:
            yield "data:" + str(ratio) + "\n\n"
            ratio = schedule
            time.sleep(5)
    return Response(generate(), mimetype='text/event-stream')

一个检测的入口函数

def video_process(video_path):
    return TrackShrimp(video_path,'./model.onnx').run()

接收用户上传文件的函数

一旦用户上传文件,就设置运行标志位为True,然后将文件保存,再送入模型运行接口函数,当用户请求终止时,results为None,所以使用if else进行区分。模型结果出来后就把标志位设为False,同时将数据传到前端

@app.route('/shrimp',methods=['POST'])
def shrimp_track():
    global run_task
    run_task=True
    file=request.files.get('video')
    filename = secure_filename(file.filename)
    video_path=os.path.join(app.config['UPLOAD_FOLDER'],filename)
    file.save(video_path)
    try:
        results=video_process(video_path)
        if results is not None:
            distance,position_data=results
            data = {
                'distance': distance,
                'position_data': position_data
            }
            run_task=False
            return jsonify(data)
        else:
            return jsonify({'data':'任务被终止'})
    except Exception as e:
        print('error:',e)
        return jsonify({'data':'任务被终止'})
    finally:
        if os.path.exists(video_path):
            os.remove(video_path)

接着就是启动所有代码了,为了调试方便,我写了2份代码,一份用于调试,一份用于成品

if __name__=='__main__':
    # 启动前端
    webview.create_window('鳌虾轨迹侦测',url=app,width=900,height=600)
    webview.start()
if __name__=='__main__':
    # 启动后端  
    flask_thread = threading.Thread(target=run_flask)  
    flask_thread.start()
    time.sleep(1)
    # 启动前端
    webview.create_window('鳌虾轨迹侦测',url=f'http://127.0.0.1:5000',width=900,height=600)
    webview.start()

pyinstaller打包

进入项目目录,命令行输入

piinstaller -D -w main.py

找到生成的main.spec文件,按如下修改

# -*- mode: python ; coding: utf-8 -*-


a = Analysis(
    ['main.py'],
    pathex=[],
    binaries=[],
    datas=[('templates/','templates/'),('static/','static/'),('venv/Lib/site-packages/onnxruntime/capi/onnxruntime_providers_shared.dll','onnxruntime/capi/'),('venv/Lib/site-packages/onnxruntime/capi/onnxruntime_providers_cuda.dll','onnxruntime/capi/')],
    hiddenimports=[],
    hookspath=[],
    hooksconfig={},
    runtime_hooks=[],
    excludes=[],
    noarchive=False,
    optimize=0,
)
pyz = PYZ(a.pure)

exe = EXE(
    pyz,
    a.scripts,
    [],
    exclude_binaries=True,
    name='main',
    debug=False,
    bootloader_ignore_signals=False,
    strip=False,
    upx=True,
    console=False,
    disable_windowed_traceback=False,
    argv_emulation=False,
    target_arch=None,
    codesign_identity=None,
    entitlements_file=None,
    icon='1.ico'
)
coll = COLLECT(
    exe,
    a.binaries,
    a.datas,
    strip=False,
    upx=True,
    upx_exclude=[],
    name='main',
)

在项目目录下放置一个图标命名为1.ico,最好是48*48像素

然后命令行运行

pyinstaller main.spec

然后在venv中找到 onnxruntime_gpu-1.18.1.dist-info 文件夹,复制到 dist/main/_internal 中

同时在cuDNN中找到如下几个动态链接库,复制到 dist/main/_internal 中

cudnn_ops_infer64_8.dll
cudnn_cnn_infer64_8.dll
cudnn_adv_infer64_8.dll
cudnn64_8.dll
cudart64_110.dll
cublasLt64_11.dll
cublas64_11.dll
cufft64_10.dll

然后将model.onnx放到 dist/main/ ,并在该目录创建一个目录temp

最后处理的结果如下

XXX
    dist
        main
            _internal
            main.exe
            model.onnx
            temp

生成安装包

使用into setup软件,并在网站找到中文的语言包下载为 Chinese.isl 文件,放到intosetup软件安装目录的 Languages 文件夹下

接着如图所示

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

取消立即编译,先进入文件里修改一些东西

修改成下面这样 

点击编译

然后就生成了安装包,就可以在任何win10,win11电脑里用CPU跑了,如果安装的电脑 有显卡和CUDA并把CUDA添加到了环境变量,就可以用GPU跑了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1901069.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux Centos7部署Zookeeper

目录 一、下载zookeeper 二、单机部署 1、创建目录 2、解压 3、修改配置文件名 ​4、创建保存数据的文件夹 ​5、修改配置文件保存数据的地址 ​6、启动服务 7、api创建节点 一、下载zookeeper 地址&#xff1a;Index of /dist/zookeeper/zookeeper-3.5.7 (apache.org…

在5G/6G应用中实现高性能放大器的建模挑战

来源&#xff1a;Modelling Challenges for Enabling High Performance Amplifiers in 5G/6G Applications {第28届“集成电路和系统的混合设计”(Mixed Design of Integrated Circuits and Systems)国际会议论文集&#xff0c;2021年6月24日至26日&#xff0c;波兰洛迪} 本文讨…

Stream 很好,Map 很酷,但答应我别用 toMap()

文章目录 Collectors.toMap() 的常见问题替代方案1. 使用 Collectors.groupingBy()2. 处理空值3. 自定义合并逻辑 总结 &#x1f389;欢迎来到Java学习路线专栏~探索Java中的静态变量与实例变量 ☆* o(≧▽≦)o *☆嗨~我是IT陈寒&#x1f379;✨博客主页&#xff1a;IT陈寒的博…

键盘异常的检测与解决方案

今天对象用Word写文档&#xff0c;按下Ctrl的时候&#xff0c;页面不停地上下滑动&#xff0c;导致无法正常编辑文本。 重启之后&#xff0c;仍然无法解决&#xff0c;推断是键盘坏了。 但是当按下Fn或其他功能键&#xff0c;焦点移除&#xff0c;页面就不会再抖动了。 现在…

2.2.2.1 如何在vscode 中设置ROS2的 用户代码片段

1. vscode中设置C版本的ROS2用户代码片段 1) 找到vscode 下的设置选项&#xff0c;选择用户代码片段 2) 选择用户代码片段后&#xff0c;会弹出选择框&#xff0c;如下图&#xff0c;输入C,选择 cpp.json 配置好的文件 进入如下文件&#xff0c;下图为本人配置的代码片段模版文…

E1.【C语言】练习:用函数求两个整数的较大值

有关创建函数见&#xff1a; 12.【C语言】创建函数 写法 1&#xff1a;if语句 #define _CRT_SECURE_NO_WARNINGS 1 #include <stdio.h> int max(int a, int b) {if (a > b)return a;elsereturn b; } int main() {int a 0;int b 0;scanf("%d%d", &a,…

数据可视化之智慧城市的脉动与洞察

在数字化转型的浪潮中,城市作为社会经济发展的核心单元,正经历着前所未有的变革。城市数据可视化大屏看板作为这一变革中的重要工具,不仅极大地提升了城市管理效率,还为公众提供了直观、全面的城市运行状态视图,成为智慧城市建设不可或缺的一部分。本文将深入探讨以“城市…

【MySQL04】【 redo 日志】

文章目录 一、前言二、redo 日志1. redo 日志格式2. Mini-Transaction2.1 以组的形式写入 redo 日志2.2 Mini-Transaction &#xff08;MTR&#xff09;概念 3. redo 日志写入过程3.1 redo 日志缓冲区3.3 redo 日志写入 log buffer 4. redo 日志文件4.1 redo 日志刷盘机制4.2 r…

实现桌面动态壁纸(二)

目录 前言 一、关于 WorkerW 工作区窗口 二、关于窗口关系 2.1 窗口以及窗口隶属关系 2.2 桌面管理层窗口组分简析 2.3 厘清两个概念的区别 2.4 关于设置父窗口 三、编写代码以供在 Vista 上实现 3.1 方法二&#xff1a;子类化并自绘窗口背景 四、初步分析桌面管理层…

qt 如何添加子项目

首先我们正常流程创建一个项目文件&#xff1a; 这是我已经创建好的&#xff0c;请无视红线 然后找到该项目的文件夹&#xff0c;在文件夹下创建一个文件夹&#xff0c;再到创建好的文件夹下面创建一个 .pri 文件&#xff1a; &#xff08;创建文件夹&#xff09; &#xff08…

自闭症在生活中的典型表现

自闭症&#xff0c;这个看似遥远却又悄然存在于我们周围的疾病&#xff0c;其影响深远且复杂。在日常生活中&#xff0c;自闭症患者的典型表现往往让人印象深刻&#xff0c;这些表现不仅揭示了他们内心的世界&#xff0c;也提醒我们要以更加包容和理解的心态去面对他们。 首先…

嵌入式C语言面试相关知识——关键字(不定期更新)

嵌入式C语言面试相关知识——关键字 一、博客声明二、C语言关键字1、sizeof关键字2、static关键字3、const关键字4、volatile关键字5、extern关键字 一、博客声明 又是一年一度的秋招&#xff0c;怎么能只刷笔试题目呢&#xff0c;面试题目也得看&#xff0c;想当好厂的牛马其实…

六、快速启动框架:SpringBoot3实战-个人版

六、快速启动框架&#xff1a;SpringBoot3实战 文章目录 六、快速启动框架&#xff1a;SpringBoot3实战一、SpringBoot3介绍1.1 SpringBoot3简介1.2 系统要求1.3 快速入门1.4 入门总结回顾复习 二、SpringBoot3配置文件2.1 统一配置管理概述2.2 属性配置文件使用2.3 YAML配置文…

前端面试题8

基础知识 解释一下什么是跨域问题&#xff0c;以及如何解决&#xff1f; 跨域问题是由于浏览器的同源策略限制了从一个源加载的网页脚本访问另一个源的数据。解决方法包括使用JSONP、CORS&#xff08;跨源资源共享&#xff09;、设置代理服务器等。 描述一下事件冒泡和事件捕获…

kubernetes集群部署:node节点部署和cri-docker运行时安装(四)

安装前准备 同《kubernetes集群部署&#xff1a;环境准备及master节点部署&#xff08;二&#xff09;》 安装cri-docker 在 Kubernetes 1.20 版本之前&#xff0c;Docker 是 Kubernetes 默认的容器运行时。然而&#xff0c;Kubernetes 社区决定在 Kubernetes 1.20 及以后的…

Spring中的事件监听器使用学习

一、什么是Spring中的事件监听机制&#xff1f; Spring框架中的事件监听机制是一种设计模式&#xff0c;它允许你定义和触发事件&#xff0c;同时允许其他组件监听这些事件并在事件发生时作出响应。这种机制基于观察者模式&#xff0c;提供了一种松耦合的方式来实现组件间的通信…

自动缩放 win7 远程桌面

https://mremoteng.org/download 用这个软件&#xff0c;下载 zip 版&#xff0c;不需要管理员权限 在这里找到的&#xff0c;选票最高的一个就是 https://superuser.com/questions/1030041/remote-desktop-zoom-and-full-screen-how-win10-remote-win7-2008-2003-ho

蓝桥杯开发板STM32G431RBT6高阶HAL库学习FreeRtos——认识HAL_Delay和osDelay的区别

一、修改两个任务的优先级 任务一 任务二 二、使用HAL_Delay的实验结果 结果&#xff1a; LED1亮&#xff0c;LED2不亮 三、使用osDelay的实验结果 结果&#xff1a; LED1亮&#xff0c;LED2亮 四、解释原因 vTaskDelay 与 HAL_Delay 的区别 1.vTaskDelay 作用是让任务阻…

基于RK3588的8路摄像头实时全景拼接

基于RK3588的8路摄像头实时全景拼接 输入&#xff1a;2路csi转8路mpi的ahd摄像头&#xff0c;分辨率1920 * 1080 8路拼接结果&#xff1a; 6路拼接结果&#xff1a; UI界面&#xff1a; UI节目设计原理

Python爬虫获取视频

验证电脑是否安装python 1.winr输入cmd 2.在黑窗口输入 python.exe 3.不是命令不存在就说明python环境安装完成 抓取快手视频 1.在phcharm应用中新建一个项目 3.新建一个python文件 4.选择python文件,随便起一个名字后按回车 5.安装requests pip install requests 6.寻找需要的…