Skip to content

Commit

Permalink
Various improvements about subqueries.
Browse files Browse the repository at this point in the history
  • Loading branch information
tonioo committed Dec 10, 2024
1 parent 60f84d1 commit 820556f
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 37 deletions.
14 changes: 10 additions & 4 deletions doc/source/advanced_query_operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ As discussed in the note above, this is for example useful when you need to orde

Options for `intermediate_transform` *variables* are:

- `source`: `string`or `Resolver` - the variable to use as source for the transformation. Works with resolvers (see below).
- `source`: `string` or `Resolver` - the variable to use as source for the transformation. Works with resolvers (see below).
- `source_prop`: `string` - optionally, a property of the source variable to use as source for the transformation.
- `include_in_return`: `bool` - whether to include the variable in the return statement. Defaults to False.

Expand Down Expand Up @@ -95,7 +95,7 @@ Subqueries
The `subquery` method allows you to perform a `Cypher subquery <https://neo4j.com/docs/cypher-manual/current/subqueries/call-subquery/>`_ inside your query. This allows you to perform operations in isolation to the rest of your query::

from neomodel.sync_match import Collect, Last

# This will create a CALL{} subquery
# And return a variable named supps usable in the rest of your query
Coffee.nodes.filter(name="Espresso")
Expand All @@ -106,12 +106,18 @@ The `subquery` method allows you to perform a `Cypher subquery <https://neo4j.co
)
.annotate(supps=Last(Collect("suppliers"))),
["supps"],
[NodeNameResolver("self")]
)

Options for `subquery` calls are:

- `return_set`: list of `string` - the subquery variables that should be included in the outer query result
- `initial_context`: optional list of `string` or `Resolver` - the outer query variables that will be injected at the begining of the subquery

.. note::
Notice the subquery starts with Coffee.nodes ; neomodel will use this to know it needs to inject the source "coffee" variable generated by the outer query into the subquery. This means only Espresso coffee nodes will be considered in the subquery.
In the example above, we reference `self` to be included in the initial context. It will actually inject the outer variable corresponding to `Coffee` node.

We know this is confusing to read, but have not found a better wat to do this yet. If you have any suggestions, please let us know.
We know this is confusing to read, but have not found a better wat to do this yet. If you have any suggestions, please let us know.

Helpers
-------
Expand Down
75 changes: 59 additions & 16 deletions neomodel/async_/match.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import inspect
import re
import string
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List
from typing import Optional as TOptional
Expand All @@ -13,7 +12,7 @@
from neomodel.exceptions import MultipleNodesReturned
from neomodel.match_q import Q, QBase
from neomodel.properties import AliasProperty, ArrayProperty, Property
from neomodel.typing import Transformation
from neomodel.typing import Subquery, Transformation
from neomodel.util import INCOMING, OUTGOING

CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR = re.compile(r"(?i:MERGE|CREATE|DELETE|DETACH)")
Expand Down Expand Up @@ -414,13 +413,13 @@ def __init__(


class AsyncQueryBuilder:
def __init__(self, node_set, subquery_context: bool = False) -> None:
def __init__(self, node_set, subquery_namespace: TOptional[str] = None) -> None:
self.node_set = node_set
self._ast = QueryAST()
self._query_params: Dict = {}
self._place_holder_registry: Dict = {}
self._ident_count: int = 0
self._subquery_context: bool = subquery_context
self._subquery_namespace: TOptional[str] = subquery_namespace

async def build_ast(self) -> "AsyncQueryBuilder":
if hasattr(self.node_set, "relations_to_fetch"):
Expand Down Expand Up @@ -558,7 +557,7 @@ def build_traversal_from_path(
# contains the primary node so _contains() works
# as usual
self._ast.return_clause = lhs_name
if self._subquery_context:
if self._subquery_namespace:
# Don't include label in identifier if we are in a subquery
lhs_ident = lhs_name
elif relation["include_in_return"]:
Expand Down Expand Up @@ -672,7 +671,10 @@ def _register_place_holder(self, key: str) -> str:
self._place_holder_registry[key] += 1
else:
self._place_holder_registry[key] = 1
return key + "_" + str(self._place_holder_registry[key])
place_holder = f"{key}_{self._place_holder_registry[key]}"
if self._subquery_namespace:
place_holder = f"{self._subquery_namespace}_{place_holder}"
return place_holder

def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str, Any]:
is_rel_filter = "|" in prop
Expand Down Expand Up @@ -879,10 +881,21 @@ def build_query(self) -> str:
query += ",".join(ordering)

if hasattr(self.node_set, "_subqueries"):
for subquery, return_set in self.node_set._subqueries:
outer_primary_var = self._ast.return_clause
query += f" CALL {{ WITH {outer_primary_var} {subquery} }} "
for varname in return_set:
for subquery in self.node_set._subqueries:
query += " CALL {"
if subquery["initial_context"]:
query += " WITH "
context: List[str] = []
for var in subquery["initial_context"]:
if isinstance(var, (NodeNameResolver, RelationNameResolver)):
context.append(var.resolve(self))
else:
context.append(var)

Check warning on line 893 in neomodel/async_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/async_/match.py#L893

Added line #L893 was not covered by tests
query += ",".join(context)

query += f"{subquery['query']} }} "
self._query_params.update(subquery["query_params"])
for varname in subquery["return_set"]:
# We declare the returned variables as "virtual" relations of the
# root node class to make sure they will be translated by a call to
# resolve_subgraph() (otherwise, they will be lost).
Expand All @@ -893,10 +906,10 @@ def build_query(self) -> str:
"variable_name": varname,
"rel_variable_name": varname,
}
returned_items += return_set
returned_items += subquery["return_set"]

query += " RETURN "
if self._ast.return_clause and not self._subquery_context:
if self._ast.return_clause and not self._subquery_namespace:
returned_items.append(self._ast.return_clause)
if self._ast.additional_return:
returned_items += self._ast.additional_return
Expand Down Expand Up @@ -1120,6 +1133,8 @@ class NodeNameResolver:
node: str

def resolve(self, qbuilder: AsyncQueryBuilder) -> str:
if self.node == "self" and qbuilder._ast.return_clause:
return qbuilder._ast.return_clause
result = qbuilder.lookup_query_variable(self.node)
if result is None:
raise ValueError(f"Unable to resolve variable name for node {self.node}")
Expand Down Expand Up @@ -1238,7 +1253,7 @@ def __init__(self, source) -> None:

self.relations_to_fetch: List = []
self._extra_results: List = []
self._subqueries: list[Tuple[str, list[str]]] = []
self._subqueries: list[Subquery] = []
self._intermediate_transforms: list = []

def __await__(self):
Expand Down Expand Up @@ -1525,7 +1540,10 @@ async def resolve_subgraph(self) -> list:
return results

async def subquery(
self, nodeset: "AsyncNodeSet", return_set: List[str]
self,
nodeset: "AsyncNodeSet",
return_set: List[str],
initial_context: TOptional[List[str]] = None,
) -> "AsyncNodeSet":
"""Add a subquery to this node set.
Expand All @@ -1534,16 +1552,41 @@ async def subquery(
declared inside return_set variable in order to be included in the final RETURN
statement.
"""
qbuilder = await nodeset.query_cls(nodeset, subquery_context=True).build_ast()
namespace = f"sq{len(self._subqueries) + 1}"
qbuilder = await nodeset.query_cls(
nodeset, subquery_namespace=namespace
).build_ast()
for var in return_set:
if (
var != qbuilder._ast.return_clause
and var not in qbuilder._ast.additional_return
and var
not in [res["alias"] for res in nodeset._extra_results if res["alias"]]
and var
not in [
varname
for tr in nodeset._intermediate_transforms
for varname, vardef in tr["vars"].items()
if vardef.get("include_in_return")
]
):
raise RuntimeError(f"Variable '{var}' is not returned by subquery.")
self._subqueries.append((qbuilder.build_query(), return_set))
if initial_context:
for var in initial_context:
if type(var) is not str and not isinstance(
var, (NodeNameResolver, RelationNameResolver, RawCypher)
):
raise ValueError(

Check warning on line 1579 in neomodel/async_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/async_/match.py#L1579

Added line #L1579 was not covered by tests
f"Wrong variable specified in initial context, should be a string or an instance of NodeNameResolver or RelationNameResolver"
)
self._subqueries.append(
{
"query": qbuilder.build_query(),
"query_params": qbuilder._query_params,
"return_set": return_set,
"initial_context": initial_context,
}
)
return self

def intermediate_transform(
Expand Down
75 changes: 59 additions & 16 deletions neomodel/sync_/match.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import inspect
import re
import string
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List
from typing import Optional as TOptional
Expand All @@ -13,7 +12,7 @@
from neomodel.sync_ import relationship_manager
from neomodel.sync_.core import StructuredNode, db
from neomodel.sync_.relationship import StructuredRel
from neomodel.typing import Transformation
from neomodel.typing import Subquery, Transformation
from neomodel.util import INCOMING, OUTGOING

CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR = re.compile(r"(?i:MERGE|CREATE|DELETE|DETACH)")
Expand Down Expand Up @@ -414,13 +413,13 @@ def __init__(


class QueryBuilder:
def __init__(self, node_set, subquery_context: bool = False) -> None:
def __init__(self, node_set, subquery_namespace: TOptional[str] = None) -> None:
self.node_set = node_set
self._ast = QueryAST()
self._query_params: Dict = {}
self._place_holder_registry: Dict = {}
self._ident_count: int = 0
self._subquery_context: bool = subquery_context
self._subquery_namespace: TOptional[str] = subquery_namespace

def build_ast(self) -> "QueryBuilder":
if hasattr(self.node_set, "relations_to_fetch"):
Expand Down Expand Up @@ -558,7 +557,7 @@ def build_traversal_from_path(
# contains the primary node so _contains() works
# as usual
self._ast.return_clause = lhs_name
if self._subquery_context:
if self._subquery_namespace:
# Don't include label in identifier if we are in a subquery
lhs_ident = lhs_name
elif relation["include_in_return"]:
Expand Down Expand Up @@ -672,7 +671,10 @@ def _register_place_holder(self, key: str) -> str:
self._place_holder_registry[key] += 1
else:
self._place_holder_registry[key] = 1
return key + "_" + str(self._place_holder_registry[key])
place_holder = f"{key}_{self._place_holder_registry[key]}"
if self._subquery_namespace:
place_holder = f"{self._subquery_namespace}_{place_holder}"
return place_holder

def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str, Any]:
is_rel_filter = "|" in prop
Expand Down Expand Up @@ -879,10 +881,21 @@ def build_query(self) -> str:
query += ",".join(ordering)

if hasattr(self.node_set, "_subqueries"):
for subquery, return_set in self.node_set._subqueries:
outer_primary_var = self._ast.return_clause
query += f" CALL {{ WITH {outer_primary_var} {subquery} }} "
for varname in return_set:
for subquery in self.node_set._subqueries:
query += " CALL {"
if subquery["initial_context"]:
query += " WITH "
context: List[str] = []
for var in subquery["initial_context"]:
if isinstance(var, (NodeNameResolver, RelationNameResolver)):
context.append(var.resolve(self))
else:
context.append(var)

Check warning on line 893 in neomodel/sync_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/sync_/match.py#L893

Added line #L893 was not covered by tests
query += ",".join(context)

query += f"{subquery['query']} }} "
self._query_params.update(subquery["query_params"])
for varname in subquery["return_set"]:
# We declare the returned variables as "virtual" relations of the
# root node class to make sure they will be translated by a call to
# resolve_subgraph() (otherwise, they will be lost).
Expand All @@ -893,10 +906,10 @@ def build_query(self) -> str:
"variable_name": varname,
"rel_variable_name": varname,
}
returned_items += return_set
returned_items += subquery["return_set"]

query += " RETURN "
if self._ast.return_clause and not self._subquery_context:
if self._ast.return_clause and not self._subquery_namespace:
returned_items.append(self._ast.return_clause)
if self._ast.additional_return:
returned_items += self._ast.additional_return
Expand Down Expand Up @@ -1118,6 +1131,8 @@ class NodeNameResolver:
node: str

def resolve(self, qbuilder: QueryBuilder) -> str:
if self.node == "self" and qbuilder._ast.return_clause:
return qbuilder._ast.return_clause
result = qbuilder.lookup_query_variable(self.node)
if result is None:
raise ValueError(f"Unable to resolve variable name for node {self.node}")
Expand Down Expand Up @@ -1236,7 +1251,7 @@ def __init__(self, source) -> None:

self.relations_to_fetch: List = []
self._extra_results: List = []
self._subqueries: list[Tuple[str, list[str]]] = []
self._subqueries: list[Subquery] = []
self._intermediate_transforms: list = []

def __await__(self):
Expand Down Expand Up @@ -1522,24 +1537,52 @@ def resolve_subgraph(self) -> list:
)
return results

def subquery(self, nodeset: "NodeSet", return_set: List[str]) -> "NodeSet":
def subquery(
self,
nodeset: "NodeSet",
return_set: List[str],
initial_context: TOptional[List[str]] = None,
) -> "NodeSet":
"""Add a subquery to this node set.
A subquery is a regular cypher query but executed within the context of a CALL
statement. Such query will generally fetch additional variables which must be
declared inside return_set variable in order to be included in the final RETURN
statement.
"""
qbuilder = nodeset.query_cls(nodeset, subquery_context=True).build_ast()
namespace = f"sq{len(self._subqueries) + 1}"
qbuilder = nodeset.query_cls(nodeset, subquery_namespace=namespace).build_ast()
for var in return_set:
if (
var != qbuilder._ast.return_clause
and var not in qbuilder._ast.additional_return
and var
not in [res["alias"] for res in nodeset._extra_results if res["alias"]]
and var
not in [
varname
for tr in nodeset._intermediate_transforms
for varname, vardef in tr["vars"].items()
if vardef.get("include_in_return")
]
):
raise RuntimeError(f"Variable '{var}' is not returned by subquery.")
self._subqueries.append((qbuilder.build_query(), return_set))
if initial_context:
for var in initial_context:
if type(var) is not str and not isinstance(
var, (NodeNameResolver, RelationNameResolver, RawCypher)
):
raise ValueError(

Check warning on line 1575 in neomodel/sync_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/sync_/match.py#L1575

Added line #L1575 was not covered by tests
f"Wrong variable specified in initial context, should be a string or an instance of NodeNameResolver or RelationNameResolver"
)
self._subqueries.append(
{
"query": qbuilder.build_query(),
"query_params": qbuilder._query_params,
"return_set": return_set,
"initial_context": initial_context,
}
)
return self

def intermediate_transform(
Expand Down
13 changes: 12 additions & 1 deletion neomodel/typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Custom types used for annotations."""

from typing import Any, Optional, TypedDict
from typing import Any, Dict, List, Optional, TypedDict

Transformation = TypedDict(
"Transformation",
Expand All @@ -10,3 +10,14 @@
"include_in_return": Optional[bool],
},
)


Subquery = TypedDict(
"Subquery",
{
"query": str,
"query_params": Dict,
"return_set": List[str],
"initial_context": Optional[List[Any]],
},
)
Loading

0 comments on commit 820556f

Please sign in to comment.