springai+pgvector+ollama实现rag

news2024/11/20 15:35:10

        首先在ollama中安装mofanke/dmeta-embedding-zh:latest。执行ollama run mofanke/dmeta-embedding-zh 。实现将文本转化为向量数据

        接着安装pgvector(建议使用pgadmin4作为可视化工具,用navicate会出现表不显示的问题)

        安装好需要的软件后我们开始编码操作。

1:在pom文件中加入:

        <!--用于连接pgsql-->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-jdbc</artifactId>
        </dependency>
        <!--用于使用pgvector来操作向量数据库-->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-pgvector-store-spring-boot-starter</artifactId>
        </dependency>
        <!--pdf解析-->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-pdf-document-reader</artifactId>
        </dependency>
        <!--文档解析l-->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-tika-document-reader</artifactId>
        </dependency>

2:在yml中配置:

spring:
  datasource:
    url: jdbc:postgresql://127.0.0.1:5432/postgres
    username: postgres
    password: password
  ai:
    vectorstore:
      pgvector:
        dimensions: 768   #不同的embeddingmodel对应的值
    ollama:
      base-url: http://127.0.0.1:11434
      chat:
        enabled: true
        options:
          model: qwen2:7b
      embedding:
        model: mofanke/dmeta-embedding-zh

3:在controller中加入:

   /**
     * 嵌入文件
     *
     * @param file 待嵌入的文件
     * @return 是否成功
     */
    @SneakyThrows
    @PostMapping("embedding")
    public List<Document> embedding(@RequestParam MultipartFile file) {

        // 从IO流中读取文件
        TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(new InputStreamResource(file.getInputStream()));
        // 将文本内容划分成更小的块
        List<Document> splitDocuments = new TokenTextSplitter()
                .apply(tikaDocumentReader.read());
        // 存入向量数据库,这个过程会自动调用embeddingModel,将文本变成向量再存入。
        vector.add(splitDocuments);
        return splitDocuments;
    }

调用上方的接口可以将文档转为向量数据存入到pgvector中

4:请求聊天,先根据聊天内容通过pgvector获取对应的数据,并将结果丢到qwen2模型中进行数据分析并返回结果

/**
     * 获取prompt
     *
     * @param message 提问内容
     * @param context 上下文
     * @return prompt
     */
    private String getChatPrompt2String(String message, String context) {
        String promptText = """
				请用仅用以下内容回答"%s" ,输出结果仅在以下内容中,输出内容仅以下内容,不需要其他描述词:
				%s
				""";
        return String.format(promptText, message, context);
    }

    @GetMapping("chatToPgVector")
    public String chatToPgVector(String message) {

        // 1. 定义提示词模板,question_answer_context会被替换成向量数据库中查询到的文档。
        String promptWithContext = """
                你是一个代码程序,你需要在文本中获取信息并输出成json格式的数据,下面是上下文信息
                ---------------------
                {question_answer_context}
                ---------------------
                给定的上下文和提供的历史信息,而不是事先的知识,回复用户的意见。如果答案不在上下文中,告诉用户你不能回答这个问题。
                """;
        //查询获取文档信息
        List<Document> documents = vector.similaritySearch(message,"test_store");
        //提取文本内容
        String content = documents.stream()
                .map(Document::getContent)
                .collect(Collectors.joining("\n"));
        System.out.println(content);
        //封装prompt并调用大模型
        String chatResponse = ollamaChatModel.call(getChatPrompt2String(message, content));
        return chatResponse;
   /*     return ChatClient.create(ollamaChatModel).prompt()
                .user(message)
                // 2. QuestionAnswerAdvisor会在运行时替换模板中的占位符`question_answer_context`,替换成向量数据库中查询到的文档。此时的query=用户的提问+替换完的提示词模板;
                .advisors(new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults(), promptWithContext))
                .call().content();*/
    }

至此一个简单的rag搜索增强demo就完成了。接下来我们来看看PgVectorStore为我们做了什么

//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by FernFlower decompiler)
//

package org.springframework.ai.vectorstore;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.pgvector.PGvector;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.IntStream;
import org.postgresql.util.PGobject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.filter.converter.PgVectorFilterExpressionConverter;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.StatementCreatorUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;

public class PgVectorStore implements VectorStore, InitializingBean {
    private static final Logger logger = LoggerFactory.getLogger(PgVectorStore.class);
    public static final int OPENAI_EMBEDDING_DIMENSION_SIZE = 1536;
    public static final int INVALID_EMBEDDING_DIMENSION = -1;
    public static final String VECTOR_TABLE_NAME = "vector_store";
    public static final String VECTOR_INDEX_NAME = "spring_ai_vector_index";
    public final FilterExpressionConverter filterExpressionConverter;
    private final JdbcTemplate jdbcTemplate;
    private final EmbeddingModel embeddingModel;
    private int dimensions;
    private PgDistanceType distanceType;
    private ObjectMapper objectMapper;
    private boolean removeExistingVectorStoreTable;
    private PgIndexType createIndexMethod;
    private final boolean initializeSchema;

    public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
        this(jdbcTemplate, embeddingModel, -1, PgVectorStore.PgDistanceType.COSINE_DISTANCE, false, PgVectorStore.PgIndexType.NONE, false);
    }

    public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions) {
        this(jdbcTemplate, embeddingModel, dimensions, PgVectorStore.PgDistanceType.COSINE_DISTANCE, false, PgVectorStore.PgIndexType.NONE, false);
    }

    public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgDistanceType distanceType, boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod, boolean initializeSchema) {
        this.filterExpressionConverter = new PgVectorFilterExpressionConverter();
        this.objectMapper = new ObjectMapper();
        this.jdbcTemplate = jdbcTemplate;
        this.embeddingModel = embeddingModel;
        this.dimensions = dimensions;
        this.distanceType = distanceType;
        this.removeExistingVectorStoreTable = removeExistingVectorStoreTable;
        this.createIndexMethod = createIndexMethod;
        this.initializeSchema = initializeSchema;
    }

    public PgDistanceType getDistanceType() {
        return this.distanceType;
    }

    public void add(final List<Document> documents) {
        final int size = documents.size();
        this.jdbcTemplate.batchUpdate("INSERT INTO vector_store (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) ON CONFLICT (id) DO UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ", new BatchPreparedStatementSetter() {
            public void setValues(PreparedStatement ps, int i) throws SQLException {
                Document document = (Document)documents.get(i);
                String content = document.getContent();
                String json = PgVectorStore.this.toJson(document.getMetadata());
                PGvector pGvector = new PGvector(PgVectorStore.this.toFloatArray(PgVectorStore.this.embeddingModel.embed(document)));
                StatementCreatorUtils.setParameterValue(ps, 1, Integer.MIN_VALUE, UUID.fromString(document.getId()));
                StatementCreatorUtils.setParameterValue(ps, 2, Integer.MIN_VALUE, content);
                StatementCreatorUtils.setParameterValue(ps, 3, Integer.MIN_VALUE, json);
                StatementCreatorUtils.setParameterValue(ps, 4, Integer.MIN_VALUE, pGvector);
                StatementCreatorUtils.setParameterValue(ps, 5, Integer.MIN_VALUE, content);
                StatementCreatorUtils.setParameterValue(ps, 6, Integer.MIN_VALUE, json);
                StatementCreatorUtils.setParameterValue(ps, 7, Integer.MIN_VALUE, pGvector);
            }

            public int getBatchSize() {
                return size;
            }
        });
    }

    private String toJson(Map<String, Object> map) {
        try {
            return this.objectMapper.writeValueAsString(map);
        } catch (JsonProcessingException var3) {
            throw new RuntimeException(var3);
        }
    }

    private float[] toFloatArray(List<Double> embeddingDouble) {
        float[] embeddingFloat = new float[embeddingDouble.size()];
        int i = 0;

        Double d;
        for(Iterator var4 = embeddingDouble.iterator(); var4.hasNext(); embeddingFloat[i++] = d.floatValue()) {
            d = (Double)var4.next();
        }

        return embeddingFloat;
    }

    public Optional<Boolean> delete(List<String> idList) {
        int updateCount = 0;

        int count;
        for(Iterator var3 = idList.iterator(); var3.hasNext(); updateCount += count) {
            String id = (String)var3.next();
            count = this.jdbcTemplate.update("DELETE FROM vector_store WHERE id = ?", new Object[]{UUID.fromString(id)});
        }

        return Optional.of(updateCount == idList.size());
    }

    public List<Document> similaritySearch(SearchRequest request) {
        String nativeFilterExpression = request.getFilterExpression() != null ? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : "";
        String jsonPathFilter = "";
        if (StringUtils.hasText(nativeFilterExpression)) {
            jsonPathFilter = " AND metadata::jsonb @@ '" + nativeFilterExpression + "'::jsonpath ";
        }

        double distance = 1.0 - request.getSimilarityThreshold();
        PGvector queryEmbedding = this.getQueryEmbedding(request.getQuery());
        return this.jdbcTemplate.query(String.format(this.getDistanceType().similaritySearchSqlTemplate, "vector_store", jsonPathFilter), new DocumentRowMapper(this.objectMapper), new Object[]{queryEmbedding, queryEmbedding, distance, request.getTopK()});
    }

    public List<Double> embeddingDistance(String query) {
        return this.jdbcTemplate.query("SELECT embedding " + this.comparisonOperator() + " ? AS distance FROM vector_store", new RowMapper<Double>() {
            @Nullable
            public Double mapRow(ResultSet rs, int rowNum) throws SQLException {
                return rs.getDouble("distance");
            }
        }, new Object[]{this.getQueryEmbedding(query)});
    }

    private PGvector getQueryEmbedding(String query) {
        List<Double> embedding = this.embeddingModel.embed(query);
        return new PGvector(this.toFloatArray(embedding));
    }

    private String comparisonOperator() {
        return this.getDistanceType().operator;
    }

    public void afterPropertiesSet() throws Exception {
        if (this.initializeSchema) {
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");
            if (this.removeExistingVectorStoreTable) {
                this.jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store");
            }

            this.jdbcTemplate.execute(String.format("CREATE TABLE IF NOT EXISTS %s (\n\tid uuid DEFAULT uuid_generate_v4() PRIMARY KEY,\n\tcontent text,\n\tmetadata json,\n\tembedding vector(%d)\n)\n", "vector_store", this.embeddingDimensions()));
            if (this.createIndexMethod != PgVectorStore.PgIndexType.NONE) {
                this.jdbcTemplate.execute(String.format("CREATE INDEX IF NOT EXISTS %s ON %s USING %s (embedding %s)\n", "spring_ai_vector_index", "vector_store", this.createIndexMethod, this.getDistanceType().index));
            }

        }
    }

    int embeddingDimensions() {
        if (this.dimensions > 0) {
            return this.dimensions;
        } else {
            try {
                int embeddingDimensions = this.embeddingModel.dimensions();
                if (embeddingDimensions > 0) {
                    return embeddingDimensions;
                }
            } catch (Exception var2) {
                logger.warn("Failed to obtain the embedding dimensions from the embedding model and fall backs to default:1536", var2);
            }

            return 1536;
        }
    }

    public static enum PgDistanceType {
        EUCLIDEAN_DISTANCE("<->", "vector_l2_ops", "SELECT *, embedding <-> ? AS distance FROM %s WHERE embedding <-> ? < ? %s ORDER BY distance LIMIT ? "),
        NEGATIVE_INNER_PRODUCT("<#>", "vector_ip_ops", "SELECT *, (1 + (embedding <#> ?)) AS distance FROM %s WHERE (1 + (embedding <#> ?)) < ? %s ORDER BY distance LIMIT ? "),
        COSINE_DISTANCE("<=>", "vector_cosine_ops", "SELECT *, embedding <=> ? AS distance FROM %s WHERE embedding <=> ? < ? %s ORDER BY distance LIMIT ? ");

        public final String operator;
        public final String index;
        public final String similaritySearchSqlTemplate;

        private PgDistanceType(String operator, String index, String sqlTemplate) {
            this.operator = operator;
            this.index = index;
            this.similaritySearchSqlTemplate = sqlTemplate;
        }
    }

    public static enum PgIndexType {
        NONE,
        IVFFLAT,
        HNSW;

        private PgIndexType() {
        }
    }

    private static class DocumentRowMapper implements RowMapper<Document> {
        private static final String COLUMN_EMBEDDING = "embedding";
        private static final String COLUMN_METADATA = "metadata";
        private static final String COLUMN_ID = "id";
        private static final String COLUMN_CONTENT = "content";
        private static final String COLUMN_DISTANCE = "distance";
        private ObjectMapper objectMapper;

        public DocumentRowMapper(ObjectMapper objectMapper) {
            this.objectMapper = objectMapper;
        }

        public Document mapRow(ResultSet rs, int rowNum) throws SQLException {
            String id = rs.getString("id");
            String content = rs.getString("content");
            PGobject pgMetadata = (PGobject)rs.getObject("metadata", PGobject.class);
            PGobject embedding = (PGobject)rs.getObject("embedding", PGobject.class);
            Float distance = rs.getFloat("distance");
            Map<String, Object> metadata = this.toMap(pgMetadata);
            metadata.put("distance", distance);
            Document document = new Document(id, content, metadata);
            document.setEmbedding(this.toDoubleList(embedding));
            return document;
        }

        private List<Double> toDoubleList(PGobject embedding) throws SQLException {
            float[] floatArray = (new PGvector(embedding.getValue())).toArray();
            return IntStream.range(0, floatArray.length).mapToDouble((i) -> {
                return (double)floatArray[i];
            }).boxed().toList();
        }

        private Map<String, Object> toMap(PGobject pgObject) {
            String source = pgObject.getValue();

            try {
                return (Map)this.objectMapper.readValue(source, Map.class);
            } catch (JsonProcessingException var4) {
                throw new RuntimeException(var4);
            }
        }
    }
}

我们可以看到PgVectorStore实现了InitializingBean并实现了afterPropertiesSet方法。它会在属性设置完成后执行。

 public void afterPropertiesSet() throws Exception {
        if (this.initializeSchema) {
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");
            if (this.removeExistingVectorStoreTable) {
                this.jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store");
            }

            this.jdbcTemplate.execute(String.format("CREATE TABLE IF NOT EXISTS %s (\n\tid uuid DEFAULT uuid_generate_v4() PRIMARY KEY,\n\tcontent text,\n\tmetadata json,\n\tembedding vector(%d)\n)\n", "vector_store", this.embeddingDimensions()));
            if (this.createIndexMethod != PgVectorStore.PgIndexType.NONE) {
                this.jdbcTemplate.execute(String.format("CREATE INDEX IF NOT EXISTS %s ON %s USING %s (embedding %s)\n", "spring_ai_vector_index", "vector_store", this.createIndexMethod, this.getDistanceType().index));
            }

        }
    }

这里它会根据initializeSchema(在PgVectorStoreProperties中,默认为true,我们可以yml中配置spring:ai:vectorstore:pgvector:initialize-schema:false来禁用)来判断是否帮我们建表。这里他会帮我们建一个叫vector_store的表,其中包含id(uuid),metadate(json),content(text),embedding(vector(1536))。这里1536指的就是dimensions的值。当我们用默认建的表去做pgvector的诗句存储时会出现 ERROR: expected 1536 dimensions, not 768这样的报错,就是表示我们ollama中的embedding模型输出的dimensions是768,而pgvector中的embedding是1536,他们不匹配所以无法存储。这时我们需要去pgvector中修改embedding字段的token数为768即可(这里不同模型返回的dimension值不一样,可以根据报错信息自行调整)

接下来我们看一下核心的操作方法-向数据库中插入数据

 public void add(final List<Document> documents) {
        final int size = documents.size();
        this.jdbcTemplate.batchUpdate("INSERT INTO vector_store (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) ON CONFLICT (id) DO UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ", new BatchPreparedStatementSetter() {
            public void setValues(PreparedStatement ps, int i) throws SQLException {
                Document document = (Document)documents.get(i);
                String content = document.getContent();
                String json = PgVectorStore.this.toJson(document.getMetadata());
                PGvector pGvector = new PGvector(PgVectorStore.this.toFloatArray(PgVectorStore.this.embeddingModel.embed(document)));
                StatementCreatorUtils.setParameterValue(ps, 1, Integer.MIN_VALUE, UUID.fromString(document.getId()));
                StatementCreatorUtils.setParameterValue(ps, 2, Integer.MIN_VALUE, content);
                StatementCreatorUtils.setParameterValue(ps, 3, Integer.MIN_VALUE, json);
                StatementCreatorUtils.setParameterValue(ps, 4, Integer.MIN_VALUE, pGvector);
                StatementCreatorUtils.setParameterValue(ps, 5, Integer.MIN_VALUE, content);
                StatementCreatorUtils.setParameterValue(ps, 6, Integer.MIN_VALUE, json);
                StatementCreatorUtils.setParameterValue(ps, 7, Integer.MIN_VALUE, pGvector);
            }

            public int getBatchSize() {
                return size;
            }
        });
    }

这里因为Springai刚出,也不是稳定版的,它在代码中直接写死了操作表。我们使用pgvectorStore时只能对vector_store进行操作,这在实际应用场景中可能会造成一定的局限性。所以我们可以自己写一个扩展操作类来替换它。如下:

package com.lccloud.tenderdocument.vector;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.pgvector.PGvector;
import org.postgresql.util.PGobject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.PgVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.filter.converter.PgVectorFilterExpressionConverter;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.StatementCreatorUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.*;
import java.util.stream.IntStream;

public class ExtendPgVectorStore  {
    private static final Logger logger = LoggerFactory.getLogger(ExtendPgVectorStore.class);
    public final FilterExpressionConverter filterExpressionConverter;
    private final JdbcTemplate jdbcTemplate;
    private final EmbeddingModel embeddingModel;
    private int dimensions;
    private PgVectorStore.PgDistanceType distanceType;
    private ObjectMapper objectMapper;
    private boolean removeExistingVectorStoreTable;
    private PgVectorStore.PgIndexType createIndexMethod;
    private final boolean initializeSchema;

    public ExtendPgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
        this(jdbcTemplate, embeddingModel, -1, PgVectorStore.PgDistanceType.COSINE_DISTANCE, false, PgVectorStore.PgIndexType.NONE, false);
    }

    public ExtendPgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions) {
        this(jdbcTemplate, embeddingModel, dimensions, PgVectorStore.PgDistanceType.COSINE_DISTANCE, false, PgVectorStore.PgIndexType.NONE, false);
    }

    public ExtendPgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgVectorStore.PgDistanceType distanceType, boolean removeExistingVectorStoreTable, PgVectorStore.PgIndexType createIndexMethod, boolean initializeSchema) {
        this.filterExpressionConverter = new PgVectorFilterExpressionConverter();
        this.objectMapper = new ObjectMapper();
        this.jdbcTemplate = jdbcTemplate;
        this.embeddingModel = embeddingModel;
        this.dimensions = dimensions;
        this.distanceType = distanceType;
        this.removeExistingVectorStoreTable = removeExistingVectorStoreTable;
        this.createIndexMethod = createIndexMethod;
        this.initializeSchema = initializeSchema;
    }

    public PgVectorStore.PgDistanceType getDistanceType() {
        return this.distanceType;
    }

    public void add(final List<Document> documents,String tableName) {
        final int size = documents.size();
        this.jdbcTemplate.batchUpdate("INSERT INTO "+ tableName+" (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) ON CONFLICT (id) DO UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ", new BatchPreparedStatementSetter() {
            public void setValues(PreparedStatement ps, int i) throws SQLException {
                Document document = (Document)documents.get(i);
                String content = document.getContent();
                String json = ExtendPgVectorStore.this.toJson(document.getMetadata());
                PGvector pGvector = new PGvector(ExtendPgVectorStore.this.toFloatArray(ExtendPgVectorStore.this.embeddingModel.embed(document)));
                StatementCreatorUtils.setParameterValue(ps, 1, Integer.MIN_VALUE, UUID.fromString(document.getId()));
                StatementCreatorUtils.setParameterValue(ps, 2, Integer.MIN_VALUE, content);
                StatementCreatorUtils.setParameterValue(ps, 3, Integer.MIN_VALUE, json);
                StatementCreatorUtils.setParameterValue(ps, 4, Integer.MIN_VALUE, pGvector);
                StatementCreatorUtils.setParameterValue(ps, 5, Integer.MIN_VALUE, content);
                StatementCreatorUtils.setParameterValue(ps, 6, Integer.MIN_VALUE, json);
                StatementCreatorUtils.setParameterValue(ps, 7, Integer.MIN_VALUE, pGvector);
            }

            public int getBatchSize() {
                return size;
            }
        });
    }

    private String toJson(Map<String, Object> map) {
        try {
            return this.objectMapper.writeValueAsString(map);
        } catch (JsonProcessingException var3) {
            throw new RuntimeException(var3);
        }
    }

    private float[] toFloatArray(List<Double> embeddingDouble) {
        float[] embeddingFloat = new float[embeddingDouble.size()];
        int i = 0;

        Double d;
        for(Iterator var4 = embeddingDouble.iterator(); var4.hasNext(); embeddingFloat[i++] = d.floatValue()) {
            d = (Double)var4.next();
        }

        return embeddingFloat;
    }

    public Optional<Boolean> delete(List<String> idList,String tableName) {
        int updateCount = 0;

        int count;
        for(Iterator var3 = idList.iterator(); var3.hasNext(); updateCount += count) {
            String id = (String)var3.next();
            count = this.jdbcTemplate.update("DELETE FROM "+tableName+" WHERE id = ?", new Object[]{UUID.fromString(id)});
        }

        return Optional.of(updateCount == idList.size());
    }

    public List<Document> similaritySearch(String query,String tableName) {
        return this.similaritySearch(SearchRequest.query(query),tableName);
    }

    public List<Document> similaritySearch(SearchRequest request,String tableName) {
        String nativeFilterExpression = request.getFilterExpression() != null ? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : "";
        String jsonPathFilter = "";
        if (StringUtils.hasText(nativeFilterExpression)) {
            jsonPathFilter = " AND metadata::jsonb @@ '" + nativeFilterExpression + "'::jsonpath ";
        }

        double distance = 1.0 - request.getSimilarityThreshold();
        PGvector queryEmbedding = this.getQueryEmbedding(request.getQuery());
        return this.jdbcTemplate.query(String.format(this.getDistanceType().similaritySearchSqlTemplate, tableName, jsonPathFilter), new ExtendPgVectorStore.DocumentRowMapper(this.objectMapper), new Object[]{queryEmbedding, queryEmbedding, distance, request.getTopK()});
    }

    public List<Double> embeddingDistance(String query,String tableName) {
        return this.jdbcTemplate.query("SELECT embedding " + this.comparisonOperator() + " ? AS distance FROM vector_store", new RowMapper<Double>() {
            @Nullable
            public Double mapRow(ResultSet rs, int rowNum) throws SQLException {
                return rs.getDouble("distance");
            }
        }, new Object[]{this.getQueryEmbedding(query)});
    }

    private PGvector getQueryEmbedding(String query) {
        List<Double> embedding = this.embeddingModel.embed(query);
        return new PGvector(this.toFloatArray(embedding));
    }

    private String comparisonOperator() {
        return this.getDistanceType().operator;
    }

/*    public void afterPropertiesSet() throws Exception {
        if (this.initializeSchema) {
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");
            if (this.removeExistingVectorStoreTable) {
                this.jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store");
            }

            this.jdbcTemplate.execute(String.format("CREATE TABLE IF NOT EXISTS %s (\n\tid uuid DEFAULT uuid_generate_v4() PRIMARY KEY,\n\tcontent text,\n\tmetadata json,\n\tembedding vector(%d)\n)\n", "vector_store", this.embeddingDimensions()));
            if (this.createIndexMethod != PgVectorStore.PgIndexType.NONE) {
                this.jdbcTemplate.execute(String.format("CREATE INDEX IF NOT EXISTS %s ON %s USING %s (embedding %s)\n", "spring_ai_vector_index", "vector_store", this.createIndexMethod, this.getDistanceType().index));
            }

        }
    }*/

    int embeddingDimensions() {
        if (this.dimensions > 0) {
            return this.dimensions;
        } else {
            try {
                int embeddingDimensions = this.embeddingModel.dimensions();
                if (embeddingDimensions > 0) {
                    return embeddingDimensions;
                }
            } catch (Exception var2) {
                logger.warn("Failed to obtain the embedding dimensions from the embedding model and fall backs to default:1536", var2);
            }

            return 1536;
        }
    }

    public static enum PgDistanceType {
        EUCLIDEAN_DISTANCE("<->", "vector_l2_ops", "SELECT *, embedding <-> ? AS distance FROM %s WHERE embedding <-> ? < ? %s ORDER BY distance LIMIT ? "),
        NEGATIVE_INNER_PRODUCT("<#>", "vector_ip_ops", "SELECT *, (1 + (embedding <#> ?)) AS distance FROM %s WHERE (1 + (embedding <#> ?)) < ? %s ORDER BY distance LIMIT ? "),
        COSINE_DISTANCE("<=>", "vector_cosine_ops", "SELECT *, embedding <=> ? AS distance FROM %s WHERE embedding <=> ? < ? %s ORDER BY distance LIMIT ? ");

        public final String operator;
        public final String index;
        public final String similaritySearchSqlTemplate;

        private PgDistanceType(String operator, String index, String sqlTemplate) {
            this.operator = operator;
            this.index = index;
            this.similaritySearchSqlTemplate = sqlTemplate;
        }
    }

    public static enum PgIndexType {
        NONE,
        IVFFLAT,
        HNSW;

        private PgIndexType() {
        }
    }

    private static class DocumentRowMapper implements RowMapper<Document> {
        private static final String COLUMN_EMBEDDING = "embedding";
        private static final String COLUMN_METADATA = "metadata";
        private static final String COLUMN_ID = "id";
        private static final String COLUMN_CONTENT = "content";
        private static final String COLUMN_DISTANCE = "distance";
        private ObjectMapper objectMapper;

        public DocumentRowMapper(ObjectMapper objectMapper) {
            this.objectMapper = objectMapper;
        }

        public Document mapRow(ResultSet rs, int rowNum) throws SQLException {
            String id = rs.getString("id");
            String content = rs.getString("content");
            PGobject pgMetadata = (PGobject)rs.getObject("metadata", PGobject.class);
            PGobject embedding = (PGobject)rs.getObject("embedding", PGobject.class);
            Float distance = rs.getFloat("distance");
            Map<String, Object> metadata = this.toMap(pgMetadata);
            metadata.put("distance", distance);
            Document document = new Document(id, content, metadata);
            document.setEmbedding(this.toDoubleList(embedding));
            return document;
        }

        private List<Double> toDoubleList(PGobject embedding) throws SQLException {
            float[] floatArray = (new PGvector(embedding.getValue())).toArray();
            return IntStream.range(0, floatArray.length).mapToDouble((i) -> {
                return (double)floatArray[i];
            }).boxed().toList();
        }

        private Map<String, Object> toMap(PGobject pgObject) {
            String source = pgObject.getValue();

            try {
                return (Map)this.objectMapper.readValue(source, Map.class);
            } catch (JsonProcessingException var4) {
                throw new RuntimeException(var4);
            }
        }
    }
}

当我们要使用上面这个ExtendPgVectorStore进行操作时首先我们要排除掉原PgVectorStore的注入。

接着我们需要注入自己的ExtendPgVectorStore类


import com.lccloud.tenderdocument.vector.ExtendPgVectorStore;
import org.springframework.ai.autoconfigure.vectorstore.pgvector.PgVectorStoreProperties;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.boot.autoconfigure.AutoConfigureAfter;
import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.jdbc.core.JdbcTemplate;

@Configuration
@AutoConfigureAfter(JdbcTemplateAutoConfiguration.class)
@EnableConfigurationProperties({PgVectorStoreProperties.class})
public class PgVectorConfig {

    public PgVectorConfig() {
    }
    /**
     * 向量数据库进行检索操作
     * @param jdbcTemplate
     * @return
     */
    @Bean
    public ExtendPgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, PgVectorStoreProperties properties) {
        boolean initializeSchema = properties.isInitializeSchema();
        return new ExtendPgVectorStore(jdbcTemplate, embeddingModel, properties.getDimensions(), properties.getDistanceType(), properties.isRemoveExistingVectorStoreTable(), properties.getIndexType(), initializeSchema);
    }
    /**
     * 文本分割器
     * @return
     */
    @Bean
    public TokenTextSplitter tokenTextSplitter() {
        return new TokenTextSplitter();
    }
}

上面这里的PgVectorStoreProperties也可以换成我们自己的类方法(这里我懒得换就用pgVectorStore自带的了)。然后我们在使用的时候就可以注入ExtendPgVectorStore进行操作了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1895999.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【python】OpenCV—Nighttime Low Illumination Image Enhancement

文章目录 1 背景介绍2 代码实现3 原理分析4 效果展示5 附录np.ndindexnumpy.ravelnumpy.argsortcv2.detailEnhancecv2.edgePreservingFilter 1 背景介绍 学习参考来自&#xff1a;OpenCV基础&#xff08;24&#xff09;改善夜间图像的照明 源码&#xff1a; 链接&#xff1a…

vue2 webpack使用optimization.splitChunks分包,实现按需引入,进行首屏加载优化

optimization.splitChunks的具体功能和配置信息可以去网上自行查阅。 这边简单讲一下他的使用场景、作用、如何使用&#xff1a; 1、没用使用splitChunks进行分包之前&#xff0c;所有模块都揉在一个文件里&#xff0c;那么当这个文件足够大、网速又一般的时候&#xff0c;首…

原厂商是什么意思?云管平台原厂商有哪些企业?

最近不少IT小伙伴在问关于原厂商相关问题&#xff0c;今天我们就来简单回答一下&#xff0c;仅供参考&#xff01; 原厂商是什么意思&#xff1f; 原厂商&#xff0c;或称原厂&#xff0c;是指生产特定产品或零部件的原始厂家。 软件原厂商是什么意思&#xff1f; 软件原厂…

课设:选课管理系统(Java+MySQL)

在本博客中&#xff0c;我将介绍用Java、MySQL、JDBC和Swing GUI开发一个简单的选课管理系统。 技术栈 Java&#xff1a;用于编写应用程序逻辑MySQL&#xff1a;用于存储和管理数据JDBC&#xff1a;用于连接Java应用程序和MySQL数据库Swing GUI&#xff1a;用于构建桌面应用程…

Let‘s Encrypt免费SSL证书申请最简单的步骤

随着互联网的飞速发展&#xff0c;网络安全问题愈发凸显其重要性。而HTTPS协议作为保障网站数据传输安全的重要手段&#xff0c;已经得到了广泛的应用。 申请Lets Encrypt免费泛域名SSL证书步骤 登录来此加密网站&#xff0c;输入域名&#xff0c;可以勾选泛域名和包含根域。…

Appium环境搭建,华为nova8鸿蒙系统(包括环境安装,环境配置)(一)

1.安装代码工具包 appium python client pip install appium-python-client 2.安装JDK 参考链接: ant+jmeter+jenkins从0实现持续集成(Windows)-CSDN博客 3.下载并安卓SDK 下载地址:AndroidDevTools - Android开发工具 Android SDK下载 Android Studio下载 Gradle下载…

搜维尔科技:详谈ART的工具追踪技术

您的生产流程中是否已经受益于刀具跟踪系统&#xff1f;您是否意识到它们的价值&#xff1f;因为它们可以优化您的装配顺序&#xff0c;从而节省您的时间和金钱。 目前我们提供两种工具跟踪解决方案&#xff1a; 1.ART与 VERPOSE的解决方案——易于使用的图像识别 安装在工…

C语言 | Leetcode C语言题解之第213题打家劫舍II

题目&#xff1a; 题解&#xff1a; int robRange(int* nums, int start, int end) {int first nums[start], second fmax(nums[start], nums[start 1]);for (int i start 2; i < end; i) {int temp second;second fmax(first nums[i], second);first temp;}retur…

[激光原理与应用-97]:南京科耐激光-激光焊接-焊中检测-智能制程监测系统IPM介绍 - 1 - 什么是焊接以及传统的焊接方法

目录 一、什么是焊接 1.1 概述 1.2 基本原理 二、传统的焊接技术与方法 2.1 手工电弧焊&#xff1a; 1、定义与原理 2、特点 3、焊条类型 4、应用领域 5、安全注意事项 2.2 气体保护焊&#xff1a; 1、原理与特点 2、应用领域 3、气体选择 4、注意事项 2.3 电阻…

六角法兰面螺栓机械性能

六角法兰面螺栓&#xff0c;作为一种常见的紧固件&#xff0c;因其独特的设计和优良的机械性能&#xff0c;在众多工业领域中占据重要地位。与传统的六角头螺栓相比&#xff0c;六角法兰面螺栓的底部有一个扁平的法兰面&#xff0c;能够提供更大的接触面积&#xff0c;分散压力…

[leetcode] n个骰子的点数

. - 力扣&#xff08;LeetCode&#xff09; class Solution { public:vector<double> statisticsProbability(int num) {vector<double> dp(6, 1.0 / 6.0);for (int i 2; i < num; i) {vector<double> tmp(5 * i 1, 0);for (int j 0; j < dp.size()…

算法day02 回文 罗马数字转整数

回文 搞错了String类型的indexOf方法&#xff0c;理解成获取对应下标的值&#xff0c;实际上是在找对应值的下标。 4ms 耗时最少的方法尽量不会去调用jdk提供的方法&#xff0c;而是直接使用对应的数学逻辑关系来处理&#xff0c; 甚至用 代替equals方法。 罗马数字转整数 考…

西安石油大学 课程习题信息管理系统(数据库课设)

主要技术栈 Java Mysql SpringBoot Tomcat HTML CSS JavaScript 该课设必备环境配置教程&#xff1a;&#xff08;参考给出的链接和给出的关键链接&#xff09; JAVA课设必备环境配置 教程 JDK Tomcat配置 IDEA开发环境配置 项目部署参考视频 若依框架 链接数据库格式注…

使用Python脚本实现SSH登录

调试IDE&#xff1a;PyCharm Python库&#xff1a;Paramiko 首先安装Paramiko包到PyCharm&#xff0c;具体步骤为&#xff1a;在打开的PyCharm工具中&#xff0c;选择顶部菜单栏中“File”下的“Settings”&#xff0c;在设置对话框中&#xff0c;选择“Project”下的“Proje…

taoCMS v3.0.2 文件上传漏洞(CVE-2022-23880)

前言 CVE-2022-23880是一个影响taoCMS v3.0.2的任意文件上传漏洞。攻击者可以利用此漏洞通过上传特制的PHP文件在受影响的系统上执行任意代码。 漏洞细节 描述: 在taoCMS v3.0.2的文件管理模块中存在任意文件上传漏洞。攻击者可以通过上传恶意的PHP文件来执行任意代码。 影响…

【IDEA】maven如何进行文件导入,配置并打包

一&#xff0c;介绍、安装 1、maven介绍 maven是一个Java世界中&#xff0c;构建工具。 核心功能&#xff1a; (1) 管理依赖&#xff1a; 管理文件运行的顺序逻辑依赖关系。对配置文件&#xff0c;进行构建和编译。其也是在调用jdk&#xff0c;来进行编译打包工作。 (2) 打…

编译lvgl(V8.4.0)源代码为.lib文件并验证

目录 概述 1. 软硬件信息 1.1 开发版硬件 1.2 软件版本信息 2 编译LVGL为.lib 2.1 准备工作 2.2 编译.lib 3 验证.lib 3.1 Keil中加载.lib 3.2 Keil配置头文件路径 3.3 编译代码 4 应用程序 4.1 主函数中初始化接口 4.2 LVGL demo测试 4.2.1 编写测试代码 4.2.2…

字符串和正则表达式踩坑

// 中石化加油卡号格式&#xff1a;以 100011 开头共19位public static final String ZHONGSHIYOU_OIL_CARD_PATTERN "^100011\\d{13}$";// 中石油加油卡号格式&#xff1a;以90、95、70开头共16位public static final String ZHONGYOU_OIL_CARD_PATTERN "^(9…

房屋租赁管理小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;用户管理&#xff0c;中介管理&#xff0c;房屋信息管理&#xff0c;房屋类型管理&#xff0c;租房订单管理 微信端账号功能包括&#xff1a;系统首页&#xff0c;房屋信息&#xff0c;我的 开发系统…

植物大战僵尸融合嫁接版 MAC 版本下载安装详细教程

继植物大战僵尸杂交版火了之后&#xff0c;PVZ改版可谓是百花齐放&#xff0c;最近又有一个非常好玩的模式被开发出来了&#xff0c;他们称为《植物大战僵尸融合嫁接版》 该版本并没有对植物卡牌做改动&#xff0c;而是可以将任意两种植物叠放到一起进行融合&#xff0c;产生新…