政安晨:【Keras机器学习示例演绎】(四)—— 利用迁移学习进行关键点检测

news2025/2/25 12:44:28

目录

数据收集

导入

定义超参数

加载数据

可视化数据

准备数据生成器

定义增强变换

创建训练和验证分割

数据生成器调查

模型构建

模型编译和训练

进行预测并将其可视化

更进一步


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:利用数据增强和迁移学习训练关键点检测器。

关键点检测包括对物体关键部位的定位。例如,人脸的关键部位包括鼻尖、眉毛、眼角等。这些部分有助于以特征丰富的方式表示底层对象。关键点检测的应用包括姿势估计、人脸检测等。

在本例中,我们将使用斯坦福 Extra 数据集,利用迁移学习建立一个关键点检测器。

本示例需要 TensorFlow 2.4 或更高版本,以及 imgaug 库,可使用以下命令安装:

!pip install -q -U imgaug

数据收集


StanfordExtra 数据集包含 12,000 张狗的图像以及关键点和分割图。该数据集由斯坦福狗数据集发展而来。

可通过以下命令下载:

!wget -q http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar

注释在 StanfordExtra 数据集中以单个 JSON 文件的形式提供,用户需要填写此表格才能访问该文件。

作者明确指示用户不要共享 JSON 文件,本示例也尊重这一愿望:您应自行获取 JSON 文件。

JSON 文件的本地版本为 stanfordextra_v12.zip。
下载文件后,我们可以解压缩。

!tar xf images.tar
!unzip -qq ~/stanfordextra_v12.zip

导入

from keras import layers
import keras

from imgaug.augmentables.kps import KeypointsOnImage
from imgaug.augmentables.kps import Keypoint
import imgaug.augmenters as iaa

from PIL import Image
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
import json
import os

定义超参数

IMG_SIZE = 224
BATCH_SIZE = 64
EPOCHS = 5
NUM_KEYPOINTS = 24 * 2  # 24 pairs each having x and y coordinates

加载数据


作者还提供了一个元数据文件,其中指定了有关关键点的其他信息,如颜色信息、动物姿势名称等。我们将在 pandas 数据框中加载该文件,以提取信息用于可视化目的。

IMG_DIR = "Images"
JSON = "StanfordExtra_V12/StanfordExtra_v12.json"
KEYPOINT_DEF = (
    "https://github.com/benjiebob/StanfordExtra/raw/master/keypoint_definitions.csv"
)

# Load the ground-truth annotations.
with open(JSON) as infile:
    json_data = json.load(infile)

# Set up a dictionary, mapping all the ground-truth information
# with respect to the path of the image.
json_dict = {i["img_path"]: i for i in json_data}

json_dict 的单个条目如下所示:

'n02085782-Japanese_spaniel/n02085782_2886.jpg':
{'img_bbox': [205, 20, 116, 201],
 'img_height': 272,
 'img_path': 'n02085782-Japanese_spaniel/n02085782_2886.jpg',
 'img_width': 350,
 'is_multiple_dogs': False,
 'joints': [[108.66666666666667, 252.0, 1],
            [147.66666666666666, 229.0, 1],
            [163.5, 208.5, 1],
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0],
            [54.0, 244.0, 1],
            [77.33333333333333, 225.33333333333334, 1],
            [79.0, 196.5, 1],
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0],
            [150.66666666666666, 86.66666666666667, 1],
            [88.66666666666667, 73.0, 1],
            [116.0, 106.33333333333333, 1],
            [109.0, 123.33333333333333, 1],
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0]],
 'seg': ...}

在这个例子中,我们感兴趣的键是:

  • img_path
  • joints

关节中总共有 24 个条目。

每个条目有 3 个值:


x 坐标
y 坐标
关键点的可见性标志(1 表示可见,0 表示不可见)


我们可以看到,关节点包含多个 [0, 0, 0] 条目,表示这些关键点没有被标记。

在本例中,我们将同时考虑非可见和未标记的关键点,以便进行迷你批量学习。

# Load the metdata definition file and preview it.
keypoint_def = pd.read_csv(KEYPOINT_DEF)
keypoint_def.head()

# Extract the colours and labels.
colours = keypoint_def["Hex colour"].values.tolist()
colours = ["#" + colour for colour in colours]
labels = keypoint_def["Name"].values.tolist()


# Utility for reading an image and for getting its annotations.
def get_dog(name):
    data = json_dict[name]
    img_data = plt.imread(os.path.join(IMG_DIR, data["img_path"]))
    # If the image is RGBA convert it to RGB.
    if img_data.shape[-1] == 4:
        img_data = img_data.astype(np.uint8)
        img_data = Image.fromarray(img_data)
        img_data = np.array(img_data.convert("RGB"))
    data["img_data"] = img_data

    return data

可视化数据


现在,我们编写一个实用程序来可视化图像及其关键点。

# Parts of this code come from here:
# https://github.com/benjiebob/StanfordExtra/blob/master/demo.ipynb
def visualize_keypoints(images, keypoints):
    fig, axes = plt.subplots(nrows=len(images), ncols=2, figsize=(16, 12))
    [ax.axis("off") for ax in np.ravel(axes)]

    for (ax_orig, ax_all), image, current_keypoint in zip(axes, images, keypoints):
        ax_orig.imshow(image)
        ax_all.imshow(image)

        # If the keypoints were formed by `imgaug` then the coordinates need
        # to be iterated differently.
        if isinstance(current_keypoint, KeypointsOnImage):
            for idx, kp in enumerate(current_keypoint.keypoints):
                ax_all.scatter(
                    [kp.x],
                    [kp.y],
                    c=colours[idx],
                    marker="x",
                    s=50,
                    linewidths=5,
                )
        else:
            current_keypoint = np.array(current_keypoint)
            # Since the last entry is the visibility flag, we discard it.
            current_keypoint = current_keypoint[:, :2]
            for idx, (x, y) in enumerate(current_keypoint):
                ax_all.scatter([x], [y], c=colours[idx], marker="x", s=50, linewidths=5)

    plt.tight_layout(pad=2.0)
    plt.show()


# Select four samples randomly for visualization.
samples = list(json_dict.keys())
num_samples = 4
selected_samples = np.random.choice(samples, num_samples, replace=False)

images, keypoints = [], []

for sample in selected_samples:
    data = get_dog(sample)
    image = data["img_data"]
    keypoint = data["joints"]

    images.append(image)
    keypoints.append(keypoint)

visualize_keypoints(images, keypoints)

从图中可以看出,我们的图像大小并不均匀,这在现实世界的大多数情况下都是意料之中的。

但是,如果我们将这些图像的大小调整为统一形状(例如 (224 x 224)),它们的地面实况注释也会受到影响。

如果我们对图像进行任何几何变换(例如水平翻转),情况也是一样。幸运的是,imgaug 提供了可以处理这个问题的实用程序。

在后文中,我们将编写一个继承 keras.utils.Sequence 类的数据生成器,使用 imgaug 对成批数据进行数据增强。

准备数据生成器

class KeyPointsDataset(keras.utils.PyDataset):
    def __init__(self, image_keys, aug, batch_size=BATCH_SIZE, train=True, **kwargs):
        super().__init__(**kwargs)
        self.image_keys = image_keys
        self.aug = aug
        self.batch_size = batch_size
        self.train = train
        self.on_epoch_end()

    def __len__(self):
        return len(self.image_keys) // self.batch_size

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.image_keys))
        if self.train:
            np.random.shuffle(self.indexes)

    def __getitem__(self, index):
        indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
        image_keys_temp = [self.image_keys[k] for k in indexes]
        (images, keypoints) = self.__data_generation(image_keys_temp)

        return (images, keypoints)

    def __data_generation(self, image_keys_temp):
        batch_images = np.empty((self.batch_size, IMG_SIZE, IMG_SIZE, 3), dtype="int")
        batch_keypoints = np.empty(
            (self.batch_size, 1, 1, NUM_KEYPOINTS), dtype="float32"
        )

        for i, key in enumerate(image_keys_temp):
            data = get_dog(key)
            current_keypoint = np.array(data["joints"])[:, :2]
            kps = []

            # To apply our data augmentation pipeline, we first need to
            # form Keypoint objects with the original coordinates.
            for j in range(0, len(current_keypoint)):
                kps.append(Keypoint(x=current_keypoint[j][0], y=current_keypoint[j][1]))

            # We then project the original image and its keypoint coordinates.
            current_image = data["img_data"]
            kps_obj = KeypointsOnImage(kps, shape=current_image.shape)

            # Apply the augmentation pipeline.
            (new_image, new_kps_obj) = self.aug(image=current_image, keypoints=kps_obj)
            batch_images[i,] = new_image

            # Parse the coordinates from the new keypoint object.
            kp_temp = []
            for keypoint in new_kps_obj:
                kp_temp.append(np.nan_to_num(keypoint.x))
                kp_temp.append(np.nan_to_num(keypoint.y))

            # More on why this reshaping later.
            batch_keypoints[i,] = np.array(kp_temp).reshape(1, 1, 24 * 2)

        # Scale the coordinates to [0, 1] range.
        batch_keypoints = batch_keypoints / IMG_SIZE

        return (batch_images, batch_keypoints)

要进一步了解如何在 imgaug 中使用关键点,请查看本文。

定义增强变换

train_aug = iaa.Sequential(
    [
        iaa.Resize(IMG_SIZE, interpolation="linear"),
        iaa.Fliplr(0.3),
        # `Sometimes()` applies a function randomly to the inputs with
        # a given probability (0.3, in this case).
        iaa.Sometimes(0.3, iaa.Affine(rotate=10, scale=(0.5, 0.7))),
    ]
)

test_aug = iaa.Sequential([iaa.Resize(IMG_SIZE, interpolation="linear")])

创建训练和验证分割

np.random.shuffle(samples)
train_keys, validation_keys = (
    samples[int(len(samples) * 0.15) :],
    samples[: int(len(samples) * 0.15)],
)

数据生成器调查

train_dataset = KeyPointsDataset(
    train_keys, train_aug, workers=2, use_multiprocessing=True
)
validation_dataset = KeyPointsDataset(
    validation_keys, test_aug, train=False, workers=2, use_multiprocessing=True
)

print(f"Total batches in training set: {len(train_dataset)}")
print(f"Total batches in validation set: {len(validation_dataset)}")

sample_images, sample_keypoints = next(iter(train_dataset))
assert sample_keypoints.max() == 1.0
assert sample_keypoints.min() == 0.0

sample_keypoints = sample_keypoints[:4].reshape(-1, 24, 2) * IMG_SIZE
visualize_keypoints(sample_images[:4], sample_keypoints)

演绎展示:
 

Total batches in training set: 166
Total batches in validation set: 29

模型构建


Stanford dogs 数据集(StanfordExtra 数据集基于该数据集)是使用 ImageNet-1k 数据集构建的。因此,在 ImageNet-1k 数据集上预训练的模型很可能对这项任务有用。我们将使用在该数据集上预先训练好的 MobileNetV2 作为骨干,从图像中提取有意义的特征,然后将这些特征传递给自定义回归头用于预测坐标。

def get_model():
    # Load the pre-trained weights of MobileNetV2 and freeze the weights
    backbone = keras.applications.MobileNetV2(
        weights="imagenet",
        include_top=False,
        input_shape=(IMG_SIZE, IMG_SIZE, 3),
    )
    backbone.trainable = False

    inputs = layers.Input((IMG_SIZE, IMG_SIZE, 3))
    x = keras.applications.mobilenet_v2.preprocess_input(inputs)
    x = backbone(x)
    x = layers.Dropout(0.3)(x)
    x = layers.SeparableConv2D(
        NUM_KEYPOINTS, kernel_size=5, strides=1, activation="relu"
    )(x)
    outputs = layers.SeparableConv2D(
        NUM_KEYPOINTS, kernel_size=3, strides=1, activation="sigmoid"
    )(x)

    return keras.Model(inputs, outputs, name="keypoint_detector")

我们定制的网络是全卷积网络,因此与具有全连接密集层的相同版本网络相比,它对参数更加友好。

get_model().summary()

演绎展示:

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
 9406464/9406464 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step

请注意网络的输出形状:(None, 1, 1, 48)。

这就是我们重塑坐标的原因:batch_keypoints[i, :] = np.array(kp_temp).reshape(1, 1, 24 * 2)

模型编译和训练


在本例中,我们只对网络进行 5 个历元的训练。

model = get_model()
model.compile(loss="mse", optimizer=keras.optimizers.Adam(1e-4))
model.fit(train_dataset, validation_data=validation_dataset, epochs=EPOCHS)
Epoch 1/5
 166/166 ━━━━━━━━━━━━━━━━━━━━ 84s 415ms/step - loss: 0.1110 - val_loss: 0.0959
Epoch 2/5
 166/166 ━━━━━━━━━━━━━━━━━━━━ 79s 472ms/step - loss: 0.0874 - val_loss: 0.0802
Epoch 3/5
 166/166 ━━━━━━━━━━━━━━━━━━━━ 78s 463ms/step - loss: 0.0789 - val_loss: 0.0765
Epoch 4/5
 166/166 ━━━━━━━━━━━━━━━━━━━━ 78s 467ms/step - loss: 0.0769 - val_loss: 0.0731
Epoch 5/5
 166/166 ━━━━━━━━━━━━━━━━━━━━ 77s 464ms/step - loss: 0.0753 - val_loss: 0.0712

<keras.src.callbacks.history.History at 0x7fb5c4299ae0>

进行预测并将其可视化

sample_val_images, sample_val_keypoints = next(iter(validation_dataset))
sample_val_images = sample_val_images[:4]
sample_val_keypoints = sample_val_keypoints[:4].reshape(-1, 24, 2) * IMG_SIZE
predictions = model.predict(sample_val_images).reshape(-1, 24, 2) * IMG_SIZE

# Ground-truth
visualize_keypoints(sample_val_images, sample_val_keypoints)

# Predictions
visualize_keypoints(sample_val_images, predictions)

演绎展示:

 1/1 ━━━━━━━━━━━━━━━━━━━━ 7s 7s/step

通过更多的训练,预测结果可能会有所改善。

更进一步


尝试使用 imgaug 的其他增强变换,研究其对结果的影响。

在这里,我们从预训练网络中线性转移了特征,也就是说,我们没有对其进行微调。

我们鼓励你在这项任务中对其进行微调,看看是否能提高性能。您还可以尝试不同的架构,看看它们对最终性能有什么影响。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1609017.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

遍历取后端数据推送到地图上,实现图标点标记地图效果

遍历取后端数据推送到地图上&#xff0c;实现图标点标记地图效果 示例链接&#xff1a; 功能示例(Vue版) | Mars3D三维可视化平台 | 火星科技 踩坑注意点&#xff1a; 1. id: 1 是地图底图的id 后台也返回之后 id直接会有冲突 此时图标标记之后无法单击 相关代码&#xff1a…

liqo学习及安装,k8s,kubernetes多集群互联

先按照官方的教程在虚拟机安装学习 在开始以下教程之前&#xff0c;您应该确保您的系统上安装了以下软件&#xff1a; Docker&#xff0c;容器运行时。Kubectl&#xff0c;Kubernetes 的命令行工具。 curl -LO "https://dl.k8s.io/release/$(curl -L -s https://dl.k8s.…

Spark-Scala语言实战(17)

我带着大家一起来到Linux集群环境下&#xff0c;学习我们的spark。想了解的朋友可以查看这篇文章。同时&#xff0c;希望我的文章能帮助到你&#xff0c;如果觉得我的文章写的不错&#xff0c;请留下你宝贵的点赞&#xff0c;谢谢。 Spark-Scala语言实战&#xff08;16&#x…

关于MCU核心板的一些常见问题

BGA植球与焊接&#xff08;多涂焊油&#xff09;&#xff1a; 【BGA芯片是真麻烦&#xff0c;主要是植锡珠太麻烦了&#xff0c;拆一次就得重新植】https://www.bilibili.com/video/BV1vW4y1w7oNvd_source3cc3c07b09206097d0d8b0aefdf07958 / NC电容一般有两种含义&#xff1…

js自动缩放页面,html自动缩放页面,大屏自动缩放页面,数字看板自动缩放页面,大数据看板自动缩放页面

js自动缩放页面&#xff0c;html自动缩放页面&#xff0c;大屏自动缩放页面&#xff0c;数字看板自动缩放页面&#xff0c;大数据看板自动缩放页面 由纯JS实现 html代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"…

CSS基础:display的3个常见属性值详解

你好&#xff0c;我是云桃桃。 一个希望帮助更多朋友快速入门 WEB 前端的程序媛。 云桃桃-大专生&#xff0c;一枚程序媛&#xff0c;感谢关注。回复 “前端基础题”&#xff0c;可免费获得前端基础 100 题汇总&#xff0c;回复 “前端工具”&#xff0c;可获取 Web 开发工具…

13-LINUX--消息队列

一.消息队列 1.消息队列&#xff1a;消息队列为一个进程向另一个进程发送一个数据块提供了条件&#xff0c;每个数据块会包含一个类型。 2.相关函数 1>.msgget(key_t key,int msgflg) : 创建消息队列 2>. msgsnd&#xff1a;把消息添加到消息队列 3>.msgrcv &#xf…

【Golang】Gin教学-获取请求信息并返回

安装Gin初始化Gin处理所有HTTP请求获取请求的URL和Method获取请求参数根据Content-Type判断请求数据类型处理JSON数据处理表单数据处理文件返回JSON响应启动服务完整代码测试 Gin是一个用Go&#xff08;又称Golang&#xff09;编写的HTTP Web框架&#xff0c;它具有高性能和简洁…

【React】Sigma.js框架网络图-入门篇

一、介绍 Sigma.js是一个专门用于图形绘制的JavaScript库。 它使在Web页面上发布网络变得容易&#xff0c;并允许开发人员将网络探索集成到丰富的Web应用程序中。 Sigma.js提供了许多内置功能&#xff0c;例如Canvas和WebGL渲染器或鼠标和触摸支持&#xff0c;以使用户在网页上…

【数据结构】图论(图的储存方式,图的遍历算法DFS和BFS、图的遍历算法的应用、图的连通性问题)

目录 图论一、 图的基本概念和术语二、图的存储结构1. 数组(邻接矩阵)存储表示无向图的数组(邻接矩阵)存储表示有向图的数组(邻接矩阵)存储表示 邻接表存储表示有向图的十字链表存储表示无向图的邻接多重表存储表示 三、图的遍历算法图的遍历——深度优先搜索&#xff08;DFS&a…

cdp集群Hbase组件HRegionServer服务停止原因以及排查

前言&#xff1a;重启集群后某一节点HRegionServer服务停止&#xff0c;重启前所有服务均正常 去查看日志&#xff1a; 日志报错 ERROR HRegionServer Master rejected startup because clock is out of sync org.apache.hadoop.hbase.ClockOutOfSyncException: org.apache.h…

Amazon云计算AWS之[1]基础存储架构Dynamo

文章目录 Dynamo概况Dynamo架构的主要技术主要问题及解决方案Dynamo的存储节点数据均衡分布的问题一致性哈希算法改进一致性哈希算法 数据备份数据冲突问题成员资格及错误检测容错机制临时故障处理机制永久性故障处理机制 Dynamo概况 面向服务的Amazon平台基本架构为了保证其稳…

【深度学习】Vision Transformer

一、Vision Transformer Vision Transformer (ViT)将Transformer应用在了CV领域。在学习它之前&#xff0c;需要了解ResNet、LayerNorm、Multi-Head Self-Attention。 ViT的结构图如下&#xff1a; 如图所示&#xff0c;ViT主要包括Embedding、Encoder、Head三大部分。Class …

小球反弹(蓝桥杯)

文章目录 小球反弹【问题描述】答案&#xff1a;1100325199.77解题思路模拟 小球反弹 【问题描述】 有一长方形&#xff0c;长为 343720 单位长度&#xff0c;宽为 233333 单位长度。在其内部左上角顶点有一小球&#xff08;无视其体积&#xff09;&#xff0c;其初速度如图所…

Targeted influence maximization in competitive social networks

abstract 利用口碑效应的广告对于推销产品是相当有效的。在过去的十年中&#xff0c;人们对营销中的影响力最大化问题进行了深入的研究。影响力最大化问题旨在将社交网络中的一小群人识别为种子&#xff0c;最终他们将引发网络中最大的影响力传播或产品采用。在网络营销的实际场…

微信小程序日期增加时间完成订单失效倒计时(有效果图)

效果图 .wxml <view class"TimeSeond">{{second}}</view>.js Page({data: {tiem_one:,second:,//倒计时deadline:,},onLoad(){this.countdown();},countdown(){let timestamp Date.parse(new Date()) / 1000;//当前时间戳let time this.addtime(2024…

记一次中间件宕机以后持续请求导致应用OOM的排查思路(server.max-http-header-size属性配置不当的严重后果)

一、背景 最近有一次在系统并发比较高的时候&#xff0c;数据库突然发生了故障&#xff0c;导致大量请求失败&#xff0c;在数据库宕机不久&#xff0c;通过应用日志可以看到系统发生了OOM。 二、排查 初次看到这个现象的时候&#xff0c;我还是有点懵逼的&#xff0c;数据库…

k8s 部署 kube-prometheus监控

一、Prometheus监控部署 1、下载部署文件 # 使用此链接下载后解压即可 wget https://github.com/prometheus-operator/kube-prometheus/archive/refs/heads/release-0.13.zip2、根据k8s集群版本获取不同的kube-prometheus版本部署 https://github.com/prometheus-operator/k…

基于SSM的物流快递管理系统(含源码+sql+视频导入教程+文档+PPT)

&#x1f449;文末查看项目功能视频演示获取源码sql脚本视频导入教程视频 1 、功能描述 基于SSM的物流快递管理系统2拥有三个角色&#xff1a; 管理员&#xff1a;用户管理、管理员管理、新闻公告管理、留言管理、取件预约管理、收件管理、货物分类管理、发件信息管理等 用户…

C#在窗体中设计滚动字幕的方法:创建特殊窗体

目录 1.涉及到的知识点 (1)Timer组件 (2)Label控件的Left属性 (3)启动和关闭Timer计时器 2. 实例 &#xff08;1&#xff09;Resources.Designer.cs设计 &#xff08;2&#xff09; Form1.Designer.cs设计 &#xff08;3&#xff09;Form1.cs设计 &#xff08;4&#…