当一个程序员决定穿上粉色裤子

news2024/11/17 13:19:34

74eeab02088485d13f3e886412104c77.png

作为一个大众眼中的“非典型程序员”,我喜欢拥抱时尚和潮流,比如我经常在演讲时穿粉色裤子,这甚至已经成为一个标志性打扮。某天又逢主题演讲日,我站在衣柜前挑选上衣的时候,忽然灵光乍现:有没有可能借助 Milvus 找到和我穿搭风格最为相似的明星呢?

5bbe8b8eec1539b413e226ad709d6e1e.jpeg

这个想法在我脑海中不停地闪现,始终没有遇到特别合适的契机进行实践。直到最近,我遇到了一个名为 Fashion AI 的项目,它主要利用微调模型对服装图片进行分割(segmentation),然后裁剪出图像中标注(label)的时尚单品,并将所有图片调整为相同的大小,最后将这些图像转化为 embedding 向量存储在开源向量数据库 Milvus 中。通过这个项目可以在 Milvus 数据库中查询并获得 3 个最相似的向量结果。随后,就可以通过上传一张自己穿着打扮的照片,最终确定与我们时尚风格最为相似的明星。

接下来,我将和大家分享这个项目具体的实现路径。

在正式开始前,可以通过这个链接 https://drive.google.com/file/d/1pBO02iLgToBSCOyMJ58zWHQf4ZRkP5AY/view 获取项目使用到的图片。此外,想要搭建本项目,还需要升级 Python 版本,通过指令pip install milvus pymilvus torch torchvision matplotli安装所需软件工具等。本项目使用了 Hugging Face 上由 Mateusz Dziemian 提供的 clothing segmenter 模型 https://huggingface.co/mattmdjaga/segformer_b2_clothes 以及 PyTorch 上由 Nvidia 提供的 ResNet50 模型 https://pytorch.org/hub/nvidia_deeplearningexamples_resnet50/对图像进行分割,将图像转化为 embedding 向量。

01.

图像分割

为了完成图像分割任务,我在 Hugging Face 上找到了以下 3 个模型:

  • Mateusz Dziemian 提供的 segformer_b2_clothes 模型

  • Valentina Feruere 提供的 YOLOS-Fashionpedia 模型

  • Patrick John Chia 提供的 Fashion-CLIP 模型

最终,我选择了 segformer 模型,因为它可以对不同的服装图片进行准确分割,并识别出 18 种“对象”类型。也就是说,这个模型可以检测到图片中的“上衣”、“连衣裙”、“左脚鞋子”、“右脚鞋子”等诸多服装类型。此外,这个模型还可以检测图片中的”脸部”、“头发”、“右腿”、“左腿”等。浏览该链接 https://huggingface.co/mattmdjaga/segformer_b2_clothes/blob/main/config.json#L30了解模型可以识别的全部 18 种对象(object)类型。

开始前,我们首先需要导入本项目中图像处理时所需的工具包,包括:

  • torch用于提取图像特征

  • 来自 transformers 的 segformer

  • 来自 torchvision 的 Resize、masks_to_boxes、crop。

import torch
from torch import nn, tensor
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
from torchvision.transforms import Resize
import torchvision.transforms as T
from torchvision.ops import masks_to_boxes
from torchvision.transforms.functional import crop
  • 使用 Hugging Face 生成图像分割掩膜

图像分割方法有很多种,采用哪种方法主要取决于你使用的模型及其检测到的内容。在本项目中,我们使用的模型会返回一个 18 层的图像,每层包含一种检测对象类型,其中包含图像背景。

现在,我们先编写一个函数来生成这个 18 层图像。

get_segmentation 函数需要三个参数:特征提取器(feature extractor)、模型(model)和图像(image)。首先,这个函数会使用图像和提取器生成输入特征(input feature), 然后将模型输出转换为 logits。之后,该函数通过 PyTorch 双线性插值(Bilinear Interpolation)上采样(upsample) logits。最后,该函数仅采取每个像素中的最大预测值,以创建分割掩膜(mask)。

def get_segmentation(extractor, model, image):
    inputs = extractor(images=image, return_tensors="pt")

    outputs = model(**inputs)
    logits = outputs.logits.cpu()

    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.size[::-1],
        mode="bilinear",
        align_corners=False,
    )

    pred_seg = upsampled_logits.argmax(dim=1)[0]
    return pred_seg

upsampled_logits 中的图像如下所示:

7b5044cf1fcd1dea4e7ac477893095a8.png

pred_seg 图像如下所示。上面两张都是 Andre 3000 的照片,但其实是不同的图像:

313a0bdb8856525305453bc56629d771.png

至此,获取分割 mask 的操作就十分简单了。我们获取分割结果中所有的唯一值。根据本项目采用的模型,最多可以获取 18 个值。第一个结果代表的是图像背景,所以可以舍弃这个结果。为了生成 mask,我们提取分割像素中与对象 ID 一致的像素。

以下函数会返回 mask 和 ID,以便可以同时查看二者:

# 返回 2 个 lists masks (tensor) 和obj_ids(int)
# 来自 hugging face 的 "mattmdjaga/segformer_b2_clothes" 模型
def get_masks(segmentation):
    obj_ids = torch.unique(segmentation)
    obj_ids = obj_ids[1:]
    masks = segmentation == obj_ids[:, None, None]
    return masks, obj_ids

函数生成的图像 mask 如下所示。左图为头发 mask,右图为上衣 mask:

8eccf3a47e390d21ca28feead80c9786.png

  • 使用 Pytorch 裁剪和调整图像大小

接下来使用 get_masks 函数为图像中每个监测到的对象以及原图生成新图像。随后用 masks_to_boxes 函数将 mask 转化为边界框(bounding box)。此前,我们已经通过 torchvision.ops 导入了这个函数。

接着,创建一系列边界框并将边界框坐标系转为 crop 坐标系。边界框的形式为 (x1, x2, y1, y2)。crop 函数期望输入形式为 (top, left, height, width)

在正式裁剪图像前,我们还定义了一个图像预处理函数。将每个图像调整为 256x256 的大小,并转化为 PyTorch tensor (目前是 PIL 图像)。裁剪时,循环遍历裁剪框,并调用 crop 函数。随后我们将预处理完成的图片加入到 dictionary 中,以对应分割 ID 的主键值。函数最后会返回 dictionary。

def crop_images(masks, obj_ids, img):
    boxes = masks_to_boxes(masks)
    crop_boxes = []
    for box in boxes:
        crop_box = tensor([box[0], box[1], box[2]-box[0], box[3]-box[1]])
        crop_boxes.append(crop_box)

    preprocess = T.Compose([
        T.Resize(size=(256, 256)),
        T.ToTensor()
    ])

    cropped_images = {}
    for i in range(len(crop_boxes)):
        crop_box = crop_boxes[i]
        cropped = crop(img, crop_box[1].item(), crop_box[0].item(), crop_box[3].item(), crop_box[2].item())
        cropped_images[obj_ids[i].item()] = preprocess(cropped)
    return cropped_images

下面的示例图中 Drake 穿着鲜橙色的衣服。我们使用裁剪框框处图像中的对象(时尚单品)并为他们各自生成单独的图像:

7a4f92d48a254855368c2c423fe603f6.png

02.

将图像数据添加至向量数据库中

图像分割裁剪完成后,我们就可以将其添加至 Milvus 向量数据库中了。为了方便上手,本项目中使用了 Milvus Lite 版本,可以在 notebook 中运行 Milvus 实例。接下来,使用 PyMilvus 连接至 Milvus Lite 提供的默认服务器。

这一步骤中,还需要设置一些常量。定义向量维度、数据量、集合名称、返回的结果个数。随后,运行 ssl 函数来创建上下文,从 PyTorch 获取模型。

from milvus import default_server
from pymilvus import utility, connections
default_server.start()
connections.connect(host="127.0.0.1", port=default_server.listen_port)

DIMENSION = 2048
BATCH_SIZE = 128
COLLECTION_NAME = "fashion"
TOP_K = 3

# 如果遇到 SSL 证书 URL 错误,请在导入 resnet50 模型前运行此内容
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
  • 在向量数据库中定制 Schema 并存储元数据

先定制 Schema。Schema 用于组织向量数据库中存储的数据。id 字段就和 SQL 或者 NoSQL 数据库中的 key ID 一样。Milvus Schema 中的其他字段可以设置 int64、varchar、float 等数据类型。

在本项目中,我们是保存文件路径、明星名字、分割 ID,并将其作为元数据,后续还会考虑添加更多字段,例如边界框、mask 位置等。定义好 FieldSchema、CollectionSchema 后,就可以创建 1 个 Miluvs Collection。

Collection 创建完成后,构建索引。索引参数十分简单。选择 IVF Flat 的索引类型和 L2 相似度类型。这个索引是针对于 Collection 中的 embedding 向量字段。索引构建完成后,将 Collection 加载到内存中,以便后续操作。

from pymilvus import FieldSchema, CollectionSchema, Collection, DataType

fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200),
    FieldSchema(name="name", dtype=DataType.VARCHAR, max_length=200),
    FieldSchema(name="seg_id", dtype=DataType.INT64),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
]

schema = CollectionSchema(fields=fields)
collection = Collection(name=COLLECTION_NAME, schema=schema)

index_params = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 128},
}
collection.create_index(field_name="embedding", index_params=index_params)
collection.load()
  • 从 Nvidia ResNet50 模型获取 embedding 向量

我们需要先从 PyTorch 中加载 Nvidia  ResNet50 模型,然后删除最后一层输出层,因为embedding 向量是模型的倒数第二层输出。

# 加载 embedding 模型并删除最后一层输出
embeddings_model = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_resnet50', pretrained=True)
embeddings_model = torch.nn.Sequential(*(list(embeddings_model.children())[:-1]))
embeddings_model.eval()

以下函数负责接收向量并将数据插入 Milvus。主要有三个参数:数据、集合对象和模型(也就是本项目中使用的 embedding 模型)。为了解插入到数据库中的数据,以下代码中添加了几条打印语句。

除了打印调试数据外,我们还将 data[0] 中的所有值堆叠到一个 tensor 中,然后使用 squeeze 函数从输出中删除维度是 1 的值。随后,插入新的数据列表,其中包括原数据中的最后三条以及由 tensor 输出转化而来的数据列表,这些数据对应文件路径、名称、分割 ID、2048 维向量。

def embed_insert(data, collection, model):
    with torch.no_grad():
        print(len(data[0]))
        print(data[0][0].size())
        output = model(torch.stack(data[0])).squeeze()
        print(type(output))
        print(len(output))
        print(len(output[0]))
        print(output[0])

    collection.insert([data[1], data[2], data[3], output.tolist()])

打印的数据如下图所示:

49f616b80cf482a4cc555af6fc0813b2.png

每个数据批次的大小为 128,每条数据的大小为 3x256x256。输出是  PyTorch tensor,长度为 128,输出中的每条数据长度为 2048。打印的 tensor 是数据批次中的第一条数据。

  • 将图像数据存储到向量数据库中

还记得前文提到的特征提取器和分割模型吗?接下来轮到它们出场了。我们需要用到 segformer 预训练模型, 在循环遍历所有文件路径之后,将所有文件路径放入一个列表中。

extractor = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
import os
image_paths = []
for celeb in os.listdir("./photos"):
    for image in os.listdir(f"./photos/{celeb}/"):
        # print(image)
        image_paths.append(f"./photos/{celeb}/{image}")

Milvus 期望输入格式为列表。在本项目中,我们使用了 4 个列表,分别对应图像、文件路径、名称和分割 ID。在 embed_insert 函数中,将图像转换为 embedding 向量。然后,循环遍历每个图像文件的文件路径,收集它们的分割 mask 并对其进行裁剪。最后,将图像及元数据添加到数据批处理中。

每 128 张图像作为一批数据,我们将其转化为向量并插入到 Milvus 中,然后清空这批数据。在循环结束时,会 flush 数据完成索引构建。注意,在配备 M1 2021 Mac 和 16GB RAM 的计算机上,运行此过程需要约 8 分钟。

from PIL import Image
data_batch = [[], [], [], []]

for path in image_paths:
    image = Image.open(path)
    path_split = path.split("/")
    name = " ".join(path_split[2].split("_"))
    segmentation = get_segmentation(extractor, model, image)
    masks, ids = get_masks(segmentation)
    cropped_images = crop_images(masks, ids, image)

    for key, image in cropped_images.items():
        data_batch[0].append(image)
        data_batch[1].append(path)
        data_batch[2].append(name)
        data_batch[3].append(key)

    if len(data_batch[0]) % BATCH_SIZE == 0:
        embed_insert(data_batch, collection, embeddings_model)
        data_batch = [[], [], [], []]

if len(data_batch[0]) != 0:
    embed_insert(data_batch, collection, embeddings_model)

collection.flush()

03.

寻找与你时尚风格最相似的明星

上述步骤都完成后,就可以开始玩转这个系统了,它可以根据你上传的图片返回前 3 个与你穿搭风格最相似的明星。

  • 将上传图像转化为向量

首先需要处理上传的图像。以下函数需要两个参数:数据和 (embedding)模型。我们使用模型将图像转化为向量、处理图像,图像转化为列表并返回图片列表。

def embed_search_images(data, model):
    with torch.no_grad():
        print(len(data[0]))
        print(data[0][0].size())
        output = model(torch.stack(data))
        print(type(output))
        print(len(output))
        print(len(output[0]))
        print(output[0])
        if len(output) > 1:
            return output.squeeze().tolist()
        Else:
     return torch.flatten(output, start_dim=1).tolist()

如下图所示,传入本函数的 data 实际上是 data[0] 对象。

4da4384d7ed4b7c231b5c2e5a074842d.png

在查询时,我们只需要向量数据,但还是可以保留其他数据字段,就像把数据插入到 Milvus 中一样。

# data_batch[0] is a list of tensors
# data_batch[1] is a list of filepaths to the images (string)
# data_batch[2] is a list of the names of the people in the images (string)
# data_batch[3] is a list of segmentation keys (int)
data_batch = [[], [], [], []]


search_paths = ["./photos/Taylor_Swift/Taylor_Swift_3.jpg", "./photos/Taylor_Swift/Taylor_Swift_8.jpg"]


for path in search_paths:
    image = Image.open(path)
    path_split = path.split("/")
    name = " ".join(path_split[2].split("_"))
    segmentation = get_segmentation(extractor, model, image)
    masks, ids = get_masks(segmentation)
    cropped_images = crop_images(masks, ids, image)
    for key, image in cropped_images.items():
        data_batch[0].append(image)
        data_batch[1].append(path)
        data_batch[2].append(name)
        data_batch[3].append(key)


embeds = embed_search_images(data_batch[0], embeddings_model)
  • 查询向量数据库

将上传图片转化为向量后,便可以开始在向量数据库中查询相似数据了。为了测试,我们添加了 time 模块记录每次查询所需的时间。本项目中测量了查询 23 个 2048 维向量数据所需的时间,如果没有这个需求,可以直接使用 search 函数。

import time
start = time.time()
res = collection.search(embeds,
    anns_field='embedding',
    param={"metric_type": "L2",
    "params": {"nprobe": 10}},
    limit=TOP_K,
    output_fields=['filepath'])
finish = time.time()
print(finish - start)

在循环后,可以看到以下生成的响应。

for index, result in enumerate(res):
    print(index)
    print(result)

cc9fd3ef0724aa7151a3bdbde0e30e9c.png

欢迎大家上手操作,期待你们的结果分享!

本文最初发布于 AI Accelerator Institute,已获得转载许可。

🌟「寻找 AIGC 时代的 CVP 实践之星」 专题活动即将启动!

Zilliz 将联合国内头部大模型厂商一同甄选应用场景, 由双方提供向量数据库与大模型顶级技术专家为用户赋能,一同打磨应用,提升落地效果,赋能业务本身。

如果你的应用也适合 CVP 框架,且正为应用落地和实际效果发愁,可直接申请参与活动,获得最专业的帮助和指导!联系邮箱为 business@zilliz.com。

本文作者

fe6c04f66c26df3c968c15ad6ce83d34.jpeg

Yujian Tang
Zilliz 开发者布道师

推荐阅读

47067f07c0a22ecfaf463375302e5285.png

e1db5fdf7aa78fdec56410b33b181dae.png

53efdcb061196e9deb582f816a80a87b.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/921554.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于jenkins自动化部署PHP环境

实验环境 操作系统 IP地址 主机名 角色 CentOS7.5 192.168.147.141 git git服务器 CentOS7.5 192.168.147.142 Jenkins git客户端 jenkins服务器 CentOS7.5 192.168.147.143 web web服务器 具体环境配置见上一篇! 准备git仓库 [rootgit ~]# su -…

如何写好公文材料

写好公文材料需要具备一定的写作技巧,同时也需要对公文的格式、语言和结构有深入的了解。以下是如何写好公文材料的建议和步骤: 1.确定公文的目的 在开始写作前,明确公文的目的。它是为了传达什么样的信息?是通知、申请、报告、建…

专题-【十字链表】

有向图的十字链表表示法:

U盘文件恢复,拯救文件,只需简单3招!

“u盘文件删掉了还能恢复吗?七夕和对象吵架了,一气之下把之前一起旅游的照片视频都删了,今天看到空空的u盘,心里真的很难受。有什么方法可以恢复u盘文件吗?” U盘在我们的日常生活中已经扮演了很重要的角色&#xff0c…

TC1016-同星4路CAN(FD),2路LIN转USB接口卡

TC1016是同星智能推出的一款多通道CAN(FD)和LIN总线接口设备,CANFD总线速率最高支持8M bps,LIN支持速率0~20K bps,产品采用高速USB2.0接口与PC连接,Windows系统免驱设计使得设备具备极佳的系统兼容性。 支…

【附安装】R语言4.3.0安装教程

软件下载 软件:R语言版本:4.3.0语言:简体中文大小:77.74M安装环境:Win7及以上版本,64位操作系统硬件要求:CPU2.0GHz 内存4G(或更高)下载通道①百度网盘丨64位下载链接:h…

android Junit4编写自测用例

10多年的android开发经验,一直以来呢,也没有使用过android自带的测试代码编写。说来也惭愧。今天也花了点时间稍微研究了下。还挺简单。接下来就简单的说一下。 新建工程 直接默认新建一个工程,就会有两个目录androidTest和test(unitTest)两…

漏洞复现 || muhttpd 任意文件读取

漏洞描述 muhttpd(mu-HTTP-deamon)是一个简单但完整的web服务器,用可移植的ANSI C编写。它支持静态页面、CGI脚本、基于MIME类型的处理程序和HTTPS,muhttpd 1.1.7之前版本存在安全漏洞。攻击者利用该漏洞读取系统任意文件。 免责…

免费制作高质量的电子期刊网站

工具介绍:FLBOOK 打开FLBOOK首页就能看见有四五本高质量的电子书刊,并且每打开一本,书的最下方就有阅读次数的统计。 FLBOOK制作电子期刊的方法也非常简单,可以根据小编的步骤开始制作或是看FLBOOK的教程,亲自动手制作…

第一讲使用IDEA创建Java工程——HelloWorld

一、前言导读 为了能够让初学者更快上手Java,不会像其他书籍或者视频一样,介绍一大堆历史背景,默认大家已经知道Java这么编程语言了。本专栏只会讲解干货,直接从HelloWord入手,慢慢由浅入深,讲个各个知识点,这些知识点也是目前工作中项目使用的,而不是讲一些老的知识点…

Tuxera NTFS2023中文版Mac读写NTFS格式硬盘访问、编辑、存储和传输文件工具

因为Mac电脑不能写入NTFS格式磁盘,但是多数用户使用的是NTFS格式的移动硬盘、u盘,因此很多NTFS for Mac软件应运而生。但是市面上很多NTFS for Mac软件很多,例如:Tuxera NTFS for Mac、Paragon NTFS for Mac等。Tuxera NTFS for M…

【分析绘图】R语言实现一些常见的绘图

微生信-在线绘图网站 线性图 library(ggplot2)x <- rnorm(100, 14, 5) # rnorm(n, mean 0, sd 1) y <- x rnorm(100, 0, 1) ggplot(data NULL, aes(x x, y y)) # 开始绘图geom_point(color "darkred") # 添加点annotate("text",x 13,…

Java面试题—2023年8月24日—YDZH

2023-08-24 10:54:28 北京 yī do zh h 答案仅供参考&#xff0c;博主仅记录发表&#xff0c;没有实际查询&#xff0c;不保证正确性。 面试题&#xff1a; 1、请你谈谈关于 Synchronized 和 lock ? 2、请简单描述一下类的加载过程?类加载器有几个种&#xff0c;分别作用是什…

微信小程序开发项目步骤【详细】

在平常是H5开发中已经不能满足我们的需求了&#xff0c;随着小程序的火热&#xff0c;越来越多的项目开发也离不开小程序的运用&#xff0c;目前常用的就是微信小程序&#xff0c;我们学完微信小程序后其他的小程序开发也是基本一样的&#xff0c;也为后面的uniapp开发做下一定…

港联证券|油价上涨对股票影响大吗?利好还是利空?

石油是现代国家国民经济的血脉&#xff0c;直接影响国民经济的发展。那么&#xff0c;油价上涨对股票影响大吗&#xff1f;利好仍是利空&#xff1f;为大家准备了相关内容&#xff0c;以供参阅。 香港港联证券有限公司&#xff08;百度一下港联证券&#xff09;成立于2021年1月…

vue 简单实验 v-model 变量和htm值双向绑定

1.代码 <script src"https://unpkg.com/vuenext" rel"external nofollow" ></script> <div id"two-way-binding"><p>{{ message }}</p><input v-model"message" /> </div> <script>…

AD域组策略开机脚本客户端不执行:解决方法

需求&#xff1a;本例实现的客户端计算机开机执行脚本&#xff0c;实现重置本地管理员的密码 1、创建组策略 2、在AD域中添加脚本 3、注意脚本的路径&#xff1a;就是打开 Show Files 目录&#xff0c;保证在客户端也能正常访问 4、本例建了2个脚本&#xff0c;一个是用来测试…

当《孤注一掷》照进现实,创邻科技Galaxybase助反诈一臂之力

“想成功&#xff0c;先发疯&#xff0c;不顾一切向钱冲&#xff1b;拼一次&#xff0c;富三代&#xff0c;拼命才能不失败。” 这看似振奋、实则让人背后发凉的口号来自于电影《孤注一掷》&#xff0c;它的背后是无数受害人血泪交织的受骗故事。 作为一部反诈题材电影&…

若依微服务版部署到IDEA

1.进入若依官网&#xff0c;找到我们要下的微服务版框架 2.点击进入gitee,获取源码&#xff0c;下载到本地 3.下载到本地后&#xff0c;用Idea打开&#xff0c;点击若依官网&#xff0c;找到在线文档&#xff0c;找到微服务版本的&#xff0c;当然你不看文档&#xff0c;直接按…

【从零学习python 】69. 网络通信及IP地址分类解析

文章目录 网络通信的概念IP地址IP地址的分类A类地址B类地址C类地址D类地址E类地址私有地址 进阶案例 网络通信的概念 简单来说&#xff0c;网络是用物理链路将各个孤立的工作站或主机相连在一起&#xff0c;组成数据链路&#xff0c;从而达到资源共享和通信的目的。 使用网络…