diff --git a/Changelog b/Changelog
index 4f78f667..a93e806e 100644
--- a/Changelog
+++ b/Changelog
@@ -1,5 +1,7 @@
Vesion 5.4.2 2024-12
* Add support for Neo4j Rust driver extension : pip install neomodel[rust-driver-ext]
+* Add initial_context parameter to subqueries
+* NodeNameResolver can call self to reference top-level node
Version 5.4.1 2024-11
* Add support for Cypher parallel runtime
diff --git a/doc/source/advanced_query_operations.rst b/doc/source/advanced_query_operations.rst
index 73c5bbd6..74c15683 100644
--- a/doc/source/advanced_query_operations.rst
+++ b/doc/source/advanced_query_operations.rst
@@ -60,7 +60,7 @@ As discussed in the note above, this is for example useful when you need to orde
Options for `intermediate_transform` *variables* are:
-- `source`: `string`or `Resolver` - the variable to use as source for the transformation. Works with resolvers (see below).
+- `source`: `string` or `Resolver` - the variable to use as source for the transformation. Works with resolvers (see below).
- `source_prop`: `string` - optionally, a property of the source variable to use as source for the transformation.
- `include_in_return`: `bool` - whether to include the variable in the return statement. Defaults to False.
@@ -95,7 +95,7 @@ Subqueries
The `subquery` method allows you to perform a `Cypher subquery `_ inside your query. This allows you to perform operations in isolation to the rest of your query::
from neomodel.sync_match import Collect, Last
-
+
# This will create a CALL{} subquery
# And return a variable named supps usable in the rest of your query
Coffee.nodes.filter(name="Espresso")
@@ -106,12 +106,18 @@ The `subquery` method allows you to perform a `Cypher subquery None:
+ def __init__(self, node_set, subquery_namespace: TOptional[str] = None) -> None:
self.node_set = node_set
self._ast = QueryAST()
self._query_params: dict = {}
self._place_holder_registry: dict = {}
self._ident_count: int = 0
- self._subquery_context: bool = subquery_context
+ self._subquery_namespace: TOptional[str] = subquery_namespace
async def build_ast(self) -> "AsyncQueryBuilder":
if hasattr(self.node_set, "relations_to_fetch"):
@@ -563,7 +563,7 @@ def build_traversal_from_path(
# contains the primary node so _contains() works
# as usual
self._ast.return_clause = lhs_name
- if self._subquery_context:
+ if self._subquery_namespace:
# Don't include label in identifier if we are in a subquery
lhs_ident = lhs_name
elif relation["include_in_return"]:
@@ -677,7 +677,10 @@ def _register_place_holder(self, key: str) -> str:
self._place_holder_registry[key] += 1
else:
self._place_holder_registry[key] = 1
- return key + "_" + str(self._place_holder_registry[key])
+ place_holder = f"{key}_{self._place_holder_registry[key]}"
+ if self._subquery_namespace:
+ place_holder = f"{self._subquery_namespace}_{place_holder}"
+ return place_holder
def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str, Any]:
is_rel_filter = "|" in prop
@@ -884,10 +887,21 @@ def build_query(self) -> str:
query += ",".join(ordering)
if hasattr(self.node_set, "_subqueries"):
- for subquery, return_set in self.node_set._subqueries:
- outer_primary_var = self._ast.return_clause
- query += f" CALL {{ WITH {outer_primary_var} {subquery} }} "
- for varname in return_set:
+ for subquery in self.node_set._subqueries:
+ query += " CALL {"
+ if subquery["initial_context"]:
+ query += " WITH "
+ context: List[str] = []
+ for var in subquery["initial_context"]:
+ if isinstance(var, (NodeNameResolver, RelationNameResolver)):
+ context.append(var.resolve(self))
+ else:
+ context.append(var)
+ query += ",".join(context)
+
+ query += f"{subquery['query']} }} "
+ self._query_params.update(subquery["query_params"])
+ for varname in subquery["return_set"]:
# We declare the returned variables as "virtual" relations of the
# root node class to make sure they will be translated by a call to
# resolve_subgraph() (otherwise, they will be lost).
@@ -898,10 +912,10 @@ def build_query(self) -> str:
"variable_name": varname,
"rel_variable_name": varname,
}
- returned_items += return_set
+ returned_items += subquery["return_set"]
query += " RETURN "
- if self._ast.return_clause and not self._subquery_context:
+ if self._ast.return_clause and not self._subquery_namespace:
returned_items.append(self._ast.return_clause)
if self._ast.additional_return:
returned_items += self._ast.additional_return
@@ -1128,6 +1142,8 @@ class NodeNameResolver:
node: str
def resolve(self, qbuilder: AsyncQueryBuilder) -> str:
+ if self.node == "self" and qbuilder._ast.return_clause:
+ return qbuilder._ast.return_clause
result = qbuilder.lookup_query_variable(self.node)
if result is None:
raise ValueError(f"Unable to resolve variable name for node {self.node}")
@@ -1246,7 +1262,7 @@ def __init__(self, source) -> None:
self.relations_to_fetch: list = []
self._extra_results: list = []
- self._subqueries: list[Tuple[str, list[str]]] = []
+ self._subqueries: list[Subquery] = []
self._intermediate_transforms: list = []
def __await__(self):
@@ -1534,7 +1550,10 @@ async def resolve_subgraph(self) -> list:
return results
async def subquery(
- self, nodeset: "AsyncNodeSet", return_set: list[str]
+ self,
+ nodeset: "AsyncNodeSet",
+ return_set: list[str],
+ initial_context: TOptional[list[str]] = None,
) -> "AsyncNodeSet":
"""Add a subquery to this node set.
@@ -1543,7 +1562,10 @@ async def subquery(
declared inside return_set variable in order to be included in the final RETURN
statement.
"""
- qbuilder = await nodeset.query_cls(nodeset, subquery_context=True).build_ast()
+ namespace = f"sq{len(self._subqueries) + 1}"
+ qbuilder = await nodeset.query_cls(
+ nodeset, subquery_namespace=namespace
+ ).build_ast()
for var in return_set:
if (
var != qbuilder._ast.return_clause
@@ -1553,9 +1575,31 @@ async def subquery(
)
and var
not in [res["alias"] for res in nodeset._extra_results if res["alias"]]
+ and var
+ not in [
+ varname
+ for tr in nodeset._intermediate_transforms
+ for varname, vardef in tr["vars"].items()
+ if vardef.get("include_in_return")
+ ]
):
raise RuntimeError(f"Variable '{var}' is not returned by subquery.")
- self._subqueries.append((qbuilder.build_query(), return_set))
+ if initial_context:
+ for var in initial_context:
+ if type(var) is not str and not isinstance(
+ var, (NodeNameResolver, RelationNameResolver, RawCypher)
+ ):
+ raise ValueError(
+ f"Wrong variable specified in initial context, should be a string or an instance of NodeNameResolver or RelationNameResolver"
+ )
+ self._subqueries.append(
+ {
+ "query": qbuilder.build_query(),
+ "query_params": qbuilder._query_params,
+ "return_set": return_set,
+ "initial_context": initial_context,
+ }
+ )
return self
def intermediate_transform(
diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py
index 0ae98b5f..463dfa5a 100644
--- a/neomodel/sync_/match.py
+++ b/neomodel/sync_/match.py
@@ -12,7 +12,7 @@
from neomodel.sync_ import relationship_manager
from neomodel.sync_.core import StructuredNode, db
from neomodel.sync_.relationship import StructuredRel
-from neomodel.typing import Transformation
+from neomodel.typing import Subquery, Transformation
from neomodel.util import INCOMING, OUTGOING
CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR = re.compile(r"(?i:MERGE|CREATE|DELETE|DETACH)")
@@ -415,13 +415,13 @@ def __init__(
class QueryBuilder:
- def __init__(self, node_set, subquery_context: bool = False) -> None:
+ def __init__(self, node_set, subquery_namespace: TOptional[str] = None) -> None:
self.node_set = node_set
self._ast = QueryAST()
self._query_params: dict = {}
self._place_holder_registry: dict = {}
self._ident_count: int = 0
- self._subquery_context: bool = subquery_context
+ self._subquery_namespace: TOptional[str] = subquery_namespace
def build_ast(self) -> "QueryBuilder":
if hasattr(self.node_set, "relations_to_fetch"):
@@ -563,7 +563,7 @@ def build_traversal_from_path(
# contains the primary node so _contains() works
# as usual
self._ast.return_clause = lhs_name
- if self._subquery_context:
+ if self._subquery_namespace:
# Don't include label in identifier if we are in a subquery
lhs_ident = lhs_name
elif relation["include_in_return"]:
@@ -677,7 +677,10 @@ def _register_place_holder(self, key: str) -> str:
self._place_holder_registry[key] += 1
else:
self._place_holder_registry[key] = 1
- return key + "_" + str(self._place_holder_registry[key])
+ place_holder = f"{key}_{self._place_holder_registry[key]}"
+ if self._subquery_namespace:
+ place_holder = f"{self._subquery_namespace}_{place_holder}"
+ return place_holder
def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str, Any]:
is_rel_filter = "|" in prop
@@ -884,10 +887,21 @@ def build_query(self) -> str:
query += ",".join(ordering)
if hasattr(self.node_set, "_subqueries"):
- for subquery, return_set in self.node_set._subqueries:
- outer_primary_var = self._ast.return_clause
- query += f" CALL {{ WITH {outer_primary_var} {subquery} }} "
- for varname in return_set:
+ for subquery in self.node_set._subqueries:
+ query += " CALL {"
+ if subquery["initial_context"]:
+ query += " WITH "
+ context: List[str] = []
+ for var in subquery["initial_context"]:
+ if isinstance(var, (NodeNameResolver, RelationNameResolver)):
+ context.append(var.resolve(self))
+ else:
+ context.append(var)
+ query += ",".join(context)
+
+ query += f"{subquery['query']} }} "
+ self._query_params.update(subquery["query_params"])
+ for varname in subquery["return_set"]:
# We declare the returned variables as "virtual" relations of the
# root node class to make sure they will be translated by a call to
# resolve_subgraph() (otherwise, they will be lost).
@@ -898,10 +912,10 @@ def build_query(self) -> str:
"variable_name": varname,
"rel_variable_name": varname,
}
- returned_items += return_set
+ returned_items += subquery["return_set"]
query += " RETURN "
- if self._ast.return_clause and not self._subquery_context:
+ if self._ast.return_clause and not self._subquery_namespace:
returned_items.append(self._ast.return_clause)
if self._ast.additional_return:
returned_items += self._ast.additional_return
@@ -1126,6 +1140,8 @@ class NodeNameResolver:
node: str
def resolve(self, qbuilder: QueryBuilder) -> str:
+ if self.node == "self" and qbuilder._ast.return_clause:
+ return qbuilder._ast.return_clause
result = qbuilder.lookup_query_variable(self.node)
if result is None:
raise ValueError(f"Unable to resolve variable name for node {self.node}")
@@ -1244,7 +1260,7 @@ def __init__(self, source) -> None:
self.relations_to_fetch: list = []
self._extra_results: list = []
- self._subqueries: list[Tuple[str, list[str]]] = []
+ self._subqueries: list[Subquery] = []
self._intermediate_transforms: list = []
def __await__(self):
@@ -1531,7 +1547,12 @@ def resolve_subgraph(self) -> list:
)
return results
- def subquery(self, nodeset: "NodeSet", return_set: list[str]) -> "NodeSet":
+ def subquery(
+ self,
+ nodeset: "NodeSet",
+ return_set: list[str],
+ initial_context: TOptional[list[str]] = None,
+ ) -> "NodeSet":
"""Add a subquery to this node set.
A subquery is a regular cypher query but executed within the context of a CALL
@@ -1539,7 +1560,8 @@ def subquery(self, nodeset: "NodeSet", return_set: list[str]) -> "NodeSet":
declared inside return_set variable in order to be included in the final RETURN
statement.
"""
- qbuilder = nodeset.query_cls(nodeset, subquery_context=True).build_ast()
+ namespace = f"sq{len(self._subqueries) + 1}"
+ qbuilder = nodeset.query_cls(nodeset, subquery_namespace=namespace).build_ast()
for var in return_set:
if (
var != qbuilder._ast.return_clause
@@ -1549,9 +1571,31 @@ def subquery(self, nodeset: "NodeSet", return_set: list[str]) -> "NodeSet":
)
and var
not in [res["alias"] for res in nodeset._extra_results if res["alias"]]
+ and var
+ not in [
+ varname
+ for tr in nodeset._intermediate_transforms
+ for varname, vardef in tr["vars"].items()
+ if vardef.get("include_in_return")
+ ]
):
raise RuntimeError(f"Variable '{var}' is not returned by subquery.")
- self._subqueries.append((qbuilder.build_query(), return_set))
+ if initial_context:
+ for var in initial_context:
+ if type(var) is not str and not isinstance(
+ var, (NodeNameResolver, RelationNameResolver, RawCypher)
+ ):
+ raise ValueError(
+ f"Wrong variable specified in initial context, should be a string or an instance of NodeNameResolver or RelationNameResolver"
+ )
+ self._subqueries.append(
+ {
+ "query": qbuilder.build_query(),
+ "query_params": qbuilder._query_params,
+ "return_set": return_set,
+ "initial_context": initial_context,
+ }
+ )
return self
def intermediate_transform(
diff --git a/neomodel/typing.py b/neomodel/typing.py
index 9438bd54..a23f88eb 100644
--- a/neomodel/typing.py
+++ b/neomodel/typing.py
@@ -10,3 +10,14 @@
"include_in_return": Optional[bool],
},
)
+
+
+Subquery = TypedDict(
+ "Subquery",
+ {
+ "query": str,
+ "query_params": dict,
+ "return_set": list[str],
+ "initial_context": Optional[list[Any]],
+ },
+)
diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py
index 2dff91c0..70c7f351 100644
--- a/test/async_/test_match_api.py
+++ b/test/async_/test_match_api.py
@@ -2,6 +2,7 @@
from datetime import datetime
from test._async_compat import mark_async_test
+import numpy as np
from pytest import raises, skip, warns
from neomodel import (
@@ -880,15 +881,16 @@ async def test_subquery():
await nescafe.suppliers.connect(supplier2)
await nescafe.species.connect(arabica)
- result = await Coffee.nodes.subquery(
+ subquery = await Coffee.nodes.subquery(
Coffee.nodes.traverse_relations(suppliers="suppliers")
.intermediate_transform(
{"suppliers": {"source": "suppliers"}}, ordering=["suppliers.delivery_cost"]
)
.annotate(supps=Last(Collect("suppliers"))),
["supps"],
+ [NodeNameResolver("self")],
)
- result = await result.all()
+ result = await subquery.all()
assert len(result) == 1
assert len(result[0]) == 2
assert result[0][0] == supplier2
@@ -904,6 +906,58 @@ async def test_subquery():
["unknown"],
)
+ result_string_context = await subquery.subquery(
+ Coffee.nodes.traverse_relations(supps2="suppliers").annotate(
+ supps2=Collect("supps")
+ ),
+ ["supps2"],
+ ["supps"],
+ )
+ result_string_context = await result_string_context.all()
+ assert len(result) == 1
+ additional_elements = [
+ item for item in result_string_context[0] if item not in result[0]
+ ]
+ assert len(additional_elements) == 1
+ assert isinstance(additional_elements[0], list)
+
+ with raises(ValueError, match=r"Wrong variable specified in initial context"):
+ result = await Coffee.nodes.subquery(
+ Coffee.nodes.traverse_relations(suppliers="suppliers").annotate(
+ supps=Collect("suppliers")
+ ),
+ ["supps"],
+ [2],
+ )
+
+
+@mark_async_test
+async def test_subquery_other_node():
+ arabica = await Species(name="Arabica").save()
+ nescafe = await Coffee(name="Nescafe", price=99).save()
+ supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save()
+ supplier2 = await Supplier(name="Supplier 2", delivery_cost=20).save()
+
+ await nescafe.suppliers.connect(supplier1)
+ await nescafe.suppliers.connect(supplier2)
+ await nescafe.species.connect(arabica)
+
+ result = await Coffee.nodes.subquery(
+ Supplier.nodes.filter(name="Supplier 2").intermediate_transform(
+ {
+ "cost": {
+ "source": "supplier",
+ "source_prop": "delivery_cost",
+ "include_in_return": True,
+ }
+ }
+ ),
+ ["cost"],
+ )
+ result = await result.all()
+ assert len(result) == 1
+ assert result[0][0] == 20
+
@mark_async_test
async def test_intermediate_transform():
diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py
index 4df51866..94465db2 100644
--- a/test/sync_/test_match_api.py
+++ b/test/sync_/test_match_api.py
@@ -2,6 +2,7 @@
from datetime import datetime
from test._async_compat import mark_sync_test
+import numpy as np
from pytest import raises, skip, warns
from neomodel import (
@@ -864,15 +865,16 @@ def test_subquery():
nescafe.suppliers.connect(supplier2)
nescafe.species.connect(arabica)
- result = Coffee.nodes.subquery(
+ subquery = Coffee.nodes.subquery(
Coffee.nodes.traverse_relations(suppliers="suppliers")
.intermediate_transform(
{"suppliers": {"source": "suppliers"}}, ordering=["suppliers.delivery_cost"]
)
.annotate(supps=Last(Collect("suppliers"))),
["supps"],
+ [NodeNameResolver("self")],
)
- result = result.all()
+ result = subquery.all()
assert len(result) == 1
assert len(result[0]) == 2
assert result[0][0] == supplier2
@@ -888,6 +890,58 @@ def test_subquery():
["unknown"],
)
+ result_string_context = subquery.subquery(
+ Coffee.nodes.traverse_relations(supps2="suppliers").annotate(
+ supps2=Collect("supps")
+ ),
+ ["supps2"],
+ ["supps"],
+ )
+ result_string_context = result_string_context.all()
+ assert len(result) == 1
+ additional_elements = [
+ item for item in result_string_context[0] if item not in result[0]
+ ]
+ assert len(additional_elements) == 1
+ assert isinstance(additional_elements[0], list)
+
+ with raises(ValueError, match=r"Wrong variable specified in initial context"):
+ result = Coffee.nodes.subquery(
+ Coffee.nodes.traverse_relations(suppliers="suppliers").annotate(
+ supps=Collect("suppliers")
+ ),
+ ["supps"],
+ [2],
+ )
+
+
+@mark_sync_test
+def test_subquery_other_node():
+ arabica = Species(name="Arabica").save()
+ nescafe = Coffee(name="Nescafe", price=99).save()
+ supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save()
+ supplier2 = Supplier(name="Supplier 2", delivery_cost=20).save()
+
+ nescafe.suppliers.connect(supplier1)
+ nescafe.suppliers.connect(supplier2)
+ nescafe.species.connect(arabica)
+
+ result = Coffee.nodes.subquery(
+ Supplier.nodes.filter(name="Supplier 2").intermediate_transform(
+ {
+ "cost": {
+ "source": "supplier",
+ "source_prop": "delivery_cost",
+ "include_in_return": True,
+ }
+ }
+ ),
+ ["cost"],
+ )
+ result = result.all()
+ assert len(result) == 1
+ assert result[0][0] == 20
+
@mark_sync_test
def test_intermediate_transform():