Skip to content

Commit

Permalink
Add wait pull time for test worker
Browse files Browse the repository at this point in the history
  • Loading branch information
TheSuperiorStanislav committed Jul 2, 2024
1 parent 02a74b1 commit 4998de5
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
18 changes: 18 additions & 0 deletions sns_sqs_communicator/testing/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def pytest_addoption(parser: pytest.Parser) -> None:
"sqs_poll_worker_class",
"Path to sqs poll worker class.",
)
parser.addini(
"sns_sqs_worker_wait_before_pull",
"Time for worker to wait before pulling.",
)


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -516,6 +520,18 @@ async def sqs_poll_worker_class(
return getattr(importlib.import_module(".".join(module)), klass)


@pytest.fixture
async def sns_sqs_worker_wait_before_pull(
request: pytest.FixtureRequest,
) -> int | float:
"""Time to wait before worker will pull messages."""
return float(
str(
request.config.inicfg.get("sns_sqs_worker_wait_before_pull", 0.5),
),
)


@pytest.fixture
async def sns_sqs_worker(
sns_topic: topic_module.SNSTopic,
Expand All @@ -524,6 +540,7 @@ async def sns_sqs_worker(
sqs_queue: queue_module.SQSQueue,
dead_letter_sqs_queue: queue_module.SQSQueue,
sns_parser: type[parsers_module.ParserProtocol[typing.Any]],
sns_sqs_worker_wait_before_pull: int | float,
) -> typing.AsyncGenerator[worker.TestWorker[typing.Any], None]:
"""Set up sns sqs worker."""
yield worker.TestWorker(
Expand All @@ -533,6 +550,7 @@ async def sns_sqs_worker(
sns_topic=sns_topic,
logger=logger,
parser=sns_parser,
wait_before_pull=sns_sqs_worker_wait_before_pull,
)
await sqs_queue.receive_all()
await dead_letter_sqs_queue.receive_all()
4 changes: 4 additions & 0 deletions sns_sqs_communicator/testing/worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import collections.abc
import logging
import typing
Expand All @@ -18,13 +19,15 @@ def __init__(
sns_topic: topic.SNSTopic,
parser: type[parsers.ParserProtocol[messages.MessageActionT]],
logger: logging.Logger,
wait_before_pull: int | float = 0.5,
) -> None:
self.sqs_poll_worker_class = sqs_poll_worker_class
self.queue = sqs_queue
self.dead_letter_queue = dead_letter_sqs_queue
self.sns_topic = sns_topic
self.parser = parser
self.logger = logger
self.wait_before_pull = wait_before_pull

async def publish_and_pull(
self,
Expand All @@ -35,6 +38,7 @@ async def publish_and_pull(
body=message.serialize_body(),
metadata=message.metadata,
)
await asyncio.sleep(self.wait_before_pull)
return await self.pull()

async def pull(
Expand Down

0 comments on commit 4998de5

Please sign in to comment.