diff --git a/backend/src/main/java/ch/xxx/aidoclibchat/usecase/service/TableService.java b/backend/src/main/java/ch/xxx/aidoclibchat/usecase/service/TableService.java index 0281911..95d6d0a 100644 --- a/backend/src/main/java/ch/xxx/aidoclibchat/usecase/service/TableService.java +++ b/backend/src/main/java/ch/xxx/aidoclibchat/usecase/service/TableService.java @@ -19,7 +19,6 @@ import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -27,6 +26,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.ChatClient.Builder; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -37,7 +37,6 @@ import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.support.rowset.SqlRowSet; import org.springframework.stereotype.Service; -import org.springframework.ai.chat.client.ChatClient.Builder; import ch.xxx.aidoclibchat.domain.client.ImportClient; import ch.xxx.aidoclibchat.domain.common.MetaData; @@ -99,6 +98,10 @@ Pay attention to use date('now') function to get the current date, if the questi @Value("${spring.profiles.active:}") private String activeProfile; + record MyTableData(String joinColumn, String joinTable, String columnValue, List tableRecords, + TableColumnNames tableColumnNames) { + } + public TableService(ImportClient importClient, ImportService importService, Builder builder, JdbcTemplate jdbcTemplate, TableMetadataRepository tableMetadataRepository, DocumentVsRepository documentVsRepository) { @@ -149,37 +152,48 @@ private Prompt createPrompt(SearchDto searchDto, EmbeddingContainer documentCont List tableRecords = this.tableMetadataRepository .findByTableNameIn(tableColumnNames.tableNames()).stream() .map(tableMetaData -> new TableNameSchema(tableMetaData.getTableName(), tableMetaData.getTableDdl())) - .collect(Collectors.toList()); - final AtomicReference joinColumn = new AtomicReference(""); - final AtomicReference joinTable = new AtomicReference(""); - final AtomicReference columnValue = new AtomicReference(""); - sortedRowDocs.stream().filter(myDoc -> minRowDistance <= MAX_ROW_DISTANCE) + .collect(Collectors.toList()); + var result = sortedRowDocs.stream().filter(myDoc -> minRowDistance <= MAX_ROW_DISTANCE) .filter(myRowDoc -> tableRecords.stream() .filter(myRecord -> myRecord.name().equals(myRowDoc.getMetadata().get(MetaData.TABLE_NAME))) .findFirst().isEmpty()) - .findFirst().ifPresent(myRowDoc -> { - joinTable.set(((String) myRowDoc.getMetadata().get(MetaData.TABLE_NAME))); - joinColumn.set(((String) myRowDoc.getMetadata().get(MetaData.DATANAME))); - tableColumnNames.columnNames().add(((String) myRowDoc.getMetadata().get(MetaData.DATANAME))); - columnValue.set(myRowDoc.getText()); - this.tableMetadataRepository - .findByTableNameIn(List.of(((String) myRowDoc.getMetadata().get(MetaData.TABLE_NAME)))) - .stream() - .map(myTableMetadata -> new TableNameSchema(myTableMetadata.getTableName(), - myTableMetadata.getTableDdl())) - .findFirst().ifPresent(myRecord -> tableRecords.add(myRecord)); - }); - var messages = this.createMessages(searchDto, minRowDistance, tableColumnNames, tableRecords, joinColumn, - joinTable, columnValue); + .findFirst().map(myRowDoc -> createTableData(tableColumnNames, tableRecords, myRowDoc)) + .orElseThrow(); + var messages = this.createMessages(searchDto, minRowDistance, result.tableColumnNames(), result.tableRecords(), result.joinColumn(), + result.joinTable(), result.columnValue()); Prompt prompt = new Prompt(messages); // LOGGER.info("Prompt: {}", prompt.getContents()); return prompt; } + private MyTableData createTableData(TableColumnNames tableColumnNames, List tableRecords, + Document myRowDoc) { + tableColumnNames.columnNames().add(((String) myRowDoc.getMetadata().get(MetaData.DATANAME))); + return findTable(myRowDoc).map(myRecord -> { + tableRecords.add(myRecord); + return createMyTableResult(tableColumnNames, tableRecords, myRowDoc); + }).orElse(createMyTableResult(tableColumnNames, tableRecords, myRowDoc)); + } + + private MyTableData createMyTableResult(TableColumnNames tableColumnNames, List tableRecords, + Document myRowDoc) { + return new MyTableData(((String) myRowDoc.getMetadata().get(MetaData.DATANAME)), + ((String) myRowDoc.getMetadata().get(MetaData.TABLE_NAME)), myRowDoc.getText(), tableRecords, + tableColumnNames); + } + + private Optional findTable(Document myRowDoc) { + return this.tableMetadataRepository + .findByTableNameIn(List.of(((String) myRowDoc.getMetadata().get(MetaData.TABLE_NAME)))).stream() + .map(myTableMetadata -> new TableNameSchema(myTableMetadata.getTableName(), + myTableMetadata.getTableDdl())) + .findFirst(); + } + private List createMessages(SearchDto searchDto, final Float minRowDistance, TableColumnNames tableColumnNames, List tableRecords, - final AtomicReference joinColumn, final AtomicReference joinTable, - final AtomicReference columnValue) { + final String joinColumn, final String joinTable, + final String columnValue) { SystemPromptTemplate systemPromptTemplate = this.activeProfile.contains("ollama") ? new SystemPromptTemplate(minRowDistance > MAX_ROW_DISTANCE ? String.format(this.ollamaPrompt, "") : String.format(this.ollamaPrompt, columnMatch)) @@ -188,8 +202,8 @@ private List createMessages(SearchDto searchDto, final Float minRowDist Message systemMessage = systemPromptTemplate.createMessage( Map.of("columns", tableColumnNames.columnNames().stream().collect(Collectors.joining(",")), "schemas", tableRecords.stream().map(myRecord -> myRecord.schema()).collect(Collectors.joining(";")), - "prompt", searchDto.getSearchString(), "joinColumn", joinColumn.get(), "joinTable", - joinTable.get(), "columnValue", columnValue.get())); + "prompt", searchDto.getSearchString(), "joinColumn", joinColumn, "joinTable", + joinTable, "columnValue", columnValue)); UserMessage userMessage = this.activeProfile.contains("ollama") ? new UserMessage(systemMessage.getText()) : new UserMessage(searchDto.getSearchString()); return List.of(systemMessage, userMessage);