Paper reading:Fine-Grained Head Pose Estimation Without Keypoints (CVPR2018)

news2024/9/24 19:14:16

Paper reading:Fine-Grained Head Pose Estimation Without Keypoints (CVPR2018)

一、 背景

为什么要读这篇论文,因为LZ之前要做头部姿态估计,看到一些传统的方法,都是先进行人脸检测,然后再进行关键点定位,当然现在可以一起做,anyway,得到最后的关键点位置,再使用一个通用的3D人脸模型,通过solvePnP来得到最终的头部姿态,但是不管是脑子中考虑还是最后的动手实践,得到的结论就是这种方式的头部姿态方法不robust。可以想一下:每个人的脸型不一样吧,物管肯定也有差异,3D通用模型也有很多方式,关键点定位也有偏差,这些都是不确定的,只能说当精度要求不高,并且关键点定位足够准确,且头部姿态估计的对象和3D的通用人脸模型相对匹配的情况下,这种方式才比较好,那么问题来了,算法的泛化能力呢。。。

于是乎,还是往深度学习的方法上瞅瞅,就看到了题目中的文章,简单测试了下,觉得效果可行,那么就开始阅读论文和代码吧。

二、 数据集准备

主要使用的数据是300W-LP,下载的地址为: http://www.cbsr.ia.ac.cn/users/xiangyuzhu/projects/3ddfa/main.htm

在这里插入图片描述

大概有2.6个G,下载可能需要一段时间,所以有的时候LZ如果确定要尝试一种方法,首先就要开始准备下载数据集,在下载数据集的时候可以在慢慢阅读下论文。

当然这些数据都是合成的,所以有些图片看起来会有点奇怪

在这里插入图片描述

三、 训练代码运行的一些问题

1. python2和python3的兼容性问题

LZ用的是python3,原始论文使用的是python2,所以会存在一些兼容性的问题,这些都比较好修改,例如把xrange替换成range这种。

2. pytorch的版本问题

因为是两三年前的代码了,pytorch可能版本比较旧,也会存在一些代码的修改

  • utils.py中
# 直接注释掉这一行
# from torch.utils.serialization import load_lua
  • 训练代码以train_hopenet.py为例吧
    error:
RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([1]) and output[0] has a shape of torch.Size([]).

solution:

 # grad_seq = [torch.ones(1).cuda(gpu) for _ in range(len(loss_seq))]
 grad_seq = [torch.tensor(1.0).cuda(gpu) for _ in range(len(loss_seq))]

error:

IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

solution:

 # print('Epoch [%d/%d], Iter [%d/%d] Losses: Yaw %.4f, Pitch %.4f, Roll %.4f'
                #       % (epoch + 1, num_epochs, i + 1, len(pose_dataset) // batch_size, loss_yaw.data[0],
                #          loss_pitch.data[0], loss_roll.data[0]))
print('Epoch [%d/%d], Iter [%d/%d] Losses: Yaw %.4f, Pitch %.4f, Roll %.4f'
       % (epoch + 1, num_epochs, i + 1, len(pose_dataset) // batch_size, loss_yaw.item(),
          loss_pitch.item(), loss_roll.item()))

运行就没啥问题了
在这里插入图片描述但是這個後面得看一下,爲什麼loss會突然增到這麼大。。。

四、测试结果

因为这个算法的流程是要先进行人脸检测,然后在人脸检测框四周扩充一定的范围后进行头部姿态估计的,按照上述的方法,经过测试,确实效果还可以,但是如果是一整张大图,直接回归出头部姿态,这个结果就是非常不准确的了,下面我们来看下代码,看看是否有值得借鉴的信息。

五、代码部分

5.1 训练代码

我们就以train_hopenet.py为例,其他只是换了backbone,原理都是一样的,当然LZ还是小小改动了一下源码

  • 一些常规设置
def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Head pose estimation using the Hopenet network.')
    parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
                        default=0, type=int)
    parser.add_argument('--num_epochs', dest='num_epochs', help='Maximum number of training epochs.',
                        default=5, type=int)
    parser.add_argument('--batch_size', dest='batch_size', help='Batch size.',
                        default=16, type=int)
    parser.add_argument('--lr', dest='lr', help='Base learning rate.',
                        default=0.001, type=float)
    parser.add_argument('--dataset', dest='dataset', help='Dataset type.', default='Pose_300W_LP', type=str)
    parser.add_argument('--data_dir', dest='data_dir', help='Directory path for data.',
                        default='', type=str)
    parser.add_argument('--filename_list', dest='filename_list',
                        help='Path to text file containing relative paths for every example.',
                        default='', type=str)
    parser.add_argument('--output_string', dest='output_string', help='String appended to output snapshots.',
                        default='', type=str)
    parser.add_argument('--alpha', dest='alpha', help='Regression loss coefficient.',
                        default=0.001, type=float)
    parser.add_argument('--snapshot', dest='snapshot', help='Path of model snapshot.',
                        default='', type=str)

    args = parser.parse_args()
    return args

  • 主函数

5.2 Hopenet部分

class Hopenet(nn.Module):
    # Hopenet with 3 output layers for yaw, pitch and roll
    # Predicts Euler angles by binning and regression with the expected value
    def __init__(self, block, layers, num_bins):
        self.inplanes = 64
        super(Hopenet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7)
        self.fc_yaw = nn.Linear(512 * block.expansion, num_bins)
        self.fc_pitch = nn.Linear(512 * block.expansion, num_bins)
        self.fc_roll = nn.Linear(512 * block.expansion, num_bins)

        # Vestigial layer from previous experiments
        self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        pre_yaw = self.fc_yaw(x)
        pre_pitch = self.fc_pitch(x)
        pre_roll = self.fc_roll(x)

        return pre_yaw, pre_pitch, pre_roll

5.3 datasets部分

这里LZ就选择其中的一个数据集Pose_300W_LP来进行解释

class Pose_300W_LP(Dataset):
    # Head pose from 300W-LP dataset
    def __init__(self, data_dir, filename_path, transform, img_ext='.jpg', annot_ext='.mat', image_mode='RGB'):
        self.data_dir = data_dir
        self.transform = transform
        self.img_ext = img_ext
        self.annot_ext = annot_ext

        filename_list = get_list_from_filenames(filename_path)

        self.X_train = filename_list
        self.y_train = filename_list
        self.image_mode = image_mode
        self.length = len(filename_list)

    def __getitem__(self, index):
    	#这个比较重要的是数据处理部分
        img = Image.open(os.path.join(self.data_dir, self.X_train[index] + self.img_ext)) 
        img = img.convert(self.image_mode)
        mat_path = os.path.join(self.data_dir, self.y_train[index] + self.annot_ext)

        # Crop the face loosely
        pt2d = utils.get_pt2d_from_mat(mat_path) #这个是从mat中得到对应的68个关键点的坐标
        x_min = min(pt2d[0, :])
        y_min = min(pt2d[1, :])
        x_max = max(pt2d[0, :])
        y_max = max(pt2d[1, :])

        # k = 0.2 to 0.40
        k = np.random.random_sample() * 0.2 + 0.2
        x_min -= 0.6 * k * abs(x_max - x_min)
        y_min -= 2 * k * abs(y_max - y_min)
        x_max += 0.6 * k * abs(x_max - x_min)
        y_max += 0.6 * k * abs(y_max - y_min)
        img = img.crop((int(x_min), int(y_min), int(x_max), int(y_max)))

        # We get the pose in radians
        pose = utils.get_ypr_from_mat(mat_path)
        # And convert to degrees.
        pitch = pose[0] * 180 / np.pi
        yaw = pose[1] * 180 / np.pi
        roll = pose[2] * 180 / np.pi

        # Flip?
        rnd = np.random.random_sample()
        if rnd < 0.5:
            yaw = -yaw
            roll = -roll
            img = img.transpose(Image.FLIP_LEFT_RIGHT)

        # Blur?
        rnd = np.random.random_sample()
        if rnd < 0.05:
            img = img.filter(ImageFilter.BLUR)

        # Bin values
        bins = np.array(range(-99, 102, 3))
        binned_pose = np.digitize([yaw, pitch, roll], bins) - 1

        # Get target tensors
        labels = binned_pose
        cont_labels = torch.FloatTensor([yaw, pitch, roll])

        if self.transform is not None:
            img = self.transform(img)

        return img, labels, cont_labels, self.X_train[index]

    def __len__(self):
        # 122,450
        return self.length


数据集中的mat主要包含这几个部分:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/66953.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java基于JSP旅游网站系统的设计于实现

我国的旅游事业目前正处于一个科学技术日新月异飞速向前发展的环境中。信息技术和通信技术以令人目不暇接的速度发展&#xff0c;尤其是互联网络的广泛流行&#xff0c;使得各种服务信息已近乎透明&#xff0c;且个性突出的游客们已不再满足于死板的标准化的旅游项目&#xff0…

JavaScript-T2

JavaScript-T2 前言 本次主要讲解的知识点是&#xff1a; JavaScript自定义函数 JavaScript系统函数 JavaScript 事件 JavaScript 的常用事件 JavaScript自定义函数 函数就是为了完成程序中的某些特定功能而进行专门定义的一段程序代码 函数包括自定义函数和系统函数 使用函数…

Akka 学习(二)第一个入门程序

目录一 sbt 介绍1.1 Sbt1.2 下载安装1.3 sbt的特点1.4 Idea 配置Sbt开发工具二 构建定义2.1 指定版本2.2 build.sbt 设置三 代码实现3.1 Java版本3.2 Scala版本3.3 对比一 sbt 介绍 1.1 Sbt sbt 是为 Scala 和 Java 项目构建的。它是93.6%的 Scala 开发人员的首选构建工具&am…

2000-2021年各省GDP包括名义GDP、实际GDP、GDP平减指数(以2000年为基期)

全国31省市GDP平减指数(2000-2021年)及计算步骤 1、时间&#xff1a;2000-2021年 2、范围&#xff1a;31省 3、数据包括&#xff1a;2000-2021年各省市GDP平减指数&#xff0c;以2000年为基期&#xff0c;包括数据来源、计算方法、公式等。 4、计算步骤&#xff1a; 第一步…

物联卡采购注意要点有哪些

在这个万物互联的时代&#xff0c;针对于企业设备联网的物联卡就显得格外重要了&#xff0c;而共享单车&#xff0c;移动支付&#xff0c;智慧城市&#xff0c;自动售卖机等企业采购物联卡会面临着各种问题&#xff0c;低价陷阱&#xff0c;流量虚假&#xff0c;管理混乱&#…

sealos issue #2157 debug 思路流程记录

sealos issues#2157 debug思路流程前言分析issue剖析源码解决方案总结前言 这个项目蛮有意思的&#xff0c;sealos 是以 kubernetes 为内核的云操作系统发行版。 boss上看到 -> 沟通 -> 解决某个issue直接offer -> 舒服 本文记录解决 issue 的思路 分析issue BUG…

Linux系统常用的工具

1.1 Vscode编辑器 从官网下载 ubuntu 版本&#xff0c;官网地址&#xff1a;https://code.visualstudio.com/。下载xxx.deb的包。 或者使用指令下载&#xff1a;wget https://az764295.vo.msecnd.net/stable/6261075646f055b99068d3688932416f2346dd3b/code_1.73.1-1667967334…

基于Intel Lake-UP3平台为半导体与集成电路测试设备提供优异计算性能

为什么半导体和IC测试设备需要升级&#xff1f; 随着众多新的高性能应用的需求不断增加&#xff0c;信迈旨在为半导体集成电路测试设备领域的客户提供更好的方案。半导体和集成电路&#xff08;IC&#xff09;测试设备设计用于在一台测试机上同时对不同线路的数百个集成电路…

How Can We Know What Language Models Know?

Abstract 最近的工作通过让语言模型&#xff08;LM&#xff09;填补诸如“奥巴马是一个职业”之类的提示的空白&#xff0c;提出了一个有趣的结果&#xff0c;以检查语言模型&#xff08;LM&#xff09;中包含的知识。这些提示通常是手动创建的&#xff0c;而且很可能不是最佳…

Linux进程通信之进程信号

一、信号的概念&#xff1a; 信号机制是Linux最基本的通讯机制&#xff0c;它可以用来向一个或者多个进程发送异步事件信息&#xff0c;传送少量信息。信号是一个软件中断&#xff0c;并且是一个“软中断”&#xff08;只是告诉有这样一个信号&#xff0c;但这个信号具体如何进…

Redis6入门到实战------思维导图+章节目录

Redis学习大纲 思维导图 思维导图 Redis6入门到实战------1、NoSQL数据库简介 地址&#xff1a; Redis6入门到实战------2、Redis6概述和安装 地址&#xff1a; Redis6入门到实战------3、常用五大数据类型 地址&#xff1a; Redis6入门到实战------4、Redis6配置文件详解…

Stack Overflow 临时禁用 ChatGPT 生成内容,网友:人类和AI快打起来!

如果有一天我们查询到的「知识」真假难辨&#xff0c;那这就太可怕了。 要问最近 AI 圈哪个模型最火爆&#xff0c;你不得不把 OpenAI 推出的 ChatGPT 排在前面。自从发布以来&#xff0c;这个对话模型可谓是出尽风头&#xff0c;很多人更是对其产生了一百个新玩法&#xff0c;…

Linux系统移植四:Petalinux使用本地sstate-cache加速构建根文件系统

根文件系统简介 根文件系统 rootfs 是Linux内核启动以后挂载(mount)的第一个文件系统&#xff0c;然后从根文件系统中读取初始化脚本&#xff0c;比如rcS&#xff0c;inittab等 根文件系统和Linux内核是分开的&#xff0c;单独的Linux内核是没法正常工作的&#xff0c;必须要…

TPM零知识学习六 —— tpm模拟器安装

本文参考以下链接&#xff1a; TPM模拟器和TPM2-TSS安装_jianming21的博客-CSDN博客_tpm2-tss 可信平台模块TPM&#xff08;Trusted Platform Module&#xff09;介绍及tpm-tools安装使用_jinhuazhe2013的博客-CSDN博客_tpm模块 1. 源码下载 运行以下命令下载源码&#xff1…

设计模式--观察者模式

文章目录前言一、未使用设计模式二、观察者模式1.定义2.组成三、应用场景四、优缺点优缺前言 甲人A&#xff08;产品经理&#xff09;&#xff1a;好啊&#xff0c;你小子&#xff0c;又被我逮到了&#xff0c;很闲是吧&#x1f607;&#xff0c;需求完成了吗&#xff1f; two…

MOSFET 和 IGBT 栅极驱动器电路的基本原理学习笔记(三)同步整流器驱动

同步整流器驱动 1.栅极电荷 2.dv/dt注意事项 MOSFET 同步整流器是接地基准开关的一个特例。这些器件与传统应用所使用的 N 沟道 MOSFET 相同&#xff0c;只是它们被应用到了电源的低电压输出而非整流器二极管中。 它们通常可在非常有限的漏源极电压摆幅下工作&#xff0c;因此…

redis活跃非活跃连接数统计及client list说明

概念说明 活跃连接是指当下正在执行命令的连接&#xff0c;非活跃当然是相对的。 在redis中判断当前连接是否活跃是通过 内置的client list 命令输出中的idle来判断 client list字段说明 (kfzops) [roottest-xxx-01-vm ]# redis-cli -h r-xxxxxxxxxxxx.redis.rds.aliyuncs.…

学生身份标签的识别与风控应用

当前的互联网借贷平台&#xff0c;国家已明确规定不允许向高校学生发放贷款&#xff0c;因此对于小贷、消金等金融机构&#xff0c;在信贷产品业务的风控体系中&#xff0c;有效判断申请用户是否为高校学生是一个非常重要的问题。针对高校学生身份的识别&#xff0c;虽然有多种…

机器学习、深度学习、自然语言处理学习 NLP-RoadMap-996station GitHub鉴赏官

推荐理由&#xff1a; 机器学习、深度学习、自然语言处理学习路线图 及 AI方向学习资源、工具 NLP-RoadMap 持续更新中。以下内容有错误或者不足&#xff0c;欢迎提Issue或者联系我讨论 整理不易&#xff0c;希望点个小星星 ​支持下呀&#xff01; 前言 数理基础 编程基础 机…

RE2:Simple and Effective Text Matching with Richer Alignment Features

原文链接&#xff1a;https://aclanthology.org/P19-1465.pdf 介绍 问题 作者认为之前文本匹配模型中序列对齐部分&#xff0c;过于复杂。只有单个inter-sequence alignment层的模型&#xff0c;常会引入外部信息&#xff08;例如语法特征&#xff09;作为额外输入&#xff0c;…