resnet50,clip,Faiss+Flask简易图文搜索服务

news2025/1/20 1:38:39

一、实现

文件夹目录结构:

templates

        -----upload.html

faiss_app.py

前端代码:

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Search and Show Multiple Images</title>
    <style>
        #image-container {
            display: flex;
            flex-wrap: wrap;
        }
        #image-container img {
            max-width: 150px;
            margin: 10px;
        }
    </style>
</head>
<body>
    <h1>Search Images</h1>
     <!-- 上传表单 -->
     <form id="upload-form" enctype="multipart/form-data">
        <input type="file" id="file-input" name="file" accept="image/*" required>
        <input type="submit" value="Upload">
    </form>

    <!-- 搜索框 -->
    <form id="search-form">
        <input type="text" id="search-input" name="query" placeholder="Enter search term" required>
        <input type="submit" value="Search">
    </form>

    <h2>Search Results</h2>
    <!-- 显示搜索返回的多张图片 -->
    <div id="image-container"></div>

    <!-- 使用JS处理表单提交 -->
    <script>
        document.getElementById('search-form').addEventListener('submit', async function(event) {
            event.preventDefault();  // 阻止表单默认提交行为
            
            const query = document.getElementById('search-input').value;  // 获取搜索框中的输入内容

            try {
                // 发送GET请求,将搜索关键词发送到后端
                const response = await fetch(`/search?query=${encodeURIComponent(query)}`, {
                    method: 'GET',
                });

                // 确保服务器返回JSON数据
                const data = await response.json();

                // 清空图片容器
                const imageContainer = document.getElementById('image-container');
                imageContainer.innerHTML = '';

                // 遍历后端返回的图片URL数组,动态创建<img>标签并渲染
                data.image_urls.forEach(url => {
                    const imgElement = document.createElement('img');
                    imgElement.src = url;  // 设置图片的src属性为返回的URL
                    imageContainer.appendChild(imgElement);  // 将图片添加到容器中
                });
            } catch (error) {
                console.error('Error searching for images:', error);
            }
        });
        document.getElementById('upload-form').addEventListener('submit', async function(event) {
            event.preventDefault();  // 阻止表单默认提交行为
            
            const fileInput = document.getElementById('file-input');
            const formData = new FormData();
            formData.append('file', fileInput.files[0]);  // 获取用户上传的图片文件

            try {
                // 发送POST请求,将图片发送到后端
                const response = await fetch('/search_by_images', {
                    method: 'POST',
                    body: formData
                });

                // 确保服务器返回JSON数据
                const data = await response.json();

                // 清空图片容器
                const imageContainer = document.getElementById('image-container');
                imageContainer.innerHTML = '';

                // 遍历后端返回的图片URL数组,动态创建<img>标签并渲染
                data.image_urls.forEach(url => {
                    const imgElement = document.createElement('img');
                    imgElement.src = url;  // 设置图片的src属性为返回的URL
                    imageContainer.appendChild(imgElement);  // 将图片添加到容器中
                });
            } catch (error) {
                console.error('Error uploading file:', error);
            }
        });
    </script>
</body>
</html>

后端代码:

from sentence_transformers import SentenceTransformer, util
from torchvision import models, transforms
from PIL import Image
from flask import Flask, request, jsonify, current_app, render_template, send_from_directory, url_for
from werkzeug.utils import secure_filename
import faiss
import os, glob
import numpy as np
from markupsafe import escape
import shutil

#Load CLIP model
model = SentenceTransformer('clip-ViT-B-32')
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp'}

UPLOAD_FOLDER = 'uploads/'
IMAGES_PATH  = "C:\\Users\\cccc\\Pictures\\cls_auto_config"

def generate_clip_embeddings(images_path, model):
    image_paths = []
    # 使用 os.walk 遍历所有子目录和文件
    for root, dirs, files in os.walk(images_path):
        for file in files:
            # 获取文件的扩展名并转换为小写
            ext = os.path.splitext(file)[1].lower()
            # 判断是否是图片文件
            if ext in IMAGE_EXTENSIONS:
                image_paths.append(os.path.join(root, file)) 
    embeddings = []
    for img_path in image_paths:
        image = Image.open(img_path)
        embedding = model.encode(image)
        embeddings.append(embedding)
    
    return embeddings, image_paths

def generate_res50_embeddings(images_path):
    # Load the pretrained model
    res50_model = models.resnet50(pretrained=True)
    res50_model = res50_model.eval()

    # Define the image transformations
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    image_paths = []
    # 使用 os.walk 遍历所有子目录和文件
    for root, dirs, files in os.walk(images_path):
        for file in files:
            # 获取文件的扩展名并转换为小写
            ext = os.path.splitext(file)[1].lower()
            # 判断是否是图片文件
            if ext in IMAGE_EXTENSIONS:
                image_paths.append(os.path.join(root, file)) 
    embeddings = []
    for img_path in image_paths:
        image = Image.open(img_path)
        # Apply the transformations and get the image vector
        image = transform(image).unsqueeze(0)
        image_vector = res50_model(image).detach().numpy()
        embeddings.append(image_vector[0])
    return embeddings, image_paths

def create_faiss_index(embeddings, image_paths, output_path):

    dimension = len(embeddings[0])

    # 分情况创建Faiss索引对象
    if len(image_paths) < 39 * 256:
        # 如果条目很少,直接用最普通的L2索引
        faiss_index = faiss.IndexFlatL2(dimension)
    elif len(image_paths) < 39 * 4096:
        # 如果条目少于39 × 4096,就只用PQ量化,不使用IVF
        faiss_index = faiss.index_factory(dimension, 'OPQ64_256,PQ64x8')
    else:
        # 否则就加上IVF
        faiss_index = faiss.index_factory(dimension, 'OPQ64_256,IVF4096,PQ64x8')
    res = faiss.StandardGpuResources()
    co = faiss.GpuClonerOptions()
    co.useFloat16 = True
    faiss_index = faiss.index_cpu_to_gpu(res, 0, faiss_index, co)

    #index = faiss.IndexFlatIP(dimension)
    faiss_index = faiss.IndexIDMap(faiss_index)
    
    vectors = np.array(embeddings).astype(np.float32)

    # Add vectors to the index with IDs
    faiss_index.add_with_ids(vectors, np.array(range(len(embeddings))))
    
    # Save the index
    faiss_index = faiss.index_gpu_to_cpu(faiss_index)
    faiss.write_index(faiss_index, output_path)
    print(f"Index created and saved to {output_path}")
    
    # Save image paths
    with open(output_path + '.paths', 'w') as f:
        for img_path in image_paths:
            f.write(img_path + '\n')
    
    return faiss_index

def load_faiss_index(index_path):
    faiss_index = faiss.read_index(index_path)
    with open(index_path + '.paths', 'r') as f:
        image_paths = [line.strip() for line in f]
    print(f"Index loaded from {index_path}")
    if not faiss_index.is_trained:
            raise RuntimeError(f'从[{index_path}]加载的Faiss索引未训练')
    res = faiss.StandardGpuResources()
    co = faiss.GpuClonerOptions()
    co.useFloat16 = True
    faiss_index = faiss.index_cpu_to_gpu(res, 0, faiss_index, co)
    return faiss_index, image_paths


def retrieve_similar_images(query, model, index, image_paths, top_k=3):
    
    # query preprocess:
    if query.endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')):
        query = Image.open(query)

    query_features = model.encode(query)
    query_features = query_features.astype(np.float32).reshape(1, -1)

    distances, indices = index.search(query_features, top_k)

    retrieved_images = [image_paths[int(idx)] for idx in indices[0]]

    return query, retrieved_images

def retrieve_res50_similar_images(query, index, image_paths, top_k=3):
    # query preprocess:
    if query.endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')):
        image = Image.open(query)
        # Load the pretrained model
        res50_model = models.resnet50(pretrained=True)
        res50_model = res50_model.eval()

        # Define the image transformations
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        # Apply the transformations and get the image vector
        image = transform(image).unsqueeze(0)
        query_features = res50_model(image).detach().numpy()
        query_features = query_features[0]
        query_features = query_features.astype(np.float32).reshape(1, -1)

        distances, indices = index.search(query_features, top_k)

        retrieved_images = [image_paths[int(idx)] for idx in indices[0]]

        return query, retrieved_images

# 检查文件扩展名是否允许
def allowed_file(filename):
    return '.' in filename and "." + filename.rsplit('.', 1)[1].lower() in IMAGE_EXTENSIONS

def search():
    query = request.args.get('query')  # 获取搜索关键词
    safe_query = escape(query)

    if not query:
        return jsonify({"error": "No search query provided"}), 400
    index, image_paths = None, []
    OUTPUT_INDEX_PATH = f"{app.config['UPLOAD_FOLDER']}/vector.index"
    if os.path.exists(OUTPUT_INDEX_PATH):
        index, image_paths = load_faiss_index(OUTPUT_INDEX_PATH)
    else:
        # embeddings, image_paths = generate_clip_embeddings(IMAGES_PATH, model)
        embeddings, image_paths = generate_res50_embeddings(IMAGES_PATH)
        index = create_faiss_index(embeddings, image_paths, OUTPUT_INDEX_PATH)
    query, retrieved_images = retrieve_similar_images(query, model, index, image_paths, top_k=5)


    image_urls = []
    for path in retrieved_images:
        base_name = os.path.basename(path)
        shutil.copy(path, os.path.join(app.config['UPLOAD_FOLDER'], base_name))
        image_urls.append(url_for('uploaded_file_path', filename=base_name))

    return jsonify({"image_urls": image_urls})


def search_by_images():
    # 检查请求中是否有文件
    if 'file' not in request.files:
        return jsonify({"error": "No file part"}), 400
    file = request.files['file']

    # 检查文件是否为空
    if file.filename == '':
        return jsonify({"error": "No selected file"}), 400
    print(file.filename)
    if file and allowed_file(file.filename):
        filename = secure_filename(file.filename)
        filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(filepath)

        index, image_paths = None, []
        OUTPUT_INDEX_PATH = f"{app.config['UPLOAD_FOLDER']}/images_vector.index"
        if os.path.exists(OUTPUT_INDEX_PATH):
            index, image_paths = load_faiss_index(OUTPUT_INDEX_PATH)
        else:
            embeddings, image_paths = generate_res50_embeddings(IMAGES_PATH)
            index = create_faiss_index(embeddings, image_paths, OUTPUT_INDEX_PATH)
        filepath, retrieved_images = retrieve_res50_similar_images(filepath, index, image_paths, top_k=5)

        image_urls = []
        for path in retrieved_images:
            base_name = os.path.basename(path)
            shutil.copy(path, os.path.join(app.config['UPLOAD_FOLDER'], base_name))
            image_urls.append(url_for('uploaded_file_path', filename=base_name))

        return jsonify({"image_urls": image_urls})
    else:
        return jsonify({"error": "Invalid file"}), 400


def index():
    return render_template('upload.html')

# 提供静态文件的访问路径
def uploaded_file_path(filename):
    return send_from_directory(app.config['UPLOAD_FOLDER'], filename)

if __name__ == "__main__":
    app = Flask(__name__)
    app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
    if not os.path.exists(UPLOAD_FOLDER):
        os.makedirs(UPLOAD_FOLDER)
    # 主页显示上传表单
    app.route('/')(index)
    app.route('/search', methods=['GET'])(search)
    app.route('/uploads/images/<filename>')(uploaded_file_path)
    app.route('/search_by_images', methods=['POST'])(search_by_images)
    app.run(host='0.0.0.0', port=8080, debug=True)

二、实现效果

三、参考文章

1. https://towardsdatascience.com/building-an-image-similarity-search-engine-with-faiss-and-clip-2211126d08fa

2.向量数据库Faiss的搭建与使用 - 很久8899 - 博客园

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2243305.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Flink监控checkpoint

Flink的web界面提供了一个选项卡来监控作业的检查点。这些统计信息在任务终止后也可用。有四个选项卡可以显示关于检查点的信息:概述(Overview)、历史(History)、摘要(Summary)和配置(Configuration)。下面依次来看这几个选项。 Overview Tab Overview选项卡列出了以…

如何用Excel批量提取文件夹内所有文件名?两种简单方法推荐

在日常办公中&#xff0c;我们有时需要将文件夹中的所有文件名整理在Excel表格中&#xff0c;方便管理和查阅。手动复制文件名既费时又易出错&#xff0c;因此本文将介绍两种利用Excel自动提取文件夹中所有文件名的方法&#xff0c;帮助你快速整理文件信息。 方法一&#xff1…

微信小程序-prettier 格式化

一.安装prettier插件 二.配置开发者工具的设置 配置如下代码在setting.json里&#xff1a; "editor.formatOnSave": true,"editor.defaultFormatter": "esbenp.prettier-vscode","prettier.documentSelectors": ["**/*.wxml"…

Debezium日常分享系列之:Debezium3版本Debezium connector for JDBC

Debezium日常分享系列之&#xff1a;Debezium3版本Debezium connector for JDBC 概述JDBC连接器的工作原理消费复杂的Debezium变更事件至少一次的传递多个任务数据和列类型映射主键处理删除模式幂等写入模式演化引用和大小写敏感性连接空闲超时数据类型映射部署Debezium JDBC连…

前端页面自适应等比例缩放 Flexible+rem方案

在移动互联网时代&#xff0c;随着智能手机和平板电脑的普及&#xff0c;前端开发者面临的一个重要挑战是如何让网页在不同尺寸和分辨率的设备上都能良好地显示。为了应对这一挑战&#xff0c;阿里巴巴的前端团队开发了 flexible.js&#xff0c;旨在提供一种简单有效的解决方案…

Argo workflow 拉取git 并使用pvc共享文件

文章目录 拉取 Git 仓库并读取文件使用 Kubernetes Persistent Volumes&#xff08;通过 volumeClaimTemplates&#xff09;以及任务之间如何共享数据 拉取 Git 仓库并读取文件 在 Argo Workflows 中&#xff0c;如果你想要一个任务拉取 Git 仓库中的文件&#xff0c;另一个任…

AWTK VSCode 实时预览插件端口冲突的解决办法

AWTK XML UI 预览插件&#xff1a;在 vscode 中实时预览 AWTK XML UI 文件&#xff0c;在 Copilot 的帮助下&#xff0c;可以大幅提高界面的开发效率。 主要特色&#xff1a; 真实的 UI 效果。可以设置主题&#xff0c;方便查看在不同主题下界面的效果。可以设置语言&#xf…

x-cmd pkg | helix - 用 Rust 打造的文本编辑器,内置 LSP 和语法高亮,兼容 Vim 用户习惯

目录 简介快速上手安装使用 功能特点竞品和相关项目进一步阅读 简介 helix 是用 Rust 开发的文本编辑器&#xff0c;以 Modal editing&#xff08;模态编辑&#xff09;为核心特性&#xff0c;类似于 Vim。它结合了经典的 Vim 模式编辑和现代开发工具的特性&#xff08;如 LSP…

资源管理功能拆解——如何高效配置和管理项目资源?

在任何一个项目中&#xff0c;资源的合理配置和高效管理是决定项目成败的关键因素。无论是人力、物资还是时间&#xff0c;每一个资源的使用都直接关系到项目的执行效果和企业的成本控制。因此&#xff0c;如何在项目管理中实现资源的高效配置和监控&#xff0c;成为了企业管理…

SpringCloud OpenFeign负载均衡远程调用 跨服务调用 连接池优化

介绍 Spring Cloud OpenFeign 是 Spring Cloud 的一部分&#xff0c;提供了一种声明式的 HTTP 客户端方式来简化服务间的通信。通过 OpenFeign&#xff0c;开发者可以像调用本地方法一样&#xff0c;轻松地调用远程服务&#xff0c;而不需要手动处理 HTTP 请求、响应和连接等底…

Android Studio开发学习(五)———LinearLayout(线性布局)

一、布局 认识了解一下Android中的布局&#xff0c;分别是: LinearLayout(线性布局)&#xff0c;RelativeLayout(相对布局)&#xff0c;TableLayout(表格布局)&#xff0c; FrameLayout(帧布局)&#xff0c;AbsoluteLayout(绝对布局)&#xff0c;GridLayout(网格布局) 等。 二、…

Android WMS概览

WMS&#xff08;WindowManagerService&#xff09;是 Android 系统的核心服务&#xff0c;负责管理应用和系统的窗口&#xff0c;包括窗口的创建、销毁、布局、层级管理、输入事件分发以及动画显示等。它通过协调 InputManager 和 SurfaceFlinger 实现触摸事件处理和窗口渲染&a…

C++ 【string】使用及函数

详解 C std::string 是 C 标准库中的一个类&#xff0c;它用于处理字符串数据。它是容器适配器&#xff08;container adapter&#xff09;&#xff0c;基于 basic_stringbuf 和 basic_ostream 类&#xff0c;提供了高效、安全的字符串操作。 以下是 std::string 的一些关键特…

基于的图的异常检测算法OddBall

OddBall异常检测算法出自2010年的论文《OddBall: Spotting Anomalies in Weighted Graphs》&#xff0c;它是一个在加权图(weighted graph)上检测异常点的算法&#xff0c;基本思路为计算每一个点的一度邻域特征&#xff0c;然后在整个图上用这些特征拟合出一个函数&#xff0c…

最新版xAI LLM 模型Grok-2 上线

xAI&#xff01;Grok-2 最新版开启公测&#xff01;”。这是我注册成功的截图&#xff0c;使用国内的邮箱就可以注册使用了&#xff01; Grok API公测与免费体验: Grok API开启公测&#xff0c;提供免费体验128k上下文支持&#xff0c;。Grok-Beta与马斯克: 马斯克庆祝特朗普当…

华为云stack网络服务流量走向

1.同VPC同子网同主机内ECS间互访流量走向 一句话通过主机内部br-int通信 2.同VPC同子网跨主机ECS间互访流量走向 3.同VPC不同子网同主机ECS间互访流量走向 去往本机的mac地址都记录在br-tun流表里 4.同VPC不同子网跨主机ECS间互访流量走向 5.对等连接流量走向&#xff08;跨V…

Vue Canvas实现区域拉框选择

canvas.vue组件 <template><div class"all" ref"divideBox"><!-- 显示图片&#xff0c;如果 imgUrl 存在则显示 --><img id"img" v-if"imgUrl" :src"imgUrl" oncontextmenu"return false" …

实例教程:BBDB为AHRS算法开发提供完善的支撑环境(上)

1. 概述 本教程将结合程序代码及CSS控制站工程&#xff0c;讲述如何基于PH47代码框架的BBDB固件版本&#xff0c;为开发自己的AHRS姿态解算算法提供完善支撑环境&#xff0c;以及数据分析手段。 BBDB固件已内置了一套姿态解算算法。对于需要进行AHRS算法开发研究的开发者&…

Linux操作系统 ----- (5.系统管理)

目录 1.总结 2.本章学习目标 3.图形界面管理 3.1.X-Window图形界面概述 3.2.X-Window的结构 3.3.X-Window的特点 3.4.UKUI图形环境 3.5.桌面 3.5.1.桌面图标 3.5.2.计算机属性 3.5.3.桌面快捷菜单 3.6.任务栏 3.6.1.开始菜单 3.6.2.显示任务视图 3.6.3.文件管理器…

hive复杂数据类型Array Map Struct 炸裂函数explode

1、Array的使用 create table tableName( ...... colName array<基本类型> ...... ) 说明&#xff1a;下标从0开始&#xff0c;越界不报错&#xff0c;以null代替 arr1.txtzhangsan 78,89,92,96 lisi 67,75,83,94 王五 23,12 新建表&#xff1a; create table arr1(n…