【实验练习】基于自注意力机制Vision Transformer模型实现人脸朝向识别 (Python实现) 内容原创

news2024/11/26 4:48:49

题目

人脸识别是一个复杂的模式识别问题,人脸识别是人脸应用研究中非常重要的一步。由于人脸形状不规则、光线和背景条件多样,导致人脸检测精度受限。实际应用中,大量图像和视频源中人脸的位置、朝向、朝向角度都不是固定的,极大化的增加了人脸识别的难度。目前研究中,大多数研究是希望人脸识别过程中去除人脸水平旋转对识别过程的不良影响。但实际应用时往往比较复杂。

现给出采集到的一组人脸朝向不同角度时的图片,详见Image文件夹。图像来自不同的10个人,每人5个图像,人脸朝向分别为:向左、左前方、前方、右前方和向右。请选择本课程学习的任意一种算法进行训练,能够对任意给出的人脸图像进行朝向预测和识别。

# -*- coding: utf-8 -*- #
"""
@Project    :Exp
@File       :FaceVitRun.py
@Author     :ZAY
@Time       :2023/5/30 15:44
@Annotation : " "
"""

import os
import glob
import torch
import torch.nn as nn
import datetime
import numpy as np
from timm.models.layers import to_2tuple
from torchvision import transforms
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from Exp.Plot import plotloss,plotShow,plotROC
from Net.VitNet import ViT
from DataLoad import  Mydataset, Mydatasetpro
from EarlyStop import EarlyStopping
from sklearn import metrics
from sklearn.preprocessing import label_binarize
from sklearn.metrics import accuracy_score,auc,roc_curve,precision_recall_curve,f1_score, precision_score, recall_score
from matplotlib import pyplot as plt


LR = 0.0001 # 0.0001
EPOCH = 100
BATCH_SIZE = 10
Test_Batch_Size = 6

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def judgeType(y_predicted):
    if y_predicted == 0:
        pre = "left"
    elif y_predicted == 1:
        pre = "lf"
    elif y_predicted == 2:
        pre = "front"
    elif y_predicted == 3:
        pre = "rf"
    else:
        pre = "right"
    return pre

class ConfusionMatrix(object):

    def __init__(self, num_classes: int, labels: list):
        self.matrix = np.zeros((num_classes, num_classes))  # 初始化混淆矩阵,元素都为0
        self.num_classes = num_classes  # 类别数量,本例数据集类别为5
        self.labels = labels  # 类别标签

    def update(self, preds, labels):
        for p, t in zip(preds, labels):  # pred为预测结果,labels为真实标签
            self.matrix[p, t] += 1  # 根据预测结果和真实标签的值统计数量,在混淆矩阵相应位置+1

    def summary(self):  # 计算指标函数
        # calculate accuracy
        sum_TP = 0
        n = np.sum(self.matrix)
        for i in range(self.num_classes):
            sum_TP += self.matrix[i, i]  # 混淆矩阵对角线的元素之和,也就是分类正确的数量
        acc = sum_TP / n  # 总体准确率
        print("the model accuracy is ", acc)

        # kappa
        sum_po = 0
        sum_pe = 0
        for i in range(len(self.matrix[0])):
            sum_po += self.matrix[i][i]
            row = np.sum(self.matrix[i, :])
            col = np.sum(self.matrix[:, i])
            sum_pe += row * col
        po = sum_po / n
        pe = sum_pe / (n * n)
        # print(po, pe)
        kappa = round((po - pe) / (1 - pe), 3)
        # print("the model kappa is ", kappa)

        return str(acc)

    def plot(self):  # 绘制混淆矩阵
        matrix = self.matrix
        print("matrix: ",matrix)
        plt.imshow(matrix, cmap=plt.cm.Blues)

        # 设置x轴坐标label
        plt.xticks(range(self.num_classes), self.labels, rotation=45)
        # 设置y轴坐标label
        plt.yticks(range(self.num_classes), self.labels)
        # 显示colorbar
        plt.colorbar()
        plt.xlabel('True Labels')
        plt.ylabel('Predicted Labels')
        plt.title('Confusion matrix (acc=' + self.summary() + ')')

        # 在图中标注数量/概率信息
        thresh = matrix.max() / 2
        for x in range(self.num_classes):
            for y in range(self.num_classes):
                # 注意这里的matrix[y, x]不是matrix[x, y]
                info = int(matrix[y, x])
                plt.text(x, y, info,
                         verticalalignment='center',
                         horizontalalignment='center',
                         color="white" if info > thresh else "black")
        plt.tight_layout()
        plt.savefig(".//Result//matrix.png")
        plt.show()


if __name__ == "__main__":

    global id_to_species

    store_path = './/model//transformer.pt'
    txt_path = './/Result//Vit.txt'

    # 使用glob方法来获取数据图片的所有路径
    all_imgs_path = glob.glob(r"./Data/*/*.bmp")  # 数据文件夹路径

    # for var in all_imgs_path:
    #     print(var)

    # 利用自定义类Mydataset创建对象face_dataset
    face_dataset = Mydataset(all_imgs_path)
    print("文件夹中图片总个数:",len(face_dataset))  # 返回文件夹中图片总个数
    # face_dataloader = torch.utils.data.DataLoader(face_dataset, batch_size = 8)  # 每次迭代时返回8个数据
    # 为每张图片制作对应标签
    species = ['left', 'lf', 'front', 'rf', 'right']
    species_to_id = dict((c, i) for i, c in enumerate(species))
    id_to_species = dict((v, k) for k, v in species_to_id.items())
    print("id_to_species",id_to_species)

    # 对所有图片路径进行迭代
    all_labels = []
    for img in all_imgs_path:
        # 区分出每个img,应该属于什么类别
        for i, c in enumerate(species):
            if c in img:
                all_labels.append(i)
    print("all_labels",all_labels)

    # 对数据进行转换处理
    transform = transforms.Compose([
        transforms.Resize((420, 420)),  # 做的第一步转换
        transforms.ToTensor()  # 第二步转换,作用:第一转换成Tensor,第二将图片取值范围转换成0-1之间,第三会将channel置前
    ])

    face_dataset = Mydatasetpro(all_imgs_path, all_labels, transform)
    face_dataloader = DataLoader(
        face_dataset,
        batch_size = BATCH_SIZE,
        shuffle = True
    )

    imgs_batch, labels_batch = next(iter(face_dataloader))
    print("imgs_batch.shape",imgs_batch.shape) # torch.Size([4, 3, 420, 420])


    # plt.figure(figsize = (12, 8))
    # for i, (img, label) in enumerate(zip(imgs_batch[:6], labels_batch[:6])):
    #     img = img.permute(1, 2, 0).numpy() # (H,W,C)
    #     plt.subplot(2, 3, i + 1) # subplot(numRows, numCols, plotNum) numRows 行 numCols 列
    #     plt.title(id_to_species.get(label.item()))
    #     plt.imshow(img)
    # plt.show()  # 展示图片

    # 划分数据集和测试集
    index = np.random.permutation(len(all_imgs_path))
    print("index",index)
    # 打乱顺序
    all_imgs_path = np.array(all_imgs_path)[index]
    all_labels = np.array(all_labels)[index]

    for i in range(len(all_imgs_path)):
        print("第{}张图片存储路径为:{}, 朝向为:{}, 标签为:{}".format(i + 1, all_imgs_path[i],judgeType(all_labels[i]),all_labels[i]))
    # 80%做训练集
    # c = int(len(all_imgs_path) * 0.8)
    # print("训练集和验证集数量:", c)
    c = 40
    v = 4
    t = 6

    c_imgs = all_imgs_path[:c]
    c_labels = all_labels[:c]
    v_imgs = all_imgs_path[c:c+v]
    v_labels = all_labels[c:c+v]
    t_imgs = all_imgs_path[c+v:]
    t_labels = all_labels[c+v:]

    test_face_dataset = Mydatasetpro(t_imgs, t_labels, transform)
    global test_face_dataloader
    test_face_dataloader = DataLoader(
        test_face_dataset,
        batch_size = t,
        shuffle = False # 在每个epoch开始的时候是否进行数据的重新排序,默认false
    )

    # train_imgs = all_imgs_path[:c]
    # train_labels = all_labels[:c]
    # test_imgs = all_imgs_path[c:]
    # test_labels = all_labels[c:]

    # print(test_imgs)
    # print(test_imgs.shape)
    # print(test_labels)
    # print(test_labels.shape)

    cal_ds = Mydatasetpro(c_imgs, c_labels, transform)  # TrainSet TensorData
    val_ds = Mydatasetpro(v_imgs, v_labels, transform)  # TestSet TensorData
    test_ds = Mydatasetpro(t_imgs, t_labels, transform)

    # print(next(iter(face_dataloader)))
    # 改进 psize 30-60 减少必要的epoch heads 12-6 减少参数和必要的epoch
    modeltrian(image_size = 420, ncls=5, psize=60, depth=3, heads=12, dim = 2048, mlp_dim=1024, path=store_path, data_train = cal_ds, data_test = val_ds)  # depth=6, heads=10, 12, 14

    # modeltest(image_size = 420, ncls = 5, psize = 60, depth=3, heads=6, dim = 2048, mlp_dim=1024, path=store_path, data_test = test_ds, txt_path = txt_path)

    acc, precis, reca, F1, roc_auc = model4AUCtest(image_size = 420, ncls=5, psize=60, depth=3, heads=12, dim = 2048, mlp_dim=1024, path=store_path, data_test = test_ds, txt_path = txt_path)  # depth=6, heads=10, 12, 14
    # print("acc:{}, precis:{}, recall:{}, F1:{}, auc:{}".format(acc, precis, reca, F1, roc_auc))

   

测试结果

第1张图片朝向为:front, 第1张图片预测为:front
第2张图片朝向为:front, 第2张图片预测为:front
第3张图片朝向为:rf, 第3张图片预测为:rf
第4张图片朝向为:rf, 第4张图片预测为:rf
第5张图片朝向为:lf, 第5张图片预测为:lf
第6张图片朝向为:left, 第6张图片预测为:left
acc:1.0, precis:1.0, recall:1.0, F1:1.0, auc:1.0
DATE:_2023-06-04_13-23-15, TEST:Avg_acc= 1.0000, train_time = 0:00:54.309789, test_time = 0:00:00.274266
当前模型参数量: 38.990857 M

 

完整代码及数据集请私信 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/616170.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

内网渗透—隧道技术

内网渗透—隧道技术 1. 隧道技术介绍1.1. 内网—隧道技术1.2. 常见的隧道协议1.3. 前置条件1.4. 判断内网的连通性 2. 网络层隧道技术2.1. ICMP隧道技术2.1.1. 常见工具2.1.2. Pingtunnel基础演示2.1.2.1. 下载服务端2.1.2.2. 下载客户端2.1.2.3. 设置CS连接2.1.2.4. 连接测试 …

功能上新|对比分析、Batches数量、函数释义Tips

本篇是继功能上新|内存篇、GPU篇之后,为大家展示更多关于提升浏览UWA GOT Online Overview报告体验的优化项,包括Overview报告的对比分析、Batches数量、函数释义Tips等。这些功能可以让你更快上手对报告的理解,亦或者更好地融入在…

.Net Core 6 WebApi 项目搭建(一、简单搭建)

前言 对于后端开发者最耻辱的是什么,是只会增删改查,只会CV,只会业务代码。没错,我就是被钉在耻辱柱上的一员,3年开发经验,不会搭建框架,只会写业务代码,丢人丢人啊,所以…

【XR】One More Thing:Vision Pro ,7年磨一剑,2023WWDC苹果发布Vision MR

One More Thing:Vision Pro ,7年磨一剑,2023WWDC苹果发布Vision MR 1. 苹果MR Vision Pro:1. 专利布局:苹果表示在开发过程中申请了5000多项专利。2. 专属感知计算芯片3. 显示屏系统方面4. 续航方面5. Vision MR 的新框…

MATLAB安装配置MinGW-w64 C++编译器

文章目录 前言一、Mingw安装1、安装教程2、验证 二、MATLAB安装配置MinGW总结 #pic_center 前言 只是为方便学习,不做其他用途 一、Mingw安装 在网上找到的安装一直报错:The file has been downloaded incorrectly 1、安装教程 建议参考博客Mingw快捷安…

C++内存序、屏障和原子操作

文章目录 一、原子类型二、原子操作函数三、内存序1&#xff09;happens-before和synchronizes-with语义2&#xff09;内存序模式 四、标准库函数五、栅栏&#xff08;Barrier&#xff09; 一、原子类型 标准原子类型的备选名和与其相关的 std::atomic<> 特化类&#xf…

探索低代码的新形态(D2C、ChatGPT)

前言 低代码平台的出现&#xff0c;是互联网快速发展的背景下&#xff0c;满足产品快速迭代的实际需求。现在国内外都已经拥有非常多优秀的开源项目&#xff08;如&#xff1a;lowcode-engine&#xff09;和成熟的商业产品&#xff08;如&#xff1a;Mendix 、PowerPlatform&a…

Orillusion次时代 WebGPU 引擎

Orillusion 次时代 WebGPU 引擎 官网: https://www.orillusion.com/ 教程: https://www.orillusion.com/guide/ Orillusion 引擎是一款完全支持 WebGPU 标准的轻量级渲染引擎。基于最新的 Web 图形API标准&#xff0c;我们做了大量的探索和尝试&#xff0c;实现了很多曾经在 We…

python接口自动化(三)--如何设计接口测试用例(详解)

在开始接口测试之前&#xff0c;我们来想一下&#xff0c;如何进行接口测试的准备工作。或者说&#xff0c;接口测试的流程是什么&#xff1f;有些人就很好奇&#xff0c;接口测试要流程干嘛&#xff1f;不就是拿着接口文档直接利用接口 测试工具测试嘛。其实&#xff0c;如果…

【正点原子STM32连载】 第二十八章 低功耗实验摘自【正点原子】STM32F103 战舰开发指南V1.2

1&#xff09;实验平台&#xff1a;正点原子stm32f103战舰开发板V4 2&#xff09;平台购买地址&#xff1a;https://detail.tmall.com/item.htm?id609294757420 3&#xff09;全套实验源码手册视频下载地址&#xff1a; http://www.openedv.com/thread-340252-1-1.html 第二十…

计算机中数据的表示:定点数、浮点数

文章目录 1 概述2 定点数2.1 表示方法2.2 取值范围2.3 运算方法 3 浮点数3.1 表示方法3.2 运算方法 4 扩展4.1 等比数列前 n 项和公式 1 概述 #mermaid-svg-EXDrkn8G91FsDdps {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#merm…

给 a 标签设置 display:inline-block 之后 a 整体下沉

今天给一个a设置宽高&#xff0c;前提是添加了display:inline-block&#xff1b;然后发下a没有与父元素div顶部对齐&#xff0c;反而下沉了。试了好多办法都没成功&#xff0c;然后在网上找的教程。 原因 1、问题就是出在了display:inline-block;语句上&#xff0c;行内块元…

使用Python进行接口性能测试:从入门到高级

前言&#xff1a; 在今天的网络世界中&#xff0c;接口性能测试越来越重要。良好的接口性能可以确保我们的应用程序可以在各种网络条件下&#xff0c;保持流畅、稳定和高效。Python&#xff0c;作为一种广泛使用的编程语言&#xff0c;为进行接口性能测试提供了强大而灵活的工…

Redis:数据类型

一、Redis字符串(String) 1、String类型 String字符串&#xff1a;string类型是redis最基本、最简单的数据类型&#xff0c;一个key对应一个value。 String类型的二进制是安全的&#xff0c;可以包含任何数据&#xff0c;但是每一个value最大时512M 2、String命令 设置和获…

《人月神话》译文修订明细(6)-读者可以对照修改

《人月神话》译文修订明细&#xff08;1&#xff09;-读者可以对照修改 《人月神话》译文修订明细&#xff08;2&#xff09;-读者可以对照修改 《人月神话》译文修订明细&#xff08;3&#xff09;-读者可以对照修改 《人月神话》译文修订明细&#xff08;4&#xff09;-读…

前端面试题整理14

目录 1.什么是同步&#xff1f;什么是异步&#xff1f; 2.localStorage、sessionStorage和cookie的区别&#xff1f; 3.Vue中key的作用是什么&#xff1f; 4.支付流程是什么&#xff1f; 5.Vuex的模块化是如何做的&#xff1f; 6.Vite和webpack的不同&#xff1f;Vite的优…

BS LIS系统仪器数据采集方法

BS LIS系统仪器数据采集方法 BS LIS系统对检验仪器的数据采集主要通过串行口通讯、USB端口通讯、TCP/IP通讯、定时监控数据库和手工录入等几种方法。串行口通讯最为普遍&#xff0c;采用RS-232C标准&#xff0c;一般的仪器都支持此标准。定时监控数据库对仪器管理机上已有的检…

【Vue】Element Plus和Element UI中插槽使用

文章目录 前言一、两者的区别二、组件库三、具体讲解总结 前言 今天和大家讲一下Element Plus和Element UI这两个组件库中表格的插槽使用方法&#xff0c;一般情况下vue2使用Element UI这个组件库&#xff0c;表格组件的插槽的话一般都是使用v-slot&#xff0c;而vue3使用Elem…

如何进行有效的移动应用测试?10个步骤带你一战成神

移动应用的市场日益壮大&#xff0c;而随着这个市场的发展&#xff0c;如何有效地测试移动应用也成为了一个重要的问题。本文将为你提供一些关于如何进行有效的移动应用测试的建议&#xff0c;并提供一些实际测试例子。 1. 理解你的用户和使用场景 在进行移动应用测试之前&…

rror updating database. Cause: java.sql.SQLSyntaxErrorException解决方案

错误描述&#xff1a; ### Error updating database. Cause: java.sql.SQLSyntaxErrorException: You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near CONDITION 1 这里是因为字段名…