diff --git a/src/main/java/com/epam/aidial/core/config/Limit.java b/src/main/java/com/epam/aidial/core/config/Limit.java index bc25d3cd1..d930d128a 100644 --- a/src/main/java/com/epam/aidial/core/config/Limit.java +++ b/src/main/java/com/epam/aidial/core/config/Limit.java @@ -6,8 +6,10 @@ public class Limit { private long minute = Long.MAX_VALUE; private long day = Long.MAX_VALUE; + private long requestHour = Long.MAX_VALUE; + private long requestDay = Long.MAX_VALUE; public boolean isPositive() { - return minute > 0 && day > 0; + return minute > 0 && day > 0 && requestDay > 0 && requestHour > 0; } } \ No newline at end of file diff --git a/src/main/java/com/epam/aidial/core/data/TokenLimitStats.java b/src/main/java/com/epam/aidial/core/data/ItemLimitStats.java similarity index 78% rename from src/main/java/com/epam/aidial/core/data/TokenLimitStats.java rename to src/main/java/com/epam/aidial/core/data/ItemLimitStats.java index 83b0bff10..8a1b5c1ce 100644 --- a/src/main/java/com/epam/aidial/core/data/TokenLimitStats.java +++ b/src/main/java/com/epam/aidial/core/data/ItemLimitStats.java @@ -3,7 +3,7 @@ import lombok.Data; @Data -public class TokenLimitStats { +public class ItemLimitStats { private long total; private long used; } diff --git a/src/main/java/com/epam/aidial/core/data/LimitStats.java b/src/main/java/com/epam/aidial/core/data/LimitStats.java index 256b670c6..7b6aeab8e 100644 --- a/src/main/java/com/epam/aidial/core/data/LimitStats.java +++ b/src/main/java/com/epam/aidial/core/data/LimitStats.java @@ -4,6 +4,8 @@ @Data public class LimitStats { - private TokenLimitStats minuteTokenStats; - private TokenLimitStats dayTokenStats; + private ItemLimitStats minuteTokenStats; + private ItemLimitStats dayTokenStats; + private ItemLimitStats hourRequestStats; + private ItemLimitStats dayRequestStats; } diff --git a/src/main/java/com/epam/aidial/core/limiter/RateLimiter.java b/src/main/java/com/epam/aidial/core/limiter/RateLimiter.java index 901b47e88..3ad97b8e7 100644 --- a/src/main/java/com/epam/aidial/core/limiter/RateLimiter.java +++ b/src/main/java/com/epam/aidial/core/limiter/RateLimiter.java @@ -5,9 +5,9 @@ import com.epam.aidial.core.config.Key; import com.epam.aidial.core.config.Limit; import com.epam.aidial.core.config.Role; +import com.epam.aidial.core.data.ItemLimitStats; import com.epam.aidial.core.data.LimitStats; import com.epam.aidial.core.data.ResourceType; -import com.epam.aidial.core.data.TokenLimitStats; import com.epam.aidial.core.service.ResourceService; import com.epam.aidial.core.storage.BlobStorageUtil; import com.epam.aidial.core.storage.ResourceDescription; @@ -17,7 +17,6 @@ import io.vertx.core.Future; import io.vertx.core.Vertx; import lombok.RequiredArgsConstructor; -import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import java.util.List; @@ -42,15 +41,15 @@ public Future increase(ProxyContext context) { return Future.succeededFuture(); } - Deployment deployment = context.getDeployment(); TokenUsage usage = context.getTokenUsage(); if (usage == null || usage.getTotalTokens() <= 0) { return Future.succeededFuture(); } - ResourceDescription resourceDescription = getResourceDescription(deployment.getName(), context); - return vertx.executeBlocking(() -> updateLimit(resourceDescription, usage.getTotalTokens())); + String tokensPath = getPathToTokens(context.getDeployment().getName()); + ResourceDescription resourceDescription = getResourceDescription(context, tokensPath); + return vertx.executeBlocking(() -> updateTokenLimit(resourceDescription, usage.getTotalTokens())); } catch (Throwable e) { return Future.failedFuture(e); } @@ -63,25 +62,24 @@ public Future limit(ProxyContext context) { return Future.succeededFuture(RateLimitResult.SUCCESS); } Key key = context.getKey(); - Deployment deployment = context.getDeployment(); + String deploymentName = context.getDeployment().getName(); Limit limit; if (key == null) { - limit = getLimitByUser(context); + limit = getLimitByUser(context, deploymentName); } else { - limit = getLimitByApiKey(context, deployment.getName()); + limit = getLimitByApiKey(context, deploymentName); } if (limit == null || !limit.isPositive()) { if (limit == null) { - log.warn("Limit is not found for deployment: {}", deployment.getName()); + log.warn("Limit is not found for deployment: {}", deploymentName); } else { - log.warn("Limit must be positive for deployment: {}", deployment.getName()); + log.warn("Limit must be positive for deployment: {}", deploymentName); } return Future.succeededFuture(new RateLimitResult(HttpStatus.FORBIDDEN, "Access denied")); } - ResourceDescription resourceDescription = getResourceDescription(deployment.getName(), context); - return vertx.executeBlocking(() -> checkLimit(resourceDescription, limit)); + return vertx.executeBlocking(() -> checkLimit(context, limit)); } catch (Throwable e) { return Future.failedFuture(e); } @@ -96,71 +94,126 @@ public Future getLimitStats(String deploymentName, ProxyContext cont Key key = context.getKey(); Limit limit; if (key == null) { - // don't support user limits yet - return Future.succeededFuture(); + limit = getLimitByUser(context, deploymentName); } else { limit = getLimitByApiKey(context, deploymentName); } if (limit == null) { - log.warn("Limit is not found. Trace: {}. Span: {}. Key: {}. Deployment: {}", context.getTraceId(), context.getSpanId(), key.getProject(), deploymentName); + log.warn("Limit is not found. Trace: {}. Span: {}. Key: {}. User sub: {}. Deployment: {}", + context.getTraceId(), context.getSpanId(), key == null ? null : key.getProject(), + context.getUserSub(), deploymentName); return Future.succeededFuture(); } - ResourceDescription resourceDescription = getResourceDescription(deploymentName, context); - return vertx.executeBlocking(() -> getLimitStats(resourceDescription, limit)); + return vertx.executeBlocking(() -> getLimitStats(context, limit, deploymentName)); } catch (Throwable e) { return Future.failedFuture(e); } } - private LimitStats getLimitStats(ResourceDescription resourceDescription, Limit limit) { - String json = resourceService.getResource(resourceDescription, true); + private LimitStats getLimitStats(ProxyContext context, Limit limit, String deploymentName) { LimitStats limitStats = create(limit); - RateLimit rateLimit = ProxyUtil.convertToObject(json, RateLimit.class); + long timestamp = System.currentTimeMillis(); + collectTokenLimitStats(context, limitStats, timestamp, deploymentName); + collectRequestLimitStats(context, limitStats, timestamp, deploymentName); + return limitStats; + } + + private void collectTokenLimitStats(ProxyContext context, LimitStats limitStats, long timestamp, String deploymentName) { + String tokensPath = getPathToTokens(deploymentName); + ResourceDescription resourceDescription = getResourceDescription(context, tokensPath); + String json = resourceService.getResource(resourceDescription, true); + TokenRateLimit rateLimit = ProxyUtil.convertToObject(json, TokenRateLimit.class); if (rateLimit == null) { - return limitStats; + return; + } + rateLimit.update(timestamp, limitStats); + } + + private void collectRequestLimitStats(ProxyContext context, LimitStats limitStats, long timestamp, String deploymentName) { + String requestsPath = getPathToRequests(deploymentName); + ResourceDescription resourceDescription = getResourceDescription(context, requestsPath); + String json = resourceService.getResource(resourceDescription, true); + RequestRateLimit rateLimit = ProxyUtil.convertToObject(json, RequestRateLimit.class); + if (rateLimit == null) { + return; } - long timestamp = System.currentTimeMillis(); rateLimit.update(timestamp, limitStats); - return limitStats; } private LimitStats create(Limit limit) { LimitStats limitStats = new LimitStats(); - TokenLimitStats dayTokenStats = new TokenLimitStats(); + + ItemLimitStats dayTokenStats = new ItemLimitStats(); dayTokenStats.setTotal(limit.getDay()); limitStats.setDayTokenStats(dayTokenStats); - TokenLimitStats minuteTokenStats = new TokenLimitStats(); + + ItemLimitStats minuteTokenStats = new ItemLimitStats(); minuteTokenStats.setTotal(limit.getMinute()); limitStats.setMinuteTokenStats(minuteTokenStats); + + ItemLimitStats hourRequestStats = new ItemLimitStats(); + hourRequestStats.setTotal(limit.getRequestHour()); + limitStats.setHourRequestStats(hourRequestStats); + + ItemLimitStats dayRequestStats = new ItemLimitStats(); + dayRequestStats.setTotal(limit.getRequestDay()); + limitStats.setDayRequestStats(dayRequestStats); + return limitStats; } - private ResourceDescription getResourceDescription(String deploymentName, ProxyContext context) { - String path = getPath(deploymentName); + private ResourceDescription getResourceDescription(ProxyContext context, String path) { String bucketLocation = BlobStorageUtil.buildUserBucket(context); return ResourceDescription.fromEncoded(ResourceType.LIMIT, bucketLocation, bucketLocation, path); } - private RateLimitResult checkLimit(ResourceDescription resourceDescription, Limit limit) { + private RateLimitResult checkLimit(ProxyContext context, Limit limit) { + long timestamp = System.currentTimeMillis(); + RateLimitResult tokenResult = checkTokenLimit(context, limit, timestamp); + if (tokenResult.status() != HttpStatus.OK) { + return tokenResult; + } + return checkRequestLimit(context, limit, timestamp); + } + + private RateLimitResult checkTokenLimit(ProxyContext context, Limit limit, long timestamp) { + String tokensPath = getPathToTokens(context.getDeployment().getName()); + ResourceDescription resourceDescription = getResourceDescription(context, tokensPath); String prevValue = resourceService.getResource(resourceDescription); - RateLimit rateLimit = ProxyUtil.convertToObject(prevValue, RateLimit.class); + TokenRateLimit rateLimit = ProxyUtil.convertToObject(prevValue, TokenRateLimit.class); if (rateLimit == null) { return RateLimitResult.SUCCESS; } - long timestamp = System.currentTimeMillis(); return rateLimit.update(timestamp, limit); } - private Void updateLimit(ResourceDescription resourceDescription, long totalUsedTokens) { - resourceService.computeResource(resourceDescription, json -> updateLimit(json, totalUsedTokens)); + private RateLimitResult checkRequestLimit(ProxyContext context, Limit limit, long timestamp) { + String tokensPath = getPathToRequests(context.getDeployment().getName()); + ResourceDescription resourceDescription = getResourceDescription(context, tokensPath); + // pass array to hold rate limit result returned by the function to compute the resource + RateLimitResult[] result = new RateLimitResult[1]; + resourceService.computeResource(resourceDescription, json -> updateRequestLimit(json, timestamp, limit, result)); + return result[0]; + } + + private String updateRequestLimit(String json, long timestamp, Limit limit, RateLimitResult[] result) { + RequestRateLimit rateLimit = ProxyUtil.convertToObject(json, RequestRateLimit.class); + if (rateLimit == null) { + rateLimit = new RequestRateLimit(); + } + result[0] = rateLimit.check(timestamp, limit, 1); + return ProxyUtil.convertToString(rateLimit); + } + + private Void updateTokenLimit(ResourceDescription resourceDescription, long totalUsedTokens) { + resourceService.computeResource(resourceDescription, json -> updateTokenLimit(json, totalUsedTokens)); return null; } - @SneakyThrows - private String updateLimit(String json, long totalUsedTokens) { - RateLimit rateLimit = ProxyUtil.convertToObject(json, RateLimit.class); + private String updateTokenLimit(String json, long totalUsedTokens) { + TokenRateLimit rateLimit = ProxyUtil.convertToObject(json, TokenRateLimit.class); if (rateLimit == null) { - rateLimit = new RateLimit(); + rateLimit = new TokenRateLimit(); } long timestamp = System.currentTimeMillis(); rateLimit.add(timestamp, totalUsedTokens); @@ -179,9 +232,8 @@ private Limit getLimitByApiKey(ProxyContext context, String deploymentName) { return role.getLimits().get(deploymentName); } - private Limit getLimitByUser(ProxyContext context) { + private Limit getLimitByUser(ProxyContext context, String deploymentName) { List userRoles = context.getUserRoles(); - String deploymentName = context.getDeployment().getName(); Map roles = context.getConfig().getRoles(); Limit defaultUserLimit = getLimit(roles, DEFAULT_USER_ROLE, deploymentName, DEFAULT_LIMIT); if (userRoles.isEmpty()) { @@ -194,20 +246,28 @@ private Limit getLimitByUser(ProxyContext context) { if (limit == null) { limit = new Limit(); limit.setMinute(candidate.getMinute()); + limit.setRequestHour(candidate.getRequestHour()); + limit.setRequestDay(candidate.getRequestDay()); limit.setDay(candidate.getDay()); } else { limit.setMinute(Math.max(candidate.getMinute(), limit.getMinute())); limit.setDay(Math.max(candidate.getDay(), limit.getDay())); + limit.setRequestDay(Math.max(candidate.getRequestDay(), limit.getRequestDay())); + limit.setRequestHour(Math.max(candidate.getRequestHour(), limit.getRequestHour())); } } } return limit == null ? defaultUserLimit : limit; } - private static String getPath(String deploymentName) { + private static String getPathToTokens(String deploymentName) { return String.format("%s/tokens", deploymentName); } + private static String getPathToRequests(String deploymentName) { + return String.format("%s/requests", deploymentName); + } + private static Limit getLimit(Map roles, String userRole, String deploymentName, Limit defaultLimit) { return Optional.ofNullable(roles.get(userRole)) .map(role -> role.getLimits().get(deploymentName)) diff --git a/src/main/java/com/epam/aidial/core/limiter/RateWindow.java b/src/main/java/com/epam/aidial/core/limiter/RateWindow.java index 9520e1083..b8148019f 100644 --- a/src/main/java/com/epam/aidial/core/limiter/RateWindow.java +++ b/src/main/java/com/epam/aidial/core/limiter/RateWindow.java @@ -6,7 +6,9 @@ @Getter @Accessors(fluent = true) public enum RateWindow { - MINUTE(60L * 1000, 60), DAY(24L * 60 * 60 * 1000, 24); + MINUTE(60L * 1000, 60), + HOUR(60 * 60 * 1000, 60), + DAY(24L * 60 * 60 * 1000, 24); private final long window; private final long interval; diff --git a/src/main/java/com/epam/aidial/core/limiter/RequestRateLimit.java b/src/main/java/com/epam/aidial/core/limiter/RequestRateLimit.java new file mode 100644 index 000000000..84ae2fc1b --- /dev/null +++ b/src/main/java/com/epam/aidial/core/limiter/RequestRateLimit.java @@ -0,0 +1,39 @@ +package com.epam.aidial.core.limiter; + +import com.epam.aidial.core.config.Limit; +import com.epam.aidial.core.data.LimitStats; +import com.epam.aidial.core.util.HttpStatus; +import lombok.Data; + +@Data +public class RequestRateLimit { + private final RateBucket hour = new RateBucket(RateWindow.HOUR); + private final RateBucket day = new RateBucket(RateWindow.DAY); + + public RateLimitResult check(long timestamp, Limit limit, long count) { + long hourTotal = hour.update(timestamp); + long dayTotal = day.update(timestamp); + + boolean result = hourTotal >= limit.getRequestHour() || dayTotal >= limit.getRequestDay(); + if (result) { + String errorMsg = String.format(""" + Hit request rate limit: + - hour limit: %d / %d requests + - day limit: %d / %d requests + """, + hourTotal, limit.getRequestHour(), dayTotal, limit.getRequestDay()); + return new RateLimitResult(HttpStatus.TOO_MANY_REQUESTS, errorMsg); + } else { + hour.add(timestamp, count); + day.add(timestamp, count); + return RateLimitResult.SUCCESS; + } + } + + public void update(long timestamp, LimitStats limitStats) { + long hourTotal = hour.update(timestamp); + long dayTotal = day.update(timestamp); + limitStats.getDayRequestStats().setUsed(dayTotal); + limitStats.getHourRequestStats().setUsed(hourTotal); + } +} diff --git a/src/main/java/com/epam/aidial/core/limiter/RateLimit.java b/src/main/java/com/epam/aidial/core/limiter/TokenRateLimit.java similarity index 98% rename from src/main/java/com/epam/aidial/core/limiter/RateLimit.java rename to src/main/java/com/epam/aidial/core/limiter/TokenRateLimit.java index 6690f590b..f4d2760b3 100644 --- a/src/main/java/com/epam/aidial/core/limiter/RateLimit.java +++ b/src/main/java/com/epam/aidial/core/limiter/TokenRateLimit.java @@ -6,7 +6,7 @@ import lombok.Data; @Data -public class RateLimit { +public class TokenRateLimit { private final RateBucket minute = new RateBucket(RateWindow.MINUTE); private final RateBucket day = new RateBucket(RateWindow.DAY); diff --git a/src/test/java/com/epam/aidial/core/LimitApiTest.java b/src/test/java/com/epam/aidial/core/LimitApiTest.java index 603790c77..5ecf50652 100644 --- a/src/test/java/com/epam/aidial/core/LimitApiTest.java +++ b/src/test/java/com/epam/aidial/core/LimitApiTest.java @@ -17,9 +17,17 @@ public void testGetLimitStats_Success() { "dayTokenStats": { "total": %d, "used": %d + }, + "hourRequestStats": { + "total": %d, + "used": %d + }, + "dayRequestStats": { + "total": %d, + "used": %d } } - """.formatted(Long.MAX_VALUE, 0, Long.MAX_VALUE, 0)); + """.formatted(Long.MAX_VALUE, 0, Long.MAX_VALUE, 0, Long.MAX_VALUE, 0, Long.MAX_VALUE, 0)); } @Test diff --git a/src/test/java/com/epam/aidial/core/limiter/RateLimiterTest.java b/src/test/java/com/epam/aidial/core/limiter/RateLimiterTest.java index a39fcfd32..498fbd083 100644 --- a/src/test/java/com/epam/aidial/core/limiter/RateLimiterTest.java +++ b/src/test/java/com/epam/aidial/core/limiter/RateLimiterTest.java @@ -113,20 +113,6 @@ public void beforeEach() { rateLimiter = new RateLimiter(vertx, resourceService); } - @Test - public void testLimit_EntityNotFound() { - ApiKeyData apiKeyData = new ApiKeyData(); - apiKeyData.setOriginalKey(new Key()); - ProxyContext proxyContext = new ProxyContext(new Config(), request, apiKeyData, null, "unknown-trace-id", "span-id"); - proxyContext.setDeployment(new Application()); - - Future result = rateLimiter.limit(proxyContext); - - assertNotNull(result); - assertNotNull(result.result()); - assertEquals(HttpStatus.FORBIDDEN, result.result().status()); - } - @Test public void testLimit_ApiKeyLimitNotFound() { Key key = new Key(); @@ -280,6 +266,8 @@ public void testGetLimitStats_ApiKey() { Limit limit = new Limit(); limit.setDay(10000); limit.setMinute(100); + limit.setRequestDay(10); + limit.setRequestHour(2); role.setLimits(Map.of("model", limit)); config.setRoles(Map.of("role", role)); ApiKeyData apiKeyData = new ApiKeyData(); @@ -298,6 +286,11 @@ public void testGetLimitStats_ApiKey() { tokenUsage.setTotalTokens(90); proxyContext.setTokenUsage(tokenUsage); + Future resultFuture = rateLimiter.limit(proxyContext); + assertNotNull(resultFuture); + assertNotNull(resultFuture.result()); + assertEquals(HttpStatus.OK, resultFuture.result().status()); + Future increaseLimitFuture = rateLimiter.increase(proxyContext); assertNotNull(increaseLimitFuture); assertNull(increaseLimitFuture.cause()); @@ -311,6 +304,10 @@ public void testGetLimitStats_ApiKey() { assertEquals(90, limitStats.getDayTokenStats().getUsed()); assertEquals(100, limitStats.getMinuteTokenStats().getTotal()); assertEquals(90, limitStats.getMinuteTokenStats().getUsed()); + assertEquals(10, limitStats.getDayRequestStats().getTotal()); + assertEquals(1, limitStats.getDayRequestStats().getUsed()); + assertEquals(2, limitStats.getHourRequestStats().getTotal()); + assertEquals(1, limitStats.getHourRequestStats().getUsed()); increaseLimitFuture = rateLimiter.increase(proxyContext); assertNotNull(increaseLimitFuture); @@ -385,7 +382,7 @@ public void testLimit_User_LimitFound() { } @Test - public void testLimit_User_LimitNotFound() { + public void testLimit_User_DefaultLimit() { Config config = new Config(); ApiKeyData apiKeyData = new ApiKeyData(); @@ -422,6 +419,61 @@ public void testLimit_User_LimitNotFound() { assertNotNull(checkLimitFuture); assertNotNull(checkLimitFuture.result()); assertEquals(HttpStatus.OK, checkLimitFuture.result().status()); + } + + @Test + public void testLimit_User_RequestLimit() { + Config config = new Config(); + + Role role1 = new Role(); + Limit limit = new Limit(); + limit.setRequestDay(10); + limit.setRequestHour(1); + role1.setLimits(Map.of("model", limit)); + + Role role2 = new Role(); + limit = new Limit(); + limit.setRequestDay(20); + limit.setRequestHour(1); + role2.setLimits(Map.of("model", limit)); + + config.getRoles().put("role1", role1); + config.getRoles().put("role2", role2); + + ApiKeyData apiKeyData = new ApiKeyData(); + ProxyContext proxyContext = new ProxyContext(config, request, apiKeyData, new ExtractedClaims("sub", List.of("role1", "role2"), "user-hash"), "trace-id", "span-id"); + Model model = new Model(); + model.setName("model"); + proxyContext.setDeployment(model); + + when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> { + Callable callable = invocation.getArgument(0); + return Future.succeededFuture(callable.call()); + }); + + TokenUsage tokenUsage = new TokenUsage(); + tokenUsage.setTotalTokens(150); + proxyContext.setTokenUsage(tokenUsage); + + Future increaseLimitFuture = rateLimiter.increase(proxyContext); + assertNotNull(increaseLimitFuture); + assertNull(increaseLimitFuture.cause()); + + Future checkLimitFuture = rateLimiter.limit(proxyContext); + + assertNotNull(checkLimitFuture); + assertNotNull(checkLimitFuture.result()); + assertEquals(HttpStatus.OK, checkLimitFuture.result().status()); + + increaseLimitFuture = rateLimiter.increase(proxyContext); + assertNotNull(increaseLimitFuture); + assertNull(increaseLimitFuture.cause()); + + checkLimitFuture = rateLimiter.limit(proxyContext); + + assertNotNull(checkLimitFuture); + assertNotNull(checkLimitFuture.result()); + assertEquals(HttpStatus.TOO_MANY_REQUESTS, checkLimitFuture.result().status()); }