Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Return retry-after header in case if rate limit is exceeded #615 #617

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import java.io.InputStream;
import java.math.BigDecimal;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

Expand Down Expand Up @@ -208,9 +209,20 @@ private void handleRateLimitHit(String deploymentId, RateLimitResult result) {
ErrorData rateLimitError = new ErrorData();
rateLimitError.getError().setCode(String.valueOf(result.status().getCode()));
rateLimitError.getError().setMessage(result.errorMessage());

log.error("Rate limit error {}. Project: {}. User sub: {}. Deployment: {}. Trace: {}. Span: {}", result.errorMessage(),
context.getProject(), context.getUserSub(), deploymentId, context.getTraceId(), context.getSpanId());
respond(result.status(), rateLimitError);

String errorMessage = ProxyUtil.convertToString(rateLimitError);
HttpException httpException;
if (result.replyAfterSeconds() >= 0) {
Map<String, String> headers = Map.of(HttpHeaders.RETRY_AFTER.toString(), Long.toString(result.replyAfterSeconds()));
httpException = new HttpException(result.status(), errorMessage, headers);
} else {
httpException = new HttpException(result.status(), errorMessage);
}

respond(httpException);
}

private void handleError(Throwable error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -205,10 +206,21 @@ private void handleRateLimitHit(RateLimitResult result) {
ErrorData rateLimitError = new ErrorData();
rateLimitError.getError().setCode(String.valueOf(result.status().getCode()));
rateLimitError.getError().setMessage(result.errorMessage());

log.error("Rate limit error {}. Project: {}. User sub: {}. Route: {}. Trace: {}. Span: {}", result.errorMessage(),
context.getProject(), context.getUserSub(), context.getRoute().getName(), context.getTraceId(),
context.getSpanId());
context.respond(result.status(), rateLimitError);

String errorMessage = ProxyUtil.convertToString(rateLimitError);
HttpException httpException;
if (result.replyAfterSeconds() >= 0) {
Map<String, String> headers = Map.of(HttpHeaders.RETRY_AFTER.toString(), Long.toString(result.replyAfterSeconds()));
httpException = new HttpException(result.status(), errorMessage, headers);
} else {
httpException = new HttpException(result.status(), errorMessage);
}

context.respond(httpException);
}

private void handleError(Throwable error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.concurrent.TimeUnit;

@Data
@NoArgsConstructor
public class RateBucket {
Expand Down Expand Up @@ -52,6 +54,22 @@ public long update(long timestamp) {
return sum;
}

/**
* Returns the number of seconds the user agent should wait before making a retry request.
*
* @param limit - requested limit
*/
long retryAfter(long limit) {
long sum = this.sum;
long replyAfter = 0;
for (long start = this.start; start < this.end && sum >= limit; start++) {
int index = index(start);
sum -= sums[index];
replyAfter += window.interval();
}
return TimeUnit.MILLISECONDS.toSeconds(replyAfter);
artsiomkorzun marked this conversation as resolved.
Show resolved Hide resolved
}

private long interval(long timestamp) {
if (timestamp < window.window()) {
throw new IllegalArgumentException("timestamp < window");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

import com.epam.aidial.core.storage.http.HttpStatus;

public record RateLimitResult(HttpStatus status, String errorMessage) {
public static final RateLimitResult SUCCESS = new RateLimitResult(HttpStatus.OK, null);
public record RateLimitResult(HttpStatus status, String errorMessage, long replyAfterSeconds) {
public static final RateLimitResult SUCCESS = new RateLimitResult(HttpStatus.OK, null, -1);
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public Future<RateLimitResult> limit(ProxyContext context, RoleBasedEntity roleB
} else {
log.warn("Limit must be positive for {}", name);
}
return Future.succeededFuture(new RateLimitResult(HttpStatus.FORBIDDEN, "Access denied"));
return Future.succeededFuture(new RateLimitResult(HttpStatus.FORBIDDEN, "Access denied", -1));
}

return vertx.executeBlocking(() -> checkLimit(context, limit, roleBasedEntity), false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ public RateLimitResult check(long timestamp, Limit limit, long count) {
if (result) {
String errorMsg = String.format("Hit request rate limit. Hour limit: %d / %d requests. Day limit: %d / %d requests.",
hourTotal, limit.getRequestHour(), dayTotal, limit.getRequestDay());
return new RateLimitResult(HttpStatus.TOO_MANY_REQUESTS, errorMsg);
long hourRetryAfter = hour.retryAfter(limit.getRequestHour());
long dayRetryAfter = day.retryAfter(limit.getRequestDay());
long retryAfter = Math.max(hourRetryAfter, dayRetryAfter);
artsiomkorzun marked this conversation as resolved.
Show resolved Hide resolved
return new RateLimitResult(HttpStatus.TOO_MANY_REQUESTS, errorMsg, retryAfter);
} else {
hour.add(timestamp, count);
day.add(timestamp, count);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ public RateLimitResult update(long timestamp, Limit limit) {
if (result) {
String errorMsg = String.format("Hit token rate limit. Minute limit: %d / %d tokens. Day limit: %d / %d tokens.",
minuteTotal, limit.getMinute(), dayTotal, limit.getDay());
return new RateLimitResult(HttpStatus.TOO_MANY_REQUESTS, errorMsg);
long minuteRetryAfter = minute.retryAfter(limit.getMinute());
long dayRetryAfter = day.retryAfter(limit.getDay());
long retryAfter = Math.max(minuteRetryAfter, dayRetryAfter);
return new RateLimitResult(HttpStatus.TOO_MANY_REQUESTS, errorMsg, retryAfter);
} else {
return RateLimitResult.SUCCESS;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.apache.hc.client5.http.entity.mime.HttpMultipartMode;
import org.apache.hc.client5.http.entity.mime.MultipartEntityBuilder;
import org.apache.hc.client5.http.impl.classic.CloseableHttpClient;
import org.apache.hc.client5.http.impl.classic.HttpClients;
import org.apache.hc.client5.http.impl.classic.HttpClientBuilder;
import org.apache.hc.core5.http.ContentType;
import org.apache.hc.core5.http.Header;
import org.apache.hc.core5.http.io.entity.EntityUtils;
Expand Down Expand Up @@ -122,7 +122,9 @@ void init() throws Exception {
.build();
redis.start();

client = HttpClients.createDefault();
// create HTTP client with disabled retries
// e.g. don't retry if response contains the header `retry-after`
client = HttpClientBuilder.create().disableAutomaticRetries().build();

String overrides = """
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package com.epam.aidial.core.server.limiter;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.util.concurrent.ThreadLocalRandom;

import static org.junit.jupiter.api.Assertions.assertEquals;

class RateBucketTest {

private RateBucket bucket;
Expand Down Expand Up @@ -53,14 +54,62 @@ void testDayBucket() {
update(49, 0);
}

@Test
public void testRetryAfterMinute() {
bucket = new RateBucket(RateWindow.MINUTE);

update(0, 0);
assertEquals(0, bucket.retryAfter(30));
add(0, 10, 10);

update(5, 10);
assertEquals(0, bucket.retryAfter(30));
add(5, 20, 30);

update(15, 30);
assertEquals(45, bucket.retryAfter(30));
add(15, 30, 60);

update(25, 60);
add(25, 10, 70);

update(60, 60);
assertEquals(15, bucket.retryAfter(30));
}

@Test
public void testRetryAfterDay() {
bucket = new RateBucket(RateWindow.DAY);

update(0, 0);
assertEquals(0, bucket.retryAfter(30));
add(0, 10, 10);

update(5, 10);
assertEquals(0, bucket.retryAfter(30));
add(5, 20, 30);

update(10, 30);
// need to wait 14 hours
assertEquals(14 * 60 * 60, bucket.retryAfter(30));
add(10, 30, 60);

update(20, 60);
add(23, 10, 70);

update(24, 60);
// need to wait 10 hours
assertEquals(10 * 60 * 60, bucket.retryAfter(30));
}

private void add(long interval, long count, long expected) {
RateWindow window = bucket.getWindow();
long whole = interval * window.interval();
long fraction = ThreadLocalRandom.current().nextLong(0, window.interval());

long timestamp = window.window() + whole + fraction;
long actual = bucket.add(timestamp, count);
Assertions.assertEquals(expected, actual);
assertEquals(expected, actual);
}

private void update(long interval, long expected) {
Expand All @@ -70,6 +119,6 @@ private void update(long interval, long expected) {

long timestamp = window.window() + whole + fraction;
long actual = bucket.update(timestamp);
Assertions.assertEquals(expected, actual);
assertEquals(expected, actual);
}
}
Loading