Skip to content

Commit

Permalink
Added possibility to specify subqueries.
Browse files Browse the repository at this point in the history
  • Loading branch information
tonioo committed Oct 2, 2024
1 parent cad5936 commit 55e04bc
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 14 deletions.
48 changes: 41 additions & 7 deletions neomodel/async_/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,14 +419,17 @@ def __init__(


class AsyncQueryBuilder:
def __init__(self, node_set, with_subgraph: bool = False):
def __init__(
self, node_set, with_subgraph: bool = False, subquery_context: bool = False
):
self.node_set = node_set
self._ast = QueryAST()
self._query_params = {}
self._place_holder_registry = {}
self._ident_count = 0
self._node_counters = defaultdict(int)
self._with_subgraph: bool = with_subgraph
self._subquery_context: bool = subquery_context

async def build_ast(self):
if hasattr(self.node_set, "relations_to_fetch"):
Expand All @@ -442,7 +445,7 @@ async def build_ast(self):

return self

async def build_source(self, source):
async def build_source(self, source) -> str:
if isinstance(source, AsyncTraversal):
return await self.build_traversal(source)
if isinstance(source, AsyncNodeSet):
Expand Down Expand Up @@ -548,6 +551,9 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str:
# contains the primary node so _contains() works
# as usual
self._ast.return_clause = lhs_name
if self._subquery_context:
# Don't include label in identifier if we are in a subquery
lhs_ident = lhs_name
elif relation["include_in_return"]:
self._additional_return(lhs_name)
else:
Expand Down Expand Up @@ -594,7 +600,7 @@ async def build_node(self, node):
self._ast.result_class = node.__class__
return ident

def build_label(self, ident, cls):
def build_label(self, ident, cls) -> str:
"""
match nodes by a label
"""
Expand Down Expand Up @@ -703,7 +709,7 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None):

self._ast.where.append(" AND ".join(stmts))

def build_query(self):
def build_query(self) -> str:
query: str = ""

if self._ast.lookup:
Expand All @@ -730,9 +736,15 @@ def build_query(self):
query += " WITH "
query += self._ast.with_clause

query += " RETURN "
returned_items: list[str] = []
if self._ast.return_clause:
if hasattr(self.node_set, "_subqueries"):
for subquery, return_set in self.node_set._subqueries:
outer_primary_var: str = self._ast.return_clause
query += f" CALL {{ WITH {outer_primary_var} {subquery} }} "
returned_items += return_set

query += " RETURN "
if self._ast.return_clause and not self._subquery_context:
returned_items.append(self._ast.return_clause)
if self._ast.additional_return:
returned_items += self._ast.additional_return
Expand Down Expand Up @@ -960,6 +972,7 @@ def __init__(self, source):

self.relations_to_fetch: list = []
self._extra_results: dict[str] = {}
self._subqueries: list[tuple(str, list[str])] = []

def __await__(self):
return self.all().__await__()
Expand Down Expand Up @@ -1238,6 +1251,27 @@ async def resolve_subgraph(self) -> list:
)
return results

async def subquery(
self, nodeset: "AsyncNodeSet", return_set: list[str]
) -> "AsyncNodeSet":
"""Add a subquery to this node set.
A subquery is a regular cypher query but executed within the context of a CALL
statement. Such query will generally fetch additional variables which must be
declared inside return_set variable in order to be included in the final RETURN
statement.
"""
qbuilder = await nodeset.query_cls(nodeset, subquery_context=True).build_ast()
for var in return_set:
if (
var != qbuilder._ast.return_clause
and var not in qbuilder._ast.additional_return
and var not in nodeset._extra_results
):
raise RuntimeError(f"Variable '{var}' is not returned by subquery.")
self._subqueries.append((qbuilder.build_query(), return_set))
return self


class AsyncTraversal(AsyncBaseSet):
"""
Expand All @@ -1251,7 +1285,7 @@ class AsyncTraversal(AsyncBaseSet):
:type name: :class:`str`
:param definition: A relationship definition that most certainly deserves
a documentation here.
:type defintion: :class:`dict`
:type definition: :class:`dict`
"""

def __await__(self):
Expand Down
46 changes: 39 additions & 7 deletions neomodel/sync_/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,14 +419,17 @@ def __init__(


class QueryBuilder:
def __init__(self, node_set, with_subgraph: bool = False):
def __init__(
self, node_set, with_subgraph: bool = False, subquery_context: bool = False
):
self.node_set = node_set
self._ast = QueryAST()
self._query_params = {}
self._place_holder_registry = {}
self._ident_count = 0
self._node_counters = defaultdict(int)
self._with_subgraph: bool = with_subgraph
self._subquery_context: bool = subquery_context

def build_ast(self):
if hasattr(self.node_set, "relations_to_fetch"):
Expand All @@ -442,7 +445,7 @@ def build_ast(self):

return self

def build_source(self, source):
def build_source(self, source) -> str:
if isinstance(source, Traversal):
return self.build_traversal(source)
if isinstance(source, NodeSet):
Expand Down Expand Up @@ -548,6 +551,9 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str:
# contains the primary node so _contains() works
# as usual
self._ast.return_clause = lhs_name
if self._subquery_context:
# Don't include label in identifier if we are in a subquery
lhs_ident = lhs_name
elif relation["include_in_return"]:
self._additional_return(lhs_name)
else:
Expand Down Expand Up @@ -594,7 +600,7 @@ def build_node(self, node):
self._ast.result_class = node.__class__
return ident

def build_label(self, ident, cls):
def build_label(self, ident, cls) -> str:
"""
match nodes by a label
"""
Expand Down Expand Up @@ -703,7 +709,7 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None):

self._ast.where.append(" AND ".join(stmts))

def build_query(self):
def build_query(self) -> str:
query: str = ""

if self._ast.lookup:
Expand All @@ -730,9 +736,15 @@ def build_query(self):
query += " WITH "
query += self._ast.with_clause

query += " RETURN "
returned_items: list[str] = []
if self._ast.return_clause:
if hasattr(self.node_set, "_subqueries"):
for subquery, return_set in self.node_set._subqueries:
outer_primary_var: str = self._ast.return_clause
query += f" CALL {{ WITH {outer_primary_var} {subquery} }} "
returned_items += return_set

query += " RETURN "
if self._ast.return_clause and not self._subquery_context:
returned_items.append(self._ast.return_clause)
if self._ast.additional_return:
returned_items += self._ast.additional_return
Expand Down Expand Up @@ -958,6 +970,7 @@ def __init__(self, source):

self.relations_to_fetch: list = []
self._extra_results: dict[str] = {}
self._subqueries: list[tuple(str, list[str])] = []

def __await__(self):
return self.all().__await__()
Expand Down Expand Up @@ -1236,6 +1249,25 @@ def resolve_subgraph(self) -> list:
)
return results

def subquery(self, nodeset: "NodeSet", return_set: list[str]) -> "NodeSet":
"""Add a subquery to this node set.
A subquery is a regular cypher query but executed within the context of a CALL
statement. Such query will generally fetch additional variables which must be
declared inside return_set variable in order to be included in the final RETURN
statement.
"""
qbuilder = nodeset.query_cls(nodeset, subquery_context=True).build_ast()
for var in return_set:
if (
var != qbuilder._ast.return_clause
and var not in qbuilder._ast.additional_return
and var not in nodeset._extra_results
):
raise RuntimeError(f"Variable '{var}' is not returned by subquery.")
self._subqueries.append((qbuilder.build_query(), return_set))
return self


class Traversal(BaseSet):
"""
Expand All @@ -1249,7 +1281,7 @@ class Traversal(BaseSet):
:type name: :class:`str`
:param definition: A relationship definition that most certainly deserves
a documentation here.
:type defintion: :class:`dict`
:type definition: :class:`dict`
"""

def __await__(self):
Expand Down
36 changes: 36 additions & 0 deletions test/async_/test_match_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,42 @@ async def test_resolve_subgraph_optional():
assert coffees._relations["species"] == arabica


@mark_async_test
async def test_subquery():
# Clean DB before we start anything...
await adb.cypher_query("MATCH (n) DETACH DELETE n")

arabica = await Species(name="Arabica").save()
nescafe = await Coffee(name="Nescafe", price=99).save()
supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save()
supplier2 = await Supplier(name="Supplier 2", delivery_cost=20).save()

await nescafe.suppliers.connect(supplier1)
await nescafe.suppliers.connect(supplier2)
await nescafe.species.connect(arabica)

result = await Coffee.nodes.subquery(
Coffee.nodes.traverse_relations(suppliers="suppliers").annotate(
supps=Collect("suppliers")
),
["supps"],
)
result = await result.all()
assert len(result) == 1
assert len(result[0][0][0]) == 2

with raises(
RuntimeError,
match=re.escape("Variable 'unknown' is not returned by subquery."),
):
result = await Coffee.nodes.subquery(
Coffee.nodes.traverse_relations(suppliers="suppliers").annotate(
supps=Collect("suppliers")
),
["unknown"],
)


@mark_async_test
async def test_issue_795():
jim = await PersonX(name="Jim", age=3).save() # Create
Expand Down
36 changes: 36 additions & 0 deletions test/sync_/test_match_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,42 @@ def test_resolve_subgraph_optional():
assert coffees._relations["species"] == arabica


@mark_sync_test
def test_subquery():
# Clean DB before we start anything...
db.cypher_query("MATCH (n) DETACH DELETE n")

arabica = Species(name="Arabica").save()
nescafe = Coffee(name="Nescafe", price=99).save()
supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save()
supplier2 = Supplier(name="Supplier 2", delivery_cost=20).save()

nescafe.suppliers.connect(supplier1)
nescafe.suppliers.connect(supplier2)
nescafe.species.connect(arabica)

result = Coffee.nodes.subquery(
Coffee.nodes.traverse_relations(suppliers="suppliers").annotate(
supps=Collect("suppliers")
),
["supps"],
)
result = result.all()
assert len(result) == 1
assert len(result[0][0][0]) == 2

with raises(
RuntimeError,
match=re.escape("Variable 'unknown' is not returned by subquery."),
):
result = Coffee.nodes.subquery(
Coffee.nodes.traverse_relations(suppliers="suppliers").annotate(
supps=Collect("suppliers")
),
["unknown"],
)


@mark_sync_test
def test_issue_795():
jim = PersonX(name="Jim", age=3).save() # Create
Expand Down

0 comments on commit 55e04bc

Please sign in to comment.