Java调用Pytorch实现以图搜图(附源码)

news2025/2/1 6:38:17

Java调用Pytorch实现以图搜图

设计技术栈:
1、ElasticSearch环境;
2、Python运行环境(如果事先没有pytorch模型时,可以用python脚本创建模型);

1、运行效果

在这里插入图片描述

2、创建模型(有则可以跳过)

vi script.py

import torch
import torch.nn as nn
import torchvision.models as models
 
class ImageFeatureExtractor(nn.Module):
    def __init__(self):
        super(ImageFeatureExtractor, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        #最终输出维度1024的向量,下文elastic search要设置dims为1024
        self.resnet.fc = nn.Linear(2048, 1024)
 
    def forward(self, x):
        x = self.resnet(x)
        return x
 
if __name__ == '__main__':
    model = ImageFeatureExtractor()
    model.eval()
    #根据模型随便创建一个输入
    input = torch.rand([1, 3, 224, 224])
    output = model(input)
    #以这种方式保存
    script = torch.jit.trace(model, input)
    script.save("model.pt")

2、java项目pom.xml

<dependencies>
		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-web</artifactId>
		</dependency>
		<dependency>
			<groupId>org.projectlombok</groupId>
			<artifactId>lombok</artifactId>
			<scope>provided</scope>
		</dependency>
		<dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>0.19.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-cpu</artifactId>
            <version>1.10.0</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-jni</artifactId>
            <version>1.10.0-0.19.0</version>
        </dependency>
        <dependency>
            <groupId>org.elasticsearch.client</groupId>
            <artifactId>elasticsearch-rest-high-level-client</artifactId>
        </dependency>
	</dependencies>

3、ES创建文档

PUT /isi
{
  "mappings": {
    "properties": {
      "vector": {
        "type": "dense_vector",
        "dims": 1024
      },
      "url" : {
        "type" : "keyword"
      },
      "user_id": {
          "type": "keyword"
      }
    }
  }
}

4、编写java代码调用模型

ORCUtil.java

package com.topprismcloud.rtm;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Transform;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import org.apache.http.HttpHost;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestClientBuilder;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.ScriptQueryBuilder;
import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xcontent.XContentType;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URL;
import java.nio.file.Paths;
import java.util.*;

public class ORCUtil {

	private static final String INDEX = "isi";
	private static final int IMAGE_SIZE = 224;
	private static Model model; // 模型
	private static Predictor<Image, float[]> predictor; // predictor.predict(input)相当于python中model(input)
	static {
		try {
			model = Model.newInstance("model");
			// 这里的model.pt是上面代码展示的那种方式保存的
			model.load(ORCUtil.class.getClassLoader().getResourceAsStream("model.pt"));
			Transform resize = new Resize(IMAGE_SIZE);
			Transform toTensor = new ToTensor();
			Transform normalize = new Normalize(new float[] { 0.485f, 0.456f, 0.406f },
					new float[] { 0.229f, 0.224f, 0.225f });
			// Translator处理输入Image转为tensor、输出转为float[]
			Translator<Image, float[]> translator = new Translator<Image, float[]>() {
				@Override
				public NDList processInput(TranslatorContext ctx, Image input) throws Exception {
					NDManager ndManager = ctx.getNDManager();
					System.out.println("input: " + input.getWidth() + ", " + input.getHeight());
					NDArray transform = normalize
							.transform(toTensor.transform(resize.transform(input.toNDArray(ndManager))));
					System.out.println(transform.getShape());
					NDList list = new NDList();
					list.add(transform);
					return list;
				}

				@Override
				public float[] processOutput(TranslatorContext ctx, NDList ndList) throws Exception {
					return ndList.get(0).toFloatArray();
				}
			};
			predictor = new Predictor<>(model, translator, Device.cpu(), true);
		} catch (Exception e) {
			e.printStackTrace();
		}
	}

	public static void upload() throws Exception {
		HttpHost host=new HttpHost("14.20.30.16", 9200, HttpHost.DEFAULT_SCHEME_NAME);
		RestClientBuilder builder=RestClient.builder(host);
		CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
		credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials("elastic", "123456"));
		builder.setHttpClientConfigCallback(f -> f.setDefaultCredentialsProvider(credentialsProvider));
		RestHighLevelClient client = new RestHighLevelClient( builder);
		// 批量上传请求
		BulkRequest bulkRequest = new BulkRequest(INDEX);
		File file = new File("D:\\001ENV\\nginx-1.24.0\\html\\resource\\new");
		for (File listFile : file.listFiles()) {
//			float[] vector = predictor.predict(ImageFactory.getInstance()
//					.fromInputStream(Test.class.getClassLoader().getResourceAsStream("new/" + listFile.getName())));
			
			float[] vector = predictor.predict(ImageFactory.getInstance()
					.fromInputStream(new FileInputStream(listFile)));
			// 构建文档
			Map<String, Object> jsonMap = new HashMap<>();
			jsonMap.put("url", "/resource/"+listFile.getName());
			jsonMap.put("vector", vector);
			jsonMap.put("user_id", "user123");
			IndexRequest request = new IndexRequest(INDEX).source(jsonMap, XContentType.JSON);
			bulkRequest.add(request);
		}
		client.bulk(bulkRequest, RequestOptions.DEFAULT);
		client.close();
	}

	// 接收待搜索图片的inputstream,搜索与其相似的图片
	public static List<SearchResult> search(InputStream input) throws Throwable {
		float[] vector = predictor.predict(ImageFactory.getInstance().fromInputStream(input));
		System.out.println(Arrays.toString(vector));

		// 展示k个结果
		int k = 100;
		// 连接Elasticsearch服务器
		RestHighLevelClient client = new RestHighLevelClient(
				RestClient.builder(new HttpHost("14.20.30.16", 9200, "http")));

		SearchRequest searchRequest = new SearchRequest(INDEX);
		Script script = new Script(ScriptType.INLINE, "painless", "cosineSimilarity(params.queryVector, doc['vector'])",
				Collections.singletonMap("queryVector", vector));

		FunctionScoreQueryBuilder functionScoreQueryBuilder = QueryBuilders
				.functionScoreQuery(QueryBuilders.matchAllQuery(), ScoreFunctionBuilders.scriptFunction(script));

		SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
		searchSourceBuilder.query(functionScoreQueryBuilder).fetchSource(null, "vector") // 不返回vector字段,太多了没用还耗时
				.size(k);

		searchRequest.source(searchSourceBuilder);

		SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT);

		SearchHits hits = searchResponse.getHits();

		List<SearchResult> list = new ArrayList<>();
		for (SearchHit hit : hits) {
			// 处理搜索结果
			System.out.println(hit.toString());
			SearchResult result = new SearchResult((String) hit.getSourceAsMap().get("url"), hit.getScore());
			list.add(result);
		}

		client.close();
		return list;
	}

	public static void main(String[] args) throws Throwable {
		ORCUtil.upload();
		System.out.println("hao");
	}
}

SearchController.java

package com.topprismcloud.rtm;

import java.util.List;

import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;

@RestController
@CrossOrigin
public class SearchController {

	@PostMapping("search")
	public ResponseEntity search(MultipartFile file) {
		try {
			List<SearchResult> list = ORCUtil.search(file.getInputStream());
			return ResponseEntity.ok(list);
		} catch (Throwable e) {
			return ResponseEntity.status(400).body(null);
		}
	}
}

SearchResult.java

package com.topprismcloud.rtm;

import lombok.AllArgsConstructor;
import lombok.Data;

@Data
@AllArgsConstructor
public class SearchResult {
    private String url;
    private Float score;
}

5、前端

index.html

<!DOCTYPE html>
<html lang="zh">

<head>
    <meta charset="UTF-8">
    <title>以图搜图</title>
    <style>
        body {
            background: url("/img/bg.jpg");
            background-attachment: fixed;
            background-size: 100% 100%;
        }

        body>div {
            width: 1000px;
            margin: 50px auto;
            padding: 10px 20px;
            border: 1px solid lightgray;
            border-radius: 20px;
            box-sizing: border-box;
            background: rgba(255, 255, 255, 0.7);
        }

        .upload {
            display: inline-block;
            width: 300px;
            height: 280px;
            border: 1px dashed lightcoral;
            vertical-align: top;
        }

        .upload .cover {
            width: 200px;
            height: 200px;
            margin: 10px 50px;
            border: 1px solid black;
            box-sizing: border-box;
            text-align: center;
            line-height: 200px;
            position: relative;
        }

        .upload img {
            width: 198px;
            height: 198px;
            position: absolute;
            left: 0;
            top: 0;
        }

        .upload input {
            margin-left: 50px;
        }

        .upload button {
            width: 80px;
            height: 30px;
            margin-left: 110px;
        }

        .result-block {
            display: inline-block;
            margin-left: 40px;
            border: 1px solid lightgray;
            border-radius: 10px;
            min-height: 500px;
            width: 600px;
        }

        .result-block h1 {
            text-align: center;
            margin-top: 100px;
        }

        .result {
            padding: 10px;
            cursor: pointer;
            display: inline-block;
        }

        .result:hover {
            background: rgb(240, 240, 240);
        }

        .result p {
            width: 110px;
            overflow: hidden;
            white-space: nowrap;
            text-overflow: ellipsis;
        }

        .result img {
            width: 160px;
            height: 160px;
        }

        .result .prob {
            color: rgb(37, 147, 60)
        }
    </style>
    <script src="js/jquery-3.6.0.js"></script>
</head>

<body>
    <div>
        <div class="upload">
            <div class="cover">
                请选择图片
                <img id="image" src="" />
            </div>
            <input id="file" type="file">
        </div>
        <div class="result-block">
            <h1>请选择图片</h1>
        </div>
    </div>
    <ul id="box">

    </ul>
    <script>
        var file = $('#file')
        file.change(function () {
            let f = this.files[0]
            let index = f.name.lastIndexOf('.')
            let fileText = f.name.substring(index, f.name.length)
            let ext = fileText.toLowerCase() //文件类型
            console.log(ext)
            if (ext != '.png' && ext != '.jpg' && ext != '.jpeg') {
                alert('系统仅支持 JPG、PNG、JPEG 格式的图片,请您调整格式后重新上传')
                return
            }
            $('.result-block').empty().append($('<h1>正在识别中...</h1>'))
            $("#image").attr("src", getObjectURL(f));
            let formData = new FormData()
            formData.append('file', f)
            $.ajax({
                url: 'http://10.1.2.240:8081/search',
                method: 'post',
                data: formData,
                processData: false,
                contentType: false,
                success: res => {
                    console.log('shibie', res)
                    $('.result-block').empty()
                    for (let item of res) {
                        console.log(item)
                        let html = `<div class="result">
                                    <img src="${item.url}"/>
                                    <div style="display: inline-block;vertical-align: top">
                                        <p class="prob">得分:${item.score.toFixed(4)}</p>
                                    </div>
                                </div>`
                        $('.result-block').append($(html))
                    }

                }
            })
        });
        $('#button').click(function (e) {
            var file = $('#file')[0].files[0] //单个
            console.log(file)
        })
        function getObjectURL(file) {
            var url = null;
            if (window.createObjcectURL != undefined) {
                url = window.createOjcectURL(file);
            } else if (window.URL != undefined) {
                url = window.URL.createObjectURL(file);
            } else if (window.webkitURL != undefined) {
                url = window.webkitURL.createObjectURL(file);
            }
            return url;
        }
        function detect() {

        }
    </script>
</body>

</html>

6、打包后的源代码

以图搜图Java+html源代码

相关参考文章:Java调用Pytorch模型进行图像识别

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/631418.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

经典目标检测YOLO系列(1)YOLO-V1算法及其在VOC2007数据集上的应用

经典目标检测YOLO系列(1)YOLO-V1算法及其在VOC2007数据集上的应用 1 YOLO-V1的简述 1.1 目标检测概述 ​ 目标检测有非常广泛的应用&#xff0c; 例如&#xff1a;在安防监控、手机支付中的人脸检测&#xff1b;在智慧交通&#xff0c;自动驾驶中的车辆检测&#xff1b;在智…

TCP与UDP的可靠性传输

目录 一、TCP可靠性传输1、重传机制1.1、超时重传1.2、快速重传1.3、SACK1.4、Duplicate SACK 2、滑动窗口3、流量控制3.1 滑动窗口与流量控制3.2窗口关闭 4、拥塞控制4.1拥塞窗口4.2 慢启动4.3 拥塞避免4.4 拥塞发生4.5 快速恢复 二、UDP可靠性传输1、主要策略2、重传机制2.1 …

软件测试03:软件工程和软件生命周期

软件测试03&#xff1a;软件工程和软件生命周期 软件危机 软件危机是指落后的软件生产方式无法满足迅速增长的计算机软件需求&#xff0c;从而导致软件开发与维护过程中出现一系列严重问题的现象。 软件工程 基本软件危机对于计算机发展的阻碍&#xff0c;1968年&#xff0…

史上最详细的安装Kali-linux教程(附视频教程)

之前不少人问kali怎么安装&#xff0c;今天就发一篇利用VM虚拟机安装kali的详细教程&#xff0c;每一步都截图了&#xff0c;让大家尽可能的清楚每一步的操作。 1.2 使用 VM 虚拟机安装 Kali 1.2.1 官方下载 Kali Linux 官方网址&#xff1a;http://www.Kali.org 下载方式分…

跨域 —— 反向代理配置

跨域问题在讲Node.js学习中编写接口的时候就已经讲到了&#xff0c;由后端配置解决跨域问题&#xff0c;使用cors中间件解决跨域问题以及使用JSONP解决跨域&#xff08;仅支持GET请求&#xff09;&#xff0c;具体可以看一下这篇文章的内容&#xff1a;十二、Express接口编写 —…

python面向对象操作3(速通版)

目录 一、多态和类名 1.标准多态 2.实例属性和实例方法 3.类对象和类属性 4.对象保存 二、方法 1.类方法 3.四种方法的区别 三、模块 1.导入模块 2.自动模块导入 3.模块导入的几种形式 3.1模块导入的两种方式和别名 3.2 from 模块 import 成员 4.两种方法的区别…

【运筹优化】最短路算法之A星算法 + Java代码实现

文章目录 一、A星算法简介二、A星算法思想三、A星算法 java代码四、测试 一、A星算法简介 A*算法是一种静态路网中求解最短路径最有效的直接搜索方法&#xff0c;也是解决许多搜索问题的有效算法。算法中的距离估算值与实际值越接近&#xff0c;最终搜索速度越快。 二、A星算…

day52_Spring

今日内容 零、 复习昨日 一、Spring 零、 复习昨日 一、引言 以前 public class HelleServlet extends HttpServlet{UserService service new UsrServiceImpl();void doGet(){service.findUser();} }public interface UserService{User findUser(); } public class UserServ…

Tigase-Server 8.3.0在windows11下安装

一、JDK安装&#xff1a; tigase-server要求JDK 17,请先下载JDK17, 下载地址&#xff1a;https://download.oracle.com/java/17/latest/jdk-17_windows-x64_bin.exe 配置环境变量&#xff1a;JAVA_HOME{JDK安装目录} 二、数据库安装&#xff1a;tigase-server8.3在windows下…

【算法系列 | 4】深入解析排序算法之——归并排序

序言 你只管努力&#xff0c;其他交给时间&#xff0c;时间会证明一切。 文章标记颜色说明&#xff1a; 黄色&#xff1a;重要标题红色&#xff1a;用来标记结论绿色&#xff1a;用来标记一级论点蓝色&#xff1a;用来标记二级论点 决定开一个算法专栏&#xff0c;希望能帮助大…

Chrome内核插件开发报错:Unchecked runtime.lastError:的原因及解决办法。

本篇文章主要讲解,chrome内核插件开发时报错:Unchecked runtime.lastError: Extensions using event pages or Service Workers must pass an id parameter to chrome.contextMenus.create 的原因及解决办法。 日期:2023年6月10日 作者:任聪聪 报错现象: 查看报错路径,在…

项目经理必备!这四个高效管理工具帮你实现项目管理目标

在项目管理中&#xff0c;图形工具可以帮助我们让项目信息可视化&#xff0c;让项目管理更加高效&#xff0c;对于项目经理而言&#xff0c;这些工具都是好帮手。让我们一起看看&#xff0c;项目经理常用的管理工具都有那些吧~ 1&#xff0c;甘特图 甘特图是计划和管理项目的好…

【Spring使用注解更简单的实现Bean对象的存取】

&#x1f389;&#x1f389;&#x1f389;点进来你就是我的人了博主主页&#xff1a;&#x1f648;&#x1f648;&#x1f648;戳一戳,欢迎大佬指点! 欢迎志同道合的朋友一起加油喔&#x1f93a;&#x1f93a;&#x1f93a; 目录 一、前言&#xff1a; 二、储存Bean对象和使…

天黑的时候如果下雨了,会比平常更亮一些

目录 一、最近的感受 二、自我的审视 三、如何变得强大 1.保持善良 2.不过度追求公平 3.在痛苦中找到自己的意义 4.令人振奋的生命力 四、情绪调节中的个人见解及如何处理情绪后的学习 1.运动 2.散步 3.找好朋友倾诉 五、总结 一、最近的感受 天黑的时候如果下雨了…

设计模式(十一):结构型之组合模式

设计模式系列文章 设计模式(一)&#xff1a;创建型之单例模式 设计模式(二、三)&#xff1a;创建型之工厂方法和抽象工厂模式 设计模式(四)&#xff1a;创建型之原型模式 设计模式(五)&#xff1a;创建型之建造者模式 设计模式(六)&#xff1a;结构型之代理模式 设计模式…

C语言:写一个代码,使用 试除法 打印100~200之间的素数(质数)

题目&#xff1a; 使用 试除法 打印100~200之间的素数。 素数&#xff08;质数&#xff09;&#xff1a;一个数只能被写成一和本身的积。 如&#xff1a;7只能写成1*7&#xff0c;那就是素数&#xff08;质数&#xff09;了。 思路一&#xff1a;使用试除法 总体思路&#xff…

HTML5 介绍

目录 1. HTML5介绍 1.1 介绍 1.2 内容 1.3 浏览器支持情况 2. 创建HTML5页面 2.1 <!DOCTYPE> 文档类型声明 2.2 <html>标签 2.3 <meta>标签 设置字符编码 2.4 引用样式表 2.5 引用JavaScript文件 3. 完整页面示例 4. 资料网站 1. HTML5介绍 1.1 介绍 …

带你手撕一颗红黑树

红黑树&#xff08;C&#xff09; 红黑树简述红黑树的概念红黑树的性质红黑树结点定义 一&#xff0c;红黑树的插入插入调整插入代码 二&#xff0c;红黑树的验证三&#xff0c;红黑树的删除待删除的结点只有一个子树删除结点颜色为红色删除结点颜色为黑色 删除的结点为叶子节点…

直流稳压电源与信号产生电路(模电速成)

目录 一、直流稳压电源 1、直流稳压电路 2、串联型稳压电路 3、集成稳压电路 二、信号产生电路 1、振荡电路 2、波形发生器 一、直流稳压电源 1、直流稳压电路 直流电源由 变压器、整流、滤波、稳压 四部分组成 整流&#xff1a;将交流变为直流 滤波&#xff1a;减小…

AI人工智能之科研论文搜索集锦

AI人工智能之科研论文搜索集锦 前言1. Google学术搜索2. Google搜索3. Arxiv#Example&#xff1a; 4. Github#Example&#xff1a; 5. Paperwithcode6. Connectedpapers7. OpenReview 总结 前言 如今越来越多领域都会与计算机、人工智能方面进行跨领域融合&#xff0c;一个万物…