From c25eee1db608fc84fcd51c7e7544eb16b69d6d18 Mon Sep 17 00:00:00 2001 From: etvt <149910696+etvt@users.noreply.github.com> Date: Mon, 11 Dec 2023 17:50:09 +0100 Subject: [PATCH] Introduce proper dependency management. --- src/aitestdrive/controller/chat.py | 7 ++-- src/aitestdrive/controller/document.py | 6 +-- src/aitestdrive/di/__init__.py | 0 src/aitestdrive/di/factories.py | 5 +++ src/aitestdrive/di/singletons.py | 16 ++++++++ src/aitestdrive/persistence/qdrant.py | 51 +++++++++++++------------- src/aitestdrive/service/document.py | 27 +++++++------- 7 files changed, 67 insertions(+), 45 deletions(-) create mode 100644 src/aitestdrive/di/__init__.py create mode 100644 src/aitestdrive/di/factories.py create mode 100644 src/aitestdrive/di/singletons.py diff --git a/src/aitestdrive/controller/chat.py b/src/aitestdrive/controller/chat.py index 5ebfac9..f9fca2f 100644 --- a/src/aitestdrive/controller/chat.py +++ b/src/aitestdrive/controller/chat.py @@ -2,11 +2,11 @@ from typing import List import vertexai -from fastapi import APIRouter +from fastapi import APIRouter, Depends from vertexai.language_models import ChatModel from aitestdrive.common.models import ChatMessage, ChatRequest -from aitestdrive.service.document import document_service +from aitestdrive.service.document import DocumentService log = logging.getLogger(__name__) @@ -14,7 +14,8 @@ @api.post("/") -async def chat(request: ChatRequest) -> ChatMessage: +async def chat(request: ChatRequest, + document_service=Depends(DocumentService)) -> ChatMessage: log.debug(f"Request received: '{request}'") assert len(request.history) > 0 diff --git a/src/aitestdrive/controller/document.py b/src/aitestdrive/controller/document.py index 140dcde..b6fbd47 100644 --- a/src/aitestdrive/controller/document.py +++ b/src/aitestdrive/controller/document.py @@ -1,8 +1,8 @@ import logging -from fastapi import APIRouter +from fastapi import APIRouter, Depends -from aitestdrive.service.document import document_service +from aitestdrive.service.document import DocumentService log = logging.getLogger(__name__) @@ -10,7 +10,7 @@ @api.post("/re-vectorize-from-storage") -async def re_vectorize_documents_from_storage(): +async def re_vectorize_documents_from_storage(document_service=Depends(DocumentService)): log.info(f"Request received to re-vectorize documents from storage") await document_service.re_vectorize_documents_from_storage() log.info("Re-vectorization of documents done.") diff --git a/src/aitestdrive/di/__init__.py b/src/aitestdrive/di/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/aitestdrive/di/factories.py b/src/aitestdrive/di/factories.py new file mode 100644 index 0000000..98fb92d --- /dev/null +++ b/src/aitestdrive/di/factories.py @@ -0,0 +1,5 @@ +from google.cloud import storage + + +def create_storage_client(): + return storage.Client() # cannot be created automatically by FastAPI's Depends(...) diff --git a/src/aitestdrive/di/singletons.py b/src/aitestdrive/di/singletons.py new file mode 100644 index 0000000..71d1da7 --- /dev/null +++ b/src/aitestdrive/di/singletons.py @@ -0,0 +1,16 @@ +from aitestdrive.persistence.qdrant import QdrantService + +__singletons = { + QdrantService: QdrantService() +} + + +def get(clazz): + return __singletons[clazz] + + +def depends(clazz): + async def async_dep(): + return get(clazz) + + return async_dep diff --git a/src/aitestdrive/persistence/qdrant.py b/src/aitestdrive/persistence/qdrant.py index 9a32500..a88c609 100644 --- a/src/aitestdrive/persistence/qdrant.py +++ b/src/aitestdrive/persistence/qdrant.py @@ -3,6 +3,7 @@ from qdrant_client import AsyncQdrantClient from qdrant_client.http.models import Distance, VectorParams, VectorStruct +from aitestdrive.common.async_locks import ReadWriteLock from aitestdrive.common.config import config @@ -14,39 +15,39 @@ def __init__(self): url=config.qdrant_url, api_key=config.qdrant_api_key, ) + self.__lock = ReadWriteLock() async def search(self, query_vector: VectorStruct, limit: int) -> List[dict[str, Any]]: - search_result = await self.__qdrant_client.search( - collection_name=QdrantService.__collection_name, - query_vector=query_vector, - query_filter=None, - limit=limit - ) - return [hit.payload for hit in search_result] + async with self.__lock.reader(): + search_result = await self.__qdrant_client.search( + collection_name=QdrantService.__collection_name, + query_vector=query_vector, + query_filter=None, + limit=limit + ) + return [hit.payload for hit in search_result] async def re_upload_collection(self, vector_size: int, vectors: Iterable[VectorStruct], payloads: Iterable[dict[str, Any]]): - if QdrantService.__collection_name in await self.__get_collection_names(): - await self.__qdrant_client.delete_collection(QdrantService.__collection_name) - - await self.__qdrant_client.create_collection( - collection_name=QdrantService.__collection_name, - vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), - ) - - self.__qdrant_client.upload_collection( - collection_name=QdrantService.__collection_name, - vectors=vectors, - payload=payloads, - ids=None, - batch_size=256 - ) + async with self.__lock.writer(): + if QdrantService.__collection_name in await self.__get_collection_names(): + await self.__qdrant_client.delete_collection(QdrantService.__collection_name) + + await self.__qdrant_client.create_collection( + collection_name=QdrantService.__collection_name, + vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), + ) + + self.__qdrant_client.upload_collection( + collection_name=QdrantService.__collection_name, + vectors=vectors, + payload=payloads, + ids=None, + batch_size=256 + ) async def __get_collection_names(self): collections = (await self.__qdrant_client.get_collections()).collections return map(lambda coll: coll.name, collections) - - -qdrant_service = QdrantService() diff --git a/src/aitestdrive/service/document.py b/src/aitestdrive/service/document.py index ecc2f03..21b4a1e 100644 --- a/src/aitestdrive/service/document.py +++ b/src/aitestdrive/service/document.py @@ -3,29 +3,32 @@ from typing import List import pdfplumber +from fastapi import Depends from google.cloud import storage from langchain.text_splitter import TokenTextSplitter from vertexai.language_models import TextEmbeddingModel -from aitestdrive.common.async_locks import ReadWriteLock from aitestdrive.common.config import config -from aitestdrive.persistence.qdrant import qdrant_service +from aitestdrive.di import singletons +from aitestdrive.di.factories import create_storage_client +from aitestdrive.persistence.qdrant import QdrantService log = logging.getLogger(__name__) class DocumentService: - def __init__(self): - self.__storage_client = storage.Client() + def __init__(self, + storage_client: storage.Client = Depends(create_storage_client), + qdrant_service: QdrantService = Depends(singletons.depends(QdrantService))): + self.__storage_client = storage_client + self.__qdrant_service = qdrant_service self.__embedding_model = TextEmbeddingModel.from_pretrained("textembedding-gecko") - self.__lock = ReadWriteLock() async def search_documents(self, query: str, limit: int = 5) -> List[str]: query_vector = (await self.__embedding_model.get_embeddings_async([query]))[0].values - async with self.__lock.reader(): - search_results = await qdrant_service.search(query_vector, limit=limit) + search_results = await self.__qdrant_service.search(query_vector, limit=limit) return [payload['text'] for payload in search_results] @@ -42,10 +45,9 @@ async def re_vectorize_documents_from_storage(self): embeddings = await self.__embedding_model.get_embeddings_async(chunks) assert len(embeddings) > 0 - async with self.__lock.writer(): - await qdrant_service.re_upload_collection(vector_size=len(embeddings[0].values), - vectors=map(lambda embedding: embedding.values, embeddings), - payloads=map(lambda chunk: {'text': chunk}, chunks)) + await self.__qdrant_service.re_upload_collection(vector_size=len(embeddings[0].values), + vectors=map(lambda embedding: embedding.values, embeddings), + payloads=map(lambda chunk: {'text': chunk}, chunks)) @staticmethod def extract_text_from_pdf(pdf_file_content): @@ -60,6 +62,3 @@ def chunk_text(text, chunk_size, chunk_overlap) -> List[str]: chunk_overlap=chunk_overlap) return [doc.page_content for doc in text_splitter.create_documents([text])] - - -document_service = DocumentService()