Skip to content

Commit

Permalink
feat: use AI
Browse files Browse the repository at this point in the history
  • Loading branch information
Angular2Guy committed Jan 22, 2024
1 parent bb00713 commit 15a13b6
Showing 1 changed file with 60 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.document.Document;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.prompt.SystemPromptTemplate;
import org.springframework.ai.prompt.messages.Message;
import org.springframework.ai.prompt.messages.UserMessage;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
Expand Down Expand Up @@ -94,62 +97,64 @@ public void searchTables(SearchDto searchDto) {
searchDto.getResultAmount());
var rowDocuments = this.documentVsRepository.retrieve(searchDto.getSearchString(), MetaData.DataType.ROW,
searchDto.getResultAmount());
LOGGER.info("Table: ");
tableDocuments.forEach(myDoc -> LOGGER.info("name: {}, distance: {}",
myDoc.getMetadata().get(MetaData.DATANAME), myDoc.getMetadata().get(MetaData.DISTANCE)));
LOGGER.info("Column: ");
columnDocuments.forEach(myDoc -> LOGGER.info("name: {}, distance: {}",
myDoc.getMetadata().get(MetaData.DATANAME), myDoc.getMetadata().get(MetaData.DISTANCE)));
LOGGER.info("Row: ");
rowDocuments.forEach(
myDoc -> LOGGER.info("name: {}, content: {}, distance: {}", myDoc.getMetadata().get(MetaData.DATANAME),
myDoc.getContent(), myDoc.getMetadata().get(MetaData.DISTANCE)));
//
// LOGGER.info("Table: ");
// tableDocuments.forEach(myDoc -> LOGGER.info("name: {}, distance: {}",
// myDoc.getMetadata().get(MetaData.DATANAME), myDoc.getMetadata().get(MetaData.DISTANCE)));
// LOGGER.info("Column: ");
// columnDocuments.forEach(myDoc -> LOGGER.info("name: {}, distance: {}",
// myDoc.getMetadata().get(MetaData.DATANAME), myDoc.getMetadata().get(MetaData.DISTANCE)));
// LOGGER.info("Row: ");
// rowDocuments.forEach(
// myDoc -> LOGGER.info("name: {}, content: {}, distance: {}", myDoc.getMetadata().get(MetaData.DATANAME),
// myDoc.getContent(), myDoc.getMetadata().get(MetaData.DISTANCE)));

/*
* final Float minRowDistance = rowDocuments.stream() .map(myDoc -> (Float)
* myDoc.getMetadata().getOrDefault(MetaData.DISTANCE,
* 1.0f)).sorted().findFirst() .orElse(1.0f); LOGGER.info("MinRowDistance: {}",
* minRowDistance); var sortedRowDocs =
* rowDocuments.stream().sorted(this.compareDistance()).toList(); var
* sortedColumnDocs =
* columnDocuments.stream().sorted(this.compareDistance()).toList(); var
* sortedTableDocs =
* tableDocuments.stream().sorted(this.compareDistance()).toList();
* SystemPromptTemplate systemPromptTemplate =
* this.activeProfile.contains("ollama") ? new
* SystemPromptTemplate(minRowDistance > 0.25 ? this.ollamaPrompt :
* this.ollamaPrompt + columnMatch) : new SystemPromptTemplate(minRowDistance >
* 0.25 ? this.systemPrompt : this.systemPrompt + columnMatch); List<Document>
* filteredColDocs = sortedColumnDocs.stream() .filter(myRowDoc ->
* sortedTableDocs.stream().limit(2) .anyMatch(myTableDoc ->
* myTableDoc.getMetadata().get(MetaData.TABLE_NAME)
* .equals(myRowDoc.getMetadata().get(MetaData.TABLE_NAME))))
* .filter(StreamHelpers .distinctByKey(myRowDoc -> ((String)
* myRowDoc.getMetadata().get(MetaData.DATANAME)))) .limit(2).toList();
* Set<String> columnNames = filteredColDocs.stream() .map(myDoc -> ((String)
* myDoc.getMetadata().get(MetaData.DATANAME))).collect(Collectors.toSet());
* List<Long> tableMetadataIds = filteredColDocs.stream() .map(myDoc -> ((Long)
* myDoc.getMetadata().get(MetaData.ID))).distinct().toList(); record
* TableNameSchema(String name, String schema) { } List<TableNameSchema>
* tableRecords =
* this.tableMetadataRepository.findAllById(tableMetadataIds).stream()
* .map(tableMetaData -> new TableNameSchema(tableMetaData.getTableName(),
* tableMetaData.getTableDdl())) .toList(); final AtomicReference<String>
* joinColumn = new AtomicReference<String>(""); final AtomicReference<String>
* joinTable = new AtomicReference<String>(""); final AtomicReference<String>
* columnValue = new AtomicReference<String>("");
* sortedRowDocs.stream().filter(myDoc -> minRowDistance <=
* 0.25).findFirst().ifPresent(myRowDoc -> { joinTable.set(((String)
* myRowDoc.getMetadata().get(MetaData.TABLE_NAME))); joinColumn.set(((String)
* myRowDoc.getMetadata().get(MetaData.DATANAME)));
* columnValue.set(myRowDoc.getContent()); }); Message systemMessage =
* systemPromptTemplate .createMessage(Map.of("columns",
* columnNames.stream().collect(Collectors.joining(",")), "schemas",
* tableRecords.stream().map(myRecord ->
* myRecord.schema()).collect(Collectors.joining(";\n\n")), "prompt",
* searchDto.getSearchString(), "joinColumn", joinColumn.get(), "joinTable",
* joinTable.get(), "columnValue", columnValue.get()));
*/
final Float minRowDistance = rowDocuments.stream()
.map(myDoc -> (Float) myDoc.getMetadata().getOrDefault(MetaData.DISTANCE, 1.0f)).sorted().findFirst()
.orElse(1.0f);
LOGGER.info("MinRowDistance: {}", minRowDistance);
var sortedRowDocs = rowDocuments.stream().sorted(this.compareDistance()).toList();
var sortedColumnDocs = columnDocuments.stream().sorted(this.compareDistance()).toList();
var sortedTableDocs = tableDocuments.stream().sorted(this.compareDistance()).toList();
SystemPromptTemplate systemPromptTemplate = this.activeProfile.contains("ollama")
? new SystemPromptTemplate(minRowDistance > 0.25 ? this.ollamaPrompt : this.ollamaPrompt + columnMatch)
: new SystemPromptTemplate(minRowDistance > 0.25 ? this.systemPrompt : this.systemPrompt + columnMatch);
List<Document> filteredColDocs = sortedColumnDocs.stream()
.filter(myRowDoc -> sortedTableDocs.stream().limit(2)
.anyMatch(myTableDoc -> myTableDoc.getMetadata().get(MetaData.TABLE_NAME)
.equals(myRowDoc.getMetadata().get(MetaData.TABLE_NAME))))
.filter(StreamHelpers
.distinctByKey(myRowDoc -> ((String) myRowDoc.getMetadata().get(MetaData.DATANAME))))
.limit(2).toList();
Set<String> columnNames = filteredColDocs.stream()
.map(myDoc -> ((String) myDoc.getMetadata().get(MetaData.DATANAME))).collect(Collectors.toSet());
List<Long> tableMetadataIds = filteredColDocs.stream()
.map(myDoc -> ((String) myDoc.getMetadata().get(MetaData.ID))).map(myId -> Long.parseLong(myId)).distinct().toList();
record TableNameSchema(String name, String schema) {
}
List<TableNameSchema> tableRecords = this.tableMetadataRepository.findAllById(tableMetadataIds).stream()
.map(tableMetaData -> new TableNameSchema(tableMetaData.getTableName(), tableMetaData.getTableDdl()))
.toList();
final AtomicReference<String> joinColumn = new AtomicReference<String>("");
final AtomicReference<String> joinTable = new AtomicReference<String>("");
final AtomicReference<String> columnValue = new AtomicReference<String>("");
sortedRowDocs.stream().filter(myDoc -> minRowDistance <= 0.25).findFirst().ifPresent(myRowDoc -> {
joinTable.set(((String) myRowDoc.getMetadata().get(MetaData.TABLE_NAME)));
joinColumn.set(((String) myRowDoc.getMetadata().get(MetaData.DATANAME)));
columnValue.set(myRowDoc.getContent());
});
Message systemMessage = systemPromptTemplate
.createMessage(Map.of("columns", columnNames.stream().collect(Collectors.joining(",")), "schemas",
tableRecords.stream().map(myRecord -> myRecord.schema()).collect(Collectors.joining(";\n\n")),
"prompt", searchDto.getSearchString(), "joinColumn", joinColumn.get(), "joinTable",
joinTable.get(), "columnValue", columnValue.get()));
UserMessage userMessage = new UserMessage(searchDto.getSearchString());
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));

var chatStart = new Date();
ChatResponse response = chatClient.generate(prompt);
LOGGER.info("AI response time: {}ms", new Date().getTime() - chatStart.getTime());
LOGGER.info("AI response: {}", response.getGenerations().stream().map(myGen -> myGen.getContent()).collect(Collectors.joining(",")));
}

private Comparator<? super Document> compareDistance() {
Expand Down

0 comments on commit 15a13b6

Please sign in to comment.