【以图搜图代码实现2】--faiss工具实现犬类以图搜图

news2024/11/6 3:13:13

第一篇:【以图搜图代码实现】–犬类以图搜图示例 使用保存成h5文件,使用向量积来度量相似性,实现了以图搜图,说明了可以优化的点。
第二篇:【使用resnet18训练自己的数据集】 准对模型问题进行了优化,取得了显著性的效果。
本篇继续第一篇中所说的优化方向,使用faiss实现以图搜图。

1.faiss使用介绍

Faiss的全称是Facebook AI Similarity Search,是FaceBook针对大规模相似度检索问题开发的一个工具,底层是使用C++代码实现的,提供了python的接口,号称对10亿量级的索引可以做到毫秒级检索。

使用faiss的基本步骤
1、数据转换:把原始数据转换为"float32"数据类型的向量。
2、index构建:用 faiss 构建index
3、数据添加:将向量add到创建的index中
4、通过创建的index进行检索

1.创建索引

import faiss

def create_index(datas_embedding):
    # 构建索引,L2代表构建的index采用的相似度度量方法为L2范数
    # 必须传入一个向量的维度,创建一个空的索引
    index = faiss.IndexFlatL2(datas_embedding.shape[1])  
    # 把向量数据加入索引
    index.add(datas_embedding)   
    return index

2.保存索引

def faiss_index_save(faiss_index, save_file_location):
    faiss.write_index(faiss_index, save_file_location)

3.加载索引

def faiss_index_load(faiss_index_save_file_location):
    index = faiss.read_index(faiss_index_save_file_location)
    return index

4.向索引中添加向量

def index_data_add(faiss_index, img_path):
    # 获得索引向量的数量
    print(faiss_index.ntotal)
    img_embedding = extract_image_features(img_path)
    faiss_index.add(img_embedding)
    print(faiss_index.ntotal)

5.删除索引中的向量

def index_data_delete(faiss_index):
    print(faiss_index.ntotal)
    # remove, 指定要删除的向量id,是一个np的array
    faiss_index.remove_ids(np.array([0]))
    print(faiss_index.ntotal)

可以看出使用Faiss工具更加的灵活,可以向索引中添加和删除向量。

2.faiss实现以图搜图

本篇代码有部分是在前两篇的基础之上的,这里使用11类犬类数据集微调之后的resnet18进行特征提取。
第一篇:【以图搜图代码实现】–犬类以图搜图示例
第二篇:【使用resnet18训练自己的数据集】

数据集准备和下载可以去看第二篇文章。

1.模型加载

为了更好的适配,对第一篇中的resnet18的初始化方法进行了修改,如下:

@Project :ImageRec
@File    :resnet18.py
@IDE     :PyCharm
@Author  :菜菜2024
@Date    :2024/9/30
'''
from PIL import Image
from torchvision import transforms
import torch
import torch.nn as nn
from torchvision import models

class ResNet18:
    def __init__(self,
                 out_feature = 11,
                 model_path='E:\\xxx\\ImageRec\\weights\\resnet18.pth'):
        self.trans = transforms.Compose([
        transforms.Resize(size=(256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
        print("-----------loading resnet18------------")
        self.model = models.resnet18()
        num_feats = self.model.fc.in_features
        self.model.fc = nn.Linear(num_feats, out_feature)
        self.model.load_state_dict(torch.load(model_path))
        self.model.eval()


    def extract_image_features(self, img_path):

        image = Image.open(img_path).convert('RGB')
        image_tensor = self.trans(image).unsqueeze(0)
        with torch.no_grad():
            features = self.model(image_tensor)
        return features

其中out_feature 根据自己的数据集的类别个数进行更改,我这里的犬类是11种。model_path是训练好的保存的权重文件【训练过程可以去看第二篇】

2.文件名映射

在第一篇:【以图搜图代码实现】–犬类以图搜图示例 中使用的是保存成h5文件,索引是没有要求是整数的,这里faiss要求是整数,搞了一个映射方法,同时也是为了在后面可视化的时候,能根据索引再解码得到对应的文件路径。

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :ImageRec 
@File    :Imgmap.py
@IDE     :PyCharm 
@Author  :菜菜2024
@Date    :2024/9/29 18:02 
'''
import os
import uuid
import numpy as np


def getImgMap(img_path):
	# 为类别生成一个映射文件
    subnames = [f.split('\\')[-1] for f in os.listdir(img_path)]
    element_mapping = {}
 
    for i in range(len(subnames)):
        unique_id = str(i+2024)
        element_mapping[unique_id] = subnames[i]

    return element_mapping

def valueGetKey(mapping, target_value):


    for key, value in mapping.items():
        if value == target_value:
            # print(f"值 '{target_value}' 对应的键是: {key}")
            break
    return key


def nameMap(imgnames, img_path='E:\\xxx\\datas\\pet_dog\\train'):
    '''
    getImagVector函数得到的image_ids在保存为h5文件时进行了编码
    现在faiss工具中index需要是int类型的,这里进行映射转化
    :param img_path: 数据集目录,来得到类别映射
    :param imgnames: 需要映射的图片名称,解码之后是“中华田园犬_0”格式
    这里传参是列表
    :return:
    '''
    element_mapping = getImgMap(img_path)
    decode_names = [imgname.decode('utf-8') for imgname in imgnames]

    name_ids=[]
    for decode_name in decode_names:
        cla_name = decode_name.split("_")[0]
        img_name = decode_name.split("_")[-1]
        key = valueGetKey(element_mapping, cla_name)
        name_id = key+img_name
        name_ids.append(name_id)

    name_ids=np.array(name_ids).astype('int32')

    return name_ids




if __name__ == "__main__":

    database = 'E:\\xxx\\datas\\pet_dog\\train'
    element_mapping = getImgMap(database)
    print(element_mapping)
    print(element_mapping.get("2024"))

映射文件:

{‘2024’: ‘中华田园犬’, ‘2025’: ‘吉娃娃’, ‘2026’: ‘哈士奇’, ‘2027’: ‘德牧’, ‘2028’: ‘拉布拉多’, ‘2029’: ‘杜宾’, ‘2030’: ‘柴犬’, ‘2031’: ‘法国斗牛’, ‘2032’: ‘萨摩耶’, ‘2033’: ‘藏獒’, ‘2034’: ‘金毛’}
nameMap函数是将之前编码的图像名称进行解码,然后重新编码,编码成20240,20301,分别表示的中华田园犬文件夹下的0.jpg, 柴犬下面的1.jpg。这都是为了可视化的时候进行追溯,得到文件路径。
在这里插入图片描述

3.以图搜图实现

定义了一个类ImageRetrival,使用faiss实现创建索引,保存索引,加载索引和图像检索功能

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :ImageRec 
@File    :faiss_index.py
@IDE     :PyCharm 
@Author  :菜菜2024
@Date    :2024/9/30 15:04 
'''


import os
import faiss
from utils.split_data import array_norm
from utils.Imgmap import nameMap, getImgMap
from model import ResNet18
from save_feature import getImagVectors
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rc
# 设置全局字体为支持中文的字体
rc('font', family='SimHei')  # 黑体

class ImageRetrival:
    def __init__(self, model_path,
                 index_dim=None):
        self.index_dim = index_dim
        self.index = faiss.IndexFlatL2(self.index_dim)
        self.model_path = model_path

    def build_index(self, image_files):
        # image_vectors图片特征,image_ids对应的标签
        image_vectors, image_ids = getImagVectors(image_files)

        # image_ids 在之前保存为h5文件时进行了编码,这里进行映射
        name_ids = nameMap(image_ids)

        index = faiss.IndexIDMap(self.index)
        index.add_with_ids(image_vectors, name_ids)
        return index

    def save_index(self, index, index_path):
        faiss.write_index(index, index_path)

    def load_index(self, index_path):
        return faiss.read_index(index_path)

    def image_topK_search(self, index, input_image, topK=None):

        resnet18 = ResNet18(out_feature=11,
                            model_path=self.model_path)
        queryVec = resnet18.extract_image_features(input_image)


        dist, ind = index.search(queryVec, topK)
        dist, ind = dist.flatten(), ind.flatten()
        res = array_norm(dist, ind)
        return res

4.运行调用

if __name__=="__main__":

    model_path='E:\\xxx\\Pycharm_files\\ImageRec\\weights\\resnet18.pth'
    # 1.创建索引
    imageRetrival = ImageRetrival(model_path=model_path,
                                  index_dim=11)
    image_files = 'E:\\xxx\\datas\\pet_dog\\train'
    save_index = "./weights/dog.index"
    index = imageRetrival.build_index(image_files)

    # # 2.保存索引
    imageRetrival.save_index(index, save_index)

    # 3.加载索引
    index_load = imageRetrival.load_index(save_index)
    #
    # # 4.相似度匹配
    input_image = './data/pic/德牧.jpg'
    out = imageRetrival.image_topK_search(index_load, input_image, topK=3)
    print(out)
    showFaissRes(image_files, input_image, out)

运行时选择性注销其中的某一步骤。
最后是可视化实现showFaissRes

5.可视化实现


def showFaissRes(image_files, input_image, faissRes):
    '''
    对faiss得到的结果进行可视化
    :param image_files: 图片数据库
    :param input_image: 查询图片路径
    :param faissRes: 返回的topk跟距离最近的结果[(ind, score), (ind, score)]
    :return:
    '''
    scores = []
    imgs = []
    info = []

    # 1.得到图片名称的映射
    element_mapping = getImgMap(image_files)
    imgs.append(mpimg.imread(input_image))
    info.append(input_image.split("/")[-1])

    for i in range(len(faissRes)):
        score = faissRes[i][1]
        ind = str(faissRes[i][0])
        scores.append(score)

        # 根据索引构建原本的图像路径ind格式:20276,前四个是类别表示
        claName = element_mapping.get(ind[:4])
        imgName = ind[4:]+".jpg"
        imgpath = image_files +"\\"+ claName+ "\\"+imgName
        imgs.append(mpimg.imread(imgpath))

        info.append(claName+"_"+ imgName+"_"+ str(score))
        print("图片名称是: " + claName+ imgName + " 对应得分是: %f" %score)

    num = int((len(faissRes) + 1) // 2)+1
    fig, axs = plt.subplots(nrows=num, ncols=num, figsize=(10, 10))

    # 确保即使只有一个子图,也可以进行索引
    if not isinstance(axs, np.ndarray):
        axs = np.array([[axs]])

    # 显示图像
    flat_index = 0
    for i in range(num):
        for j in range(num):
            if flat_index < len(imgs):
                img = imgs[flat_index]
                axs[i, j].imshow(img, cmap='gray')
                axs[i, j].axis('off')
                axs[i, j].set_title(info[flat_index])
                flat_index += 1
            else:
                axs[i, j].set_visible(False)

    plt.tight_layout()
    plt.show()

3.效果对比

第一篇:【以图搜图代码实现】–犬类以图搜图示例 预训练的resnet18

第二篇:【使用resnet18训练自己的数据集】 微调的resnet18
在这里插入图片描述

本章 Faiss实现: 分数不重要,本篇对分数进行了归一化。
在这里插入图片描述
准确性更高了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2180575.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

汽修行业的知识库搭建:赋能在线教育与知识付费

随着汽修行业的蓬勃发展&#xff0c;其业务范围和技术要求日益多元化。为了应对这一趋势&#xff0c;许多汽修公司开始探索线上教育模式&#xff0c;通过开设汽修知识课程&#xff0c;实现知识的有偿分享与传播。这一转变不仅拓宽了企业的盈利渠道&#xff0c;也为广大汽修爱好…

深度学习之贝叶斯分类器

贝叶斯分类器 1 图解极大似然估计 极大似然估计的原理&#xff0c;用一张图片来说明&#xff0c;如下图所示&#xff1a; ​ 例&#xff1a;有两个外形完全相同的箱子&#xff0c;1号箱有99只白球&#xff0c;1只黑球&#xff1b;2号箱有1只白球&#xff0c;99只黑球。在一次…

【Spark 实战】基于spark3.4.2+iceberg1.6.1搭建本地调试环境

基于spark3.4.2iceberg1.6.1搭建本地调试环境 文章目录 基于spark3.4.2iceberg1.6.1搭建本地调试环境环境准备使用maven构建sparksql编辑SparkSQL简单任务附录A iceberg术语参考 环境准备 IntelliJ IDEA 2024.1.2 (Ultimate Edition)JDK 1.8Spark 3.4.2Iceberg 1.6.1 使用mave…

C++----类和对象(一)

一.类的定义 1.类定义的格式 • class为定义类的关键字&#xff0c;ST为类的名字&#xff0c;{}中为类的主体&#xff0c;注意类定义结束时后面分号不能省 略。类体中内容称为类的成员&#xff1a;类中的变量称为类的属性或成员变量; 类中的函数称为类的方法或 者成员函数。 …

NAND Flash虚拟层设计概述

NAND Flash虚拟层的建立需要对NAND Flash虚拟层进行初始化&#xff0c;根据相应的NAND Flash的物理结构参数建立逻辑结构&#xff0c;并建立索引表来管理逻辑虚拟层与物理虚拟层之间的联系&#xff1b;而在NAND Flash虚拟层运行过程中需要对NAND Flash虚拟层进行相应的垃圾回收…

【AI驱动TDSQL-C Serverless数据库技术实战】 AI电商数据分析系统——探索Text2SQL下AI驱动代码进行实际业务

目录 一、Text2SQL简介二、基于TDSQL-C Serverless的Text2SQL实战2.1、程序流程图2.2、实践流程2.2.1、配置TDSQL-C2.2.2、部署LLAMA模型2.2.3、本地依赖安装2.2.4、应用构建 2.3、运行效果 三、Text2SQL下的AI驱动 Text2SQL 是一种将自然语言查询转换为 SQL 查询的技术&#x…

NVIDIA H200 Tensor Core GPU

增强 AI 和 HPC 工作负载。 文章目录 前言一、通过更大、更快的内存实现更高的性能二、通过高性能 LLM三、增强高性能计算四、Reduce Energy and TCO 降低能耗和 TCO五、通过 H200 NVL 为主流企业服务器释放 AI 加速前言 The GPU for Generative AI and HPC 用于生成式 AI 和 …

香港科技大学新作:速度场如何在复杂城市场景规划中大显身手

导读&#xff1a; 本篇文章提出了一种局部地图表示方法&#xff08;即速度场&#xff09;来解决无法为所有场景设计通用规划规则的问题。此外&#xff0c;本文开发了一种高效的迭代轨迹优化器&#xff0c;其与速度场无缝兼容&#xff0c;实现了训练和推理过程。实验结果表明&am…

Linux操作系统中Redis

1、什么是Redis Redis&#xff08;Remote Dictionary Server &#xff09;&#xff0c;即远程字典服务&#xff0c;是一个开源的使用ANSIC语言编写、支持网络、可基于内存亦可持久化的日志型、Key-Value数据库&#xff0c;并提供多种语言的API。 可以理解成一个大容量的map。…

《向量数据库指南》——Milvus 和 Fivetran 如何为 AI 构建基础

哈哈,说起 Milvus 和 Fivetran 如何为 AI 构建基础,这可真是个有意思的话题!来,让我这个向量数据库领域的“老司机”给你详细讲解一番,保证让你听得津津有味,还能学到不少干货! Milvus 和 Fivetran:AI 搜索解决方案的黄金搭档 在当今这个数据爆炸的时代,AI 已经成为…

《软件工程概论》作业一:新冠疫情下软件产品设计(小区电梯实体按钮的软件替代方案)

课程说明&#xff1a;《软件工程概论》为浙江科技学院2018级软件工程专业在大二下学期开设的必修课。课程使用《软件工程导论&#xff08;第6版&#xff09;》&#xff08;张海藩等编著&#xff0c;清华大学出版社&#xff09;作为教材。以《软件设计文档国家标准GBT8567-2006》…

net core mvc 数据绑定 《2》 bind fromquery,FromRoute,fromform等,自定义模型绑定器

mvc core 模型绑定 控制绑定名称 》》》Bind 属性可以用来指定 模型应该 绑定的前缀 public class MyController : Controller {[HttpPost]public ActionResult Create([Bind(Prefix "MyModel")] Ilist<MyModel> model){// 模型绑定将尝试从请求的表单数据中…

Vue2实现主内容滚动到指定位置时,侧边导航栏也跟随选中变化

需求背景&#xff1a; PC端项目需要实现一个有侧边导航栏&#xff0c;可点击跳转至对应内容区域&#xff0c;类似锚点导航&#xff0c; 同时主内容区域上下滚动时&#xff0c;可实现左侧导航栏选中样式能实时跟随变动的效果。 了解了一下&#xff0c;Element Plus 组件库 和 …

从源码中学习动态代理模式

动态代理模式 动态代理是 Java 反射&#xff08;Reflection&#xff09;API 提供的一种强大机制&#xff0c;它允许在运行时创建对象的代理实例&#xff0c;而不需要在编译时静态地创建。 Java 提供了两种主要的方式来实现动态代理&#xff1a; 基于接口的动态代理&#xff1a…

2024/9/29周报

文章目录 摘要Abstract污水处理工艺流程整体介绍粗格栅细格栅曝气沉砂池提升泵房峰谷平策略 初沉池&#xff08;一级处理&#xff09;工作原理运行管理 氧化沟生化池&#xff08;二级处理&#xff09;二沉池工作原理运行参数 高效沉淀池功能与特点工作原理 深度处理&#xff08…

[BUUCTF从零单排] Web方向 03.Web入门篇之sql注入-1(手工注入详解)

这是作者新开的一个专栏《BUUCTF从零单排》&#xff0c;旨在从零学习CTF知识&#xff0c;方便更多初学者了解各种类型的安全题目&#xff0c;后续分享一定程度会对不同类型的题目进行总结&#xff0c;并结合CTF书籍和真实案例实践&#xff0c;希望对您有所帮助。当然&#xff0…

html+css+js实现dialog对话框

实现效果 HTML部分 <span class"text">点击打开 Dialog</span><!-- 警告框 --><div class"alert"><div class"header"><i>X</i> </div><div class"content">确认关闭</di…

Python 实现 YouTube 视频自动上传

文章目录 前言申请 Google API 秘钥启用 API创建项目凭证配置 API下载生成的凭据文件 youtube-upload 工具使用安装配置秘钥使用 其它问题程序尚未完成 Google 验证流程 个人简介 前言 youtube-upload 库 Python 中一个用于实现 YouTube 视频自动上传的实用工具。以下是关于如…

【 微信机器人+ AI 搭建】

摘要&#xff1a; 各种大模型已经出来好久了&#xff0c;各类app也已经玩腻了&#xff0c;接下来&#xff0c;就在考虑&#xff0c;怎么让大模型&#xff0c;利益最大化。 本人没有显著的家世&#xff0c;没有富婆包养&#xff0c;只能自己抽点时间&#xff0c;研究下技术&…

Java使用BeanUtils.copyProperties实现对象的拷贝

1、BeanUtils.copyProperties() 方法的使用 BeanUtils.copyProperties 方法是 Java 中 Spring 框架提供的一个非常实用的工具方法&#xff0c;它用于将一个 JavaBean 对象的属性值拷贝到另一个 JavaBean 对象中。这个方法主要用于简化对象之间的数据转换过程&#xff0c;尤其是…