基于RetinaNet和TensorFlow Object Detection API实现目标检测(附源码)

news2024/12/26 14:10:26

文章目录

  • 一、RetinaNet原理
  • 二、RetinaNet实现
    • 1. tf.train.CheckPoint简介
    • 2. RetinaNet的TensorFlow源码

一、RetinaNet原理

在这里插入图片描述
待补充

二、RetinaNet实现

1. tf.train.CheckPoint简介

待补充

2. RetinaNet的TensorFlow源码

  Step 1:安装Tensorflow 2 Object Detection API及相关包

# 删除models文件夹下所有文件
!rm -rf ./models/
# 拷贝Tensorflow Model Garden
!git clone --depth 1 https://github.com/tensorflow/models/
# 编译Object Detection API protocol buffers
!cd models/research/ && protoc object_detection/protos/*.proto --python_out=.

%%writefile models/research/setup.py
import os
from setuptools import find_packages
from setuptools import setup

REQUIRED_PACKAGES = [
    'tf-models-official==2.8.0',
    'tensorflow_io==0.24.0',
    'numpy==1.21.5'
]

setup(
    name='object_detection',
    version='0.1',
    install_requires=REQUIRED_PACKAGES,
    include_package_data=True,
    packages=(
        [p for p in find_packages() if p.startswith('object_detection')] +
        find_packages(where=os.path.join('.', 'slim'))),
    package_dir={
        'datasets': os.path.join('slim', 'datasets'),
        'nets': os.path.join('slim', 'nets'),
        'preprocessing': os.path.join('slim', 'preprocessing'),
        'deployment': os.path.join('slim', 'deployment'),
        'scripts': os.path.join('slim', 'scripts'),
    },
    description='Tensorflow Object Detection Library',
    python_requires='>3.6',
)

# Run the setup script you just wrote
!python -m pip install models/research

  Step 2:导入包

import matplotlib
import matplotlib.pyplot as plt

import os
import random
import io
import imageio
import glob
import scipy.misc
import numpy as np
from six import BytesIO
from PIL import Image, ImageDraw, ImageFont
from IPython.display import display, Javascript
from IPython.display import Image as IPyImage

import tensorflow as tf

from object_detection.utils import label_map_util
from object_detection.utils import config_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.utils import colab_utils
from object_detection.builders import model_builder

%matplotlib inline

  Step 3:图片加载&画图工具函数定义

def load_image_into_numpy_array(path):
  """Load an image from file into a numpy array.

  Puts image into numpy array to feed into tensorflow graph.
  Note that by convention we put it into a numpy array with shape
  (height, width, channels), where channels=3 for RGB.

  Args:
    path: a file path.

  Returns:
    uint8 numpy array with shape (img_height, img_width, 3)
  """
  img_data = tf.io.gfile.GFile(path, 'rb').read()
  image = Image.open(BytesIO(img_data))
  (im_width, im_height) = image.size
  return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)

def plot_detections(image_np,
                    boxes,
                    classes,
                    scores,
                    category_index,
                    figsize=(12, 16),
                    image_name=None):
  """Wrapper function to visualize detections.

  Args:
    image_np: uint8 numpy array with shape (img_height, img_width, 3)
    boxes: a numpy array of shape [N, 4]
    classes: a numpy array of shape [N]. Note that class indices are 1-based,
      and match the keys in the label map.
    scores: a numpy array of shape [N] or None.  If scores=None, then
      this function assumes that the boxes to be plotted are groundtruth
      boxes and plot all boxes as black with no classes or scores.
    category_index: a dict containing category dictionaries (each holding
      category index `id` and category name `name`) keyed by category indices.
    figsize: size for the figure.
    image_name: a name for the image file.
  """
  image_np_with_annotations = image_np.copy()
  viz_utils.visualize_boxes_and_labels_on_image_array(
      image_np_with_annotations,
      boxes,
      classes,
      scores,
      category_index,
      use_normalized_coordinates=True,
      min_score_thresh=0.8)
  if image_name:
    plt.imsave(image_name, image_np_with_annotations)
  else:
    plt.imshow(image_np_with_annotations)

  Step 4:下载训练图片集(此处以training-zombie为例)

# download the images
!wget --no-check-certificate \
    https://storage.googleapis.com/tensorflow-3-public/datasets/training-zombie.zip \
    -O ./training-zombie.zip

import zipfile
# unzip to a local directory
local_zip = './training-zombie.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('./training')
zip_ref.close()

  Step 5:切换训练图片的路径,初始化训练图片list,并展示样例

train_image_dir = './training'
train_image_name = 'training-zombie'

# Load images and visualize
train_images_np = []
for i in range(1, 6):
  image_path = os.path.join(train_image_dir, train_image_name + str(i) + '.jpg')
  train_images_np.append(load_image_into_numpy_array(image_path))

plt.rcParams['axes.grid'] = False
plt.rcParams['xtick.labelsize'] = False
plt.rcParams['ytick.labelsize'] = False
plt.rcParams['xtick.top'] = False
plt.rcParams['xtick.bottom'] = False
plt.rcParams['ytick.left'] = False
plt.rcParams['ytick.right'] = False
plt.rcParams['figure.figsize'] = [14, 7]

for idx, train_image_np in enumerate(train_images_np):
  plt.subplot(2, 3, idx+1)  # 2, 3 -> 1, 5
  plt.imshow(train_image_np)
plt.show() # 样例展示

  样例效果如下图:
在这里插入图片描述
  Step 6:初始化边框位置(人为确定真实框线的坐标,用于训练)

gt_boxes = [
        np.array([[0.27333333, 0.41500586, 0.74333333, 0.57678781]], dtype=np.float32),
        np.array([[0.29833333, 0.45955451, 0.75666667, 0.61078546]], dtype=np.float32),
        np.array([[0.40833333, 0.18288394, 0.945, 0.34818288]], dtype=np.float32),
        np.array([[0.16166667, 0.61899179, 0.8, 0.91910903]], dtype=np.float32),
        np.array([[0.28833333, 0.12543962, 0.835, 0.35052755]], dtype=np.float32),
]

  Step 7:初始化待检测目标的label和分类,由于我们只检测一种物体,故分类为1

zombie_class_id = 1
num_classes = 1

category_index = {zombie_class_id: {'id': zombie_class_id, 'name': 'zombie'}}

  Step 8:将训练数据转换为tensor(即TensorFlow可识别的数据格式)

label_id_offset = 1
train_image_tensors = []
gt_classes_one_hot_tensors = []
gt_box_tensors = []
for (train_image_np, gt_box_np) in zip(
    train_images_np, gt_boxes):
  train_image_tensors.append(tf.expand_dims(tf.convert_to_tensor(
      train_image_np, dtype=tf.float32), axis=0))
  gt_box_tensors.append(tf.convert_to_tensor(gt_box_np, dtype=tf.float32))
  zero_indexed_groundtruth_classes = tf.convert_to_tensor(
      np.ones(shape=[gt_box_np.shape[0]], dtype=np.int32) - label_id_offset)
  gt_classes_one_hot_tensors.append(tf.one_hot(
      zero_indexed_groundtruth_classes, num_classes))
print('Done prepping data.')

  Step 9:展示准备好的训练tensor和边框(在数据的预处理过程中,要多观察数据是否正确)

dummy_scores = np.array([1.0], dtype=np.float32)  # give boxes a score of 100%

plt.figure(figsize=(30, 15))
for idx in range(5):
  plt.subplot(2, 3, idx+1)
  plot_detections(
      train_images_np[idx],
      gt_boxes[idx],
      np.ones(shape=[gt_boxes[idx].shape[0]], dtype=np.int32),
      dummy_scores, category_index)
plt.show()

  展示效果如下图:
在这里插入图片描述
  Step 10:下载Retinanet模型

!wget http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz
!tar -xf ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz
!mv ssd_resnet50_v1_fpn_640x640_coco17_tpu-8/checkpoint models/research/object_detection/test_data/

  Step 11:模型加载、修改(主要修改检测物体的类别数量)、weights初始化(通过假数据的预测初始化weights)

tf.keras.backend.clear_session()

print('Building model and restoring weights for fine-tuning...', flush=True)
num_classes = 1
pipeline_config = 'models/research/object_detection/configs/tf2/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.config'
checkpoint_path = 'models/research/object_detection/test_data/checkpoint/ckpt-0'

# Load pipeline config and build a detection model.
configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']
model_config.ssd.num_classes = num_classes
model_config.ssd.freeze_batchnorm = True
detection_model = model_builder.build(
      model_config=model_config, is_training=True)

fake_box_predictor = tf.train.Checkpoint(
    _base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,
    _box_prediction_head=detection_model._box_predictor._box_prediction_head,
    )
fake_model = tf.train.Checkpoint(
          _feature_extractor=detection_model._feature_extractor,
          _box_predictor=fake_box_predictor)
ckpt = tf.train.Checkpoint(model=fake_model)
ckpt.restore(checkpoint_path)

# Run model through a dummy image so that variables are created
image, shapes = detection_model.preprocess(tf.zeros([1, 640, 640, 3]))
prediction_dict = detection_model.predict(image, shapes)
_ = detection_model.postprocess(prediction_dict, shapes)
print('Weights restored!')

  Step 12:定义train_step和train_loop

tf.keras.backend.set_learning_phase(True)

# 训练参数设置
batch_size = 4
learning_rate = 0.01
num_batches = 100

# 从模型中选择需要fine tune的参数
trainable_variables = detection_model.trainable_variables
to_fine_tune = []
prefixes_to_train = [
  'WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalBoxHead',
  'WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalClassHead']
for var in trainable_variables:
  if any([var.name.startswith(prefix) for prefix in prefixes_to_train]):
    to_fine_tune.append(var)

# train_step.
def get_model_train_step_function(model, optimizer, vars_to_fine_tune):
  """Get a tf.function for training step."""
  @tf.function
  def train_step_fn(image_tensors,
                    groundtruth_boxes_list,
                    groundtruth_classes_list):
    """A single training iteration.
    Args:
      image_tensors: A list of [1, height, width, 3] Tensor of type tf.float32.
        Note that the height and width can vary across images, as they are
        reshaped within this function to be 640x640.
      groundtruth_boxes_list: A list of Tensors of shape [N_i, 4] with type
        tf.float32 representing groundtruth boxes for each image in the batch.
      groundtruth_classes_list: A list of Tensors of shape [N_i, num_classes]
        with type tf.float32 representing groundtruth boxes for each image in
        the batch.

    Returns:
      A scalar tensor representing the total loss for the input batch.
    """
    shapes = tf.constant(batch_size * [[640, 640, 3]], dtype=tf.int32)
    model.provide_groundtruth(
        groundtruth_boxes_list=groundtruth_boxes_list,
        groundtruth_classes_list=groundtruth_classes_list)
    with tf.GradientTape() as tape:
      preprocessed_images = tf.concat(
          [detection_model.preprocess(image_tensor)[0]
           for image_tensor in image_tensors], axis=0)
      prediction_dict = model.predict(preprocessed_images, shapes)
      losses_dict = model.loss(prediction_dict, shapes)
      total_loss = losses_dict['Loss/localization_loss'] + losses_dict['Loss/classification_loss']
      gradients = tape.gradient(total_loss, vars_to_fine_tune)
      optimizer.apply_gradients(zip(gradients, vars_to_fine_tune))
    return total_loss

  return train_step_fn

# 优化器
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)
train_step_fn = get_model_train_step_function(
    detection_model, optimizer, to_fine_tune)

print('Start fine-tuning!', flush=True)

# 开始训练(即train_loop)
for idx in range(num_batches):
  # Grab keys for a random subset of examples
  all_keys = list(range(len(train_images_np)))
  random.shuffle(all_keys)
  example_keys = all_keys[:batch_size]

  gt_boxes_list = [gt_box_tensors[key] for key in example_keys]
  gt_classes_list = [gt_classes_one_hot_tensors[key] for key in example_keys]
  image_tensors = [train_image_tensors[key] for key in example_keys]

  # Training step (forward pass + backwards pass)
  total_loss = train_step_fn(image_tensors, gt_boxes_list, gt_classes_list)

  if idx % 10 == 0:
    print('batch ' + str(idx) + ' of ' + str(num_batches)
    + ', loss=' +  str(total_loss.numpy()), flush=True)

print('Done fine-tuning!')

  Step 13:下载测试图片,用来测试上一步训练好的模型

# uncomment if you want to delete existing files
!rm zombie-walk-frames.zip
!rm -rf ./zombie-walk
!rm -rf ./results

# download test images
!wget --no-check-certificate \
    https://storage.googleapis.com/tensorflow-3-public/datasets/zombie-walk-frames.zip \
    -O zombie-walk-frames.zip

# unzip test images
local_zip = './zombie-walk-frames.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('./results')
zip_ref.close()

  Step 14:测试image to numpy 转换

test_image_dir = './results/'
test_images_np = []

# load images into a numpy array. this will take a few minutes to complete.
for i in range(0, 237):
    image_path = os.path.join(test_image_dir, 'zombie-walk' + "{0:04}".format(i) + '.jpg')
    print(image_path)
    test_images_np.append(np.expand_dims(
      load_image_into_numpy_array(image_path), axis=0))

  Step 15:目标检测函数定义

@tf.function
def detect(input_tensor):
    """Run detection on an input image.
    Args:
    input_tensor: A [1, height, width, 3] Tensor of type tf.float32.
      Note that height and width can be anything since the image will be
      immediately resized according to the needs of the model within this
      function.

    Returns:
    A dict containing 3 Tensors (`detection_boxes`, `detection_classes`,
      and `detection_scores`).
    """
    preprocessed_image, shapes = detection_model.preprocess(input_tensor)
    prediction_dict = detection_model.predict(preprocessed_image, shapes)
    
    detections = detection_model.postprocess(prediction_dict, shapes)
    
    return detections

  Step 16:调用目标检测函数,测试模型准确度

label_id_offset = 1
results = {'boxes': [], 'scores': []}

i = 150
images_np = test_images_np
# input_tensor = train_image_tensors[i]
input_tensor = tf.convert_to_tensor(images_np[i], dtype=tf.float32)
detections = detect(input_tensor)

detections['detection_boxes'][0].shape
detections['detection_classes'][0].shape
plot_detections(
  images_np[i][0],
  detections['detection_boxes'][0].numpy(),
  detections['detection_classes'][0].numpy().astype(np.uint32)
  + label_id_offset,
  detections['detection_scores'][0].numpy(),
  category_index, figsize=(15, 20))

  测试结果如下图:
在这里插入图片描述
  由此可见,模型的检测效果符合预期。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/571888.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ORB-SLAM3整体流程详解

0. 简介 在之前,作者曾经转过一篇《一文详解ORB-SLAM3》的文章。那篇文章中提到了ORB-SLAM3是一个支持视觉、视觉加惯导、混合地图的SLAM系统,可以在单目,双目和RGB-D相机上利用针孔或者鱼眼模型运行。与ORB-SLAM2相比,ORB-SLAM3…

软件系统三基座之一:权限管理

软件系统三基座包含:权限管理、组织架构、用户管理。 何为基座,即是有了这些基础,任一相关的“建筑”就能逐步搭建起来。 万丈高楼平地起 一、为什么要权限管理 权限管理,一般指根据系统设置的安全规则或者安全策略,…

集成chatgpt4和midjourney的超强镜像站

昨天发现一个镜像站,和之前发的镜像站不一样,这个集成了midjourney和chatgpt,且免翻,相信给很多很多用户都提供了便利吧! 先把网站贴出来,有兴趣的伙伴可以玩一玩 http://mtw.so/5EoyYy http://mtw.so/5E…

如何在上架App之前设置证书并上传应用

App上架教程 在上架App之前想要进行真机测试的同学,请查看《iOS- 最全的真机测试教程》,里面包含如何让多台电脑同时上架App和真机调试。 P12文件的使用详解 注意: 同样可以在Build Setting 的sign中设置证书,但是有点麻烦&…

浅析 Redis 中 String 数据类型及其底层编码

从 RedisObject 说起 在 Redis 中,任意数据类型的键和值都会被封装为一个 RedisObject ,也叫做Redis对象,源码如下 c 复制代码 /*server.h*/ typedef struct redisObject { unsigned type:4; unsigned encoding:4; unsigned lru:LRU_BITS;…

springboot+vue之java学习平台(java项目源码+文档)

风定落花生,歌声逐流水,大家好我是风歌,混迹在java圈的辛苦码农。今天要和大家聊的是一款基于springboot的java学习平台。项目源码以及部署相关请联系风歌,文末附上联系信息 。 💕💕作者:风歌&a…

档案库房太乱了怎么办?这个方法秒变高级!

全国有数以万计的大大小小的档案馆,其中有许多非常重要的机要档案,其历史和社会价值非常高,而档案保存的质量、档案的物理寿命、档案的防虫防霉都与库房的空气质量、温湿度息息相关。 解决档案高效管理及利用的安全问题越来越迫切&#xff0c…

在Ubuntu22.04上安装QQ~Linux

在Ubuntu22.04上安装QQ~Linux 0. 前言1. 下载deb安装包2. 使用dpkg安装deb包3. 安装完成,启动QQ3.1 点击图标打开3.2 使用命令行的方式打开 0. 前言 换Ubuntu当主力生产力了,并不是太喜欢vmware,所以我直接装到了硬盘里边,需要移…

SSM 如何使用 Kafka 实现消息队列?

SSM 如何使用 Kafka 实现消息队列? Kafka 是一个高性能、可扩展、分布式的消息队列系统,它支持多种数据格式和多种操作,可以用于实现数据传输、消息通信、日志处理等场景。在 SSM(Spring Spring MVC MyBatis)开发中…

iOS-最全的App上架教程

App上架教程 在上架App之前想要进行真机测试的同学,请查看《iOS- 最全的真机测试教程》,里面包含如何让多台电脑同时上架App和真机调试。 P12文件的使用详解 注意: 同样可以在Build Setting 的sign中设置证书,但是有点麻烦&…

软件开发项目成本控制的4大策略

1、构建责权利相结合的成本控制机制 需要对每个部门与个人的工作范围和工作职业有明确的界定,并赋予相应的权利以充分履行职责。在责任支配下高效完成工作进度时,需要给予一定的物质奖励。通过这样层层落实,逐级负责,从而做到责权…

VanillaNet:深度学习极简主义的力量

摘要 基础模型的核心是“更多不同”的理念,计算机视觉和自然语言处理方面的出色表现就是例证。然而,Transformer模型的优化和固有复杂性的挑战要求范式向简单性转变。在本文中,我们介绍了VanillaNET,这是一种设计优雅的神经网络架…

学会提问,ChatGPT可以帮你写出高质量论文

前言 ChatGPT 很火,火到大家以为他可以上天入地,上到天文,下到地理无所不能,但实际使用大家是不是会遇到如下的情况。 写论文步骤 今天,我们来探讨下怎样问ChatGPT,才能帮你写出一篇优秀的论文,…

【Java-Crawler】爬取动态页面(HtmlUnit、WebMagic)

爬取动态页面(WebMagic、HtmlUnit) 一、HtmlUnit的基本使用引入依赖一般使用步骤WebClient 的一些配置(上述一般步骤中的第二步) 二、案例(爬取CSDN首页)测试(WebMagicHtmlUnit)三、…

人机交互技术在车管所的应用探索

车管所作为交通管理的重要机构,承担着车辆登记、驾驶证办理、年检等重要职责,其工作效率和服务质量对于保障道路交通安全和畅通至关重要。而人机交互技术作为一种新兴的技术手段,可以为车管所提供更加高效、便捷的服务。因此,本文…

ESD防静电监控系统后台实时掌控现场静电防护情况

当静电积累到一定程度时,它可能会产生电击,从而对工人造成伤害。因此,工厂应该采取必要的预防措施,如提供防静电鞋和衣服,以保护工人免受静电伤害。 ESD防静电监控系统实现工业4.0技术要求,ESD物联技术稳定…

chatgpt赋能python:Python编程:接口程序的SEO优化方法

Python编程:接口程序的SEO优化方法 简介 接口程序是现代软件开发不可或缺的一部分,为应用程序提供外部数据访问和交互的方式。Python是一种功能强大的编程语言,在接口开发中也得到了广泛应用。本文将介绍如何使用Python编写有效的接口程序并…

新形式下安科瑞智能配网监控系统的应用研究

安科瑞 徐浩竣 江苏安科瑞电器制造有限公司 zx acrelxhj 摘要:随着经济和科技水平的快速发展,大型建筑变电所、配电房数量较多,分布区域广,配电运维部门人员对配电房的运维管理基本停留在传统的定期巡视、周期性检修、故障抢修…

对于质量保障,前端职能该做些什么?

目录 前言 1. 背景 2. 分析 2.1 前端自动化测试工具 2.1.1 针对工程代码的静态检查 2.1.2 针对部署产物的检查 2.1.3 性能测试 2.1.4 错误检测 2.1.5 容灾(白屏)检测 2.2 devOps 流程关联 2.2.1 提测卡点 2.2.2 发布卡点 3. 总结 3.1 严选…

RabbitMQ消息持久化机制

上一篇说到生产者消息确认机制,它可以确保消息投递到RabbitMQ的队列中,但是消息发送到RabbitMQ以后,如果MQ宕机,也可能导致消息丢失,所以提出了消息持久化。持久化的主要机制就是将信息写入磁盘,当RabbtiMQ…