Skip to content

Commit

Permalink
Introduce proper dependency management.
Browse files Browse the repository at this point in the history
  • Loading branch information
etvt committed Dec 11, 2023
1 parent 9616e6d commit c25eee1
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 45 deletions.
7 changes: 4 additions & 3 deletions src/aitestdrive/controller/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@
from typing import List

import vertexai
from fastapi import APIRouter
from fastapi import APIRouter, Depends
from vertexai.language_models import ChatModel

from aitestdrive.common.models import ChatMessage, ChatRequest
from aitestdrive.service.document import document_service
from aitestdrive.service.document import DocumentService

log = logging.getLogger(__name__)

api = APIRouter(prefix="/chat", tags=["Chat"])


@api.post("/")
async def chat(request: ChatRequest) -> ChatMessage:
async def chat(request: ChatRequest,
document_service=Depends(DocumentService)) -> ChatMessage:
log.debug(f"Request received: '{request}'")

assert len(request.history) > 0
Expand Down
6 changes: 3 additions & 3 deletions src/aitestdrive/controller/document.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import logging

from fastapi import APIRouter
from fastapi import APIRouter, Depends

from aitestdrive.service.document import document_service
from aitestdrive.service.document import DocumentService

log = logging.getLogger(__name__)

api = APIRouter(prefix="/documents", tags=["Documents"])


@api.post("/re-vectorize-from-storage")
async def re_vectorize_documents_from_storage():
async def re_vectorize_documents_from_storage(document_service=Depends(DocumentService)):
log.info(f"Request received to re-vectorize documents from storage")
await document_service.re_vectorize_documents_from_storage()
log.info("Re-vectorization of documents done.")
Empty file added src/aitestdrive/di/__init__.py
Empty file.
5 changes: 5 additions & 0 deletions src/aitestdrive/di/factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from google.cloud import storage


def create_storage_client():
return storage.Client() # cannot be created automatically by FastAPI's Depends(...)
16 changes: 16 additions & 0 deletions src/aitestdrive/di/singletons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from aitestdrive.persistence.qdrant import QdrantService

__singletons = {
QdrantService: QdrantService()
}


def get(clazz):
return __singletons[clazz]


def depends(clazz):
async def async_dep():
return get(clazz)

return async_dep
51 changes: 26 additions & 25 deletions src/aitestdrive/persistence/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from qdrant_client import AsyncQdrantClient
from qdrant_client.http.models import Distance, VectorParams, VectorStruct

from aitestdrive.common.async_locks import ReadWriteLock
from aitestdrive.common.config import config


Expand All @@ -14,39 +15,39 @@ def __init__(self):
url=config.qdrant_url,
api_key=config.qdrant_api_key,
)
self.__lock = ReadWriteLock()

async def search(self, query_vector: VectorStruct, limit: int) -> List[dict[str, Any]]:
search_result = await self.__qdrant_client.search(
collection_name=QdrantService.__collection_name,
query_vector=query_vector,
query_filter=None,
limit=limit
)
return [hit.payload for hit in search_result]
async with self.__lock.reader():
search_result = await self.__qdrant_client.search(
collection_name=QdrantService.__collection_name,
query_vector=query_vector,
query_filter=None,
limit=limit
)
return [hit.payload for hit in search_result]

async def re_upload_collection(self,
vector_size: int,
vectors: Iterable[VectorStruct],
payloads: Iterable[dict[str, Any]]):
if QdrantService.__collection_name in await self.__get_collection_names():
await self.__qdrant_client.delete_collection(QdrantService.__collection_name)

await self.__qdrant_client.create_collection(
collection_name=QdrantService.__collection_name,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)

self.__qdrant_client.upload_collection(
collection_name=QdrantService.__collection_name,
vectors=vectors,
payload=payloads,
ids=None,
batch_size=256
)
async with self.__lock.writer():
if QdrantService.__collection_name in await self.__get_collection_names():
await self.__qdrant_client.delete_collection(QdrantService.__collection_name)

await self.__qdrant_client.create_collection(
collection_name=QdrantService.__collection_name,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)

self.__qdrant_client.upload_collection(
collection_name=QdrantService.__collection_name,
vectors=vectors,
payload=payloads,
ids=None,
batch_size=256
)

async def __get_collection_names(self):
collections = (await self.__qdrant_client.get_collections()).collections
return map(lambda coll: coll.name, collections)


qdrant_service = QdrantService()
27 changes: 13 additions & 14 deletions src/aitestdrive/service/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,32 @@
from typing import List

import pdfplumber
from fastapi import Depends
from google.cloud import storage
from langchain.text_splitter import TokenTextSplitter
from vertexai.language_models import TextEmbeddingModel

from aitestdrive.common.async_locks import ReadWriteLock
from aitestdrive.common.config import config
from aitestdrive.persistence.qdrant import qdrant_service
from aitestdrive.di import singletons
from aitestdrive.di.factories import create_storage_client
from aitestdrive.persistence.qdrant import QdrantService

log = logging.getLogger(__name__)


class DocumentService:

def __init__(self):
self.__storage_client = storage.Client()
def __init__(self,
storage_client: storage.Client = Depends(create_storage_client),
qdrant_service: QdrantService = Depends(singletons.depends(QdrantService))):
self.__storage_client = storage_client
self.__qdrant_service = qdrant_service
self.__embedding_model = TextEmbeddingModel.from_pretrained("textembedding-gecko")
self.__lock = ReadWriteLock()

async def search_documents(self, query: str, limit: int = 5) -> List[str]:
query_vector = (await self.__embedding_model.get_embeddings_async([query]))[0].values

async with self.__lock.reader():
search_results = await qdrant_service.search(query_vector, limit=limit)
search_results = await self.__qdrant_service.search(query_vector, limit=limit)

return [payload['text'] for payload in search_results]

Expand All @@ -42,10 +45,9 @@ async def re_vectorize_documents_from_storage(self):
embeddings = await self.__embedding_model.get_embeddings_async(chunks)
assert len(embeddings) > 0

async with self.__lock.writer():
await qdrant_service.re_upload_collection(vector_size=len(embeddings[0].values),
vectors=map(lambda embedding: embedding.values, embeddings),
payloads=map(lambda chunk: {'text': chunk}, chunks))
await self.__qdrant_service.re_upload_collection(vector_size=len(embeddings[0].values),
vectors=map(lambda embedding: embedding.values, embeddings),
payloads=map(lambda chunk: {'text': chunk}, chunks))

@staticmethod
def extract_text_from_pdf(pdf_file_content):
Expand All @@ -60,6 +62,3 @@ def chunk_text(text, chunk_size, chunk_overlap) -> List[str]:
chunk_overlap=chunk_overlap)

return [doc.page_content for doc in text_splitter.create_documents([text])]


document_service = DocumentService()

0 comments on commit c25eee1

Please sign in to comment.