向量检索:基于ResNet预训练模型构建以图搜图系统

news2024/10/2 3:27:00

1 项目背景介绍

以图搜图是一种向量检索技术,通过上传一张图像来搜索并找到与之相关的其他图像或相关信息。以图搜图技术提供了一种更直观、更高效的信息检索方式。这种技术应用场景和价值非常广泛,经常会用在商品检索及购物、动植物识别、食品识别、知识检索等领域。以图搜图涉及到的技术点如下:

  • 如何对图片数据进行向量编码
  • 如何对海量的向量数据进行存储
  • 如何快速对海量的向量数据进行检索

本项目基于Resnet预训练模型结合Milvus向量数据库,在水果数据集上实现了以图搜图系统,读者可以将数据集扩展到其它领域,构建满足自身业务的以图搜图系统。

2 关键技术介绍

2.1 Resnet网络

ResNet,全称为Residual Network,是深度学习领域中非常重要的卷积神经网络(Convolutional Neural Network,CNN)架构之一。它由Kaiming He等人在2015年提出,并在ImageNet图像分类比赛中取得了显著的成果,在当时获得分类任务,目标检测,图像分割第一名。ResNet的创新之处在于引入了残差连接(residual connections),允许网络在训练过程中更容易地训练深层网络。

在传统的神经网络中,随着网络层数的增加,性能可能会饱和甚至下降。这是因为梯度消失和梯度爆炸等问题会导致训练变得困难。ResNet通过引入残差块(residual block)来解决这个问题。每个残差块包括一个主要的卷积层,其输出与输入之间的差异被称为“残差”,然后将残差添加回来,得到最终的输出。这样的架构允许信息在网络中更容易地传播,即使网络变得非常深。

ResNet的经典网络结构有:ResNet-18、ResNet-34、ResNet-50、ResNet-101、ResNet-152几种,其中,ResNet-18和ResNet-34的基本结构相同,属于相对浅层的网络,后面3种属于更深层的网络,其中RestNet50最为常用。

 ResNet的优点包括:

  • 训练更深的网络: 引入残差连接允许构建非常深的网络,这些网络在训练时更容易收敛。
  • 避免梯度消失和爆炸: 残差连接有助于梯度在网络中更好地传播,减少了梯度消失和爆炸的问题。
  • 更好的特征学习: 残差块允许网络学习残差,即学习更容易捕获到的细粒度特征。

ResNet详细介绍:ResNet

2.2 Milvus向量数据库

Milvus 是一款云原生向量数据库,它具备高可用、高性能、易拓展的特点,用于海量向量数据的实时召回。

Milvus 基于FAISS、Annoy、HNSW 等向量搜索库构建,核心是解决稠密向量相似度检索的问题。在向量检索库的基础上,Milvus 支持数据分区分片、数据持久化、增量数据摄取、标量向量混合查询、time travel 等功能,同时大幅优化了向量检索的性能,可满足任何向量检索场景的应用需求。通常,建议用户使用 Kubernetes 部署 Milvus,以获得最佳可用性和弹性。

Milvus 采用共享存储架构,​存储计算完全分离​,计算节点支持横向扩展。从架构上来看,Milvus 遵循数据流和控制流分离,整体分为了四个层次,分别为接入层(access layer)、协调服务(coordinator service)、执行节点(worker node)和存储层(storage)。各个层次相互独立,独立扩展和容灾。

 Milvus 向量数据库能够帮助用户轻松应对海量非结构化数据(图片/视频/语音/文本)检索。单节点 Milvus 可以在秒内完成十亿级的向量搜索,分布式架构亦能满足用户的水平扩展需求。

milvus特点总结如下:

  • 高性能:性能高超,可对海量数据集进行向量相似度检索。
  • 高可用、高可靠:Milvus 支持在云上扩展,其容灾能力能够保证服务高可用。
  • 混合查询:Milvus 支持在向量相似度检索过程中进行标量字段过滤,实现混合查询。
  • 开发者友好:支持多语言、多工具的 Milvus 生态系统。

Milvus详细介绍:Milvus

3 系统代码实现

3.1 运行环境构建

conda环境准备详见:annoconda

git clone https://gitcode.net/ai-medical/image_image_search.git
cd image_image_search

pip install -r requirements.txt

3.2 数据集下载

下载地址:

第一个数据包:package01

第二个数据包:package01

在数据集目录下,存放着10个文件夹,文件夹名称为水果类型,每个文件夹包含几百到几千张此类水果的图片,如下图所示:

 以apple文件夹为例,内容如下:

下载后进行解压,保存到D:/dataset/fruit目录下,查看显示如下

# ll fruit/
总用量 508
drwxr-xr-x 2 root root 36864 8月   2 16:35 apple
drwxr-xr-x 2 root root 24576 8月   2 16:36 apricot
drwxr-xr-x 2 root root 40960 8月   2 16:36 banana
drwxr-xr-x 2 root root 20480 8月   2 16:36 blueberry
drwxr-xr-x 2 root root 45056 8月   2 16:37 cherry
drwxr-xr-x 2 root root 12288 8月   2 16:37 citrus
drwxr-xr-x 2 root root 49152 8月   2 16:38 grape
drwxr-xr-x 2 root root 16384 8月   2 16:38 lemon
drwxr-xr-x 2 root root 36864 8月   2 16:39 litchi
drwxr-xr-x 2 root root 49152 8月   2 16:39 mango

3.3 预训练模型下载

 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',

下载resnet50的预训练模型:resnet50,存放到D:/models目录下

3.4 代码实现

3.4.1 创建database

from pymilvus import connections, db

conn = connections.connect(host="192.168.1.156", port=19530)
database = db.create_database("image_vector_db")

db.using_database("image_vector_db")
print(db.list_database())

3.4.2 创建collection

from pymilvus import CollectionSchema, FieldSchema, DataType
from pymilvus import Collection, db, connections


conn = connections.connect(host="192.168.1.156", port=19530)
db.using_database("image_vector_db")

m_id = FieldSchema(name="m_id", dtype=DataType.INT64, is_primary=True,)
embeding = FieldSchema(name="embeding", dtype=DataType.FLOAT_VECTOR, dim=2048,)
path = FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=256,)
schema = CollectionSchema(
  fields=[m_id, embeding, path],
  description="image to image embeding search",
  enable_dynamic_field=True
)

collection_name = "fruit_vector"
collection = Collection(name=collection_name, schema=schema, using='default', shards_num=2)

 

3.4.3 创建index

from pymilvus import Collection, utility, connections, db

conn = connections.connect(host="192.168.1.156", port=19530)
db.using_database("image_vector_db")

index_params = {
  "metric_type": "L2",
  "index_type": "IVF_FLAT",
  "params": {"nlist": 1024}
}

collection = Collection("fruit_vector")
collection.create_index(
  field_name="embeding",
  index_params=index_params
)

utility.index_building_progress("fruit_vector")

3.4.4 数据加载到milvus

from restnet_embeding import restnet_embeding
from milvus_operator import restnet_image, MilvusOperator
from PIL import Image, ImageSequence
import os


def update_image_vector(data_path, operator: MilvusOperator):
    idxs, embedings, paths = [], [], []

    total_count = 0
    for dir_name in os.listdir(data_path):
        sub_dir = os.path.join(data_path, dir_name)
        for file in os.listdir(sub_dir):

            image = Image.open(os.path.join(sub_dir, file)).convert('RGB')
            embeding = restnet_embeding.embeding(image)

            idxs.append(total_count)
            embedings.append(embeding[0].detach().numpy().tolist())
            paths.append(os.path.join(sub_dir, file))
            total_count += 1

            if total_count % 50 == 0:
                data = [idxs, embedings, paths]
                operator.insert_data(data)

                print(f'success insert {operator.coll_name} items:{len(idxs)}')
                idxs, embedings, paths = [], [], []

        if len(idxs):
            data = [idxs, embedings, paths]
            operator.insert_data(data)
            print(f'success insert {operator.coll_name} items:{len(idxs)}')

    print(f'finish update {operator.coll_name} items: {total_count}')


if __name__ == '__main__':
    data_dir = 'D:/dataset/fruit'
    update_image_vector(data_dir, restnet_image)

3.4.5 基于Resnet预训练模型构建编码网络

加载预训练模型,去掉全连接层,是的Resnet编码输出特征维度为2048

from torchvision.models import resnet50
import torch
from torchvision import transforms
from torch import nn


class RestnetEmbeding:
    pretrained_model = 'D:/models/resnet50-0676ba61.pth'

    def __init__(self):
        self.model = resnet50()
        self.model.load_state_dict(torch.load(self.pretrained_model))

        # delete fc layer
        self.model.fc = nn.Sequential()
        self.transform = transforms.Compose([transforms.Resize((224, 224)),
                                             transforms.ToTensor(),
                                             transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                                                  std=[0.26862954, 0.26130258, 0.27577711])])

    def embeding(self, image):
        trans_image = self.transform(image)
        trans_image = trans_image.unsqueeze_(0)
        return self.model(trans_image)


restnet_embeding = RestnetEmbeding()

 3.4.6 构建检索web

import gradio as gr
import torch
import numpy as np
import argparse
from net_helper import net_helper
from PIL import Image
from restnet_embeding import restnet_embeding
from milvus_operator import restnet_image


def image_search(image):
    if image is None:
        return None

    image = image.convert("RGB")

    # restnet编码
    imput_embeding = restnet_embeding.embeding(image)
    imput_embeding = imput_embeding[0].detach().cpu().numpy()

    results = restnet_image.search_data(imput_embeding)
    pil_images = [Image.open(result['path']) for result in results]
    return pil_images


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--share", action="store_true",
                        default=False, help="share gradio app")
    args = parser.parse_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    app = gr.Blocks(theme='default', title="image",
                    css=".gradio-container, .gradio-container button {background-color: #009FCC} "
                        "footer {visibility: hidden}")
    with app:
        with gr.Tabs():
            with gr.TabItem("image search"):
                with gr.Row():
                    with gr.Column():
                        image = gr.inputs.Image(type="pil", source='upload')
                        btn = gr.Button(label="search")

                    with gr.Column():
                        with gr.Row():
                            output_images = [gr.outputs.Image(type="pil", label=None) for _ in range(16)]

                btn.click(image_search, inputs=[image], outputs=output_images, show_progress=True)

    ip_addr = net_helper.get_host_ip()
    app.queue(concurrency_count=3).launch(show_api=False, share=True, server_name=ip_addr, server_port=9099)

4 总结

本项目基于Resnet预训练模型及milvus向量数据库两个关键技术,构建了以图搜图的图像检索系统;在构建过程中,对Resnet网络模型进行了改造,去掉了全连接层,经过Restnet编码后每个图片输出向量维度为2048,存入milvus向量数据库;为保证图像检索的效率,通过脚本在milvus向量数据库中构建了向量索引。此项目可作为参考,在实际开发类似的以图搜图项目中直接使用。

项目完整代码地址:code

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/919372.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

leetcode48. 旋转图像(java)

旋转图像 题目描述旋转技巧上期经典算法 题目描述 难度 - 中等 原题链接 - 旋转图像 给定一个 n n 的二维矩阵 matrix 表示一个图像。请你将图像顺时针旋转 90 度。 你必须在 原地 旋转图像,这意味着你需要直接修改输入的二维矩阵。请不要 使用另一个矩阵来旋转图像…

StreamPark

1、StreamPark的标语 一个神奇的框架,让流处理更简单 2、StreamPark的前世今生 早期用名streamx,加入apache孵化器之后更名为StreamPark 3、StreamPark可以为你提供什么 降低学习成本、开发门槛,让开发者只用关心核心的业务 简单来说&#xf…

Pytorch学习:torchvison.transforms常用包(ToTensor、Resize、Compose和RandomCrop)

transforms常用包 1. torchvision.transforms.ToTensor2. torchvision.transforms.Resize3. torchvision.transforms.Compose4. torchvision.transforms.Normalize5. torchvision.transforms.RandomCrop 1. torchvision.transforms.ToTensor 将PIL Image或ndarray转换为张量并…

基于React实现无限滚动的日历详细教程,附源码【手写日历教程第二篇】

前言 最常见的日历大部分都是滚动去加载更多的月份,而不是让用户手动点击按钮切换日历月份。滚动加载的交互方式对于用户而言是更加丝滑和舒适的,没有明显的操作割裂感。 那么现在需要做一个这样的无限滚动的日历,前端开发者应该如何去思考…

【数据结构】实现栈和队列

目录 一、栈1.栈的概念及结构(1)栈的概念(2)栈的结构 2.栈的实现(1)类型和函数的声明(2)初始化栈(3)销毁(4)入栈(5&#x…

在 Redis 中处理键值 | Navicat

Redis 是一个键值存储系统,允许我们将值与键相关联起来。与关系型数据库不同的是, 在Redis 中,不需要使用数据操作语言 (DML) 和查询语法,那么我们如何进行数据的写入、读取、更新和删除操作呢?…

shell 11(shell重定向输入输出)

一、标准输入输出 标准输入介绍 从键盘读取用户输入的数据,然后再把数据拿到Shell程序中使用; 标准输出介绍 Shell程序产生的数据,这些数据一般都是呈现到显示器上供用户浏览查看

go学习一之go的初体验

go语言学习笔记 一、golang初体验: 1.简单体验案例: package main{ //把这个test.go归属到main import "fmt" //引入一个包 func main(){//输出hellofmt.Println("hello world")} }2.从案例学到的知识点: (1) go文件的后缀是.…

【集合学习HashMap】HashMap集合详细分析

HashMap集合详细分析 一、HashMap简介 HashMap 主要用来存放键值对(key-value的形式),它基于哈希表的 Map 接口实现,是常用的 Java 集合之一,是非线程安全的。 HashMap 可以存储 null 的 key 和 value,但 …

nginx基本介绍(安装、常用命令、反向代理)

文章目录 引言一、nginx是什么二、nginx的下载和安装1. 下载2. windows下安装3. 运行4. 外部服务器无法访问问题 三、nginx的常用命令四、nginx.config五、FileZilla1. 什么是FileZilla2. FileZilla的下载和安装 六、反向代理1. 什么是nginx的反向代理2. 反向代理工作流程3. 如…

2023-8-23 连通块中点的数量

题目链接&#xff1a;连通块中点的数量 #include <iostream>using namespace std;const int N 100010;int n, m; int p[N], Size[N], idx;int find(int x) {if(p[x] ! x) p[x] find(p[x]);return p[x]; }int main() {cin >> n >> m;for(int i 1; i <…

七、任务优先级和Tick

1、任务与中断的优先级 (1)相同优先级任务轮流执行。 (2)高优先级任务打断低优先级任务。 (3)中断可以打断所有优先级的任务。 2、任务优先级 (1)优先级的取值范围是&#xff1a;0~(configMAX_PRIORITIES – 1)&#xff0c;数值越大优先级越高。 (2)FreeRTOS会确保最高优…

API 网关基础

目录 一、网关概述二、网关提供的功能三、常见网关系统3.1 Netflix Zuul3.2 Spring Cloud Gateway3.3 Kong3.4 APISIX3.5 Shenyu 一、网关概述 API网关是一个服务器&#xff0c;是系统的唯一入口。 从面向对象设计的角度看&#xff0c;它与外观模式类似。API网关封装了系统内部…

小白带你学习linux的LVS集群(三十六)

一、集群概述 1、负载均衡技术类型 四层负载均衡器 也称为 4 层交换机&#xff0c;主要通过分析 IP 层及 TCP/UDP 层的流量实现基于 IP 加端口的负载均衡&#xff0c;如常见的 LVS、F5 等&#xff1b; 七层负载均衡器 也称为 7 层交换机&#xff0c;位于 OSI 的最高层&#…

机器人力控入门——牛顿欧拉法动力学建模

建立机器人的动力学模型是完成力控的基础&#xff0c;常用的动力学模型建模法有拉格朗日法和牛顿-欧拉法&#xff0c;其中牛顿-欧拉采用递推形式&#xff0c;计算更为简便&#xff0c;使用也更为广泛。本文就来介绍下牛顿-欧拉的动力学建模方法&#xff0c; PS&#xff0c;网上…

C++--动态规划两个数组的dp问题

1.最长公共子序列 力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 给定两个字符串 text1 和 text2&#xff0c;返回这两个字符串的最长 公共子序列 的长度。如果不存在 公共子序列 &#xff0c;返回 0 。 一个字符串的 子序列 是指这样一个新的字符串…

day 37 | ● 1049. 最后一块石头的重量 II ● 494. 目标和 ● 474.一和零

1049. 最后一块石头的重量 II 与前一道分割等和子集的思路差不多&#xff0c;都是01背包问题。因为是采用滚动数组的形式&#xff0c;所以必须要倒序遍历才可以。 dp[i]代表着在i的限制下最大的承重。所以另一半就是all - dp【all / 2】 func lastStoneWeightII(stones []int…

Fabric.js 元素选中状态的事件与样式

本文简介 带尬猴&#xff01; 你是否在使用 Fabric.js 时希望能在选中元素后自定义元素样式或选框&#xff08;控制角和辅助线&#xff09;的样式&#xff1f; 如果是的话&#xff0c;可以放心往下读。 本文将手把脚和你一起过一遍 Fabric.js 在对象元素选中后常用的样式设置…

git 把项目托管到 码云出现的错误集合

分享一下我git项目时碰见的错误 1、error: could not lock config file D:/orcad/Cadence/SPB_Data/.gitconfig: No suchfile or directory 在下载git后设置用户名、邮箱时会出现的错误 需要去修改环境变量&#xff0c;这个之前写好了&#xff0c;可以跳转看看 Git配置error:…

计算机竞赛 基于Django与深度学习的股票预测系统

文章目录 0 前言1 课题背景2 实现效果3 Django框架4 数据整理5 模型准备和训练6 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; **基于Django与深度学习的股票预测系统 ** 该项目较为新颖&#xff0c;适合作为竞赛课题方向&#xff…