Skip to content

Commit

Permalink
fix: refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Angular2Guy committed Dec 18, 2024
1 parent b04d847 commit 9cfbae8
Showing 1 changed file with 39 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.ChatClient.Builder;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
Expand All @@ -37,7 +37,6 @@
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.support.rowset.SqlRowSet;
import org.springframework.stereotype.Service;
import org.springframework.ai.chat.client.ChatClient.Builder;

import ch.xxx.aidoclibchat.domain.client.ImportClient;
import ch.xxx.aidoclibchat.domain.common.MetaData;
Expand Down Expand Up @@ -99,6 +98,10 @@ Pay attention to use date('now') function to get the current date, if the questi
@Value("${spring.profiles.active:}")
private String activeProfile;

record MyTableData(String joinColumn, String joinTable, String columnValue, List<TableNameSchema> tableRecords,
TableColumnNames tableColumnNames) {
}

public TableService(ImportClient importClient, ImportService importService, Builder builder,
JdbcTemplate jdbcTemplate, TableMetadataRepository tableMetadataRepository,
DocumentVsRepository documentVsRepository) {
Expand Down Expand Up @@ -149,37 +152,48 @@ private Prompt createPrompt(SearchDto searchDto, EmbeddingContainer documentCont
List<TableNameSchema> tableRecords = this.tableMetadataRepository
.findByTableNameIn(tableColumnNames.tableNames()).stream()
.map(tableMetaData -> new TableNameSchema(tableMetaData.getTableName(), tableMetaData.getTableDdl()))
.collect(Collectors.toList());
final AtomicReference<String> joinColumn = new AtomicReference<String>("");
final AtomicReference<String> joinTable = new AtomicReference<String>("");
final AtomicReference<String> columnValue = new AtomicReference<String>("");
sortedRowDocs.stream().filter(myDoc -> minRowDistance <= MAX_ROW_DISTANCE)
.collect(Collectors.toList());
var result = sortedRowDocs.stream().filter(myDoc -> minRowDistance <= MAX_ROW_DISTANCE)
.filter(myRowDoc -> tableRecords.stream()
.filter(myRecord -> myRecord.name().equals(myRowDoc.getMetadata().get(MetaData.TABLE_NAME)))
.findFirst().isEmpty())
.findFirst().ifPresent(myRowDoc -> {
joinTable.set(((String) myRowDoc.getMetadata().get(MetaData.TABLE_NAME)));
joinColumn.set(((String) myRowDoc.getMetadata().get(MetaData.DATANAME)));
tableColumnNames.columnNames().add(((String) myRowDoc.getMetadata().get(MetaData.DATANAME)));
columnValue.set(myRowDoc.getText());
this.tableMetadataRepository
.findByTableNameIn(List.of(((String) myRowDoc.getMetadata().get(MetaData.TABLE_NAME))))
.stream()
.map(myTableMetadata -> new TableNameSchema(myTableMetadata.getTableName(),
myTableMetadata.getTableDdl()))
.findFirst().ifPresent(myRecord -> tableRecords.add(myRecord));
});
var messages = this.createMessages(searchDto, minRowDistance, tableColumnNames, tableRecords, joinColumn,
joinTable, columnValue);
.findFirst().map(myRowDoc -> createTableData(tableColumnNames, tableRecords, myRowDoc))
.orElseThrow();
var messages = this.createMessages(searchDto, minRowDistance, result.tableColumnNames(), result.tableRecords(), result.joinColumn(),
result.joinTable(), result.columnValue());
Prompt prompt = new Prompt(messages);
// LOGGER.info("Prompt: {}", prompt.getContents());
return prompt;
}

private MyTableData createTableData(TableColumnNames tableColumnNames, List<TableNameSchema> tableRecords,
Document myRowDoc) {
tableColumnNames.columnNames().add(((String) myRowDoc.getMetadata().get(MetaData.DATANAME)));
return findTable(myRowDoc).map(myRecord -> {
tableRecords.add(myRecord);
return createMyTableResult(tableColumnNames, tableRecords, myRowDoc);
}).orElse(createMyTableResult(tableColumnNames, tableRecords, myRowDoc));
}

private MyTableData createMyTableResult(TableColumnNames tableColumnNames, List<TableNameSchema> tableRecords,
Document myRowDoc) {
return new MyTableData(((String) myRowDoc.getMetadata().get(MetaData.DATANAME)),
((String) myRowDoc.getMetadata().get(MetaData.TABLE_NAME)), myRowDoc.getText(), tableRecords,
tableColumnNames);
}

private Optional<TableNameSchema> findTable(Document myRowDoc) {
return this.tableMetadataRepository
.findByTableNameIn(List.of(((String) myRowDoc.getMetadata().get(MetaData.TABLE_NAME)))).stream()
.map(myTableMetadata -> new TableNameSchema(myTableMetadata.getTableName(),
myTableMetadata.getTableDdl()))
.findFirst();
}

private List<Message> createMessages(SearchDto searchDto, final Float minRowDistance,
TableColumnNames tableColumnNames, List<TableNameSchema> tableRecords,
final AtomicReference<String> joinColumn, final AtomicReference<String> joinTable,
final AtomicReference<String> columnValue) {
final String joinColumn, final String joinTable,
final String columnValue) {
SystemPromptTemplate systemPromptTemplate = this.activeProfile.contains("ollama")
? new SystemPromptTemplate(minRowDistance > MAX_ROW_DISTANCE ? String.format(this.ollamaPrompt, "")
: String.format(this.ollamaPrompt, columnMatch))
Expand All @@ -188,8 +202,8 @@ private List<Message> createMessages(SearchDto searchDto, final Float minRowDist
Message systemMessage = systemPromptTemplate.createMessage(
Map.of("columns", tableColumnNames.columnNames().stream().collect(Collectors.joining(",")), "schemas",
tableRecords.stream().map(myRecord -> myRecord.schema()).collect(Collectors.joining(";")),
"prompt", searchDto.getSearchString(), "joinColumn", joinColumn.get(), "joinTable",
joinTable.get(), "columnValue", columnValue.get()));
"prompt", searchDto.getSearchString(), "joinColumn", joinColumn, "joinTable",
joinTable, "columnValue", columnValue));
UserMessage userMessage = this.activeProfile.contains("ollama") ? new UserMessage(systemMessage.getText())
: new UserMessage(searchDto.getSearchString());
return List.of(systemMessage, userMessage);
Expand Down

0 comments on commit 9cfbae8

Please sign in to comment.