diff --git a/README.md b/README.md index 49d042177..ec9c6c81d 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ Static settings are used on startup and cannot be changed while application is r |--------------------------------------------|--------------------|------------------------------------------------------------------------------------------------------------------- | config.files | aidial.config.json | Config files with parts of the whole config. | config.reload | 60000 | Config reload interval in milliseconds. -| identityProviders | - | List of identity providers. **Note**. At least one identity provider must be provided. +| identityProviders | - | Map of identity providers. **Note**. At least one identity provider must be provided. | identityProviders.*.jwksUrl | - | Url to jwks provider. **Required** if `disabledVerifyJwt` is set to `false` | identityProviders.*.rolePath | - | Path to the claim user roles in JWT token, e.g. `resource_access.chatbot-ui.roles` or just `roles`. **Required**. | identityProviders.*.loggingKey | - | User information to search in claims of JWT token. @@ -164,22 +164,23 @@ Dynamic settings include: * Access Permissions * Rate Limits -| Parameter | Description | -| ------------------------------- | ------------ | -| routes | Path(s) for specific upstream routing or to respond with a configured body. | -| applications | A list of deployed AI DIAL Applications and their parameters:
``: Unique application name. | -| applications. | `endpoint`: AI DIAL Application API for chat completions.
`iconUrl`: Icon path for the AI DIAL Application on UI.
`description`: Brief AI DIAL Application description.
`displayName`: AI DIAL Application name on UI.
`inputAttachmentTypes`: A list of allowed MIME types for the input attachments.
`maxInputAttachments`: Maximum number of input attachments (default is zero when `inputAttachmentTypes` is unset, otherwise, infinity) | -| models | A list of deployed models and their parameters:
``: Unique model name. | -| models. | `type`: Model type—`chat` or `embedding`.
`iconUrl`: Icon path for the model on UI.
`description`: Brief model description.
`displayName`: Model name on UI.
`displayVersion`: Model version on UI.
`endpoint`: Model API for chat completions or embeddings.
`tokenizerModel`: Identifies the specific model whose tokenization algorithm exactly matches that of the referenced model. This is typically the name of the earliest-released model in a series of models sharing an identical tokenization algorithm (e.g. `gpt-3.5-turbo-0301`, `gpt-4-0314`, or `gpt-4-1106-vision-preview`). This parameter is essential for DIAL clients that reimplement tokenization algorithms on their side, instead of utilizing the `tokenizeEndpoint` provided by the model.
`features`: Model features.
`limits`: Model token limits.
`pricing`: Model pricing.
`upstreams`: Used for load-balancing—request is sent to model endpoint containing X-UPSTREAM-ENDPOINT and X-UPSTREAM-KEY headers. | -| models..limits | `maxPromptTokens`: maximum number of tokens in a completion request.
`maxCompletionTokens`: maximum number of tokens in a completion response.
`maxTotalTokens`: maximum number of tokens in completion request and response combined.
Typically either `maxTotalTokens` is specified or `maxPromptTokens` and `maxCompletionTokens`. | -| models..pricing | `unit`: the pricing units (currently `token` and `char_without_whitespace` are supported).
`prompt`: per-unit price for the completion request in USD.
`completion`: per-unit price for the completion response in USD. | -| models..features | `rateEndpoint`: endpoint for rate requests *(exposed by core as `/rate`)*.
`tokenizeEndpoint`: endpoint for requests to the model tokenizer *(exposed by core as `/tokenize`)*.
`truncatePromptEndpoint`: endpoint for truncating prompt requests *(exposed by core as `/truncate_prompt`)*.
`systemPromptSupported`: does the model support system prompt (default is `true`).
`toolsSupported`: does the model support tools (default is `false`).
`seedSupported`: does the model support `seed` request parameter (default is `false`).
`urlAttachmentsSupported`: does the model/application support attachments with URLs (default is `false`) | -| models..upstreams | `endpoint`: Model endpoint.
`key`: Your API key. | -| keys | API Keys parameters:
``: Your API key. | -| keys. | `project`: Project name assigned to this key.
`role`: A configured role name that defines key permissions. | -| roles | API key roles `` with associated limits. Each API key has one role defined in the list of roles. Roles are associated with models, applications, assistants, and defined limits. | -| roles. | `limits`: Limits for models, applications, or assistants. | -| roles..limits | `minute`: Total tokens per minute limit sent to the model, managed via floating window approach for well-distributed rate limiting.
`day`: Total tokens per day limit sent to the model, managed via floating window approach for balanced rate limiting. | +| Parameter | Description | +|---------------------------------------| ------------ | +| routes | Path(s) for specific upstream routing or to respond with a configured body. | +| applications | A list of deployed AI DIAL Applications and their parameters:
``: Unique application name. | +| applications. | `endpoint`: AI DIAL Application API for chat completions.
`iconUrl`: Icon path for the AI DIAL Application on UI.
`description`: Brief AI DIAL Application description.
`displayName`: AI DIAL Application name on UI.
`inputAttachmentTypes`: A list of allowed MIME types for the input attachments.
`maxInputAttachments`: Maximum number of input attachments (default is zero when `inputAttachmentTypes` is unset, otherwise, infinity) | +| models | A list of deployed models and their parameters:
``: Unique model name. | +| models. | `type`: Model type—`chat` or `embedding`.
`iconUrl`: Icon path for the model on UI.
`description`: Brief model description.
`displayName`: Model name on UI.
`displayVersion`: Model version on UI.
`endpoint`: Model API for chat completions or embeddings.
`tokenizerModel`: Identifies the specific model whose tokenization algorithm exactly matches that of the referenced model. This is typically the name of the earliest-released model in a series of models sharing an identical tokenization algorithm (e.g. `gpt-3.5-turbo-0301`, `gpt-4-0314`, or `gpt-4-1106-vision-preview`). This parameter is essential for DIAL clients that reimplement tokenization algorithms on their side, instead of utilizing the `tokenizeEndpoint` provided by the model.
`features`: Model features.
`limits`: Model token limits.
`pricing`: Model pricing.
`upstreams`: Used for load-balancing—request is sent to model endpoint containing X-UPSTREAM-ENDPOINT and X-UPSTREAM-KEY headers. | +| models..limits | `maxPromptTokens`: maximum number of tokens in a completion request.
`maxCompletionTokens`: maximum number of tokens in a completion response.
`maxTotalTokens`: maximum number of tokens in completion request and response combined.
Typically either `maxTotalTokens` is specified or `maxPromptTokens` and `maxCompletionTokens`. | +| models..pricing | `unit`: the pricing units (currently `token` and `char_without_whitespace` are supported).
`prompt`: per-unit price for the completion request in USD.
`completion`: per-unit price for the completion response in USD. | +| models..features | `rateEndpoint`: endpoint for rate requests *(exposed by core as `/rate`)*.
`tokenizeEndpoint`: endpoint for requests to the model tokenizer *(exposed by core as `/tokenize`)*.
`truncatePromptEndpoint`: endpoint for truncating prompt requests *(exposed by core as `/truncate_prompt`)*.
`systemPromptSupported`: does the model support system prompt (default is `true`).
`toolsSupported`: does the model support tools (default is `false`).
`seedSupported`: does the model support `seed` request parameter (default is `false`).
`urlAttachmentsSupported`: does the model/application support attachments with URLs (default is `false`) | +| models..upstreams | `endpoint`: Model endpoint.
`key`: Your API key. | +| models..defaultUserLimit | Default user limit for the given model.
`minute`: Total tokens per minute limit sent to the model, managed via floating window approach for well-distributed rate limiting.
`day`: Total tokens per day limit sent to the model, managed via floating window approach for balanced rate limiting.| +| keys | API Keys parameters:
``: Your API key. | +| keys. | `project`: Project name assigned to this key.
`role`: A configured role name that defines key permissions. | +| roles | API key roles `` with associated limits. Each API key has one role defined in the list of roles. Roles are associated with models, applications, assistants, and defined limits. | +| roles. | `limits`: Limits for models, applications, or assistants. | +| roles..limits | `minute`: Total tokens per minute limit sent to the model, managed via floating window approach for well-distributed rate limiting.
`day`: Total tokens per day limit sent to the model, managed via floating window approach for balanced rate limiting. | ## License diff --git a/sample/aidial.config.json b/sample/aidial.config.json index a31e72a79..0bd2c3ac1 100644 --- a/sample/aidial.config.json +++ b/sample/aidial.config.json @@ -46,7 +46,12 @@ "endpoint": "http://localhost:7003", "key": "modelKey3" } - ] + ], + "userRoles": ["role1", "role2"], + "defaultUserLimit": { + "minute": "100000", + "day": "10000000" + } }, "embedding-ada": { "type": "embedding", @@ -57,7 +62,8 @@ "key": "modelKey4" } ] - } + }, + "userRoles": ["role3"] }, "keys": { "proxyKey1": { @@ -86,6 +92,22 @@ "search_assistant": {}, "app": {} } + }, + "role1": { + "limits": { + "chat-gpt-35-turbo": { + "minute": "200000", + "day": "10000000" + } + } + }, + "role2": { + "limits": { + "chat-gpt-35-turbo": { + "minute": "100000", + "day": "20000000" + } + } } } } \ No newline at end of file diff --git a/src/main/java/com/epam/aidial/core/config/Model.java b/src/main/java/com/epam/aidial/core/config/Model.java index a95fa70f2..ccb30e279 100644 --- a/src/main/java/com/epam/aidial/core/config/Model.java +++ b/src/main/java/com/epam/aidial/core/config/Model.java @@ -17,4 +17,5 @@ public class Model extends Deployment { private List upstreams = List.of(); // if it's set then the model name is overridden with that name in the request body to the model adapter private String overrideName; + private Limit defaultUserLimit; } \ No newline at end of file diff --git a/src/main/java/com/epam/aidial/core/limiter/RateLimiter.java b/src/main/java/com/epam/aidial/core/limiter/RateLimiter.java index 74d1d773e..e32f3fa6b 100644 --- a/src/main/java/com/epam/aidial/core/limiter/RateLimiter.java +++ b/src/main/java/com/epam/aidial/core/limiter/RateLimiter.java @@ -4,6 +4,7 @@ import com.epam.aidial.core.config.Deployment; import com.epam.aidial.core.config.Key; import com.epam.aidial.core.config.Limit; +import com.epam.aidial.core.config.Model; import com.epam.aidial.core.config.Role; import com.epam.aidial.core.data.LimitStats; import com.epam.aidial.core.data.ResourceType; @@ -20,10 +21,16 @@ import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; +import java.util.List; +import java.util.Map; +import java.util.Optional; + @Slf4j @RequiredArgsConstructor public class RateLimiter { + private static final Limit DEFAULT_LIMIT = new Limit(); + private final Vertx vertx; private final ResourceService resourceService; @@ -34,10 +41,7 @@ public Future increase(ProxyContext context) { if (resourceService == null) { return Future.succeededFuture(); } - Key key = context.getKey(); - if (key == null) { - return Future.succeededFuture(); - } + Deployment deployment = context.getDeployment(); TokenUsage usage = context.getTokenUsage(); @@ -62,8 +66,7 @@ public Future limit(ProxyContext context) { Deployment deployment = context.getDeployment(); Limit limit; if (key == null) { - // don't support user limits yet - return Future.succeededFuture(RateLimitResult.SUCCESS); + limit = getLimitByUser(context); } else { limit = getLimitByApiKey(context, deployment.getName()); } @@ -176,8 +179,38 @@ private Limit getLimitByApiKey(ProxyContext context, String deploymentName) { return role.getLimits().get(deploymentName); } + private Limit getLimitByUser(ProxyContext context) { + List userRoles = context.getUserRoles(); + Limit defaultUserLimit = getDefaultUserLimit(context.getDeployment()); + if (userRoles.isEmpty()) { + return defaultUserLimit; + } + String deploymentName = context.getDeployment().getName(); + Map userRoleToDeploymentLimits = context.getConfig().getRoles(); + long minuteLimit = 0; + long dayLimit = 0; + for (String userRole : userRoles) { + Limit limit = Optional.ofNullable(userRoleToDeploymentLimits.get(userRole)) + .map(role -> role.getLimits().get(deploymentName)) + .orElse(defaultUserLimit); + minuteLimit = Math.max(minuteLimit, limit.getMinute()); + dayLimit = Math.max(dayLimit, limit.getDay()); + } + Limit limit = new Limit(); + limit.setMinute(minuteLimit); + limit.setDay(dayLimit); + return limit; + } + private static String getPath(String deploymentName) { return String.format("%s/tokens", deploymentName); } + private static Limit getDefaultUserLimit(Deployment deployment) { + if (deployment instanceof Model model) { + return model.getDefaultUserLimit() == null ? DEFAULT_LIMIT : model.getDefaultUserLimit(); + } + return DEFAULT_LIMIT; + } + } diff --git a/src/test/java/com/epam/aidial/core/limiter/RateLimiterTest.java b/src/test/java/com/epam/aidial/core/limiter/RateLimiterTest.java index 469d0a4b8..c3ef7773a 100644 --- a/src/test/java/com/epam/aidial/core/limiter/RateLimiterTest.java +++ b/src/test/java/com/epam/aidial/core/limiter/RateLimiterTest.java @@ -15,7 +15,6 @@ import com.epam.aidial.core.storage.BlobStorage; import com.epam.aidial.core.token.TokenUsage; import com.epam.aidial.core.util.HttpStatus; -import com.epam.aidial.core.util.ProxyUtil; import io.vertx.core.Future; import io.vertx.core.Vertx; import io.vertx.core.http.HttpServerRequest; @@ -34,7 +33,7 @@ import redis.embedded.RedisServer; import java.io.IOException; -import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.concurrent.Callable; @@ -128,17 +127,6 @@ public void testLimit_EntityNotFound() { assertEquals(HttpStatus.FORBIDDEN, result.result().status()); } - @Test - public void testLimit_SuccessUser() { - ProxyContext proxyContext = new ProxyContext(new Config(), request, new ApiKeyData(), new ExtractedClaims("sub", Collections.emptyList(), "hash"), "trace-id", "span-id"); - - Future result = rateLimiter.limit(proxyContext); - - assertNotNull(result); - assertNotNull(result.result()); - assertEquals(HttpStatus.OK, result.result().status()); - } - @Test public void testLimit_ApiKeyLimitNotFound() { Key key = new Key(); @@ -340,4 +328,101 @@ public void testGetLimitStats_ApiKey() { } + @Test + public void testLimit_User_LimitFound() { + Config config = new Config(); + + Role role1 = new Role(); + Limit limit = new Limit(); + limit.setDay(10000); + limit.setMinute(100); + role1.setLimits(Map.of("model", limit)); + + Role role2 = new Role(); + limit = new Limit(); + limit.setDay(20000); + limit.setMinute(200); + role2.setLimits(Map.of("model", limit)); + + config.getRoles().put("role1", role1); + config.getRoles().put("role2", role2); + + ApiKeyData apiKeyData = new ApiKeyData(); + ProxyContext proxyContext = new ProxyContext(config, request, apiKeyData, new ExtractedClaims("sub", List.of("role1", "role2"), "user-hash"), "trace-id", "span-id"); + Model model = new Model(); + model.setName("model"); + proxyContext.setDeployment(model); + + when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> { + Callable callable = invocation.getArgument(0); + return Future.succeededFuture(callable.call()); + }); + + TokenUsage tokenUsage = new TokenUsage(); + tokenUsage.setTotalTokens(150); + proxyContext.setTokenUsage(tokenUsage); + + Future increaseLimitFuture = rateLimiter.increase(proxyContext); + assertNotNull(increaseLimitFuture); + assertNull(increaseLimitFuture.cause()); + + Future checkLimitFuture = rateLimiter.limit(proxyContext); + + assertNotNull(checkLimitFuture); + assertNotNull(checkLimitFuture.result()); + assertEquals(HttpStatus.OK, checkLimitFuture.result().status()); + + increaseLimitFuture = rateLimiter.increase(proxyContext); + assertNotNull(increaseLimitFuture); + assertNull(increaseLimitFuture.cause()); + + checkLimitFuture = rateLimiter.limit(proxyContext); + + assertNotNull(checkLimitFuture); + assertNotNull(checkLimitFuture.result()); + assertEquals(HttpStatus.TOO_MANY_REQUESTS, checkLimitFuture.result().status()); + + } + + @Test + public void testLimit_User_LimitNotFound() { + Config config = new Config(); + + ApiKeyData apiKeyData = new ApiKeyData(); + ProxyContext proxyContext = new ProxyContext(config, request, apiKeyData, new ExtractedClaims("sub", List.of("role1"), "user-hash"), "trace-id", "span-id"); + Model model = new Model(); + model.setName("model"); + proxyContext.setDeployment(model); + + when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> { + Callable callable = invocation.getArgument(0); + return Future.succeededFuture(callable.call()); + }); + + TokenUsage tokenUsage = new TokenUsage(); + tokenUsage.setTotalTokens(90); + proxyContext.setTokenUsage(tokenUsage); + + Future increaseLimitFuture = rateLimiter.increase(proxyContext); + assertNotNull(increaseLimitFuture); + assertNull(increaseLimitFuture.cause()); + + Future checkLimitFuture = rateLimiter.limit(proxyContext); + + assertNotNull(checkLimitFuture); + assertNotNull(checkLimitFuture.result()); + assertEquals(HttpStatus.OK, checkLimitFuture.result().status()); + + increaseLimitFuture = rateLimiter.increase(proxyContext); + assertNotNull(increaseLimitFuture); + assertNull(increaseLimitFuture.cause()); + + checkLimitFuture = rateLimiter.limit(proxyContext); + + assertNotNull(checkLimitFuture); + assertNotNull(checkLimitFuture.result()); + assertEquals(HttpStatus.OK, checkLimitFuture.result().status()); + + } + }