diff --git a/doc/source/advanced_query_operations.rst b/doc/source/advanced_query_operations.rst index de1c8c61..74c15683 100644 --- a/doc/source/advanced_query_operations.rst +++ b/doc/source/advanced_query_operations.rst @@ -117,7 +117,7 @@ Options for `subquery` calls are: .. note:: In the example above, we reference `self` to be included in the initial context. It will actually inject the outer variable corresponding to `Coffee` node. - We know this is confusing to read, but have not found a better wat to do this yet. If you have any suggestions, please let us know. + We know this is confusing to read, but have not found a better way to do this yet. If you have any suggestions, please let us know. Helpers ------- diff --git a/neomodel/_version.py b/neomodel/_version.py index 1e41bf8f..cfda0f8e 100644 --- a/neomodel/_version.py +++ b/neomodel/_version.py @@ -1 +1 @@ -__version__ = "5.4.1" +__version__ = "5.4.2" diff --git a/neomodel/typing.py b/neomodel/typing.py index f0558096..a23f88eb 100644 --- a/neomodel/typing.py +++ b/neomodel/typing.py @@ -1,6 +1,6 @@ """Custom types used for annotations.""" -from typing import Any, Dict, List, Optional, TypedDict +from typing import Any, Optional, TypedDict Transformation = TypedDict( "Transformation", @@ -16,8 +16,8 @@ "Subquery", { "query": str, - "query_params": Dict, - "return_set": List[str], - "initial_context": Optional[List[Any]], + "query_params": dict, + "return_set": list[str], + "initial_context": Optional[list[Any]], }, ) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index a494ae42..70c7f351 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -2,6 +2,7 @@ from datetime import datetime from test._async_compat import mark_async_test +import numpy as np from pytest import raises, skip, warns from neomodel import ( @@ -880,7 +881,7 @@ async def test_subquery(): await nescafe.suppliers.connect(supplier2) await nescafe.species.connect(arabica) - result = await Coffee.nodes.subquery( + subquery = await Coffee.nodes.subquery( Coffee.nodes.traverse_relations(suppliers="suppliers") .intermediate_transform( {"suppliers": {"source": "suppliers"}}, ordering=["suppliers.delivery_cost"] @@ -889,7 +890,7 @@ async def test_subquery(): ["supps"], [NodeNameResolver("self")], ) - result = await result.all() + result = await subquery.all() assert len(result) == 1 assert len(result[0]) == 2 assert result[0][0] == supplier2 @@ -905,6 +906,30 @@ async def test_subquery(): ["unknown"], ) + result_string_context = await subquery.subquery( + Coffee.nodes.traverse_relations(supps2="suppliers").annotate( + supps2=Collect("supps") + ), + ["supps2"], + ["supps"], + ) + result_string_context = await result_string_context.all() + assert len(result) == 1 + additional_elements = [ + item for item in result_string_context[0] if item not in result[0] + ] + assert len(additional_elements) == 1 + assert isinstance(additional_elements[0], list) + + with raises(ValueError, match=r"Wrong variable specified in initial context"): + result = await Coffee.nodes.subquery( + Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( + supps=Collect("suppliers") + ), + ["supps"], + [2], + ) + @mark_async_test async def test_subquery_other_node(): diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 0bf69b7f..94465db2 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -2,6 +2,7 @@ from datetime import datetime from test._async_compat import mark_sync_test +import numpy as np from pytest import raises, skip, warns from neomodel import ( @@ -864,7 +865,7 @@ def test_subquery(): nescafe.suppliers.connect(supplier2) nescafe.species.connect(arabica) - result = Coffee.nodes.subquery( + subquery = Coffee.nodes.subquery( Coffee.nodes.traverse_relations(suppliers="suppliers") .intermediate_transform( {"suppliers": {"source": "suppliers"}}, ordering=["suppliers.delivery_cost"] @@ -873,7 +874,7 @@ def test_subquery(): ["supps"], [NodeNameResolver("self")], ) - result = result.all() + result = subquery.all() assert len(result) == 1 assert len(result[0]) == 2 assert result[0][0] == supplier2 @@ -889,6 +890,30 @@ def test_subquery(): ["unknown"], ) + result_string_context = subquery.subquery( + Coffee.nodes.traverse_relations(supps2="suppliers").annotate( + supps2=Collect("supps") + ), + ["supps2"], + ["supps"], + ) + result_string_context = result_string_context.all() + assert len(result) == 1 + additional_elements = [ + item for item in result_string_context[0] if item not in result[0] + ] + assert len(additional_elements) == 1 + assert isinstance(additional_elements[0], list) + + with raises(ValueError, match=r"Wrong variable specified in initial context"): + result = Coffee.nodes.subquery( + Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( + supps=Collect("suppliers") + ), + ["supps"], + [2], + ) + @mark_sync_test def test_subquery_other_node():