Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creates edge embeddings collection #251

Merged
3 changes: 3 additions & 0 deletions cognee/api/v1/cognify/cognify_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents
from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.storage import add_data_points
from cognee.tasks.storage.index_graph_edges import index_graph_edges
from cognee.tasks.summarization import summarize_text

logger = logging.getLogger("cognify.v2")
Expand Down Expand Up @@ -94,6 +95,8 @@ async def run_cognify_pipeline(dataset: Dataset, user: User):
async for result in pipeline:
print(result)

await index_graph_edges()

send_telemetry("cognee.cognify EXECUTION COMPLETED", user.id)

await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_COMPLETED, {
Expand Down
3 changes: 3 additions & 0 deletions cognee/infrastructure/databases/graph/networkx/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ async def load_graph_from_file(self, file_path: str = None):
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")

self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)

for node_id, node_data in self.graph.nodes(data=True):
node_data['id'] = node_id
else:
# Log that the file does not exist and an empty graph is initialized
logger.warning("File %s not found. Initializing an empty graph.", file_path)
Expand Down
11 changes: 11 additions & 0 deletions cognee/modules/graph/models/EdgeType.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Optional
from cognee.infrastructure.engine import DataPoint

class EdgeType(DataPoint):
__tablename__ = "edge_type"
relationship_name: str
number_of_edges: int

_metadata: Optional[dict] = {
"index_fields": ["relationship_name"],
}
70 changes: 70 additions & 0 deletions cognee/tasks/storage/index_graph_edges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import logging
from collections import Counter

from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.models.EdgeType import EdgeType


async def index_graph_edges():
"""
Indexes graph edges by creating and managing vector indexes for relationship types.

This function retrieves edge data from the graph engine, counts distinct relationship
types, and creates `EdgeType` pydantic objects. It ensures that vector indexes are created for
the `relationship_name` field.

Steps:
1. Initialize the vector engine and graph engine.
2. Retrieve graph edge data and count relationship types (`relationship_name`).
3. Create vector indexes for `relationship_name` if they don't exist.
4. Transform the counted relationships into `EdgeType` objects.
5. Index the transformed data points in the vector engine.

Raises:
RuntimeError: If initialization of the vector engine or graph engine fails.

Returns:
None
"""
try:
created_indexes = {}
index_points = {}

vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
except Exception as e:
logging.error("Failed to initialize engines: %s", e)
raise RuntimeError("Initialization error") from e

_, edges_data = await graph_engine.get_graph_data()

edge_types = Counter(
item.get('relationship_name')
for edge in edges_data
for item in edge if isinstance(item, dict) and 'relationship_name' in item
)

for text, count in edge_types.items():
edge = EdgeType(relationship_name=text, number_of_edges=count)
data_point_type = type(edge)

for field_name in edge._metadata["index_fields"]:
index_name = f"{data_point_type.__tablename__}.{field_name}"

if index_name not in created_indexes:
await vector_engine.create_vector_index(data_point_type.__tablename__, field_name)
created_indexes[index_name] = True

if index_name not in index_points:
index_points[index_name] = []

indexed_data_point = edge.model_copy()
indexed_data_point._metadata["index_fields"] = [field_name]
index_points[index_name].append(indexed_data_point)

for index_name, indexable_points in index_points.items():
index_name, field_name = index_name.split(".")
await vector_engine.index_data_points(index_name, field_name, indexable_points)

return None
56 changes: 56 additions & 0 deletions cognee/tests/infrastructure/databases/test_index_graph_edges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest
from unittest.mock import AsyncMock, patch

@pytest.mark.asyncio
async def test_index_graph_edges_success():
"""Test that index_graph_edges uses the index datapoints and creates vector index."""
mock_graph_engine = AsyncMock()
mock_graph_engine.get_graph_data.return_value = (None, [
[{"relationship_name": "rel1"}, {"relationship_name": "rel1"}],
[{"relationship_name": "rel2"}]
])

mock_vector_engine = AsyncMock()

with patch("cognee.tasks.storage.index_graph_edges.get_graph_engine", return_value=mock_graph_engine), \
patch("cognee.tasks.storage.index_graph_edges.get_vector_engine", return_value=mock_vector_engine):

from cognee.tasks.storage.index_graph_edges import index_graph_edges
await index_graph_edges()

mock_graph_engine.get_graph_data.assert_awaited_once()
assert mock_vector_engine.create_vector_index.await_count == 1
assert mock_vector_engine.index_data_points.await_count == 1


@pytest.mark.asyncio
async def test_index_graph_edges_no_relationships():
"""Test that index_graph_edges handles empty relationships correctly."""
mock_graph_engine = AsyncMock()
mock_graph_engine.get_graph_data.return_value = (None, [])

mock_vector_engine = AsyncMock()

with patch("cognee.tasks.storage.index_graph_edges.get_graph_engine", return_value=mock_graph_engine), \
patch("cognee.tasks.storage.index_graph_edges.get_vector_engine", return_value=mock_vector_engine):

from cognee.tasks.storage.index_graph_edges import index_graph_edges
await index_graph_edges()

mock_graph_engine.get_graph_data.assert_awaited_once()
mock_vector_engine.create_vector_index.assert_not_awaited()
mock_vector_engine.index_data_points.assert_not_awaited()


@pytest.mark.asyncio
async def test_index_graph_edges_initialization_error():
"""Test that index_graph_edges raises a RuntimeError if initialization fails."""
with patch("cognee.tasks.storage.index_graph_edges.get_graph_engine", side_effect=Exception("Graph engine failed")), \
patch("cognee.tasks.storage.index_graph_edges.get_vector_engine", return_value=AsyncMock()):

from cognee.tasks.storage.index_graph_edges import index_graph_edges

with pytest.raises(RuntimeError, match="Initialization error"):
await index_graph_edges()

hajdul88 marked this conversation as resolved.
Show resolved Hide resolved

2 changes: 0 additions & 2 deletions examples/python/dynamic_steps_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,6 @@ async def main(enable_steps):
print(format_triplets(results))

if __name__ == '__main__':
# Flags to enable/disable steps

rebuild_kg = True
retrieve = True
steps_to_enable = {
Expand Down
Loading