diff --git a/celery_heimdall/__init__.py b/celery_heimdall/__init__.py index 36f3195..93d4988 100644 --- a/celery_heimdall/__init__.py +++ b/celery_heimdall/__init__.py @@ -1,4 +1,4 @@ -__all__ = ("HeimdallTask", "AlreadyQueuedError", "RateLimit") +__all__ = ("HeimdallTask", "AlreadyQueuedError", "RateLimit", "HeimdallConfig") -from celery_heimdall.task import HeimdallTask, RateLimit +from celery_heimdall.task import HeimdallTask, RateLimit, HeimdallConfig from celery_heimdall.errors import AlreadyQueuedError diff --git a/celery_heimdall/contrib/README.md b/celery_heimdall/contrib/README.md deleted file mode 100644 index a581ecc..0000000 --- a/celery_heimdall/contrib/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Contrib - -This directory contains optional integrations and tools. \ No newline at end of file diff --git a/celery_heimdall/contrib/__init__.py b/celery_heimdall/contrib/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/celery_heimdall/contrib/inspector/README.md b/celery_heimdall/contrib/inspector/README.md deleted file mode 100644 index 84188d5..0000000 --- a/celery_heimdall/contrib/inspector/README.md +++ /dev/null @@ -1,25 +0,0 @@ -# Inspector - -**Note:** This tool is in beta, and currently only tested against SQLite as a -data store. - -The Inspector is a minimal debugging tool for working with Celery queues and -tasks. It is an optional component of celery-heimdall and not installed by -default. - -It runs a monitor, which populates any SQLAlchemy-compatible database with the -state of your Celery cluster. - -## Why? - -This tool is used to assist in debugging, generate graphs of queues for -documentation, to verify the final state of Celery after tests, etc... - -Flower deprecated their graphs page, and now require you to use prometheus and -grafana, which is overkill when you just want to see what's been running. - -## Installation - -``` -pip install celery-heimdall[inspector] -``` \ No newline at end of file diff --git a/celery_heimdall/contrib/inspector/__init__.py b/celery_heimdall/contrib/inspector/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/celery_heimdall/contrib/inspector/cli.py b/celery_heimdall/contrib/inspector/cli.py deleted file mode 100644 index eb6d339..0000000 --- a/celery_heimdall/contrib/inspector/cli.py +++ /dev/null @@ -1,49 +0,0 @@ -from pathlib import Path - -import click -from celery import Celery - -from celery_heimdall.contrib.inspector.monitor import monitor - - -@click.group() -def cli(): - """ - heimdall-inspector provides tools for introspecting a live Celery cluster. - """ - - -@cli.command("monitor") -@click.argument("broker_url") -@click.option( - "--enable-events", - default=False, - is_flag=True, - help=( - "Sends a command-and-control message to all Celery workers to start" - " emitting worker events before starting the server." - ), -) -@click.option( - "--db", - default="heimdall.db", - type=click.Path(dir_okay=False, writable=True, path_type=Path), - help=("Use the provided path to store our sqlite database."), -) -def monitor_command(broker_url: str, enable_events: bool, db: Path): - """ - Starts a monitor to watch for Celery events and records them to an SQLite - database. - - Optionally enables event monitoring on a live cluster if --enable-events is - provided. Note that it will not stop events when finished. - """ - if enable_events: - celery_app = Celery(broker=broker_url) - celery_app.control.enable_events() - - monitor(broker=broker_url, db=db) - - -if __name__ == "__main__": - cli() diff --git a/celery_heimdall/contrib/inspector/models.py b/celery_heimdall/contrib/inspector/models.py deleted file mode 100644 index 1e45a82..0000000 --- a/celery_heimdall/contrib/inspector/models.py +++ /dev/null @@ -1,84 +0,0 @@ -import enum - -from sqlalchemy import ( - Column, - TIMESTAMP, - Integer, - String, - func, - BigInteger, - text, - Enum, -) -from sqlalchemy.orm import declarative_base, sessionmaker - -Base = declarative_base() -Session = sessionmaker() - - -class WorkerStatus(enum.Enum): - #: Worker is answering heartbeats. - ALIVE = 0 - #: We didn't get an offline event, but we're not getting heartbeats. - LOST = 10 - #: We got a shutdown event for the worker. - OFFLINE = 20 - - -class TaskStatus(enum.Enum): - RECEIVED = 0 - STARTED = 10 - SUCCEEDED = 20 - FAILED = 30 - REJECTED = 40 - REVOKED = 50 - RETRIED = 60 - - -class TaskInstance(Base): - __tablename__ = "task_instance" - - uuid = Column(String, primary_key=True) - name = Column(String) - status = Column(Enum(TaskStatus)) - hostname = Column(String) - args = Column(String) - kwargs = Column(String) - - runtime = Column(Integer) - - received = Column(TIMESTAMP, server_default=func.now()) - started = Column(TIMESTAMP, nullable=True) - failed = Column(TIMESTAMP, nullable=True) - rejected = Column(TIMESTAMP, nullable=True) - succeeded = Column(TIMESTAMP, nullable=True) - - retries = Column(Integer, server_default=text("0")) - last_seen = Column(TIMESTAMP, server_onupdate=func.now()) - - -class Worker(Base): - __tablename__ = "worker" - - #: The hostname of a worker is used as its ID. - id = Column(String, primary_key=True) - - #: How often the worker is configured to send heartbeats. - frequency = Column(Integer, server_default=text("0")) - #: Name of the worker software - sw_identity = Column(String, nullable=True) - #: Version of the worker software. - sw_version = Column(String, nullable=True) - #: Host operating system of the worker. - sw_system = Column(String, nullable=True) - - #: Number of currently executing tasks. - active = Column(BigInteger, server_default=text("0")) - #: Number of processed tasks. - processed = Column(BigInteger, server_default=text("0")) - - #: Last known status of the worker. - status = Column(Enum(WorkerStatus)) - - first_seen = Column(TIMESTAMP, server_default=func.now()) - last_seen = Column(TIMESTAMP, server_onupdate=func.now()) diff --git a/celery_heimdall/contrib/inspector/monitor.py b/celery_heimdall/contrib/inspector/monitor.py deleted file mode 100644 index 816e30b..0000000 --- a/celery_heimdall/contrib/inspector/monitor.py +++ /dev/null @@ -1,161 +0,0 @@ -import datetime -from pathlib import Path - -from celery import Celery -from sqlalchemy import create_engine, insert, func, update -from sqlalchemy.dialects.sqlite import insert - -from celery_heimdall.contrib.inspector import models - - -def task_received(event): - with models.Session() as session: - session.execute( - insert(models.TaskInstance.__table__).values( - uuid=event["uuid"], - name=event["name"], - status=models.TaskStatus.RECEIVED, - hostname=event["hostname"], - args=event["args"], - kwargs=event["kwargs"], - received=datetime.datetime.fromtimestamp(event["timestamp"]), - ) - ) - session.commit() - - -def task_started(event): - with models.Session() as session: - session.execute( - update(models.TaskInstance.__table__) - .where(models.TaskInstance.uuid == event["uuid"]) - .values( - runtime=event.get("runtime", 0), - status=models.TaskStatus.STARTED, - started=datetime.datetime.fromtimestamp(event["timestamp"]), - last_seen=func.now(), - ) - ) - session.commit() - - -def task_succeeded(event): - with models.Session() as session: - session.execute( - update(models.TaskInstance.__table__) - .where(models.TaskInstance.uuid == event["uuid"]) - .values( - runtime=event.get("runtime", 0), - status=models.TaskStatus.SUCCEEDED, - succeeded=datetime.datetime.fromtimestamp(event["timestamp"]), - last_seen=func.now(), - ) - ) - session.commit() - - -def task_retried(event): - with models.Session() as session: - session.execute( - update(models.TaskInstance.__table__) - .where(models.TaskInstance.uuid == event["uuid"]) - .values( - runtime=event.get("runtime", 0), - status=models.TaskStatus.RETRIED, - retries=models.TaskInstance.retries + 1, - last_seen=func.now(), - ) - ) - session.commit() - - -def task_failed(event): - with models.Session() as session: - session.execute( - update(models.TaskInstance.__table__) - .where(models.TaskInstance.uuid == event["uuid"]) - .values( - runtime=event.get("runtime", 0), - status=models.TaskStatus.FAILED, - failed=datetime.datetime.fromtimestamp(event["timestamp"]), - last_seen=func.now(), - ) - ) - session.commit() - - -def task_rejected(event): - with models.Session() as session: - session.execute( - update(models.TaskInstance.__table__) - .where(models.TaskInstance.uuid == event["uuid"]) - .values( - runtime=event.get("runtime", 0), - status=models.TaskStatus.REJECTED, - rejected=datetime.datetime.fromtimestamp(event["timestamp"]), - last_seen=func.now(), - ) - ) - session.commit() - - -def worker_event(event): - field_mapping = { - "freq": models.Worker.frequency, - "sw_ident": models.Worker.sw_identity, - "sw_ver": models.Worker.sw_version, - "sw_sys": models.Worker.sw_system, - "active": models.Worker.active, - "processed": models.Worker.processed, - } - - payload = { - "last_seen": func.now(), - "status": { - "worker-heartbeat": models.WorkerStatus.ALIVE, - "worker-online": models.WorkerStatus.ALIVE, - "worker-offline": models.WorkerStatus.OFFLINE, - }.get(event["type"], models.WorkerStatus.LOST), - } - for k, v in field_mapping.items(): - if k in event: - payload[v] = event[k] - - # FIXME: Support postgres / MySQL - with models.Session() as session: - session.execute( - insert(models.Worker.__table__) - .values({"id": event["hostname"], **payload}) - .on_conflict_do_update(index_elements=["id"], set_=payload) - ) - session.commit() - - -def monitor(*, broker: str, db: Path): - """ - A real-time Celery event monitor which captures events and populates a - supported SQLAlchemy database. - """ - app = Celery(broker=broker) - - engine = create_engine(f"sqlite:///{db}") - models.Session.configure(bind=engine) - models.Base.metadata.create_all(engine) - - with app.connection() as connection: - recv = app.events.Receiver( - connection, - handlers={ - # '*': state.event, - "task-started": task_started, - "task-rejected": task_rejected, - "task-failed": task_failed, - "task-received": task_received, - "task-succeeded": task_succeeded, - "task-retried": task_retried, - "worker-online": worker_event, - "worker-heartbeat": worker_event, - "worker-offline": worker_event, - }, - ) - recv.capture(limit=None, timeout=None, wakeup=True) diff --git a/celery_heimdall/rate.py b/celery_heimdall/rate.py new file mode 100644 index 0000000..39f8572 --- /dev/null +++ b/celery_heimdall/rate.py @@ -0,0 +1,91 @@ +from typing import List, Optional, Tuple + +from redis import Redis +from celery import Task + + +def check_rate_limits( + redis_client: Redis, + rate_limits: List[Tuple[bytes, int, int]], +) -> Optional[float]: + + pipeline = redis_client.pipeline() + max_delay = 0.0 + + # Check each rate limit + for rate_key, count, period in rate_limits: + # Get current count and timestamp + pipeline.get(rate_key) + pipeline.ttl(rate_key) + value, ttl = pipeline.execute() + + if value is None: + # First execution within this period + pipeline.set(rate_key, 1, ex=period) + pipeline.execute() + continue + + count_used = int(value) + if count_used >= count: + # Rate limit exceeded, calculate delay + if ttl > 0: + delay = float(ttl) + max_delay = max(max_delay, delay) + continue + + # Increment counter + pipeline.incr(rate_key) + if ttl <= 0: + pipeline.expire(rate_key, period) + pipeline.execute() + + return max_delay + + +def get_rate_limits( + task: Task, config, args, kwargs +) -> List[Tuple[bytes, int, int]]: + """ + Convert rate limit configuration into a list of (key, count, period) tuples. + + Args: + task: The Celery task being rate limited + config: Task configuration containing rate limit settings + args: Task arguments + kwargs: Task keyword arguments + + Returns: + List of (key, count, period) tuples defining the rate limits + """ + if config.rate_limit is None: + return [] + + if not isinstance(config.rate_limit, list): + rate_limits_config = [config.rate_limit] + else: + rate_limits_config = config.rate_limit + + rate_limits = [] + for limit_config in rate_limits_config: + # Get the rate limit values + if callable(limit_config.rate_limit): + count, period = limit_config.rate_limit() + else: + count, period = limit_config.rate_limit + + # Get the key for this rate limit + if limit_config.key: + # Use provided key or key function + if callable(limit_config.key): + key = limit_config.key(args, kwargs) + else: + key = limit_config.key + else: + # Fall back to task's unique key + key = config.get_key(task, args, kwargs).decode('utf-8') + + # Build the full Redis key with prefix + full_key = f"{config.get_rate_limit_prefix()}:{key}".encode('utf-8') + rate_limits.append((full_key, count, period)) + + return rate_limits \ No newline at end of file diff --git a/celery_heimdall/task.py b/celery_heimdall/task.py index faa0a90..28df5b5 100644 --- a/celery_heimdall/task.py +++ b/celery_heimdall/task.py @@ -2,7 +2,7 @@ from abc import ABC from dataclasses import dataclass from enum import Enum -from functools import cache +from functools import cache, cached_property from typing import Callable import celery @@ -13,6 +13,7 @@ from . import lock from .errors import AlreadyQueuedError +from .rate import get_rate_limits, check_rate_limits class RateLimitStrategy(Enum): @@ -29,10 +30,13 @@ class RateLimit: A rate limit configuration for a HeimdallTask. """ - # The rate limit to apply to the task. Can be a tuple in the form of - # (times, per) or a callable that returns a tuple. - rate_limit: tuple | Callable - # The strategy to use for rate limiting. + #: The rate limit to apply to the task. Can be a tuple in the form of + #: (times, per) or a callable that returns the tuple. + rate_limit: tuple[int, int] | Callable[..., tuple[int, int]] + #: The key to use for rate limiting. If not provided, the key will be + #: taken from the unique key of the task. + key: str | None = None + #: The strategy to use for rate limiting. strategy: RateLimitStrategy = RateLimitStrategy.DEFAULT @@ -53,29 +57,32 @@ class HeimdallConfig: Configuration options for a HeimdallTask. """ - # If True, the task will be globally unique, allowing only one instance - # to run or be queued at a time. + #: If True, the task will be globally unique, allowing only one instance + #: to run or be queued at a time. unique: bool = False - # If True, the lock will be acquired before the task is queued. + #: If True, the lock will be acquired before the task is queued. unique_early: bool = True - # If True, the lock will be acquired when the task is started. + #: If True, the lock will be acquired when the task is started. unique_late: bool = True - # If True, the task will raise an exception if it's already queued. + #: If True, the task will raise an exception if it's already queued. unique_raises: bool = False - # The amount of time to wait before allowing the task lock to expire, - # even if the task has not yet completed. + #: The amount of time to wait before allowing the task lock to expire, + #: even if the task has not yet completed. unique_expiry: int = 60 * 100 - # If True, the task will wait for the lock to expire instead of releasing - # it, even if the task has already completed. This can be used to easily - # implement tasks that should only run once per interval. + #: If True, the task will wait for the lock to expire instead of releasing + #: it, even if the task has already completed. This can be used to easily + #: implement tasks that should only run once per interval. unique_wait_for_expiry: bool = False - # A user-provided unique key for the task. If not specified, a unique - # key will be generated from the task's arguments. - key: str | Callable = None - - # The default prefix to use for the task lock key. + #: A user-provided unique key for the task. If not specified, a unique + #: key will be generated from the task's arguments. + key: str | Callable[..., str] = None + #: The rate limit to apply to the task. Can optionally be a list of + #: rate limits which will be applied in order. + rate_limit: RateLimit | list[RateLimit] | None = None + + #: The default prefix to use for the task lock key. lock_prefix: str = "h-lock" - # The default prefix to use for the rate limit key. + #: The default prefix to use for the rate limit key. rate_limit_prefix: str = "h-rate" def get_redis(self, app: Celery) -> redis.Redis: @@ -151,53 +158,6 @@ def get_key(self, task: celery.Task, args, kwargs) -> bytes: return f"{self.get_lock_prefix()}:{k}".encode("utf-8") -class HeimdallNamespace: - """ - A namespace for Heimdall configuration and utilities for a task, to keep - them from conflicting with other task mixins. - """ - - def __init__(self, task: "HeimdallTask"): - self.task = task - self.config = getattr(task, "heimdall", HeimdallConfig()) - self.redis = self.config.get_redis(task.app) - - def extend_lock(self, milliseconds: int): - """ - Extends the expiry on the lock for the current task by the given number - of milliseconds. - """ - if not self.config.unique: - raise ValueError("Task is not configured to have a unique lock") - - key = self.config.get_key( - self.task, self.task.request.args, self.task.request.kwargs - ) - - return lock.extend( - self.redis, - key, - self.task.request.id.encode("utf-8"), - milliseconds, - replace=False, - ) - - def clear_lock(self) -> bool: - """ - Clears the lock for the current task. - """ - if not self.config.unique: - raise ValueError("Task is not configured to have a unique lock") - - key = self.config.get_key( - self.task, self.task.request.args, self.task.request.kwargs - ) - - return lock.release( - self.redis, key, token=self.task.request.id.encode("utf-8") - ) - - class HeimdallTask(celery.Task, ABC): """ A base task for Celery that adds helpful features such as global rate @@ -209,23 +169,35 @@ class HeimdallTask(celery.Task, ABC): abstract = True - def __init__(self, *args, **kwargs): - self._bifrost = None - super().__init__(*args, **kwargs) - def __call__(self, *args, **kwargs): - bifrost = self.bifrost() + if self.h_config.rate_limit is not None: + rate_limits = get_rate_limits(self, self.h_config, args, kwargs) + if rate_limits: + delay = check_rate_limits(self.h_redis, rate_limits) + if delay > 0: + # We don't want our rescheduling retry to count against + # any normal retry limits the user might have set on the + # task or globally. + self.request.retries -= 1 + # Max retries needs to be set to None _before_ calling + # retry(). This value will not propagate, allowing the user's + # normal retry behaviour to apply on the next call. + self.max_retries = None + # Retrying with a future ETA is flawed in Celery. Celery + # will schedule these by pulling them into worker memory + # which can cause massive problems on busy queues. + raise self.retry(countdown=delay) # Acquire a globally unique lock for this task at the time it's # executed. - if bifrost.config.unique and bifrost.config.unique_late: - key = bifrost.config.get_key(self, args, kwargs) + if self.h_config.unique and self.h_config.unique_late: + key = self.h_config.get_key(self, args, kwargs) token = lock.lock( - bifrost.redis, + self.h_redis, key, token=self.request.id.encode("utf-8"), - expiry=bifrost.config.unique_expiry, + expiry=self.h_config.unique_expiry, ).decode("utf-8") if not token == self.request.id: @@ -234,24 +206,22 @@ def __call__(self, *args, **kwargs): return self.run(*args, **kwargs) def apply_async(self, args=None, kwargs=None, task_id=None, **options): - bifrost = self.bifrost() - - if bifrost.config.unique and bifrost.config.unique_early: + if self.h_config.unique and self.h_config.unique_early: # Acquire a globally unique lock for this task before it's queued. # In some cases, this function may not be called, such as by # send_task() or non-standard task execution. task_id: str = task_id or uuid() - key = bifrost.config.get_key(self, args, kwargs) + key = self.h_config.get_key(self, args, kwargs) token = lock.lock( - bifrost.redis, + self.h_redis, key, token=task_id.encode("utf-8"), - expiry=bifrost.config.unique_expiry, + expiry=self.h_config.unique_expiry, ).decode("utf-8") if not token == task_id: - if not bifrost.config.unique_raises: + if not self.h_config.unique_raises: # If the task is not configured to raise an exception when # it's already queued, we'll just return the task ID of the # task that already holds the lock. @@ -268,29 +238,72 @@ def after_return(self, status, retval, task_id, args, kwargs, einfo): # Handles post-task cleanup, when a task exits cleanly. This will be # called if a task raises an exception (stored in `einfo`), but not # if a worker straight up dies (say, because of running out of memory) - bifrost = self.bifrost() - # Cleanup the unique task lock when the task finishes, unless the user - # told us to wait for the remaining interval. - if bifrost.config.unique and not bifrost.config.unique_wait_for_expiry: - key = bifrost.config.get_key(self, args, kwargs) + if self.h_config.unique and not self.h_config.unique_wait_for_expiry: + key = self.h_config.get_key(self, args, kwargs) # It's not an error for our lock to have already been cleared by # another token, because our token may have expired. - lock.release(bifrost.redis, key, token=task_id.encode("utf-8")) + lock.release(self.h_redis, key, token=task_id.encode("utf-8")) super().after_return(status, retval, task_id, args, kwargs, einfo) - def bifrost(self) -> HeimdallNamespace: + @cached_property + def h_config(self) -> HeimdallConfig: + """ + Heimdall's configuration for the current task. + """ + return getattr(self, "heimdall", HeimdallConfig()) + + @cached_property + def h_redis(self) -> redis.Redis: + """ + Heimdall's cached redis connection. + """ + return self.h_config.get_redis(self.app) + + def h_only_after(self, key: str, seconds: int) -> bool: + """ + Only returns true if either the key has expired in the cache or if + the current task already holds the key. + + This can be used to gate parts of a task that should only be run + once per time period. + + .. code-block:: python + + @shared_task(bind=True, task=HeimdallTask) + def my_task(self, *args, **kwargs): + if self.only_after("my_task", 60 * 60): + print( + "It's been at least an hour since this task was" + " last run." + ) """ - Get the Heimdall namespace for the task. - This object contains the Heimdall configuration for this task, redis - caches, and other helpful utilities. It's designed to be used by the - task to interact with Heimdall features while minimizing conflicts - with other Task implementations. + token = self.request.id.encode("utf-8") + return lock.lock( + self.h_redis, + key.encode("utf-8"), + token=token, + expiry=seconds, + ) == token + + def h_extend_lock(self, milliseconds: int): """ - if self._bifrost is not None: - return self._bifrost + Extends the expiry on the lock for the current task by the given number + of milliseconds. + """ + if not self.h_config.unique: + raise ValueError("Task is not configured to have a unique lock") - self._bifrost = HeimdallNamespace(self) - return self._bifrost + key = self.h_config.get_key( + self, self.request.args, self.request.kwargs + ) + + return lock.extend( + self.h_redis, + key, + self.request.id.encode("utf-8"), + milliseconds, + replace=False, + ) diff --git a/pyproject.toml b/pyproject.toml index 2a05c95..68f8c32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,12 +9,9 @@ homepage = "https://github.com/tktech/celery-heimdall" repository = "https://github.com/tktech/celery-heimdall" [tool.poetry.dependencies] -python = ">3.7" -celery = ">5.2.7" +python = ">3.10" +celery = ">5.3" redis = "*" -click = {version = "^8.1.3", optional = true} -SQLAlchemy = {version = "^1.4.40", optional = true} -importlib-metadata = "<=4.13" [tool.poetry.dev-dependencies] pytest = "^7.1.2" @@ -22,12 +19,6 @@ bumpversion = "^0.6.0" coverage = "^6.4.4" pytest-cov = "^3.0.0" -[tool.poetry.extras] -inspector = ["click", "sqlalchemy"] - -[tool.poetry.scripts] -heimdall-inspector = "celery_heimdall.contrib.inspector.cli:cli" - [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" diff --git a/tests/test_only_after.py b/tests/test_only_after.py index 596d192..ccc3a2a 100644 --- a/tests/test_only_after.py +++ b/tests/test_only_after.py @@ -8,7 +8,7 @@ @shared_task(base=HeimdallTask, bind=True) def task_with_block(self: HeimdallTask): - if self.only_after('only_after', 5): + if self.h_only_after('only_after', 5): return True return False diff --git a/tests/test_rate_limited.py b/tests/test_rate_limited.py index 3bfd221..b7b5bd0 100644 --- a/tests/test_rate_limited.py +++ b/tests/test_rate_limited.py @@ -1,51 +1,52 @@ import time -import celery.result import pytest from celery import shared_task -from celery_heimdall import HeimdallTask, RateLimit +from celery_heimdall import HeimdallTask, RateLimit, HeimdallConfig + @shared_task( base=HeimdallTask, - heimdall={ - 'times': 2, - 'per': 10 - } + heimdall=HeimdallConfig( + rate_limit=RateLimit((2, 10)) + ) ) -def default_rate_limit_task(): +def tuple_rate_limit_task(): pass @shared_task( base=HeimdallTask, - heimdall={ - 'rate_limit': RateLimit((2, 10)) - } + heimdall=HeimdallConfig( + rate_limit=RateLimit(lambda *args, **kwargs: (2, 10)) + ) ) -def tuple_rate_limit_task(): +def callable_rate_limit_task(): pass @shared_task( base=HeimdallTask, - heimdall={ - 'rate_limit': RateLimit(lambda key: (2, 10)) - } + heimdall=HeimdallConfig( + rate_limit=[ + RateLimit((2, 30), key="global"), + RateLimit((1, 10)) + ] + ) ) -def callable_rate_limit_task(): - pass +def multiple_rate_limit_task(key: str): + return key @pytest.mark.parametrize('func', [ - default_rate_limit_task, tuple_rate_limit_task, callable_rate_limit_task ]) def test_default_rate_limit(celery_session_worker, func): """ - Ensure a unique task with no other configuration "just works". + Ensure that rate limiting works as expected. """ start = time.time() # Immediate @@ -78,3 +79,34 @@ def test_default_rate_limit(celery_session_worker, func): elapsed = time.time() - start assert 20 < elapsed < 30 + + +def test_multiple_rate_limit(celery_session_worker): + """ + Ensure that rate limiting works as expected when multiple rate limits + are configured. + """ + start = time.time() + + # Since both task1 and task2 use distinct arguments and are using the + # default key, they will run immediately. + task1 = multiple_rate_limit_task.delay("t1") + task2 = multiple_rate_limit_task.delay("t2") + # ... but task3 will be delayed by the global rate limit. + task3 = multiple_rate_limit_task.delay("t3") + # ... and task4 will be delayed by the global rate limit. + task4 = multiple_rate_limit_task.delay("t2") + + task1.get() + task2.get() + + elapsed = time.time() - start + assert elapsed < 5 + + task3.get() + elapsed = time.time() - start + assert 30 < elapsed < 40 + + task4.get() + elapsed = time.time() - start + assert 30 < elapsed < 40 \ No newline at end of file diff --git a/tests/test_unique.py b/tests/test_unique.py index c5f122b..d03cf1e 100644 --- a/tests/test_unique.py +++ b/tests/test_unique.py @@ -8,8 +8,7 @@ import pytest from celery.result import AsyncResult -from celery_heimdall import HeimdallTask, AlreadyQueuedError -from celery_heimdall.task import HeimdallConfig +from celery_heimdall import HeimdallTask, AlreadyQueuedError, HeimdallConfig @celery.shared_task(base=HeimdallTask, heimdall=HeimdallConfig(unique=True)) @@ -46,7 +45,7 @@ def explicit_key_callable_task(): ) def task_with_override_config(task: HeimdallTask): time.sleep(2) - return task.bifrost().config.get_lock_prefix() + return task.h_config.get_lock_prefix() @celery.shared_task(