如何实现以图搜图

news2024/7/6 18:43:39

一、前言

在许多搜索引擎中,都内置了以图搜图的功能。以图搜图功能,可以极大简化搜索工作。今天要做的就是实现一个以图搜图引擎。

我们先来讨论一下以图搜图的难点,首当其冲的就是如何对比图片的相似度?怎么样的图片才叫相似?人可以一眼判断,但是计算机却不一样。图片以数字矩阵的形式存在,而相似度的比较也是比较矩阵的相似度。但是这有一些问题。

第二个问题就是大小问题,图片的大小通常是不一样的,而不同大小的矩阵也无法比较相似度。不过这个很好解决,直接修改图片尺寸即可。

第三个问题则是像素包含的信息非常有限,无法表达抽象信息。比如画风、物体、色调等。

根据上面描述,我们现在要解决两个问题:用什么信息替换像素信息、怎么计算相似度。下面一一解决。在开始前,我们先实现一个简易的以图搜图功能。

二、简易以图搜图实现

2.1 如何计算相似度

首先来讨论一下直接使用像素作为图像的表示,此时我们应该如何完成以图搜图的操作。一个非常简单的想法就是直接计算两个图片的几何距离,假如我们目标图片为target,图库中的图片为source,几何距离的计算如下:

distance = sum[(target−source)2]distance = \sqrt{sum[(target - source)^2]}distance = sum[(target−source)2]​

然后把距离最小的n个图片作为搜索结果。

这个方法看起来不可靠,但是实际使用时也会有不错的结果。如果图库图片本身不是非常复杂,比如动漫头像,那么这种方式非常简单有效,而其它情况下结果会比较不稳定。

2.2 基于几何距离的图片搜索

基于几何距离的图片搜索实现步骤如下:

  1. 把图片修改到同一尺寸,如果尺寸不同则无法计算几何距离
  2. 选定一个图片作为目标图片,即待搜索图片
  3. 遍历图库,计算几何距离,并记录到列表
  4. 对列表排序,获取几何距离最小的n张图片

这里使用蜡笔小新的图片作为图库进行搜索,下面是图片的一些示例:

部分图片有类似的风格,我们希望能根据一张图片找到类似风格的图片。实现代码如下:

import os
import cv2
import random

import numpy as np

base_path = r"G:\datasets\lbxx"
# 获取所有图片路径
files = [os.path.join(base_path, file) for file in os.listdir(base_path)]
# 选取一张图片作为目标图片
target_path = random.choice(files)
target = cv2.imread(target_path)
h, w, _ = target.shape
distances = []
# 遍历图库
for file in files:
    # 读取图片,转换成与目标图片同一尺寸
    source = cv2.imread(file)
    if not isinstance(source, np.ndarray):
        continue
    source = cv2.resize(source, (w, h))
    # 计算几何距离,并加入列表,这里没有开方
    distance = ((target - source) ** 2).sum()
    distances.append((file, distance))
# 找到相似度前5的图片,这里拿了6个,第一个是原图
distances = sorted(distances, key=lambda x: x[-1])[:6]
imgs = list(map(lambda x: cv2.imread(x[0]), distances))
result = np.hstack(imgs)
cv2.imwrite("result.jpg", result)

下面是一些比较好搜索结果,其中最左边是target,其余10张为搜索结果。

如果换成猫狗图片,下面是一些搜索结果:

2.3 存在的问题

上面的实现存在两个问题,其一是像素并不能表示图像的深层含义。搜索结果中经常会返回颜色相似的图片。第二个则是计算复杂度的问题,如果图片大小未224×224,那么图片有150528个像素,计算几何距离会比较耗时。而且在搜索时,需要遍历整个图库,当图库数量较大时,计算量将不可忍受。因此需要对上面的方法进行一些改进。

三、改进一,用特征代替像素

3.1 图像特征

在表示图片时,就是从基本的像素到手工特征再到深度学习特征。相比之下,用卷积神经网络提取的图像特征有几个有点,具体如下:

  1. 具有很强的泛化能力,提取的特征受角度、位置、亮度等的影响会像素和手工特征。
  2. 较少的维度,使用ResNet50提取224×224图片的特征时,会返回一个7×7×2048的张量,这比像素数量要少许多。
  3. 具有抽象性,相比前面两种,卷积神经网络提取的特征具有抽象性。比如关于图片中类别的信息,这是前面两种无法达到的效果。

在本文我们会使用ResNet50来提取图片特征。

3.2 Embedding的妙用

使用ResNet50提取的特征也可以被称为Embedding,也可以简单理解为图向量。Embedding近几年在人工智能领域发挥了巨大潜力,尤其在自然语言处理领域。

3.2.1 关系可视化

早期Embedding主要用于词向量,通过word2vec把单词转换成向量,然后就可以完成一些奇妙的操作。比如单词之间关系的可视化,比如下面这张图:

在图片中可视化了:mother、father、car、auto、cat、tiger六个单词,从图可以明显看出mother、father比较近;car、auto比较近;cat、tiger比较近,这些都与我们常识相符。

3.2.2 关系运算

我们希望训练良好的Embedding每一个维度都有一个具体的含义,比如第一维表示词性,第二维表示情感,其余各个维度都有具体含义。如果能达到这个效果,或者达到近似效果,那么就可以使用向量的计算来计算单词之间的关系。

比如“妈妈-女性+男性≈爸爸”,或者“国王-男性+女性≈皇后”。比如以往要搜索“物理学界的贝多芬是谁”可能得到非常奇怪的结果,但是如果把这个问题转换成“贝多芬-音乐界+物理学界≈?”,这样问题就简单多了。

3.2.3 聚类

当我们可以用Embedding表示图片和文字时,就可以使用聚类算法完成图片或文字的自动分组。在许多手机的相册中,就有自动图片归类的功能。

聚类还可以加速搜索的操作,这点会在后面详细说。

3.3 以图搜图改进

下面使用图像特征来代替像素改进以图搜图,代码如下:

import os
import cv2
import random
import numpy as np
from keras.api.keras.applications.resnet50 import ResNet50
from keras.api.keras.applications.resnet50 import preprocess_input

w, h = 224, 224
# 加载模型
encoder = ResNet50(include_top=False)


base_path = r"G:\datasets\lbxx"
files = [os.path.join(base_path, file) for file in os.listdir(base_path)]
target_path = random.choice(files)
target = cv2.resize(cv2.imread(target_path), (w, h))


# 提取图片特征
target = encoder(preprocess_input(target[None]))
distances = []
for file in files:
    source = cv2.imread(file)
    if not isinstance(source, np.ndarray):
        continue
        
    # 读取图片,提取图片特征
    source = cv2.resize(source, (w, h))
    source = encoder(preprocess_input(source[None]))
    

    distance = np.sum((target - source) ** 2)
    distances.append((file, distance))
# 找到相似度前5的图片,这里拿了6个,第一个是原图
distances = sorted(distances, key=lambda x: x[-1])[:6]
imgs = list(map(lambda x: cv2.imread(x[0]), distances))
result = np.hstack(imgs)
cv2.imwrite("result.jpg", result)

这里使用在imagenet上预训练的ResNet50作为特征提取网络,提取的关键操作如下:

  1. 加载模型
# 加载ResNet50的卷积层,舍弃全连接部分
encoder = ResNet50(include_top=False)
  1. 图片预处理
# 把图片转换成224×224,并使用ResNet50内置的预处理方法处理
target = cv2.resize(cv2.imread(target_path), (w, h))
target = preprocess_input(target[None])
  1. 提取特征
# 使用ResNet40网络提取特征
target = encoder(preprocess_input(target)

下面是改进后的搜索结果:

四、改进二,使用聚类改进搜索速度

4.1 实现原理

在前面的例子中,我们都是使用线性搜索的方式,此时需要遍历所有图片。搜索复杂度为O(n),通常可以用树结构来存储待搜索的内容,从而把复杂度降低到O(logn)。这里我们使用更简单的方法,即聚类。

首先我们要做的就是对图片的特征进行聚类,聚成c个簇,每个簇都会对应一个簇中心。簇中心可以认为是一个簇中的平均结构,同一簇中的样本相似度会比较高。

在完成聚类后,我们可以拿到target图片的向量,在c个簇中心中查找target与哪个簇最接近。然后再到当前簇中线性查找最相似的几个图片。

4.2 代码实现

代码实现分为下面几个步骤:

  1. 把图片转换成向量

这部分代码和前面基本一样,不过这次为了速度快,我们把图像特征存储到embeddings.pkl文件:

import os
import cv2
import pickle
import numpy as np
import tensorflow as tf
from keras.api.keras.applications.resnet50 import ResNet50
from keras.api.keras.applications.resnet50 import preprocess_input

w, h = 224, 224
# 加载模型
encoder = ResNet50(include_top=False)
base_path = r"G:\datasets\lbxx"
# 获取所有图片路径
files = [os.path.join(base_path, file) for file in os.listdir(base_path)]
# 将图片转换成向量
embeddings = []
for file in files:
    # 读取图片,转换成与目标图片同一尺寸
    source = cv2.imread(file)
    if not isinstance(source, np.ndarray):
        continue
    source = cv2.resize(source, (w, h))
    embedding = encoder(preprocess_input(source[None]))
    embeddings.append({
        "filepath": file,
        "embedding": tf.reshape(embedding, (-1,))
    })
with open('embeddings.pkl', 'wb') as f:
    pickle.dump(embeddings, f)
  1. 对所有向量进行聚类操作

这里可以使用sklearn完成:

from sklearn.cluster import KMeans
with open('embeddings.pkl', 'rb') as f:
    embeddings = pickle.load(f)
X = [item['embedding'] for item in embeddings]
kmeans = KMeans(n_clusters=500)
kmeans.fit(X)
preds = kmeans.predict(X)
for item, pred in zip(embeddings, preds):
    item['cluster'] = pred
joblib.dump(kmeans, 'kmeans.pkl')
with open('embeddings.pkl', 'wb') as f:
    pickle.dump(embeddings, f)

如果图片数量比较多的话,这部分操作会比较耗时。然后调用kmeans.predict方法就可以知道某个图片属于哪个簇,这个也可以事先存储。

  1. 找到输入图片最近的簇中心

在训练完成后,就可以拿到所有簇中心:

kmeans.cluster_centers_

现在要做的就是找到与输入图片最近的簇中心,这个和前面的搜索一样:

# 查找最近的簇
closet_cluster = 0
closet_distance = sys.float_info.max
for idx, center in enumerate(centers):
    distance = np.sum((target.numpy() - center) ** 2)
    if distance < closet_distance:
        closet_distance = distance
        closet_cluster = idx
  1. 在当前簇中查找图片

这个和前面也是基本一样的:

distances = []
for item in embeddings:
    if not item['cluster'] == closet_cluster:
        continue
    embedding = item['embedding']
    distance = np.sum((target - embedding) ** 2)
    distances.append((item['filepath'], distance))
# 对距离进行排序
distances = sorted(distances, key=lambda x: x[-1])[:11]
imgs = list(map(lambda x: cv2.imread(x[0]), distances))
result = np.hstack(imgs)
cv2.imwrite("result.jpg", result)

下面是一些搜索结果:

效果还是不错的,而且这次搜索速度快了许多。不过在编码上这种方式比较繁琐,为了让代码更简洁,下面引入向量数据库。

五、向量数据库

5.1 向量数据库

向量数据库和传统数据库不太一样,可以在数据库中存储向量字段,然后完成向量相似度检索。使用向量数据库可以很方便实现上面的检索功能,而且性能方面会比前面更佳。

向量数据库与传统数据库有很多相似的地方,在关系型数据库中,数据库分为连接、数据库、表、对象。在向量数据库中分别对应连接、数据库、集合、数据。集合中,可以添加embedding类型的字段,该字段可以用于向量检索。

5.2 Milvus向量数据库的使用

下面简单说一下Milvus向量数据库的使用,首先需要安装Milvus,执行下面两条执行即可:

wget https://github.com/milvus-io/milvus/releases/download/v2.2.11/milvus-standalone-docker-compose.yml -O docker-compose.yml
sudo docker-compose up -d

下载完成后,需要连接数据库,代码如下:

from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
connections.connect(host='127.0.0.1', port='19530')

然后创建集合:

def create_milvus_collection(collection_name, dim):
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)

    fields = [
        FieldSchema(name='id', dtype=DataType.INT64, descrition='ids', max_length=500, is_primary=True,
                    auto_id=True),
        FieldSchema(name='filepath', dtype=DataType.VARCHAR, description='filepath', max_length=512),
        FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim),
    ]
    schema = CollectionSchema(fields=fields, description='reverse image search')
    collection = Collection(name=collection_name, schema=schema)

    # create IVF_FLAT index for collection.
    index_params = {
        'metric_type': 'L2',
        'index_type': "IVF_FLAT",
        'params': {"nlist": 2048}
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    return collection


collection = create_milvus_collection('images', 2048)

其中create_milvus_collection的第二个参数是embedding的维度,这里传入图片特征的维度。然后把图片特征存储到向量数据库中,这里需要注意维度不能超过32768,但是ResNet50返回的维度超过了这个限制,为此可以用PCA降维或者采用其它方法获取图片embedding。

import pickle
from sklearn.decomposition import PCA

with open('embeddings.pkl', 'rb') as f:
    embeddings = pickle.load(f)
X = [item['embedding'] for item in embeddings]
pca = PCA(n_components=2048)
X = pca.fit_transform(X)
for item, vec in zip(embeddings, X):
    item['embedding'] = vec
with open('embeddings.pkl', 'wb') as f:
    pickle.dump(embeddings, f)
with open('pca.pkl', 'wb') as f:
    pickle.dump(pca, f)

这样就可以插入数据了,代码如下:

index_params = {
    "metric_type": "L2",
    "index_type": "IVF_FLAT",
    "params": {"nlist": 1024}
}
with open('embeddings.pkl', 'rb') as f:
    embeddings = pickle.load(f)
base_path = r"G:\datasets\lbxx"
# 获取所有图片路径
files = [os.path.join(base_path, file) for file in os.listdir(base_path)]
for item in embeddings:
    collection.insert([
        [item['filepath']],
        [item['embedding']]
    ])

现在如果想要搜索图片,只需要下面几行代码即可:

import os
import cv2
import joblib
import random

import numpy as np
import tensorflow as tf
from PIL import Image
from keras.api.keras.applications.resnet50 import ResNet50
from keras.api.keras.applications.resnet50 import preprocess_input
from pymilvus import connections, Collection

pca = joblib.load('pca.pkl')
w, h = 224, 224
encoder = ResNet50(include_top=False)
base_path = r"G:\datasets\lbxx"
files = [os.path.join(base_path, file) for file in os.listdir(base_path)]
target_path = random.choice(files)
target = cv2.resize(cv2.imread(target_path), (w, h))
target = encoder(preprocess_input(target[None]))
target = tf.reshape(target, (1, -1))
target = pca.transform(target)

# 连接数据库,加载images集合
connections.connect(host='127.0.0.1', port='19530')
collection = Collection(name='images')
search_params = {"metric_type": "L2", "params": {"nprobe": 10}, "offset": 5}
collection.load()
# 在数据库中搜索
results = collection.search(
    data=[target[0]],
    anns_field='embedding',
    param=search_params,
    output_fields=['filepath'],
    limit=10,
    consistency_level="Strong"
)
collection.release()
images = []
for result in results[0]:
    entity = result.entity
    filepath = entity.get('filepath')
    image = cv2.resize(cv2.imread(filepath), (w, h))
    images.append(np.array(image))
result = np.hstack(images)
cv2.imwrite("result.jpg", result)

下面是一些搜索结果,整体来看还是非常不错的,不过由于降维的关系,搜索效果可能或略差于前面,但是整体效率要高许多。

六、总结

本文我们分享了以图搜图的功能。主要思想就是将图片转换成向量表示,然后利用相似度计算,在图库中查找与之最接近的图片。最开始使用线性搜索的方式,此时查找效率最低。而后使用聚类进行改进,把把图片分成多个簇,把查找分为查找簇和查找最近图片两个步骤,可以大大提高查找效率。

改进后代码变得比较繁琐,于是引入向量数据库,使用向量数据库完成检索功能。这样就完成了整个程序的编写。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/763469.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

每天一道C语言编程:排队买票

题目描述 有M个小孩到公园玩&#xff0c;门票是1元。其中N个小孩带的钱为1元&#xff0c;K个小孩带的钱为2元。售票员没有零钱&#xff0c;问这些小孩共有多少种排队方法&#xff0c;使得售票员总能找得开零钱。注意&#xff1a;两个拿一元零钱的小孩&#xff0c;他们的位置互…

Windows 10快速启动怎么关闭?

有的用户电脑在开启快速启动后&#xff0c;发现电脑的开机速度确实变快了&#xff0c;但有的用户开启快速启动后开机速度反而变慢了&#xff0c;所以想关闭快速启动。那电脑快速启动怎么关闭&#xff1f; 快速启动与休眠 快速启动与电脑的休眠功能相似&#xff0c;但又有所不同…

获取网络包的硬件时间戳

转自&#xff1a;如何获取网络包的硬件时间戳_飞行的精灵的博客-CSDN博客 在一些应用中我们需要获取网路报文进出MAC的精准的时间戳。相比较于软件时间戳&#xff0c;硬件时间戳排除了系统软件引起的延时和抖动。如下图所示意&#xff1a; 下面我们使用北京飞灵科技有限公司开…

在命令行执行命令后出现 Permission denied 的问题解决

解决在项目目录安装一个有 “bin” 配置的依赖包后&#xff0c;执行 “bin” 命令&#xff0c;出现了 Permission denied 的问题。 问题 比如有这样一个包 json2playwright &#xff0c;它的 package.json 中 “bin” 是&#xff1a; "bin": {"pince": &q…

如何让Vue项目本地运行的时候,同时支持http://localhost和http://192.168.X.X访问?

方法1&#xff1a;在package.json的"scripts":→ "dev":末尾追加 --host 0.0.0.0 方法2&#xff1a;将config\index.js的"dev":→ "host":修改为0.0.0.0

攻不下dfs不参加比赛(十八)

标题 为什么练dfs题目为什么练dfs 相信学过数据结构的朋友都知道dfs(深度优先搜索)是里面相当重要的一种搜索算法,可能直接说大家感受不到有条件的大家可以去看看一些算法比赛。这些比赛中每一届或多或少都会牵扯到dfs,可能提到dfs大家都知道但是我们为了避免眼高手低有的东…

CANoe-Symbol Mapping介绍

在CANoe的Environment菜单下有一个模块叫:Symbol Mapping。 打开后的界面为: 它的作用是: 在mapping对话框内,你可以映射系统变量、环境变量、信号、通信对象的值或分布式对象的成员以及系统变量的命名空间。当测量过程中源变量的值发生变化时,目标变量的值会自动设置。 你…

JDK安装

JDK安装 1、Windows环境下JDK的安装 1.1 下载 到 Java 的官网下载 JDK 安装包&#xff0c;下载地址&#xff1a; http://www.oracle.com/technetwork/java/javase/downloads/index.html 选择一个适合自己的 JDK 版本下载即可。 1.2 安装 通过双击软件并且点击下一步进行…

Fiddler抓包使用简介

目录 Fiddler简介 请求抓包 抓取PC端HTTPS请求 抓取移动端请求 请求查看 发送请求 Mock接口 断点调试 弱网模拟 请求重放 修改HOSTS 总结&#xff1a; Fiddler简介 Fiddler是一款免费的Windows平台的抓包工具&#xff0c;功能强大&#xff0c;使用简单。Fiddler抓…

【SCI征稿】老牌期刊2023年上涨质量高!中科院2/1区(TOP),国人发文友好!

您有一个评职称弯道超车的机会&#xff1f; 因为本期小编要推荐一本中科院2/1区&#xff08;TOP&#xff09;期刊&#xff0c;期刊质量不论是评职晋升求职毕业都是首选的好刊&#xff01;究竟怎么回事&#xff1f;且看下文&#xff1a; 期刊简介&#xff1a; 影响因子&#…

传感器信息系统中的节能收集(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

消息队列——spring和springboot整合rabbitmq

目录 spring整合rabbitmq——生产者 rabbitmq配置文件信息 倒入生产者工程的相关代码 简单工作模式 spring整合rabbitmq——消费者 spring整合rabbitmq——配置详解 SpringBoot整合RabbitMQ——生产者 SpringBoot整合RabbitMQ——消费者 spring整合rabbitmq——生产者 使…

分布式应用之存储(Ceph)

分布式应用之存储&#xff08;Ceph) 一、数据存储类型 存储类型说明典型代表块存储一对一&#xff0c;只能被一个主机挂载使用数据以块为单位进行存储硬盘文件存储一对多&#xff0c;能同时被多个主机挂载/传输使用&#xff0c;数据以文件的形式存储&#xff08;元数据和实际…

Appium+python自动化(十)- 元素定位秘籍助你打通任督二脉 - 上卷(超详解)

1、 常用定位方法讲解 对象定位是自动化测试中很关键的一步&#xff0c;也可以说是最关键的一步&#xff0c;毕竟你对象都没定位那么你想操作也不行。所以本章节的知识宏哥希望小伙伴们多动手去操作和实战&#xff0c;不要仅仅只是书本上的知识&#xff0c;毕竟这个我只能够举例…

AtcoderABC301场

A - Order Something Else A - Order Something Else 题目大意 计算 Takahashi 最少需要支付多少钱才能获得 AtCoder Drink。AtCoder Drink 可以按照原价 P 日元购买&#xff0c;也可以使用折扣券以 Q 日元的价格购买&#xff0c;但必须再额外购买 N 道菜品中的一道才能使用折…

Navicat代码片段存储位置

1、在Navicat的主界面中&#xff0c;选择“工具”——》“选项”——》文件位置&#xff0c;如下图 配置文件就是存放自动保存、代码片段等文件的位置&#xff0c;其中snippets&#xff08;片段&#xff09;就是自定义片段的存储位置了

【Android】在某个model中找不到自己的R资源的原因

背景 在某个新建的model为lib包的时候&#xff0c;我想在这个model内的activity引用R.string 等等资源&#xff0c;但是Android studio找不到。 解决 原来我之前误删了这个manifest中的 补齐包名即可。

Triton_server部署学习笔记

下载镜像 docker pill http://nvcr.io/nvidia/tritonserver:22.07-py3 docker run --gpus all -itd -p8000:8000 -p8001:8001 -p8002:8002 -v /home/ai-developer/server/docs/examples/model_repository/:/models nvcr.io/nvidia/tritonserver:22.07-py3 docker exec -it a5…

使用shell监控应用运行状态通过企业微信接收监控通知

目的&#xff1a;编写shell脚本来监控应用服务运行状态&#xff0c;若是应用异常则自动重启应用通过企业微信接收监控告警通知 知识要点&#xff1a; 使用shell脚本监控应用服务使用shell脚本自动恢复异常服务通过企业微信通知接收监控结果shell脚本使用数组知识&#xff0c;…

[黑苹果EFI]Lenovo ThinkPad T490电脑 Hackintosh 黑苹果引导文件

原文来源于黑果魏叔官网&#xff0c;转载需注明出处。&#xff08;下载请直接百度黑果魏叔&#xff09; 硬件型号驱动情况 主板Lenovo ThinkPad T490 处理器Intel Intel Core i5 8265U (Quad Core)已驱动 内存16 GB:8 GB Samsung DDR 4 2666 Mhz *2已驱动 硬盘PC SN520 NVM…