华为开源自研AI框架昇思MindSpore应用案例:人体关键点检测模型Lite-HRNet

news2024/11/20 15:11:59

如果你对MindSpore感兴趣,可以关注昇思MindSpore社区

在这里插入图片描述

在这里插入图片描述

一、环境准备

1.进入ModelArts官网

云平台帮助用户快速创建和部署模型,管理全周期AI工作流,选择下面的云平台以开始使用昇思MindSpore,获取安装命令,安装MindSpore2.0.0-alpha版本,可以在昇思教程中进入ModelArts官网

在这里插入图片描述

选择下方CodeLab立即体验

在这里插入图片描述

等待环境搭建完成

在这里插入图片描述

2.使用CodeLab体验Notebook实例

下载NoteBook样例代码,Lite-HRNet实现人体关键点检测 ,.ipynb为样例代码

在这里插入图片描述

选择ModelArts Upload Files上传.ipynb文件

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

选择Kernel环境

在这里插入图片描述

切换至GPU环境,切换成第一个限时免费

在这里插入图片描述

进入昇思MindSpore官网,点击上方的安装

在这里插入图片描述

获取安装命令

在这里插入图片描述

回到Notebook中,在第一块代码前加入命令
在这里插入图片描述

conda update -n base -c defaults conda

在这里插入图片描述

安装MindSpore 2.0 GPU版本

conda install mindspore=2.0.0a0 -c mindspore -c conda-forge

在这里插入图片描述

安装mindvision

pip install mindvision

在这里插入图片描述

安装下载download

pip install download

人体关键点检测模型Lite-HRNet

人体关键点检测是计算机视觉的基本任务之一,在许多应用场景诸如自动驾驶、安防等有着重要的地位。可以发现,在这些应用场景下,深度学习模型可能需要部署在IoT设备上,这些设备算力较低,存储空间有限,无法支撑太大的模型,因此轻量但不失高性能的人体关键点检测级模型将极大降低模型部署难度。Lite-HRNet便提供了一轻量级神经网络骨干,通过接上不同的后续模型可以完成不同的任务,其中便包括人体关键点检测,在配置合理的情况下,Lite-HRNet可以以大型神经网络数十分之一的参数量及计算量达到相近的性能。

模型简介

Lite-HRNet由HRNet(High-Resolution Network)改进而来,HRNet的主要思路是在前向传播过程中通过维持不同分辨率的特征,使得最后生成的高阶特征既可以保留低分辨率高阶特征中的图像语义信息,也可以保留高分辨率高阶特征中的物体位置信息,进而提高在分辨率敏感的任务如语义分割、姿态检测中的表现。Lite-HRNet是HRNet的轻量化改进,改进了HRNet中的卷积模块,将HRNet中的参数量从28.5M降低至1.1M,计算量从7.1GFLOPS降低至0.2GFLOPS,但AP75仅下降了7%。
综上,Lite-HRNet具有计算量、参数量低,精度可观的优点,有利于部署在物联网低算力设备上服务于各个应用场景。

数据准备

本案例使用COCO2017数据集作为训练、验证数据集,请首先安装Mindspore Vision套件,并确保安装的Mindspore是GPU版本,随后请在https://cocodataset.org/ 上下载好2017 Train Images、2017 Val Images以及对应的标记2017 Train/Val Annotations,并解压至当前文件夹,文件夹结构下表所示

Lite-HRNet/
    ├── imgs
    ├── src
    ├── annotations
        ├──person_keypoints_train2017.json
        └──person_keypoints_train2017.json
    ├── train2017
    └── val2017

训练、测试原始图片如下所示,图片中可能包含多个人体,且包含的人体不一定包含COCO2017中定义的17个关键点,标注中有每个人体的边框、关键点信息,以便处理图像后供模型训练。

数据预处理

src/mindspore_coco.py中定义了供mindspore模型训练、测试的COCO数据集接口,在加载训练数据集时只需指定所用数据集文件夹位置、输入图像的尺寸、目标热力图的尺寸、以及手动设置对训练图像采用的变换即可


import mindspore as ms
import mindspore.dataset as dataset
import mindspore.dataset.vision.py_transforms as py_vision
import mindspore.nn as nn
from mindspore.dataset.transforms.py_transforms import Compose

from src.configs.dataset_config import COCOConfig
from src.dataset.mindspore_coco import COCODataset

cfg = COCOConfig(root="./", output_dir="outputs/", image_size=[192, 256], heatmap_size=[48, 64])
trans = Compose([py_vision.ToTensor(),
                 py_vision.Normalize(
                     mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
train_ds = COCODataset(cfg, "../", "train2017", True, transform=trans)
train_loader = dataset.GeneratorDataset(train_ds, ["data", "target", "weight"])

在这里插入图片描述

构建网络

Lite-HRNet网络骨干大体结构如下图所示:
在这里插入图片描述

网络中存在不同的分辨率分支,网络主干上维持着较高分辨率、较少通道数的输出特征,网络分支上延展出较低分辨率、较多通道数的输出特征,且这些不同分辨率的特征之间通过上采样、下采样的卷积层进行交互、融合。Stage内的Cross Channel Weighting(CCW)则是网络实现轻量化的精髓,它将原HRNet中复杂度较高的1*1卷积以更低复杂度的Spatial Weighting等方法替代,从而实现降低网络参数、计算量的效果。CCW的结构如下图所示

在这里插入图片描述

值得注意的是,除了骨干网络,作者在论文中同时也给出了所使用的检测头即SimpleBaseline,为了简洁起见,在本次的Lite-HRNet的Mindspore实现中,检测头(代码中包括IterativeHeads和LiteTopDownSimpleHeatMap)已集成至骨干网络之后,作为整体模型的一部分,直接调用模型即可得到热力图预测输出。

损失函数

此处使用损失函数为JointMSELoss,即关节点的均方差误差损失函数,其源码如下所示,总体流程即计算每个关节点预测热力图与实际热力图的均方差,其中target是根据关节点的人工标注坐标,通过二维高斯分布生成的热力图,target_weight用于指定参与计算的关节点,若某关节点对应target_weight取值为0,则表明该关节点在输入图像中未出现,不参与计算。

"""JointMSELoss"""
import mindspore.nn as nn
import mindspore.ops as ops

class JointsMSELoss(nn.Cell):
    """Joint MSELoss"""
    def __init__(self, use_target_weight):
        """JointMSELoss"""
        super(JointsMSELoss, self).__init__()
        self.criterion = nn.MSELoss(reduction='mean')
        self.use_target_weight = use_target_weight

    def construct(self, output, target, weight):
        """construct"""
        target = target
        target_weight = weight
        batch_size = output.shape[0]
        num_joints = output.shape[1]
        spliter = ops.Split(axis=1, output_num=num_joints)
        mul = ops.Mul()
        heatmaps_pred = spliter(output.reshape((batch_size, num_joints, -1)))
        heatmaps_gt = spliter(target.reshape((batch_size, num_joints, -1)))
        loss = 0

        for idx in range(num_joints):
            heatmap_pred = heatmaps_pred[idx].squeeze()
            heatmap_gt = heatmaps_gt[idx].squeeze()
            if self.use_target_weight:
                heatmap_pred = mul(heatmap_pred, target_weight[:, idx])
                heatmap_gt = mul(heatmap_gt, target_weight[:, idx])
                loss += 0.5 * self.criterion(
                    heatmap_pred,
                    heatmap_gt
                )
            else:
                loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)

        return loss/num_joints

模型实现与训练

在实现模型时,需指定模型内部结构,在src/net_configs中已指定原论文中10种结构配置,在训练样例种取Lite_18_coco作为模型结构,此处作为案例,仅设置epoch数量为1,在实际训练中可以设置为200,并且可以加入warmup。由于mindspore的训练接口默认数据集中每条数据只有两列(即训练数据和标签),所以这里需自定义Loss Cell。值得注意的是loss在训练前后变化并不会十分大,训练好的模型的loss为0.0004左右

class CustomWithLossCell(nn.Cell):

    def __init__(self,
                 net: nn.Cell,
                 loss_fn: nn.Cell):
        super(CustomWithLossCell, self).__init__()
        self.net = net
        self._loss_fn = loss_fn

    def construct(self, img, target, weight):
        """ build network """
        heatmap_pred = self.net(img)
        return self._loss_fn(heatmap_pred,
                             target,
                             weight)
from src.configs.net_configs import get_netconfig
from mindspore.train.callback import  LossMonitor
from src.backbone import LiteHRNet

ext = get_netconfig("extra_lite_18_coco")
net = LiteHRNet(ext)
criterion = JointsMSELoss(use_target_weight=True)

train_loader = train_loader.batch(64)
optim = nn.Adam(net.trainable_params(), learning_rate=2e-3)
loss = JointsMSELoss(use_target_weight=True)
net_with_loss = CustomWithLossCell(net, loss)

model = ms.Model(network=net_with_loss, optimizer=optim)
epochs = 1
#Start Training
model.train(epochs, train_loader, callbacks=[LossMonitor(100)], dataset_sink_mode=False)

在这里插入图片描述

模型评估

模型评估过程中使用AP、AP50、AP75以及AR50、AR75作为评价指标,val2017作为评价数据集,pycocotool包中已实现根据评价函数,且src/mindspore_coco.py中的evaluate函数也实现了调用该评价函数的接口,只需提供预测关键点坐标等信息即可获得评价指标。此处载入Lite_18_coco的预训练模型进行评价。

from mindspore import load_checkpoint
from mindspore import load_param_into_net

from src.utils.utils import get_final_preds
import numpy as np

def evaluate_model(model, dataset, output_path):
    """Evaluate"""
    num_samples = len(dataset)
    all_preds = np.zeros(
        (num_samples, 17, 3),
        dtype=np.float32
        )

    all_boxes = np.zeros((num_samples, 6))
    image_path = []

    for i, data in enumerate(dataset):
        input_data, target, meta = data[0], data[1], data[3]
        input_data = ms.Tensor(input_data[0], ms.float32).reshape(1, 3, 256, 192)
        shit = model(input_data).asnumpy()
        target = target.reshape(shit.shape)
        c = meta['center'].reshape(1, 2)
        s = meta['scale'].reshape(1, 2)
        score = meta['score']
        preds, maxvals = get_final_preds(shit, c, s)
        all_preds[i:i + 1, :, 0:2] = preds[:, :, 0:2]
        all_preds[i:i + 1, :, 2:3] = maxvals
        # double check this all_boxes parts
        all_boxes[i:i + 1, 0:2] = c[:, 0:2]
        all_boxes[i:i + 1, 2:4] = s[:, 0:2]
        all_boxes[i:i + 1, 4] = np.prod(s*200, 1)
        all_boxes[i:i + 1, 5] = score
        image_path.append(meta['image'])

    dataset.evaluate(0, all_preds, output_path, all_boxes, image_path)

net_dict = load_checkpoint("./ckpt/litehrnet_18_coco_256x192.ckpt")
load_param_into_net(net, net_dict)

eval_ds = COCODataset(cfg, "./", "val2017", False, transform=trans)
evaluate_model(net, eval_ds, "./result")

在这里插入图片描述

模型推理

  1. Lite-HRNet是关键点检测模型,所以输入待推理图像应为包含单个人体的图像,作者在论文中提及在coco test 2017测试前已使用SimpleBaseline生成的目标检测Bounding Box处理图像,所以待推理图像应仅包含单个人体。
  2. 网络的输入为(1,3,256,192),所以在输入图像前应先将其变换成网络可处理的形式。
import cv2
from src.utils.utils import get_max_preds
origin_img = cv2.imread("./imgs/man.jpg")
origin_h, origin_w, _ = origin_img.shape
scale_factor = [origin_w/192, origin_h/256]

# resize to (112 112 3) and convert to tensor
img = cv2.resize(origin_img, (192, 256))
print(img.shape)
img = trans(img)
# img = np.expand_dims(img, axis=0)
img = ms.Tensor(img)
print(img.shape)

# Infer
heatmap_pred = net(img).asnumpy()
pred, _ = get_max_preds(heatmap_pred)

# Postprocess
pred = pred.reshape(pred.shape[0], -1, 2)
print(pred[0])
pre_landmark = pred[0] * 4 * scale_factor
# Draw points
for (x, y) in pre_landmark.astype(np.int32):
    cv2.circle(origin_img, (x, y), 3, (255, 255, 255), -1)

# Save image
cv2.imwrite("./imgs/man_infer.jpg", origin_img)

在这里插入图片描述

可以看到模型基本正确标注出了关键点的位置\

在这里插入图片描述

算法基本流程

  1. 获取原始数据
  2. 从数据集的标注json文件中得到各个图像bbox以及关键点坐标信息
  3. 根据bbox裁剪图像,并放缩至指定尺寸,如果是训练还可以作适当数据增强,生成指定尺寸的目标热力图
  4. 指定尺寸的输入经过网络前向传播后得到预测的关键点热力图
  5. 经过处理后取热力图中的最大值坐标作为关键点的预测坐标

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2244106.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一道算法期末应用题及解答

1.印刷电路板布线区划分成为n m 个方格,确定连接方格a 到方格b 的最短布线方案。 在布线时,只能沿直线或者直角布线,为避免交叉,已经布线的方格做了封锁标记,其他线路不允许穿过被封锁的方格,某…

Springboot项目搭建(1)-用户登录与注册

1.引入lombok依赖 若<dependency>中数据为红&#xff0c;则说明Maven本地仓库里未引用依赖 可在右侧“m”标识中&#xff0c;下载源代码和文档后刷新。 2.统一响应数据Result 在entity文档下创建&#xff0c;名为Result的java类 文件地址&#xff1a;org/example/enti…

用go语言后端开发速查

文章目录 一、发送请求和接收请求示例1.1 发送请求1.2 接收请求 二、发送form-data格式的数据示例 用go语言发送请求和接收请求的快速参考 一、发送请求和接收请求示例 1.1 发送请求 package mainimport ("bytes""encoding/json""fmt""ne…

【视频讲解】Python深度神经网络DNNs-K-Means(K-均值)聚类方法在MNIST等数据可视化对比分析...

全文链接&#xff1a;https://tecdat.cn/?p38289 分析师&#xff1a;Cucu Sun 近年来&#xff0c;由于诸如自动编码器等深度神经网络&#xff08;DNN&#xff09;的高表示能力&#xff0c;深度聚类方法发展迅速。其核心思想是表示学习和聚类可以相互促进&#xff1a;好的表示会…

可视化展示深度学习模型中模块的详细流程图:结合GraphvizOnline

一、在GPT中输入指令 根据以下Python模块代码&#xff0c;自动生成对应的Graphviz流程图代码&#xff0c;并保持图表简洁清晰&#xff0c;仅展示主流程&#xff1a; <模块代码>1. 以YOLOv9中ADown下采样为例&#xff1a; 根据以下Python模块代码&#xff0c;自动生成对…

强大的正则表达式——Hard

由前两篇文章《Easy》中提到过的&#xff1a; 还是先相信一下AI&#xff0c;让AI写个生成满足难度3的正则表达式的python代码&#xff0c;但还是出错了&#xff0c;还是不能什么都指望AI 了解了一下相关知识&#xff0c;CRC本质上是多项式除法&#xff0c;所以同样可以得到对应…

Xilinx 7 系列 FPGA的各引脚外围电路接法

Xilinx 7系列FPGA的外围电路接法涉及到多个方面&#xff0c;包括电源引脚、时钟输入引脚、FPGA配置引脚、JTAG调试引脚&#xff0c;以及其他辅助引脚。 本文大部分内容由ug475, Product Specification——7 Series FPGAs Packaging and Pinout《7系列FPGA的封装与引脚》整理汇…

IDM扩展添加到Edge浏览器

IDM扩展添加到Edge浏览器 一般情况下&#xff0c;当安装IDM软件后&#xff0c;该软件将会自动将IDM Integration Module浏览器扩展安装到Edge浏览器上&#xff0c;但在某些情况下&#xff0c;需要我们手动安装&#xff0c;以下为手动安装步骤 手动安装IDM扩展到Edge浏览器 打…

使用OpenUI智能生成专业级网页UI实现远程高效前端开发新手指南

文章目录 前言1. 本地部署Open UI1.1 安装Git、Python、pip1.2 安装Open UI 2. 本地访问Open UI3. 安装Cpolar内网穿透4. 实现公网访问Open UI5. 固定Open UI 公网地址 前言 今天给大家带来一篇非常实用的技术分享&#xff0c;介绍如何在Windows系统本地部署OpenUI&#xff0c…

Vue3 虚拟列表组件库 virtual-list-vue3 的使用

Vue3 虚拟列表组件库 virtual-list-vue3 的基本使用 分享个人写的一个基于 Vue3 的虚拟列表组件库&#xff0c;欢迎各位来进行使用与给予一些更好的建议&#x1f60a; 概述&#xff1a;该组件组件库用于提供虚拟化列表能力的组件&#xff0c;用于解决展示大量数据渲染时首屏渲…

数据库中库的操作

数据库中库的操作 查看数据库语法 创建数据库语法⽰例创建⼀个名为test班级号的数据库⾃定义⼀个数据库名&#xff0c;如果数据库不存则创建重新运⾏上⾯的语句观察现象查看警告信息 字符集编码和校验(排序)规则查看数据库⽀持的字符集编码查看数据库⽀持的排序规则不同的字串集…

【MySQL-3】表的约束

目录 1. 整体学习的思维导图 2. 非空约束 3. default约束 4. No Null和default约束 5. 列描述 comment 6. Zerofill 7. 主键 primary key 复合主键 8. 自增长 auto_increment 9. 唯一键 10. 外键 11. 实现综合案例 1. 整体学习的思维导图 2. 非空约束 正如该标题一…

C++设计模式行为模式———迭代器模式

文章目录 一、引言二、迭代器模式三、总结 一、引言 迭代器模式是一种行为设计模式&#xff0c; 让你能在不暴露集合底层表现形式 &#xff08;列表、 栈和树等&#xff09; 的情况下遍历集合中所有的元素。C标准库中内置了很多容器并提供了合适的迭代器&#xff0c;尽管我们不…

自存 sql常见语句和实际应用

关于连表 查询两个表 SELECT * FROM study_article JOIN study_article_review 查询的就是两个表相乘&#xff0c;结果为两个表的笛卡尔积 相这样 这种并不是我们想要的结果 通常会添加一些查询条件 SELECT * FROM study_articleJOIN study_article_review ON study_art…

为自动驾驶提供高分辨率卫星图像数据,实例级标注数据集OpenSatMap

对于交通控制、自动驾驶等任务来说&#xff0c;大规模的高分辨率与更新频率的地图至关重要。现有的地图构建方法多依赖地面采集数据&#xff0c;这种方法的精度固然较高&#xff0c;但在覆盖范围、更新频率却存在限制&#xff0c;测绘成本也相当高昂。 相比之下&#xff0c;使…

基于STM32的智能语音识别饮水机系统设计

功能描述 1、给饮水机设定称呼&#xff0c;喊出称呼&#xff0c;饮水机回答&#xff1a;我在 2、语音进行加热功能&#xff0c;说&#xff1a;请加热&#xff0c;加热片运行 3、饮水机水位检测&#xff0c;低于阈值播报“水量少&#xff0c;请换水” 4、检测饮水机水温&#xf…

百度世界2024精选公开课:基于地图智能体的导航出行AI应用创新实践

11月12日&#xff0c;“百度世界2024”在上海世博中心举行。百度创始人、董事长兼首席执行官李彦宏发表了主题为《应用来了》的演讲。 百度地图也为大家带来了干货满满、精彩纷呈的智能体公开课&#xff0c;由百度地图开放平台技术架构师江畅分享《地图智能体&#xff1a;导航…

sourceInsight常用设置和功能汇总(不断更新)(RGB、高亮、全路径、鼠标、宏、TODO高亮)

文章目录 必开配置设置背景颜色护眼的RGB值&#xff1f;sourceInsight4.0中如何设置选中某个单词以后自动高亮的功能&#xff1f;sourceinsight中输入设置显示全路径&#xff1f; 常用sourceInsight4.0中文乱码怎么解决&#xff0c;注意事项是什么&#xff1f;如何绑定鼠标中键…

[JavaWeb] 尚硅谷JavaWeb课程笔记

1 Tomcat服务器 Tomcat目录结构 bin&#xff1a;该目录下存放的是二进制可执行文件&#xff0c;如果是安装版&#xff0c;那么这个目录下会有两个exe文件&#xff1a;tomcat10.exe、tomcat10w.exe&#xff0c;前者是在控制台下启动Tomcat&#xff0c;后者是弹出GUI窗口启动To…

uniapp开发微信小程序笔记2-开发静态页面(新建页面、内置组件、设置编译模式、样式、SCSS的使用)

前言&#xff1a;本文从新建页面、认识内置组件、设置编译模式、样式、SCSS的使用来逐步形成对微信小程序开发结构的认识 一、新建页面 pages就是放页面代码的文件夹&#xff0c;点击新建页面就可以自动新增页面&#xff0c;并且可以看到pages.json里面也会自动添加该页面的路…