首先在ollama中安装mofanke/dmeta-embedding-zh:latest。执行ollama run mofanke/dmeta-embedding-zh 。实现将文本转化为向量数据
接着安装pgvector(建议使用pgadmin4作为可视化工具,用navicate会出现表不显示的问题)
安装好需要的软件后我们开始编码操作。
1:在pom文件中加入:
<!--用于连接pgsql-->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-jdbc</artifactId>
</dependency>
<!--用于使用pgvector来操作向量数据库-->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-pgvector-store-spring-boot-starter</artifactId>
</dependency>
<!--pdf解析-->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-pdf-document-reader</artifactId>
</dependency>
<!--文档解析l-->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-tika-document-reader</artifactId>
</dependency>
2:在yml中配置:
spring:
datasource:
url: jdbc:postgresql://127.0.0.1:5432/postgres
username: postgres
password: password
ai:
vectorstore:
pgvector:
dimensions: 768 #不同的embeddingmodel对应的值
ollama:
base-url: http://127.0.0.1:11434
chat:
enabled: true
options:
model: qwen2:7b
embedding:
model: mofanke/dmeta-embedding-zh
3:在controller中加入:
/**
* 嵌入文件
*
* @param file 待嵌入的文件
* @return 是否成功
*/
@SneakyThrows
@PostMapping("embedding")
public List<Document> embedding(@RequestParam MultipartFile file) {
// 从IO流中读取文件
TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(new InputStreamResource(file.getInputStream()));
// 将文本内容划分成更小的块
List<Document> splitDocuments = new TokenTextSplitter()
.apply(tikaDocumentReader.read());
// 存入向量数据库,这个过程会自动调用embeddingModel,将文本变成向量再存入。
vector.add(splitDocuments);
return splitDocuments;
}
调用上方的接口可以将文档转为向量数据存入到pgvector中
4:请求聊天,先根据聊天内容通过pgvector获取对应的数据,并将结果丢到qwen2模型中进行数据分析并返回结果
/**
* 获取prompt
*
* @param message 提问内容
* @param context 上下文
* @return prompt
*/
private String getChatPrompt2String(String message, String context) {
String promptText = """
请用仅用以下内容回答"%s" ,输出结果仅在以下内容中,输出内容仅以下内容,不需要其他描述词:
%s
""";
return String.format(promptText, message, context);
}
@GetMapping("chatToPgVector")
public String chatToPgVector(String message) {
// 1. 定义提示词模板,question_answer_context会被替换成向量数据库中查询到的文档。
String promptWithContext = """
你是一个代码程序,你需要在文本中获取信息并输出成json格式的数据,下面是上下文信息
---------------------
{question_answer_context}
---------------------
给定的上下文和提供的历史信息,而不是事先的知识,回复用户的意见。如果答案不在上下文中,告诉用户你不能回答这个问题。
""";
//查询获取文档信息
List<Document> documents = vector.similaritySearch(message,"test_store");
//提取文本内容
String content = documents.stream()
.map(Document::getContent)
.collect(Collectors.joining("\n"));
System.out.println(content);
//封装prompt并调用大模型
String chatResponse = ollamaChatModel.call(getChatPrompt2String(message, content));
return chatResponse;
/* return ChatClient.create(ollamaChatModel).prompt()
.user(message)
// 2. QuestionAnswerAdvisor会在运行时替换模板中的占位符`question_answer_context`,替换成向量数据库中查询到的文档。此时的query=用户的提问+替换完的提示词模板;
.advisors(new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults(), promptWithContext))
.call().content();*/
}
至此一个简单的rag搜索增强demo就完成了。接下来我们来看看PgVectorStore为我们做了什么
//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by FernFlower decompiler)
//
package org.springframework.ai.vectorstore;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.pgvector.PGvector;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.IntStream;
import org.postgresql.util.PGobject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.filter.converter.PgVectorFilterExpressionConverter;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.StatementCreatorUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;
public class PgVectorStore implements VectorStore, InitializingBean {
private static final Logger logger = LoggerFactory.getLogger(PgVectorStore.class);
public static final int OPENAI_EMBEDDING_DIMENSION_SIZE = 1536;
public static final int INVALID_EMBEDDING_DIMENSION = -1;
public static final String VECTOR_TABLE_NAME = "vector_store";
public static final String VECTOR_INDEX_NAME = "spring_ai_vector_index";
public final FilterExpressionConverter filterExpressionConverter;
private final JdbcTemplate jdbcTemplate;
private final EmbeddingModel embeddingModel;
private int dimensions;
private PgDistanceType distanceType;
private ObjectMapper objectMapper;
private boolean removeExistingVectorStoreTable;
private PgIndexType createIndexMethod;
private final boolean initializeSchema;
public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
this(jdbcTemplate, embeddingModel, -1, PgVectorStore.PgDistanceType.COSINE_DISTANCE, false, PgVectorStore.PgIndexType.NONE, false);
}
public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions) {
this(jdbcTemplate, embeddingModel, dimensions, PgVectorStore.PgDistanceType.COSINE_DISTANCE, false, PgVectorStore.PgIndexType.NONE, false);
}
public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgDistanceType distanceType, boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod, boolean initializeSchema) {
this.filterExpressionConverter = new PgVectorFilterExpressionConverter();
this.objectMapper = new ObjectMapper();
this.jdbcTemplate = jdbcTemplate;
this.embeddingModel = embeddingModel;
this.dimensions = dimensions;
this.distanceType = distanceType;
this.removeExistingVectorStoreTable = removeExistingVectorStoreTable;
this.createIndexMethod = createIndexMethod;
this.initializeSchema = initializeSchema;
}
public PgDistanceType getDistanceType() {
return this.distanceType;
}
public void add(final List<Document> documents) {
final int size = documents.size();
this.jdbcTemplate.batchUpdate("INSERT INTO vector_store (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) ON CONFLICT (id) DO UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ", new BatchPreparedStatementSetter() {
public void setValues(PreparedStatement ps, int i) throws SQLException {
Document document = (Document)documents.get(i);
String content = document.getContent();
String json = PgVectorStore.this.toJson(document.getMetadata());
PGvector pGvector = new PGvector(PgVectorStore.this.toFloatArray(PgVectorStore.this.embeddingModel.embed(document)));
StatementCreatorUtils.setParameterValue(ps, 1, Integer.MIN_VALUE, UUID.fromString(document.getId()));
StatementCreatorUtils.setParameterValue(ps, 2, Integer.MIN_VALUE, content);
StatementCreatorUtils.setParameterValue(ps, 3, Integer.MIN_VALUE, json);
StatementCreatorUtils.setParameterValue(ps, 4, Integer.MIN_VALUE, pGvector);
StatementCreatorUtils.setParameterValue(ps, 5, Integer.MIN_VALUE, content);
StatementCreatorUtils.setParameterValue(ps, 6, Integer.MIN_VALUE, json);
StatementCreatorUtils.setParameterValue(ps, 7, Integer.MIN_VALUE, pGvector);
}
public int getBatchSize() {
return size;
}
});
}
private String toJson(Map<String, Object> map) {
try {
return this.objectMapper.writeValueAsString(map);
} catch (JsonProcessingException var3) {
throw new RuntimeException(var3);
}
}
private float[] toFloatArray(List<Double> embeddingDouble) {
float[] embeddingFloat = new float[embeddingDouble.size()];
int i = 0;
Double d;
for(Iterator var4 = embeddingDouble.iterator(); var4.hasNext(); embeddingFloat[i++] = d.floatValue()) {
d = (Double)var4.next();
}
return embeddingFloat;
}
public Optional<Boolean> delete(List<String> idList) {
int updateCount = 0;
int count;
for(Iterator var3 = idList.iterator(); var3.hasNext(); updateCount += count) {
String id = (String)var3.next();
count = this.jdbcTemplate.update("DELETE FROM vector_store WHERE id = ?", new Object[]{UUID.fromString(id)});
}
return Optional.of(updateCount == idList.size());
}
public List<Document> similaritySearch(SearchRequest request) {
String nativeFilterExpression = request.getFilterExpression() != null ? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : "";
String jsonPathFilter = "";
if (StringUtils.hasText(nativeFilterExpression)) {
jsonPathFilter = " AND metadata::jsonb @@ '" + nativeFilterExpression + "'::jsonpath ";
}
double distance = 1.0 - request.getSimilarityThreshold();
PGvector queryEmbedding = this.getQueryEmbedding(request.getQuery());
return this.jdbcTemplate.query(String.format(this.getDistanceType().similaritySearchSqlTemplate, "vector_store", jsonPathFilter), new DocumentRowMapper(this.objectMapper), new Object[]{queryEmbedding, queryEmbedding, distance, request.getTopK()});
}
public List<Double> embeddingDistance(String query) {
return this.jdbcTemplate.query("SELECT embedding " + this.comparisonOperator() + " ? AS distance FROM vector_store", new RowMapper<Double>() {
@Nullable
public Double mapRow(ResultSet rs, int rowNum) throws SQLException {
return rs.getDouble("distance");
}
}, new Object[]{this.getQueryEmbedding(query)});
}
private PGvector getQueryEmbedding(String query) {
List<Double> embedding = this.embeddingModel.embed(query);
return new PGvector(this.toFloatArray(embedding));
}
private String comparisonOperator() {
return this.getDistanceType().operator;
}
public void afterPropertiesSet() throws Exception {
if (this.initializeSchema) {
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");
if (this.removeExistingVectorStoreTable) {
this.jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store");
}
this.jdbcTemplate.execute(String.format("CREATE TABLE IF NOT EXISTS %s (\n\tid uuid DEFAULT uuid_generate_v4() PRIMARY KEY,\n\tcontent text,\n\tmetadata json,\n\tembedding vector(%d)\n)\n", "vector_store", this.embeddingDimensions()));
if (this.createIndexMethod != PgVectorStore.PgIndexType.NONE) {
this.jdbcTemplate.execute(String.format("CREATE INDEX IF NOT EXISTS %s ON %s USING %s (embedding %s)\n", "spring_ai_vector_index", "vector_store", this.createIndexMethod, this.getDistanceType().index));
}
}
}
int embeddingDimensions() {
if (this.dimensions > 0) {
return this.dimensions;
} else {
try {
int embeddingDimensions = this.embeddingModel.dimensions();
if (embeddingDimensions > 0) {
return embeddingDimensions;
}
} catch (Exception var2) {
logger.warn("Failed to obtain the embedding dimensions from the embedding model and fall backs to default:1536", var2);
}
return 1536;
}
}
public static enum PgDistanceType {
EUCLIDEAN_DISTANCE("<->", "vector_l2_ops", "SELECT *, embedding <-> ? AS distance FROM %s WHERE embedding <-> ? < ? %s ORDER BY distance LIMIT ? "),
NEGATIVE_INNER_PRODUCT("<#>", "vector_ip_ops", "SELECT *, (1 + (embedding <#> ?)) AS distance FROM %s WHERE (1 + (embedding <#> ?)) < ? %s ORDER BY distance LIMIT ? "),
COSINE_DISTANCE("<=>", "vector_cosine_ops", "SELECT *, embedding <=> ? AS distance FROM %s WHERE embedding <=> ? < ? %s ORDER BY distance LIMIT ? ");
public final String operator;
public final String index;
public final String similaritySearchSqlTemplate;
private PgDistanceType(String operator, String index, String sqlTemplate) {
this.operator = operator;
this.index = index;
this.similaritySearchSqlTemplate = sqlTemplate;
}
}
public static enum PgIndexType {
NONE,
IVFFLAT,
HNSW;
private PgIndexType() {
}
}
private static class DocumentRowMapper implements RowMapper<Document> {
private static final String COLUMN_EMBEDDING = "embedding";
private static final String COLUMN_METADATA = "metadata";
private static final String COLUMN_ID = "id";
private static final String COLUMN_CONTENT = "content";
private static final String COLUMN_DISTANCE = "distance";
private ObjectMapper objectMapper;
public DocumentRowMapper(ObjectMapper objectMapper) {
this.objectMapper = objectMapper;
}
public Document mapRow(ResultSet rs, int rowNum) throws SQLException {
String id = rs.getString("id");
String content = rs.getString("content");
PGobject pgMetadata = (PGobject)rs.getObject("metadata", PGobject.class);
PGobject embedding = (PGobject)rs.getObject("embedding", PGobject.class);
Float distance = rs.getFloat("distance");
Map<String, Object> metadata = this.toMap(pgMetadata);
metadata.put("distance", distance);
Document document = new Document(id, content, metadata);
document.setEmbedding(this.toDoubleList(embedding));
return document;
}
private List<Double> toDoubleList(PGobject embedding) throws SQLException {
float[] floatArray = (new PGvector(embedding.getValue())).toArray();
return IntStream.range(0, floatArray.length).mapToDouble((i) -> {
return (double)floatArray[i];
}).boxed().toList();
}
private Map<String, Object> toMap(PGobject pgObject) {
String source = pgObject.getValue();
try {
return (Map)this.objectMapper.readValue(source, Map.class);
} catch (JsonProcessingException var4) {
throw new RuntimeException(var4);
}
}
}
}
我们可以看到PgVectorStore实现了InitializingBean并实现了afterPropertiesSet方法。它会在属性设置完成后执行。
public void afterPropertiesSet() throws Exception {
if (this.initializeSchema) {
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");
if (this.removeExistingVectorStoreTable) {
this.jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store");
}
this.jdbcTemplate.execute(String.format("CREATE TABLE IF NOT EXISTS %s (\n\tid uuid DEFAULT uuid_generate_v4() PRIMARY KEY,\n\tcontent text,\n\tmetadata json,\n\tembedding vector(%d)\n)\n", "vector_store", this.embeddingDimensions()));
if (this.createIndexMethod != PgVectorStore.PgIndexType.NONE) {
this.jdbcTemplate.execute(String.format("CREATE INDEX IF NOT EXISTS %s ON %s USING %s (embedding %s)\n", "spring_ai_vector_index", "vector_store", this.createIndexMethod, this.getDistanceType().index));
}
}
}
这里它会根据initializeSchema(在PgVectorStoreProperties中,默认为true,我们可以yml中配置spring:ai:vectorstore:pgvector:initialize-schema:false来禁用)来判断是否帮我们建表。这里他会帮我们建一个叫vector_store的表,其中包含id(uuid),metadate(json),content(text),embedding(vector(1536))。这里1536指的就是dimensions的值。当我们用默认建的表去做pgvector的诗句存储时会出现 ERROR: expected 1536 dimensions, not 768这样的报错,就是表示我们ollama中的embedding模型输出的dimensions是768,而pgvector中的embedding是1536,他们不匹配所以无法存储。这时我们需要去pgvector中修改embedding字段的token数为768即可(这里不同模型返回的dimension值不一样,可以根据报错信息自行调整)
接下来我们看一下核心的操作方法-向数据库中插入数据
public void add(final List<Document> documents) {
final int size = documents.size();
this.jdbcTemplate.batchUpdate("INSERT INTO vector_store (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) ON CONFLICT (id) DO UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ", new BatchPreparedStatementSetter() {
public void setValues(PreparedStatement ps, int i) throws SQLException {
Document document = (Document)documents.get(i);
String content = document.getContent();
String json = PgVectorStore.this.toJson(document.getMetadata());
PGvector pGvector = new PGvector(PgVectorStore.this.toFloatArray(PgVectorStore.this.embeddingModel.embed(document)));
StatementCreatorUtils.setParameterValue(ps, 1, Integer.MIN_VALUE, UUID.fromString(document.getId()));
StatementCreatorUtils.setParameterValue(ps, 2, Integer.MIN_VALUE, content);
StatementCreatorUtils.setParameterValue(ps, 3, Integer.MIN_VALUE, json);
StatementCreatorUtils.setParameterValue(ps, 4, Integer.MIN_VALUE, pGvector);
StatementCreatorUtils.setParameterValue(ps, 5, Integer.MIN_VALUE, content);
StatementCreatorUtils.setParameterValue(ps, 6, Integer.MIN_VALUE, json);
StatementCreatorUtils.setParameterValue(ps, 7, Integer.MIN_VALUE, pGvector);
}
public int getBatchSize() {
return size;
}
});
}
这里因为Springai刚出,也不是稳定版的,它在代码中直接写死了操作表。我们使用pgvectorStore时只能对vector_store进行操作,这在实际应用场景中可能会造成一定的局限性。所以我们可以自己写一个扩展操作类来替换它。如下:
package com.lccloud.tenderdocument.vector;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.pgvector.PGvector;
import org.postgresql.util.PGobject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.PgVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.filter.converter.PgVectorFilterExpressionConverter;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.StatementCreatorUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.*;
import java.util.stream.IntStream;
public class ExtendPgVectorStore {
private static final Logger logger = LoggerFactory.getLogger(ExtendPgVectorStore.class);
public final FilterExpressionConverter filterExpressionConverter;
private final JdbcTemplate jdbcTemplate;
private final EmbeddingModel embeddingModel;
private int dimensions;
private PgVectorStore.PgDistanceType distanceType;
private ObjectMapper objectMapper;
private boolean removeExistingVectorStoreTable;
private PgVectorStore.PgIndexType createIndexMethod;
private final boolean initializeSchema;
public ExtendPgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
this(jdbcTemplate, embeddingModel, -1, PgVectorStore.PgDistanceType.COSINE_DISTANCE, false, PgVectorStore.PgIndexType.NONE, false);
}
public ExtendPgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions) {
this(jdbcTemplate, embeddingModel, dimensions, PgVectorStore.PgDistanceType.COSINE_DISTANCE, false, PgVectorStore.PgIndexType.NONE, false);
}
public ExtendPgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgVectorStore.PgDistanceType distanceType, boolean removeExistingVectorStoreTable, PgVectorStore.PgIndexType createIndexMethod, boolean initializeSchema) {
this.filterExpressionConverter = new PgVectorFilterExpressionConverter();
this.objectMapper = new ObjectMapper();
this.jdbcTemplate = jdbcTemplate;
this.embeddingModel = embeddingModel;
this.dimensions = dimensions;
this.distanceType = distanceType;
this.removeExistingVectorStoreTable = removeExistingVectorStoreTable;
this.createIndexMethod = createIndexMethod;
this.initializeSchema = initializeSchema;
}
public PgVectorStore.PgDistanceType getDistanceType() {
return this.distanceType;
}
public void add(final List<Document> documents,String tableName) {
final int size = documents.size();
this.jdbcTemplate.batchUpdate("INSERT INTO "+ tableName+" (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) ON CONFLICT (id) DO UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ", new BatchPreparedStatementSetter() {
public void setValues(PreparedStatement ps, int i) throws SQLException {
Document document = (Document)documents.get(i);
String content = document.getContent();
String json = ExtendPgVectorStore.this.toJson(document.getMetadata());
PGvector pGvector = new PGvector(ExtendPgVectorStore.this.toFloatArray(ExtendPgVectorStore.this.embeddingModel.embed(document)));
StatementCreatorUtils.setParameterValue(ps, 1, Integer.MIN_VALUE, UUID.fromString(document.getId()));
StatementCreatorUtils.setParameterValue(ps, 2, Integer.MIN_VALUE, content);
StatementCreatorUtils.setParameterValue(ps, 3, Integer.MIN_VALUE, json);
StatementCreatorUtils.setParameterValue(ps, 4, Integer.MIN_VALUE, pGvector);
StatementCreatorUtils.setParameterValue(ps, 5, Integer.MIN_VALUE, content);
StatementCreatorUtils.setParameterValue(ps, 6, Integer.MIN_VALUE, json);
StatementCreatorUtils.setParameterValue(ps, 7, Integer.MIN_VALUE, pGvector);
}
public int getBatchSize() {
return size;
}
});
}
private String toJson(Map<String, Object> map) {
try {
return this.objectMapper.writeValueAsString(map);
} catch (JsonProcessingException var3) {
throw new RuntimeException(var3);
}
}
private float[] toFloatArray(List<Double> embeddingDouble) {
float[] embeddingFloat = new float[embeddingDouble.size()];
int i = 0;
Double d;
for(Iterator var4 = embeddingDouble.iterator(); var4.hasNext(); embeddingFloat[i++] = d.floatValue()) {
d = (Double)var4.next();
}
return embeddingFloat;
}
public Optional<Boolean> delete(List<String> idList,String tableName) {
int updateCount = 0;
int count;
for(Iterator var3 = idList.iterator(); var3.hasNext(); updateCount += count) {
String id = (String)var3.next();
count = this.jdbcTemplate.update("DELETE FROM "+tableName+" WHERE id = ?", new Object[]{UUID.fromString(id)});
}
return Optional.of(updateCount == idList.size());
}
public List<Document> similaritySearch(String query,String tableName) {
return this.similaritySearch(SearchRequest.query(query),tableName);
}
public List<Document> similaritySearch(SearchRequest request,String tableName) {
String nativeFilterExpression = request.getFilterExpression() != null ? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : "";
String jsonPathFilter = "";
if (StringUtils.hasText(nativeFilterExpression)) {
jsonPathFilter = " AND metadata::jsonb @@ '" + nativeFilterExpression + "'::jsonpath ";
}
double distance = 1.0 - request.getSimilarityThreshold();
PGvector queryEmbedding = this.getQueryEmbedding(request.getQuery());
return this.jdbcTemplate.query(String.format(this.getDistanceType().similaritySearchSqlTemplate, tableName, jsonPathFilter), new ExtendPgVectorStore.DocumentRowMapper(this.objectMapper), new Object[]{queryEmbedding, queryEmbedding, distance, request.getTopK()});
}
public List<Double> embeddingDistance(String query,String tableName) {
return this.jdbcTemplate.query("SELECT embedding " + this.comparisonOperator() + " ? AS distance FROM vector_store", new RowMapper<Double>() {
@Nullable
public Double mapRow(ResultSet rs, int rowNum) throws SQLException {
return rs.getDouble("distance");
}
}, new Object[]{this.getQueryEmbedding(query)});
}
private PGvector getQueryEmbedding(String query) {
List<Double> embedding = this.embeddingModel.embed(query);
return new PGvector(this.toFloatArray(embedding));
}
private String comparisonOperator() {
return this.getDistanceType().operator;
}
/* public void afterPropertiesSet() throws Exception {
if (this.initializeSchema) {
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");
if (this.removeExistingVectorStoreTable) {
this.jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store");
}
this.jdbcTemplate.execute(String.format("CREATE TABLE IF NOT EXISTS %s (\n\tid uuid DEFAULT uuid_generate_v4() PRIMARY KEY,\n\tcontent text,\n\tmetadata json,\n\tembedding vector(%d)\n)\n", "vector_store", this.embeddingDimensions()));
if (this.createIndexMethod != PgVectorStore.PgIndexType.NONE) {
this.jdbcTemplate.execute(String.format("CREATE INDEX IF NOT EXISTS %s ON %s USING %s (embedding %s)\n", "spring_ai_vector_index", "vector_store", this.createIndexMethod, this.getDistanceType().index));
}
}
}*/
int embeddingDimensions() {
if (this.dimensions > 0) {
return this.dimensions;
} else {
try {
int embeddingDimensions = this.embeddingModel.dimensions();
if (embeddingDimensions > 0) {
return embeddingDimensions;
}
} catch (Exception var2) {
logger.warn("Failed to obtain the embedding dimensions from the embedding model and fall backs to default:1536", var2);
}
return 1536;
}
}
public static enum PgDistanceType {
EUCLIDEAN_DISTANCE("<->", "vector_l2_ops", "SELECT *, embedding <-> ? AS distance FROM %s WHERE embedding <-> ? < ? %s ORDER BY distance LIMIT ? "),
NEGATIVE_INNER_PRODUCT("<#>", "vector_ip_ops", "SELECT *, (1 + (embedding <#> ?)) AS distance FROM %s WHERE (1 + (embedding <#> ?)) < ? %s ORDER BY distance LIMIT ? "),
COSINE_DISTANCE("<=>", "vector_cosine_ops", "SELECT *, embedding <=> ? AS distance FROM %s WHERE embedding <=> ? < ? %s ORDER BY distance LIMIT ? ");
public final String operator;
public final String index;
public final String similaritySearchSqlTemplate;
private PgDistanceType(String operator, String index, String sqlTemplate) {
this.operator = operator;
this.index = index;
this.similaritySearchSqlTemplate = sqlTemplate;
}
}
public static enum PgIndexType {
NONE,
IVFFLAT,
HNSW;
private PgIndexType() {
}
}
private static class DocumentRowMapper implements RowMapper<Document> {
private static final String COLUMN_EMBEDDING = "embedding";
private static final String COLUMN_METADATA = "metadata";
private static final String COLUMN_ID = "id";
private static final String COLUMN_CONTENT = "content";
private static final String COLUMN_DISTANCE = "distance";
private ObjectMapper objectMapper;
public DocumentRowMapper(ObjectMapper objectMapper) {
this.objectMapper = objectMapper;
}
public Document mapRow(ResultSet rs, int rowNum) throws SQLException {
String id = rs.getString("id");
String content = rs.getString("content");
PGobject pgMetadata = (PGobject)rs.getObject("metadata", PGobject.class);
PGobject embedding = (PGobject)rs.getObject("embedding", PGobject.class);
Float distance = rs.getFloat("distance");
Map<String, Object> metadata = this.toMap(pgMetadata);
metadata.put("distance", distance);
Document document = new Document(id, content, metadata);
document.setEmbedding(this.toDoubleList(embedding));
return document;
}
private List<Double> toDoubleList(PGobject embedding) throws SQLException {
float[] floatArray = (new PGvector(embedding.getValue())).toArray();
return IntStream.range(0, floatArray.length).mapToDouble((i) -> {
return (double)floatArray[i];
}).boxed().toList();
}
private Map<String, Object> toMap(PGobject pgObject) {
String source = pgObject.getValue();
try {
return (Map)this.objectMapper.readValue(source, Map.class);
} catch (JsonProcessingException var4) {
throw new RuntimeException(var4);
}
}
}
}
当我们要使用上面这个ExtendPgVectorStore进行操作时首先我们要排除掉原PgVectorStore的注入。
接着我们需要注入自己的ExtendPgVectorStore类
import com.lccloud.tenderdocument.vector.ExtendPgVectorStore;
import org.springframework.ai.autoconfigure.vectorstore.pgvector.PgVectorStoreProperties;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.boot.autoconfigure.AutoConfigureAfter;
import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.jdbc.core.JdbcTemplate;
@Configuration
@AutoConfigureAfter(JdbcTemplateAutoConfiguration.class)
@EnableConfigurationProperties({PgVectorStoreProperties.class})
public class PgVectorConfig {
public PgVectorConfig() {
}
/**
* 向量数据库进行检索操作
* @param jdbcTemplate
* @return
*/
@Bean
public ExtendPgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, PgVectorStoreProperties properties) {
boolean initializeSchema = properties.isInitializeSchema();
return new ExtendPgVectorStore(jdbcTemplate, embeddingModel, properties.getDimensions(), properties.getDistanceType(), properties.isRemoveExistingVectorStoreTable(), properties.getIndexType(), initializeSchema);
}
/**
* 文本分割器
* @return
*/
@Bean
public TokenTextSplitter tokenTextSplitter() {
return new TokenTextSplitter();
}
}
上面这里的PgVectorStoreProperties也可以换成我们自己的类方法(这里我懒得换就用pgVectorStore自带的了)。然后我们在使用的时候就可以注入ExtendPgVectorStore进行操作了。