跨模态检索:基于OpenAI的Clip预训练模型构建以文搜图系统

news2024/11/20 21:19:34

目录

1 项目背景

2 关键技术

2.1 Clip模型

2.2 Milvus向量数据库

 3 系统代码实现

3.1 运行环境构建

3.2 数据集下载

3.3 预训练模型下载

3.4 代码实现

3.4.1 创建向量表和索引

 3.4.2 构建向量编码模型

3.4.3 数据向量化与加载

3.4.4 构建检索web

4 总结


1 项目背景

以文搜图是一种跨模态检索技术,即通过输入文字描述来搜索图片,它不仅应用于辅助搜索与信息检索,尤其在难以用关键词准确描述情况下发挥作用,提供了一种高效的信息检索方式。这种技术应用场景和价值非常广泛,它在辅助信息搜索、艺术、广告等领域均有重要的应用价值,为用户提供更个性化的搜索体验。以文搜图涉及到的技术点如下:

  • 如何对文本数据进行向量编码
  • 如何对海量图片数据进行向量化和存储
  • 如何映射文本向量与图片向量的关系
  • 如何快速对海量的向量数据进行检索

本项目基于OpenAI的Clip预训练模型结合Milvus向量数据库,在水果数据集上实现了以文搜图系统,读者可以将数据集扩展到其它领域,构建满足自身业务的以文搜图系统。

2 关键技术

2.1 Clip模型

CLIP全称Constrastive Language-Image Pre-training,是OpenAI推出的采用对比学习的文本-图像预训练模型。CLIP惊艳之处在于架构非常简洁且效果好到难以置信,在zero-shot文本-图像检索,zero-shot图像分类,文本→图像生成任务guidance,open-domain 检测分割等任务上均有非常惊艳的表现。

CLIP的创新之处在于,它能够将图像和文本映射到一个共享的向量空间中,从而使得模型能够理解图像和文本之间的语义关系。这种共享的向量空间使得CLIP在图像和文本之间实现了无监督的联合学习,从而可以用于各种视觉和语言任务。

CLIP的设计灵感源于一个简单的思想:让模型理解图像和文本之间的关系,不仅仅是通过监督训练,而是通过自监督的方式。CLIP通过大量的图像和文本对来训练,使得模型在向量空间中将相应的图像和文本嵌入彼此相近。


CLIP模型的特点

  • 统一的向量空间: CLIP的一个关键创新是将图像和文本都映射到同一个向量空间中。这使得模型能够直接在向量空间中计算图像和文本之间的相似性,而无需额外的中间表示。
  •  对比学习: CLIP使用对比学习的方式进行预训练。模型被要求将来自同一个样本的图像和文本嵌入映射到相近的位置,而将来自不同样本的嵌入映射到较远的位置。这使得模型能够学习到图像和文本之间的共同特征。
  •  多语言支持: CLIP的预训练模型是多语言的,这意味着它可以处理多种语言的文本,并将它们嵌入到共享的向量空间中。
  •  无监督学习: CLIP的预训练是无监督的,这意味着它不需要大量标注数据来指导训练。它从互联网上的文本和图像数据中学习,使得它在各种领域的任务上都能够表现出色。

Clip模型详细介绍:Clip模型详解

2.2 Milvus向量数据库

Milvus 是一款云原生向量数据库,它具备高可用、高性能、易拓展的特点,用于海量向量数据的实时召回。

Milvus 基于FAISS、Annoy、HNSW 等向量搜索库构建,核心是解决稠密向量相似度检索的问题。在向量检索库的基础上,Milvus 支持数据分区分片、数据持久化、增量数据摄取、标量向量混合查询、time travel 等功能,同时大幅优化了向量检索的性能,可满足任何向量检索场景的应用需求。通常,建议用户使用 Kubernetes 部署 Milvus,以获得最佳可用性和弹性。

Milvus 采用共享存储架构,​存储计算完全分离​,计算节点支持横向扩展。从架构上来看,Milvus 遵循数据流和控制流分离,整体分为了四个层次,分别为接入层(access layer)、协调服务(coordinator service)、执行节点(worker node)和存储层(storage)。各个层次相互独立,独立扩展和容灾。

 Milvus 向量数据库能够帮助用户轻松应对海量非结构化数据(图片/视频/语音/文本)检索。单节点 Milvus 可以在秒内完成十亿级的向量搜索,分布式架构亦能满足用户的水平扩展需求。

milvus特点总结如下:

  • 高性能:性能高超,可对海量数据集进行向量相似度检索。
  • 高可用、高可靠:Milvus 支持在云上扩展,其容灾能力能够保证服务高可用。
  • 混合查询:Milvus 支持在向量相似度检索过程中进行标量字段过滤,实现混合查询。
  • 开发者友好:支持多语言、多工具的 Milvus 生态系统。

Milvus详细介绍:Miluvs详解

 3 系统代码实现

3.1 运行环境构建

conda环境准备详见:annoconda

git clone https://gitcode.net/ai-medical/text_image_search.git
cd text_image_search

pip install -r requirements.txt
pip install git+https://ghproxy.com/https://github.com/openai/CLIP.git

3.2 数据集下载

下载地址:

第一个数据包:package01

第二个数据包:package01

在数据集目录下,存放着10个文件夹,文件夹名称为水果类型,每个文件夹包含几百到几千张此类水果的图片,如下图所示:

 以apple文件夹为例,内容如下:

下载后进行解压,保存到D:/dataset/fruit目录下,查看显示如下

# ll fruit/
总用量 508
drwxr-xr-x 2 root root 36864 8月   2 16:35 apple
drwxr-xr-x 2 root root 24576 8月   2 16:36 apricot
drwxr-xr-x 2 root root 40960 8月   2 16:36 banana
drwxr-xr-x 2 root root 20480 8月   2 16:36 blueberry
drwxr-xr-x 2 root root 45056 8月   2 16:37 cherry
drwxr-xr-x 2 root root 12288 8月   2 16:37 citrus
drwxr-xr-x 2 root root 49152 8月   2 16:38 grape
drwxr-xr-x 2 root root 16384 8月   2 16:38 lemon
drwxr-xr-x 2 root root 36864 8月   2 16:39 litchi
drwxr-xr-x 2 root root 49152 8月   2 16:39 mango

3.3 预训练模型下载

预训练模型包含5个resnet和4个VIT,其中ViT-L/14@336px效果最好。

"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",

下载ViT-L/14@336px的预训练模型:ViT-L-14-336px.pt,存放到D:/models目录下

3.4 代码实现

3.4.1 创建向量表和索引

from pymilvus import connections, db

conn = connections.connect(host="192.168.1.156", port=19530)
database = db.create_database("text_image_db")

db.using_database("text_image_db")
print(db.list_database())

创建collection

from pymilvus import CollectionSchema, FieldSchema, DataType
from pymilvus import Collection, db, connections


conn = connections.connect(host="192.168.1.156", port=19530)
db.using_database("text_image_db")

m_id = FieldSchema(name="m_id", dtype=DataType.INT64, is_primary=True,)
embeding = FieldSchema(name="embeding", dtype=DataType.FLOAT_VECTOR, dim=768,)
path = FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=256,)
schema = CollectionSchema(
  fields=[m_id, embeding, path],
  description="text to image embeding search",
  enable_dynamic_field=True
)

collection_name = "text_image_vector"
collection = Collection(name=collection_name, schema=schema, using='default', shards_num=2)

创建index

from pymilvus import Collection, utility, connections, db

conn = connections.connect(host="192.168.1.156", port=19530)
db.using_database("text_image_db")

index_params = {
  "metric_type": "IP",
  "index_type": "IVF_FLAT",
  "params": {"nlist": 1024}
}

collection = Collection("text_image_vector")
collection.create_index(
  field_name="embeding",
  index_params=index_params
)

utility.index_building_progress("text_image_vector")

 3.4.2 构建向量编码模型

加载预训练模型,通过Clip模型对图片进行编码,编码后输出特征维度为768

from torchvision.models import resnet50
import torch
from torchvision import transforms
from torch import nn


class RestnetEmbeding:
    pretrained_model = 'D:/models/resnet50-0676ba61.pth'

    def __init__(self):
        self.model = resnet50()
        self.model.load_state_dict(torch.load(self.pretrained_model))

        # delete fc layer
        self.model.fc = nn.Sequential()
        self.transform = transforms.Compose([transforms.Resize((224, 224)),
                                             transforms.ToTensor(),
                                             transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                                                  std=[0.26862954, 0.26130258, 0.27577711])])

    def embeding(self, image):
        trans_image = self.transform(image)
        trans_image = trans_image.unsqueeze_(0)
        return self.model(trans_image)


restnet_embeding = RestnetEmbeding()

3.4.3 数据向量化与加载

from clip_embeding import clip_embeding
from milvus_operator import text_image_vector, MilvusOperator
from PIL import Image
import os


def update_image_vector(data_path, operator: MilvusOperator):
    idxs, embedings, paths = [], [], []

    total_count = 0
    for dir_name in os.listdir(data_path):
        sub_dir = os.path.join(data_path, dir_name)
        for file in os.listdir(sub_dir):

            image = Image.open(os.path.join(sub_dir, file)).convert('RGB')
            embeding = clip_embeding.embeding_image(image)

            idxs.append(total_count)
            embedings.append(embeding[0].detach().numpy().tolist())
            paths.append(os.path.join(sub_dir, file))
            total_count += 1

            if total_count % 50 == 0:
                data = [idxs, embedings, paths]
                operator.insert_data(data)

                print(f'success insert {operator.coll_name} items:{len(idxs)}')
                idxs, embedings, paths = [], [], []

        if len(idxs):
            data = [idxs, embedings, paths]
            operator.insert_data(data)
            print(f'success insert {operator.coll_name} items:{len(idxs)}')

    print(f'finish update {operator.coll_name} items: {total_count}')


if __name__ == '__main__':
    data_dir = 'D:/dataset/fruit'
    update_image_vector(data_dir, text_image_vector)

3.4.4 构建检索web

import gradio as gr
import torch
import argparse
from net_helper import net_helper
from PIL import Image
from clip_embeding import clip_embeding
from milvus_operator import text_image_vector


def image_search(text):
    if text is None:
        return None

    # clip编码
    imput_embeding = clip_embeding.embeding_text(text)
    imput_embeding = imput_embeding[0].detach().cpu().numpy()

    results = text_image_vector.search_data(imput_embeding)
    pil_images = [Image.open(result['path']) for result in results]
    return pil_images


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--share", action="store_true",
                        default=False, help="share gradio app")
    args = parser.parse_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    app = gr.Blocks(theme='default', title="image",
                    css=".gradio-container, .gradio-container button {background-color: #009FCC} "
                        "footer {visibility: hidden}")
    with app:
        with gr.Tabs():
            with gr.TabItem("image search"):
                with gr.Row():
                    with gr.Column():
                        text = gr.TextArea(label="Text",
                                           placeholder="description",
                                           value="",)
                        btn = gr.Button(label="search")

                    with gr.Column():
                        with gr.Row():
                            output_images = [gr.outputs.Image(type="pil", label=None) for _ in range(16)]

                btn.click(image_search, inputs=[text], outputs=output_images, show_progress=True)

    ip_addr = net_helper.get_host_ip()
    app.queue(concurrency_count=3).launch(show_api=False, share=True, server_name=ip_addr, server_port=9099)

4 总结

本项目基于OpenAI的Clip预训练模型及milvus向量数据库两个关键技术,构建了以文搜图的跨模态检索系统;经过Clip模型编码后每个图片输出向量维度为768,存入milvus向量数据库;为保证图像检索的效率,通过脚本在milvus向量数据库中构建了向量索引。此项目可作为参考,在实际开发类似的信息检索项目中使用。

项目完整代码地址:code

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/927379.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何数据库备份,如何将数据库备份到其他服务器

在当今的数字世界里,数据库已经成为单位和个人存储、管理和检索海量数据的关键工具。然而,随着数据量的增加,内容丢失的风险也随之增加。这就是为什么定期备份数据库变得尤为重要。本文将详细介绍如何有效备份数据库,以保护您的数…

2023高教社杯数学建模思路 - 复盘:光照强度计算的优化模型

文章目录 0 赛题思路1 问题要求2 假设约定3 符号约定4 建立模型5 模型求解6 实现代码 建模资料 0 赛题思路 (赛题出来以后第一时间在CSDN分享) https://blog.csdn.net/dc_sinor?typeblog 1 问题要求 现在已知一个教室长为15米,宽为12米&…

五度易链最新“产业大数据服务解决方案”亮相,打造数据引擎,构建智慧产业

快来五度易链官网 点击网址【http://www.wdsk.net/】 看看我们都发布了哪些新功能!!! 自2015年布局产业大数据服务行业以来,“五度易链”作为全国产业大数据服务行业先锋企业,以“让数据引领决策,以智慧驾驭未来”为愿景,肩负“打…

说点大实话丨知名技术博主 Kirito 测评云原生网关

作者:徐靖峰 关注了阿里云云原生公众号,经常能看到 MSE-Higress 相关的推文,恰逢这次阿里云产品举办了一个 MSE-Higress 云原生网关的测评活动,借此机会体验了一把云原生网关的功能。 购买流程体验 购买网关时,页面明…

python入门篇04-循环(while与for),变量,函数基础

python目录 1. 前言1.1 上文传送 2. python基础使用2.1 while循环2.1.1 while循环的使用> 案例: 猜数字游戏(多经典...) 2.1.2 while双层循环> 案例: 输出9*9乘法表> 运行结果 2.2 for循环2.2.1 **for循环使用**> 案例: (字符串)查出有多少字符 2.2.2 方法range()的…

Leetcode每日一题:1448. 统计二叉树中好节点的数目

原题 给你一棵根为 root 的二叉树,请你返回二叉树中好节点的数目。 「好节点」X 定义为:从根到该节点 X 所经过的节点中,没有任何节点的值大于 X 的值。 示例 1: 输入:root [3,1,4,3,null,1,5] 输出:4 解…

网络映射会遇到哪些困难

网络映射通过将复杂的网络划分为更小、可管理的块,帮助 IT 管理员获得对其网络的更大控制和可见性,它有助于可视化不同的网络组件(如服务器、交换机端口和路由器)如何互连以执行其功能,通过表示网络设备的通信方式&…

腾讯云服务器价格表大全_轻量服务器_CVM云服务器报价明细

腾讯云服务器租用费用表:轻量应用服务器2核2G4M带宽112元一年,540元三年、2核4G5M带宽218元一年,2核4G5M带宽756元三年、云服务器CVM S5实例2核2G配置280.8元一年、GPU服务器GN10Xp实例145元7天,腾讯云服务器网长期更新腾讯云轻量…

无涯教程-进程 - 子进程监控

正如我们所看到的,每当我们使用fork从程序创建子进程时,都会发生以下情况- 当前进程成为父进程新进程成为子进程 如果父进程比子进程提前完成任务然后退出,会发生什么?现在谁将成为子进程的父进程?子进程的父进程是init进程,它…

业财融合背景下,全面预算管理的发展之路

随着社会经济的高速发展,单一的组织机构职能极大限制了企业发展的创新动能。业务壁垒的不断滋生造成了信息传达严重的不对等,沟通协作成本加大,业务效率降低,专业化的分工形式逐渐成为了制约企业发展的桎梏。 2016年&…

基于Python科研论文绘制学习 - task3

Seaborn seaborn 在matplotlib 的基础上进行了更高级的封装,能用更少的代码绘制配图。 1、图类型 关系型图 数据分布型图 分类数据型图 回归模型分析图 2、多子图网格型图 FacetGrid() import pandas as pd import numpy as np…

全球纳米烧结银市场年复合增长率为6.5%!

烧结银简单来讲是指经过低温烧结技术将纳米银粉&#xff08;平均粒径<0.1μm(100nm)&#xff09;印刷在承印物上&#xff0c;使之成为具有传导电流和排除积累静电荷能力的银浆&#xff0c;其由导电填料——银粉、粘合剂、溶剂及改善性能的微量添加剂组成&#xff0c;使用低熔…

云企业网CEN与转发路由器TR

云企业网CEN 云企业网CEN&#xff08;Cloud Enterprise Network&#xff09;是运行在阿里云私有全球网络上的一张高可用网络。云企业网通过转发路由器TR&#xff08;Transit Router&#xff09;帮助您在跨地域专有网络之间&#xff0c;专有网络与本地数据中心间搭建私网通信通…

苹果手机数据恢复的详细教程,果粉必看!

“照片不小心误删”、“清理内存把聊天记录删除了”、“手机重要文件丢失”……大家是否也会遇到以上的糟糕情况呢&#xff1f;“手机数据丢失”这六个字的杀伤力有多大&#xff0c;大家可想而知。 那么&#xff0c;手机删除的数据能够恢复吗&#xff1f;苹果手机数据恢复的方…

Hadoop Yarn 核心调优参数

文章目录 测试集群环境说明Yarn 核心配置参数1. 调度器选择2. ResourceManager 调度器处理线程数量设置3. 是否启用节点功能的自动检测设置4. 是否将逻辑处理器当作物理核心处理器5. 设置物理核心到虚拟核心的转换乘数6. 设置 NodeManager 使用的内存量7. 设置 NodeManager 节点…

光伏发电储能方案(小城光伏电站的智慧节能之路)

​阳光城小城光伏电站位于某省郊区,建成于5年前,总装机100兆瓦。电站采用光伏组串并联方式,将太阳能转换为电力后并网发电。但受限于光照条件,发电量经常无法有效利用,浪费严重。 为实现清洁能源的充分应用,电站管理层决定改造储能系统。经评估,采用联合电站侧、电网侧储能方式…

Java之内部类的详解

3.1 概述 3.1.1 什么是内部类 将一个类A定义在另一个类B里面&#xff0c;里面的那个类A就称为内部类&#xff0c;B则称为外部类。可以把内部类理解成寄生&#xff0c;外部类理解成宿主。 3.1.2 什么时候使用内部类 一个事物内部还有一个独立的事物&#xff0c;内部的事物脱…

Pytorch-day08-模型进阶训练技巧

PyTorch 模型进阶训练技巧 自定义损失函数 如 cross_entropy L2正则化动态调整学习率 如每十次 *0.1 典型案例&#xff1a;loss上下震荡 1、自定义损失函数 1、PyTorch已经提供了很多常用的损失函数&#xff0c;但是有些非通用的损失函数并未提供&#xff0c;比如&#xf…

4.7 为什么 TCP 每次建立连接时,初始化序列号都要不一样呢?

主要原因是为了防止历史报文被下一个相同四元组的连接接收。 如果正常通过四次挥手完成&#xff0c;TIME_WAIT状态会持续2MSL&#xff0c;历史报文会在下一个连接之前就自然消失&#xff0c;但无法保证每次连接都能通过四次挥手正常关闭。 假设客户端和服务端建立一个连接后&a…

Unity项目如何上传Gitee仓库

前言 最近Unity项目比较多&#xff0c;我都是把Unity项目上传到Gitee中去&#xff0c;GitHub的话我用的少&#xff0c;可能我还是更喜欢Gitee吧&#xff0c;毕竟Gitee仓库用起来更加方便&#xff0c;注意Unity项目上传时最佳的方式是把 Asste ProjectSetting 两个文件夹上传上…