Skip to content

Commit

Permalink
feat: return best matching document
Browse files Browse the repository at this point in the history
  • Loading branch information
Angular2Guy committed Nov 14, 2023
1 parent a3ce251 commit 9c370ef
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,25 @@ public Document toEntity(DocumentDto dto) {
}

public DocumentDto toDto(Document entity) {
return this.toDto(entity, false);
}

private DocumentDto toDto(Document entity, boolean noContent) {
var dto = new DocumentDto();
dto.setDocumentContent(entity.getDocumentContent());
dto.setDocumentContent(noContent ? null : entity.getDocumentContent());
dto.setDocumentName(entity.getDocumentName());
dto.setDocumentType(entity.getDocumentType());
dto.setId(entity.getId());
return dto;
}


public DocumentDto toDtoNoContent(Document entity) {
return this.toDto(entity, true);
}

public DocumentSearchDto toDto(AiResult aiResult) {
var dto = new DocumentSearchDto();
dto.setDocuments(aiResult.documents().stream().map(myDoc -> this.toDto(myDoc)).toList());
dto.setDocuments(aiResult.documents().stream().map(myDoc -> this.toDtoNoContent(myDoc)).toList());
dto.setResultStrings(aiResult.generations().stream().map(myGen -> myGen.getText()).toList());
dto.setSearchString(aiResult.searchString());
return dto;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
public class DocumentService {
private static final Logger LOGGER = LoggerFactory.getLogger(DocumentService.class);
private static final String ID = "id";
private static final String DISTANCE = "distance";
private final DocumentRepository documentRepository;
private final DocumentVsRepository documentVsRepository;
private final AiClient aiClient;
Expand Down Expand Up @@ -83,18 +84,27 @@ record TikaDocumentAndContent(org.springframework.ai.document.Document document,
public AiResult queryDocuments(String query) {
var similarDocuments = this.documentVsRepository.retrieve(query);
LOGGER.info("Documents: {}", similarDocuments.size());
Message systemMessage = this.getSystemMessage(similarDocuments,
(similarDocuments.size() <= 0 ? 2000 : Math.floorDiv(2000, similarDocuments.size())));
var mostSimilar = similarDocuments.stream()
.sorted((myDocA, myDocB) -> ((Float) myDocA.getMetadata().get(DISTANCE))
.compareTo(((Float) myDocB.getMetadata().get(DISTANCE))))
.findFirst();
var documentChunks = mostSimilar.stream()
.flatMap(mySimilar -> similarDocuments.stream()
.filter(mySimilar1 -> mySimilar1.getMetadata().get(ID).equals(mySimilar.getMetadata().get(ID))))
.toList();
Message systemMessage = this.getSystemMessage(documentChunks,
(documentChunks.size() <= 0 ? 2000 : Math.floorDiv(2000, documentChunks.size())));
UserMessage userMessage = new UserMessage(query);
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
LocalDateTime start = LocalDateTime.now();
AiResponse response = aiClient.generate(prompt);
LOGGER.info("AI response time: {}ms",
ZonedDateTime.of(LocalDateTime.now(), ZoneId.systemDefault()).toInstant().toEpochMilli()
- ZonedDateTime.of(start, ZoneId.systemDefault()).toInstant().toEpochMilli());
var documents = response.getGenerations().stream().map(myGen -> myGen.getInfo().get(ID))
.filter(myId -> (myId instanceof Long)).map(myId -> this.documentRepository.findById((Long) myId))
.filter(Optional::isPresent).map(Optional::get).toList();
var documents = mostSimilar.stream().map(myGen -> myGen.getMetadata().get(ID))
.map(myId -> (myId instanceof Integer ? Integer.valueOf((Integer) myId).longValue() : (Long) myId))
.map(myId -> this.documentRepository.findById(myId)).filter(Optional::isPresent).map(Optional::get)
.toList();
return new AiResult(query, response.getGenerations(), documents);
}

Expand Down

0 comments on commit 9c370ef

Please sign in to comment.