Skip to content

Commit

Permalink
Feature: Integrate Milvus as the VectorDatabase
Browse files Browse the repository at this point in the history
  • Loading branch information
jinhonglin-ryan committed Dec 3, 2024
1 parent 42ab601 commit f650700
Show file tree
Hide file tree
Showing 6 changed files with 486 additions and 18 deletions.
43 changes: 29 additions & 14 deletions cognee/infrastructure/databases/vector/create_vector_engine.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Dict


class VectorConfig(Dict):
vector_db_url: str
vector_db_port: str
vector_db_key: str
vector_db_provider: str


def create_vector_engine(config: VectorConfig, embedding_engine):
if config["vector_db_provider"] == "weaviate":
from .weaviate_db import WeaviateAdapter
Expand All @@ -16,24 +18,37 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
return WeaviateAdapter(
config["vector_db_url"],
config["vector_db_key"],
embedding_engine = embedding_engine
embedding_engine=embedding_engine
)

elif config["vector_db_provider"] == "qdrant":
if not (config["vector_db_url"] and config["vector_db_key"]):
raise EnvironmentError("Missing requred Qdrant credentials!")

from .qdrant.QDrantAdapter import QDrantAdapter

return QDrantAdapter(
url = config["vector_db_url"],
api_key = config["vector_db_key"],
embedding_engine = embedding_engine
url=config["vector_db_url"],
api_key=config["vector_db_key"],
embedding_engine=embedding_engine
)

elif config['vector_db_provider'] == 'milvus':
from .milvus.MilvusAdapter import MilvusAdapter

if not config["vector_db_url"]:
raise EnvironmentError("Missing required Milvus credentials!")

return MilvusAdapter(
url=config["vector_db_url"],
api_key=config['vector_db_key'],
embedding_engine=embedding_engine
)


elif config["vector_db_provider"] == "pgvector":
from cognee.infrastructure.databases.relational import get_relational_config

# Get configuration for postgres database
relational_config = get_relational_config()
db_username = relational_config.db_username
Expand All @@ -52,8 +67,8 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
from .pgvector.PGVectorAdapter import PGVectorAdapter

return PGVectorAdapter(
connection_string,
config["vector_db_key"],
connection_string,
config["vector_db_key"],
embedding_engine,
)

Expand All @@ -64,16 +79,16 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter

return FalkorDBAdapter(
database_url = config["vector_db_url"],
database_port = config["vector_db_port"],
embedding_engine = embedding_engine,
database_url=config["vector_db_url"],
database_port=config["vector_db_port"],
embedding_engine=embedding_engine,
)

else:
from .lancedb.LanceDBAdapter import LanceDBAdapter

return LanceDBAdapter(
url = config["vector_db_url"],
api_key = config["vector_db_key"],
embedding_engine = embedding_engine,
url=config["vector_db_url"],
api_key=config["vector_db_key"],
embedding_engine=embedding_engine,
)
245 changes: 245 additions & 0 deletions cognee/infrastructure/databases/vector/milvus/MilvusAdapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import asyncio
import logging
from typing import List, Optional
from uuid import UUID
from cognee.infrastructure.engine import DataPoint
from ..vector_db_interface import VectorDBInterface
from ..models.ScoredResult import ScoredResult
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from pymilvus import MilvusClient

logger = logging.getLogger("MilvusAdapter")


class IndexSchema(DataPoint):
text: str

_metadata: dict = {
"index_fields": ["text"]
}


class MilvusAdapter(VectorDBInterface):
name = "Milvus"
url: str
api_key: Optional[str]
embedding_engine: EmbeddingEngine = None

def __init__(self, url: str, api_key: Optional[str], embedding_engine: EmbeddingEngine):
self.url = url
self.api_key = api_key

self.embedding_engine = embedding_engine

def get_milvus_client(self) -> MilvusClient:
if self.api_key is not None:
client = MilvusClient(uri=self.url, token=self.api_key)
else:
client = MilvusClient(uri=self.url)
return client

async def embed_data(self, data: List[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data)

async def has_collection(self, collection_name: str) -> bool:
future = asyncio.Future()
client = self.get_milvus_client()
future.set_result(client.has_collection(collection_name=collection_name))

return await future

async def create_collection(
self,
collection_name: str,
payload_schema=None,
):
from pymilvus import DataType, MilvusException
client = self.get_milvus_client()
if client.has_collection(collection_name=collection_name):
logger.info(f"Collection '{collection_name}' already exists.")
return True

try:
dimension = self.embedding_engine.get_vector_size()
assert dimension > 0, "Embedding dimension must be greater than 0."

schema = client.create_schema(
auto_id=False,
enable_dynamic_field=False,
)

schema.add_field(
field_name="id",
datatype=DataType.VARCHAR,
is_primary=True,
max_length=36
)

schema.add_field(
field_name="vector",
datatype=DataType.FLOAT_VECTOR,
dim=dimension
)

schema.add_field(
field_name="text",
datatype=DataType.VARCHAR,
max_length=60535
)

index_params = client.prepare_index_params()
index_params.add_index(
field_name="vector",
metric_type="COSINE"
)

client.create_collection(
collection_name=collection_name,
schema=schema,
index_params=index_params
)

client.load_collection(collection_name)

logger.info(f"Collection '{collection_name}' created successfully.")
return True
except MilvusException as e:
logger.error(f"Error creating collection '{collection_name}': {str(e)}")
raise e

async def create_data_points(
self,
collection_name: str,
data_points: List[DataPoint]
):
from pymilvus import MilvusException
client = self.get_milvus_client()
data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points]
)

insert_data = [
{
"id": str(data_point.id),
"vector": data_vectors[index],
"text": data_point.text,
}
for index, data_point in enumerate(data_points)
]

try:
result = client.insert(
collection_name=collection_name,
data=insert_data
)
logger.info(
f"Inserted {result.get('insert_count', 0)} data points into collection '{collection_name}'."
)
return result
except MilvusException as e:
logger.error(f"Error inserting data points into collection '{collection_name}': {str(e)}")
raise e

async def create_vector_index(self, index_name: str, index_property_name: str):
await self.create_collection(f"{index_name}_{index_property_name}")

async def index_data_points(self, index_name: str, index_property_name: str, data_points: List[DataPoint]):
formatted_data_points = [
IndexSchema(
id=data_point.id,
text=getattr(data_point, data_point._metadata["index_fields"][0]),
)
for data_point in data_points
]
collection_name = f"{index_name}_{index_property_name}"
await self.create_data_points(collection_name, formatted_data_points)

async def retrieve(self, collection_name: str, data_point_ids: list[str]):
from pymilvus import MilvusException
client = self.get_milvus_client()
try:
filter_expression = f"""id in [{", ".join(f'"{id}"' for id in data_point_ids)}]"""

results = client.query(
collection_name=collection_name,
expr=filter_expression,
output_fields=["*"],
)
return results
except MilvusException as e:
logger.error(f"Error retrieving data points from collection '{collection_name}': {str(e)}")
raise e

async def search(
self,
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 5,
with_vector: bool = False,
):
from pymilvus import MilvusException
client = self.get_milvus_client()
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")

try:
query_vector = query_vector or (await self.embed_data([query_text]))[0]

output_fields = ["id", "text"]
if with_vector:
output_fields.append("vector")

results = client.search(
collection_name=collection_name,
data=[query_vector],
anns_field="vector",
limit=limit,
output_fields=output_fields,
search_params={
"metric_type": "COSINE",
},
)

return [
ScoredResult(
id=UUID(result["id"]),
score=result["distance"],
payload=result.get("entity", {}),
)
for result in results[0]
]
except MilvusException as e:
logger.error(f"Error during search in collection '{collection_name}': {str(e)}")
raise e

async def batch_search(self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False):
def query_search(query_vector):
return self.search(collection_name, query_vector=query_vector, limit=limit, with_vector=with_vectors)

return [await query_search(query_vector) for query_vector in await self.embed_data(query_texts)]

async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
from pymilvus import MilvusException
client = self.get_milvus_client()
try:
filter_expression = f"""id in [{", ".join(f'"{id}"' for id in data_point_ids)}]"""

delete_result = client.delete(
collection_name=collection_name,
filter=filter_expression
)

logger.info(f"Deleted data points with IDs {data_point_ids} from collection '{collection_name}'.")
return delete_result
except MilvusException as e:
logger.error(f"Error deleting data points from collection '{collection_name}': {str(e)}")
raise e

async def prune(self):
client = self.get_milvus_client()
if client:
collections = client.list_collections()
for collection_name in collections:
client.drop_collection(collection_name=collection_name)
client.close()
1 change: 1 addition & 0 deletions cognee/infrastructure/databases/vector/milvus/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .MilvusAdapter import MilvusAdapter
Loading

0 comments on commit f650700

Please sign in to comment.