Skip to content

Commit

Permalink
Merge pull request #800 from neo4j-contrib/798-async-iterating-over-n…
Browse files Browse the repository at this point in the history
…odes-is-not-working

Fix async node iterator
  • Loading branch information
mariusconjeaud authored May 28, 2024
2 parents f7dcdb5 + 4db58a0 commit 5f58cc4
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 28 deletions.
23 changes: 15 additions & 8 deletions neomodel/async_/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,8 +785,11 @@ async def _execute(self, lazy=False):
# It seems that certain calls are only supposed to be focusing to the first
# result item returned (?)
if results and len(results[0]) == 1:
return [n[0] for n in results]
return results
for n in results:
yield n[0]
else:
for result in results:
yield result


class AsyncBaseSet:
Expand All @@ -806,12 +809,15 @@ async def all(self, lazy=False):
:rtype: list
"""
ast = await self.query_cls(self).build_ast()
return await ast._execute(lazy)
results = [
node async for node in ast._execute(lazy)
] # Collect all nodes asynchronously
return results

async def __aiter__(self):
ast = await self.query_cls(self).build_ast()
async for i in await ast._execute():
yield i
async for item in ast._execute():
yield item

async def get_len(self):
ast = await self.query_cls(self).build_ast()
Expand Down Expand Up @@ -862,8 +868,8 @@ async def get_item(self, key):
self.limit = 1

ast = await self.query_cls(self).build_ast()
_items = ast._execute()
return _items[0]
_first_item = [node async for node in ast._execute()][0]
return _first_item

Check warning on line 872 in neomodel/async_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/async_/match.py#L871-L872

Added lines #L871 - L872 were not covered by tests

return None

Expand Down Expand Up @@ -911,7 +917,8 @@ async def _get(self, limit=None, lazy=False, **kwargs):
if limit:
self.limit = limit
ast = await self.query_cls(self).build_ast()
return await ast._execute(lazy)
results = [node async for node in ast._execute(lazy)]
return results

async def get(self, lazy=False, **kwargs):
"""
Expand Down
23 changes: 15 additions & 8 deletions neomodel/sync_/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,8 +781,11 @@ def _execute(self, lazy=False):
# It seems that certain calls are only supposed to be focusing to the first
# result item returned (?)
if results and len(results[0]) == 1:
return [n[0] for n in results]
return results
for n in results:
yield n[0]
else:
for result in results:
yield result


class BaseSet:
Expand All @@ -802,12 +805,15 @@ def all(self, lazy=False):
:rtype: list
"""
ast = self.query_cls(self).build_ast()
return ast._execute(lazy)
results = [
node for node in ast._execute(lazy)
] # Collect all nodes asynchronously
return results

def __iter__(self):
ast = self.query_cls(self).build_ast()
for i in ast._execute():
yield i
for item in ast._execute():
yield item

def __len__(self):
ast = self.query_cls(self).build_ast()
Expand Down Expand Up @@ -858,8 +864,8 @@ def __getitem__(self, key):
self.limit = 1

ast = self.query_cls(self).build_ast()
_items = ast._execute()
return _items[0]
_first_item = [node for node in ast._execute()][0]
return _first_item

return None

Expand Down Expand Up @@ -907,7 +913,8 @@ def _get(self, limit=None, lazy=False, **kwargs):
if limit:
self.limit = limit
ast = self.query_cls(self).build_ast()
return ast._execute(lazy)
results = [node for node in ast._execute(lazy)]
return results

def get(self, lazy=False, **kwargs):
"""
Expand Down
45 changes: 39 additions & 6 deletions test/async_/test_match_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def test_filter_exclude_via_labels():
node_set = AsyncNodeSet(Coffee)
qb = await AsyncQueryBuilder(node_set).build_ast()

results = await qb._execute()
results = [node async for node in qb._execute()]

assert "(coffee:Coffee)" in qb._ast.match
assert qb._ast.result_class
Expand All @@ -76,7 +76,7 @@ async def test_filter_exclude_via_labels():
node_set = node_set.filter(price__gt=2).exclude(price__gt=6, name="Java")
qb = await AsyncQueryBuilder(node_set).build_ast()

results = await qb._execute()
results = [node async for node in qb._execute()]
assert "(coffee:Coffee)" in qb._ast.match
assert "NOT" in qb._ast.where[0]
assert len(results) == 1
Expand All @@ -91,15 +91,15 @@ async def test_simple_has_via_label():

ns = AsyncNodeSet(Coffee).has(suppliers=True)
qb = await AsyncQueryBuilder(ns).build_ast()
results = await qb._execute()
results = [node async for node in qb._execute()]
assert "COFFEE SUPPLIERS" in qb._ast.where[0]
assert len(results) == 1
assert results[0].name == "Nescafe"

await Coffee(name="nespresso", price=99).save()
ns = AsyncNodeSet(Coffee).has(suppliers=False)
qb = await AsyncQueryBuilder(ns).build_ast()
results = await qb._execute()
results = [node async for node in qb._execute()]
assert len(results) > 0
assert "NOT" in qb._ast.where[0]

Expand Down Expand Up @@ -129,7 +129,7 @@ async def test_simple_traverse_with_filter():
)

_ast = await qb.build_ast()
results = await _ast._execute()
results = [node async for node in qb._execute()]

assert qb._ast.lookup
assert qb._ast.match
Expand All @@ -148,7 +148,7 @@ async def test_double_traverse():
ns = AsyncNodeSet(AsyncNodeSet(source=nescafe).suppliers.match()).coffees.match()
qb = await AsyncQueryBuilder(ns).build_ast()

results = await qb._execute()
results = [node async for node in qb._execute()]
assert len(results) == 2
assert results[0].name == "Decafe"
assert results[1].name == "Nescafe plus"
Expand Down Expand Up @@ -589,3 +589,36 @@ async def test_in_filter_with_array_property():
assert arabica not in await Species.nodes.filter(
tags__in=no_match
), "Species found by tags with not match tags given"


@mark_async_test
async def test_async_iterator():
n = 10
if AsyncUtil.is_async_code:
for c in await Coffee.nodes:
await c.delete()

for i in range(n):
await Coffee(name=f"xxx_{i}", price=i).save()

nodes = await Coffee.nodes
# assert that nodes was created
assert isinstance(nodes, list)
assert all(isinstance(i, Coffee) for i in nodes)
assert len(nodes) == n

counter = 0
async for node in Coffee.nodes:
assert isinstance(node, Coffee)
counter += 1

# assert that generator runs loop above
assert counter == n

counter = 0
for node in await Coffee.nodes:
assert isinstance(node, Coffee)
counter += 1

# assert that generator runs loop above
assert counter == n
45 changes: 39 additions & 6 deletions test/sync_/test_match_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_filter_exclude_via_labels():
node_set = NodeSet(Coffee)
qb = QueryBuilder(node_set).build_ast()

results = qb._execute()
results = [node for node in qb._execute()]

assert "(coffee:Coffee)" in qb._ast.match
assert qb._ast.result_class
Expand All @@ -69,7 +69,7 @@ def test_filter_exclude_via_labels():
node_set = node_set.filter(price__gt=2).exclude(price__gt=6, name="Java")
qb = QueryBuilder(node_set).build_ast()

results = qb._execute()
results = [node for node in qb._execute()]
assert "(coffee:Coffee)" in qb._ast.match
assert "NOT" in qb._ast.where[0]
assert len(results) == 1
Expand All @@ -84,15 +84,15 @@ def test_simple_has_via_label():

ns = NodeSet(Coffee).has(suppliers=True)
qb = QueryBuilder(ns).build_ast()
results = qb._execute()
results = [node for node in qb._execute()]
assert "COFFEE SUPPLIERS" in qb._ast.where[0]
assert len(results) == 1
assert results[0].name == "Nescafe"

Coffee(name="nespresso", price=99).save()
ns = NodeSet(Coffee).has(suppliers=False)
qb = QueryBuilder(ns).build_ast()
results = qb._execute()
results = [node for node in qb._execute()]
assert len(results) > 0
assert "NOT" in qb._ast.where[0]

Expand Down Expand Up @@ -120,7 +120,7 @@ def test_simple_traverse_with_filter():
qb = QueryBuilder(NodeSet(source=nescafe).suppliers.match(since__lt=datetime.now()))

_ast = qb.build_ast()
results = _ast._execute()
results = [node for node in qb._execute()]

assert qb._ast.lookup
assert qb._ast.match
Expand All @@ -139,7 +139,7 @@ def test_double_traverse():
ns = NodeSet(NodeSet(source=nescafe).suppliers.match()).coffees.match()
qb = QueryBuilder(ns).build_ast()

results = qb._execute()
results = [node for node in qb._execute()]
assert len(results) == 2
assert results[0].name == "Decafe"
assert results[1].name == "Nescafe plus"
Expand Down Expand Up @@ -578,3 +578,36 @@ def test_in_filter_with_array_property():
assert arabica not in Species.nodes.filter(
tags__in=no_match
), "Species found by tags with not match tags given"


@mark_sync_test
def test_async_iterator():
n = 10
if Util.is_async_code:
for c in Coffee.nodes:
c.delete()

for i in range(n):
Coffee(name=f"xxx_{i}", price=i).save()

nodes = Coffee.nodes
# assert that nodes was created
assert isinstance(nodes, list)
assert all(isinstance(i, Coffee) for i in nodes)
assert len(nodes) == n

counter = 0
for node in Coffee.nodes:
assert isinstance(node, Coffee)
counter += 1

# assert that generator runs loop above
assert counter == n

counter = 0
for node in Coffee.nodes:
assert isinstance(node, Coffee)
counter += 1

# assert that generator runs loop above
assert counter == n

0 comments on commit 5f58cc4

Please sign in to comment.