Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cog 533 pydantic unit tests #230

Merged
merged 20 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 82 additions & 67 deletions cognee/modules/graph/utils/get_graph_from_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from datetime import datetime, timezone

from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model

def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}):

def get_graph_from_model(data_point: DataPoint, added_nodes=None, added_edges=None):

if not added_nodes:
borisarzentar marked this conversation as resolved.
Show resolved Hide resolved
added_nodes = {}
if not added_edges:
added_edges = {}

nodes = []
edges = []

Expand All @@ -12,87 +20,94 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes
for field_name, field_value in data_point:
if field_name == "_metadata":
continue

if isinstance(field_value, DataPoint):
elif isinstance(field_value, DataPoint):
excluded_properties.add(field_name)

property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges)

for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True

for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]

if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True

for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name

if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
}))
added_edges[str(edge_key)] = True
continue

if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
data_point,
field_name,
field_value,
nodes,
edges,
added_nodes,
added_edges,
)

elif (
isinstance(field_value, list)
and len(field_value) > 0
and isinstance(field_value[0], DataPoint)
):
excluded_properties.add(field_name)

for item in field_value:
property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges)

for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True

for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]

if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[edge_key] = True

for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name

if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
"metadata": {
"type": "list"
},
}))
added_edges[edge_key] = True
continue

data_point_properties[field_name] = field_value
n_edges_before = len(edges)
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
data_point, field_name, item, nodes, edges, added_nodes, added_edges
)
edges = edges[:n_edges_before] + [
(*edge[:3], {**edge[3], "metadata": {"type": "list"}})
for edge in edges[n_edges_before:]
]
else:
data_point_properties[field_name] = field_value

SimpleDataPointModel = copy_model(
type(data_point),
include_fields = {
include_fields={
"_metadata": (dict, data_point._metadata),
},
exclude_fields = excluded_properties,
exclude_fields=excluded_properties,
)

if include_root:
nodes.append(SimpleDataPointModel(**data_point_properties))
nodes.append(SimpleDataPointModel(**data_point_properties))

return nodes, edges


def add_nodes_and_edges(
data_point, field_name, field_value, nodes, edges, added_nodes, added_edges
borisarzentar marked this conversation as resolved.
Show resolved Hide resolved
):

property_nodes, property_edges = get_graph_from_model(
field_value, dict(added_nodes), dict(added_edges)
)
0xideas marked this conversation as resolved.
Show resolved Hide resolved

for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True

for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]

if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True

for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name

if str(edge_key) not in added_edges:
edges.append(
(
data_point.id,
property_node.id,
field_name,
{
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime(
"%Y-%m-%d %H:%M:%S"
),
},
)
)
added_edges[str(edge_key)] = True

return (nodes, edges, added_nodes, added_edges)


def get_own_properties(property_nodes, property_edges):
own_properties = []

Expand Down
44 changes: 28 additions & 16 deletions cognee/modules/graph/utils/get_model_instance_from_graph.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,41 @@
from typing import Callable

from pydantic_core import PydanticUndefined

from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model


def get_model_instance_from_graph(nodes: list[DataPoint], edges: list, entity_id: str):
node_map = {}
def get_model_instance_from_graph(
nodes: list[DataPoint],
edges: list[tuple[str, str, str, dict[str, str]]],
entity_id: str,
):
node_map = {node.id: node for node in nodes}

for node in nodes:
node_map[node.id] = node

for edge in edges:
source_node = node_map[edge[0]]
target_node = node_map[edge[1]]
edge_label = edge[2]
edge_properties = edge[3] if len(edge) == 4 else {}
for source_node_id, target_node_id, edge_label, edge_properties in edges:
source_node = node_map[source_node_id]
target_node = node_map[target_node_id]
edge_metadata = edge_properties.get("metadata", {})
edge_type = edge_metadata.get("type")
edge_type = edge_metadata.get("type", "default")

if edge_type == "list":
NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) })

node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: [target_node] })
NewModel = copy_model(
type(source_node),
{edge_label: (list[type(target_node)], PydanticUndefined)},
)
source_node_dict = source_node.model_dump()
source_node_edge_label_values = source_node_dict.get(edge_label, [])
source_node_dict[edge_label] = source_node_edge_label_values + [target_node]

node_map[source_node_id] = NewModel(**source_node_dict)
else:
NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) })
NewModel = copy_model(
type(source_node), {edge_label: (type(target_node), PydanticUndefined)}
)

node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: target_node })
node_map[target_node_id] = NewModel(
**source_node.model_dump(), **{edge_label: target_node}
)

return node_map[entity_id]
4 changes: 3 additions & 1 deletion cognee/modules/storage/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def copy_model(model: DataPoint, include_fields: dict = {}, exclude_fields: list
**include_fields
}

return create_model(model.__name__, **final_fields)
model = create_model(model.__name__, **final_fields)
model.model_rebuild()
return model

def get_own_properties(data_point: DataPoint):
properties = {}
Expand Down
18 changes: 3 additions & 15 deletions cognee/tests/unit/interfaces/graph/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from datetime import datetime, timezone
from enum import Enum
from typing import Optional

import pytest

from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils import (
get_graph_from_model,
get_model_instance_from_graph,
)


class CarTypeName(Enum):
Expand Down Expand Up @@ -47,8 +42,8 @@ class Person(DataPoint):
_metadata: dict = dict(index_fields=["name"])


@pytest.fixture(scope="session")
def graph_outputs():
@pytest.fixture(scope="function")
def boris():
boris = Person(
id="boris",
name="Boris",
Expand All @@ -70,11 +65,4 @@ def graph_outputs():
"expires_on": "2025-11-06",
},
)
nodes, edges = get_graph_from_model(boris)

car, person = nodes[0], nodes[1]
edge = edges[0]

parsed_person = get_model_instance_from_graph(nodes, edges, "boris")

return (car, person, edge, parsed_person)
return boris
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import warnings

import pytest

from cognee.modules.graph.utils import get_graph_from_model
from cognee.tests.unit.interfaces.graph.util import (
PERSON_NAMES,
count_society,
create_organization_recursive,
)


@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
def test_society_nodes_and_edges(recursive_depth):
import sys

if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
society = create_organization_recursive(
"society", "Society", PERSON_NAMES, recursive_depth
)

n_organizations, n_persons = count_society(society)
society_counts_total = n_organizations + n_persons

nodes, edges = get_graph_from_model(society)

assert (
len(nodes) == society_counts_total
), f"{society_counts_total = } != {len(nodes) = }, not all DataPoint instances were found"

assert len(edges) == (
len(nodes) - 1
), f"{(len(nodes) - 1) = } != {len(edges) = }, there have to be n_nodes - 1 edges, as each node has exactly one parent node, except for the root node"
else:
warnings.warn(
"The recursive pydantic data structure cannot be reconstructed from the graph because the 'inner' pydantic class is not defined. Hence this test is skipped. This problem is solved in Python 3.11"
)
Loading
Loading