关键点检测——HRNet源码解析篇

news2025/1/22 15:46:49

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题

🍊专栏推荐:深度学习网络原理与实战

🍊近期目标:写好专栏的每一篇文章

🍊支持小苏:点赞👍🏼、收藏⭐、留言📩

HRNet源码解析篇

写在前面

Hello,大家好,我是小苏👦🏽👦🏽👦🏽

在上一节中,我已经为大家介绍HRNet的原理部分,其实说起来挺惭愧,因为原理部分介绍的还是比较简单的,我想你仅仅阅读原理部分是很难彻底弄明白HRNet的精髓。那么本节将在上一节的基础上,为大家更细致的讲解HRNet。🧨🧨🧨当然了,本节属于源码解析篇,所有会存在比较多的代码,大家也不用担心看不懂,我都对关键代码做了详细的解释,并画图进一步帮助大家理解,所以大家一定要耐心看下去喔。🎯🎯🎯

这里我还想多说一句,其实写源码解析类博客其实怪难的,因为有时候明明很好表达的内容,用文字为大家展示却有种一拳打在棉花上的无力感,哈哈哈,可能是自己水平不够叭。🧑🧑🧑自己也做过几期视频,感觉效果也还行,感兴趣的可以点击☞☞☞看看,后期可能会考虑写完博客出配套视频的方式为大家介绍知识点。enmmm,说远了,说这些我是想告诉大家,我并不期望通过这一篇文章就能让你把整个HRNet的源码都看透,这是不可能的。但是其中一些关键的地方,如果本文能给你一点启发,那我觉得此篇文章的目的就达到了。此外,大家在阅读代码时,一定不要停留在看的层面,一定要动手调试起来,这样会有不一样的收获。🍋🍋🍋

好了,不说废话了,让我们一起发车,来学学HRNet的源码叭~~~🚖🚖🚖

源码地址:HRNet源码🌱🌱🌱

关键点数据集构建

深度学习中数据才是王道,本文使用的是COCO数据集中的人体关键点检测数据集,对此数据集还不清楚的务必点击下面链接了解详情:

COCO数据集——关键点检测标注文件解析🌼🌼🌼

清楚COCO数据集的格式后,我们一起来看看是如何构建关键点数据集的?首先来说说这里的关键点数据集的构建主要干了什么?其实它就是把原始图像中对人体关键点标注过的图像记录了下来。

我们一点点的来看其是如何实现的,主要定义在CocoKeypoint类中:

data_root = args.data_path   # data_path:'D://Dataset//coco2017'
data_transform = {
        "train": transforms.Compose([
            transforms.HalfBody(0.3, person_kps_info["upper_body_ids"], person_kps_info["lower_body_ids"]),
            transforms.AffineTransform(scale=(0.65, 1.35), rotation=(-45, 45), fixed_size=fixed_size),
            transforms.RandomHorizontalFlip(0.5, person_kps_info["flip_pairs"]),
            transforms.KeypointToHeatMap(heatmap_hw=heatmap_hw, gaussian_sigma=2, keypoints_weights=kps_weights),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]),
        "val": transforms.Compose([
            transforms.AffineTransform(scale=(1.25, 1.25), fixed_size=fixed_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    }

fixed_size = args.fixed_size  # fixed_size:[256,192]

train_dataset = CocoKeypoint(data_root, "train", transforms=data_transform["train"], fixed_size=args.fixed_size)

上述代码主要展示了传入CocoKeypoint类中的参数,为后面阅读CocoKeypoint类中代码做准备,对本篇transforms不熟悉的可以稍后阅读本文的下一节–>在线数据增强。🥗🥗🥗

下面我们就一步步的来看看CocoKeypoint类到底干了什么?在__init__函数中,先初始化了一系列变量:

anno_file = f"person_keypoints_{dataset}{years}.json"
self.img_root = os.path.join(root, f"{dataset}{years}")
self.anno_path = os.path.join(root, "annotations", anno_file)
self.fixed_size = fixed_size
self.mode = dataset
self.transforms = transforms
self.coco = COCO(self.anno_path)

我们调试来看看这些值的结果:

image-20231204135905433

这里我重点介绍一下self.coco = COCO(self.anno_path)这句代码,其传入的是self.anno_path参数,即人体关键点检测标注文件——'D://Dataset//coco2017\\annotations\\person_keypoints_train2017.json'我们跳入COCO函数内部调试一下:

首先设置一些字典变量来存储相关信息:

self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)

接着我们会打开标注文件路径并读取得到dataset:

with open(annotation_file, 'r') as f:
    dataset = json.load(f)

我们来看看dataset的值:

image-20231204141135654

dataset一共有五个字段的值,和我在COCO数据集——关键点检测标注文件解析这篇博客中介绍的是完全一致的。

接着调用createIndex方法为之前定义的字典变量赋值:

if 'annotations' in self.dataset:
    for ann in self.dataset['annotations']:
        imgToAnns[ann['image_id']].append(ann)
        anns[ann['id']] = ann

这段代码先是遍历数据集中的annotations标签,然后将其image_id作为键,标签作为值构建一个图像id到标签的字典imgToAnns,来看看其遍历一次的结果:

image-20231204143553020

接着是将标签的id作为键,标签作为值构建一个字典,同样看看遍历一次的结果:

image-20231204143840903

注意:这里的id和image_id不一样,image_id是图像的唯一标识,id是目标实例分配的唯一标识符,用于在数据集中唯一标识这个目标实例。一个图像中可能会有多个目标实列,即有多个人。

我们遍历完所有标签,看看imgToAnns和anns中有多少数据:

image-20231204144404625

image-20231204144424869

可以看到imgToAnns共有64115条数据,表示一共有64115张图像存在标注。anns共有262465条数据,表示一共有262465个标注目标实列,也就是标注了262465个人。那么为什么会存在这样的差异呢?因为不是每张图像都会有标注目标实列(一共118287张图像,有标注目标实列的有64115张),也不是每张存在标注目标实列的都只有一个标注目标实列(最少有一个,从imgToAnns数据图中可以看到,image_id为120021的图像有三个标注目标实列,说明标注了3个人,我们也可以来看看这张图像,看看是不是有3个人,如下:

image-20231204145952059)


if 'images' in self.dataset:
    for img in self.dataset['images']:
        imgs[img['id']] = img

接着这段代码是遍历数据集中的image图像,然后将id【注意:image中的id指的是image_id,而不是上文说的实列id】作为键,img图像作为值构建字典imgs,来看看遍历一次的结果:

image-20231204150549882

然后来看看遍历完所有数据imgs的结果:

image-20231204150712467

一共有118287条数据,这就是COCO训练集图片的数量。


if 'categories' in self.dataset:
    for cat in self.dataset['categories']:
        cats[cat['id']] = cat

这段代码是遍历数据集中的categories类别,然后将其id【注意:这里的id指类别的唯一标识符。在人体关键点检测中,这个id都是1,因为我们只会对人体进行标注,而person的类别标识符是1。】作为键,categories信息作为值构建cats字典,来看看遍历一次的结果:

image-20231204151430421

enmmm,dataset中只有一条数据,只能遍历一次,因为只有一个类别id,即id=1。

image-20231204151631585


if 'annotations' in self.dataset and 'categories' in self.dataset:
    for ann in self.dataset['annotations']:
        catToImgs[ann['category_id']].append(ann['image_id'])

这段代码同样遍历annotations标签,然后将 category_id【注意:这个是类别id,其为1】作为键,image_id作为值构建从类别id到图像id映射的字典 catToImgs,来看一次遍历的结果:

image-20231204152232436

然后来看看遍历完所有数据的catToImgs字典:

image-20231204152351685

一共有262465条数据,即表示有262465张图像有类别id 1(这里包括了重复的图像,比如一个图像中有3个人,那么这里就有三条数据,那么其实这里的262465表示一共标注了262465个person实列)


到这里我们的COCO(self.anno_path)函数的内容就介绍完啦,我们来看看self.coco的值,如下:

image-20231204153552424

其实其就是COCO(self.anno_path)函数中那几个字典变量。

接着我们会把imgs的key进行排序,并转成列表:

img_ids = list(sorted(self.coco.imgs.keys()))

image-20231204153830130

然后通过det = self.coco将self.coco的值赋给det,并设置一个self.valid_person_list列表用于存储有效的人体关键点信息,并设置一个obj_idx记录目标实列个数。

self.valid_person_list = []
obj_idx = 0

最后执行下面的代码:

for img_id in img_ids:
    img_info = self.coco.loadImgs(img_id)[0]
    ann_ids = det.getAnnIds(imgIds=img_id)
    anns = det.loadAnns(ann_ids)
    for ann in anns:
        # only save person class
        if ann["category_id"] != 1:
            print(ann["category_id"])

            # skip objs without keypoints annotation
            if "keypoints" in ann:
                if max(ann["keypoints"]) == 0:
                    continue

                    xmin, ymin, w, h = ann['bbox']
                    # Use only valid bounding boxes
                    if w > 0 and h > 0:
                        info = {
                            "box": [xmin, ymin, w, h],
                            "image_path": os.path.join(self.img_root, img_info["file_name"]),
                            "image_id": img_id,
                            "image_width": img_info['width'],
                            "image_height": img_info['height'],
                            "obj_origin_hw": [h, w],
                            "obj_index": obj_idx,
                            "score": ann["score"] if "score" in ann else 1.
                        }
                        if "keypoints" in ann:
                            keypoints = np.array(ann["keypoints"]).reshape([-1, 3])
                            visible = keypoints[:, 2]
                            keypoints = keypoints[:, :2]
                            info["keypoints"] = keypoints
                            info["visible"] = visible

                            self.valid_person_list.append(info)
                            obj_idx += 1

这段代码干了什么呢,我们一点点来分析:首先是遍历img_ids,第一次拿到第一个image_id=9:

image-20231204155004345

此image_id对应图像如下:

000000000009

然后执行img_info = self.coco.loadImgs(img_id)[0],loadImgs定义如下:

def loadImgs(self, ids=[]):
    if _isArrayLike(ids):
        return [self.imgs[id] for id in ids]
    elif type(ids) == int:
        return [self.imgs[ids]]

这个函数主要是根据img_id来加载图像,我们直接来看img_info的结果:

image-20231204155602708

这显示了image_id=9的图像的信息。


接着是ann_ids = det.getAnnIds(imgIds=img_id)这句代码,getAnnIds函数如下:

def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
    imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
    catIds = catIds if _isArrayLike(catIds) else [catIds]

    if len(imgIds) == len(catIds) == len(areaRng) == 0:
        anns = self.dataset['annotations']
    else:
        if not len(imgIds) == 0:
            lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns]
            anns = list(itertools.chain.from_iterable(lists))
        else:
            anns = self.dataset['annotations']
            anns = anns if len(catIds)  == 0 else [ann for ann in anns if ann['category_id'] in catIds]
            anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]]
            if not iscrowd == None:
                ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]
            else:
                ids = [ann['id'] for ann in anns]
    return ids

我们注意来看一下这句:lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns],意思是遍历imgIds,然后判断imgId是否存在self.imgToAnns中,我们知道imgToAnns共有64115条数据,表示一共有64115张图像存在标注,也就是说只有存在人物的图像才存在标注,即只有图像存在人物,其imgId才会在imgToAnns中,而我们刚刚image_id=9的图像不存在人物,故lists=[]是空列表,后续返回值会是空列表。因此我们需要换一张存在人物的图像进行展示,当遍历至image_id=36时,图像出现人物,图像如下:

000000000036

此时lists值如下,为这张图像的标注信息:

image-20231204161257749

接着anns = list(itertools.chain.from_iterable(lists))是将将嵌套的列表(lists)展平成一个单层的列表。

image-20231204161444188

后面的这几句都没起作用:

anns = anns if len(catIds)  == 0 else [ann for ann in anns if ann['category_id'] in catIds]
anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]]
if not iscrowd == None:
    ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]

然后通过ids = [ann['id'] for ann in anns]获取到标注目标实列的id,并将其返回给ann_ids,其值如下:

接着是anns = det.loadAnns(ann_ids),将ann_ids传入loadAnns方法中,其定义如下:

def loadAnns(self, ids=[]):
    if _isArrayLike(ids):
        return [self.anns[id] for id in ids]
    elif type(ids) == int:
        return [self.anns[ids]]

这个函数主要通过ann_ids来加载标注信息,返回的anns如下:

image-20231204162611755

接着是遍历anns,先检查ann[“category_id”] 是否为1并检查标注的keypoints关键点是否存在可见关键点。

for ann in anns:
    # only save person class
    if ann["category_id"] != 1:
        print(ann["category_id"])

    # skip objs without keypoints annotation
    if "keypoints" in ann:
        if max(ann["keypoints"]) == 0:
            continue

然后从标注的bbox中获取xmin, ymin, w, h,并构建info信息,注意这里score不在ann中,其最后值为1。

xmin, ymin, w, h = ann['bbox']
# Use only valid bounding boxes
if w > 0 and h > 0:
    info = {
        "box": [xmin, ymin, w, h],
        "image_path": os.path.join(self.img_root, img_info["file_name"]),
        "image_id": img_id,
        "image_width": img_info['width'],
        "image_height": img_info['height'],
        "obj_origin_hw": [h, w],
        "obj_index": obj_idx,
        "score": ann["score"] if "score" in ann else 1.
    }

然后将关键点的坐标和可见性分成两个变量表示并加入到info字典中,最后将info添加到self.valid_person_list中,并将obj_idx加1,表示多了一个目标实列 。

if "keypoints" in ann:
    keypoints = np.array(ann["keypoints"]).reshape([-1, 3])
    visible = keypoints[:, 2]
    keypoints = keypoints[:, :2]
    info["keypoints"] = keypoints
    info["visible"] = visible

self.valid_person_list.append(info)
obj_idx += 1

此循环代码结束,来看看self.valid_person_list.append的值:

image-20231204163748952

【注意:这里的图像只有一个人物,如果图像中包含多个人物原理是一样的,会遍历图像中的各个人物,并把每个人物的信息存放到valid_person_list中】

当我们遍历完所有img_ids数据时,self.valid_person_list.append就存储了所有有效的人体关键点检测的相关信息,最后一共有149813个有效数据。

image-20231204164659648


到这里,关键点检测数据集的构建部分就为大家介绍完了,这部分说难也算不上难,但我认为却是非常重要的一部分,希望大家好好消化一下。🍻🍻🍻

在线数据增强

enmmm,我想大家应该对数据增强有一定的了解叭,比如旋转、剪裁、水平翻转等等,在代码中通常使用transforms.xxx来实现,在之前的博客中,我为大家介绍过一些数据增强,如Faster RCNN中为大家介绍了水平翻转。

这一小节我准备多花点时间来写,因为这部分有的地方是比较难理解,也是蛮重要的。话不多说,我们一起来看看HRNet中使用了哪些数据增强手段:

data_transform = {
        "train": transforms.Compose([
            transforms.HalfBody(0.3, person_kps_info["upper_body_ids"], person_kps_info["lower_body_ids"]),
            transforms.AffineTransform(scale=(0.65, 1.35), rotation=(-45, 45), fixed_size=fixed_size),
            transforms.RandomHorizontalFlip(0.5, person_kps_info["flip_pairs"]),
            transforms.KeypointToHeatMap(heatmap_hw=heatmap_hw, gaussian_sigma=2, keypoints_weights=kps_weights),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]),
        "val": transforms.Compose([
            transforms.AffineTransform(scale=(1.25, 1.25), fixed_size=fixed_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    }

其训练集和测试集采用不同的数据增强手段,测试集使用的数据增强方法是训练集的子集,所以直接来看训练集中的方法就好了,一共有以下四个:

  • transforms.HalfBody
  • transforms.AffineTransform
  • transforms.RandomHorizontalFlip
  • transforms.KeypointToHeatMap

注意: transforms.ToTensor和transforms.Normalize不属于数据增强,而是处于数据预处理,对这两个不熟悉的可以点击下面的链接了解详情:

pytorch中的transforms.ToTensor和transforms.Normalize理解🍁🍁🍁


下面我将来一个个的为大家介绍这四种数据增强方式,快来和我一起学学叭~~~🥂🥂🥂

在具体介绍每种数据增强方法之前,我先来给大家展示一下本次调试使用的图片,如下:

image-20231123192226619

其shape为(426,640,3),为COCO验证集的第一张图片。

transforms.HalfBody

HalfBody——一半的身体,大家可以猜猜这个数据增强手段干了什么?好叭,不卖关子了,一句话解释它干了什么,就是以一定的概率让人体关键点保留上半部分或者下半部分。你或许会疑惑为什么要把完整的人切分成上半部分和下半部分,这是为了模拟关键点检测中的部分遮挡情况。在实际场景中,人物可能被其他对象或者场景的遮挡,这样的情况会使得关键点检测更加具有挑战性。而我们使用HalfBody数据增强,只使用部分身体进行训练,可以实现类似遮挡的效果。这可以帮助模型学习如何处理部分遮挡的情况,提高模型在真实场景中的鲁棒性。

知道了HalfBody的原理,下面就来看看这个HalfBody类是如何实现的:

首先来看看__init__方法:

def __init__(self, p: float = 0.3, upper_body_ids=None, lower_body_ids=None):
    assert upper_body_ids is not None
    assert lower_body_ids is not None
    self.p = p
    self.upper_body_ids = upper_body_ids
    self.lower_body_ids = lower_body_ids

这个初始化方法主要定义了一个概率p,upper_body_ids和lower_body_ids。upper_body_ids和lower_body_ids是指人的身体的上部和下部的索引,0,1,2,3,4,5,6,7,8,9,10为上,其它为下:

image-20231123194345538

为了让大家直观的感受,作图如下:

image-20231123194632684

注意:上图索引是从1开始的,代码索引是从0开始的


接着来看__call__方法:

def __call__(self, image, target):
    if random.random() < self.p:
        kps = target["keypoints"]
        vis = target["visible"]
        upper_kps = []
        lower_kps = []

        # 对可见的keypoints进行归类
        for i, v in enumerate(vis):
            if v > 0.5:
                if i in self.upper_body_ids:
                    upper_kps.append(kps[i])
                else:
                    lower_kps.append(kps[i])

        # 50%的概率选择上或下半身
        if random.random() < 0.5:
            selected_kps = upper_kps
        else:
            selected_kps = lower_kps

        # 如果点数太少就不做任何处理
        if len(selected_kps) > 2:
            selected_kps = np.array(selected_kps, dtype=np.float32)
            xmin, ymin = np.min(selected_kps, axis=0).tolist()  # 寻找x,y坐标的最小值
            xmax, ymax = np.max(selected_kps, axis=0).tolist()  # 寻找x,y坐标的最大值
            w = xmax - xmin
            h = ymax - ymin
            if w > 1 and h > 1:
                # 把w和h适当放大点,要不然关键点处于边缘位置
                xmin, ymin, w, h = scale_box(xmin, ymin, w, h, (1.5, 1.5))
                target["box"] = [xmin, ymin, w, h]

    return image, target

对上述代码做相关解释,首先以一定概率p(0.3)对图像进行HalfBody操作,若满足条件,获取关键点坐标和可见性,如下:

image-20231123195501184

image-20231123195557711

接着按照上半身和下半身对17个关键点进行分类,结果如下:【注意:这里只有15个关键点,因为vis表格有两个0值,表示有两个点没有标注,被if v > 0.5过滤掉了】

image-20231123195901289

上半身一个9个关键点,下半身一共6个关键点,这就是根据upper_body_ids和lower_body_ids来划分的。【注:大家这里要是不理解一定要自己调试看看】

然后会以0.5的概率选择上半身的关键点或者下半身的关键点。如果发现选择的一半身体的关键点个数小于等于2个,则不做任何处理,返回原有的image和target。若关键点个数大于2,则执行以下代码:

if len(selected_kps) > 2:
    selected_kps = np.array(selected_kps, dtype=np.float32)
    xmin, ymin = np.min(selected_kps, axis=0).tolist()  # 寻找x,y坐标的最小值
    xmax, ymax = np.max(selected_kps, axis=0).tolist()  # 寻找x,y坐标的最大值
    w = xmax - xmin
    h = ymax - ymin
    if w > 1 and h > 1:
    # 把w和h适当放大点,要不然关键点处于边缘位置
    xmin, ymin, w, h = scale_box(xmin, ymin, w, h, (1.5, 1.5))
    target["box"] = [xmin, ymin, w, h]

我先来介绍一下这段代码主要干了什么,其实就是找到新的目标(上半身或下半身)的bbox,我画图为大家解释一下:

关键点-第 2 页.drawio

这里我们其实是可以得到一个bbox了,但是其太靠近物体边缘了,放大1.5倍,代码如下:

def scale_box(xmin: float, ymin: float, w: float, h: float, scale_ratio: Tuple[float, float]):
    """根据传入的h、w缩放因子scale_ratio,重新计算xmin,ymin,w,h"""
    s_h = h * scale_ratio[0]
    s_w = w * scale_ratio[1]
    xmin = xmin - (s_w - w) / 2.
    ymin = ymin - (s_h - h) / 2.
    return xmin, ymin, s_w, s_h

同样画个图帮大家理解,如下:

关键点-第 3 页.drawio

transforms.AffineTransform

AffineTransform——仿射变化,这个是干什么的,我来帮大家解释一下这个仿射变换干了什么,其实就是需要原始图像和目标图像中三个对应点**(代码中使用的是图像中心点,上边界中心点和右边界中心点)**,然后通过这三个点将原始图像变换倒目标图像。

enmmm,大家是不是没怎么明白,别急,我会带大家看看代码,并可视化输出结果,这样大家就能直观的感受到仿射变换到底做了什么了。

首先第一步会调整上一步骤得到的bbox的长宽比,使其符合h:w=256:192,这个是我们输入图片的尺寸,具体代码如下:

src_xmin, src_ymin, src_xmax, src_ymax = adjust_box(*target["box"], self.fixed_size)

def adjust_box(xmin: float, ymin: float, w: float, h: float, fixed_size: Tuple[float, float]):
    """通过增加w或者h的方式保证输入图片的长宽比固定"""
    xmax = xmin + w
    ymax = ymin + h

    hw_ratio = fixed_size[0] / fixed_size[1]
    if h / w > hw_ratio:
        # 需要在w方向padding
        wi = h / hw_ratio
        pad_w = (wi - w) / 2
        xmin = xmin - pad_w
        xmax = xmax + pad_w
    else:
        # 需要在h方向padding
        hi = w * hw_ratio
        pad_h = (hi - h) / 2
        ymin = ymin - pad_h
        ymax = ymax + pad_h

    return xmin, ymin, xmax, ymax

我们可以看看调整bbox后的图像,如下:

image-20231124155146491

你可能看不出bbox的长宽比,但其就是256:192,其高度h为71.507,宽度w为53.630,精度上有点差别,不相信的大家自己去调试试试喔~~~🍡🍡🍡

接着我们就寻找原图像(bbox)和目标图像的三个点:

src_w = src_xmax - src_xmin
src_h = src_ymax - src_ymin

src_center = np.array([(src_xmin + src_xmax) / 2, (src_ymin + src_ymax) / 2])
src_p2 = src_center + np.array([0, -src_h / 2])  # top middle
src_p3 = src_center + np.array([src_w / 2, 0])   # right middle

dst_center = np.array([(self.fixed_size[1] - 1) / 2, (self.fixed_size[0] - 1) / 2])
dst_p2 = np.array([(self.fixed_size[1] - 1) / 2, 0])  # top middle
dst_p3 = np.array([self.fixed_size[1] - 1, (self.fixed_size[0] - 1) / 2])

然后对bbox进行缩放和旋转,先是缩放,

if self.scale is not None:
    scale = random.uniform(*self.scale)
    src_w = src_w * scale
    src_h = src_h * scale
    src_p2 = src_center + np.array([0, -src_h / 2])  # top middle
    src_p3 = src_center + np.array([src_w / 2, 0])   # right middle

我们来看看缩放后的bbox: 【注意这里我调试时的scale取0.7115,所以bbox变小了】

image-20231124155839448

然后是旋转:【注意这里我调试时的angle取-25,所以bbox逆时针旋转了25°】

if self.rotation is not None:
    angle = random.randint(*self.rotation)  # 角度制
    angle = angle / 180 * math.pi  # 弧度制
    src_p2 = src_center + np.array([src_h / 2 * math.sin(angle), -src_h / 2 * math.cos(angle)])
    src_p3 = src_center + np.array([src_w / 2 * math.cos(angle), src_w / 2 * math.sin(angle)])

我们再来看看旋转后的结果:

image-20231124160500978

最后就是仿射变换了:

src = np.stack([src_center, src_p2, src_p3]).astype(np.float32)
dst = np.stack([dst_center, dst_p2, dst_p3]).astype(np.float32)

trans = cv2.getAffineTransform(src, dst)  # 计算正向仿射变换矩阵
dst /= 4  # 网络预测的heatmap尺寸是输入图像的1/4
reverse_trans = cv2.getAffineTransform(dst, src)  # 计算逆向仿射变换矩阵,方便后续还原


# 对图像进行仿射变换
resize_img = cv2.warpAffine(img,
                            trans,
                            tuple(self.fixed_size[::-1]),  # [w, h]
                            flags=cv2.INTER_LINEAR)

同样的我们来看看最后的resize_img长什么样,如下:【resize_img的大小是256*192喔】

image-20231124160953706

到这里我们对图像的操作就完成了,不要忘记我们还要对标签进行同样的操作喔,关键点检测的标签就是一个个点嘛,如下:

if "keypoints" in target:
    kps = target["keypoints"]
    mask = np.logical_and(kps[:, 0] != 0, kps[:, 1] != 0)
    kps[mask] = affine_points(kps[mask], trans)
    target["keypoints"] = kps
    
def affine_points(pt, t):
    ones = np.ones((pt.shape[0], 1), dtype=float)
    pt = np.concatenate([pt, ones], axis=1).T
    new_pt = np.dot(t, pt)
    return new_pt.T
           

最后再来用一个图来总结一下仿射变换都做了什么,如下:

关键点-第 4 页.drawio

transforms.RandomHorizontalFlip

RandomHorizontalFlip——水平翻转。我想对这个数据增强手段大家都比较熟悉,就是将图片左右进行翻转,其最后实现的效果如下图所示:

image-20231124211610488

和仿射变换一样,要实现水平翻转,我们不仅需要对图片进行水平翻转操作,同样需要对标签进行同步操作,我们分别来看看如何对图片和标签进行水平翻转操作的叭。

  • 图片

    对图片进行水平翻转的操作很简单啦,只需要一行代码就可以了喔,如下:

    image = np.ascontiguousarray(np.flip(image, axis=[1]))    # 水平左右翻转
    
  • 标签

    对标签进行翻转是这步的难点,我当时阅读这部分的代码时弄了很长时间才明白,其实要画一个图大家就能很容易的理解。我们先来看代码叭:【注:我没有复制所有代码过来了,挑了关键的代码】

    # Flip horizontal
    keypoints[:, 0] = width - keypoints[:, 0] - 1
    

    这句代码什么意思呢?其实就是将关键点水平翻转了一下,作图帮大家理解:

    大家会不会认为这样就结束了呢,其实很没有,我们来看看代码中还做了什么,如下:

    # matched_parts这些值对应COCO人体关键点,交换人体关键点中对称的点,这个是person_keypoints.json文件中的flip_pairs
    # Change left-right parts
    for pair in self.matched_parts:
        keypoints[pair[0], :], keypoints[pair[1], :] = \
        keypoints[pair[1], :], keypoints[pair[0], :].copy()
    
        visible[pair[0]], visible[pair[1]] = \
        visible[pair[1]], visible[pair[0]].copy()
    

    这步交换了人体关键点中堆成的关键点,为什么要这么做,我当时就是这里疑惑了好久好久,我来画个图帮大家理解一下:

    image-20231124215904452

这样你可能还看不出端倪,我在画出水平翻转后的图像,如下:

image-20231124215958175

你会发现如果单纯的将两个点对应过来,左右关系会对调,因此需要把标签进行左右互换。【大家这里如果觉得不好理解的话,可以自己动动手,画画图,相信你会有所收获】

transforms.KeypointToHeatMap

KeypointToHeatMap——将关键点映射为热力图。我们在理论部分说到,HRNet是基于热力图实现关键点检测,不清楚的可以去原理详解篇寻找寻找答案。那么其是怎么将关键点映射成热力图的呢,我们一起来看看代码是怎么实现的叭。

首先,先来看看其__init__函数:

def __init__(self,
             heatmap_hw: Tuple[int, int] = (256 // 4, 192 // 4),
             gaussian_sigma: int = 2,
             keypoints_weights=None):
    self.heatmap_hw = heatmap_hw
    self.sigma = gaussian_sigma
    self.kernel_radius = self.sigma * 3
    self.use_kps_weights = False if keypoints_weights is None else True
    self.kps_weights = keypoints_weights

    # generate gaussian kernel(not normalized)
    kernel_size = 2 * self.kernel_radius + 1
    kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
    x_center = y_center = kernel_size // 2
    for x in range(kernel_size):
        for y in range(kernel_size):
            kernel[y, x] = np.exp(-((x - x_center) ** 2 + (y - y_center) ** 2) / (2 * self.sigma ** 2))
            # print(kernel)

            self.kernel = kernel

这段主要定义了存储热力图的宽度和高度、高斯标准差和关键点权重等信息,然后生成了一个大小为13*13的高斯核kernel(中间的值大,往四周扩散值越来越小),如下图所示:

image-20231127195949780

接着我们来看__call__函数:

def __call__(self, image, target):
    kps = target["keypoints"]
    num_kps = kps.shape[0]
    kps_weights = np.ones((num_kps,), dtype=np.float32)
    if "visible" in target:
        visible = target["visible"]
        kps_weights = visible

        heatmap = np.zeros((num_kps, self.heatmap_hw[0], self.heatmap_hw[1]), dtype=np.float32)
        heatmap_kps = (kps / 4 + 0.5).astype(np.int)  # round
        for kp_id in range(num_kps):
            v = kps_weights[kp_id]
            if v < 0.5:
                # 如果该点的可见度很低,则直接忽略
                continue

                x, y = heatmap_kps[kp_id]
                ul = [x - self.kernel_radius, y - self.kernel_radius]  # up-left x,y
                br = [x + self.kernel_radius, y + self.kernel_radius]  # bottom-right x,y
                # 如果以xy为中心kernel_radius为半径的辐射范围内与heatmap没交集,则忽略该点(该规则并不严格)
                if ul[0] > self.heatmap_hw[1] - 1 or \
                ul[1] > self.heatmap_hw[0] - 1 or \
                br[0] < 0 or \
                br[1] < 0:
                    # If not, just return the image as is
                    kps_weights[kp_id] = 0
                    continue

                    # Usable gaussian range
                    # 计算高斯核有效区域(高斯核坐标系)
                    g_x = (max(0, -ul[0]), min(br[0], self.heatmap_hw[1] - 1) - ul[0])
                    g_y = (max(0, -ul[1]), min(br[1], self.heatmap_hw[0] - 1) - ul[1])
                    # image range
                    # 计算heatmap中的有效区域(heatmap坐标系)
                    img_x = (max(0, ul[0]), min(br[0], self.heatmap_hw[1] - 1))
                    img_y = (max(0, ul[1]), min(br[1], self.heatmap_hw[0] - 1))

                    if kps_weights[kp_id] > 0.5:
                        # 将高斯核有效区域复制到heatmap对应区域
                        heatmap[kp_id][img_y[0]:img_y[1] + 1, img_x[0]:img_x[1] + 1] = \
                        self.kernel[g_y[0]:g_y[1] + 1, g_x[0]:g_x[1] + 1]

                        if self.use_kps_weights:
                            kps_weights = np.multiply(kps_weights, self.kps_weights)

                            plot_heatmap(image, heatmap, kps, kps_weights)

                            target["heatmap"] = torch.as_tensor(heatmap, dtype=torch.float32)
                            target["kps_weights"] = torch.as_tensor(kps_weights, dtype=torch.float32)

                            return image, target

我给大家解释一下可能难理解的地方:

heatmap_kps = (kps / 4 + 0.5).astype(np.int)

这句是将关键点的坐标映射到热力图上,因为最终的热力图相较于原图像下采样了4倍,所以要除以4,这里加上0.5是起到一个四舍五入的作用,因为后面要将坐标转为int格式。

ul = [x - self.kernel_radius, y - self.kernel_radius]  # up-left x,y
br = [x + self.kernel_radius, y + self.kernel_radius]  # bottom-right x,y

这两句是找到某个关键点对应热力图的左上角(ul)和右下角(br)的坐标,kernel_radius是高斯核的半径,如下图所示,hw坐标系表示热力图坐标,中间的⚪表示关键点在热力图上的坐标,坐标为(x,y):

# 如果以xy为中心kernel_radius为半径的辐射范围内与heatmap没交集,则忽略该点(该规则并不严格)
if ul[0] > self.heatmap_hw[1] - 1 or \
        ul[1] > self.heatmap_hw[0] - 1 or \
        br[0] < 0 or \
        br[1] < 0:
    # If not, just return the image as is
    kps_weights[kp_id] = 0
    continue

这句是看看以xy为中心kernel_radius为半径的辐射范围内(就是上图中的正方形区域内)与heatmap(就是上图的hw坐标系,当然其h=64,w=48,并不是无线延长的坐标系)有没有交集,若无交集,则将kps_weights[kp_id]置为0。

# Usable gaussian range
# 计算高斯核有效区域(高斯核坐标系)
g_x = (max(0, -ul[0]), min(br[0], self.heatmap_hw[1] - 1) - ul[0])
g_x = (max(0, -ul[1]), min(br[1], self.heatmap_hw[0] - 1) - ul[1])
# image range
# 计算heatmap中的有效区域(heatmap坐标系)
img_x = (max(0, ul[0]), min(br[0], self.heatmap_hw[1] - 1))
img_y = (max(0, ul[1]), min(br[1], self.heatmap_hw[0] - 1))

这几句分别计算高斯核有效区域和heatmap中的有效区域,为下一步将将高斯核有效区域复制到heatmap对应区域做准备:

if kps_weights[kp_id] > 0.5:
    # 将高斯核有效区域复制到heatmap对应区域
    heatmap[kp_id][img_y[0]:img_y[1] + 1, img_x[0]:img_x[1] + 1] = \
        self.kernel[g_y[0]:g_y[1] + 1, g_x[0]:g_x[1] + 1]

这几句到底实现了什么呢,其实就是把高斯核kernel复制到热力图中,至于复制到什么位置,复制多少,就看g_x、g_x、img_x和img_y了。我调试帮助大家理解一下,比如现在g_x=(0,12)、g_y=(0,12)、img_x=(25,37)和img_y=(12,24)。

g_x[0]:g_x[1]+1=0:12+1、g_y[0]:g_y[1]+1=0:12+1表示复制kernel的x方向(0,12+1)范围内的值和y方向(0,12+1)范围内,你看kernel的shape你会发现,其大小为13*13,那么这个(0,12+1)就是复制整个kernel数组**(这里刚好是整个数组,你调试的话会有不同的结果)**:

image-20231127204600376

那么把这个数组复制到哪里呢,其实就是热力图的对应区域,这是就用到了img_x=(25,37)和img_y=(12,24),将其复制到热力图w方向(25,37+1)和h方向(12,24+1)的位置,如下图所示:

image-20231127205240127


这里展示一下图片和产生热力图的结果,如下图所示:【注:由于不是同一次调试的结果,所以这里的图像和之前的有所差异】

image-20231127212959988


最后我还想说一个小点,就是kps_weights这个值,表示的是关键点的权重,如果没有指定这个参数,那么其就默认是关键点的可见性,如果指定了这个参数,其会让原来的可见性乘这个指定的参数,在HRNet中,这个kps_weights默认如下:

image-20231127205709643

小结

HRNet中的在线数据增强方式到这里就为大家介绍完啦,我觉得这部分还是非常重要的,大家可以去认真的学习一下喔,不明白的可以先调试调试,实在搞不懂欢迎评论区和我交流探讨。🥂🥂🥂

网络结构搭建

HRNet的网络结构我在原理详解篇已经为大家介绍过了,也简略的为大家展示了一些代码,但是没用具体介绍网络的详细结构。这里呢,我也不打算介绍了,因为我认为网络搭建部分真的是比较简单的,就像搭积木一样,一层一层的,只要你拿起代码对照着网络结构图调试一遍就会非常清晰了。所以这里大家一定要动起小手来喔!!!🌼🌼🌼

网络训练和预测

我们一起来看看训练阶段的代码,主要看训练一个epoch的情况就好啦,即train_one_epoch函数,首先有一个热身训练的代码:

lr_scheduler = None
if epoch == 0 and warmup is True:  # 当训练第一轮(epoch=0)时,启用warmup训练方式,可理解为热身训练
    warmup_factor = 1.0 / 1000
    warmup_iters = min(1000, len(data_loader) - 1)

    lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)

关于此部分代码可以从我的这篇博客-->poly学习率策略源码详解中查看详情,对这种学习率调整策略有详细解释,链接如下:

深度学习语义分割篇——DeeplabV3原理详解+源码实战🌱🌱🌱

接着就来说说for循环遍历数据集的过程,使用的for循环如下:

for i, [images, targets] in enumerate(metric_logger.log_every(data_loader, print_freq, header)):

我们可以调试进入log_every函数中,注意到log_every函数中有一个yield obj,yield是python中的关键字,是一个生成器,每次log_every运行到yield obj时都会暂停执行下面的代码,而是将obj返回给调用方,这样做的目的是节省内存。

这么说我觉得大家听的还是云里雾里,我画一个图解释一下代码的运行流程:

其按顺序依次执行①②③④⑤,在执行完③时,obj会传给①,得到image和target。知道了这一点,那么剩下的内容就比较简单啦,这里就不在过多叙述咯,不清楚的大家一定要调试调试喔。


接下来就是预测过程,我们会通过网络得到输出结果,其尺寸为(1,17,64,48),1表示batch为1,后面我们需要对这个输出做一些后处理操作,使其能够将预测关键点映射到原图上。关于预测过程,我也在HRNet原理详解篇进行了介绍,不清楚的一定要去看看喔🍚🍚🍚

总结

呼呼呼~~~终于写完啦,源码解析篇就到这里结束啦,整个HRNet到这里也结束咯,如果有任何不明白的地方欢迎和我一起探讨,共同进步喔。🥂🥂🥂

参考链接

HRNet论文🍁🍁🍁

HRNet网络简介🍁🍁🍁

HRNet源码🍁🍁🍁

如若文章对你有所帮助,那就🛴🛴🛴

一键三连 (1).gif

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2077249.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

linux下部署数据库总结

数据库 数据库主要分为两大类&#xff1a;关系型数据库与 NoSQL 数据库 关系型数据库&#xff0c;是建立在关系模型基础上的数据库&#xff0c;其借助于集合代数等数学概念和方法来处理数据库 中的数据主流的 MySQL、Oracle、MS SQL Server 和 DB2 都属于这类传统数据库。 NoSQ…

JVM理论篇(一)

一、类加载子系统 1.1 类加载子系统作用 类加载子系统负责从文件系统或者网络中加载Class文件&#xff0c;Class文件在文件开头有特定的文件标识。(CAFEBABE)ClassLoader只负责class文件的加载&#xff0c;至于它是否可以运行&#xff0c;则由Execution Engine 执行引擎决定。…

Spire.PDF for .NET【文档操作】演示:创建标记的 PDF 文档

带标签的 PDF&#xff08;也称为 PDF/UA&#xff09;是一种包含底层标签树&#xff08;类似于 HTML&#xff09;的 PDF&#xff0c;用于定义文档的结构。这些标签可以帮助屏幕阅读器浏览整个文档而不会丢失任何信息。本文介绍如何使用Spire.PDF for .NET在 C# 和 VB.NET 中从头…

Python中csv文件的操作3

在《Python中csv文件的操作2》中提到&#xff0c;with as语句可以自动关闭文件&#xff0c;而该语句可以和csv模块中的函数配合使用&#xff0c;达到读取和写入csv文件的目的。 1 csv文件的读取 使用csv模块中的函数读取csv文件的代码如图1所示。 图1 使用csv模块中的函数读取…

AI终于杀死了Leetcode!网友:面试神器已到位

家人们&#xff0c;今早起来 x 上一个帖子引起了奶茶的注意&#xff1a; 什么&#xff1f;奶茶以为自己没睡醒&#xff0c;揉了揉眼睛一看&#xff0c;没看错的话&#xff0c;这不就是AI结束了比赛吗。。。。 原文链接&#xff1a; https://www.reddit.com/r/leetcode/comments…

【ES6新特性】ES6新特性中Promise对象的概念,Async函数的使用以及Module语法

目录 1.Promise 对象 1.1 概念 1.2 使用 2.Async函数 2.1 同步和异步的区别 3.Mdule语法 1.Promise 对象 1.1 概念 Promise 是异步编程的一种解决方案&#xff0c;简单说就是一个容器&#xff0c;里面保存着某个未来才会结束 的事件&#xff08;通常是一个异步操作&#…

初识QT:从创建到认识

QT怎么安装这里就不说了&#xff0c;直接从使用开始 文章目录 1.QT项目的创建及介绍2.Hello QT&#xff01;2.1 图形化形式创建2.2 代码形式创建 3.对象树3.1 内存泄漏与对象树3.2 通过C类理解释放过程 4.乱码问题4.1 如何查看编码方式4.2 如何处理乱码 提示&#xff1a;QT项目…

arm 指令移位操作(11)

逻辑左移&#xff1a; 可以使寄存器也可以是 立即数 LSL &#xff1a; 字母缩写 举例&#xff1a; MOV R0&#xff0c;R1 &#xff0c;LSL #2 向左移位后&#xff0c;右面填0补充 逻辑右移&#xff1a; 可以使寄存器也可以是 立即数 LSR &#xff1a; 字母缩写 举例&…

10天速通Tkinter库——Day7:主菜单及图鉴

本篇博客我将介绍Tkinter实践项目《植物杂交实验室》中的杂交实验室主菜单、基础植物图鉴、杂交植物图鉴、杂交植物更多信息四个页面的制作。 它们作为主窗口的子页面实例&#xff0c;除了继承主窗口的基础设置&#xff08;如图标、标题、尺寸等等&#xff09;、还可以使用主窗…

《黑神话:悟空》游戏中的福建元素

《黑神话&#xff1a;悟空》作为一款深受玩家喜爱的动作角色扮演游戏&#xff0c;不仅在游戏剧情和角色设计上独具匠心&#xff0c;还巧妙地融入了丰富的中国传统文化元素&#xff0c;其中福建元素尤为突出。以下是对游戏中福建元素的详细解析&#xff1a; 一、地域文化与背景…

《机器学习》—— 支持向量机(SVM)实现二分类问题

文章目录 一、什么是支持向量机&#xff08;SVM&#xff09;1、SVM两个基本概念2、SVM的原理 二、示例&#xff1a;支持向量机&#xff08;SVM&#xff09;实现二分类问题1、先选取两个特征&#xff0c;并进行可视化2、选取所有的特征传入模型&#xff0c;并对模型进行评估 一、…

回归预测|基于北方苍鹰优化核极限学习机的数据预测Matlab程序NGO-KELM 多特征输入单输出

回归预测|基于北方苍鹰优化核极限学习机的数据预测Matlab程序NGO-KELM 多特征输入单输出 文章目录 一、基本原理1. 基本原理核极限学习机&#xff08;KELM&#xff09; 2. NGO-KELM回归预测流程1. 数据预处理2. 核极限学习机&#xff08;KELM&#xff09;模型构建3. 北方苍鹰优…

【Tomcat+MySQL+Redis源码安装三件套】

TomcatMySQLRedis源码安装三件套 Tomcat部分概念Tomcat的作用Tomcat的原理Linux运维中的应用场景具体操作示例 基本使用流程实验Tomcat安装tomcat的反向代理及负载均衡源地址hash(不能负载)tomcat负载均衡实现&#xff08;使用cookie&#xff09; memcached 操作命令 理论补充结…

大数据系统测试——大数据系统解析(上)

各位好&#xff0c;我是 道普云 欢迎关注我的主页 希望这篇文章对想提高软件测试水平的你有所帮助。 在本文中我们一起来看一下大数据系统每一个层次需要解决的技术问题和对应的一些技术需求。以此来作为学习大数据系统测试的基础。 数据收集层主要是进行数据源的分布式、…

sqli-labs靶场通关攻略 31-35

主页有sqli-labs靶场通关攻略 1-30 第三一关 less-31 闭合方式为?id1&id1 ") -- 步骤一&#xff1a;查看数据库名 http://127.0.0.1/less-31/?id1&id-1%22)%20union%20select%201,database(),3%20-- 步骤二&#xff1a;查看表名 http://127.0.0.1/less-31/?…

Redis 实现哨兵模式

目录 1 哨兵模式介绍 1.1 什么是哨兵模式 1.2 sentinel中的三个定时任务 2 配置哨兵 2.1 实验环境 2.2 实现哨兵的三条参数&#xff1a; 2.3 修改配置文件 2.3.1 MASTER 2.3.2 SLAVE 2.4 将 sentinel 进行备份 2.5 开启哨兵模式 2.6 故障模拟 3 在整个架构中可能会出现的问题 …

【FastAPI】—— 01 创建FastAPI项目

1.FastAPI框架介绍 FastAPI是⼀个现代、快速&#xff08;⾼性能&#xff09;的Web框架&#xff0c;⽤于构建API。是建⽴在Starlette和Pydantic基础上的。它基于Python3.7的类型提示&#xff08;typehints&#xff09;和异步编程&#xff08;asyncio&#xff09;能⼒&#xff0c…

软件设计原则之开闭原则

开闭原则&#xff08;Open-Closed Principle, OCP&#xff09;是软件设计中的一个重要原则&#xff0c;由伯特兰梅耶&#xff08;Bertrand Meyer&#xff09;在1988年提出。该原则强调软件实体&#xff08;如类、模块、函数等&#xff09;应该对扩展开放&#xff0c;对修改关闭…

【机器学习】 7. 梯度下降法,随机梯度下降法SGD,Mini-batch SGD

梯度下降法,随机梯度下降法SGD,Mini-batch SGD 梯度下降法凸函数(convex)和非凸函数梯度更新方向选择步长的选择 随机梯度下降SGD(Stochastic Gradient Descent)梯度下降法&#xff1a;SGD: Mini-batch SGD 梯度下降法 从一个随机点开始决定下降方向&#xff08;重要&#xff…

关于kafka的分区和消费者之间的关系

消费者和消费者组 当生产者向 Topic 写入消息的速度超过了消费者&#xff08;consumer&#xff09;的处理速度&#xff0c;导致大量的消息在 Kafka 中淤积&#xff0c;此时需要对消费者进行横向伸缩&#xff0c;用多个消费者从同一个主题读取消息&#xff0c;对消息进行分流。 …