Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-period intervals in throttling #1373

Merged
merged 1 commit into from
Dec 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/docs/guides/throttling.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@ Throttles allows to control the rate of requests that clients can make to an API

Django Ninja’s throttling feature is pretty much based on what Django Rest Framework (DRF) uses, which you can check out [here](https://www.django-rest-framework.org/api-guide/throttling/). So, if you’ve already got custom throttling set up for DRF, there’s a good chance it’ll work with Django Ninja right out of the box. The key difference is that you need to pass initialized Throttle objects instead of classes (which should give a better performance).

You can specify a rate using the format requests/time-unit, where time-unit represents a number of units followed by an optional unit of time. If the unit is omitted, it defaults to seconds. For example, the following are equivalent and all represent "100 requests per 5 minutes":

* 100/5m
* 100/300s
* 100/300

The following units are supported:

* `s` or `sec`
* `m` or `min`
* `h` or `hour`
* `d` or `day`

## Usage

Expand Down
31 changes: 27 additions & 4 deletions ninja/throttling.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ class SimpleRateThrottle(BaseThrottle):
cache_format = "throttle_%(scope)s_%(ident)s"
scope: Optional[str] = None
THROTTLE_RATES: Dict[str, Optional[str]] = settings.DEFAULT_THROTTLE_RATES
_PERIODS = {
"s": 1,
"m": 60,
"h": 60 * 60,
"d": 60 * 60 * 24,
"sec": 1,
"min": 60,
"hour": 60 * 60,
"day": 60 * 60 * 24,
}

def __init__(self, rate: Optional[str] = None):
self.rate: Optional[str]
Expand Down Expand Up @@ -106,10 +116,23 @@ def parse_rate(self, rate: Optional[str]) -> Tuple[Optional[int], Optional[int]]
"""
if rate is None:
return (None, None)
num, period = rate.split("/")
num_requests = int(num)
duration = {"s": 1, "m": 60, "h": 3600, "d": 86400}[period[0]]
return (num_requests, duration)

try:
count, rest = rate.split("/", 1)

for unit in self._PERIODS:
if rest.endswith(unit):
multi = int(rest[: -len(unit)]) if rest[: -len(unit)] else 1
period = self._PERIODS[unit]
break
else:
multi, period = int(rest), 1

count = int(count)
return count, multi * period

except (ValueError, IndexError):
raise ValueError(f"Invalid rate format: {rate}") from None

def allow_request(self, request: HttpRequest) -> bool:
"""
Expand Down
13 changes: 13 additions & 0 deletions tests/test_throttling.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,22 @@ def test_rate_parser():
th = SimpleRateThrottle("1/s")
assert th.parse_rate(None) == (None, None)
assert th.parse_rate("1/s") == (1, 1)
assert th.parse_rate("1/sec") == (1, 1)
assert th.parse_rate("100/10s") == (100, 10)
assert th.parse_rate("100/10sec") == (100, 10)
assert th.parse_rate("100/10") == (100, 10)
assert th.parse_rate("5/m") == (5, 60)
assert th.parse_rate("5/min") == (5, 60)
assert th.parse_rate("500/10m") == (500, 600)
assert th.parse_rate("500/10min") == (500, 600)
assert th.parse_rate("10/h") == (10, 3600)
assert th.parse_rate("10/hour") == (10, 3600)
assert th.parse_rate("1000/2h") == (1000, 7200)
assert th.parse_rate("1000/2hour") == (1000, 7200)
assert th.parse_rate("100/d") == (100, 86400)
assert th.parse_rate("100/day") == (100, 86400)
assert th.parse_rate("10_000/7d") == (10000, 86400 * 7)
assert th.parse_rate("10_000/7day") == (10000, 86400 * 7)


def test_proxy_throttle():
Expand Down
Loading