Skip to content

Commit

Permalink
refactor: start using TaskGroups
Browse files Browse the repository at this point in the history
  • Loading branch information
hartungstenio committed Apr 5, 2024
1 parent 251055e commit 7e54e1f
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 27 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
"Topic :: System :: Distributed Computing",
]
requires-python = ">=3.8"
dependencies = ["aiobotocore>=2.0.0,<3.0.0"]
dependencies = ["aiobotocore>=2.0.0,<3.0.0", "taskgroup ; python_version < '3.11'"]

[project.urls]
Download = "https://github.com/olist/olist-loafer/releases"
Expand Down
9 changes: 7 additions & 2 deletions src/loafer/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
PY312 = sys.version_info >= (3, 12)

if PY311:
from asyncio import to_thread
from asyncio import TaskGroup, to_thread
else:
import asyncio
import functools

from taskgroup import TaskGroup

async def to_thread(func, /, *args, **kwargs):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, functools.partial(func, *args, **kwargs))
return await loop.run_in_executor(
None, functools.partial(func, *args, **kwargs)
)


if PY312:
Expand All @@ -23,4 +27,5 @@ async def to_thread(func, /, *args, **kwargs):
__all__ = [
"to_thread",
"iscoroutinefunction",
"TaskGroup",
]
57 changes: 34 additions & 23 deletions src/loafer/dispatchers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import logging
import sys
from contextlib import suppress
from functools import partial
from typing import Any, Optional, Sequence

from .compat import TaskGroup
from .exceptions import DeleteMessage
from .routes import Route

Expand Down Expand Up @@ -53,12 +54,13 @@ async def _process_message(self, message: Any, route: Route) -> bool:
return confirmation

async def _fetch_messages(
self, processing_queue: asyncio.Queue, forever: bool = True
self,
processing_queue: asyncio.Queue,
tg: TaskGroup,
forever: bool = True,
) -> None:
routes = [route for route in self.routes]
tasks = [
asyncio.create_task(route.provider.fetch_messages()) for route in routes
]
tasks = [tg.create_task(route.provider.fetch_messages()) for route in routes]

while routes or tasks:
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
Expand All @@ -67,13 +69,16 @@ async def _fetch_messages(
new_tasks = []
for task, route in zip(tasks, routes):
if task.done():
if task.exception():
raise task.exception()

for message in task.result():
await processing_queue.put((message, route))

if forever:
new_routes.append(route)
new_tasks.append(
asyncio.create_task(route.provider.fetch_messages())
tg.create_task(route.provider.fetch_messages())
)
else:
new_routes.append(route)
Expand All @@ -82,29 +87,35 @@ async def _fetch_messages(
routes = new_routes
tasks = new_tasks

async def _consume_messages(self, processing_queue: asyncio.Queue) -> None:
with suppress(asyncio.CancelledError):
while True:
message, route = await processing_queue.get()
asyncio.create_task(
self._process_message(message, route)
).add_done_callback(lambda _: processing_queue.task_done())
def _mark_task_done(self, queue: asyncio.Queue, *args, **kwargs):
queue.task_done()

async def _consume_messages(
self, processing_queue: asyncio.Queue, tg: TaskGroup
) -> None:
mark_task_done = partial(self._mark_task_done, processing_queue)

while True:
message, route = await processing_queue.get()

task = tg.create_task(self._process_message(message, route))
task.add_done_callback(mark_task_done)

async def dispatch_providers(self, forever: bool = True) -> None:
processing_queue = asyncio.Queue(self.max_concurrency)
provider_task = asyncio.create_task(
self._fetch_messages(processing_queue, forever)
)
consumer_task = asyncio.create_task(self._consume_messages(processing_queue))

async def join():
await provider_task
await processing_queue.join()
consumer_task.cancel()
async with TaskGroup() as tg:
provider_task = tg.create_task(
self._fetch_messages(processing_queue, tg, forever)
)
consumer_task = tg.create_task(self._consume_messages(processing_queue, tg))

joining_task = asyncio.create_task(join())
async def join():
await provider_task
await processing_queue.join()
consumer_task.cancel()

await asyncio.gather(provider_task, consumer_task, joining_task)
tg.create_task(join())

def stop(self) -> None:
for route in self.routes:
Expand Down
8 changes: 7 additions & 1 deletion tests/test_dispatchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
from unittest import mock

import pytest
from loafer.compat import PY311
from loafer.dispatchers import LoaferDispatcher
from loafer.exceptions import DeleteMessage
from loafer.routes import Route

if not PY311:
from exceptiongroup import ExceptionGroup


def create_mock_route(messages):
provider = mock.AsyncMock(
Expand Down Expand Up @@ -128,9 +132,11 @@ async def test_dispatch_providers_with_error(route):
route.provider.fetch_messages.side_effect = ValueError
dispatcher = LoaferDispatcher([route])

with pytest.raises(ValueError):
with pytest.raises(ExceptionGroup) as exc_info:
await dispatcher.dispatch_providers(forever=False)

assert exc_info.value.subgroup(ValueError) is not None


def test_dispatcher_stop(route):
route.stop = mock.Mock()
Expand Down

0 comments on commit 7e54e1f

Please sign in to comment.