以图搜图,涉及两大功能:1、提取图像特征向量。2、相似向量检索。
第一个功能我通过编写pytorch模型并在java端借助djl调用实现,第二个功能通过elasticsearch7.6.2的dense_vector、cosineSimilarity实现。
一、准备模型
创建demo.py,输入代码,借助resnet提取图像特征
import torch
import torch.nn as nn
import torchvision.models as models
class ImageFeatureExtractor(nn.Module):
def __init__(self):
super(ImageFeatureExtractor, self).__init__()
self.resnet = models.resnet50(pretrained=True)
#最终输出维度1024的向量,下文elastic search要设置dims为1024
self.resnet.fc = nn.Linear(2048, 1024)
def forward(self, x):
x = self.resnet(x)
return x
if __name__ == '__main__':
model = ImageFeatureExtractor()
model.eval()
#根据模型随便创建一个输入
input = torch.rand([1, 3, 224, 224])
output = model(input)
#以这种方式保存
script = torch.jit.trace(model, input)
script.save("model.pt")
保存好的model.pt文件放入java项目的resources中,可以在java中引入Deep-Java-Library来调用
二、创建Java项目
创建项目,引入djl和elasticsearch的依赖
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.19.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<version>1.10.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>1.10.0-0.19.0</version>
</dependency>
<dependency>
<groupId>org.elasticsearch.client</groupId>
<artifactId>elasticsearch-rest-high-level-client</artifactId>
<version>7.6.2</version>
</dependency>
然后随便从网上下载点图片,比如猫5张图狗5张图什么的,放到项目"resources/随便" 路径下,一会要提取他们的特征向量并上传至elasticsearch
三、es创建文档
需要在elastic search中创建一个新文档。localhost:9200/isi (img search img)
PUT /isi
{
"mappings": {
"properties": {
"vector": {
"type": "dense_vector",
"dims": 1024
},
"url" : {
"type" : "keyword"
},
"user_id": {
"type": "keyword"
}
}
}
}
完成下文上传操作后测试搜索(params中queryVector为随便选了个图像提取的特征向量):
POST /isi/_search
{
"query": {
"function_score": {
"query": {
"match_all": {}
},
"script_score": {
"script": {
"source": "cosineSimilarity(params.queryVector, 'vector') ",
"params": {
"queryVector": [-0.21950562, 0.0979692, 0.30605257, -0.04246464, 0.3086218, 0.2133326, -0.13531154, 0.16382562, 0.2505685, 0.35654455, 0.50346404, -0.2031727, -0.4501943, 0.23117387, 0.39451313, 0.044487886, -0.11032343, 0.47252116, 0.24667346, -0.2052311, -0.10872754, 0.22328046, 0.13366169, -0.5555884, 0.23139203, 0.024292288, 0.3071902, 0.23381571, -0.14484097, -0.80570614, 0.096950606, -0.034106746, 0.3221968, 0.35980088, -0.24408965, 0.10010342, 0.34878045, 0.25403115, 0.8813986, -0.23978959, -0.101492174, -0.34241566, -0.258092, 0.38593173, 0.24993907, -0.6891467, 0.5723483, -0.4987241, -0.46613082, 0.07435644, -0.32876882, 0.1923833, 0.41619772, 0.006919967, -0.35519657, -0.2463252, -0.07216969, -0.10412077, -0.3964988, -0.43174505, 0.6576338, -0.09753291, 0.058324523, -0.366405, -0.08003934, -0.41232625, -0.59834087, 0.35432702, -0.33971205, -0.695481, -0.38738084, -0.08746443, 0.37581405, 0.5092232, 0.26168102, 0.33873072, 0.3769325, 0.5525994, -0.018578911, -0.16984223, 0.24996795, -0.33088574, 0.22646378, 0.28422385, -0.4198824, 0.23480973, -0.17118981, 0.21741581, -0.21377188, -0.21778852, 0.052144438, 0.6118544, 0.29015478, 0.38010067, -0.2526567, 0.31930774, 0.39023396, -0.18484715, 0.25706646, 0.5203727, 0.29022205, -0.21464778, 0.47365767, 0.87486994, 0.44488958, 0.35022217, 0.27183002, -0.10723972, 0.22553404, -0.15306596, 0.22945945, -0.3670853, 0.19239302, -0.44274554, -0.57217056, 0.03906954, 0.7685065, -0.9496267, 0.4024507, -0.13379225, 0.011526011, -0.09900194, -0.16814715, 0.13582, -0.30829066, -0.13147047, 0.2822718, -0.1368126, -0.027293338, 0.49085197, -0.5331921, 0.49128994, 0.07755277, 0.0067159105, -0.42123628, 0.4408007, 0.58957845, -0.611145, -0.20726573, -0.14454971, 0.66820395, -0.007127879, 0.39164197, 0.69164586, 0.0024765078, -0.002838524, -0.6342908, 0.082881235, 0.298888, 0.5127087, -0.1300959, 0.16589926, 0.14517388, 0.2331577, -0.79271543, 0.20724288, -0.08834252, 0.2896087, 0.45591825, 0.0028544534, -0.51650685, 0.40780628, -0.2327805, 0.36442646, -0.6027139, -0.13944842, 0.24956803, -0.024805166, 0.24770494, -0.45614043, 0.051073316, 0.16322246, 0.28946966, -0.74529195, -0.4670576, -0.0660522, 0.2550549, -0.08014119, 0.06633573, -0.24153815, 0.36240074, 0.6568509, 0.1186171, 0.24172828, -0.47089025, -0.03256646, 0.14894187, 0.37634763, -0.24351446, 0.16164464, -0.06065242, 0.5091901, 0.5561973, -0.12233099, 0.13057524, -0.6718906, 0.35523456, 0.07152326, -0.20505619, -0.72078145, 0.044493422, -0.7440514, 0.032849424, 0.20111637, 0.43541732, 0.57285374, 0.11018203, -0.28251386, 0.2966433, -0.23974887, 0.068521775, 0.108346015, 0.30997896, 0.1311228, 0.22840464, -0.027650226, 0.73658687, -0.42126977, -0.10911406, -0.03391467, 0.035138264, -0.37075225, -0.5644764, 0.18124935, -0.16772854, 0.19432716, -0.22632165, 0.3900388, -0.69183433, 0.3003796, -0.3586413, 0.53591347, 0.6152693, -0.0038650539, -0.21756373, -0.31812942, -0.5402912, 0.2289098, -0.26077187, -0.2690417, -0.23407489, -0.3686389, -0.40951043, 0.18680233, 0.09277787, -0.013292864, -0.10835711, -0.53350085, 0.14608094, -0.15405877, -0.4699281, 0.10309839, 0.270473, -0.53506243, 0.001998501, -0.24999668, -0.1813581, -0.005529306, -0.56031805, -0.048346244, 0.16996299, -0.29963455, 0.20901474, -0.30934745, -0.21063489, 0.113431595, 0.032042094, -0.41378844, 0.17604393, -0.4631637, -0.5003293, -0.55831015, -0.15353276, -0.46837738, -0.77764505, 0.4621635, 0.28037566, 0.044108637, -0.11912877, -0.082292914, 0.6248694, -0.25019014, 0.2364985, 0.37894827, -0.1412379, -0.4572027, 0.08775234, -0.1661222, -0.09539573, 0.4805072, -0.44165152, 0.092903554, 0.28206846, 0.19792132, -0.05991053, -0.1682668, -0.5983673, 0.29130918, 0.17447682, -0.13979116, -0.5393585, 0.16808756, 0.7923586, 0.09991479, -0.22254673, -0.04165893, -0.2054404, -0.23928185, -0.36898398, -0.25248212, 0.49941728, 0.41065186, -0.23073834, -0.23773289, -0.1929644, -0.3499782, 0.1223462, -0.06807637, 0.175438, 0.14037246, 0.04750026, 0.31662372, 0.9243611, 0.12812611, -0.069829896, -0.29231697, -0.14734478, 0.1471743, -0.22162598, 0.37801683, 0.078523755, 0.022035534, 0.64812034, -0.23139268, -0.24911498, -0.33378544, -0.07344267, 0.11432794, -0.059079893, 0.31315947, -0.48160297, 0.045891166, 0.09479678, 0.4726333, -0.039052956, -0.2787302, -0.11797555, 0.25318092, -0.27943715, -0.22414759, -0.5546054, -0.106284395, 0.028235137, -0.3618798, -0.3404342, 0.31008774, -0.062293224, 0.053846367, 0.057488017, 0.09902769, 0.70674986, 0.37460673, 0.032727424, -0.4898543, -0.40673503, -0.19604088, 0.4853623, -0.20603043, -0.067109674, -0.53158236, -0.09847969, -0.14446007, -0.15639015, 0.13992839, -0.11348952, 0.15039599, 0.18439567, -0.20131497, 0.20396306, -0.3851034, -0.42031956, 0.2908249, 0.38035524, 0.15540914, 0.030931382, 0.33113614, 0.2741093, 0.18215193, -0.017978923, -0.0023425352, 0.27586395, -0.48403275, 0.023819037, -0.31364787, 0.21789135, -0.3554474, 0.0189421, 0.85861605, 0.15560226, 0.014263891, -0.16498215, -0.39344954, 0.5993788, 0.2708438, -0.29564035, 0.029406447, -0.43017676, 0.057052203, -0.04981024, -0.27520806, -0.51138026, 0.5926964, -0.40741482, 0.08575866, -0.196537, 0.3554017, -0.14750098, 0.051478647, -0.44164056, 0.13783655, 0.697562, 0.069351025, -0.14384158, 0.16349174, 0.36234668, -0.29366237, -0.28419775, -0.22636151, -0.122755915, -0.08138535, -0.7832234, 0.11401084, 0.43588766, 0.3236169, 0.14211948, 0.11028929, 0.2964773, -0.35562044, -0.3229665, -0.12864646, -0.3883256, 0.18198651, -0.45499957, 0.6918359, 0.1301559, -0.19978304, -0.1603161, -0.21330322, 0.07482076, 0.19179785, 0.45639312, 0.010576941, -0.3680949, -0.67871827, 0.14103784, 0.26047683, 0.64846706, -0.6718977, -0.5179457, 0.5580428, -0.48272127, 0.09030259, -0.47150746, 0.534373, 0.20664622, 0.5013874, -0.20477112, 0.22002026, -0.042670928, -0.047632568, -0.14199638, -0.36322978, 0.14286354, 0.35466686, 0.31752202, -0.3477305, -0.0045454763, -0.066675276, -0.2702982, 0.21498637, 0.08594364, 0.23323308, -0.6374196, 0.36372712, -0.30689493, -0.15897107, -0.16212063, -0.3653109, -0.40084177, 0.050653785, 0.13755074, -0.11666774, -0.2285859, -0.037478417, 0.20095918, 0.3487094, -0.08746929, 0.61240536, -0.26032016, 0.4687963, 0.14449233, 0.47144917, 0.039736982, 0.079499036, 0.19588508, -0.017659996, 0.5632192, -0.18003824, -0.040397547, -0.36472237, -0.025306623, 0.437257, -0.086474344, -0.7352421, 0.13874608, -0.110756345, -0.048157282, 0.19240974, 0.080513775, 0.59602517, -0.34077823, 0.35610113, -0.25455856, -0.36457375, 0.37937617, -0.400827, 0.054261968, 0.2879811, -0.11387855, 0.5244568, -0.107315816, 0.27161226, -0.1021186, -0.18614362, 0.2978657, 0.5370135, 0.36572018, -0.15257043, 0.20518257, 0.6419209, -0.3316342, -0.2725687, 0.014353298, 0.26933295, 0.26414502, 0.12794185, -0.21693292, 0.20204909, -0.21943296, -0.2593051, 0.13386369, -0.58061814, -0.63560903, 0.36805475, 0.016099993, -0.30377084, -0.06283789, -0.2682347, 0.056822743, -0.40932798, 0.36230356, -0.0491954, -0.029950788, -0.014880693, 0.20313863, 0.2214678, 0.2732247, -0.20896281, 0.017126573, -0.122736566, 0.06616182, 0.024874818, -0.7285719, -0.08402412, 0.28992975, 0.6755032, -0.071360886, 0.08794832, 0.09367153, 0.2191656, -0.28197074, -0.060419604, 0.008227522, -0.77392685, -0.37148783, 0.028109314, 0.6432903, 0.107398115, -0.031399835, -0.17004932, 0.34670952, -0.2593519, 0.3303603, 0.05096604, -0.59591603, -0.28604108, -0.37413853, -0.50223345, -0.06456756, 0.30097467, -0.344833, -0.44159552, 0.1301419, 0.4851787, -0.48891386, 0.29695496, 0.04624813, 0.44581613, -0.342841, 0.030060228, 0.41584003, -0.06151448, -0.6391304, 0.24350402, 0.2373283, -0.21993239, -0.15819496, 0.33998314, -0.104932904, 0.53482264, 0.18774611, -0.30718842, 0.09050197, 0.06582601, -0.8714315, 0.2866774, 0.10656398, 0.108911455, 0.12436204, 0.6072432, -0.263783, 0.3477571, -0.21450447, -0.20965956, 0.2725455, -0.15962526, -0.023959063, -0.16272986, 0.37898353, 0.1893706, -0.35078412, 0.018863793, 0.19243363, 0.7553659, 0.3897343, 0.16990745, -0.12922706, 0.32337534, -0.07977969, 0.09969508, 0.12787843, 0.14316258, -0.38789797, -0.18665363, 0.41474488, -0.04381171, 0.47398177, 0.20612329, -0.13801742, 0.48971528, 0.15693656, 0.10821125, -0.10725921, -0.20428485, -0.84528387, 0.18022658, 0.50938493, -0.32568434, 0.30802926, 0.33309558, 0.1919713, 0.20726888, -0.16194591, 0.17306438, -0.15405764, -0.57394016, 0.6925947, -0.31852844, -0.07849608, -0.5219136, -0.12416126, -0.20998093, 0.6197391, -0.11049731, -0.07111119, 0.22960934, -0.15123159, 0.22498849, -0.07220747, -0.28159276, -0.16107027, -0.010942766, 0.5636157, 0.4077794, 0.39782813, 0.456499, 0.22233048, 0.56309587, 0.26227084, -0.18100007, 0.06122207, 0.27089763, 0.17011975, -0.42344883, -0.063430965, 0.070528686, -0.046008512, -0.29042992, -0.07066448, -0.2578915, -0.27239347, 0.2880362, -0.056104008, -0.40367386, -0.091103815, 0.46031728, -0.36084417, -0.01598189, 0.19975084, 0.01695741, 0.3267317, 0.22532314, -0.55215025, -0.098993674, 0.36677533, 0.44303438, 0.3397658, -0.42336193, 0.002683131, 0.17797257, 0.6305417, 0.54148203, 0.17323923, 0.11428201, -0.07747766, -0.11240339, 0.11639454, 0.05241075, 0.035248175, -0.57705295, 0.45263726, -0.35879546, -0.7651455, -0.03033166, 0.47368425, -0.02433325, -0.15444314, -0.27954623, 0.30544212, 0.19804852, -0.66339266, -0.018637381, -0.3836641, 0.10387643, -0.23915236, 0.097831056, -0.18519881, 0.42123106, -0.0021492783, -0.4928366, 0.051339585, -0.50189865, -0.0325974, 0.03475754, 0.24877562, -0.50540763, 0.14656179, -0.033425312, -0.2698435, 0.1414198, 0.015859405, 0.4277053, -0.040847912, 0.032052774, 0.39479595, -0.0018053818, -0.37721512, -0.027027369, 0.44188333, 0.18346275, 0.6159405, -0.0010263352, 0.120682925, -0.5515572, 0.4246414, 0.37855124, 0.31135443, 0.255429, 0.010005429, -0.8138245, -0.26479146, -0.34098482, 0.14558652, 0.63190436, 0.1779253, 0.43572387, 0.6876498, 0.06974258, 0.007930072, -0.09172004, 0.18957798, -0.16211304, 0.18704513, -0.25963065, -0.26715553, -0.22632961, -0.3099424, 0.3464097, -0.12967771, 0.16652606, 0.2921636, 0.09758349, 0.2582998, 0.11978268, 0.42495492, 0.02736637, -0.32260302, -0.3379873, -0.23938976, -0.19942743, 0.30798694, 0.25228044, -0.033107795, 0.09772943, 0.38394168, 0.7219979, -0.5064522, -0.21723904, -0.2033075, 0.020857109, 0.13053142, -0.38791847, 0.4991684, 0.20062184, 0.49477854, -0.26213312, -0.61973774, -0.074013926, -0.12128413, -0.56617993, 0.13392372, 0.73387975, 0.5033897, 0.33373255, -0.06803796, 0.5550287, 0.26606622, -0.35267583, -0.23695293, 0.26170373, -0.12340009, 0.80251247, -0.70798254, -0.028666062, 0.6997679, -0.05996991, 0.06898104, -0.14557816, 0.054661553, 0.5187798, 0.41702572, 0.792891, -0.17265478, -0.06679568, -0.331478, 0.0694997, 0.4253223, -0.2783028, 0.23903547, -0.58266413, 0.09287575, 0.045140624, -0.10417832, 0.08257238, 0.48208177, -0.24164109, 0.81102467, -0.40342188, 0.65527093, -0.12488523, 0.078327045, -0.5329088, 0.37736076, 0.2925939, 0.20142855, 0.21402623, -0.21197478, -0.31154165, 0.45887777, 0.205758, 0.12233909, -0.26103553, 1.1294454, -0.7648704, 0.32436037, -0.06368509, 0.57072765, 0.9322751, -0.29020756, 0.44769418, -0.839836, 0.07865648, -0.2559945, 0.4581841, -0.017776983, 0.18255703, 0.2528128, -0.41778934, -0.071126916, -0.041809052, -0.53156054, 0.16023451, -0.2608511, 0.0725673, -0.15921246, -0.03191948, -0.366381, 0.53149635, -0.2550226, -0.022553788, -0.36375383, 0.40580854, -0.076502666, -0.04272891, -0.28619775, -0.26721123, 0.56044143, -0.040593743, -0.28715926, -0.0043915436, 0.11840753, -0.35239887, -0.30920973, -0.14502974, -0.36411104, 0.44530326, 0.43969297, -0.23792548, 0.30757633, 0.26880985, 0.18359815, -0.5675205, 0.19222523, 0.22303401, 0.21661428, 0.22141027, -0.10556421, 0.11646886, 0.17539617, 0.96604997, 0.055217523, -0.7456562, -0.106842384, -0.19286019, 0.17667075, 0.92509866, 0.57278365, -0.024029609, -0.8224203, -0.12689532, 0.079639494, -0.06534128, 0.7061269, -0.09063532, 0.5011331, -0.5051317, -0.054662913, -0.26086497, -0.53341925, 0.9624672, 0.08449669, -0.21910548, 0.36410314, -0.24794322, 0.16658492, 0.7944018, -0.058724128, -0.22618303, 0.5062074, -0.516353, 0.69395834, -0.23764399, -0.13169304, 0.51044196, -0.042955525, -0.42410484, -0.4293069, 0.13401544, 0.80136365, 0.30296534, -0.06788176, 0.16880289, 0.27950272, -0.37403736, 0.11813866, -0.41821468, 0.0033562258, -0.53348655, -0.22950119, 0.3889678, 0.10558852, -0.25912097, -0.03190498, 0.028149713, 0.36284888, -0.63619995, 0.8380439, 0.6589971, 0.6046954, -0.2093836, 0.08808039, 0.48332697, -0.010615652, -0.40519536, 0.011716956, 0.096273005, -0.27340046, -0.19237258, -0.2970637, -0.44011658, 0.17786184, 0.0071578454, 0.23985118, -0.040508576]
}
}
}
}
},
"_source": ["url"],
"size": 100
}
四、调用pytorch模型代码
创建Test类,copy一下我的,感兴趣可以去djl的官网学习更多内容。写完后,就可以获取“随便”文件夹中的图像的特征向量,上传到es里了。
package org.gwen;
import ai.djl.Device;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Transform;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import org.apache.http.HttpHost;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestClientBuilder;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.ScriptQueryBuilder;
import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.gwen.entity.SearchResult;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URL;
import java.nio.file.Paths;
import java.util.*;
public class Test {
private static final String INDEX = "isi";
private static final int IMAGE_SIZE = 224;
private static Model model; //模型
private static Predictor<Image, float[]> predictor; //predictor.predict(input)相当于python中model(input)
static {
try {
model = Model.newInstance("model");
//这里的model.pt是上面代码展示的那种方式保存的
model.load(Test.class.getClassLoader().getResourceAsStream("model.pt"));
Transform resize = new Resize(IMAGE_SIZE);
Transform toTensor = new ToTensor();
Transform normalize = new Normalize(new float[]{0.485f, 0.456f, 0.406f}, new float[]{0.229f, 0.224f, 0.225f});
//Translator处理输入Image转为tensor、输出转为float[]
Translator<Image, float[]> translator = new Translator<Image, float[]>() {
@Override
public NDList processInput(TranslatorContext ctx, Image input) throws Exception {
NDManager ndManager = ctx.getNDManager();
System.out.println("input: " + input.getWidth() + ", " + input.getHeight());
NDArray transform = normalize.transform(toTensor.transform(resize.transform(input.toNDArray(ndManager))));
System.out.println(transform.getShape());
NDList list = new NDList();
list.add(transform);
return list;
}
@Override
public float[] processOutput(TranslatorContext ctx, NDList ndList) throws Exception {
return ndList.get(0).toFloatArray();
}
};
predictor = new Predictor<>(model, translator, Device.cpu(), true);
} catch (Exception e) {
e.printStackTrace();
}
}
}
五、es上传和搜索
上传:遍历每张图片,获取每张图片的特征,上传到es
搜索:获取输入图像的特征,创建SearchRequest在es中通过painless脚本进行余弦相似度对比检索。
首先创建SearchResult类表示es搜索的结果,包括图像url和相关度评分score
@Data
@AllArgsConstructor
public class SearchResult {
private String url;
private Float score;
}
然后在Test里实现upload和search
public static void upload() throws Exception {
RestHighLevelClient client = new RestHighLevelClient(
RestClient.builder(new HttpHost("localhost", 9200, "http")));
//批量上传请求
BulkRequest bulkRequest = new BulkRequest(INDEX);
File file = new File("C:\\Users\\Administrator\\IdeaProjects\\img_search_img\\src\\main\\resources\\随便");
for (File listFile : file.listFiles()) {
float[] vector = predictor.predict(ImageFactory.getInstance().fromInputStream(Test.class.getClassLoader().getResourceAsStream("随便/" + listFile.getName())));
// 构建文档
Map<String, Object> jsonMap = new HashMap<>();
jsonMap.put("url", listFile.getAbsolutePath());
jsonMap.put("vector", vector);
jsonMap.put("user_id", "user123");
IndexRequest request = new IndexRequest(INDEX).source(jsonMap, XContentType.JSON);
bulkRequest.add(request);
}
client.bulk(bulkRequest, RequestOptions.DEFAULT);
client.close();
}
//接收待搜索图片的inputstream,搜索与其相似的图片
public static List<SearchResult> search(InputStream input) throws Throwable {
float[] vector = predictor.predict(ImageFactory.getInstance().fromInputStream(input));
System.out.println(Arrays.toString(vector));
//展示k个结果
int k = 100;
// 连接Elasticsearch服务器
RestHighLevelClient client = new RestHighLevelClient(
RestClient.builder(new HttpHost("localhost", 9200, "http")));
SearchRequest searchRequest = new SearchRequest(INDEX);
Script script = new Script(
ScriptType.INLINE,
"painless",
"cosineSimilarity(params.queryVector, doc['vector'])",
Collections.singletonMap("queryVector", vector));
FunctionScoreQueryBuilder functionScoreQueryBuilder = QueryBuilders.functionScoreQuery(
QueryBuilders.matchAllQuery(),
ScoreFunctionBuilders.scriptFunction(script));
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(functionScoreQueryBuilder)
.fetchSource(null, "vector") //不返回vector字段,太多了没用还耗时
.size(k);
searchRequest.source(searchSourceBuilder);
SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT);
SearchHits hits = searchResponse.getHits();
List<SearchResult> list = new ArrayList<>();
for (SearchHit hit : hits) {
// 处理搜索结果
System.out.println(hit.toString());
SearchResult result = new SearchResult((String) hit.getSourceAsMap().get("url"), hit.getScore());
list.add(result);
}
client.close();
return list;
}
六、测试
@RestController
@CrossOrigin
public class SearchController {
@PostMapping("search")
public ResponseEntity search(MultipartFile file) {
try {
List<SearchResult> list = Test.search(file.getInputStream());
return ResponseEntity.ok(list);
} catch (Throwable e) {
return ResponseEntity.status(400).body(null);
}
}
}
页面:
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Title</title>
<style>
body {
background: url("/img/bg.jpg");
background-attachment: fixed;
background-size: 100% 100%;
}
body > div {
width: 1000px;
margin: 50px auto;
padding: 10px 20px;
border: 1px solid lightgray;
border-radius: 20px;
box-sizing: border-box;
background: rgba(255, 255, 255, 0.7);
}
.upload {
display: inline-block;
width: 300px;
height: 280px;
border: 1px dashed lightcoral;
vertical-align: top;
}
.upload .cover {
width: 200px;
height: 200px;
margin: 10px 50px;
border: 1px solid black;
box-sizing: border-box;
text-align: center;
line-height: 200px;
position: relative;
}
.upload img {
width: 198px;
height: 198px;
position: absolute;
left:0;
top: 0;
}
.upload input {
margin-left: 50px;
}
.upload button {
width: 80px;
height: 30px;
margin-left: 110px;
}
.result-block {
display: inline-block;
margin-left: 40px;
border: 1px solid lightgray;
border-radius: 10px;
min-height: 500px;
width: 600px;
}
.result-block h1 {
text-align: center;
margin-top: 100px;
}
.result {
padding: 10px;
cursor: pointer;
display: inline-block;
}
.result:hover {
background: rgb(240, 240, 240);
}
.result p {
width: 110px;
overflow: hidden;
white-space: nowrap;
text-overflow: ellipsis;
}
.result img {
width: 160px;
height: 160px;
}
.result .prob {
color: rgb(37, 147, 60)
}
</style>
<script src="js/jquery-3.6.0.js"></script>
</head>
<body>
<div>
<div class="upload">
<div class="cover">
请选择图片
<img id="image" src=""/>
</div>
<input id="file" type="file">
</div>
<div class="result-block">
<h1>请选择图片</h1>
</div>
</div>
<ul id="box">
</ul>
<script>
var file = $('#file')
file.change(function () {
let f = this.files[0]
let index = f.name.lastIndexOf('.')
let fileText = f.name.substring(index,f.name.length)
let ext = fileText.toLowerCase() //文件类型
console.log(ext)
if(ext != '.png' && ext != '.jpg' && ext != '.jpeg'){
alert('系统仅支持 JPG、PNG、JPEG 格式的图片,请您调整格式后重新上传')
return
}
$('.result-block').empty().append($('<h1>正在识别中...</h1>'))
$("#image").attr("src",getObjectURL(f));
let formData = new FormData()
formData.append('file',f)
$.ajax({
url: 'http://localhost:8080/search',
method: 'post',
data: formData,
processData: false,
contentType: false,
success: res => {
console.log('shibie', res)
$('.result-block').empty()
for (let item of res) {
console.log(item)
let html = `<div class="result">
<img src="file:///${item.url}"/>
<div style="display: inline-block;vertical-align: top">
<p class="prob">得分:${item.score.toFixed(4)}</p>
</div>
</div>`
$('.result-block').append($(html))
}
}
})
});
$('#button').click(function(e) {
var file = $('#file')[0].files[0] //单个
console.log(file)
})
function getObjectURL(file) {
var url = null;
if (window.createObjcectURL != undefined) {
url = window.createOjcectURL(file);
} else if (window.URL != undefined) {
url = window.URL.createObjectURL(file);
} else if (window.webkitURL != undefined) {
url = window.webkitURL.createObjectURL(file);
}
return url;
}
function detect() {
}
</script>
</body>
</html>