Skip to content

Commit

Permalink
feat: add metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
Angular2Guy committed Jan 21, 2024
1 parent b2e82f0 commit c0050dd
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public enum DocumentType {
public static final String ID = "id";
public static final String DATATYPE = "datatype";
public static final String DATANAME = "dataname";
public static final String TABLE_NAME = "tablename";
public static final String DISTANCE = "distance";
public static final String REFERENCE_COLUMN = "referenceColumn";
public static final String REFERENCE_TABLE = "referenceTable";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@
*/
package ch.xxx.aidoclibchat.usecase.service;

import java.util.Comparator;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.document.Document;
import org.springframework.ai.prompt.SystemPromptTemplate;
import org.springframework.ai.prompt.messages.Message;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;

Expand All @@ -46,13 +51,26 @@ public class TableService {
private final ImportService importService;
private final DocumentVsRepository documentVsRepository;
private final ChatClient chatClient;
private final String systemPrompt = "You're assisting with questions about documents in a catalog.\n"
+ "Use the information from the DOCUMENTS section to provide accurate answers.\n"
+ "If unsure, simply state that you don't know.\n" + "\n" + "DOCUMENTS:\n" + "{documents}";
private final String systemPrompt = "You are a Postgres expert. Given an input question, create a "
+ "syntactically correct Postgres query to run, then look at the results "
+ "of the query and return the answer to the input question.\n"
+ "Unless the user specifies in the question a specific number of "
+ "examples to obtain, query for at most 5 results using the LIMIT clause "
+ "as per Postgres. You can order the results to return the most " + "informative data in the database.\n"
+ "Never query for all columns from a table. You must query only the "
+ "columns that are needed to answer the question. Wrap each column name "
+ "in double quotes (\") to denote them as delimited identifiers.\n"
+ "Pay attention to use only the column names you can see in the tables "
+ "below. Be careful to not query for columns that do not exist. Also, "
+ "pay attention to which column is in which table.\n"
+ "Pay attention to use date('now') function to get the current date, "
+ "if the question involves \"today\".\n\n" + "\n" + "Include these columns in the query: {columns}\n"
+ "Only use the following tables:\n\n" + "{schemas}\n";

private final String ollamaPrompt = "You're assisting with questions about documents in a catalog.\n"
+ "Use the information from the DOCUMENTS section to provide accurate answers.\n"
+ "If unsure, simply state that you don't know.\n \n" + " {prompt} \n \n" + "DOCUMENTS:\n" + "{documents}";
private final String ollamaPrompt = systemPrompt + "Question: {prompt}\n";
private final String columnMatch = "Join this column: {joinColumn}\n of this table: {joinTable}\n where the column has this value: {columnValue}\n";
@Value("${spring.profiles.active:}")
private String activeProfile;

public TableService(ImportClient importClient, ImportService importService, ChatClient chatClient,
DocumentVsRepository documentVsRepository) {
Expand All @@ -69,16 +87,40 @@ public void searchTables(SearchDto searchDto) {
searchDto.getResultAmount());
var rowDocuments = this.documentVsRepository.retrieve(searchDto.getSearchString(), MetaData.DataType.ROW,
searchDto.getResultAmount());
LOGGER.info("Table: ");
tableDocuments.forEach(myDoc -> LOGGER.info("name: {}, distance: {}",
myDoc.getMetadata().get(MetaData.DATANAME), myDoc.getMetadata().get(MetaData.DISTANCE)));
LOGGER.info("Column: ");

LOGGER.info("Table: "); tableDocuments.forEach(myDoc ->
LOGGER.info("name: {}, distance: {}",
myDoc.getMetadata().get(MetaData.DATANAME),
myDoc.getMetadata().get(MetaData.DISTANCE))); LOGGER.info("Column: ");
columnDocuments.forEach(myDoc -> LOGGER.info("name: {}, distance: {}",
myDoc.getMetadata().get(MetaData.DATANAME), myDoc.getMetadata().get(MetaData.DISTANCE)));
LOGGER.info("Row: ");
rowDocuments.forEach(
myDoc -> LOGGER.info("name: {}, content: {}, distance: {}", myDoc.getMetadata().get(MetaData.DATANAME),
myDoc.getContent(), myDoc.getMetadata().get(MetaData.DISTANCE)));
myDoc.getMetadata().get(MetaData.DATANAME),
myDoc.getMetadata().get(MetaData.DISTANCE))); LOGGER.info("Row: ");
rowDocuments.forEach( myDoc ->
LOGGER.info("name: {}, content: {}, distance: {}",
myDoc.getMetadata().get(MetaData.DATANAME), myDoc.getContent(),
myDoc.getMetadata().get(MetaData.DISTANCE)));

/*
Float minRowDistance = rowDocuments.stream()
.map(myDoc -> (Float) myDoc.getMetadata().getOrDefault(MetaData.DISTANCE, 1.0f)).sorted().findFirst()
.orElse(1.0f);
LOGGER.info("MinRowDistance: {}", minRowDistance);
var sortedRowDocs = rowDocuments.stream().sorted(this.compareDistance()).toList();
var sortedColumnDocs = columnDocuments.stream().sorted(this.compareDistance()).toList();
var sortedTableDocs = tableDocuments.stream().sorted(this.compareDistance()).toList();
SystemPromptTemplate systemPromptTemplate = this.activeProfile.contains("ollama")
? new SystemPromptTemplate(minRowDistance > 0.25 ? this.ollamaPrompt : this.ollamaPrompt + columnMatch)
: new SystemPromptTemplate(minRowDistance > 0.25 ? this.systemPrompt : this.systemPrompt + columnMatch);
//sortedColumnDocs.stream().filter(myColDoc -> myColDoc.getMetadata().get(MetaData.))
Message systemMessage = systemPromptTemplate.createMessage(Map.of("columns", "documentStr", "schemas", "prompt",
"prompt", "prompt", "joinColumn", "prompt", "joinTable", "prompt", "columnValue", "prompt"));
*/
}

private Comparator<? super Document> compareDistance() {
return (myDocA, myDocB) -> ((Float) myDocA.getMetadata().get(MetaData.DISTANCE))
.compareTo(((Float) myDocB.getMetadata().get(MetaData.DISTANCE)));
}

@Async
Expand Down Expand Up @@ -133,6 +175,7 @@ private Document map(Work work) {
result.getMetadata().put(MetaData.ID, work.getId());
result.getMetadata().put(MetaData.DATATYPE, MetaData.DataType.ROW.toString());
result.getMetadata().put(MetaData.DATANAME, "style");
result.getMetadata().put(MetaData.TABLE_NAME, "museum_hours");
return result;
}

Expand All @@ -141,6 +184,7 @@ private Document map(Subject subject) {
result.getMetadata().put(MetaData.ID, subject.getWorkId());
result.getMetadata().put(MetaData.DATATYPE, MetaData.DataType.ROW.toString());
result.getMetadata().put(MetaData.DATANAME, "subject");
result.getMetadata().put(MetaData.TABLE_NAME, "subject");
return result;
}

Expand All @@ -149,6 +193,7 @@ private Document map(ColumnMetadata columnMetadata) {
result.getMetadata().put(MetaData.ID, columnMetadata.getId().toString());
result.getMetadata().put(MetaData.DATATYPE, MetaData.DataType.COLUMN.toString());
result.getMetadata().put(MetaData.DATANAME, columnMetadata.getColumnName());
result.getMetadata().put(MetaData.TABLE_NAME, columnMetadata.getTableMetadata().getTableName());
result.getMetadata().put(MetaData.PRIMARY_KEY, columnMetadata.isColumnPrimaryKey());
Optional.ofNullable(columnMetadata.getReferenceTableName()).stream().filter(myStr -> !myStr.isBlank())
.findFirst().ifPresent(myStr -> result.getMetadata().put(MetaData.REFERENCE_TABLE, myStr));
Expand All @@ -162,6 +207,7 @@ private Document map(TableMetadata tableMetadata) {
result.getMetadata().put(MetaData.ID, tableMetadata.getId().toString());
result.getMetadata().put(MetaData.DATATYPE, MetaData.DataType.TABLE.toString());
result.getMetadata().put(MetaData.DATANAME, tableMetadata.getTableName());
result.getMetadata().put(MetaData.TABLE_NAME, tableMetadata.getTableName());
result.getMetadata().put(MetaData.PRIMARY_KEY, false);
result.getMetadata().put(MetaData.TABLE_DDL, tableMetadata.getTableDdl());
return result;
Expand Down

0 comments on commit c0050dd

Please sign in to comment.