1 项目背景介绍
以图搜图是一种向量检索技术,通过上传一张图像来搜索并找到与之相关的其他图像或相关信息。以图搜图技术提供了一种更直观、更高效的信息检索方式。这种技术应用场景和价值非常广泛,经常会用在商品检索及购物、动植物识别、食品识别、知识检索等领域。以图搜图涉及到的技术点如下:
- 如何对图片数据进行向量编码
- 如何对海量的向量数据进行存储
- 如何快速对海量的向量数据进行检索
本项目基于Resnet预训练模型结合Milvus向量数据库,在水果数据集上实现了以图搜图系统,读者可以将数据集扩展到其它领域,构建满足自身业务的以图搜图系统。
2 关键技术介绍
2.1 Resnet网络
ResNet,全称为Residual Network,是深度学习领域中非常重要的卷积神经网络(Convolutional Neural Network,CNN)架构之一。它由Kaiming He等人在2015年提出,并在ImageNet图像分类比赛中取得了显著的成果,在当时获得分类任务,目标检测,图像分割第一名。ResNet的创新之处在于引入了残差连接(residual connections),允许网络在训练过程中更容易地训练深层网络。
在传统的神经网络中,随着网络层数的增加,性能可能会饱和甚至下降。这是因为梯度消失和梯度爆炸等问题会导致训练变得困难。ResNet通过引入残差块(residual block)来解决这个问题。每个残差块包括一个主要的卷积层,其输出与输入之间的差异被称为“残差”,然后将残差添加回来,得到最终的输出。这样的架构允许信息在网络中更容易地传播,即使网络变得非常深。
ResNet的经典网络结构有:ResNet-18、ResNet-34、ResNet-50、ResNet-101、ResNet-152几种,其中,ResNet-18和ResNet-34的基本结构相同,属于相对浅层的网络,后面3种属于更深层的网络,其中RestNet50最为常用。
ResNet的优点包括:
- 训练更深的网络: 引入残差连接允许构建非常深的网络,这些网络在训练时更容易收敛。
- 避免梯度消失和爆炸: 残差连接有助于梯度在网络中更好地传播,减少了梯度消失和爆炸的问题。
- 更好的特征学习: 残差块允许网络学习残差,即学习更容易捕获到的细粒度特征。
ResNet详细介绍:ResNet
2.2 Milvus向量数据库
Milvus 是一款云原生向量数据库,它具备高可用、高性能、易拓展的特点,用于海量向量数据的实时召回。
Milvus 基于FAISS、Annoy、HNSW 等向量搜索库构建,核心是解决稠密向量相似度检索的问题。在向量检索库的基础上,Milvus 支持数据分区分片、数据持久化、增量数据摄取、标量向量混合查询、time travel 等功能,同时大幅优化了向量检索的性能,可满足任何向量检索场景的应用需求。通常,建议用户使用 Kubernetes 部署 Milvus,以获得最佳可用性和弹性。
Milvus 采用共享存储架构,存储计算完全分离,计算节点支持横向扩展。从架构上来看,Milvus 遵循数据流和控制流分离,整体分为了四个层次,分别为接入层(access layer)、协调服务(coordinator service)、执行节点(worker node)和存储层(storage)。各个层次相互独立,独立扩展和容灾。
Milvus 向量数据库能够帮助用户轻松应对海量非结构化数据(图片/视频/语音/文本)检索。单节点 Milvus 可以在秒内完成十亿级的向量搜索,分布式架构亦能满足用户的水平扩展需求。
milvus特点总结如下:
- 高性能:性能高超,可对海量数据集进行向量相似度检索。
- 高可用、高可靠:Milvus 支持在云上扩展,其容灾能力能够保证服务高可用。
- 混合查询:Milvus 支持在向量相似度检索过程中进行标量字段过滤,实现混合查询。
- 开发者友好:支持多语言、多工具的 Milvus 生态系统。
Milvus详细介绍:Milvus
3 系统代码实现
3.1 运行环境构建
conda环境准备详见:annoconda
git clone https://gitcode.net/ai-medical/image_image_search.git
cd image_image_search
pip install -r requirements.txt
3.2 数据集下载
下载地址:
第一个数据包:package01
第二个数据包:package01
在数据集目录下,存放着10个文件夹,文件夹名称为水果类型,每个文件夹包含几百到几千张此类水果的图片,如下图所示:
以apple文件夹为例,内容如下:
下载后进行解压,保存到D:/dataset/fruit目录下,查看显示如下
# ll fruit/
总用量 508
drwxr-xr-x 2 root root 36864 8月 2 16:35 apple
drwxr-xr-x 2 root root 24576 8月 2 16:36 apricot
drwxr-xr-x 2 root root 40960 8月 2 16:36 banana
drwxr-xr-x 2 root root 20480 8月 2 16:36 blueberry
drwxr-xr-x 2 root root 45056 8月 2 16:37 cherry
drwxr-xr-x 2 root root 12288 8月 2 16:37 citrus
drwxr-xr-x 2 root root 49152 8月 2 16:38 grape
drwxr-xr-x 2 root root 16384 8月 2 16:38 lemon
drwxr-xr-x 2 root root 36864 8月 2 16:39 litchi
drwxr-xr-x 2 root root 49152 8月 2 16:39 mango
3.3 预训练模型下载
'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
下载resnet50的预训练模型:resnet50,存放到D:/models目录下
3.4 代码实现
3.4.1 创建database
from pymilvus import connections, db
conn = connections.connect(host="192.168.1.156", port=19530)
database = db.create_database("image_vector_db")
db.using_database("image_vector_db")
print(db.list_database())
3.4.2 创建collection
from pymilvus import CollectionSchema, FieldSchema, DataType
from pymilvus import Collection, db, connections
conn = connections.connect(host="192.168.1.156", port=19530)
db.using_database("image_vector_db")
m_id = FieldSchema(name="m_id", dtype=DataType.INT64, is_primary=True,)
embeding = FieldSchema(name="embeding", dtype=DataType.FLOAT_VECTOR, dim=2048,)
path = FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=256,)
schema = CollectionSchema(
fields=[m_id, embeding, path],
description="image to image embeding search",
enable_dynamic_field=True
)
collection_name = "fruit_vector"
collection = Collection(name=collection_name, schema=schema, using='default', shards_num=2)
3.4.3 创建index
from pymilvus import Collection, utility, connections, db
conn = connections.connect(host="192.168.1.156", port=19530)
db.using_database("image_vector_db")
index_params = {
"metric_type": "L2",
"index_type": "IVF_FLAT",
"params": {"nlist": 1024}
}
collection = Collection("fruit_vector")
collection.create_index(
field_name="embeding",
index_params=index_params
)
utility.index_building_progress("fruit_vector")
3.4.4 数据加载到milvus
from restnet_embeding import restnet_embeding
from milvus_operator import restnet_image, MilvusOperator
from PIL import Image, ImageSequence
import os
def update_image_vector(data_path, operator: MilvusOperator):
idxs, embedings, paths = [], [], []
total_count = 0
for dir_name in os.listdir(data_path):
sub_dir = os.path.join(data_path, dir_name)
for file in os.listdir(sub_dir):
image = Image.open(os.path.join(sub_dir, file)).convert('RGB')
embeding = restnet_embeding.embeding(image)
idxs.append(total_count)
embedings.append(embeding[0].detach().numpy().tolist())
paths.append(os.path.join(sub_dir, file))
total_count += 1
if total_count % 50 == 0:
data = [idxs, embedings, paths]
operator.insert_data(data)
print(f'success insert {operator.coll_name} items:{len(idxs)}')
idxs, embedings, paths = [], [], []
if len(idxs):
data = [idxs, embedings, paths]
operator.insert_data(data)
print(f'success insert {operator.coll_name} items:{len(idxs)}')
print(f'finish update {operator.coll_name} items: {total_count}')
if __name__ == '__main__':
data_dir = 'D:/dataset/fruit'
update_image_vector(data_dir, restnet_image)
3.4.5 基于Resnet预训练模型构建编码网络
加载预训练模型,去掉全连接层,是的Resnet编码输出特征维度为2048
from torchvision.models import resnet50
import torch
from torchvision import transforms
from torch import nn
class RestnetEmbeding:
pretrained_model = 'D:/models/resnet50-0676ba61.pth'
def __init__(self):
self.model = resnet50()
self.model.load_state_dict(torch.load(self.pretrained_model))
# delete fc layer
self.model.fc = nn.Sequential()
self.transform = transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])])
def embeding(self, image):
trans_image = self.transform(image)
trans_image = trans_image.unsqueeze_(0)
return self.model(trans_image)
restnet_embeding = RestnetEmbeding()
3.4.6 构建检索web
import gradio as gr
import torch
import numpy as np
import argparse
from net_helper import net_helper
from PIL import Image
from restnet_embeding import restnet_embeding
from milvus_operator import restnet_image
def image_search(image):
if image is None:
return None
image = image.convert("RGB")
# restnet编码
imput_embeding = restnet_embeding.embeding(image)
imput_embeding = imput_embeding[0].detach().cpu().numpy()
results = restnet_image.search_data(imput_embeding)
pil_images = [Image.open(result['path']) for result in results]
return pil_images
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true",
default=False, help="share gradio app")
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
app = gr.Blocks(theme='default', title="image",
css=".gradio-container, .gradio-container button {background-color: #009FCC} "
"footer {visibility: hidden}")
with app:
with gr.Tabs():
with gr.TabItem("image search"):
with gr.Row():
with gr.Column():
image = gr.inputs.Image(type="pil", source='upload')
btn = gr.Button(label="search")
with gr.Column():
with gr.Row():
output_images = [gr.outputs.Image(type="pil", label=None) for _ in range(16)]
btn.click(image_search, inputs=[image], outputs=output_images, show_progress=True)
ip_addr = net_helper.get_host_ip()
app.queue(concurrency_count=3).launch(show_api=False, share=True, server_name=ip_addr, server_port=9099)
4 总结
本项目基于Resnet预训练模型及milvus向量数据库两个关键技术,构建了以图搜图的图像检索系统;在构建过程中,对Resnet网络模型进行了改造,去掉了全连接层,经过Restnet编码后每个图片输出向量维度为2048,存入milvus向量数据库;为保证图像检索的效率,通过脚本在milvus向量数据库中构建了向量索引。此项目可作为参考,在实际开发类似的以图搜图项目中直接使用。
项目完整代码地址:code