基于 YOLO V8 Pose Fine-Tuning 训练 15 点人脸关键点检测模型

news2024/12/26 21:55:46

一、YOLO V8 Pose

YOLO V8 在上篇文章中进了简单的介绍,并基于YOLO V8 Fine-Tuning 训练了自定义的目标检测模型,而YOLO V8 Pose 是建立在YOLO V8基础上的关键点检测模型,本文基于 yolov8n-pose 模型实验 Fine-Tuning 训练15 点人脸关键点检测模型,并配合上篇文章训练的人脸检测模型一起使用。

上篇文章地址:

基于 YOLO V8 Fine-Tuning 训练自定义的目标检测模型

YOLO V8 的细节可以参考下面官方的介绍:

https://docs.ultralytics.com/zh/models/yolov8/#citations-and-acknowledgements

本文依旧使用 ultralytics 框架进行训练和测试,其中 ultralyticspytorch 的版本如下:

torch==1.13.1+cu116
ultralytics==8.1.37

YOLO V8 Pose 调用示例如下:

测试图像:
在这里插入图片描述

这里使用 yolov8n-pose 模型,如果模型不存在会自动下载:

from ultralytics import YOLO
# Load a model
model = YOLO('yolov8n-pose.pt')  # pretrained YOLOv8n model

results = model.predict('./img/1.png')
# Show results
results[0].show()

在这里插入图片描述

二、人脸关键点检测数据集

在计算机视觉人脸计算领域,人脸关键点检测是一个十分重要的区域,可以实现例如一些人脸矫正、表情分析、姿态分析、人脸识别、人脸美颜等方向。

人脸关键点数据集通常有 5点、15点、68点、96点、98点、106点、186点 等,例如通用 Dlib 中的 68 点检测,它将人脸关键点分为脸部关键点和轮廓关键点,脸部关键点包含眉毛、眼睛、鼻子、嘴巴共计51个关键点,轮廓关键点包含17个关键点。

在这里插入图片描述

本文基于 kaggleFacial Keypoints Detection 中的数据集进行实践,该数据集包含包括7,049幅训练图像,图像是 96 x 96像素的灰度图像,其中关键点有 15个点,注意数据集有的字段缺失,如果去除字段缺失的数据,实际训练数据只有 2,140 幅训练图像,还包括1,783张测试图片,数据集的效果如下所示:

在这里插入图片描述
可以看出,关键点包括眉毛的两端、眼睛的中心和两端、鼻子尖、嘴巴两端和上下嘴唇的中间。

下载数据集

数据集在 kaggle 的官方网址上:

https://www.kaggle.com/c/facial-keypoints-detection

下载前需要进行登录,如果没有 kaggle 账号可以注册一个。

在这里插入图片描述

下载解压后,可以看到 training.ziptest.zip 两个文件,分别对应训练集和测试集,解压后数据是以 CSV 的格式进行存放的:

在这里插入图片描述
其中 training.csv 中的字段分别表示:

序号字段含义
0left_eye_center_x左眼中心 x 点
1left_eye_center_y左眼中心 y 点
2right_eye_center_x右眼中心 x 点
3right_eye_center_y右眼中心 y 点
4left_eye_inner_corner_x左眼内端 x 点
5left_eye_inner_corner_y左眼内端 y 点
6left_eye_outer_corner_x左眼外端 x 点
7left_eye_outer_corner_y左眼外端 y 点
8right_eye_inner_corner_x右眼内端 x 点
9right_eye_inner_corner_y右眼内端 y 点
10right_eye_outer_corner_x右眼外端 x 点
11right_eye_outer_corner_y右眼外端 y 点
12left_eyebrow_inner_end_x左眉毛内端 x 点
13left_eyebrow_inner_end_y左眉毛内端 y 点
14left_eyebrow_outer_end_x左眉毛外端 x 点
15left_eyebrow_outer_end_y左眉毛外端 y 点
16right_eyebrow_inner_end_x右眉毛内端 x 点
17right_eyebrow_inner_end_y右眉毛内端 y 点
18right_eyebrow_outer_end_x右眉毛外端 x 点
19right_eyebrow_outer_end_y右眉毛外端 y 点
20nose_tip_x鼻尖中心 x 点
21nose_tip_y鼻尖中心 y 点
22mouth_left_corner_x嘴巴左端 x 点
23mouth_left_corner_y嘴巴左端 y 点
24mouth_right_corner_x嘴巴右端 x 点
25mouth_right_corner_y嘴巴右端 y 点
26mouth_center_top_lip_x上嘴唇中心 x 点
27mouth_center_top_lip_y上嘴唇中心 y 点
28mouth_center_bottom_lip_x下嘴唇中心 x 点
29mouth_center_bottom_lip_y下嘴唇中心 y 点
30Image图形像素

由于数据是存放在CSV中,可以借助 pandas 工具对数据进行解析,如果没有安装 pandas 工具,可以通过下面指令安装:

pip3 install pandas -i https://pypi.tuna.tsinghua.edu.cn/simple

下面程序通过 pandas 解析 CSV 文件,并将图片转为 numpy 数组,通过 matplotlib 可视化工具查看,其中具体的解释都写在了注释中:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def main():
    csv_path = './data/training.csv'
    # 读取 CSV 文件
    train_df = pd.read_csv(csv_path)
    # 查看数据框,并列出数据集的头部。
    train_df.info()
    # 丢弃有缺失数据的样本
    train_df = train_df.dropna()
    # 获取图片信息,并转为 numpy 结构
    x_train = train_df['Image'].apply(lambda img: np.fromstring(img, sep=' '))
    x_train = np.vstack(x_train)
    # 重新修改形状
    x_train = x_train.reshape((-1, 96, 96, 1))
    # 去除最后一列的 Image
    cols = train_df.columns[:-1]
    y_train = train_df[cols].values

    print('训练集 shape: ', x_train.shape)
    print('训练集label shape: ', y_train.shape)

    plt.figure(figsize=(10, 10))
    for p in range(2):
        data = x_train[(p * 9):(p * 9 + 9)]
        label = y_train[(p * 9):(p * 9 + 9)]
        plt.clf()
        for i in range(9):
            plt.subplot(3, 3, i + 1)
            img = data[i].reshape(96, 96, 1)
            plt.imshow(img, cmap='gray')
            # 画关键点
            l = label[i]
            # 从 1 开始,每次走 2 步,j-1,j 就是当前点的坐标
            for j in range(1, 31, 2):
                plt.plot(l[j - 1], l[j], 'ro', markersize=4)
        plt.show()

if __name__ == '__main__':
    main()

运行之后,可以看到如下效果图:

在这里插入图片描述
下面我们基于该数据集进行建模,训练一个自己的关键点检测模型。

三、数据集拆分和转换

数据集格式需要转换成 Ultralytics 官方的 YOLO 格式,主要包括以下几点的注意:

  • 每幅图像一个文本文件:数据集中的每幅图像都有一个相应的文本文件,文件名与图像文件相同,扩展名为".txt"。
  • 每个对象一行:文本文件中的每一行对应图像中的一个对象实例。
  • 每行对象信息:每行包含对象实例的以下信息
    • 对象类别索引:代表对象类别的整数(如 0 代表人,1 代表汽车等)。
    • 对象中心坐标:对象中心的 xy 坐标,归一化后介于 01 之间。
    • 对象宽度和高度:对象的宽度和高度,标准化后介于 01 之间。
    • 对象关键点坐标:对象的关键点,归一化为 01

姿势估计任务的标签格式示例:

<class-index> <x> <y> <width> <height> <px1> <py1> <px2> <py2> ... <pxn> <pyn>

官方的介绍:

https://docs.ultralytics.com/zh/datasets/pose/#ultralytics-yolo-format

这里由于数据集中仅包含人脸信息,没有其他因素影响,因此,<class-index> <x> <y> <width> <height> 我们可以固定写死为:0 0.5 0.5 1 1 ,转换和拆分的逻辑如下:

import os
import shutil
from tqdm import tqdm
import numpy as np
import pandas as pd
from PIL import Image

# training.csv 地址
csv_path = "./data/training.csv"
# 训练集的比例
training_ratio = 0.8
# 拆分后数据的位置
train_dir = "train_data"


def toRgbImg(img):
    img = np.fromstring(img, sep=' ').astype(np.uint8).reshape(96, 96)
    img = Image.fromarray(img).convert('RGB')
    return img


def split_data():
    # 训练集目录
    os.makedirs(os.path.join(train_dir, "images/train"), exist_ok=True)
    os.makedirs(os.path.join(train_dir, "labels/train"), exist_ok=True)
    # 验证集目录
    os.makedirs(os.path.join(train_dir, "images/val"), exist_ok=True)
    os.makedirs(os.path.join(train_dir, "labels/val"), exist_ok=True)
    # 读取数据
    train_df = pd.read_csv(csv_path)
    # 丢弃有缺失数据的样本
    train_df = train_df.dropna()
    # 获取图片信息,并转为 numpy 结构
    x_train = train_df['Image'].apply(toRgbImg)
    # 去除最后一列的 Image, 将y值缩放到[0,1]区间
    cols = train_df.columns[:-1]
    y_train = train_df[cols].values
    # 使用 80% 的数据训练,20% 的数据进行验证
    size = int(x_train.shape[0] * 0.8)
    x_val = x_train[size:]
    y_val = y_train[size:]
    x_train = x_train[:size]
    y_train = y_train[:size]

    trains = []
    vals = []
    # 生成训练数据
    for i, image in tqdm(enumerate(x_train), total=len(x_train)):
        label = y_train[i]
        image_file = os.path.join(train_dir, f"images/train/{i}.jpg")
        label_file = os.path.join(train_dir, f"labels/train/{i}.txt")
        image.save(image_file)
        trains.append(image_file)
        width, height = image.size
        yolo_label = ["0", "0.5", "0.5", "1", "1"]
        for i, v in enumerate(label):
            if i % 2 == 0:
                yolo_label.append(str(v / float(width)))
            else:
                yolo_label.append(str(v / float(height)))
        with open(label_file, "w", encoding="utf-8") as w:
            w.write(" ".join(yolo_label))

    # 生成验证数据
    for i, image in tqdm(enumerate(x_val), total=len(x_val)):
        label = y_val[i]
        image_file = os.path.join(train_dir, f"images/val/{i}.jpg")
        label_file = os.path.join(train_dir, f"labels/val/{i}.txt")
        image.save(image_file)
        vals.append(image_file)
        width, height = image.size
        yolo_label = ["0", "0.5", "0.5", "1", "1"]
        for i, v in enumerate(label):
            if i % 2 == 0:
                yolo_label.append(str(v / float(width)))
            else:
                yolo_label.append(str(v / float(height)))
        with open(label_file, "w", encoding="utf-8") as w:
            w.write(" ".join(yolo_label))

    with open(os.path.join(train_dir, "train.txt"), "w") as file:
        file.write("\n".join([image_file for image_file in trains]))
    print("save train.txt success!")

    with open(os.path.join(train_dir, "val.txt"), "w") as file:
        file.write("\n".join([image_file for image_file in vals]))
    print("save val.txt success!")

if __name__ == '__main__':
    split_data()

运行后,可以在 train_data 下面看到拆分后的数据:
在这里插入图片描述

四、训练

使用 ultralytics 框架训练非常简单,仅需三行代码即可完成训练,不过在训练前需要编写 YAML 配置信息,主要标记数据集的位置。

创建 face.yaml 文件,写入下面内容:

# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: D:/pyProject/yolov8/face/train_data  # dataset root dir
train: images/train  # train images (relative to 'path') 4 images
val: images/val  # val images (relative to 'path') 4 images
test:  # test images (optional)

# Keypoints
kpt_shape: [15, 2]  # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)

# Classes dictionary
names:
  0: person

开始训练:

from ultralytics import YOLO

# 加载模型
model = YOLO('yolov8n-pose.pt')

# 训练
model.train(
    data='face.yaml', # 训练配置文件
    epochs=50, # 训练的周期
    imgsz=640, # 图像的大小
    device=[0], # 设备,如果是 cpu 则是 device='cpu'
    workers=0,
    lr0=0.001, # 学习率
    batch=8, # 批次大小
    amp=False # 是否启用混合精度训练
)

运行后可以看到打印的网络结构:

在这里插入图片描述

训练中:

在这里插入图片描述
训练结束后可以在 runs 目录下面看到训练的结果:

在这里插入图片描述

看下训练时 loss 的变化图:

在这里插入图片描述

三、模型预测

首先使用测试集进行测试:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from ultralytics import YOLO

def toRgbImg(img):
    img = np.fromstring(img, sep=' ').astype(np.uint8)
    img = Image.fromarray(img).convert('RGB')
    return img

def main():
    csv_path = 'data/test.csv'
    # 读取 CSV 文件
    test_df = pd.read_csv(csv_path)
    # 查看数据框,并列出数据集的头部。
    test_df.info()
    # 获取图片信息,并转为 numpy 结构
    test_df = test_df['Image'].apply(toRgbImg)
    test_df = np.vstack(test_df)
    # 重新修改形状
    test_df = test_df.reshape((-1, 96, 96, 3))
    # 加载模型
    model = YOLO('runs/pose/train/weights/best.pt')

    plt.figure(figsize=(10, 10))
    for p in range(5):
        data = test_df[(p * 9):(p * 9 + 9)]
        plt.clf()
        for i in range(9):
            plt.subplot(3, 3, i + 1)
            img = data[i]
            plt.imshow(img, cmap='gray')
            results = model.predict(img, device='cpu')
            # 画关键点
            keypoints = results[0].keypoints.xy
            for keypoint in keypoints:
                for xy in keypoint:
                    plt.plot(xy[0], xy[1], 'ro', markersize=4)
        plt.show()

if __name__ == '__main__':
    main()

在这里插入图片描述

在这里插入图片描述

可以看到对于鼻子位置的关键点有些会出现偏差。

四、结合上篇的人脸检测模型

from ultralytics import YOLO
from PIL import Image
from matplotlib import pyplot as plt


def main():
    # 加载人脸检测模型
    detection_model = YOLO('yolov8_face_detection.pt')
    # 加载人脸关键点检测模型
    point_model = YOLO('runs/pose/train/weights/best.pt')

    image = plt.imread('./img/10.jpg')
    # 预测
    results = detection_model.predict(image, device='cpu')
    boxes = results[0].boxes.xyxy
    print(boxes)
    ax = plt.gca()
    for boxe in boxes:
        x1, y1, x2, y2 = boxe[0], boxe[1], boxe[2], boxe[3]
        ax.add_patch(plt.Rectangle((x1, y1), (x2 - x1), (y2 - y1), fill=False, color='red'))
        # 截取图片
        crop = image[int(y1):int(y2), int(x1):int(x2)]
        results = point_model.predict(crop, device='cpu')
        keypoints = results[0].keypoints.xy
        for keypoint in keypoints:
            for xy in keypoint:
                plt.plot(xy[0]+ x1, xy[1]+ y1, 'ro', markersize=2)

    plt.imshow(image)
    plt.show()


if __name__ == '__main__':
    main()

测试效果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1564120.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DHCP原理重磅来袭——走过路过不要错过

目录 一.DHCP来源 &#xff08;1)手工分配缺点 (2)DHCP优点 二.DHCP设备调试 &#xff08;1&#xff09;.基本配置&#xff1a; &#xff08;2&#xff09;接口地址池 1.开启DHCP功能 2.开启DHCP接口地址池功能 3.查看IP地址分配结果 &#xff08;3&#xff09;全局地…

最新AI智能系统ChatGPT网站源码V6.3版本,GPTs、AI绘画、AI换脸、垫图混图+(SparkAi系统搭建部署教程文档)

一、前言 SparkAi创作系统是基于ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统&#xff0c;支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美&#xff0c;那么如何搭建部署AI创作ChatGPT&#xff1f;小编这里写一个详细图文教程吧。已支持GPT…

商品服务 - 三级分类

1.递归查询树形结构 Overridepublic List<CategoryEntity> listWithTree() {//1.查出所有分类List<CategoryEntity> all this.list();//2.组装成父子的属性结构List<CategoryEntity> level1Menus all.stream().filter(c -> c.getParentCid().equals(0L)…

OSError: Can‘t load tokenizer for ‘bert-base-chinese‘

文章目录 OSError: Cant load tokenizer for bert-base-chinese1.问题描述2.解决办法 OSError: Can’t load tokenizer for ‘bert-base-chinese’ 1.问题描述 使用from_pretrained()函数从预训练的权重中加载模型时报错&#xff1a; OSError: Can’t load tokenizer for ‘…

2024最新软件测试【测试理论+ 性能测试】面试题(内附答案)

一、测试理论 3.1 你们原来项目的测试流程是怎么样的? 我们的测试流程主要有三个阶段&#xff1a;需求了解分析、测试准备、测试执行。 1、需求了解分析阶段 我们的 SE 会把需求文档给我们自己先去了解一到两天这样&#xff0c;之后我们会有一个需求澄清会议&#xff0c; …

基于龙芯2k1000 mips架构ddr调试心得(二)

1、内存控制器概述 龙芯处理器内部集成的内存控制器的设计遵守 DDR2/3 SDRAM 的行业标准&#xff08;JESD79-2 和 JESD79-3&#xff09;。在龙芯处理器中&#xff0c;所实现的所有内存读/写操作都遵守 JESD79-2B 及 JESD79-3 的规定。龙芯处理器支持最大 4 个 CS&#xff08;由…

NoSQL(非关系型数据库)之Redis的简介与安装

一、简介 1.1 关系型数据库与非关系型数据库 1.1.1 概念 1.1.2 区别 1.2 非关系型数据库产生背景 1.3 redis 简介 1.4 redis 优点 1.5 redis 快的原因 二、安装 2.1 关闭核心防护 2.2 安装相关依赖 2.3 解压软件包并进行编译安装 2.4 设置 Redis 服务所需相关配置文…

每日一题:c语言实现n的阶乘

目录 一、要求 二、代码 三、结果 一、要求 实现n的阶乘&#xff0c;已知n&#xff01;1*2*3*…*n 二、代码 #define _CRT_SECURE_NO_WARNINGS 1 #include <stdio.h>int main() {//初始化变量n为要求的几阶&#xff0c;jiecheng存储结果的&#xff0c;初始化为1…

黄金票据复现

黄金票据&#xff1a; 在AS-REP里面的ticket的encpart是使用krbtgt的hash进行加密&#xff0c;如果拥有krbtgt的hash就可以给我们自己签发任意用户的TGT票据&#xff0c;这个票据称之为黄金票据 使用Mimikatz伪造Kerberos黄金票据 Mimikatz命令&#xff1a;Kerberos&#xf…

File类 --java学习笔记

在java中&#xff0c;存储数据一般有如下几种方法&#xff1a; 而它们都是内存中的数据容器它们记住的数据&#xff0c;在断电&#xff0c;或者程序终止时会丢失 这种时候就可以使用File类和Io流&#xff0c;就数据存储在文件中 File File是java.io.包下的类&#xff0c; Fi…

[Linux]基础IO(中)---理解重定向与系统调用dup2的使用、缓冲区的意义

重定向理解 在Linux下&#xff0c;当打开一个文件时&#xff0c;进程会遍历文件描述符表&#xff0c;找到当前没有被使用的 最小的一个下标&#xff0c;作为新的文件描述符。 代码验证&#xff1a; ①&#xff1a;先关闭下标为0的文件&#xff0c;在打开一个文件&#xff0c;…

基于 NGINX 的 ngx_http_geoip2 模块 来禁止国外 IP 访问网站

基于 NGINX 的 ngx_http_geoip2 模块 来禁止国外 IP 访问网站 一、安装 geoip2 扩展依赖 [rootfxkj ~]# yum install libmaxminddb-devel -y二、下载 ngx_http_geoip2_module 模块 [rootfxkj tmp]# git clone https://github.com/leev/ngx_http_geoip2_module.git三、解压模…

android 使用ollvm混淆so

使用到的工具 ndk 21.4.7075529&#xff08;android studio上下载的&#xff09;cmake 3.10.2.4988404&#xff08;android studio上下载的&#xff09;llvm-9.0.1llvm-mingw-20230130-msvcrt-x86_64.zipPython 3.11.5 环境配置 添加cmake mingw环境变量如下图: 编译 下载…

代码随想录算法训练营第四十一天|343. 整数拆分,96. 不同的二叉搜索树

343. 整数拆分 题目 给定一个正整数 n &#xff0c;将其拆分为 k 个 正整数 的和&#xff08; k > 2 &#xff09;&#xff0c;并使这些整数的乘积最大化。 返回 你可以获得的最大乘积 。 示例 输入: n 10 输出: 36 解释: 10 3 3 4, 3 3 4 36。 解题思路 dp[i] …

Python读取Excel根据每行信息生成一个PDF——并自定义添加文本,可用于制作准考证

文章目录 有点小bug的:最终代码(无换行):有换行最终代码无bug根据Excel自动生成PDF,目录结构如上 有点小bug的: # coding=utf-8 import pandas as pd from reportlab.pdfgen import canvas from reportlab.lib.pagesizes import letter from reportlab.pdfbase import pdf…

每日五道java面试题之消息中间件MQ篇(三)

目录&#xff1a; 第一题. 如何确保消息正确地发送至 RabbitMQ&#xff1f; 如何确保消息接收方消费了消息&#xff1f;第二题. 如何保证RabbitMQ消息的可靠传输&#xff1f;第三题. 为什么不应该对所有的 message 都使用持久化机制&#xff1f;第四题. 如何保证高可用的&#…

腾讯云2024年4月优惠券及最新活动入口

腾讯云是腾讯集团倾力打造的云计算品牌&#xff0c;提供全球领先的云计算、大数据、人工智能等技术产品与服务。为了吸引用户上云&#xff0c;腾讯云经常推出各种优惠活动。本文将为大家分享腾讯云优惠券及最新活动入口&#xff0c;助力大家轻松上云&#xff01; 一、优惠券领取…

IO-DAY4

使用文件IO 实现父进程向子进程发送信息&#xff0c;并总结中间可能出现的各种问题 #include<myhead.h> char* my_write(char *buf) {int wfdopen("./write.txt",O_WRONLY|O_CREAT|O_TRUNC,0666);write(wfd,buf,sizeof(buf));close(wfd);return buf; } char* …

一.基本指令(1.1)

一、操作系统&#xff1a; 1.1本质&#xff1a; 操作系统是一款进行软硬件资源管理的软件。 1.2操作系统如何管理硬件&#xff1a; 硬件接入电脑&#xff0c;操作系统装载硬件的驱动之后&#xff0c;硬件就会被纳入操作系统的管理体系。因此&#xff0c;有时一些硬件初次接入电…

HTTPS跟HTTP有区别吗?

HTTPS和HTTP的区别&#xff0c;白话一点说就是&#xff1a; 1. 安全程度&#xff1a; - HTTP&#xff1a;就像是你和朋友面对面聊天&#xff0c;说的话大家都能听见&#xff08;信息明文传输&#xff0c;容易被偷听&#xff09;。 - HTTPS&#xff1a;就像是你们俩戴着加密耳机…