Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement rate limits for requests #277 #282

Merged
merged 4 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/main/java/com/epam/aidial/core/config/Limit.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
public class Limit {
private long minute = Long.MAX_VALUE;
private long day = Long.MAX_VALUE;
private long requestHour = Long.MAX_VALUE;
private long requestDay = Long.MAX_VALUE;

public boolean isPositive() {
return minute > 0 && day > 0;
return minute > 0 && day > 0 && requestDay > 0 && requestHour > 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import lombok.Data;

@Data
public class TokenLimitStats {
public class ItemLimitStats {
private long total;
private long used;
}
6 changes: 4 additions & 2 deletions src/main/java/com/epam/aidial/core/data/LimitStats.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

@Data
public class LimitStats {
private TokenLimitStats minuteTokenStats;
private TokenLimitStats dayTokenStats;
private ItemLimitStats minuteTokenStats;
private ItemLimitStats dayTokenStats;
private ItemLimitStats hourRequestStats;
private ItemLimitStats dayRequestStats;
}
138 changes: 99 additions & 39 deletions src/main/java/com/epam/aidial/core/limiter/RateLimiter.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import com.epam.aidial.core.config.Key;
import com.epam.aidial.core.config.Limit;
import com.epam.aidial.core.config.Role;
import com.epam.aidial.core.data.ItemLimitStats;
import com.epam.aidial.core.data.LimitStats;
import com.epam.aidial.core.data.ResourceType;
import com.epam.aidial.core.data.TokenLimitStats;
import com.epam.aidial.core.service.ResourceService;
import com.epam.aidial.core.storage.BlobStorageUtil;
import com.epam.aidial.core.storage.ResourceDescription;
Expand All @@ -17,7 +17,6 @@
import io.vertx.core.Future;
import io.vertx.core.Vertx;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;

import java.util.List;
Expand All @@ -42,15 +41,15 @@ public Future<Void> increase(ProxyContext context) {
return Future.succeededFuture();
}

Deployment deployment = context.getDeployment();
TokenUsage usage = context.getTokenUsage();

if (usage == null || usage.getTotalTokens() <= 0) {
return Future.succeededFuture();
}

ResourceDescription resourceDescription = getResourceDescription(deployment.getName(), context);
return vertx.executeBlocking(() -> updateLimit(resourceDescription, usage.getTotalTokens()));
String tokensPath = getPathToTokens(context.getDeployment().getName());
ResourceDescription resourceDescription = getResourceDescription(context, tokensPath);
return vertx.executeBlocking(() -> updateTokenLimit(resourceDescription, usage.getTotalTokens()));
} catch (Throwable e) {
return Future.failedFuture(e);
}
Expand All @@ -63,25 +62,24 @@ public Future<RateLimitResult> limit(ProxyContext context) {
return Future.succeededFuture(RateLimitResult.SUCCESS);
}
Key key = context.getKey();
Deployment deployment = context.getDeployment();
String deploymentName = context.getDeployment().getName();
Limit limit;
if (key == null) {
limit = getLimitByUser(context);
limit = getLimitByUser(context, deploymentName);
} else {
limit = getLimitByApiKey(context, deployment.getName());
limit = getLimitByApiKey(context, deploymentName);
}

if (limit == null || !limit.isPositive()) {
if (limit == null) {
log.warn("Limit is not found for deployment: {}", deployment.getName());
log.warn("Limit is not found for deployment: {}", deploymentName);
} else {
log.warn("Limit must be positive for deployment: {}", deployment.getName());
log.warn("Limit must be positive for deployment: {}", deploymentName);
}
return Future.succeededFuture(new RateLimitResult(HttpStatus.FORBIDDEN, "Access denied"));
}

ResourceDescription resourceDescription = getResourceDescription(deployment.getName(), context);
return vertx.executeBlocking(() -> checkLimit(resourceDescription, limit));
return vertx.executeBlocking(() -> checkLimit(context, limit));
} catch (Throwable e) {
return Future.failedFuture(e);
}
Expand All @@ -96,71 +94,126 @@ public Future<LimitStats> getLimitStats(String deploymentName, ProxyContext cont
Key key = context.getKey();
Limit limit;
if (key == null) {
// don't support user limits yet
return Future.succeededFuture();
limit = getLimitByUser(context, deploymentName);
} else {
limit = getLimitByApiKey(context, deploymentName);
}
if (limit == null) {
log.warn("Limit is not found. Trace: {}. Span: {}. Key: {}. Deployment: {}", context.getTraceId(), context.getSpanId(), key.getProject(), deploymentName);
log.warn("Limit is not found. Trace: {}. Span: {}. Key: {}. User sub: {}. Deployment: {}",
context.getTraceId(), context.getSpanId(), key == null ? null : key.getProject(),
context.getUserSub(), deploymentName);
return Future.succeededFuture();
}
ResourceDescription resourceDescription = getResourceDescription(deploymentName, context);
return vertx.executeBlocking(() -> getLimitStats(resourceDescription, limit));
return vertx.executeBlocking(() -> getLimitStats(context, limit, deploymentName));
} catch (Throwable e) {
return Future.failedFuture(e);
}
}

private LimitStats getLimitStats(ResourceDescription resourceDescription, Limit limit) {
String json = resourceService.getResource(resourceDescription, true);
private LimitStats getLimitStats(ProxyContext context, Limit limit, String deploymentName) {
LimitStats limitStats = create(limit);
RateLimit rateLimit = ProxyUtil.convertToObject(json, RateLimit.class);
long timestamp = System.currentTimeMillis();
collectTokenLimitStats(context, limitStats, timestamp, deploymentName);
collectRequestLimitStats(context, limitStats, timestamp, deploymentName);
return limitStats;
}

private void collectTokenLimitStats(ProxyContext context, LimitStats limitStats, long timestamp, String deploymentName) {
String tokensPath = getPathToTokens(deploymentName);
ResourceDescription resourceDescription = getResourceDescription(context, tokensPath);
String json = resourceService.getResource(resourceDescription, true);
TokenRateLimit rateLimit = ProxyUtil.convertToObject(json, TokenRateLimit.class);
if (rateLimit == null) {
return limitStats;
return;
}
rateLimit.update(timestamp, limitStats);
}

private void collectRequestLimitStats(ProxyContext context, LimitStats limitStats, long timestamp, String deploymentName) {
String requestsPath = getPathToRequests(deploymentName);
ResourceDescription resourceDescription = getResourceDescription(context, requestsPath);
String json = resourceService.getResource(resourceDescription, true);
RequestRateLimit rateLimit = ProxyUtil.convertToObject(json, RequestRateLimit.class);
if (rateLimit == null) {
return;
}
long timestamp = System.currentTimeMillis();
rateLimit.update(timestamp, limitStats);
return limitStats;
}

private LimitStats create(Limit limit) {
LimitStats limitStats = new LimitStats();
TokenLimitStats dayTokenStats = new TokenLimitStats();

ItemLimitStats dayTokenStats = new ItemLimitStats();
dayTokenStats.setTotal(limit.getDay());
limitStats.setDayTokenStats(dayTokenStats);
TokenLimitStats minuteTokenStats = new TokenLimitStats();

ItemLimitStats minuteTokenStats = new ItemLimitStats();
minuteTokenStats.setTotal(limit.getMinute());
limitStats.setMinuteTokenStats(minuteTokenStats);

ItemLimitStats hourRequestStats = new ItemLimitStats();
hourRequestStats.setTotal(limit.getRequestHour());
limitStats.setHourRequestStats(hourRequestStats);

ItemLimitStats dayRequestStats = new ItemLimitStats();
dayRequestStats.setTotal(limit.getRequestDay());
limitStats.setDayRequestStats(dayRequestStats);

return limitStats;
}

private ResourceDescription getResourceDescription(String deploymentName, ProxyContext context) {
String path = getPath(deploymentName);
private ResourceDescription getResourceDescription(ProxyContext context, String path) {
String bucketLocation = BlobStorageUtil.buildUserBucket(context);
return ResourceDescription.fromEncoded(ResourceType.LIMIT, bucketLocation, bucketLocation, path);
}

private RateLimitResult checkLimit(ResourceDescription resourceDescription, Limit limit) {
private RateLimitResult checkLimit(ProxyContext context, Limit limit) {
long timestamp = System.currentTimeMillis();
RateLimitResult tokenResult = checkTokenLimit(context, limit, timestamp);
if (tokenResult.status() != HttpStatus.OK) {
return tokenResult;
}
return checkRequestLimit(context, limit, timestamp);
}

private RateLimitResult checkTokenLimit(ProxyContext context, Limit limit, long timestamp) {
String tokensPath = getPathToTokens(context.getDeployment().getName());
ResourceDescription resourceDescription = getResourceDescription(context, tokensPath);
Maxim-Gadalov marked this conversation as resolved.
Show resolved Hide resolved
String prevValue = resourceService.getResource(resourceDescription);
RateLimit rateLimit = ProxyUtil.convertToObject(prevValue, RateLimit.class);
TokenRateLimit rateLimit = ProxyUtil.convertToObject(prevValue, TokenRateLimit.class);
if (rateLimit == null) {
return RateLimitResult.SUCCESS;
}
long timestamp = System.currentTimeMillis();
return rateLimit.update(timestamp, limit);
}

private Void updateLimit(ResourceDescription resourceDescription, long totalUsedTokens) {
resourceService.computeResource(resourceDescription, json -> updateLimit(json, totalUsedTokens));
private RateLimitResult checkRequestLimit(ProxyContext context, Limit limit, long timestamp) {
String tokensPath = getPathToRequests(context.getDeployment().getName());
ResourceDescription resourceDescription = getResourceDescription(context, tokensPath);
// pass array to hold rate limit result returned by the function to compute the resource
RateLimitResult[] result = new RateLimitResult[1];
Maxim-Gadalov marked this conversation as resolved.
Show resolved Hide resolved
resourceService.computeResource(resourceDescription, json -> updateRequestLimit(json, timestamp, limit, result));
return result[0];
}

private String updateRequestLimit(String json, long timestamp, Limit limit, RateLimitResult[] result) {
RequestRateLimit rateLimit = ProxyUtil.convertToObject(json, RequestRateLimit.class);
if (rateLimit == null) {
rateLimit = new RequestRateLimit();
}
result[0] = rateLimit.check(timestamp, limit, 1);
return ProxyUtil.convertToString(rateLimit);
}

private Void updateTokenLimit(ResourceDescription resourceDescription, long totalUsedTokens) {
resourceService.computeResource(resourceDescription, json -> updateTokenLimit(json, totalUsedTokens));
return null;
}

@SneakyThrows
private String updateLimit(String json, long totalUsedTokens) {
RateLimit rateLimit = ProxyUtil.convertToObject(json, RateLimit.class);
private String updateTokenLimit(String json, long totalUsedTokens) {
TokenRateLimit rateLimit = ProxyUtil.convertToObject(json, TokenRateLimit.class);
if (rateLimit == null) {
rateLimit = new RateLimit();
rateLimit = new TokenRateLimit();
}
long timestamp = System.currentTimeMillis();
rateLimit.add(timestamp, totalUsedTokens);
Expand All @@ -179,9 +232,8 @@ private Limit getLimitByApiKey(ProxyContext context, String deploymentName) {
return role.getLimits().get(deploymentName);
}

private Limit getLimitByUser(ProxyContext context) {
private Limit getLimitByUser(ProxyContext context, String deploymentName) {
List<String> userRoles = context.getUserRoles();
String deploymentName = context.getDeployment().getName();
Map<String, Role> roles = context.getConfig().getRoles();
Limit defaultUserLimit = getLimit(roles, DEFAULT_USER_ROLE, deploymentName, DEFAULT_LIMIT);
if (userRoles.isEmpty()) {
Expand All @@ -194,20 +246,28 @@ private Limit getLimitByUser(ProxyContext context) {
if (limit == null) {
limit = new Limit();
limit.setMinute(candidate.getMinute());
limit.setRequestHour(candidate.getRequestHour());
limit.setRequestDay(candidate.getRequestDay());
limit.setDay(candidate.getDay());
} else {
limit.setMinute(Math.max(candidate.getMinute(), limit.getMinute()));
limit.setDay(Math.max(candidate.getDay(), limit.getDay()));
limit.setRequestDay(Math.max(candidate.getRequestDay(), limit.getRequestDay()));
limit.setRequestHour(Math.max(candidate.getRequestHour(), limit.getRequestHour()));
}
}
}
return limit == null ? defaultUserLimit : limit;
}

private static String getPath(String deploymentName) {
private static String getPathToTokens(String deploymentName) {
return String.format("%s/tokens", deploymentName);
}

private static String getPathToRequests(String deploymentName) {
return String.format("%s/requests", deploymentName);
}

private static Limit getLimit(Map<String, Role> roles, String userRole, String deploymentName, Limit defaultLimit) {
return Optional.ofNullable(roles.get(userRole))
.map(role -> role.getLimits().get(deploymentName))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
@Getter
@Accessors(fluent = true)
public enum RateWindow {
MINUTE(60L * 1000, 60), DAY(24L * 60 * 60 * 1000, 24);
MINUTE(60L * 1000, 60),
HOUR(60 * 60 * 1000, 60),
DAY(24L * 60 * 60 * 1000, 24);

private final long window;
private final long interval;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package com.epam.aidial.core.limiter;

import com.epam.aidial.core.config.Limit;
import com.epam.aidial.core.data.LimitStats;
import com.epam.aidial.core.util.HttpStatus;
import lombok.Data;

@Data
public class RequestRateLimit {
private final RateBucket hour = new RateBucket(RateWindow.HOUR);
private final RateBucket day = new RateBucket(RateWindow.DAY);

public RateLimitResult check(long timestamp, Limit limit, long count) {
long hourTotal = hour.update(timestamp);
long dayTotal = day.update(timestamp);

boolean result = hourTotal >= limit.getRequestHour() || dayTotal >= limit.getRequestDay();
if (result) {
String errorMsg = String.format("""
Hit request rate limit:
- hour limit: %d / %d requests
- day limit: %d / %d requests
""",
hourTotal, limit.getRequestHour(), dayTotal, limit.getRequestDay());
return new RateLimitResult(HttpStatus.TOO_MANY_REQUESTS, errorMsg);
} else {
hour.add(timestamp, count);
day.add(timestamp, count);
return RateLimitResult.SUCCESS;
}
}

public void update(long timestamp, LimitStats limitStats) {
long hourTotal = hour.update(timestamp);
long dayTotal = day.update(timestamp);
limitStats.getDayRequestStats().setUsed(dayTotal);
limitStats.getHourRequestStats().setUsed(hourTotal);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import lombok.Data;

@Data
public class RateLimit {
public class TokenRateLimit {

private final RateBucket minute = new RateBucket(RateWindow.MINUTE);
private final RateBucket day = new RateBucket(RateWindow.DAY);
Expand Down
10 changes: 9 additions & 1 deletion src/test/java/com/epam/aidial/core/LimitApiTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@ public void testGetLimitStats_Success() {
"dayTokenStats": {
"total": %d,
"used": %d
},
"hourRequestStats": {
"total": %d,
"used": %d
},
"dayRequestStats": {
"total": %d,
"used": %d
}
}
""".formatted(Long.MAX_VALUE, 0, Long.MAX_VALUE, 0));
""".formatted(Long.MAX_VALUE, 0, Long.MAX_VALUE, 0, Long.MAX_VALUE, 0, Long.MAX_VALUE, 0));
}

@Test
Expand Down
Loading
Loading