Skip to content

Commit

Permalink
added prompt history API
Browse files Browse the repository at this point in the history
  • Loading branch information
Amruth-Vamshi committed Nov 9, 2023
1 parent 57555d6 commit d2459c3
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 5 deletions.
38 changes: 37 additions & 1 deletion src/modules/prompt-history/prompt-history.controller.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,47 @@
import { Body, Controller, Post, Get, HttpException, HttpStatus, Param, Delete, NotFoundException } from "@nestjs/common";
import { PromptHistoryService } from "./prompt-history.service";
import { document as Document, Prisma } from "@prisma/client";
import { SearchPromptHistoryDto } from "./prompt.dto";
import { GetPromptHistoryDto, PromptHistoryResponse, SearchPromptHistoryDto } from "./prompt.dto";

@Controller("history")
export class PromptHistoryController {
constructor(private readonly promptHistoryService: PromptHistoryService) {}

@Post("find")
async findAll(
@Body() getDocumentsDto: GetPromptHistoryDto
): Promise<PromptHistoryResponse> {
try {
if(getDocumentsDto.filter.exactQuery){
const history = await this.promptHistoryService.findOneByExactQuery(getDocumentsDto.filter.exactQuery);
return {
history: [history],
pagination: {
page: 1,
totalPages: 1
}
}
}
if(
getDocumentsDto.filter &&
getDocumentsDto.filter.query &&
getDocumentsDto.filter.similarityThreshold &&
getDocumentsDto.filter.matchCount
) {
const documents = await this.promptHistoryService.getWithFilters(getDocumentsDto);
return documents
} else {
const page = getDocumentsDto.pagination.page || 1;
const perPage = getDocumentsDto.pagination.perPage || 10;
const documents = await this.promptHistoryService.findAll(page,perPage);
return documents;
}
} catch (error) {
throw new HttpException(error.message, HttpStatus.INTERNAL_SERVER_ERROR);
}
}


@Post("/searchSimilar")
async findByCriteria(
@Body() searchQueryDto: SearchPromptHistoryDto
Expand Down
5 changes: 4 additions & 1 deletion src/modules/prompt-history/prompt-history.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ import { Module } from "@nestjs/common";
import { PromptHistoryService } from "./prompt-history.service";
import { PrismaService } from "../../global-services/prisma.service";
import { ConfigService } from "@nestjs/config";
import { PromptHistoryController } from "./prompt-history.controller";
import { AiToolsService } from "../aiTools/ai-tools.service";

@Module({
providers: [PromptHistoryService, PrismaService, ConfigService],
providers: [PromptHistoryService, PrismaService, ConfigService, AiToolsService],
controllers: [PromptHistoryController]
})
export class PromptHistoryModule {}
104 changes: 101 additions & 3 deletions src/modules/prompt-history/prompt-history.service.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import { BadRequestException, Injectable } from "@nestjs/common";
import {
document as Document,
prompt_history as PromptHistory,
prompt_history as Prompt_History,
} from "@prisma/client";
import { PrismaService } from "../../global-services/prisma.service";
import { ConfigService } from "@nestjs/config";
import { CreatePromptDto, SearchPromptHistoryDto } from "./prompt.dto";
import { CreatePromptDto, GetPromptHistoryDto, SearchPromptHistoryDto, PromptHistory, PromptHistoryResponse } from "./prompt.dto";
import { CustomLogger } from "../../common/logger";
import { AiToolsService } from "../aiTools/ai-tools.service";

Expand Down Expand Up @@ -75,7 +75,7 @@ export class PromptHistoryService {
// return document;
// }

async create(queryId): Promise<PromptHistory> {
async create(queryId): Promise<Prompt_History> {
try{
let query = await this.prisma.query.findFirst({
where: {
Expand Down Expand Up @@ -162,6 +162,18 @@ export class PromptHistoryService {
}
}

async findOneByExactQuery(query: string): Promise<PromptHistory | null> {
try {
const history: PromptHistory[] = await this.prisma.$queryRaw`
SELECT "createdAt", "updatedAt", id, "queryId", "responseTime", "queryInEnglish", "responseInEnglish"
FROM prompt_history where "queryInEnglish" = ${query}
`;
return history[0];
} catch {
return null;
}
}

async softDeleteRelatedToDocument(documentId) {
const affectedPromptHistories = await this.prisma.similarity_search_response.findMany({
where: {
Expand All @@ -186,4 +198,90 @@ export class PromptHistoryService {
);
return updated
}

async getWithFilters(getDocumentsDto: GetPromptHistoryDto): Promise<any> {
const page = getDocumentsDto.pagination.page || 1;
const perPage = getDocumentsDto.pagination.perPage || 10;
const embedding: any = (
await this.aiToolsService.getEmbedding(getDocumentsDto.filter.query)
)[0];
let query_embedding = `[${embedding
.map((x) => `${x}`)
.join(",")}]`
let similarity_threshold = getDocumentsDto.filter.similarityThreshold
let match_count = getDocumentsDto.filter.matchCount
let result = await this.prisma
.$queryRawUnsafe(`
WITH matched_docs AS (
SELECT
prompt_history.id as id,
1 - (prompt_history.embedding <=> '${query_embedding}') as similarity
FROM
prompt_history
WHERE
1 - (prompt_history.embedding <=> '${query_embedding}') > ${similarity_threshold}
ORDER BY
prompt_history.embedding <=> '${query_embedding}'
LIMIT ${match_count}
),
total_count AS (
SELECT COUNT(*) AS count
FROM matched_docs
)
SELECT
json_build_object(
'pagination',
json_build_object(
'page', $1,
'perPage', $2,
'totalPages', CEIL(total_count.count::numeric / $2),
'totalDocuments', total_count.count
),
'documents',
json_agg(
json_build_object(
'createdAt', doc."createdAt",
'updatedAt', doc."updatedAt",
'id', doc.id,
'queryId', doc."queryId",
'responseTime', doc."responseTime",
'queryInEnglish', doc."queryInEnglish",
'responseInEnglish', doc."responseInEnglish"
) ORDER BY matched_docs.similarity DESC
)
) AS result
FROM
matched_docs
JOIN document AS doc ON matched_docs.id = doc.id
CROSS JOIN total_count
GROUP BY total_count.count, $1, $2
OFFSET (($1 - 1) * $2)
LIMIT $2;
`,page,perPage);
return result[0]?.result || {
pagination: {
page: 1,
perPage: 10,
totalPages: 0,
totalDocument: 0
},
documents: []
};
}

async findAll(page: number, perPage: number) : Promise<PromptHistoryResponse>{
// using raw sql inorder to get embeddings.
const history:PromptHistory[] = await this.prisma.$queryRaw`
SELECT "createdAt", "updatedAt", id, "queryId", "responseTime", "queryInEnglish", "responseInEnglish"
FROM prompt_history
ORDER BY id
OFFSET ${(page - 1) * perPage}
LIMIT ${perPage}
`;
const totalDocuments = await this.prisma.document.count();
const totalPages = Math.ceil(totalDocuments / perPage);
const pagination = { page, perPage, totalPages, totalDocuments };
return { pagination, history };
}
}
42 changes: 42 additions & 0 deletions src/modules/prompt-history/prompt.dto.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
IsDefined,
ValidateIf,
IsUUID,
IsOptional,
} from "class-validator";

export class CreatePromptDto {
Expand All @@ -30,6 +31,9 @@ export class CreatePromptDto {
}

export class SearchPromptHistoryDto {
@IsOptional()
exactQuery?: string;

@IsDefined({ message: "Query needs to be defined to search documents" })
query: string;

Expand All @@ -41,3 +45,41 @@ export class SearchPromptHistoryDto {
})
matchCount: number;
}

class Pagination {
@IsOptional()
@IsInt()
@Min(1)
page?: number;

@IsOptional()
@IsInt()
@Min(1)
perPage?: number;
}

export class GetPromptHistoryDto {
@IsOptional()
pagination?: Pagination

@IsOptional()
filter?: SearchPromptHistoryDto
}

export class PromptHistory {
createdAt: String;
updatedAt: String;
id: Number;
queryId:String;
responseTime: Number;
queryInEnglish: String;
responseInEnglish: String;
}

export interface PromptHistoryResponse {
history: PromptHistory[];
pagination: {
page: number;
totalPages: number;
}
}

0 comments on commit d2459c3

Please sign in to comment.