-
Notifications
You must be signed in to change notification settings - Fork 156
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor retrievers vdms into E-RAG style.
Refine retrievers Dockerfile and requirements.txt and move `--extra-index-url` into Dockerfile for CPU Docker image. Fix issue #1004. Signed-off-by: letonghan <[email protected]>
- Loading branch information
Showing
5 changed files
with
260 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
import os | ||
import time | ||
|
||
from langchain_community.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceHubEmbeddings | ||
from langchain_community.vectorstores.vdms import VDMS, VDMS_Client | ||
|
||
from comps import CustomLogger, EmbedDoc, OpeaComponent, OpeaComponentRegistry, ServiceType | ||
|
||
from .config import EMBED_MODEL, TEI_EMBEDDING_ENDPOINT, VDMS_HOST, VDMS_PORT, VDMS_INDEX_NAME, VDMS_USE_CLIP, SEARCH_ENGINE, DISTANCE_STRATEGY | ||
|
||
logger = CustomLogger("vdms_retrievers") | ||
logflag = os.getenv("LOGFLAG", False) | ||
|
||
|
||
@OpeaComponentRegistry.register("OPEA_RETRIEVER_VDMS") | ||
class OpeaVDMsRetriever(OpeaComponent): | ||
"""A specialized retriever component derived from OpeaComponent for vdms retriever services. | ||
Attributes: | ||
client (VDMs): An instance of the vdms client for vector database operations. | ||
""" | ||
|
||
def __init__(self, name: str, description: str, config: dict = None): | ||
super().__init__(name, ServiceType.RETRIEVER.name.lower(), description, config) | ||
|
||
self.embedder = self._initialize_embedder() | ||
self.client = VDMS_Client(VDMS_HOST, VDMS_PORT) | ||
self.vector_db = self._initialize_vector_db() | ||
health_status = self.check_health() | ||
if not health_status: | ||
logger.error("OpeaVDMsRetriever health check failed.") | ||
|
||
def _initialize_embedder(self): | ||
if VDMS_USE_CLIP: | ||
from comps.third_parties.clip.src.clip_embedding import vCLIP | ||
embeddings = vCLIP({"model_name": "openai/clip-vit-base-patch32", "num_frm": 64}) | ||
if TEI_EMBEDDING_ENDPOINT: | ||
# create embeddings using TEI endpoint service | ||
if logflag: | ||
logger.info(f"[ init embedder ] TEI_EMBEDDING_ENDPOINT:{TEI_EMBEDDING_ENDPOINT}") | ||
embeddings = HuggingFaceHubEmbeddings(model=TEI_EMBEDDING_ENDPOINT) | ||
else: | ||
# create embeddings using local embedding model | ||
if logflag: | ||
logger.info(f"[ init embedder ] LOCAL_EMBEDDING_MODEL:{EMBED_MODEL}") | ||
embeddings = HuggingFaceBgeEmbeddings(model_name=EMBED_MODEL) | ||
return embeddings | ||
|
||
def _initialize_vector_db(self) -> VDMS: | ||
"""Initializes the vdms client.""" | ||
if VDMS_USE_CLIP: | ||
dimensions = self.embedder.get_embedding_length() | ||
vector_db = VDMS( | ||
client=self.client, | ||
embedding=self.embedder, | ||
collection_name=VDMS_INDEX_NAME, | ||
embedding_dimensions=dimensions, | ||
distance_strategy=DISTANCE_STRATEGY, | ||
engine=SEARCH_ENGINE, | ||
) | ||
else: | ||
vector_db = VDMS( | ||
client=self.client, | ||
embedding=self.embedder, | ||
collection_name=VDMS_INDEX_NAME, | ||
distance_strategy=DISTANCE_STRATEGY, | ||
engine=SEARCH_ENGINE, | ||
) | ||
return vector_db | ||
|
||
def check_health(self) -> bool: | ||
"""Checks the health of the retriever service. | ||
Returns: | ||
bool: True if the service is reachable and healthy, False otherwise. | ||
""" | ||
if logflag: | ||
logger.info("[ check health ] start to check health of vdms") | ||
try: | ||
if self.vector_db: | ||
logger.info("[ check health ] Successfully connected to VDMs!") | ||
return True | ||
else: | ||
logger.info(f"[ check health ] Failed to connect to VDMs: {e}") | ||
return False | ||
except Exception as e: | ||
logger.info(f"[ check health ] Failed to connect to VDMs: {e}") | ||
return False | ||
|
||
async def invoke(self, input: EmbedDoc) -> list: | ||
"""Search the VDMs index for the most similar documents to the input query. | ||
Args: | ||
input (EmbedDoc): The input query to search for. | ||
Output: | ||
list: The retrieved documents. | ||
""" | ||
if logflag: | ||
logger.info(input) | ||
|
||
if input.search_type == "similarity": | ||
search_res = self.vector_db.similarity_search_by_vector( | ||
embedding=input.embedding, k=input.k, filter=input.constraints | ||
) | ||
elif input.search_type == "similarity_distance_threshold": | ||
if input.distance_threshold is None: | ||
raise ValueError("distance_threshold must be provided for " + "similarity_distance_threshold retriever") | ||
search_res = self.vector_db.similarity_search_by_vector( | ||
embedding=input.embedding, k=input.k, distance_threshold=input.distance_threshold, filter=input.constraints | ||
) | ||
elif input.search_type == "similarity_score_threshold": | ||
docs_and_similarities = self.vector_db.similarity_search_with_relevance_scores( | ||
query=input.text, k=input.k, score_threshold=input.score_threshold, filter=input.constraints | ||
) | ||
search_res = [doc for doc, _ in docs_and_similarities] | ||
elif input.search_type == "mmr": | ||
search_res = self.vector_db.max_marginal_relevance_search( | ||
query=input.text, k=input.k, fetch_k=input.fetch_k, lambda_mult=input.lambda_mult, filter=input.constraints | ||
) | ||
|
||
if logflag: | ||
logger.info(f"retrieve result: {search_res}") | ||
|
||
return search_res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
#!/bin/bash | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
set -xe | ||
|
||
WORKPATH=$(dirname "$PWD") | ||
LOG_PATH="$WORKPATH/tests" | ||
ip_address=$(hostname -I | awk '{print $1}') | ||
no_proxy=$no_proxy,$ip_address | ||
|
||
function build_docker_images() { | ||
cd $WORKPATH | ||
hf_token="dummy" | ||
docker build --no-cache -t opea/retriever-vdms:comps \ | ||
--build-arg https_proxy=$https_proxy \ | ||
--build-arg http_proxy=$http_proxy \ | ||
--build-arg huggingfacehub_api_token=$hf_token\ | ||
-f comps/retrievers/src/Dockerfile . | ||
|
||
} | ||
|
||
function start_service() { | ||
#unset http_proxy | ||
# vdms | ||
vdms_port=55555 | ||
docker run -d --name test-comps-retriever-vdms-vector-db \ | ||
-p $vdms_port:$vdms_port intellabs/vdms:latest | ||
sleep 10s | ||
|
||
# tei endpoint | ||
tei_endpoint=5058 | ||
model="BAAI/bge-base-en-v1.5" | ||
docker run -d --name="test-comps-retriever-vdms-tei-endpoint" \ | ||
-p $tei_endpoint:80 -v ./data:/data \ | ||
-e HTTPS_PROXY=$https_proxy -e HTTP_PROXY=$https_proxy \ | ||
--pull always ghcr.io/huggingface/text-embeddings-inference:cpu-1.5 \ | ||
--model-id $model | ||
sleep 30s | ||
|
||
export TEI_EMBEDDING_ENDPOINT="http://${ip_address}:${tei_endpoint}" | ||
|
||
export INDEX_NAME="rag-vdms" | ||
|
||
# vdms retriever | ||
unset http_proxy | ||
use_clip=0 #set to 1 if openai clip embedding should be used | ||
|
||
retriever_port=5059 | ||
docker run -d --name="test-comps-retriever-vdms-server" -p $retriever_port:7000 --ipc=host \ | ||
-e VDMS_INDEX_NAME=$INDEX_NAME -e VDMS_HOST=$ip_address \ | ||
-e https_proxy=$https_proxy -e http_proxy=$http_proxy \ | ||
-e VDMS_PORT=$vdms_port -e HUGGINGFACEHUB_API_TOKEN=$HUGGINGFACEHUB_API_TOKEN \ | ||
-e TEI_EMBEDDING_ENDPOINT=$TEI_EMBEDDING_ENDPOINT -e VDMS_USE_CLIP=$use_clip \ | ||
-e RETRIEVER_COMPONENT_NAME="OPEA_RETRIEVER_VDMS" \ | ||
opea/retriever-vdms:comps | ||
sleep 3m | ||
} | ||
|
||
function validate_microservice() { | ||
|
||
|
||
retriever_port=5059 | ||
URL="http://${ip_address}:$retriever_port/v1/retrieval" | ||
|
||
test_embedding=$(python -c "import random; embedding = [random.uniform(-1, 1) for _ in range(768)]; print(embedding)") | ||
|
||
HTTP_STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X POST -d "{\"text\":\"test\",\"embedding\":${test_embedding}}" -H 'Content-Type: application/json' "$URL") | ||
|
||
echo "HTTP_STATUS = $HTTP_STATUS" | ||
|
||
if [ "$HTTP_STATUS" -eq 200 ]; then | ||
echo "[ retriever ] HTTP status is 200. Checking content..." | ||
local CONTENT=$(curl -s -X POST -d "{\"text\":\"test\",\"embedding\":${test_embedding}}" -H 'Content-Type: application/json' "$URL" | tee ${LOG_PATH}/retriever.log) | ||
|
||
if echo "$CONTENT" | grep -q "retrieved_docs"; then | ||
echo "[ retriever ] Content is as expected." | ||
else | ||
echo "[ retriever ] Content does not match the expected result: $CONTENT" | ||
docker logs test-comps-retriever-vdms-server >> ${LOG_PATH}/retriever.log | ||
exit 1 | ||
fi | ||
else | ||
echo "[ retriever ] HTTP status is not 200. Received status was $HTTP_STATUS" | ||
docker logs test-comps-retriever-vdms-server >> ${LOG_PATH}/retriever.log | ||
exit 1 | ||
fi | ||
|
||
docker logs test-comps-retriever-vdms-tei-endpoint >> ${LOG_PATH}/tei.log | ||
} | ||
|
||
function stop_docker() { | ||
cid_vdms=$(docker ps -aq --filter "name=test-comps-retriever-vdms*") | ||
if [[ ! -z "$cid_vdms" ]]; then | ||
docker stop $cid_vdms && docker rm $cid_vdms && sleep 1s | ||
fi | ||
} | ||
|
||
function main() { | ||
|
||
stop_docker | ||
|
||
build_docker_images | ||
start_service | ||
|
||
validate_microservice | ||
|
||
stop_docker | ||
echo y | docker system prune | ||
|
||
} | ||
|
||
main |