From 55e04bcaa57114434ba354bb54dea0deceac3a02 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Wed, 2 Oct 2024 13:06:58 +0200 Subject: [PATCH] Added possibility to specify subqueries. --- neomodel/async_/match.py | 48 ++++++++++++++++++++++++++++++----- neomodel/sync_/match.py | 46 ++++++++++++++++++++++++++++----- test/async_/test_match_api.py | 36 ++++++++++++++++++++++++++ test/sync_/test_match_api.py | 36 ++++++++++++++++++++++++++ 4 files changed, 152 insertions(+), 14 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 03033415..f3390340 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -419,7 +419,9 @@ def __init__( class AsyncQueryBuilder: - def __init__(self, node_set, with_subgraph: bool = False): + def __init__( + self, node_set, with_subgraph: bool = False, subquery_context: bool = False + ): self.node_set = node_set self._ast = QueryAST() self._query_params = {} @@ -427,6 +429,7 @@ def __init__(self, node_set, with_subgraph: bool = False): self._ident_count = 0 self._node_counters = defaultdict(int) self._with_subgraph: bool = with_subgraph + self._subquery_context: bool = subquery_context async def build_ast(self): if hasattr(self.node_set, "relations_to_fetch"): @@ -442,7 +445,7 @@ async def build_ast(self): return self - async def build_source(self, source): + async def build_source(self, source) -> str: if isinstance(source, AsyncTraversal): return await self.build_traversal(source) if isinstance(source, AsyncNodeSet): @@ -548,6 +551,9 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: # contains the primary node so _contains() works # as usual self._ast.return_clause = lhs_name + if self._subquery_context: + # Don't include label in identifier if we are in a subquery + lhs_ident = lhs_name elif relation["include_in_return"]: self._additional_return(lhs_name) else: @@ -594,7 +600,7 @@ async def build_node(self, node): self._ast.result_class = node.__class__ return ident - def build_label(self, ident, cls): + def build_label(self, ident, cls) -> str: """ match nodes by a label """ @@ -703,7 +709,7 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None): self._ast.where.append(" AND ".join(stmts)) - def build_query(self): + def build_query(self) -> str: query: str = "" if self._ast.lookup: @@ -730,9 +736,15 @@ def build_query(self): query += " WITH " query += self._ast.with_clause - query += " RETURN " returned_items: list[str] = [] - if self._ast.return_clause: + if hasattr(self.node_set, "_subqueries"): + for subquery, return_set in self.node_set._subqueries: + outer_primary_var: str = self._ast.return_clause + query += f" CALL {{ WITH {outer_primary_var} {subquery} }} " + returned_items += return_set + + query += " RETURN " + if self._ast.return_clause and not self._subquery_context: returned_items.append(self._ast.return_clause) if self._ast.additional_return: returned_items += self._ast.additional_return @@ -960,6 +972,7 @@ def __init__(self, source): self.relations_to_fetch: list = [] self._extra_results: dict[str] = {} + self._subqueries: list[tuple(str, list[str])] = [] def __await__(self): return self.all().__await__() @@ -1238,6 +1251,27 @@ async def resolve_subgraph(self) -> list: ) return results + async def subquery( + self, nodeset: "AsyncNodeSet", return_set: list[str] + ) -> "AsyncNodeSet": + """Add a subquery to this node set. + + A subquery is a regular cypher query but executed within the context of a CALL + statement. Such query will generally fetch additional variables which must be + declared inside return_set variable in order to be included in the final RETURN + statement. + """ + qbuilder = await nodeset.query_cls(nodeset, subquery_context=True).build_ast() + for var in return_set: + if ( + var != qbuilder._ast.return_clause + and var not in qbuilder._ast.additional_return + and var not in nodeset._extra_results + ): + raise RuntimeError(f"Variable '{var}' is not returned by subquery.") + self._subqueries.append((qbuilder.build_query(), return_set)) + return self + class AsyncTraversal(AsyncBaseSet): """ @@ -1251,7 +1285,7 @@ class AsyncTraversal(AsyncBaseSet): :type name: :class:`str` :param definition: A relationship definition that most certainly deserves a documentation here. - :type defintion: :class:`dict` + :type definition: :class:`dict` """ def __await__(self): diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 0f8b044e..4ea10560 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -419,7 +419,9 @@ def __init__( class QueryBuilder: - def __init__(self, node_set, with_subgraph: bool = False): + def __init__( + self, node_set, with_subgraph: bool = False, subquery_context: bool = False + ): self.node_set = node_set self._ast = QueryAST() self._query_params = {} @@ -427,6 +429,7 @@ def __init__(self, node_set, with_subgraph: bool = False): self._ident_count = 0 self._node_counters = defaultdict(int) self._with_subgraph: bool = with_subgraph + self._subquery_context: bool = subquery_context def build_ast(self): if hasattr(self.node_set, "relations_to_fetch"): @@ -442,7 +445,7 @@ def build_ast(self): return self - def build_source(self, source): + def build_source(self, source) -> str: if isinstance(source, Traversal): return self.build_traversal(source) if isinstance(source, NodeSet): @@ -548,6 +551,9 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: # contains the primary node so _contains() works # as usual self._ast.return_clause = lhs_name + if self._subquery_context: + # Don't include label in identifier if we are in a subquery + lhs_ident = lhs_name elif relation["include_in_return"]: self._additional_return(lhs_name) else: @@ -594,7 +600,7 @@ def build_node(self, node): self._ast.result_class = node.__class__ return ident - def build_label(self, ident, cls): + def build_label(self, ident, cls) -> str: """ match nodes by a label """ @@ -703,7 +709,7 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None): self._ast.where.append(" AND ".join(stmts)) - def build_query(self): + def build_query(self) -> str: query: str = "" if self._ast.lookup: @@ -730,9 +736,15 @@ def build_query(self): query += " WITH " query += self._ast.with_clause - query += " RETURN " returned_items: list[str] = [] - if self._ast.return_clause: + if hasattr(self.node_set, "_subqueries"): + for subquery, return_set in self.node_set._subqueries: + outer_primary_var: str = self._ast.return_clause + query += f" CALL {{ WITH {outer_primary_var} {subquery} }} " + returned_items += return_set + + query += " RETURN " + if self._ast.return_clause and not self._subquery_context: returned_items.append(self._ast.return_clause) if self._ast.additional_return: returned_items += self._ast.additional_return @@ -958,6 +970,7 @@ def __init__(self, source): self.relations_to_fetch: list = [] self._extra_results: dict[str] = {} + self._subqueries: list[tuple(str, list[str])] = [] def __await__(self): return self.all().__await__() @@ -1236,6 +1249,25 @@ def resolve_subgraph(self) -> list: ) return results + def subquery(self, nodeset: "NodeSet", return_set: list[str]) -> "NodeSet": + """Add a subquery to this node set. + + A subquery is a regular cypher query but executed within the context of a CALL + statement. Such query will generally fetch additional variables which must be + declared inside return_set variable in order to be included in the final RETURN + statement. + """ + qbuilder = nodeset.query_cls(nodeset, subquery_context=True).build_ast() + for var in return_set: + if ( + var != qbuilder._ast.return_clause + and var not in qbuilder._ast.additional_return + and var not in nodeset._extra_results + ): + raise RuntimeError(f"Variable '{var}' is not returned by subquery.") + self._subqueries.append((qbuilder.build_query(), return_set)) + return self + class Traversal(BaseSet): """ @@ -1249,7 +1281,7 @@ class Traversal(BaseSet): :type name: :class:`str` :param definition: A relationship definition that most certainly deserves a documentation here. - :type defintion: :class:`dict` + :type definition: :class:`dict` """ def __await__(self): diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index b3546ef9..c4ff24d9 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -716,6 +716,42 @@ async def test_resolve_subgraph_optional(): assert coffees._relations["species"] == arabica +@mark_async_test +async def test_subquery(): + # Clean DB before we start anything... + await adb.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = await Species(name="Arabica").save() + nescafe = await Coffee(name="Nescafe", price=99).save() + supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save() + supplier2 = await Supplier(name="Supplier 2", delivery_cost=20).save() + + await nescafe.suppliers.connect(supplier1) + await nescafe.suppliers.connect(supplier2) + await nescafe.species.connect(arabica) + + result = await Coffee.nodes.subquery( + Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( + supps=Collect("suppliers") + ), + ["supps"], + ) + result = await result.all() + assert len(result) == 1 + assert len(result[0][0][0]) == 2 + + with raises( + RuntimeError, + match=re.escape("Variable 'unknown' is not returned by subquery."), + ): + result = await Coffee.nodes.subquery( + Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( + supps=Collect("suppliers") + ), + ["unknown"], + ) + + @mark_async_test async def test_issue_795(): jim = await PersonX(name="Jim", age=3).save() # Create diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index ab843639..09e19bc1 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -704,6 +704,42 @@ def test_resolve_subgraph_optional(): assert coffees._relations["species"] == arabica +@mark_sync_test +def test_subquery(): + # Clean DB before we start anything... + db.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = Species(name="Arabica").save() + nescafe = Coffee(name="Nescafe", price=99).save() + supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save() + supplier2 = Supplier(name="Supplier 2", delivery_cost=20).save() + + nescafe.suppliers.connect(supplier1) + nescafe.suppliers.connect(supplier2) + nescafe.species.connect(arabica) + + result = Coffee.nodes.subquery( + Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( + supps=Collect("suppliers") + ), + ["supps"], + ) + result = result.all() + assert len(result) == 1 + assert len(result[0][0][0]) == 2 + + with raises( + RuntimeError, + match=re.escape("Variable 'unknown' is not returned by subquery."), + ): + result = Coffee.nodes.subquery( + Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( + supps=Collect("suppliers") + ), + ["unknown"], + ) + + @mark_sync_test def test_issue_795(): jim = PersonX(name="Jim", age=3).save() # Create