Skip to content

Commit

Permalink
feat: Implement endpoint for getting API key spent limit information #…
Browse files Browse the repository at this point in the history
…218 (#219)

Co-authored-by: Aliaksandr Stsiapanay <[email protected]>
  • Loading branch information
astsiapanay and astsiapanay authored Feb 19, 2024
1 parent 3e1a5b2 commit b1efac4
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ public class ControllerSelector {
private static final Pattern INVITATIONS = Pattern.compile("^/v1/invitations$");
private static final Pattern INVITATION = Pattern.compile("^/v1/invitations/([a-zA-Z0-9]+)$");

private static final Pattern DEPLOYMENT_LIMITS = Pattern.compile("^/v1/deployments/([^/]+)/limits$");

public Controller select(Proxy proxy, ProxyContext context) {
String path = context.getRequest().path();
HttpMethod method = context.getRequest().method();
Expand Down Expand Up @@ -188,6 +190,13 @@ private static Controller selectGet(Proxy proxy, ProxyContext context, String pa
return controller::getInvitations;
}

match = match(DEPLOYMENT_LIMITS, path);
if (match != null) {
String deploymentId = UrlUtil.decodePath(match.group(1));
LimitController controller = new LimitController(proxy, context);
return () -> controller.getLimits(deploymentId);
}

return null;
}

Expand Down
35 changes: 35 additions & 0 deletions src/main/java/com/epam/aidial/core/controller/LimitController.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.epam.aidial.core.controller;

import com.epam.aidial.core.Proxy;
import com.epam.aidial.core.ProxyContext;
import com.epam.aidial.core.util.HttpStatus;
import io.vertx.core.Future;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class LimitController {

private final Proxy proxy;

private final ProxyContext context;

public LimitController(Proxy proxy, ProxyContext context) {
this.proxy = proxy;
this.context = context;
}

public Future<?> getLimits(String deploymentName) {
proxy.getRateLimiter().getLimitStats(deploymentName, context).onSuccess(limitStats -> {
if (limitStats == null) {
context.respond(HttpStatus.NOT_FOUND);
} else {
context.respond(HttpStatus.OK, limitStats);
}
}).onFailure(error -> {
log.error("Failed to get limit stats", error);
context.respond(HttpStatus.INTERNAL_SERVER_ERROR,
"Failed to get limit stats for deployment=%s".formatted(deploymentName));
});
return Future.succeededFuture();
}
}
9 changes: 9 additions & 0 deletions src/main/java/com/epam/aidial/core/data/LimitStats.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.epam.aidial.core.data;

import lombok.Data;

@Data
public class LimitStats {
private TokenLimitStats minuteTokenStats;
private TokenLimitStats dayTokenStats;
}
9 changes: 9 additions & 0 deletions src/main/java/com/epam/aidial/core/data/TokenLimitStats.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.epam.aidial.core.data;

import lombok.Data;

@Data
public class TokenLimitStats {
private long total;
private long used;
}
8 changes: 8 additions & 0 deletions src/main/java/com/epam/aidial/core/limiter/RateLimit.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.epam.aidial.core.limiter;

import com.epam.aidial.core.config.Limit;
import com.epam.aidial.core.data.LimitStats;
import com.epam.aidial.core.util.HttpStatus;
import lombok.Data;

Expand Down Expand Up @@ -32,4 +33,11 @@ public RateLimitResult update(long timestamp, Limit limit) {
return RateLimitResult.SUCCESS;
}
}

public void update(long timestamp, LimitStats limitStats) {
long minuteTotal = minute.update(timestamp);
long dayTotal = day.update(timestamp);
limitStats.getDayTokenStats().setUsed(dayTotal);
limitStats.getMinuteTokenStats().setUsed(minuteTotal);
}
}
94 changes: 70 additions & 24 deletions src/main/java/com/epam/aidial/core/limiter/RateLimiter.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import com.epam.aidial.core.config.Key;
import com.epam.aidial.core.config.Limit;
import com.epam.aidial.core.config.Role;
import com.epam.aidial.core.data.LimitStats;
import com.epam.aidial.core.data.ResourceType;
import com.epam.aidial.core.data.TokenLimitStats;
import com.epam.aidial.core.service.ResourceService;
import com.epam.aidial.core.storage.BlobStorageUtil;
import com.epam.aidial.core.storage.ResourceDescription;
Expand Down Expand Up @@ -43,8 +45,8 @@ public Future<Void> increase(ProxyContext context) {
return Future.succeededFuture();
}

String path = getPath(deployment.getName());
return vertx.executeBlocking(() -> updateLimit(path, context, usage.getTotalTokens()));
ResourceDescription resourceDescription = getResourceDescription(deployment.getName(), context);
return vertx.executeBlocking(() -> updateLimit(resourceDescription, usage.getTotalTokens()));
} catch (Throwable e) {
return Future.failedFuture(e);
}
Expand All @@ -57,16 +59,15 @@ public Future<RateLimitResult> limit(ProxyContext context) {
return Future.succeededFuture(RateLimitResult.SUCCESS);
}
Key key = context.getKey();
Deployment deployment = context.getDeployment();
Limit limit;
if (key == null) {
// don't support user limits yet
return Future.succeededFuture(RateLimitResult.SUCCESS);
} else {
limit = getLimitByApiKey(context);
limit = getLimitByApiKey(context, deployment.getName());
}

Deployment deployment = context.getDeployment();

if (limit == null || !limit.isPositive()) {
if (limit == null) {
log.warn("Limit is not found for deployment: {}", deployment.getName());
Expand All @@ -76,48 +77,94 @@ public Future<RateLimitResult> limit(ProxyContext context) {
return Future.succeededFuture(new RateLimitResult(HttpStatus.FORBIDDEN, "Access denied"));
}

String path = getPath(deployment.getName());
return vertx.executeBlocking(() -> checkLimit(path, limit, context));
ResourceDescription resourceDescription = getResourceDescription(deployment.getName(), context);
return vertx.executeBlocking(() -> checkLimit(resourceDescription, limit));
} catch (Throwable e) {
return Future.failedFuture(e);
}
}

private RateLimitResult checkLimit(String path, Limit limit, ProxyContext context) throws Exception {
RateLimit rateLimit;
public Future<LimitStats> getLimitStats(String deploymentName, ProxyContext context) {
try {
// skip checking limits if redis is not available
if (resourceService == null) {
return Future.succeededFuture();
}
Key key = context.getKey();
Limit limit;
if (key == null) {
// don't support user limits yet
return Future.succeededFuture();
} else {
limit = getLimitByApiKey(context, deploymentName);
}
if (limit == null) {
log.warn("Limit is not found. Trace: {}. Span: {}. Key: {}. Deployment: {}", context.getTraceId(), context.getSpanId(), key.getProject(), deploymentName);
return Future.succeededFuture();
}
ResourceDescription resourceDescription = getResourceDescription(deploymentName, context);
return vertx.executeBlocking(() -> getLimitStats(resourceDescription, limit));
} catch (Throwable e) {
return Future.failedFuture(e);
}
}

private LimitStats getLimitStats(ResourceDescription resourceDescription, Limit limit) {
String json = resourceService.getResource(resourceDescription, true);
LimitStats limitStats = create(limit);
RateLimit rateLimit = ProxyUtil.convertToObject(json, RateLimit.class);
if (rateLimit == null) {
return limitStats;
}
long timestamp = System.currentTimeMillis();
rateLimit.update(timestamp, limitStats);
return limitStats;
}

private LimitStats create(Limit limit) {
LimitStats limitStats = new LimitStats();
TokenLimitStats dayTokenStats = new TokenLimitStats();
dayTokenStats.setTotal(limit.getDay());
limitStats.setDayTokenStats(dayTokenStats);
TokenLimitStats minuteTokenStats = new TokenLimitStats();
minuteTokenStats.setTotal(limit.getMinute());
limitStats.setMinuteTokenStats(minuteTokenStats);
return limitStats;
}

private ResourceDescription getResourceDescription(String deploymentName, ProxyContext context) {
String path = getPath(deploymentName);
String bucketLocation = BlobStorageUtil.buildUserBucket(context);
ResourceDescription resourceDescription = ResourceDescription.fromEncoded(ResourceType.LIMIT, bucketLocation, bucketLocation, path);
return ResourceDescription.fromEncoded(ResourceType.LIMIT, bucketLocation, bucketLocation, path);
}

private RateLimitResult checkLimit(ResourceDescription resourceDescription, Limit limit) {
String prevValue = resourceService.getResource(resourceDescription);
if (prevValue == null) {
RateLimit rateLimit = ProxyUtil.convertToObject(prevValue, RateLimit.class);
if (rateLimit == null) {
return RateLimitResult.SUCCESS;
} else {
rateLimit = ProxyUtil.MAPPER.readValue(prevValue, RateLimit.class);
}
long timestamp = System.currentTimeMillis();
return rateLimit.update(timestamp, limit);
}

private Void updateLimit(String path, ProxyContext context, long totalUsedTokens) {
String bucketLocation = BlobStorageUtil.buildUserBucket(context);
ResourceDescription resourceDescription = ResourceDescription.fromEncoded(ResourceType.LIMIT, bucketLocation, bucketLocation, path);
private Void updateLimit(ResourceDescription resourceDescription, long totalUsedTokens) {
resourceService.computeResource(resourceDescription, json -> updateLimit(json, totalUsedTokens));
return null;
}

@SneakyThrows
private String updateLimit(String json, long totalUsedTokens) {
RateLimit rateLimit;
if (json == null) {
RateLimit rateLimit = ProxyUtil.convertToObject(json, RateLimit.class);
if (rateLimit == null) {
rateLimit = new RateLimit();
} else {
rateLimit = ProxyUtil.MAPPER.readValue(json, RateLimit.class);
}
long timestamp = System.currentTimeMillis();
rateLimit.add(timestamp, totalUsedTokens);
return ProxyUtil.MAPPER.writeValueAsString(rateLimit);
return ProxyUtil.convertToString(rateLimit);
}

private Limit getLimitByApiKey(ProxyContext context) {
private Limit getLimitByApiKey(ProxyContext context, String deploymentName) {
// API key has always one role
Role role = context.getConfig().getRoles().get(context.getKey().getRole());

Expand All @@ -126,8 +173,7 @@ private Limit getLimitByApiKey(ProxyContext context) {
return null;
}

Deployment deployment = context.getDeployment();
return role.getLimits().get(deployment.getName());
return role.getLimits().get(deploymentName);
}

private static String getPath(String deploymentName) {
Expand Down
30 changes: 30 additions & 0 deletions src/test/java/com/epam/aidial/core/LimitApiTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package com.epam.aidial.core;

import io.vertx.core.http.HttpMethod;
import org.junit.jupiter.api.Test;

public class LimitApiTest extends ResourceBaseTest {

@Test
public void testGetLimitStats_Success() {
Response response = send(HttpMethod.GET, "/v1/deployments/test-model-v1/limits", null, null);
verifyJson(response, 200, """
{
"minuteTokenStats": {
"total": %d,
"used": %d
},
"dayTokenStats": {
"total": %d,
"used": %d
}
}
""".formatted(Long.MAX_VALUE, 0, Long.MAX_VALUE, 0));
}

@Test
public void testGetLimitStats_UnknownModel() {
Response response = send(HttpMethod.GET, "/v1/deployments/unknown-model/limits", null, null);
verify(response, 404);
}
}
60 changes: 60 additions & 0 deletions src/test/java/com/epam/aidial/core/limiter/RateLimiterTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.epam.aidial.core.config.Limit;
import com.epam.aidial.core.config.Model;
import com.epam.aidial.core.config.Role;
import com.epam.aidial.core.data.LimitStats;
import com.epam.aidial.core.security.ExtractedClaims;
import com.epam.aidial.core.service.LockService;
import com.epam.aidial.core.service.ResourceService;
Expand Down Expand Up @@ -280,4 +281,63 @@ public void testLimit_ApiKeySuccess_KeyExist() {

}

@Test
public void testGetLimitStats_ApiKey() {
Key key = new Key();
key.setRole("role");
key.setKey("key");
key.setProject("api-key");
Config config = new Config();
Role role = new Role();
Limit limit = new Limit();
limit.setDay(10000);
limit.setMinute(100);
role.setLimits(Map.of("model", limit));
config.setRoles(Map.of("role", role));
ApiKeyData apiKeyData = new ApiKeyData();
apiKeyData.setOriginalKey(key);
ProxyContext proxyContext = new ProxyContext(config, request, apiKeyData, null, "trace-id", "span-id");
Model model = new Model();
model.setName("model");
proxyContext.setDeployment(model);

when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

TokenUsage tokenUsage = new TokenUsage();
tokenUsage.setTotalTokens(90);
proxyContext.setTokenUsage(tokenUsage);

Future<Void> increaseLimitFuture = rateLimiter.increase(proxyContext);
assertNotNull(increaseLimitFuture);
assertNull(increaseLimitFuture.cause());

Future<LimitStats> limitStatsFuture = rateLimiter.getLimitStats(model.getName(), proxyContext);

assertNotNull(limitStatsFuture);
assertNotNull(limitStatsFuture.result());
LimitStats limitStats = limitStatsFuture.result();
assertEquals(10000, limitStats.getDayTokenStats().getTotal());
assertEquals(90, limitStats.getDayTokenStats().getUsed());
assertEquals(100, limitStats.getMinuteTokenStats().getTotal());
assertEquals(90, limitStats.getMinuteTokenStats().getUsed());

increaseLimitFuture = rateLimiter.increase(proxyContext);
assertNotNull(increaseLimitFuture);
assertNull(increaseLimitFuture.cause());

limitStatsFuture = rateLimiter.getLimitStats(model.getName(), proxyContext);

assertNotNull(limitStatsFuture);
assertNotNull(limitStatsFuture.result());
limitStats = limitStatsFuture.result();
assertEquals(10000, limitStats.getDayTokenStats().getTotal());
assertEquals(180, limitStats.getDayTokenStats().getUsed());
assertEquals(100, limitStats.getMinuteTokenStats().getTotal());
assertEquals(180, limitStats.getMinuteTokenStats().getUsed());

}

}

0 comments on commit b1efac4

Please sign in to comment.