【计算机视觉】如何利用 CLIP 做简单的人脸任务?(含源代码)

news2024/11/26 13:32:50

文章目录

  • 一、数据集介绍
  • 二、源代码 + 结果
  • 三、代码逐行解读

一、数据集介绍

CELEBA 数据集(CelebFaces Attributes Dataset)是一个大规模的人脸图像数据集,旨在用于训练和评估人脸相关的计算机视觉模型。该数据集由众多名人的脸部图像组成,提供了丰富的人脸属性标注信息。

以下是 CELEBA 数据集的一些详细信息:

  1. 规模:CELEBA 数据集包含超过 20 万张名人的脸部图像样本。
  2. 图像内容:数据集中的图像涵盖了各种不同种族、年龄、性别、发型、妆容等的人脸图像,以提供更广泛的人脸表征。
  3. 标注信息:除了图像本身,CELEBA 数据集还提供了一系列的属性标注信息。这些属性包括性别、年龄、眼镜、微笑等。每个图像都有对应的二进制属性标签,用于指示该图像是否具有某个属性。
  4. 数据集组织:CELEBA 数据集的图像以 JPEG 格式存储,并使用标注文件进行关联。标注文件( list_attr_celeba.txt )包含每个图像的文件名及其相关属性标签。
  5. 应用领域:CELEBA 数据集被广泛用于人脸属性识别、人脸检测、人脸生成、人脸识别等计算机视觉任务的研究和开发。

CELEBA 数据集的丰富性和规模使其成为人脸相关算法的重要基准数据集之一。研究人员和开发者可以利用该数据集来训练和评估人脸相关的深度学习模型,推动人脸识别、人脸属性分析等领域的进展。

需要注意的是,CELEBA 数据集的具体细节和使用方式可能会有更新和改变。建议在使用数据集时查阅最新的文档和数据集发布者的说明。

CELEBA 数据集每一部分的解释和名称如下:

CELEBA 数据集由多个部分组成,每个部分包含不同的信息和用途。以下是 CELEBA 数据集的一些主要部分及其解释和名称:

  1. 图像文件夹(img_align_celeba):该部分包含了 CELEBA 数据集的人脸图像文件,以 JPEG 格式存储。图像文件夹通常包含大量的人脸图像,用于进行人脸相关任务的训练、测试和评估。
  2. 标注文件(list_attr_celeba.txt):该部分是 CELEBA 数据集的属性标注文件,它提供了每个图像的属性信息。属性标注文件是一个文本文件,包含了图像文件名及其对应的属性标签。这些属性标签描述了图像中的人脸属性,例如性别、年龄、微笑、眼镜等。
  3. 划分文件(list_eval_partition.txt):这个部分是 CELEBA 数据集的划分文件,用于将数据集划分为训练集、验证集和测试集。划分文件是一个文本文件,包含了每个图像的文件名及其所属的划分集合。
  4. 人脸边界框文件(list_bbox_celeba.txt):这个部分包含了 CELEBA 数据集中每个图像的人脸边界框信息。人脸边界框文件是一个文本文件,包含了每个图像的文件名以及对应的人脸边界框的坐标信息。
  5. 人脸关键点文件(list_landmarks_celeba.txt):这个部分包含了 CELEBA 数据集中每个图像的人脸关键点信息。人脸关键点文件是一个文本文件,包含了每个图像的文件名以及对应的人脸关键点的坐标信息。

这些部分是 CELEBA 数据集中常用的部分,用于获取图像、属性标注、划分信息以及人脸边界框和关键点信息。使用这些部分的数据,可以进行各种人脸相关任务的训练、评估和分析。

二、源代码 + 结果

import clip
import torch
import torchvision
import time

device = "cuda" if torch.cuda.is_available() else "cpu"

def model_load(model_name):
    # 加载模型
    model, preprocess = clip.load(model_name, device) #ViT-B/32 RN50x16
    return model, preprocess

def data_load(data_path):
    # 加载数据集和文字描述
    celeba = torchvision.datasets.CelebA(root = './39.AIGC/CELEBA', split = 'test', download = True)
    text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in celeba.attr_names]).to(device)
    return celeba, text_inputs


def test_model(start, end, celeba, text_inputs, model, preprocess):
    # 测试模型
    length = end - start + 1
    face_accuracy = 0
    face_score = 0

    for i, data in enumerate(celeba):
        face_result = 0
        if i < start:
            continue
        image, target = data
        image_input = preprocess(image).unsqueeze(0).to(device)

        with torch.no_grad():
            image_features = model.encode_image(image_input)
            text_features = model.encode_text(text_inputs)

        image_features /= image_features.norm(dim = -1, keepdim = True)
        text_features /= text_features.norm(dim = -1, keepdim = True)

        text_probs = (100.0 * image_features @ text_features.T).softmax(dim = -1)
        top_score, top_label = text_probs.topk(6, dim = -1)
        for k, score in zip(top_label[0], top_score[0]):
            if k.item() < 40 and target[k.item()] == 1:
                face_result = 1
                face_score += score.item()
                print('Predict right! The predicted is {}'.format(celeba.attr_names[k.item()]))
            else:
                print('Predict flase! The predicted is {}'.format(celeba.attr_names[k.item()]))
        face_accuracy += face_result

        if i == end:
            break
    face_score = face_score / length
    face_accuracy = face_accuracy / length

    return face_score, face_accuracy


if __name__ == '__main__':
    start = 0
    end = 1000
    model_name = 'ViT-B/32'
    data_path = 'CELEBA'

    time_start = time.time()
    model, preprocess = model_load(model_name)
    celeba, text_inputs = data_load(data_path)
    face_score, face_accuracy = test_model(start, end, celeba, text_inputs, model, preprocess)
    time_end = time.time()

    print('The prediction:')
    print('face_accuracy: {:.2f} face_score: {}%'.format(face_accuracy, face_score * 100))
    print('runing time: %.4f' % (time_end - time_start))

在这里插入图片描述

三、代码逐行解读

import clip
import torch
import torchvision
import time

这段代码导入了 cliptorchtorchvisiontime 库。这些库提供了用于计算机视觉和深度学习任务的功能和工具。

  1. clip 是一个用于视觉和文本数据的深度学习模型库,可以将图像和文本进行编码和匹配。
  2. torchPyTorch 库,提供了张量操作、神经网络模型、优化器等工具。
  3. torchvisionPyTorch 的一个扩展库,提供了常用的计算机视觉数据集、模型架构和图像处理工具。
  4. timePython 标准库,提供了计时和时间相关的函数。
device = "cuda" if torch.cuda.is_available() else "cpu"

这行代码用于选择设备(device),可以是 CUDA 加速的 GPU 设备或者 CPU 设备。它使用了条件表达式(if-else)来检查系统是否有可用的 CUDA 设备。如果有可用的 CUDA 设备,将设备设置为 “cuda” ;否则,将设备设置为 “cpu”

def model_load(model_name):
    # 加载模型
    model, preprocess = clip.load(model_name, device) #ViT-B/32 RN50x16
    return model, preprocess

这个函数用于加载 CLIP 模型和预处理函数。

具体解读如下:

  • model_load 是一个函数,接受一个 model_name 参数作为输入。
  • 在函数内部,调用了 clip.load(model_name, device) 来加载 CLIP 模型和预处理函数。 model_name 指定了要加载的 CLIP 模型的名称,device 指定了要在哪个设备上加载模型(之前定义的 device 变量)。
  • clip.load() 函数返回一个模型对象和一个预处理函数对象。
  • 最后,函数将加载的模型对象和预处理函数对象作为结果返回。
def data_load(data_path):
    # 加载数据集和文字描述
    celeba = torchvision.datasets.CelebA(root = './39.AIGC/CELEBA', split = 'test', download = True)
    text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in celeba.attr_names]).to(device)
    return celeba, text_inputs

这个函数用于加载数据集和生成与数据集相关的文字描述。

  • data_load 是一个函数,接受一个 data_path 参数作为输入。
  • 在函数内部,调用了 torchvision.datasets.CelebA 来加载 CelebA 数据集。root 参数指定了数据集的根目录路径,split 参数指定了要加载的数据集划分(这里使用的是测试集),download 参数指定了是否下载数据集(设为 True 表示下载)。
  • 在加载 CelebA 数据集后,通过遍历 celeba.attr_names 中的每个属性名称,使用 clip.tokenize() 函数生成与属性名称相关的文字描述,并使用 torch.cat() 函数将这些描述连接起来。最终,得到的文字描述张量被转移到指定的设备上(之前定义的 device 变量)。
  • 最后,函数将加载的数据集对象和生成的文字描述张量作为结果返回。
def test_model(start, end, celeba, text_inputs, model, preprocess):
    # 测试模型
    length = end - start + 1
    face_accuracy = 0
    face_score = 0

    for i, data in enumerate(celeba):
        face_result = 0
        if i < start:
            continue
        image, target = data
        image_input = preprocess(image).unsqueeze(0).to(device)

        with torch.no_grad():
            image_features = model.encode_image(image_input)
            text_features = model.encode_text(text_inputs)

        image_features /= image_features.norm(dim = -1, keepdim = True)
        text_features /= text_features.norm(dim = -1, keepdim = True)

        text_probs = (100.0 * image_features @ text_features.T).softmax(dim = -1)
        top_score, top_label = text_probs.topk(6, dim = -1)
        for k, score in zip(top_label[0], top_score[0]):
            if k.item() < 40 and target[k.item()] == 1:
                face_result = 1
                face_score += score.item()
                print('Predict right! The predicted is {}'.format(celeba.attr_names[k.item()]))
            else:
                print('Predict flase! The predicted is {}'.format(celeba.attr_names[k.item()]))
        face_accuracy += face_result

        if i == end:
            break
    face_score = face_score / length
    face_accuracy = face_accuracy / length

    return face_score, face_accuracy

这个函数用于测试模型的性能。

  1. test_model 是一个函数,接受 startendcelebatext_inputsmodelpreprocess 作为输入。
  2. 在函数内部,首先初始化一些变量,包括 length(表示要处理的图像数量)、face_accuracy(用于记录人脸识别的准确率)和 face_score(用于记录人脸识别的得分)。
  3. 然后,使用 enumerate(celeba) 遍历 CelebA 数据集,其中i表示当前迭代的索引,data 表示当前迭代的数据。
  4. 在每次迭代中,首先将 face_result 初始化为 0。然后,通过 data 获取当前图像和目标标签。
  5. 接下来,将图像输入预处理函数 preprocess 进行预处理,并通过 unsqueeze(0) 在批次维度上添加一个维度。然后将处理后的图像输入到模型中,分别使用 model.encode_image()model.encode_text() 来获取图像特征和文字特征。
  6. 对于图像特征和文字特征,进行归一化处理,将每个特征向量除以其范数,以使其长度为 1。
  7. 使用归一化后的特征计算图像特征与文字特征之间的相似度,通过矩阵乘法和 softmax 操作得到预测的文本概率分布 text_probs
  8. 接下来,使用 topk() 函数获取预测概率最高的 6 个标签,并遍历每个标签和对应的得分。
  9. 如果预测的标签索引小于 40 且目标标签中对应索引的值为 1(表示该属性为真),则将 face_result 设置为 1,并将得分累加到 face_score 中,同时打印预测正确的信息;否则,打印预测错误的信息。
  10. 最后,将 face_result 累加到 face_accuracy 中,判断是否达到了指定的结束索引 end,如果是,则终止循环。
  11. 计算平均得分和平均准确率,并将其作为结果返回。

总的来说,这个函数的作用是对模型进行测试,并计算人脸识别的平均得分和平均准确率。在测试过程中,它遍历 CelebA 数据集中的图像,计算图像与文字特征之间的相似度,并根据预测的结果评估模型的性能。

if __name__ == '__main__':
    start = 0
    end = 1000
    model_name = 'ViT-B/32'
    data_path = 'CELEBA'

    time_start = time.time()
    model, preprocess = model_load(model_name)
    celeba, text_inputs = data_load(data_path)
    face_score, face_accuracy = test_model(start, end, celeba, text_inputs, model, preprocess)
    time_end = time.time()

    print('The prediction:')
    print('face_accuracy: {:.2f} face_score: {}%'.format(face_accuracy, face_score * 100))
    print('runing time: %.4f' % (time_end - time_start))

这段代码是整个程序的入口点,它实现了整个流程的控制和输出结果。

  • if name == ‘main’:是 Python 中的条件语句,表示当该脚本被直接运行时(而不是作为模块导入时),以下的代码块将被执行。
  • 在该代码块中,首先定义了一些变量,包括 start(开始索引)、end(结束索引)、model_name(模型名称)和 data_path(数据集路径)。
  • 通过 time.time() 获取当前时间,将其记录为 time_start,以便后续计算程序的运行时间。
  • 调用 model_load(model_name) 函数加载指定名称的模型,并将返回的 modelpreprocess 赋值给 modelpreprocess 变量。
  • 调用 data_load(data_path) 函数加载数据集,并将返回的 celebatext_inputs 赋值给 celebatext_inputs 变量。
  • 调用 test_model(start, end, celeba, text_inputs, model, preprocess) 函数对模型进行测试,获取人脸识别的得分和准确率,分别赋值给 face_scoreface_accuracy 变量。
  • 通过 time.time() 获取当前时间,将其记录为 time_end,以便计算程序的运行时间。
  • 使用 print() 函数输出预测结果,包括人脸准确率、人脸得分和运行时间。

总的来说,该部分代码是整个程序的入口,它负责加载模型、加载数据集、测试模型并输出结果。通过设定的参数对模型进行测试,并打印出人脸识别的准确率、得分和程序运行时间。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/529420.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【MySQL】MySQL索引--聚簇索引和非聚簇索引的区别

文章目录 前言1.聚簇索引和非聚簇索引的概念2.两者详细介绍2.1 聚簇索引2.2 非聚簇索引 3. 两者的区别3.1 数据存储方式3.2 二级索引查询 前言 1.聚簇索引和非聚簇索引的概念 数据库表的索引从数据存储方式上可以分为聚簇索引和非聚簇索引两种。“聚簇”的意思是数据行被按照…

【Java|golang】1072. 按列翻转得到最大值等行数

给定 m x n 矩阵 matrix 。 你可以从中选出任意数量的列并翻转其上的 每个 单元格。&#xff08;即翻转后&#xff0c;单元格的值从 0 变成 1&#xff0c;或者从 1 变为 0 。&#xff09; 返回 经过一些翻转后&#xff0c;行与行之间所有值都相等的最大行数 。 示例 1&#…

使用TensorFlow构建,绘制和解释人工神经网络

使用 Python 进行深度学习&#xff1a;神经网络&#xff08;完整教程&#xff09; 使用TensorFlow构建&#xff0c;绘制和解释人工神经网络 总结 在本文中&#xff0c;我将展示如何使用Python构建神经网络&#xff0c;以及如何使用可视化和创建模型预测解释器向业务解释深度学习…

【PCIE】pcie设备协议分析和crash后定位

分析RP Headerlog在协议中位置 能力集寄存器协议字段 HeaderLog字段偏移以及各字段含义 headerLog和协议的对应入截图中内容 completer id就是完成的ID&#xff0c;对应的BDF &#xff0c;如下图 b5:00.0 AECap寄存器 其中 first error pointer 含义&#xff1a; 这里有专…

对于 Git 每一次提交的时间信息,什么是作者日期和提交者日期

文章目录 什么是作者日期和提交者日期如何查看作者日期和提交者日期方法 1方法 2方法 3 修改最近一次提交的时间 什么是作者日期和提交者日期 对于 Git 的每一次提交&#xff0c;在 TortoiseGit 和 IntelliJ IDEA 都可以看到这次提交的时间。但很多人不知道的是&#xff0c;Gi…

人脸识别2:Python实现人脸识别Face Recognition(含源码)

人脸识别2&#xff1a;Python实现人脸识别Face Recognition(含源码) 目录 人脸识别2&#xff1a;Python实现人脸识别Face Recognition(含源码) 1. 前言 2. 项目安装 3. 人脸识别系统 &#xff08;1&#xff09;人脸检测和关键点检测 &#xff08;2&#xff09;人脸校准 …

【数据结构与算法】布隆(Bloom Filter)过滤器

文章目录 1、什么是布隆过滤器2、布隆过滤器的使用场景3、布隆过滤器的原理3.1 数据结构3.2 空间计算3.3 增加元素3.4 查询元素3.5 修改元素3.6 删除元素 4、Redis集成布隆过滤器4.1 版本要求4.2 安装&编译4.2.1 下载插件压缩包4.2.2 解压4.2.3 编译插件 4.3 Redis集成4.3.…

《硅谷钢铁侠:埃隆·马斯克的冒险人生》成就21世纪的史诗

《硅谷钢铁侠&#xff1a;埃隆马斯克的冒险人生》成就21世纪的史诗 阿什利万斯&#xff08;Ashlee Vance&#xff09;&#xff1a;美国商业专栏作家、资深科技记者。 文章目录 《硅谷钢铁侠&#xff1a;埃隆马斯克的冒险人生》成就21世纪的史诗马斯克的超级公司摘录感悟梦 马斯…

免费SSL:阿里云SSL证书免费申请入口及流程开启HTTPS

阿里云SSL免费证书在哪申请&#xff1f;一个阿里云账号一年可以申请20张免费SSL证书&#xff0c;很多同学找不到免费SSL的入口&#xff0c;阿小云来详细说下阿里云SSL证书免费申请入口链接以及免费SSL证书申请流程&#xff0c;有同学反馈阿里云免费SSL证书没有了&#xff1f;错…

【Python共享文件】——Python快速搭建HTTP web服务实现文件共享并公网远程访问

文章目录 1. 前言2. 视频教程3. 本地文件服务器搭建3.1 python的安装和设置3.2 cpolar的安装和注册 4. 本地文件服务器的发布4.1 Cpolar云端设置4.2 Cpolar本地设置 5. 公网访问测试6. 结语 1. 前言 数据共享作为和连接作为互联网的基础应用&#xff0c;不仅在商业和办公场景有…

[离散数学]图论

图基本概念 点相同 边相同 $$ 有向图 无向图 邻接点 &#xff1a;两个结点有一条有(无)向边相关联 邻接边:关联与同一个结点 孤立结点: 不予任何结点相邻接的结点 握手定理 度数边的两倍 有向图的 出度和入度和边数 n个节点无向完全图边数 C n 2 1 2 n ( n − 1 ) C_n^2…

14JS05——流程控制-分支

目标&#xff1a; 1、流程控制 2、顺序流程控制 3、分支流程控制if语句 4、三元表达式 5、分支流程控制switch语句 一、流程控制 在一个程序执行的过程中&#xff0c;各条代码的执行顺序对程序的结果是有直接影响的。很多时候我们要通过控制代码 的执行顺序来实现我们要完成的…

Science | 人体可以依靠饥饿感来延缓衰老

作为一种高级动物&#xff0c;人体需要六种营养物质来维持基本的生理需求&#xff1a;糖类、油脂、蛋白质、无机盐、水、和维生素。对营养物质的生理需求促使人和动物去追寻食物。而饮食对人和动物的行为和寿命又有着显著影响。 近年来&#xff0c;越来越多的研究表明&#xf…

计算机网络 二 (物理层)

物理层 概念 物理层为数据链路层屏蔽了各种传输媒体的差异&#xff0c;使数据链路层只需要考虑如何完成本层的协议和服务&#xff0c;而不必考虑网络具体的传输媒体是什么。 对于物理层有很多很多的协议&#xff0c;不过都不怎么重要&#xff0c;对于物理层我们知道物理层协议…

国产仪器仪表 1466C-V/1466D-V/1466E-V/1466G-V/1466H-V/1466L-V系列信号发生器

国产Ceyear 1466-V系列信号发生器是一款面向微波毫米波尖端测试的通用测试仪器&#xff0c;频率范围覆盖宽、射频调制带宽大、信号频谱纯度高&#xff0c;具有高准确度和大动态范围的功率输出&#xff0c;以及出色的矢量调制精度和ACPR性能&#xff0c;搭配单机双射频通道和多机…

21天学会Linux----Day1:Linux环境搭建

CSDN的uu们&#xff0c;大家好。这里是Linux的第一讲。 座右铭&#xff1a;前路坎坷&#xff0c;披荆斩棘&#xff0c;扶摇直上。 博客主页&#xff1a; 姬如祎 收录专栏&#xff1a;Linux保姆级教程 目录 1. Linux环境搭建的三种方式 2. 阿里云学生认证白嫖七个月云服务器…

十四、Zuul网关

目录 一、API网关作用&#xff1a; 二、网关主要功能&#xff1a; 2.1、统一服务入口 2.2、接口鉴权 2.3、智能路由 2.4、API接口进行统一管理 2.5、限流保护 三、 新建一个项目作为网关服务器 3.1、项目中引入Zuul网关依赖 3.2、在项目application.yml中配置网关路由…

mmdetection 中 Mask Rcnn检测结果可视化(DICE计算、PR曲线绘制等)

mmdetection中的Mask Rcnn是一个很不错的检测网络&#xff0c;既可以实现目标检测&#xff0c;也可以实现语义分割。官方也有很详细的doc指导&#xff0c;但是对新手来说并不友好&#xff0c;刚好之前笔者写的mmlab系列里面关于可视化都还没有一个详细的文档&#xff0c;也在此…

JAVA常用API - Runtime和System

文章目录 前言 大家好,我是最爱吃兽奶,今天给大家带来JAVA常用API中的Runtime类和System类 那么就让我们一起去看看吧! 一、Rubtime 1.Rubtime是什么? 2.Runtime常用方法 Runtime提供了很多方法,在这里演示两个 public static Runtime getRuntime(): 返回当前运行时环境的…

ChatSQL - 文本生成SQL【LLM】

ChatSQL将用户提供的纯文本转换为 mysql 查询&#xff0c;基于ChatGPT实现。 推荐&#xff1a;用 NSDT设计器 快速搭建可编程3D场景。 1、ChatSQL简介 我们需要从一开始就指定一些关于我们数据库的信息&#xff0c;以便 Chatgpt 了解我们的数据库。 info.json 文件可用于此过程…