Skip to content

Commit

Permalink
feat: Implement rate limits for requests #277 (#282)
Browse files Browse the repository at this point in the history
Co-authored-by: Aliaksandr Stsiapanay <[email protected]>
  • Loading branch information
astsiapanay and astsiapanay authored Mar 14, 2024
1 parent 52ba2a7 commit b3750b5
Show file tree
Hide file tree
Showing 9 changed files with 226 additions and 61 deletions.
4 changes: 3 additions & 1 deletion src/main/java/com/epam/aidial/core/config/Limit.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
public class Limit {
private long minute = Long.MAX_VALUE;
private long day = Long.MAX_VALUE;
private long requestHour = Long.MAX_VALUE;
private long requestDay = Long.MAX_VALUE;

public boolean isPositive() {
return minute > 0 && day > 0;
return minute > 0 && day > 0 && requestDay > 0 && requestHour > 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import lombok.Data;

@Data
public class TokenLimitStats {
public class ItemLimitStats {
private long total;
private long used;
}
6 changes: 4 additions & 2 deletions src/main/java/com/epam/aidial/core/data/LimitStats.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

@Data
public class LimitStats {
private TokenLimitStats minuteTokenStats;
private TokenLimitStats dayTokenStats;
private ItemLimitStats minuteTokenStats;
private ItemLimitStats dayTokenStats;
private ItemLimitStats hourRequestStats;
private ItemLimitStats dayRequestStats;
}
138 changes: 99 additions & 39 deletions src/main/java/com/epam/aidial/core/limiter/RateLimiter.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import com.epam.aidial.core.config.Key;
import com.epam.aidial.core.config.Limit;
import com.epam.aidial.core.config.Role;
import com.epam.aidial.core.data.ItemLimitStats;
import com.epam.aidial.core.data.LimitStats;
import com.epam.aidial.core.data.ResourceType;
import com.epam.aidial.core.data.TokenLimitStats;
import com.epam.aidial.core.service.ResourceService;
import com.epam.aidial.core.storage.BlobStorageUtil;
import com.epam.aidial.core.storage.ResourceDescription;
Expand All @@ -17,7 +17,6 @@
import io.vertx.core.Future;
import io.vertx.core.Vertx;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;

import java.util.List;
Expand All @@ -42,15 +41,15 @@ public Future<Void> increase(ProxyContext context) {
return Future.succeededFuture();
}

Deployment deployment = context.getDeployment();
TokenUsage usage = context.getTokenUsage();

if (usage == null || usage.getTotalTokens() <= 0) {
return Future.succeededFuture();
}

ResourceDescription resourceDescription = getResourceDescription(deployment.getName(), context);
return vertx.executeBlocking(() -> updateLimit(resourceDescription, usage.getTotalTokens()));
String tokensPath = getPathToTokens(context.getDeployment().getName());
ResourceDescription resourceDescription = getResourceDescription(context, tokensPath);
return vertx.executeBlocking(() -> updateTokenLimit(resourceDescription, usage.getTotalTokens()));
} catch (Throwable e) {
return Future.failedFuture(e);
}
Expand All @@ -63,25 +62,24 @@ public Future<RateLimitResult> limit(ProxyContext context) {
return Future.succeededFuture(RateLimitResult.SUCCESS);
}
Key key = context.getKey();
Deployment deployment = context.getDeployment();
String deploymentName = context.getDeployment().getName();
Limit limit;
if (key == null) {
limit = getLimitByUser(context);
limit = getLimitByUser(context, deploymentName);
} else {
limit = getLimitByApiKey(context, deployment.getName());
limit = getLimitByApiKey(context, deploymentName);
}

if (limit == null || !limit.isPositive()) {
if (limit == null) {
log.warn("Limit is not found for deployment: {}", deployment.getName());
log.warn("Limit is not found for deployment: {}", deploymentName);
} else {
log.warn("Limit must be positive for deployment: {}", deployment.getName());
log.warn("Limit must be positive for deployment: {}", deploymentName);
}
return Future.succeededFuture(new RateLimitResult(HttpStatus.FORBIDDEN, "Access denied"));
}

ResourceDescription resourceDescription = getResourceDescription(deployment.getName(), context);
return vertx.executeBlocking(() -> checkLimit(resourceDescription, limit));
return vertx.executeBlocking(() -> checkLimit(context, limit));
} catch (Throwable e) {
return Future.failedFuture(e);
}
Expand All @@ -96,71 +94,126 @@ public Future<LimitStats> getLimitStats(String deploymentName, ProxyContext cont
Key key = context.getKey();
Limit limit;
if (key == null) {
// don't support user limits yet
return Future.succeededFuture();
limit = getLimitByUser(context, deploymentName);
} else {
limit = getLimitByApiKey(context, deploymentName);
}
if (limit == null) {
log.warn("Limit is not found. Trace: {}. Span: {}. Key: {}. Deployment: {}", context.getTraceId(), context.getSpanId(), key.getProject(), deploymentName);
log.warn("Limit is not found. Trace: {}. Span: {}. Key: {}. User sub: {}. Deployment: {}",
context.getTraceId(), context.getSpanId(), key == null ? null : key.getProject(),
context.getUserSub(), deploymentName);
return Future.succeededFuture();
}
ResourceDescription resourceDescription = getResourceDescription(deploymentName, context);
return vertx.executeBlocking(() -> getLimitStats(resourceDescription, limit));
return vertx.executeBlocking(() -> getLimitStats(context, limit, deploymentName));
} catch (Throwable e) {
return Future.failedFuture(e);
}
}

private LimitStats getLimitStats(ResourceDescription resourceDescription, Limit limit) {
String json = resourceService.getResource(resourceDescription, true);
private LimitStats getLimitStats(ProxyContext context, Limit limit, String deploymentName) {
LimitStats limitStats = create(limit);
RateLimit rateLimit = ProxyUtil.convertToObject(json, RateLimit.class);
long timestamp = System.currentTimeMillis();
collectTokenLimitStats(context, limitStats, timestamp, deploymentName);
collectRequestLimitStats(context, limitStats, timestamp, deploymentName);
return limitStats;
}

private void collectTokenLimitStats(ProxyContext context, LimitStats limitStats, long timestamp, String deploymentName) {
String tokensPath = getPathToTokens(deploymentName);
ResourceDescription resourceDescription = getResourceDescription(context, tokensPath);
String json = resourceService.getResource(resourceDescription, true);
TokenRateLimit rateLimit = ProxyUtil.convertToObject(json, TokenRateLimit.class);
if (rateLimit == null) {
return limitStats;
return;
}
rateLimit.update(timestamp, limitStats);
}

private void collectRequestLimitStats(ProxyContext context, LimitStats limitStats, long timestamp, String deploymentName) {
String requestsPath = getPathToRequests(deploymentName);
ResourceDescription resourceDescription = getResourceDescription(context, requestsPath);
String json = resourceService.getResource(resourceDescription, true);
RequestRateLimit rateLimit = ProxyUtil.convertToObject(json, RequestRateLimit.class);
if (rateLimit == null) {
return;
}
long timestamp = System.currentTimeMillis();
rateLimit.update(timestamp, limitStats);
return limitStats;
}

private LimitStats create(Limit limit) {
LimitStats limitStats = new LimitStats();
TokenLimitStats dayTokenStats = new TokenLimitStats();

ItemLimitStats dayTokenStats = new ItemLimitStats();
dayTokenStats.setTotal(limit.getDay());
limitStats.setDayTokenStats(dayTokenStats);
TokenLimitStats minuteTokenStats = new TokenLimitStats();

ItemLimitStats minuteTokenStats = new ItemLimitStats();
minuteTokenStats.setTotal(limit.getMinute());
limitStats.setMinuteTokenStats(minuteTokenStats);

ItemLimitStats hourRequestStats = new ItemLimitStats();
hourRequestStats.setTotal(limit.getRequestHour());
limitStats.setHourRequestStats(hourRequestStats);

ItemLimitStats dayRequestStats = new ItemLimitStats();
dayRequestStats.setTotal(limit.getRequestDay());
limitStats.setDayRequestStats(dayRequestStats);

return limitStats;
}

private ResourceDescription getResourceDescription(String deploymentName, ProxyContext context) {
String path = getPath(deploymentName);
private ResourceDescription getResourceDescription(ProxyContext context, String path) {
String bucketLocation = BlobStorageUtil.buildUserBucket(context);
return ResourceDescription.fromEncoded(ResourceType.LIMIT, bucketLocation, bucketLocation, path);
}

private RateLimitResult checkLimit(ResourceDescription resourceDescription, Limit limit) {
private RateLimitResult checkLimit(ProxyContext context, Limit limit) {
long timestamp = System.currentTimeMillis();
RateLimitResult tokenResult = checkTokenLimit(context, limit, timestamp);
if (tokenResult.status() != HttpStatus.OK) {
return tokenResult;
}
return checkRequestLimit(context, limit, timestamp);
}

private RateLimitResult checkTokenLimit(ProxyContext context, Limit limit, long timestamp) {
String tokensPath = getPathToTokens(context.getDeployment().getName());
ResourceDescription resourceDescription = getResourceDescription(context, tokensPath);
String prevValue = resourceService.getResource(resourceDescription);
RateLimit rateLimit = ProxyUtil.convertToObject(prevValue, RateLimit.class);
TokenRateLimit rateLimit = ProxyUtil.convertToObject(prevValue, TokenRateLimit.class);
if (rateLimit == null) {
return RateLimitResult.SUCCESS;
}
long timestamp = System.currentTimeMillis();
return rateLimit.update(timestamp, limit);
}

private Void updateLimit(ResourceDescription resourceDescription, long totalUsedTokens) {
resourceService.computeResource(resourceDescription, json -> updateLimit(json, totalUsedTokens));
private RateLimitResult checkRequestLimit(ProxyContext context, Limit limit, long timestamp) {
String tokensPath = getPathToRequests(context.getDeployment().getName());
ResourceDescription resourceDescription = getResourceDescription(context, tokensPath);
// pass array to hold rate limit result returned by the function to compute the resource
RateLimitResult[] result = new RateLimitResult[1];
resourceService.computeResource(resourceDescription, json -> updateRequestLimit(json, timestamp, limit, result));
return result[0];
}

private String updateRequestLimit(String json, long timestamp, Limit limit, RateLimitResult[] result) {
RequestRateLimit rateLimit = ProxyUtil.convertToObject(json, RequestRateLimit.class);
if (rateLimit == null) {
rateLimit = new RequestRateLimit();
}
result[0] = rateLimit.check(timestamp, limit, 1);
return ProxyUtil.convertToString(rateLimit);
}

private Void updateTokenLimit(ResourceDescription resourceDescription, long totalUsedTokens) {
resourceService.computeResource(resourceDescription, json -> updateTokenLimit(json, totalUsedTokens));
return null;
}

@SneakyThrows
private String updateLimit(String json, long totalUsedTokens) {
RateLimit rateLimit = ProxyUtil.convertToObject(json, RateLimit.class);
private String updateTokenLimit(String json, long totalUsedTokens) {
TokenRateLimit rateLimit = ProxyUtil.convertToObject(json, TokenRateLimit.class);
if (rateLimit == null) {
rateLimit = new RateLimit();
rateLimit = new TokenRateLimit();
}
long timestamp = System.currentTimeMillis();
rateLimit.add(timestamp, totalUsedTokens);
Expand All @@ -179,9 +232,8 @@ private Limit getLimitByApiKey(ProxyContext context, String deploymentName) {
return role.getLimits().get(deploymentName);
}

private Limit getLimitByUser(ProxyContext context) {
private Limit getLimitByUser(ProxyContext context, String deploymentName) {
List<String> userRoles = context.getUserRoles();
String deploymentName = context.getDeployment().getName();
Map<String, Role> roles = context.getConfig().getRoles();
Limit defaultUserLimit = getLimit(roles, DEFAULT_USER_ROLE, deploymentName, DEFAULT_LIMIT);
if (userRoles.isEmpty()) {
Expand All @@ -194,20 +246,28 @@ private Limit getLimitByUser(ProxyContext context) {
if (limit == null) {
limit = new Limit();
limit.setMinute(candidate.getMinute());
limit.setRequestHour(candidate.getRequestHour());
limit.setRequestDay(candidate.getRequestDay());
limit.setDay(candidate.getDay());
} else {
limit.setMinute(Math.max(candidate.getMinute(), limit.getMinute()));
limit.setDay(Math.max(candidate.getDay(), limit.getDay()));
limit.setRequestDay(Math.max(candidate.getRequestDay(), limit.getRequestDay()));
limit.setRequestHour(Math.max(candidate.getRequestHour(), limit.getRequestHour()));
}
}
}
return limit == null ? defaultUserLimit : limit;
}

private static String getPath(String deploymentName) {
private static String getPathToTokens(String deploymentName) {
return String.format("%s/tokens", deploymentName);
}

private static String getPathToRequests(String deploymentName) {
return String.format("%s/requests", deploymentName);
}

private static Limit getLimit(Map<String, Role> roles, String userRole, String deploymentName, Limit defaultLimit) {
return Optional.ofNullable(roles.get(userRole))
.map(role -> role.getLimits().get(deploymentName))
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/com/epam/aidial/core/limiter/RateWindow.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
@Getter
@Accessors(fluent = true)
public enum RateWindow {
MINUTE(60L * 1000, 60), DAY(24L * 60 * 60 * 1000, 24);
MINUTE(60L * 1000, 60),
HOUR(60 * 60 * 1000, 60),
DAY(24L * 60 * 60 * 1000, 24);

private final long window;
private final long interval;
Expand Down
39 changes: 39 additions & 0 deletions src/main/java/com/epam/aidial/core/limiter/RequestRateLimit.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package com.epam.aidial.core.limiter;

import com.epam.aidial.core.config.Limit;
import com.epam.aidial.core.data.LimitStats;
import com.epam.aidial.core.util.HttpStatus;
import lombok.Data;

@Data
public class RequestRateLimit {
private final RateBucket hour = new RateBucket(RateWindow.HOUR);
private final RateBucket day = new RateBucket(RateWindow.DAY);

public RateLimitResult check(long timestamp, Limit limit, long count) {
long hourTotal = hour.update(timestamp);
long dayTotal = day.update(timestamp);

boolean result = hourTotal >= limit.getRequestHour() || dayTotal >= limit.getRequestDay();
if (result) {
String errorMsg = String.format("""
Hit request rate limit:
- hour limit: %d / %d requests
- day limit: %d / %d requests
""",
hourTotal, limit.getRequestHour(), dayTotal, limit.getRequestDay());
return new RateLimitResult(HttpStatus.TOO_MANY_REQUESTS, errorMsg);
} else {
hour.add(timestamp, count);
day.add(timestamp, count);
return RateLimitResult.SUCCESS;
}
}

public void update(long timestamp, LimitStats limitStats) {
long hourTotal = hour.update(timestamp);
long dayTotal = day.update(timestamp);
limitStats.getDayRequestStats().setUsed(dayTotal);
limitStats.getHourRequestStats().setUsed(hourTotal);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import lombok.Data;

@Data
public class RateLimit {
public class TokenRateLimit {

private final RateBucket minute = new RateBucket(RateWindow.MINUTE);
private final RateBucket day = new RateBucket(RateWindow.DAY);
Expand Down
10 changes: 9 additions & 1 deletion src/test/java/com/epam/aidial/core/LimitApiTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@ public void testGetLimitStats_Success() {
"dayTokenStats": {
"total": %d,
"used": %d
},
"hourRequestStats": {
"total": %d,
"used": %d
},
"dayRequestStats": {
"total": %d,
"used": %d
}
}
""".formatted(Long.MAX_VALUE, 0, Long.MAX_VALUE, 0));
""".formatted(Long.MAX_VALUE, 0, Long.MAX_VALUE, 0, Long.MAX_VALUE, 0, Long.MAX_VALUE, 0));
}

@Test
Expand Down
Loading

0 comments on commit b3750b5

Please sign in to comment.