From c273813008be8a6373ceb2e73203504a7166bb94 Mon Sep 17 00:00:00 2001 From: Aliaksandr Stsiapanay Date: Fri, 13 Dec 2024 14:29:10 +0300 Subject: [PATCH] feat: Return retry-after header in case if rate limit is exceeded #615 --- .../controller/DeploymentPostController.java | 14 ++++- .../server/controller/RouteController.java | 14 ++++- .../core/server/limiter/RateBucket.java | 18 ++++++ .../core/server/limiter/RateLimitResult.java | 4 +- .../core/server/limiter/RateLimiter.java | 2 +- .../core/server/limiter/RequestRateLimit.java | 5 +- .../core/server/limiter/TokenRateLimit.java | 5 +- .../aidial/core/server/ResourceBaseTest.java | 6 +- .../core/server/limiter/RateBucketTest.java | 55 ++++++++++++++++++- 9 files changed, 111 insertions(+), 12 deletions(-) diff --git a/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentPostController.java b/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentPostController.java index 28ad7d6df..ba53a4afe 100644 --- a/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentPostController.java +++ b/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentPostController.java @@ -49,6 +49,7 @@ import java.io.InputStream; import java.math.BigDecimal; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Set; @@ -208,9 +209,20 @@ private void handleRateLimitHit(String deploymentId, RateLimitResult result) { ErrorData rateLimitError = new ErrorData(); rateLimitError.getError().setCode(String.valueOf(result.status().getCode())); rateLimitError.getError().setMessage(result.errorMessage()); + log.error("Rate limit error {}. Project: {}. User sub: {}. Deployment: {}. Trace: {}. Span: {}", result.errorMessage(), context.getProject(), context.getUserSub(), deploymentId, context.getTraceId(), context.getSpanId()); - respond(result.status(), rateLimitError); + + String errorMessage = ProxyUtil.convertToString(rateLimitError); + HttpException httpException; + if (result.replyAfterSeconds() >= 0) { + Map headers = Map.of(HttpHeaders.RETRY_AFTER.toString(), Long.toString(result.replyAfterSeconds())); + httpException = new HttpException(result.status(), errorMessage, headers); + } else { + httpException = new HttpException(result.status(), errorMessage); + } + + respond(httpException); } private void handleError(Throwable error) { diff --git a/server/src/main/java/com/epam/aidial/core/server/controller/RouteController.java b/server/src/main/java/com/epam/aidial/core/server/controller/RouteController.java index 328a47050..aab2e82b5 100644 --- a/server/src/main/java/com/epam/aidial/core/server/controller/RouteController.java +++ b/server/src/main/java/com/epam/aidial/core/server/controller/RouteController.java @@ -28,6 +28,7 @@ import java.nio.charset.StandardCharsets; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.regex.Pattern; @@ -205,10 +206,21 @@ private void handleRateLimitHit(RateLimitResult result) { ErrorData rateLimitError = new ErrorData(); rateLimitError.getError().setCode(String.valueOf(result.status().getCode())); rateLimitError.getError().setMessage(result.errorMessage()); + log.error("Rate limit error {}. Project: {}. User sub: {}. Route: {}. Trace: {}. Span: {}", result.errorMessage(), context.getProject(), context.getUserSub(), context.getRoute().getName(), context.getTraceId(), context.getSpanId()); - context.respond(result.status(), rateLimitError); + + String errorMessage = ProxyUtil.convertToString(rateLimitError); + HttpException httpException; + if (result.replyAfterSeconds() >= 0) { + Map headers = Map.of(HttpHeaders.RETRY_AFTER.toString(), Long.toString(result.replyAfterSeconds())); + httpException = new HttpException(result.status(), errorMessage, headers); + } else { + httpException = new HttpException(result.status(), errorMessage); + } + + context.respond(httpException); } private void handleError(Throwable error) { diff --git a/server/src/main/java/com/epam/aidial/core/server/limiter/RateBucket.java b/server/src/main/java/com/epam/aidial/core/server/limiter/RateBucket.java index f71e53513..fff3b7487 100644 --- a/server/src/main/java/com/epam/aidial/core/server/limiter/RateBucket.java +++ b/server/src/main/java/com/epam/aidial/core/server/limiter/RateBucket.java @@ -3,6 +3,8 @@ import lombok.Data; import lombok.NoArgsConstructor; +import java.util.concurrent.TimeUnit; + @Data @NoArgsConstructor public class RateBucket { @@ -52,6 +54,22 @@ public long update(long timestamp) { return sum; } + /** + * Returns the number of seconds the user agent should wait before making a retry request. + * + * @param limit - requested limit + */ + long retryAfter(long limit) { + long sum = this.sum; + long replyAfter = 0; + for (long start = this.start; start < this.end && sum >= limit; start++) { + int index = index(start); + sum -= sums[index]; + replyAfter += window.interval(); + } + return TimeUnit.MILLISECONDS.toSeconds(replyAfter); + } + private long interval(long timestamp) { if (timestamp < window.window()) { throw new IllegalArgumentException("timestamp < window"); diff --git a/server/src/main/java/com/epam/aidial/core/server/limiter/RateLimitResult.java b/server/src/main/java/com/epam/aidial/core/server/limiter/RateLimitResult.java index c1f4b5e25..cfa577b90 100644 --- a/server/src/main/java/com/epam/aidial/core/server/limiter/RateLimitResult.java +++ b/server/src/main/java/com/epam/aidial/core/server/limiter/RateLimitResult.java @@ -2,6 +2,6 @@ import com.epam.aidial.core.storage.http.HttpStatus; -public record RateLimitResult(HttpStatus status, String errorMessage) { - public static final RateLimitResult SUCCESS = new RateLimitResult(HttpStatus.OK, null); +public record RateLimitResult(HttpStatus status, String errorMessage, long replyAfterSeconds) { + public static final RateLimitResult SUCCESS = new RateLimitResult(HttpStatus.OK, null, -1); } diff --git a/server/src/main/java/com/epam/aidial/core/server/limiter/RateLimiter.java b/server/src/main/java/com/epam/aidial/core/server/limiter/RateLimiter.java index 1286883a8..aa6145d00 100644 --- a/server/src/main/java/com/epam/aidial/core/server/limiter/RateLimiter.java +++ b/server/src/main/java/com/epam/aidial/core/server/limiter/RateLimiter.java @@ -70,7 +70,7 @@ public Future limit(ProxyContext context, RoleBasedEntity roleB } else { log.warn("Limit must be positive for {}", name); } - return Future.succeededFuture(new RateLimitResult(HttpStatus.FORBIDDEN, "Access denied")); + return Future.succeededFuture(new RateLimitResult(HttpStatus.FORBIDDEN, "Access denied", -1)); } return vertx.executeBlocking(() -> checkLimit(context, limit, roleBasedEntity), false); diff --git a/server/src/main/java/com/epam/aidial/core/server/limiter/RequestRateLimit.java b/server/src/main/java/com/epam/aidial/core/server/limiter/RequestRateLimit.java index 8ef54afd4..ac5c6c110 100644 --- a/server/src/main/java/com/epam/aidial/core/server/limiter/RequestRateLimit.java +++ b/server/src/main/java/com/epam/aidial/core/server/limiter/RequestRateLimit.java @@ -18,7 +18,10 @@ public RateLimitResult check(long timestamp, Limit limit, long count) { if (result) { String errorMsg = String.format("Hit request rate limit. Hour limit: %d / %d requests. Day limit: %d / %d requests.", hourTotal, limit.getRequestHour(), dayTotal, limit.getRequestDay()); - return new RateLimitResult(HttpStatus.TOO_MANY_REQUESTS, errorMsg); + long hourRetryAfter = hour.retryAfter(limit.getRequestHour()); + long dayRetryAfter = day.retryAfter(limit.getRequestDay()); + long retryAfter = Math.max(hourRetryAfter, dayRetryAfter); + return new RateLimitResult(HttpStatus.TOO_MANY_REQUESTS, errorMsg, retryAfter); } else { hour.add(timestamp, count); day.add(timestamp, count); diff --git a/server/src/main/java/com/epam/aidial/core/server/limiter/TokenRateLimit.java b/server/src/main/java/com/epam/aidial/core/server/limiter/TokenRateLimit.java index bb9a760ac..688ba7664 100644 --- a/server/src/main/java/com/epam/aidial/core/server/limiter/TokenRateLimit.java +++ b/server/src/main/java/com/epam/aidial/core/server/limiter/TokenRateLimit.java @@ -24,7 +24,10 @@ public RateLimitResult update(long timestamp, Limit limit) { if (result) { String errorMsg = String.format("Hit token rate limit. Minute limit: %d / %d tokens. Day limit: %d / %d tokens.", minuteTotal, limit.getMinute(), dayTotal, limit.getDay()); - return new RateLimitResult(HttpStatus.TOO_MANY_REQUESTS, errorMsg); + long minuteRetryAfter = minute.retryAfter(limit.getMinute()); + long dayRetryAfter = day.retryAfter(limit.getDay()); + long retryAfter = Math.max(minuteRetryAfter, dayRetryAfter); + return new RateLimitResult(HttpStatus.TOO_MANY_REQUESTS, errorMsg, retryAfter); } else { return RateLimitResult.SUCCESS; } diff --git a/server/src/test/java/com/epam/aidial/core/server/ResourceBaseTest.java b/server/src/test/java/com/epam/aidial/core/server/ResourceBaseTest.java index 635f5356d..94cae7559 100644 --- a/server/src/test/java/com/epam/aidial/core/server/ResourceBaseTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/ResourceBaseTest.java @@ -23,7 +23,7 @@ import org.apache.hc.client5.http.entity.mime.HttpMultipartMode; import org.apache.hc.client5.http.entity.mime.MultipartEntityBuilder; import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; -import org.apache.hc.client5.http.impl.classic.HttpClients; +import org.apache.hc.client5.http.impl.classic.HttpClientBuilder; import org.apache.hc.core5.http.ContentType; import org.apache.hc.core5.http.Header; import org.apache.hc.core5.http.io.entity.EntityUtils; @@ -122,7 +122,9 @@ void init() throws Exception { .build(); redis.start(); - client = HttpClients.createDefault(); + // create HTTP client with disabled retries + // e.g. don't retry if response contains the header `retry-after` + client = HttpClientBuilder.create().disableAutomaticRetries().build(); String overrides = """ { diff --git a/server/src/test/java/com/epam/aidial/core/server/limiter/RateBucketTest.java b/server/src/test/java/com/epam/aidial/core/server/limiter/RateBucketTest.java index 36b91ec46..b4b21c4ff 100644 --- a/server/src/test/java/com/epam/aidial/core/server/limiter/RateBucketTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/limiter/RateBucketTest.java @@ -1,10 +1,11 @@ package com.epam.aidial.core.server.limiter; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import java.util.concurrent.ThreadLocalRandom; +import static org.junit.jupiter.api.Assertions.assertEquals; + class RateBucketTest { private RateBucket bucket; @@ -53,6 +54,54 @@ void testDayBucket() { update(49, 0); } + @Test + public void testRetryAfterMinute() { + bucket = new RateBucket(RateWindow.MINUTE); + + update(0, 0); + assertEquals(0, bucket.retryAfter(30)); + add(0, 10, 10); + + update(5, 10); + assertEquals(0, bucket.retryAfter(30)); + add(5, 20, 30); + + update(15, 30); + assertEquals(45, bucket.retryAfter(30)); + add(15, 30, 60); + + update(25, 60); + add(25, 10, 70); + + update(60, 60); + assertEquals(15, bucket.retryAfter(30)); + } + + @Test + public void testRetryAfterDay() { + bucket = new RateBucket(RateWindow.DAY); + + update(0, 0); + assertEquals(0, bucket.retryAfter(30)); + add(0, 10, 10); + + update(5, 10); + assertEquals(0, bucket.retryAfter(30)); + add(5, 20, 30); + + update(10, 30); + // need to wait 14 hours + assertEquals(14 * 60 * 60, bucket.retryAfter(30)); + add(10, 30, 60); + + update(20, 60); + add(23, 10, 70); + + update(24, 60); + // need to wait 10 hours + assertEquals(10 * 60 * 60, bucket.retryAfter(30)); + } + private void add(long interval, long count, long expected) { RateWindow window = bucket.getWindow(); long whole = interval * window.interval(); @@ -60,7 +109,7 @@ private void add(long interval, long count, long expected) { long timestamp = window.window() + whole + fraction; long actual = bucket.add(timestamp, count); - Assertions.assertEquals(expected, actual); + assertEquals(expected, actual); } private void update(long interval, long expected) { @@ -70,6 +119,6 @@ private void update(long interval, long expected) { long timestamp = window.window() + whole + fraction; long actual = bucket.update(timestamp); - Assertions.assertEquals(expected, actual); + assertEquals(expected, actual); } }