英伟达SSD视觉算法分类代码解析

news2025/1/16 17:31:01

一、官方原代码

#!/usr/bin/env python3
#
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#

import sys
import argparse

from jetson_inference import imageNet
from jetson_utils import videoSource, videoOutput, cudaFont, Log

# parse the command line
parser = argparse.ArgumentParser(description="Classify a live camera stream using an image recognition DNN.", 
                                 formatter_class=argparse.RawTextHelpFormatter, 
                                 epilog=imageNet.Usage() + videoSource.Usage() + videoOutput.Usage() + Log.Usage())

parser.add_argument("input", type=str, default="", nargs='?', help="URI of the input stream")
parser.add_argument("output", type=str, default="", nargs='?', help="URI of the output stream")
parser.add_argument("--network", type=str, default="googlenet", help="pre-trained model to load (see below for options)")
parser.add_argument("--topK", type=int, default=1, help="show the topK number of class predictions (default: 1)")

try:
	args = parser.parse_known_args()[0]
except:
	print("")
	parser.print_help()
	sys.exit(0)


# load the recognition network
net = imageNet(args.network, sys.argv)

# note: to hard-code the paths to load a model, the following API can be used:
#
# net = imageNet(model="model/resnet18.onnx", labels="model/labels.txt", 
#                 input_blob="input_0", output_blob="output_0")

# create video sources & outputs
input = videoSource(args.input, argv=sys.argv)
output = videoOutput(args.output, argv=sys.argv)
font = cudaFont()

# process frames until EOS or the user exits
while True:
    # capture the next image
    img = input.Capture()

    if img is None: # timeout
        continue  

    # classify the image and get the topK predictions
    # if you only want the top class, you can simply run:
    #   class_id, confidence = net.Classify(img)
    predictions = net.Classify(img, topK=args.topK)

    # draw predicted class labels
    for n, (classID, confidence) in enumerate(predictions):
        classLabel = net.GetClassLabel(classID)
        confidence *= 100.0

        print(f"imagenet:  {confidence:05.2f}% class #{classID} ({classLabel})")

        font.OverlayText(img, text=f"{confidence:05.2f}% {classLabel}", 
                         x=5, y=5 + n * (font.GetSize() + 5),
                         color=font.White, background=font.Gray40)
                         
    # render the image
    output.Render(img)

    # update the title bar
    output.SetStatus("{:s} | Network {:.0f} FPS".format(net.GetNetworkName(), net.GetNetworkFPS()))

    # print out performance info
    net.PrintProfilerTimes()

    # exit on input/output EOS
    if not input.IsStreaming() or not output.IsStreaming():
        break

二、代码解析

代码增加中文注释

#!/usr/bin/env python3
#
# 版权所有 (c) 2020, NVIDIA CORPORATION. 保留所有权利。
#
# 特此免费授予获得此软件和相关文档文件(“软件”)副本的任何人,允许他们在不受限制的情况下处理软件,
# 包括但不限于使用、复制、修改、合并、发布、分发、再许可和/或出售软件副本,并允许提供软件的人
# 这样做,条件如下:
#
# 上述版权声明和本许可声明应包含在软件的所有副本或主要部分中。
#
# 本软件按“原样”提供,不提供任何形式的明示或暗示保证,包括但不限于适销性、
# 适用于特定目的和不侵权的保证。在任何情况下,作者或版权持有人均不对因使用本软件或其他交易,
# 或因使用本软件或其他交易而产生的任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权诉讼还是其他诉讼中。
#

import sys
import argparse

from jetson_inference import imageNet
from jetson_utils import videoSource, videoOutput, cudaFont, Log

# 解析命令行参数
parser = argparse.ArgumentParser(
    description="使用图像识别DNN对实时摄像头流进行分类。",
    formatter_class=argparse.RawTextHelpFormatter,
    epilog=imageNet.Usage() + videoSource.Usage() + videoOutput.Usage() + Log.Usage()
)

parser.add_argument("input", type=str, default="", nargs='?', help="输入流的URI")
parser.add_argument("output", type=str, default="", nargs='?', help="输出流的URI")
parser.add_argument("--network", type=str, default="googlenet", help="要加载的预训练模型(参见下方选项)")
parser.add_argument("--topK", type=int, default=1, help="显示前K个类别预测(默认:1)")

try:
    args = parser.parse_known_args()[0]
except:
    print("")
    parser.print_help()
    sys.exit(0)

# 加载识别网络
net = imageNet(args.network, sys.argv)

# 注意:要硬编码加载模型的路径,可以使用以下API:
# net = imageNet(model="model/resnet18.onnx", labels="model/labels.txt", 
#                input_blob="input_0", output_blob="output_0")

# 创建视频源和输出
input = videoSource(args.input, argv=sys.argv)
output = videoOutput(args.output, argv=sys.argv)
font = cudaFont()

# 处理帧直到输入结束或用户退出
while True:
    # 捕获下一帧图像
    img = input.Capture()

    if img is None:  # 超时
        continue  

    # 对图像进行分类并获取前K个预测
    # 如果只需要最顶层的类别,可以简单地运行:
    #   class_id, confidence = net.Classify(img)
    predictions = net.Classify(img, topK=args.topK)

    # 绘制预测的类别标签
    for n, (classID, confidence) in enumerate(predictions):
        classLabel = net.GetClassLabel(classID)
        confidence *= 100.0

        print(f"imagenet:  {confidence:05.2f}% class #{classID} ({classLabel})")

        font.OverlayText(
            img, 
            text=f"{confidence:05.2f}% {classLabel}", 
            x=5, y=5 + n * (font.GetSize() + 5),
            color=font.White, background=font.Gray40
        )
                         
    # 渲染图像
    output.Render(img)

    # 更新标题栏
    output.SetStatus("{:s} | Network {:.0f} FPS".format(net.GetNetworkName(), net.GetNetworkFPS()))

    # 打印性能信息
    net.PrintProfilerTimes()

    # 输入/输出流结束时退出
    if not input.IsStreaming() or not output.IsStreaming():
        break

这段Python代码是一个使用NVIDIA的Jetson平台进行图像分类的示例程序。代码解析如下:

头部版权声明和许可信息

这部分代码声明了版权信息和软件许可,允许免费使用、复制和分发软件。

导入模块

import sys
import argparse
from jetson_inference import imageNet
from jetson_utils import videoSource, videoOutput, cudaFont, Log
  • sys: 处理系统特定的参数和功能。
  • argparse: 解析命令行参数。
  • jetson_inferencejetson_utils模块用于加载和处理图像分类模型、视频源、视频输出、绘制字体和日志记录。

解析命令行参数

parser = argparse.ArgumentParser(description="Classify a live camera stream using an image recognition DNN.", 
                                 formatter_class=argparse.RawTextHelpFormatter, 
                                 epilog=imageNet.Usage() + videoSource.Usage() + videoOutput.Usage() + Log.Usage())

parser.add_argument("input", type=str, default="", nargs='?', help="URI of the input stream")
parser.add_argument("output", type=str, default="", nargs='?', help="URI of the output stream")
parser.add_argument("--network", type=str, default="googlenet", help="pre-trained model to load (see below for options)")
parser.add_argument("--topK", type=int, default=1, help="show the topK number of class predictions (default: 1)")

try:
	args = parser.parse_known_args()[0]
except:
	print("")
	parser.print_help()
	sys.exit(0)
  • 使用argparse模块定义和解析命令行参数,包括输入和输出流的URI、使用的预训练模型和显示前K个预测结果的数量。
  • 尝试解析命令行参数,如果解析失败,则显示帮助信息并退出程序。

加载图像分类网络

net = imageNet(args.network, sys.argv)
  • 使用imageNet类加载预训练的神经网络模型。

创建视频源和视频输出

input = videoSource(args.input, argv=sys.argv)
output = videoOutput(args.output, argv=sys.argv)
font = cudaFont()
  • 使用videoSource类创建视频输入流。
  • 使用videoOutput类创建视频输出流。
  • 使用cudaFont类创建用于绘制文本的字体。

处理视频帧

while True:
    # capture the next image
    img = input.Capture()

    if img is None: # timeout
        continue  

    # classify the image and get the topK predictions
    predictions = net.Classify(img, topK=args.topK)

    # draw predicted class labels
    for n, (classID, confidence) in enumerate(predictions):
        classLabel = net.GetClassLabel(classID)
        confidence *= 100.0

        print(f"imagenet:  {confidence:05.2f}% class #{classID} ({classLabel})")

        font.OverlayText(img, text=f"{confidence:05.2f}% {classLabel}", 
                         x=5, y=5 + n * (font.GetSize() + 5),
                         color=font.White, background=font.Gray40)
                         
    # render the image
    output.Render(img)

    # update the title bar
    output.SetStatus("{:s} | Network {:.0f} FPS".format(net.GetNetworkName(), net.GetNetworkFPS()))

    # print out performance info
    net.PrintProfilerTimes()

    # exit on input/output EOS
    if not input.IsStreaming() or not output.IsStreaming():
        break
  • 使用input.Capture()捕获下一帧图像。
  • 如果捕获失败,则继续下一次循环。
  • 使用net.Classify()对图像进行分类,并获取前K个预测结果。
  • 使用font.OverlayText()在图像上绘制预测的类别标签和置信度。
  • 使用output.Render()渲染图像。
  • 使用output.SetStatus()更新输出窗口的标题栏,显示网络名称和帧率。
  • 使用net.PrintProfilerTimes()打印性能信息。
  • 如果输入或输出流停止,则退出循环。

总体来说,这段代码实现了一个实时图像分类系统,使用Jetson平台的预训练模型对视频流中的每一帧进行分类,并在输出视频中叠加预测结果。



# 三、imageNet类有哪些对象和方法
`imageNet`类是Jetson Inference库的一部分,用于处理图像分类任务。以下是`imageNet`类的主要对象和方法:

## 对象属性
- `NetworkType`: 预定义的网络类型枚举。
- `model`: 加载的模型文件路径。
- `labels`: 加载的标签文件路径。
- `input_blob`: 输入层名称。
- `output_blob`: 输出层名称。

## 主要方法

1. **`__init__`**:
   ```python
   def __init__(self, network="googlenet", argv=[], model=None, labels=None, input_blob="input_0", output_blob="output_0"):
  • 初始化图像分类网络。
  • 参数包括网络类型、命令行参数、模型文件路径、标签文件路径、输入层名称和输出层名称。
  1. Classify:

    def Classify(self, img, topK=1):
    
    • 对给定的图像进行分类。
    • 返回预测结果的列表,每个结果包括类别ID和置信度。
  2. GetClassLabel:

    def GetClassLabel(self, classID):
    
    • 返回指定类别ID的标签。
  3. GetNetworkName:

    def GetNetworkName(self):
    
    • 返回网络的名称。
  4. GetNetworkFPS:

    def GetNetworkFPS(self):
    
    • 返回网络处理帧率(FPS)。
  5. PrintProfilerTimes:

    def PrintProfilerTimes(self):
    
    • 打印网络的性能分析信息。
  6. Usage:

    @staticmethod
    def Usage():
    
    • 返回类的用法说明,通常用于命令行帮助信息。

使用示例

以下是如何使用imageNet类的简单示例:

from jetson_inference import imageNet
from jetson_utils import loadImage

# 初始化图像分类网络
net = imageNet("googlenet")

# 加载图像
img = loadImage("example.jpg")

# 对图像进行分类
class_id, confidence = net.Classify(img)

# 获取类别标签
class_label = net.GetClassLabel(class_id)

print(f"Image is classified as {class_label} with {confidence * 100:.2f}% confidence")

这个示例展示了如何初始化一个imageNet对象,加载一张图像,并对其进行分类,最后打印分类结果和置信度。

三、使用示例

 python3 imagenet.py /dev/video0 display://0 --network=googlenet

在这里插入图片描述
在这里插入图片描述

四、训练自己的分类模型

以下是包含中文注释的SSD训练代码示例:

1. 安装TAO Toolkit

确保在具有NVIDIA GPU的系统上安装了Docker和NVIDIA Container Toolkit。

2. 拉取TAO Toolkit Docker容器

docker pull nvcr.io/nvidia/tao/tao-toolkit-tf:v3.21.11-tf1.15.5-py3

3. 准备数据

准备训练和验证数据,数据应按照Kitti或Pascal VOC格式组织,包含图像文件和对应的标注文件。

4. 创建SSD配置文件

以下是SSD配置文件的示例,并包含中文注释:

random_seed: 42  # 随机种子,用于确保实验的可重复性
dataset_config {
  data_sources: {
    label_directory_path: "/path/to/labels"  # 训练数据的标签路径
    image_directory_path: "/path/to/images"  # 训练数据的图像路径
  }
  validation_data_sources: {
    label_directory_path: "/path/to/val_labels"  # 验证数据的标签路径
    image_directory_path: "/path/to/val_images"  # 验证数据的图像路径
  }
}
model_config {
  pretrained_model_file: "/path/to/pretrained/model"  # 预训练模型文件路径
  num_layers: 18  # 模型的层数
  all_proposals: 200  # 所有提议框的数量
}
train_config {
  batch_size: 8  # 批次大小
  learning_rate: 0.001  # 学习率
  num_epochs: 80  # 训练轮数
  augmentations: {
    horizontal_flip: true  # 是否进行水平翻转数据增强
    vertical_flip: false  # 是否进行垂直翻转数据增强
  }
}

5. 运行训练

使用以下命令运行训练任务,并包含中文注释:

docker run --gpus all -v /path/to/your/data:/data -v /path/to/your/config:/config -v /path/to/your/output:/output nvcr.io/nvidia/tao/tao-toolkit-tf:v3.21.11-tf1.15.5-py3 ssd train \
  -e /config/ssd_config.yaml \  # 配置文件路径
  -r /output/experiment_dir \  # 实验输出目录
  -k $API_KEY  # TAO Toolkit的API密钥
  • --gpus all: 使用所有可用的GPU。
  • -v /path/to/your/data:/data: 将本地数据目录挂载到容器内的/data路径。
  • -v /path/to/your/config:/config: 将本地配置文件目录挂载到容器内的/config路径。
  • -v /path/to/your/output:/output: 将本地输出目录挂载到容器内的/output路径。
  • -e /config/ssd_config.yaml: 指定配置文件。
  • -r /output/experiment_dir: 指定实验输出目录。
  • -k $API_KEY: 指定TAO Toolkit的API密钥。

6. 导出模型

训练完成后,使用以下命令导出模型,并包含中文注释:

docker run --gpus all -v /path/to/your/output:/output nvcr.io/nvidia/tao/tao-toolkit-tf:v3.21.11-tf1.15.5-py3 ssd export \
  -m /output/experiment_dir/model.tlt \  # 输入的TAO模型路径
  -o /output/experiment_dir/model.etlt \  # 输出的ETLT模型路径
  -k $API_KEY  # TAO Toolkit的API密钥

总结

通过TAO Toolkit,你可以方便地对SSD目标检测模型进行训练。准备数据、配置训练参数并运行训练命令,可以帮助你快速训练自定义的目标检测模型并进行部署。详细的指南和更多高级功能可以参考TAO Toolkit的官方文档。

这样,代码和配置文件中都增加了中文注释,便于理解和使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1812757.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计算机毕业设计 | SSM 校园线上订餐系统 外卖购物网站(附源码)

1, 概述 1.1 项目背景 传统的外卖方式就是打电话预定,然而,在这种方式中,顾客往往通过餐厅散发的传单来获取餐厅的相关信息,通过电话来传达自己的订单信息,餐厅方面通过电话接受订单后,一般通…

UPerNet 统一感知解析:场景理解的新视角 Unified Perceptual Parsing for Scene Understanding

论文题目:统一感知解析:场景理解的新视角 Unified Perceptual Parsing for Scene Understanding 论文链接:http://arxiv.org/abs/1807.10221(ECCV 2018) 代码链接:https://github.com/CSAILVision/unifiedparsing 一、摘要 研究…

Java多线程之不可变对象(Immutable Object)模式

简介 多线程共享变量的情况下,为了保证数据一致性,往往需要对这些变量的访问进行加锁。而锁本身又会带来一些问题和开销。Immutable Object模式使得我们可以在不加锁的情况下,既保证共享变量访问的线程安全,又能避免引入锁可能带…

如何用二维码进行来访登记?这个模板帮你轻松实现!

在工厂、学校、写字楼、建筑工地等人员出入频繁的场所,使用传统的纸质登记方法容易造成数据丢失,而且信息核对过程繁琐,效率低下。 可以用二维码代替纸质登记本,访客进入时扫码就能登记身份信息,能够提高门岗访客管理…

微生信神助力:在线绘制发表级主成分分析(PCA)图

主成分分析(Principal components analysis,PCA)是一种线性降维方法。它利用正交变换对一系列可能相关的变量的观测值进行线性变换,从而投影为一系列线性不相关变量的值,这些不相关变量称为主成分(Principa…

JMH309【亲测】典藏3D魔幻端游【剑踪3DⅢ】GM工具+开区合区工具+PC客户端+配置修改教程+Win一键服务端+详细外网视频教程

资源介绍: 经典不错的一款端游 GM工具开区合区工具PC客户端配置修改教程Win一键服务端详细外网视频教程 资源截图: 下载地址

数字化医疗:揭秘物联网如何提升医院设备管理效率!

在当今数字化时代,医疗领域正迎来一场技术变革的浪潮,而基于物联网的智慧医院医疗设备管理体系正是这场变革的闪耀之星。想象一下,医院里的每一台医疗设备都能像一位精密的工匠一样,自动监测、精准诊断,甚至在发生故障…

GitLab教程(三):多人合作场景下如何pull代码和处理冲突

文章目录 1.拉取别人同步的代码到本地的流程2.push冲突发生场景情景模拟简单的解决方法 在这一章中,为了模拟多人合作的场景,我需要一个人分饰两角。 执行git clone xx远端仓库地址 xx文件夹命令,在clone代码时指定本地仓库的文件夹名&#…

33.星号三角阵(二)

上海市计算机学会竞赛平台 | YACSYACS 是由上海市计算机学会于2019年发起的活动,旨在激发青少年对学习人工智能与算法设计的热情与兴趣,提升青少年科学素养,引导青少年投身创新发现和科研实践活动。https://www.iai.sh.cn/problem/742 题目描述 给定一个整数 𝑛,输出一个…

解决:RuntimeError: “slow_conv2d_cpu“ not implemented for ‘Half‘的方法之一

1. 问题描述 今天跑实验的时候,代码报错: RuntimeError: "slow_conv2d_cpu" not implemented for Half 感觉有点莫名奇妙,经检索,发现将fp16改为fp32可以解决我的问题,但是运行速度太慢了。后来发现&…

基于WPF技术的换热站智能监控系统02--标题栏实现

1、布局划分 2、准备图片资源 3、界面UI控件 4、窗体拖动和关闭 5、运行效果 走过路过不要错过,点赞关注收藏又圈粉,共同致富,为财务自由作出贡献

理解线程安全:保护你的代码免受并发问题困扰

目录 前言 一、什么是线程安全? 二、为什么需要线程安全? 三、实现线程安全的方法 四、synchronized 使用 synchronized 关键字时,需要注意以下几点: 五、Demo讲解 前言 在现代软件开发中,尤其是在多线程编程中&…

【源码】二开版微盘交易系统/贵金属交易平台/微交易系统

二开版微盘交易系统/贵金属交易平台/微交易系统 一套二开前端UI得贵金属微交易系统,前端产品后台可任意更换 此系统框架不是以往的至尊的框架,系统完美运行,K线采用nodejs方式运行 K线结算都正常,附带教程 资源来源:https://www.…

C++ UML建模

starUML UML图转C代码 数据流图 E-R图 流程图 整体架构图 ORM关系图 参考 app.asar附件资源可免激活 JHBlog/设计模式/设计模式/1、StarUML使用简明教程.md at master SunshineBrother/JHBlog GitHub GitHub - dimon4ezzz/whitestaruml: UML modeling tool derived from …

汇编语言期末复习

目录 前言 基础知识 80x86计算机组织 80x86的寻址方式 前言 根据老师的PPT与IBM-PC汇编语言程序设计(第2版)而写,供考前突击所用。 基础知识 q 机器语言、汇编语言、高级程序语言 特性 比较 q 进位记数制与不同基数的数之间的转换 二进…

可变参数以及不可变集合

可变参数: 格式: public class ArgsDemo {public static void main(String[] args) {System.out.println(getSum(1,2,3,4,5));}//可变参数public static int getSum(int...args){int sum 0;for (int arg : args) {sum arg;}return sum;} }可变参数的…

笨蛋学算法之LeetCodeHot100_1_两数之和(Java)

package com.lsy.leetcodehot100;public class _Hot1_两数之和 {//自写方法public static int[] twoSum1(int[] nums, int target) {//定义存放返回变量的数组int[] arr new int[2];//遍历整个数组for (int i 0; i < nums.length; i) {//从第二个数开始相加判断for (int j…

RK3588 Debian11进行源码编译安装Pyqt5

RK3588 Debian11进行源码编译安装Pyqt5 参考链接 https://blog.csdn.net/qq_38184409/article/details/137047584?ops_request_misc%257B%2522request%255Fid%2522%253A%2522171808774816800222841743%2522%252C%2522scm%2522%253A%252220140713.130102334…%2522%257D&…

SpringBoot内置数据源

回顾: 在我们之前学习在配置文件当中配置对应的数据源的时候, 我们设置的数据源其实都是Druid的数据源, 并且其配置有两种方式, 当然这两种方式都需要我们导入对应的有关 德鲁伊 的依赖才行 一种是直接在开始设置为 druid 数据源类型的一种是在对应的正常的数据库配置下, 设置…

51 USART数据收发

1.0 USART实现单个数据收发 串口启动之前需要对串口进行初始化&#xff0c;主要是设置产生波特率的定时器1&#xff0c;使用串口的工作方式还是中断的工作方式具体的配置步骤如下所示。 注&#xff1a; 1&#xff1a; 确定TMOD &#xff08;定时器模式寄存器&#xff09; 确…