Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cog 533 pydantic unit tests #230

Merged
merged 20 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 76 additions & 67 deletions cognee/modules/graph/utils/get_graph_from_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model

def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}):

def get_graph_from_model(
data_point: DataPoint, added_nodes=None, added_edges=None
):

if not added_nodes:
borisarzentar marked this conversation as resolved.
Show resolved Hide resolved
added_nodes = {}
if not added_edges:
added_edges = {}

nodes = []
edges = []

Expand All @@ -12,87 +21,87 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes
for field_name, field_value in data_point:
if field_name == "_metadata":
continue

if isinstance(field_value, DataPoint):
elif isinstance(field_value, DataPoint):
excluded_properties.add(field_name)

property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges)

for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True

for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]

if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True

for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name

if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
}))
added_edges[str(edge_key)] = True
continue

if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
data_point, field_name, field_value, nodes, edges, added_nodes, added_edges
)

elif (
isinstance(field_value, list)
and len(field_value) > 0
and isinstance(field_value[0], DataPoint)
):
excluded_properties.add(field_name)

for item in field_value:
property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges)

for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True

for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]

if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[edge_key] = True

for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name

if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
"metadata": {
"type": "list"
},
}))
added_edges[edge_key] = True
continue

data_point_properties[field_name] = field_value
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
data_point, field_name, item, nodes, edges, added_nodes, added_edges
)
edges = [
(*edge[:3], {**edge[3], "metadata": {"type": "list"}})
for edge in edges
]
0xideas marked this conversation as resolved.
Show resolved Hide resolved
else:
data_point_properties[field_name] = field_value

SimpleDataPointModel = copy_model(
type(data_point),
include_fields = {
include_fields={
"_metadata": (dict, data_point._metadata),
},
exclude_fields = excluded_properties,
exclude_fields=excluded_properties,
)

if include_root:
nodes.append(SimpleDataPointModel(**data_point_properties))
nodes.append(SimpleDataPointModel(**data_point_properties))

return nodes, edges


def add_nodes_and_edges(
data_point, field_name, field_value, nodes, edges, added_nodes, added_edges
borisarzentar marked this conversation as resolved.
Show resolved Hide resolved
):

property_nodes, property_edges = get_graph_from_model(
field_value, dict(added_nodes), dict(added_edges)
)
0xideas marked this conversation as resolved.
Show resolved Hide resolved

for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True

for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]

if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True

for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name

if str(edge_key) not in added_edges:
edges.append(
(
data_point.id,
property_node.id,
field_name,
{
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime(
"%Y-%m-%d %H:%M:%S"
),
},
)
)
added_edges[str(edge_key)] = True

return (nodes, edges, added_nodes, added_edges)


def get_own_properties(property_nodes, property_edges):
own_properties = []

Expand Down
31 changes: 19 additions & 12 deletions cognee/modules/graph/utils/get_model_instance_from_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,35 @@
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model

def merge_dicts(dict1, dict2, agg_fn):
merged_dict = {}
for key, value in dict1.items():
if key in dict2:
merged_dict[key] = agg_fn(value, dict2[key])
else:
merged_dict[key] = value

def get_model_instance_from_graph(nodes: list[DataPoint], edges: list, entity_id: str):
node_map = {}
for key, value in dict2.items():
if key not in merged_dict:
merged_dict[key] = value
return merged_dict
0xideas marked this conversation as resolved.
Show resolved Hide resolved

for node in nodes:
node_map[node.id] = node
def get_model_instance_from_graph(nodes: list[DataPoint], edges: list[tuple[str, str, str, dict[str, str]]], entity_id: str):
node_map = {node.id: node for node in nodes}

for edge in edges:
source_node = node_map[edge[0]]
target_node = node_map[edge[1]]
edge_label = edge[2]
edge_properties = edge[3] if len(edge) == 4 else {}
for source_node_id, target_node_id, edge_label, edge_properties in edges:
source_node = node_map[source_node_id]
target_node = node_map[target_node_id]
edge_metadata = edge_properties.get("metadata", {})
edge_type = edge_metadata.get("type")

if edge_type == "list":
0xideas marked this conversation as resolved.
Show resolved Hide resolved
NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) })

node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: [target_node] })
new_model_dict = merge_dicts(source_node.model_dump(), { edge_label: [target_node] }, lambda a, b: a + b)
node_map[source_node_id] = NewModel(**new_model_dict)
else:
NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) })

node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: target_node })
node_map[target_node_id] = NewModel(**source_node.model_dump(), **{ edge_label: target_node })

return node_map[entity_id]
18 changes: 3 additions & 15 deletions cognee/tests/unit/interfaces/graph/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from datetime import datetime, timezone
from enum import Enum
from typing import Optional

import pytest

from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils import (
get_graph_from_model,
get_model_instance_from_graph,
)


class CarTypeName(Enum):
Expand Down Expand Up @@ -47,8 +42,8 @@ class Person(DataPoint):
_metadata: dict = dict(index_fields=["name"])


@pytest.fixture(scope="session")
def graph_outputs():
@pytest.fixture(scope="function")
def boris():
boris = Person(
id="boris",
name="Boris",
Expand All @@ -70,11 +65,4 @@ def graph_outputs():
"expires_on": "2025-11-06",
},
)
nodes, edges = get_graph_from_model(boris)

car, person = nodes[0], nodes[1]
edge = edges[0]

parsed_person = get_model_instance_from_graph(nodes, edges, "boris")

return (car, person, edge, parsed_person)
return boris
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest

from cognee.modules.graph.utils import get_graph_from_model
from cognee.tests.unit.interfaces.graph.util import (
PERSON_NAMES,
count_society,
create_organization_recursive,
)


@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
def test_society_nodes_and_edges(recursive_depth):
society = create_organization_recursive(
"society", "Society", PERSON_NAMES, recursive_depth
)

n_organizations, n_persons = count_society(society)
society_counts_total = n_organizations + n_persons

nodes, edges = get_graph_from_model(society)

assert (
len(nodes) == society_counts_total
), f"{society_counts_total = } != {len(nodes) = }, not all DataPoint instances were found"

assert len(edges) == (
len(nodes) - 1
), f"{(len(nodes) - 1) = } != {len(edges) = }, there have to be n_nodes - 1 edges, as each node has exactly one parent node, except for the root node"
0xideas marked this conversation as resolved.
Show resolved Hide resolved
57 changes: 46 additions & 11 deletions cognee/tests/unit/interfaces/graph/get_graph_from_model_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
from cognee.modules.graph.utils import get_graph_from_model
from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth

EDGE_GROUND_TRUTH = (
CAR_SEDAN_EDGE = (
"car1",
"sedan",
"is_type",
{
"source_node_id": "car1",
"target_node_id": "sedan",
"relationship_name": "is_type",
},
)


BORIS_CAR_EDGE_GROUND_TRUTH = (
"boris",
"car1",
"owns_car",
Expand All @@ -12,6 +25,8 @@
},
)

CAR_TYPE_GROUND_TRUTH = {"id": "sedan"}

CAR_GROUND_TRUTH = {
"id": "car1",
"brand": "Toyota",
Expand All @@ -33,22 +48,42 @@
}


def test_extracted_person(graph_outputs):
(_, person, _, _) = graph_outputs
def test_extracted_car_type(boris):
nodes, _ = get_graph_from_model(boris)
assert len(nodes) == 3
car_type = nodes[0]
run_test_against_ground_truth("car_type", car_type, CAR_TYPE_GROUND_TRUTH)


def test_extracted_car(boris):
nodes, _ = get_graph_from_model(boris)
assert len(nodes) == 3
car = nodes[1]
run_test_against_ground_truth("car", car, CAR_GROUND_TRUTH)


def test_extracted_person(boris):
nodes, _ = get_graph_from_model(boris)
assert len(nodes) == 3
person = nodes[2]
run_test_against_ground_truth("person", person, PERSON_GROUND_TRUTH)
0xideas marked this conversation as resolved.
Show resolved Hide resolved


def test_extracted_car(graph_outputs):
(car, _, _, _) = graph_outputs
run_test_against_ground_truth("car", car, CAR_GROUND_TRUTH)
def test_extracted_car_sedan_edge(boris):
_, edges = get_graph_from_model(boris)
edge = edges[0]

assert CAR_SEDAN_EDGE[:3] == edge[:3], f"{CAR_SEDAN_EDGE[:3] = } != {edge[:3] = }"
for key, ground_truth in CAR_SEDAN_EDGE[3].items():
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"


def test_extracted_edge(graph_outputs):
(_, _, edge, _) = graph_outputs
def test_extracted_boris_car_edge(boris):
_, edges = get_graph_from_model(boris)
edge = edges[1]

assert (
EDGE_GROUND_TRUTH[:3] == edge[:3]
), f"{EDGE_GROUND_TRUTH[:3] = } != {edge[:3] = }"
for key, ground_truth in EDGE_GROUND_TRUTH[3].items():
BORIS_CAR_EDGE_GROUND_TRUTH[:3] == edge[:3]
), f"{BORIS_CAR_EDGE_GROUND_TRUTH[:3] = } != {edge[:3] = }"
for key, ground_truth in BORIS_CAR_EDGE_GROUND_TRUTH[3].items():
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest

from cognee.modules.graph.utils import (
get_graph_from_model,
get_model_instance_from_graph,
)
from cognee.tests.unit.interfaces.graph.util import (
PERSON_NAMES,
create_organization_recursive,
show_first_difference,
)


@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
def test_society_nodes_and_edges(recursive_depth):
society = create_organization_recursive(
"society", "Society", PERSON_NAMES, recursive_depth
)
nodes, edges = get_graph_from_model(society)
parsed_society = get_model_instance_from_graph(nodes, edges, "society")
0xideas marked this conversation as resolved.
Show resolved Hide resolved

assert str(society) == (str(parsed_society)), show_first_difference(
str(society), str(parsed_society), "society", "parsed_society"
)
Loading
Loading