From 68ee55313726a53d9c6986ebff9020cf93f7cbc7 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 19 Nov 2024 12:54:37 +0100 Subject: [PATCH 01/10] Prep rc branch --- doc/source/configuration.rst | 2 +- neomodel/_version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/source/configuration.rst b/doc/source/configuration.rst index 511318c4..7c178c29 100644 --- a/doc/source/configuration.rst +++ b/doc/source/configuration.rst @@ -32,7 +32,7 @@ Adjust driver configuration - these options are only available for this connecti config.MAX_TRANSACTION_RETRY_TIME = 30.0 # default config.RESOLVER = None # default config.TRUST = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES # default - config.USER_AGENT = neomodel/v5.4.0 # default + config.USER_AGENT = neomodel/v5.4.1 # default Setting the database name, if different from the default one:: diff --git a/neomodel/_version.py b/neomodel/_version.py index fc30498f..1e41bf8f 100644 --- a/neomodel/_version.py +++ b/neomodel/_version.py @@ -1 +1 @@ -__version__ = "5.4.0" +__version__ = "5.4.1" From 7a5cbddd32e4a5a65929ce04f09c6364d60df442 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Fri, 22 Nov 2024 10:15:45 +0100 Subject: [PATCH 02/10] Improved intermediate_transform method. New syntax to allow more complex statements when injecting intermediate transformation: * Use of DISTINCT clause * Indicate item property instead of complete item * Indicate if transformed variable should be returned by query or not --- doc/source/advanced_query_operations.rst | 6 +- neomodel/async_/match.py | 70 ++++++++++++++---------- neomodel/sync_/match.py | 70 ++++++++++++++---------- neomodel/typing.py | 13 +++++ test/async_/test_match_api.py | 12 ++-- test/sync_/test_match_api.py | 12 ++-- 6 files changed, 108 insertions(+), 75 deletions(-) create mode 100644 neomodel/typing.py diff --git a/doc/source/advanced_query_operations.rst b/doc/source/advanced_query_operations.rst index e602d479..a1d3aa36 100644 --- a/doc/source/advanced_query_operations.rst +++ b/doc/source/advanced_query_operations.rst @@ -54,7 +54,7 @@ As discussed in the note above, this is for example useful when you need to orde # This will return all Coffee nodes, with their most expensive supplier Coffee.nodes.traverse_relations(suppliers="suppliers") .intermediate_transform( - {"suppliers": "suppliers"}, ordering=["suppliers.delivery_cost"] + {"suppliers": {"source": "suppliers"}}, ordering=["suppliers.delivery_cost"] ) .annotate(supps=Last(Collect("suppliers"))) @@ -71,7 +71,7 @@ The `subquery` method allows you to perform a `Cypher subquery str: query += " WITH " query += self._ast.with_clause + returned_items: list[str] = [] if hasattr(self.node_set, "_intermediate_transforms"): for transform in self.node_set._intermediate_transforms: query += " WITH " @@ -845,25 +847,19 @@ def build_query(self) -> str: # Reset return list since we'll probably invalidate most variables self._ast.return_clause = "" self._ast.additional_return = [] - for name, source in transform["vars"].items(): - if type(source) is str: - injected_vars.append(f"{source} AS {name}") - elif isinstance(source, RelationNameResolver): - result = self.lookup_query_variable( - source.relation, return_relation=True - ) - if not result: - raise ValueError( - f"Unable to resolve variable name for relation {source.relation}." - ) - injected_vars.append(f"{result[0]} AS {name}") - elif isinstance(source, NodeNameResolver): - result = self.lookup_query_variable(source.node) - if not result: - raise ValueError( - f"Unable to resolve variable name for node {source.node}." - ) - injected_vars.append(f"{result[0]} AS {name}") + for name, varprops in transform["vars"].items(): + source = varprops["source"] + transformation = "DISTINCT " if varprops.get("distinct") else "" + if isinstance(source, (NodeNameResolver, RelationNameResolver)): + transformation += source.resolve(self) + else: + transformation += source + if varprops.get("source_prop"): + transformation += f".{varprops['source_prop']}" + transformation += f" AS {name}" + if varprops.get("include_in_return"): + returned_items += [name] + injected_vars.append(transformation) query += ",".join(injected_vars) if not transform["ordering"]: continue @@ -879,7 +875,6 @@ def build_query(self) -> str: ordering.append(item) query += ",".join(ordering) - returned_items: list[str] = [] if hasattr(self.node_set, "_subqueries"): for subquery, return_set in self.node_set._subqueries: outer_primary_var = self._ast.return_clause @@ -1098,6 +1093,14 @@ class RelationNameResolver: relation: str + def resolve(self, qbuilder: AsyncQueryBuilder) -> str: + result = qbuilder.lookup_query_variable(self.relation, True) + if result is None: + raise ValueError( + f"Unable to resolve variable name for relation {self.relation}" + ) + return result[0] + @dataclass class NodeNameResolver: @@ -1111,6 +1114,12 @@ class NodeNameResolver: node: str + def resolve(self, qbuilder: AsyncQueryBuilder) -> str: + result = qbuilder.lookup_query_variable(self.node) + if result is None: + raise ValueError(f"Unable to resolve variable name for node {self.node}") + return result[0] + @dataclass class BaseFunction: @@ -1123,15 +1132,15 @@ def get_internal_name(self) -> str: return self._internal_name def resolve_internal_name(self, qbuilder: AsyncQueryBuilder) -> str: - if isinstance(self.input_name, NodeNameResolver): - result = qbuilder.lookup_query_variable(self.input_name.node) - elif isinstance(self.input_name, RelationNameResolver): - result = qbuilder.lookup_query_variable(self.input_name.relation, True) + if isinstance(self.input_name, (NodeNameResolver, RelationNameResolver)): + self._internal_name = self.input_name.resolve(qbuilder) else: result = (str(self.input_name), None) - if result is None: - raise ValueError(f"Unknown variable {self.input_name} used in Collect()") - self._internal_name = result[0] + if result is None: + raise ValueError( + f"Unknown variable {self.input_name} used in Collect()" + ) + self._internal_name = result[0] return self._internal_name def render(self, qbuilder: AsyncQueryBuilder) -> str: @@ -1538,15 +1547,16 @@ async def subquery( return self def intermediate_transform( - self, vars: Dict[str, Any], ordering: TOptional[list] = None + self, vars: Dict[str, Transformation], ordering: TOptional[list] = None ) -> "AsyncNodeSet": if not vars: raise ValueError( "You must provide one variable at least when calling intermediate_transform()" ) - for name, source in vars.items(): + for name, props in vars.items(): + source = props["source"] if type(source) is not str and not isinstance( - source, (NodeNameResolver, RelationNameResolver) + source, (NodeNameResolver, RelationNameResolver, RawCypher) ): raise ValueError( f"Wrong source type specified for variable '{name}', should be a string or an instance of NodeNameResolver or RelationNameResolver" diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 966b2601..8acc4c3d 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -12,6 +12,7 @@ from neomodel.sync_ import relationship_manager from neomodel.sync_.core import StructuredNode, db from neomodel.sync_.relationship import StructuredRel +from neomodel.typing import Transformation from neomodel.util import INCOMING, OUTGOING CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR = re.compile(r"(?i:MERGE|CREATE|DELETE|DETACH)") @@ -840,6 +841,7 @@ def build_query(self) -> str: query += " WITH " query += self._ast.with_clause + returned_items: list[str] = [] if hasattr(self.node_set, "_intermediate_transforms"): for transform in self.node_set._intermediate_transforms: query += " WITH " @@ -847,25 +849,19 @@ def build_query(self) -> str: # Reset return list since we'll probably invalidate most variables self._ast.return_clause = "" self._ast.additional_return = [] - for name, source in transform["vars"].items(): - if type(source) is str: - injected_vars.append(f"{source} AS {name}") - elif isinstance(source, RelationNameResolver): - result = self.lookup_query_variable( - source.relation, return_relation=True - ) - if not result: - raise ValueError( - f"Unable to resolve variable name for relation {source.relation}." - ) - injected_vars.append(f"{result[0]} AS {name}") - elif isinstance(source, NodeNameResolver): - result = self.lookup_query_variable(source.node) - if not result: - raise ValueError( - f"Unable to resolve variable name for node {source.node}." - ) - injected_vars.append(f"{result[0]} AS {name}") + for name, varprops in transform["vars"].items(): + source = varprops["source"] + transformation = "DISTINCT " if varprops.get("distinct") else "" + if isinstance(source, (NodeNameResolver, RelationNameResolver)): + transformation += source.resolve(self) + else: + transformation += source + if varprops.get("source_prop"): + transformation += f".{varprops['source_prop']}" + transformation += f" AS {name}" + if varprops.get("include_in_return"): + returned_items += [name] + injected_vars.append(transformation) query += ",".join(injected_vars) if not transform["ordering"]: continue @@ -881,7 +877,6 @@ def build_query(self) -> str: ordering.append(item) query += ",".join(ordering) - returned_items: list[str] = [] if hasattr(self.node_set, "_subqueries"): for subquery, return_set in self.node_set._subqueries: outer_primary_var = self._ast.return_clause @@ -1098,6 +1093,14 @@ class RelationNameResolver: relation: str + def resolve(self, qbuilder: QueryBuilder) -> str: + result = qbuilder.lookup_query_variable(self.relation, True) + if result is None: + raise ValueError( + f"Unable to resolve variable name for relation {self.relation}" + ) + return result[0] + @dataclass class NodeNameResolver: @@ -1111,6 +1114,12 @@ class NodeNameResolver: node: str + def resolve(self, qbuilder: QueryBuilder) -> str: + result = qbuilder.lookup_query_variable(self.node) + if result is None: + raise ValueError(f"Unable to resolve variable name for node {self.node}") + return result[0] + @dataclass class BaseFunction: @@ -1123,15 +1132,15 @@ def get_internal_name(self) -> str: return self._internal_name def resolve_internal_name(self, qbuilder: QueryBuilder) -> str: - if isinstance(self.input_name, NodeNameResolver): - result = qbuilder.lookup_query_variable(self.input_name.node) - elif isinstance(self.input_name, RelationNameResolver): - result = qbuilder.lookup_query_variable(self.input_name.relation, True) + if isinstance(self.input_name, (NodeNameResolver, RelationNameResolver)): + self._internal_name = self.input_name.resolve(qbuilder) else: result = (str(self.input_name), None) - if result is None: - raise ValueError(f"Unknown variable {self.input_name} used in Collect()") - self._internal_name = result[0] + if result is None: + raise ValueError( + f"Unknown variable {self.input_name} used in Collect()" + ) + self._internal_name = result[0] return self._internal_name def render(self, qbuilder: QueryBuilder) -> str: @@ -1536,15 +1545,16 @@ def subquery(self, nodeset: "NodeSet", return_set: List[str]) -> "NodeSet": return self def intermediate_transform( - self, vars: Dict[str, Any], ordering: TOptional[list] = None + self, vars: Dict[str, Transformation], ordering: TOptional[list] = None ) -> "NodeSet": if not vars: raise ValueError( "You must provide one variable at least when calling intermediate_transform()" ) - for name, source in vars.items(): + for name, props in vars.items(): + source = props["source"] if type(source) is not str and not isinstance( - source, (NodeNameResolver, RelationNameResolver) + source, (NodeNameResolver, RelationNameResolver, RawCypher) ): raise ValueError( f"Wrong source type specified for variable '{name}', should be a string or an instance of NodeNameResolver or RelationNameResolver" diff --git a/neomodel/typing.py b/neomodel/typing.py new file mode 100644 index 00000000..2a6afb05 --- /dev/null +++ b/neomodel/typing.py @@ -0,0 +1,13 @@ +"""Custom types used for annotations.""" + +from typing import Any, Optional, TypedDict + +Transformation = TypedDict( + "Transformation", + { + "source": Any, + "source_prop": Optional[str], + "distinct": Optional[bool], + "include_in_return": Optional[bool], + }, +) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 39e96957..2b9c86bf 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -879,7 +879,7 @@ async def test_subquery(): result = await Coffee.nodes.subquery( Coffee.nodes.traverse_relations(suppliers="suppliers") .intermediate_transform( - {"suppliers": "suppliers"}, ordering=["suppliers.delivery_cost"] + {"suppliers": {"source": "suppliers"}}, ordering=["suppliers.delivery_cost"] ) .annotate(supps=Last(Collect("suppliers"))), ["supps"], @@ -916,9 +916,9 @@ async def test_intermediate_transform(): await Coffee.nodes.fetch_relations("suppliers") .intermediate_transform( { - "coffee": "coffee", - "suppliers": NodeNameResolver("suppliers"), - "r": RelationNameResolver("suppliers"), + "coffee": {"source": "coffee"}, + "suppliers": {"source": NodeNameResolver("suppliers")}, + "r": {"source": RelationNameResolver("suppliers")}, }, ordering=["-r.since"], ) @@ -937,7 +937,7 @@ async def test_intermediate_transform(): ): Coffee.nodes.traverse_relations(suppliers="suppliers").intermediate_transform( { - "test": Collect("suppliers"), + "test": {"source": Collect("suppliers")}, } ) with raises( @@ -1008,7 +1008,7 @@ async def test_mix_functions(): .subquery( Student.nodes.fetch_relations("courses") .intermediate_transform( - {"rel": RelationNameResolver("courses")}, + {"rel": {"source": RelationNameResolver("courses")}}, ordering=[ RawCypher("toInteger(split(rel.level, '.')[0])"), RawCypher("toInteger(split(rel.level, '.')[1])"), diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 78909860..61021990 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -863,7 +863,7 @@ def test_subquery(): result = Coffee.nodes.subquery( Coffee.nodes.traverse_relations(suppliers="suppliers") .intermediate_transform( - {"suppliers": "suppliers"}, ordering=["suppliers.delivery_cost"] + {"suppliers": {"source": "suppliers"}}, ordering=["suppliers.delivery_cost"] ) .annotate(supps=Last(Collect("suppliers"))), ["supps"], @@ -900,9 +900,9 @@ def test_intermediate_transform(): Coffee.nodes.fetch_relations("suppliers") .intermediate_transform( { - "coffee": "coffee", - "suppliers": NodeNameResolver("suppliers"), - "r": RelationNameResolver("suppliers"), + "coffee": {"source": "coffee"}, + "suppliers": {"source": NodeNameResolver("suppliers")}, + "r": {"source": RelationNameResolver("suppliers")}, }, ordering=["-r.since"], ) @@ -921,7 +921,7 @@ def test_intermediate_transform(): ): Coffee.nodes.traverse_relations(suppliers="suppliers").intermediate_transform( { - "test": Collect("suppliers"), + "test": {"source": Collect("suppliers")}, } ) with raises( @@ -992,7 +992,7 @@ def test_mix_functions(): .subquery( Student.nodes.fetch_relations("courses") .intermediate_transform( - {"rel": RelationNameResolver("courses")}, + {"rel": {"source": RelationNameResolver("courses")}}, ordering=[ RawCypher("toInteger(split(rel.level, '.')[0])"), RawCypher("toInteger(split(rel.level, '.')[1])"), From ab8f4ba47198dd8b92fdbc6b47258f1049c0b82b Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Fri, 22 Nov 2024 10:27:30 +0100 Subject: [PATCH 03/10] Removed useless code --- neomodel/async_/match.py | 7 +------ neomodel/sync_/match.py | 7 +------ 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 6cd39917..acdc76ff 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -1135,12 +1135,7 @@ def resolve_internal_name(self, qbuilder: AsyncQueryBuilder) -> str: if isinstance(self.input_name, (NodeNameResolver, RelationNameResolver)): self._internal_name = self.input_name.resolve(qbuilder) else: - result = (str(self.input_name), None) - if result is None: - raise ValueError( - f"Unknown variable {self.input_name} used in Collect()" - ) - self._internal_name = result[0] + self._internal_name = str(self.input_name) return self._internal_name def render(self, qbuilder: AsyncQueryBuilder) -> str: diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 8acc4c3d..d59000ca 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -1135,12 +1135,7 @@ def resolve_internal_name(self, qbuilder: QueryBuilder) -> str: if isinstance(self.input_name, (NodeNameResolver, RelationNameResolver)): self._internal_name = self.input_name.resolve(qbuilder) else: - result = (str(self.input_name), None) - if result is None: - raise ValueError( - f"Unknown variable {self.input_name} used in Collect()" - ) - self._internal_name = result[0] + self._internal_name = str(self.input_name) return self._internal_name def render(self, qbuilder: QueryBuilder) -> str: From d1720df2709b86410acd2a150926cb64993b89b3 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Fri, 22 Nov 2024 11:32:10 +0100 Subject: [PATCH 04/10] DISTINCT must be at first place when used within a WITH call --- neomodel/async_/match.py | 15 ++++++++++----- neomodel/sync_/match.py | 15 ++++++++++----- neomodel/typing.py | 1 - test/async_/test_match_api.py | 10 ++++++++-- test/sync_/test_match_api.py | 10 ++++++++-- 5 files changed, 36 insertions(+), 15 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index acdc76ff..a3718d46 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -843,17 +843,17 @@ def build_query(self) -> str: if hasattr(self.node_set, "_intermediate_transforms"): for transform in self.node_set._intermediate_transforms: query += " WITH " + query += "DISTINCT " if transform.get("distinct") else "" injected_vars: list = [] # Reset return list since we'll probably invalidate most variables self._ast.return_clause = "" self._ast.additional_return = [] for name, varprops in transform["vars"].items(): source = varprops["source"] - transformation = "DISTINCT " if varprops.get("distinct") else "" if isinstance(source, (NodeNameResolver, RelationNameResolver)): - transformation += source.resolve(self) + transformation = source.resolve(self) else: - transformation += source + transformation = source if varprops.get("source_prop"): transformation += f".{varprops['source_prop']}" transformation += f" AS {name}" @@ -1542,7 +1542,10 @@ async def subquery( return self def intermediate_transform( - self, vars: Dict[str, Transformation], ordering: TOptional[list] = None + self, + vars: Dict[str, Transformation], + distinct: bool = False, + ordering: TOptional[list] = None, ) -> "AsyncNodeSet": if not vars: raise ValueError( @@ -1556,7 +1559,9 @@ def intermediate_transform( raise ValueError( f"Wrong source type specified for variable '{name}', should be a string or an instance of NodeNameResolver or RelationNameResolver" ) - self._intermediate_transforms.append({"vars": vars, "ordering": ordering}) + self._intermediate_transforms.append( + {"vars": vars, "distinct": distinct, "ordering": ordering} + ) return self diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index d59000ca..cd9a7f43 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -845,17 +845,17 @@ def build_query(self) -> str: if hasattr(self.node_set, "_intermediate_transforms"): for transform in self.node_set._intermediate_transforms: query += " WITH " + query += "DISTINCT " if transform.get("distinct") else "" injected_vars: list = [] # Reset return list since we'll probably invalidate most variables self._ast.return_clause = "" self._ast.additional_return = [] for name, varprops in transform["vars"].items(): source = varprops["source"] - transformation = "DISTINCT " if varprops.get("distinct") else "" if isinstance(source, (NodeNameResolver, RelationNameResolver)): - transformation += source.resolve(self) + transformation = source.resolve(self) else: - transformation += source + transformation = source if varprops.get("source_prop"): transformation += f".{varprops['source_prop']}" transformation += f" AS {name}" @@ -1540,7 +1540,10 @@ def subquery(self, nodeset: "NodeSet", return_set: List[str]) -> "NodeSet": return self def intermediate_transform( - self, vars: Dict[str, Transformation], ordering: TOptional[list] = None + self, + vars: Dict[str, Transformation], + distinct: bool = False, + ordering: TOptional[list] = None, ) -> "NodeSet": if not vars: raise ValueError( @@ -1554,7 +1557,9 @@ def intermediate_transform( raise ValueError( f"Wrong source type specified for variable '{name}', should be a string or an instance of NodeNameResolver or RelationNameResolver" ) - self._intermediate_transforms.append({"vars": vars, "ordering": ordering}) + self._intermediate_transforms.append( + {"vars": vars, "distinct": distinct, "ordering": ordering} + ) return self diff --git a/neomodel/typing.py b/neomodel/typing.py index 2a6afb05..9438bd54 100644 --- a/neomodel/typing.py +++ b/neomodel/typing.py @@ -7,7 +7,6 @@ { "source": Any, "source_prop": Optional[str], - "distinct": Optional[bool], "include_in_return": Optional[bool], }, ) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 2b9c86bf..c83d826f 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -916,10 +916,15 @@ async def test_intermediate_transform(): await Coffee.nodes.fetch_relations("suppliers") .intermediate_transform( { - "coffee": {"source": "coffee"}, + "coffee": {"source": "coffee", "include_in_return": True}, "suppliers": {"source": NodeNameResolver("suppliers")}, "r": {"source": RelationNameResolver("suppliers")}, + "cost": { + "source": NodeNameResolver("suppliers"), + "source_prop": "delivery_cost", + }, }, + distinct=True, ordering=["-r.since"], ) .annotate(oldest_supplier=Last(Collect("suppliers"))) @@ -927,7 +932,8 @@ async def test_intermediate_transform(): ) assert len(result) == 1 - assert result[0] == supplier2 + assert result[0][0] == nescafe + assert result[0][1] == supplier2 with raises( ValueError, diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 61021990..4a5684ea 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -900,10 +900,15 @@ def test_intermediate_transform(): Coffee.nodes.fetch_relations("suppliers") .intermediate_transform( { - "coffee": {"source": "coffee"}, + "coffee": {"source": "coffee", "include_in_return": True}, "suppliers": {"source": NodeNameResolver("suppliers")}, "r": {"source": RelationNameResolver("suppliers")}, + "cost": { + "source": NodeNameResolver("suppliers"), + "source_prop": "delivery_cost", + }, }, + distinct=True, ordering=["-r.since"], ) .annotate(oldest_supplier=Last(Collect("suppliers"))) @@ -911,7 +916,8 @@ def test_intermediate_transform(): ) assert len(result) == 1 - assert result[0] == supplier2 + assert result[0][0] == nescafe + assert result[0][1] == supplier2 with raises( ValueError, From a13697f675beae140c25f649e89bbc50d0cdeb70 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 22 Nov 2024 14:35:25 +0100 Subject: [PATCH 05/10] Add doc and changelog --- Changelog | 3 +++ doc/source/advanced_query_operations.rst | 31 ++++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/Changelog b/Changelog index 079cab72..9d62721d 100644 --- a/Changelog +++ b/Changelog @@ -1,3 +1,6 @@ +Version 5.4.1 2024-11 +* Add options for intermediate_transform : distinct, include_in_return, use a prop as source + Version 5.4.0 2024-11 * Traversal option for filtering and ordering * Insert raw Cypher for ordering diff --git a/doc/source/advanced_query_operations.rst b/doc/source/advanced_query_operations.rst index a1d3aa36..73c5bbd6 100644 --- a/doc/source/advanced_query_operations.rst +++ b/doc/source/advanced_query_operations.rst @@ -58,6 +58,37 @@ As discussed in the note above, this is for example useful when you need to orde ) .annotate(supps=Last(Collect("suppliers"))) +Options for `intermediate_transform` *variables* are: + +- `source`: `string`or `Resolver` - the variable to use as source for the transformation. Works with resolvers (see below). +- `source_prop`: `string` - optionally, a property of the source variable to use as source for the transformation. +- `include_in_return`: `bool` - whether to include the variable in the return statement. Defaults to False. + +Additional options for the `intermediate_transform` method are: +- `distinct`: `bool` - whether to deduplicate the results. Defaults to False. + +Here is a full example:: + + await Coffee.nodes.fetch_relations("suppliers") + .intermediate_transform( + { + "coffee": "coffee", + "suppliers": NodeNameResolver("suppliers"), + "r": RelationNameResolver("suppliers"), + "coffee": {"source": "coffee", "include_in_return": True}, # Only coffee will be returned + "suppliers": {"source": NodeNameResolver("suppliers")}, + "r": {"source": RelationNameResolver("suppliers")}, + "cost": { + "source": NodeNameResolver("suppliers"), + "source_prop": "delivery_cost", + }, + }, + distinct=True, + ordering=["-r.since"], + ) + .annotate(oldest_supplier=Last(Collect("suppliers"))) + .all() + Subqueries ---------- From 1355fa10466fde4e1d3d30b0fed015065bd78f73 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 25 Nov 2024 10:12:16 +0100 Subject: [PATCH 06/10] Add support for parallel runtime --- neomodel/async_/core.py | 12 ++++++++++++ neomodel/async_/match.py | 36 +++++++++++++++++++++++++++++++---- neomodel/sync_/core.py | 9 +++++++++ neomodel/sync_/match.py | 28 ++++++++++++++++++++++++++- test/async_/test_match_api.py | 33 +++++++++++++++++++++++++++++++- test/sync_/test_match_api.py | 30 ++++++++++++++++++++++++++++- 6 files changed, 141 insertions(+), 7 deletions(-) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index 5773da12..c569eadf 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -598,6 +598,18 @@ async def edition_is_enterprise(self) -> bool: edition = await self.database_edition return edition == "enterprise" + @ensure_connection + async def parallel_runtime_available(self) -> bool: + """Returns true if the database supports parallel runtime + + Returns: + bool: True if the database supports parallel runtime + """ + return ( + await self.version_is_higher_than("5.13") + and await self.edition_is_enterprise() + ) + async def change_neo4j_password(self, user, new_password): await self.cypher_query(f"ALTER USER {user} SET PASSWORD '{new_password}'") diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index a3718d46..82eaf67b 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -1,6 +1,7 @@ import inspect import re import string +import warnings from dataclasses import dataclass from typing import Any, Dict, List from typing import Optional as TOptional @@ -396,6 +397,7 @@ def __init__( lookup: TOptional[str] = None, additional_return: TOptional[List[str]] = None, is_count: TOptional[bool] = False, + use_parallel_runtime: TOptional[bool] = False, ) -> None: self.match = match if match else [] self.optional_match = optional_match if optional_match else [] @@ -409,6 +411,7 @@ def __init__( self.lookup = lookup self.additional_return = additional_return if additional_return else [] self.is_count = is_count + self.use_parallel_runtime = use_parallel_runtime self.subgraph: Dict = {} @@ -432,6 +435,19 @@ async def build_ast(self) -> "AsyncQueryBuilder": self._ast.skip = self.node_set.skip if hasattr(self.node_set, "limit"): self._ast.limit = self.node_set.limit + if hasattr(self.node_set, "use_parallel_runtime"): + if ( + self.node_set.use_parallel_runtime + and not await adb.parallel_runtime_available() + ): + warnings.warn( + "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " + "Reverting to default runtime.", + UserWarning, + ) + self.node_set.use_parallel_runtime = False + else: + self._ast.use_parallel_runtime = self.node_set.use_parallel_runtime return self @@ -589,9 +605,11 @@ def build_traversal_from_path( } else: existing_rhs_name = subgraph[part][ - "rel_variable_name" - if relation.get("relation_filtering") - else "variable_name" + ( + "rel_variable_name" + if relation.get("relation_filtering") + else "variable_name" + ) ] if relation["include_in_return"] and not already_present: self._additional_return(rel_ident) @@ -812,6 +830,8 @@ def lookup_query_variable( def build_query(self) -> str: query: str = "" + if self._ast.use_parallel_runtime: + query += "CYPHER runtime=parallel " if self._ast.lookup: query += self._ast.lookup @@ -973,7 +993,9 @@ async def _execute(self, lazy: bool = False, dict_output: bool = False): ] query = self.build_query() results, prop_names = await adb.cypher_query( - query, self._query_params, resolve_objects=True + query, + self._query_params, + resolve_objects=True, ) if dict_output: for item in results: @@ -1236,6 +1258,8 @@ def __init__(self, source) -> None: self._subqueries: list[Tuple[str, list[str]]] = [] self._intermediate_transforms: list = [] + self.use_parallel_runtime = False + def __await__(self): return self.all().__await__() @@ -1564,6 +1588,10 @@ def intermediate_transform( ) return self + def parallel_runtime(self) -> "AsyncNodeSet": + self.use_parallel_runtime = True + return self + class AsyncTraversal(AsyncBaseSet): """ diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 6c72908a..2175fa02 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -596,6 +596,15 @@ def edition_is_enterprise(self) -> bool: edition = self.database_edition return edition == "enterprise" + @ensure_connection + def parallel_runtime_available(self) -> bool: + """Returns true if the database supports parallel runtime + + Returns: + bool: True if the database supports parallel runtime + """ + return self.version_is_higher_than("5.13") and self.edition_is_enterprise() + def change_neo4j_password(self, user, new_password): self.cypher_query(f"ALTER USER {user} SET PASSWORD '{new_password}'") diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index cd9a7f43..c2c9539d 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -1,6 +1,7 @@ import inspect import re import string +import warnings from dataclasses import dataclass from typing import Any, Dict, List from typing import Optional as TOptional @@ -396,6 +397,7 @@ def __init__( lookup: TOptional[str] = None, additional_return: TOptional[List[str]] = None, is_count: TOptional[bool] = False, + use_parallel_runtime: TOptional[bool] = False, ) -> None: self.match = match if match else [] self.optional_match = optional_match if optional_match else [] @@ -409,6 +411,7 @@ def __init__( self.lookup = lookup self.additional_return = additional_return if additional_return else [] self.is_count = is_count + self.use_parallel_runtime = use_parallel_runtime self.subgraph: Dict = {} @@ -432,6 +435,19 @@ def build_ast(self) -> "QueryBuilder": self._ast.skip = self.node_set.skip if hasattr(self.node_set, "limit"): self._ast.limit = self.node_set.limit + if hasattr(self.node_set, "use_parallel_runtime"): + if ( + self.node_set.use_parallel_runtime + and not db.parallel_runtime_available() + ): + warnings.warn( + "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " + "Reverting to default runtime.", + UserWarning, + ) + self.node_set.use_parallel_runtime = False + else: + self._ast.use_parallel_runtime = self.node_set.use_parallel_runtime return self @@ -814,6 +830,8 @@ def lookup_query_variable( def build_query(self) -> str: query: str = "" + if self._ast.use_parallel_runtime: + query += "CYPHER runtime=parallel " if self._ast.lookup: query += self._ast.lookup @@ -973,7 +991,9 @@ def _execute(self, lazy: bool = False, dict_output: bool = False): ] query = self.build_query() results, prop_names = db.cypher_query( - query, self._query_params, resolve_objects=True + query, + self._query_params, + resolve_objects=True, ) if dict_output: for item in results: @@ -1236,6 +1256,8 @@ def __init__(self, source) -> None: self._subqueries: list[Tuple[str, list[str]]] = [] self._intermediate_transforms: list = [] + self.use_parallel_runtime = False + def __await__(self): return self.all().__await__() @@ -1562,6 +1584,10 @@ def intermediate_transform( ) return self + def parallel_runtime(self) -> "NodeSet": + self.use_parallel_runtime = True + return self + class Traversal(BaseSet): """ diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index c83d826f..7df6f7d7 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -1,8 +1,9 @@ import re +import warnings from datetime import datetime from test._async_compat import mark_async_test -from pytest import raises +from pytest import raises, warns from neomodel import ( INCOMING, @@ -1113,3 +1114,33 @@ async def test_async_iterator(): # assert that generator runs loop above assert counter == n + + +@mark_async_test +async def test_parallel_runtime(): + await Coffee(name="Java", price=99).save() + + node_set = AsyncNodeSet(Coffee).parallel_runtime() + + assert node_set.use_parallel_runtime + + if ( + not await adb.version_is_higher_than("5.13") + or not await adb.edition_is_enterprise() + ): + assert not await adb.parallel_runtime_available() + with warns( + UserWarning, + match="Parallel runtime is only available in Neo4j Enterprise Edition 5.13", + ): + qb = await AsyncQueryBuilder(node_set).build_ast() + assert not qb._ast.use_parallel_runtime + assert not qb.build_query().startswith("CYPHER runtime=parallel") + else: + assert await adb.parallel_runtime_available() + qb = await AsyncQueryBuilder(node_set).build_ast() + assert qb._ast.use_parallel_runtime + assert qb.build_query().startswith("CYPHER runtime=parallel") + + results = [node async for node in qb._execute()] + assert len(results) == 1 diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 4a5684ea..2b148601 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -1,8 +1,9 @@ import re +import warnings from datetime import datetime from test._async_compat import mark_sync_test -from pytest import raises +from pytest import raises, warns from neomodel import ( INCOMING, @@ -1097,3 +1098,30 @@ def test_async_iterator(): # assert that generator runs loop above assert counter == n + + +@mark_sync_test +def test_parallel_runtime(): + Coffee(name="Java", price=99).save() + + node_set = NodeSet(Coffee).parallel_runtime() + + assert node_set.use_parallel_runtime + + if not db.version_is_higher_than("5.13") or not db.edition_is_enterprise(): + assert not db.parallel_runtime_available() + with warns( + UserWarning, + match="Parallel runtime is only available in Neo4j Enterprise Edition 5.13", + ): + qb = QueryBuilder(node_set).build_ast() + assert not qb._ast.use_parallel_runtime + assert not qb.build_query().startswith("CYPHER runtime=parallel") + else: + assert db.parallel_runtime_available() + qb = QueryBuilder(node_set).build_ast() + assert qb._ast.use_parallel_runtime + assert qb.build_query().startswith("CYPHER runtime=parallel") + + results = [node for node in qb._execute()] + assert len(results) == 1 From 990815afdabdf80a5d63c11502433325a7c46f2b Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 25 Nov 2024 15:43:37 +0100 Subject: [PATCH 07/10] Use context manager instead. Add doc --- doc/source/transactions.rst | 24 +++++++++-- neomodel/async_/core.py | 23 ++++++++++- neomodel/async_/match.py | 23 ----------- neomodel/sync_/core.py | 23 ++++++++++- neomodel/sync_/match.py | 23 ----------- pyproject.toml | 1 + requirements-dev.txt | 1 + test/async_/test_match_api.py | 76 ++++++++++++++++++++++++----------- test/sync_/test_match_api.py | 76 ++++++++++++++++++++++++----------- 9 files changed, 169 insertions(+), 101 deletions(-) diff --git a/doc/source/transactions.rst b/doc/source/transactions.rst index dfa97ee6..92f5b37e 100644 --- a/doc/source/transactions.rst +++ b/doc/source/transactions.rst @@ -51,7 +51,7 @@ Explicit Transactions Neomodel also supports `explicit transactions `_ that are pre-designated as either *read* or *write*. -This is vital when using neomodel over a `Neo4J causal cluster `_ because internally, queries will be rerouted to different servers depending on their designation. @@ -168,7 +168,7 @@ Impersonation *Neo4j Enterprise feature* -Impersonation (`see Neo4j driver documentation ``) +Impersonation (`see Neo4j driver documentation `_) can be enabled via a context manager:: from neomodel import db @@ -197,4 +197,22 @@ This can be mixed with other context manager like transactions:: @db.transaction() def func2(): - ... \ No newline at end of file + ... + + +Parallel runtime +---------------- + +As of version 5.13, Neo4j *Enterprise Edition* supports parallel runtime for read transactions. + +To use it, you can simply use the `parallel_read_transaction` context manager:: + + from neomodel import db + + with db.parallel_read_transaction: + # It works for both neomodel-generated and custom Cypher queries + parallel_count_1 = len(Coffee.nodes) + parallel_count_2 = db.cypher_query("MATCH (n:Coffee) RETURN count(n)") + +It is worth noting that the parallel runtime is only available for read transactions and that it is not enabled by default, because it is not always the fastest option. It is recommended to test it in your specific use case to see if it improves performance, and read the general considerations in the `Neo4j official documentation `_. + diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index c569eadf..e28895ba 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -104,6 +104,7 @@ def __init__(self): self._database_version = None self._database_edition = None self.impersonated_user = None + self._parallel_runtime = False async def set_connection(self, url: str = None, driver: AsyncDriver = None): """ @@ -239,6 +240,10 @@ def write_transaction(self): def read_transaction(self): return AsyncTransactionProxy(self, access_mode="READ") + @property + def parallel_read_transaction(self): + return AsyncTransactionProxy(self, access_mode="READ", parallel_runtime=True) + async def impersonate(self, user: str) -> "ImpersonationHandler": """All queries executed within this context manager will be executed as impersonated user @@ -454,7 +459,6 @@ async def cypher_query( :return: A tuple containing a list of results and a tuple of headers. """ - if self._active_transaction: # Use current session is a transaction is currently active results, meta = await self._run_cypher_query( @@ -493,6 +497,8 @@ async def _run_cypher_query( try: # Retrieve the data start = time.time() + if self._parallel_runtime: + query = "CYPHER runtime=parallel " + query response: AsyncResult = await session.run(query, params) results, meta = [list(r.values()) async for r in response], response.keys() end = time.time() @@ -1180,17 +1186,30 @@ async def install_all_labels(stdout=None): class AsyncTransactionProxy: bookmarks: Optional[Bookmarks] = None - def __init__(self, db: AsyncDatabase, access_mode=None): + def __init__( + self, db: AsyncDatabase, access_mode: str = None, parallel_runtime: bool = False + ): self.db = db self.access_mode = access_mode + self.parallel_runtime = parallel_runtime @ensure_connection async def __aenter__(self): + if self.parallel_runtime: + if not await self.db.parallel_runtime_available(): + warnings.warn( + "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " + "Reverting to default runtime.", + UserWarning, + ) + self.parallel_runtime = False + self.db._parallel_runtime = self.parallel_runtime await self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks) self.bookmarks = None return self async def __aexit__(self, exc_type, exc_value, traceback): + self.db._parallel_runtime = False if exc_value: await self.db.rollback() diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 82eaf67b..99d08a16 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -397,7 +397,6 @@ def __init__( lookup: TOptional[str] = None, additional_return: TOptional[List[str]] = None, is_count: TOptional[bool] = False, - use_parallel_runtime: TOptional[bool] = False, ) -> None: self.match = match if match else [] self.optional_match = optional_match if optional_match else [] @@ -411,7 +410,6 @@ def __init__( self.lookup = lookup self.additional_return = additional_return if additional_return else [] self.is_count = is_count - self.use_parallel_runtime = use_parallel_runtime self.subgraph: Dict = {} @@ -435,19 +433,6 @@ async def build_ast(self) -> "AsyncQueryBuilder": self._ast.skip = self.node_set.skip if hasattr(self.node_set, "limit"): self._ast.limit = self.node_set.limit - if hasattr(self.node_set, "use_parallel_runtime"): - if ( - self.node_set.use_parallel_runtime - and not await adb.parallel_runtime_available() - ): - warnings.warn( - "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " - "Reverting to default runtime.", - UserWarning, - ) - self.node_set.use_parallel_runtime = False - else: - self._ast.use_parallel_runtime = self.node_set.use_parallel_runtime return self @@ -830,8 +815,6 @@ def lookup_query_variable( def build_query(self) -> str: query: str = "" - if self._ast.use_parallel_runtime: - query += "CYPHER runtime=parallel " if self._ast.lookup: query += self._ast.lookup @@ -1258,8 +1241,6 @@ def __init__(self, source) -> None: self._subqueries: list[Tuple[str, list[str]]] = [] self._intermediate_transforms: list = [] - self.use_parallel_runtime = False - def __await__(self): return self.all().__await__() @@ -1588,10 +1569,6 @@ def intermediate_transform( ) return self - def parallel_runtime(self) -> "AsyncNodeSet": - self.use_parallel_runtime = True - return self - class AsyncTraversal(AsyncBaseSet): """ diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 2175fa02..2b693908 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -104,6 +104,7 @@ def __init__(self): self._database_version = None self._database_edition = None self.impersonated_user = None + self._parallel_runtime = False def set_connection(self, url: str = None, driver: Driver = None): """ @@ -239,6 +240,10 @@ def write_transaction(self): def read_transaction(self): return TransactionProxy(self, access_mode="READ") + @property + def parallel_read_transaction(self): + return TransactionProxy(self, access_mode="READ", parallel_runtime=True) + def impersonate(self, user: str) -> "ImpersonationHandler": """All queries executed within this context manager will be executed as impersonated user @@ -452,7 +457,6 @@ def cypher_query( :return: A tuple containing a list of results and a tuple of headers. """ - if self._active_transaction: # Use current session is a transaction is currently active results, meta = self._run_cypher_query( @@ -491,6 +495,8 @@ def _run_cypher_query( try: # Retrieve the data start = time.time() + if self._parallel_runtime: + query = "CYPHER runtime=parallel " + query response: Result = session.run(query, params) results, meta = [list(r.values()) for r in response], response.keys() end = time.time() @@ -1171,17 +1177,30 @@ def install_all_labels(stdout=None): class TransactionProxy: bookmarks: Optional[Bookmarks] = None - def __init__(self, db: Database, access_mode=None): + def __init__( + self, db: Database, access_mode: str = None, parallel_runtime: bool = False + ): self.db = db self.access_mode = access_mode + self.parallel_runtime = parallel_runtime @ensure_connection def __enter__(self): + if self.parallel_runtime: + if not self.db.parallel_runtime_available(): + warnings.warn( + "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " + "Reverting to default runtime.", + UserWarning, + ) + self.parallel_runtime = False + self.db._parallel_runtime = self.parallel_runtime self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks) self.bookmarks = None return self def __exit__(self, exc_type, exc_value, traceback): + self.db._parallel_runtime = False if exc_value: self.db.rollback() diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index c2c9539d..15a49cfb 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -397,7 +397,6 @@ def __init__( lookup: TOptional[str] = None, additional_return: TOptional[List[str]] = None, is_count: TOptional[bool] = False, - use_parallel_runtime: TOptional[bool] = False, ) -> None: self.match = match if match else [] self.optional_match = optional_match if optional_match else [] @@ -411,7 +410,6 @@ def __init__( self.lookup = lookup self.additional_return = additional_return if additional_return else [] self.is_count = is_count - self.use_parallel_runtime = use_parallel_runtime self.subgraph: Dict = {} @@ -435,19 +433,6 @@ def build_ast(self) -> "QueryBuilder": self._ast.skip = self.node_set.skip if hasattr(self.node_set, "limit"): self._ast.limit = self.node_set.limit - if hasattr(self.node_set, "use_parallel_runtime"): - if ( - self.node_set.use_parallel_runtime - and not db.parallel_runtime_available() - ): - warnings.warn( - "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " - "Reverting to default runtime.", - UserWarning, - ) - self.node_set.use_parallel_runtime = False - else: - self._ast.use_parallel_runtime = self.node_set.use_parallel_runtime return self @@ -830,8 +815,6 @@ def lookup_query_variable( def build_query(self) -> str: query: str = "" - if self._ast.use_parallel_runtime: - query += "CYPHER runtime=parallel " if self._ast.lookup: query += self._ast.lookup @@ -1256,8 +1239,6 @@ def __init__(self, source) -> None: self._subqueries: list[Tuple[str, list[str]]] = [] self._intermediate_transforms: list = [] - self.use_parallel_runtime = False - def __await__(self): return self.all().__await__() @@ -1584,10 +1565,6 @@ def intermediate_transform( ) return self - def parallel_runtime(self) -> "NodeSet": - self.use_parallel_runtime = True - return self - class Traversal(BaseSet): """ diff --git a/pyproject.toml b/pyproject.toml index d72c546b..99335e40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dev = [ "pytest>=7.1", "pytest-asyncio", "pytest-cov>=4.0", + "pytest-mock", "pre-commit", "black", "isort", diff --git a/requirements-dev.txt b/requirements-dev.txt index 446dd8c1..ad82ba50 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,6 +5,7 @@ unasync>=0.5.0 pytest>=7.1 pytest-asyncio>=0.19.0 pytest-cov>=4.0 +pytest-mock pre-commit black isort diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 7df6f7d7..77b4b2ab 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -1,9 +1,8 @@ import re -import warnings from datetime import datetime from test._async_compat import mark_async_test -from pytest import raises, warns +from pytest import raises, skip, warns from neomodel import ( INCOMING, @@ -32,7 +31,11 @@ RawCypher, RelationNameResolver, ) -from neomodel.exceptions import MultipleNodesReturned, RelationshipClassNotDefined +from neomodel.exceptions import ( + FeatureNotSupported, + MultipleNodesReturned, + RelationshipClassNotDefined, +) class SupplierRel(AsyncStructuredRel): @@ -1116,31 +1119,56 @@ async def test_async_iterator(): assert counter == n -@mark_async_test -async def test_parallel_runtime(): - await Coffee(name="Java", price=99).save() - - node_set = AsyncNodeSet(Coffee).parallel_runtime() +def assert_last_query_startswith(mock_func, query) -> bool: + return mock_func.call_args_list[-1].args[0].startswith(query) - assert node_set.use_parallel_runtime +@mark_async_test +async def test_parallel_runtime(mocker): if ( not await adb.version_is_higher_than("5.13") or not await adb.edition_is_enterprise() ): - assert not await adb.parallel_runtime_available() - with warns( - UserWarning, - match="Parallel runtime is only available in Neo4j Enterprise Edition 5.13", - ): - qb = await AsyncQueryBuilder(node_set).build_ast() - assert not qb._ast.use_parallel_runtime - assert not qb.build_query().startswith("CYPHER runtime=parallel") - else: - assert await adb.parallel_runtime_available() - qb = await AsyncQueryBuilder(node_set).build_ast() - assert qb._ast.use_parallel_runtime - assert qb.build_query().startswith("CYPHER runtime=parallel") + skip("Only supported for Enterprise 5.13 and above.") + + assert await adb.parallel_runtime_available() + mock_transaction_run = mocker.patch("neo4j.AsyncTransaction.run") + + # Parallel should be applied to custom Cypher query + async with adb.parallel_read_transaction: + # Mock transaction.run to access executed query + # Assert query starts with CYPHER runtime=parallel + assert adb._parallel_runtime == True + await adb.cypher_query("MATCH (n:Coffee) RETURN n") + assert assert_last_query_startswith( + mock_transaction_run, "CYPHER runtime=parallel" + ) + # Test exiting the context sets the parallel_runtime to False + assert adb._parallel_runtime == False + + # Parallel should be applied to neomodel queries + async with adb.parallel_read_transaction: + await Coffee.nodes + assert len(mock_transaction_run.call_args_list) > 1 + assert assert_last_query_startswith( + mock_transaction_run, "CYPHER runtime=parallel" + ) - results = [node async for node in qb._execute()] - assert len(results) == 1 + +@mark_async_test +async def test_parallel_runtime_conflict(mocker): + if await adb.version_is_higher_than("5.13") and await adb.edition_is_enterprise(): + skip("Test for unavailable parallel runtime.") + + assert not await adb.parallel_runtime_available() + mock_transaction_run = mocker.patch("neo4j.AsyncTransaction.run") + with warns( + UserWarning, + match="Parallel runtime is only available in Neo4j Enterprise Edition 5.13", + ): + async with adb.parallel_read_transaction: + await Coffee.nodes + assert not adb._parallel_runtime + assert not assert_last_query_startswith( + mock_transaction_run, "CYPHER runtime=parallel" + ) diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 2b148601..16ffb532 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -1,9 +1,8 @@ import re -import warnings from datetime import datetime from test._async_compat import mark_sync_test -from pytest import raises, warns +from pytest import raises, skip, warns from neomodel import ( INCOMING, @@ -21,7 +20,11 @@ db, ) from neomodel._async_compat.util import Util -from neomodel.exceptions import MultipleNodesReturned, RelationshipClassNotDefined +from neomodel.exceptions import ( + FeatureNotSupported, + MultipleNodesReturned, + RelationshipClassNotDefined, +) from neomodel.sync_.match import ( Collect, Last, @@ -1100,28 +1103,53 @@ def test_async_iterator(): assert counter == n -@mark_sync_test -def test_parallel_runtime(): - Coffee(name="Java", price=99).save() - - node_set = NodeSet(Coffee).parallel_runtime() +def assert_last_query_startswith(mock_func, query) -> bool: + return mock_func.call_args_list[-1].args[0].startswith(query) - assert node_set.use_parallel_runtime +@mark_sync_test +def test_parallel_runtime(mocker): if not db.version_is_higher_than("5.13") or not db.edition_is_enterprise(): - assert not db.parallel_runtime_available() - with warns( - UserWarning, - match="Parallel runtime is only available in Neo4j Enterprise Edition 5.13", - ): - qb = QueryBuilder(node_set).build_ast() - assert not qb._ast.use_parallel_runtime - assert not qb.build_query().startswith("CYPHER runtime=parallel") - else: - assert db.parallel_runtime_available() - qb = QueryBuilder(node_set).build_ast() - assert qb._ast.use_parallel_runtime - assert qb.build_query().startswith("CYPHER runtime=parallel") + skip("Only supported for Enterprise 5.13 and above.") + + assert db.parallel_runtime_available() + mock_transaction_run = mocker.patch("neo4j.Transaction.run") + + # Parallel should be applied to custom Cypher query + with db.parallel_read_transaction: + # Mock transaction.run to access executed query + # Assert query starts with CYPHER runtime=parallel + assert db._parallel_runtime == True + db.cypher_query("MATCH (n:Coffee) RETURN n") + assert assert_last_query_startswith( + mock_transaction_run, "CYPHER runtime=parallel" + ) + # Test exiting the context sets the parallel_runtime to False + assert db._parallel_runtime == False + + # Parallel should be applied to neomodel queries + with db.parallel_read_transaction: + Coffee.nodes + assert len(mock_transaction_run.call_args_list) > 1 + assert assert_last_query_startswith( + mock_transaction_run, "CYPHER runtime=parallel" + ) - results = [node for node in qb._execute()] - assert len(results) == 1 + +@mark_sync_test +def test_parallel_runtime_conflict(mocker): + if db.version_is_higher_than("5.13") and db.edition_is_enterprise(): + skip("Test for unavailable parallel runtime.") + + assert not db.parallel_runtime_available() + mock_transaction_run = mocker.patch("neo4j.Transaction.run") + with warns( + UserWarning, + match="Parallel runtime is only available in Neo4j Enterprise Edition 5.13", + ): + with db.parallel_read_transaction: + Coffee.nodes + assert not db._parallel_runtime + assert not assert_last_query_startswith( + mock_transaction_run, "CYPHER runtime=parallel" + ) From 53291f4c86f07518cbffec9970d25a017fc851fd Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 25 Nov 2024 16:05:41 +0100 Subject: [PATCH 08/10] Fix tests --- test/async_/test_match_api.py | 10 +++++----- test/sync_/test_match_api.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 77b4b2ab..2dff91c0 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -1132,13 +1132,13 @@ async def test_parallel_runtime(mocker): skip("Only supported for Enterprise 5.13 and above.") assert await adb.parallel_runtime_available() - mock_transaction_run = mocker.patch("neo4j.AsyncTransaction.run") # Parallel should be applied to custom Cypher query async with adb.parallel_read_transaction: # Mock transaction.run to access executed query # Assert query starts with CYPHER runtime=parallel assert adb._parallel_runtime == True + mock_transaction_run = mocker.patch("neo4j.AsyncTransaction.run") await adb.cypher_query("MATCH (n:Coffee) RETURN n") assert assert_last_query_startswith( mock_transaction_run, "CYPHER runtime=parallel" @@ -1148,10 +1148,10 @@ async def test_parallel_runtime(mocker): # Parallel should be applied to neomodel queries async with adb.parallel_read_transaction: - await Coffee.nodes - assert len(mock_transaction_run.call_args_list) > 1 + mock_transaction_run_2 = mocker.patch("neo4j.AsyncTransaction.run") + await Coffee.nodes.all() assert assert_last_query_startswith( - mock_transaction_run, "CYPHER runtime=parallel" + mock_transaction_run_2, "CYPHER runtime=parallel" ) @@ -1167,7 +1167,7 @@ async def test_parallel_runtime_conflict(mocker): match="Parallel runtime is only available in Neo4j Enterprise Edition 5.13", ): async with adb.parallel_read_transaction: - await Coffee.nodes + await Coffee.nodes.all() assert not adb._parallel_runtime assert not assert_last_query_startswith( mock_transaction_run, "CYPHER runtime=parallel" diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 16ffb532..4df51866 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -1113,13 +1113,13 @@ def test_parallel_runtime(mocker): skip("Only supported for Enterprise 5.13 and above.") assert db.parallel_runtime_available() - mock_transaction_run = mocker.patch("neo4j.Transaction.run") # Parallel should be applied to custom Cypher query with db.parallel_read_transaction: # Mock transaction.run to access executed query # Assert query starts with CYPHER runtime=parallel assert db._parallel_runtime == True + mock_transaction_run = mocker.patch("neo4j.Transaction.run") db.cypher_query("MATCH (n:Coffee) RETURN n") assert assert_last_query_startswith( mock_transaction_run, "CYPHER runtime=parallel" @@ -1129,10 +1129,10 @@ def test_parallel_runtime(mocker): # Parallel should be applied to neomodel queries with db.parallel_read_transaction: - Coffee.nodes - assert len(mock_transaction_run.call_args_list) > 1 + mock_transaction_run_2 = mocker.patch("neo4j.Transaction.run") + Coffee.nodes.all() assert assert_last_query_startswith( - mock_transaction_run, "CYPHER runtime=parallel" + mock_transaction_run_2, "CYPHER runtime=parallel" ) @@ -1148,7 +1148,7 @@ def test_parallel_runtime_conflict(mocker): match="Parallel runtime is only available in Neo4j Enterprise Edition 5.13", ): with db.parallel_read_transaction: - Coffee.nodes + Coffee.nodes.all() assert not db._parallel_runtime assert not assert_last_query_startswith( mock_transaction_run, "CYPHER runtime=parallel" From 01bfb6e405bea125def5b7575da8f826e057c191 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 25 Nov 2024 16:10:24 +0100 Subject: [PATCH 09/10] Fixed leftover code smell --- neomodel/async_/core.py | 15 +++++++-------- neomodel/sync_/core.py | 15 +++++++-------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index e28895ba..bfa5b8b9 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -1195,14 +1195,13 @@ def __init__( @ensure_connection async def __aenter__(self): - if self.parallel_runtime: - if not await self.db.parallel_runtime_available(): - warnings.warn( - "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " - "Reverting to default runtime.", - UserWarning, - ) - self.parallel_runtime = False + if self.parallel_runtime and not await self.db.parallel_runtime_available(): + warnings.warn( + "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " + "Reverting to default runtime.", + UserWarning, + ) + self.parallel_runtime = False self.db._parallel_runtime = self.parallel_runtime await self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks) self.bookmarks = None diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 2b693908..75b7a10e 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -1186,14 +1186,13 @@ def __init__( @ensure_connection def __enter__(self): - if self.parallel_runtime: - if not self.db.parallel_runtime_available(): - warnings.warn( - "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " - "Reverting to default runtime.", - UserWarning, - ) - self.parallel_runtime = False + if self.parallel_runtime and not self.db.parallel_runtime_available(): + warnings.warn( + "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " + "Reverting to default runtime.", + UserWarning, + ) + self.parallel_runtime = False self.db._parallel_runtime = self.parallel_runtime self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks) self.bookmarks = None From f9344d50b854c7e881fd46a2e70e7669be300362 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 25 Nov 2024 16:14:16 +0100 Subject: [PATCH 10/10] Update changelog --- Changelog | 1 + 1 file changed, 1 insertion(+) diff --git a/Changelog b/Changelog index 9d62721d..91170523 100644 --- a/Changelog +++ b/Changelog @@ -1,4 +1,5 @@ Version 5.4.1 2024-11 +* Add support for Cypher parallel runtime * Add options for intermediate_transform : distinct, include_in_return, use a prop as source Version 5.4.0 2024-11