From 560d6615c3bf7ee6563049179c02f580c2e63677 Mon Sep 17 00:00:00 2001 From: Angular2guy Date: Sat, 20 Jan 2024 18:30:10 +0100 Subject: [PATCH] fix: import --- .../repository/DocumentVSRepositoryBean.java | 88 ++++++++++++++++--- .../usecase/service/TableService.java | 5 +- 2 files changed, 80 insertions(+), 13 deletions(-) diff --git a/backend/src/main/java/ch/xxx/aidoclibchat/adapter/repository/DocumentVSRepositoryBean.java b/backend/src/main/java/ch/xxx/aidoclibchat/adapter/repository/DocumentVSRepositoryBean.java index 8130af1..84c5c69 100644 --- a/backend/src/main/java/ch/xxx/aidoclibchat/adapter/repository/DocumentVSRepositoryBean.java +++ b/backend/src/main/java/ch/xxx/aidoclibchat/adapter/repository/DocumentVSRepositoryBean.java @@ -12,8 +12,13 @@ */ package ch.xxx.aidoclibchat.adapter.repository; +import java.sql.ResultSet; +import java.sql.SQLException; import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; +import org.postgresql.util.PGobject; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingClient; import org.springframework.ai.vectorstore.PgVectorStore; @@ -23,19 +28,33 @@ import org.springframework.ai.vectorstore.filter.Filter.ExpressionType; import org.springframework.ai.vectorstore.filter.Filter.Key; import org.springframework.ai.vectorstore.filter.Filter.Value; +import org.springframework.ai.vectorstore.filter.converter.FilterExpressionConverter; import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.RowMapper; import org.springframework.stereotype.Repository; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.pgvector.PGvector; + import ch.xxx.aidoclibchat.domain.common.MetaData; import ch.xxx.aidoclibchat.domain.common.MetaData.DataType; import ch.xxx.aidoclibchat.domain.model.entity.DocumentVsRepository; @Repository public class DocumentVSRepositoryBean implements DocumentVsRepository { + private final String vectorTableName; private final VectorStore vectorStore; + private final JdbcTemplate jdbcTemplate; + private final ObjectMapper objectMapper; + private final FilterExpressionConverter filterExpressionConverter; - public DocumentVSRepositoryBean(JdbcTemplate jdbcTemplate, EmbeddingClient embeddingClient) { + public DocumentVSRepositoryBean(JdbcTemplate jdbcTemplate, EmbeddingClient embeddingClient, ObjectMapper objectMapper) { + this.jdbcTemplate = jdbcTemplate; + this.objectMapper = objectMapper; this.vectorStore = new PgVectorStore(jdbcTemplate, embeddingClient); + this.filterExpressionConverter = ((PgVectorStore) this.vectorStore).filterExpressionConverter; + this.vectorTableName = PgVectorStore.VECTOR_TABLE_NAME; } @Override @@ -67,20 +86,67 @@ public List retrieve(String query, DataType dataType) { @Override public List findAllTableDocuments() { - return this.vectorStore - .similaritySearch(SearchRequest.defaults().withSimilarityThresholdAll().withTopK(Integer.MAX_VALUE) - .withFilterExpression(new Filter.Expression(ExpressionType.OR, - new Filter.Expression(ExpressionType.EQ, new Key(MetaData.DATATYPE), - new Value(DataType.COLUMN.toString())), - new Filter.Expression(ExpressionType.OR, - new Filter.Expression(ExpressionType.EQ, new Key(MetaData.DATATYPE), - new Value(DataType.TABLE.toString())), - new Filter.Expression(ExpressionType.EQ, new Key(MetaData.DATATYPE), - new Value(DataType.ROW.toString())))))); + String nativeFilterExpression = this.filterExpressionConverter.convertExpression(new Filter.Expression(ExpressionType.NE, + new Key(MetaData.DATATYPE), new Value(DataType.DOCUMENT.toString()))); + + String jsonPathFilter = " WHERE metadata::jsonb @@ '" + nativeFilterExpression + "'::jsonpath "; + + return this.jdbcTemplate.query( + String.format("SELECT * FROM %s %s LIMIT ? ", this.vectorTableName, jsonPathFilter), + new DocumentRowMapper(this.objectMapper), 100000); } @Override public void deleteByIds(List ids) { this.vectorStore.delete(ids); } + + private static class DocumentRowMapper implements RowMapper { + + private static final String COLUMN_EMBEDDING = "embedding"; + + private static final String COLUMN_METADATA = "metadata"; + + private static final String COLUMN_ID = "id"; + + private static final String COLUMN_CONTENT = "content"; + + private ObjectMapper objectMapper; + + public DocumentRowMapper(ObjectMapper objectMapper) { + this.objectMapper = objectMapper; + } + + @Override + public Document mapRow(ResultSet rs, int rowNum) throws SQLException { + String id = rs.getString(COLUMN_ID); + String content = rs.getString(COLUMN_CONTENT); + PGobject pgMetadata = rs.getObject(COLUMN_METADATA, PGobject.class); + PGobject embedding = rs.getObject(COLUMN_EMBEDDING, PGobject.class); + + Map metadata = toMap(pgMetadata); + + Document document = new Document(id, content, metadata); + document.setEmbedding(toDoubleList(embedding)); + + return document; + } + + private List toDoubleList(PGobject embedding) throws SQLException { + float[] floatArray = new PGvector(embedding.getValue()).toArray(); + return IntStream.range(0, floatArray.length).mapToDouble(i -> floatArray[i]).boxed().toList(); + } + + private Map toMap(PGobject pgObject) { + + String source = pgObject.getValue(); + try { + return (Map) objectMapper.readValue(source, Map.class); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + } } diff --git a/backend/src/main/java/ch/xxx/aidoclibchat/usecase/service/TableService.java b/backend/src/main/java/ch/xxx/aidoclibchat/usecase/service/TableService.java index 7f83f99..6d9d0bf 100644 --- a/backend/src/main/java/ch/xxx/aidoclibchat/usecase/service/TableService.java +++ b/backend/src/main/java/ch/xxx/aidoclibchat/usecase/service/TableService.java @@ -76,8 +76,9 @@ public void searchTables(SearchDto searchDto) { columnDocuments.forEach(myDoc -> LOGGER.info("name: {}, distance: {}", myDoc.getMetadata().get(MetaData.DATANAME), myDoc.getMetadata().get(MetaData.DISTANCE))); LOGGER.info("Row: "); - rowDocuments.forEach(myDoc -> LOGGER.info("name: {}, content: {}, distance: {}", myDoc.getMetadata().get(MetaData.DATANAME), myDoc.getContent(), - myDoc.getMetadata().get(MetaData.DISTANCE))); + rowDocuments.forEach( + myDoc -> LOGGER.info("name: {}, content: {}, distance: {}", myDoc.getMetadata().get(MetaData.DATANAME), + myDoc.getContent(), myDoc.getMetadata().get(MetaData.DISTANCE))); } @Async