diff --git a/README.md b/README.md
index 49d042177..ec9c6c81d 100644
--- a/README.md
+++ b/README.md
@@ -37,7 +37,7 @@ Static settings are used on startup and cannot be changed while application is r
|--------------------------------------------|--------------------|-------------------------------------------------------------------------------------------------------------------
| config.files | aidial.config.json | Config files with parts of the whole config.
| config.reload | 60000 | Config reload interval in milliseconds.
-| identityProviders | - | List of identity providers. **Note**. At least one identity provider must be provided.
+| identityProviders | - | Map of identity providers. **Note**. At least one identity provider must be provided.
| identityProviders.*.jwksUrl | - | Url to jwks provider. **Required** if `disabledVerifyJwt` is set to `false`
| identityProviders.*.rolePath | - | Path to the claim user roles in JWT token, e.g. `resource_access.chatbot-ui.roles` or just `roles`. **Required**.
| identityProviders.*.loggingKey | - | User information to search in claims of JWT token.
@@ -164,22 +164,23 @@ Dynamic settings include:
* Access Permissions
* Rate Limits
-| Parameter | Description |
-| ------------------------------- | ------------ |
-| routes | Path(s) for specific upstream routing or to respond with a configured body. |
-| applications | A list of deployed AI DIAL Applications and their parameters:
``: Unique application name. |
-| applications. | `endpoint`: AI DIAL Application API for chat completions.
`iconUrl`: Icon path for the AI DIAL Application on UI.
`description`: Brief AI DIAL Application description.
`displayName`: AI DIAL Application name on UI.
`inputAttachmentTypes`: A list of allowed MIME types for the input attachments.
`maxInputAttachments`: Maximum number of input attachments (default is zero when `inputAttachmentTypes` is unset, otherwise, infinity) |
-| models | A list of deployed models and their parameters:
``: Unique model name. |
-| models. | `type`: Model type—`chat` or `embedding`.
`iconUrl`: Icon path for the model on UI.
`description`: Brief model description.
`displayName`: Model name on UI.
`displayVersion`: Model version on UI.
`endpoint`: Model API for chat completions or embeddings.
`tokenizerModel`: Identifies the specific model whose tokenization algorithm exactly matches that of the referenced model. This is typically the name of the earliest-released model in a series of models sharing an identical tokenization algorithm (e.g. `gpt-3.5-turbo-0301`, `gpt-4-0314`, or `gpt-4-1106-vision-preview`). This parameter is essential for DIAL clients that reimplement tokenization algorithms on their side, instead of utilizing the `tokenizeEndpoint` provided by the model.
`features`: Model features.
`limits`: Model token limits.
`pricing`: Model pricing.
`upstreams`: Used for load-balancing—request is sent to model endpoint containing X-UPSTREAM-ENDPOINT and X-UPSTREAM-KEY headers. |
-| models..limits | `maxPromptTokens`: maximum number of tokens in a completion request.
`maxCompletionTokens`: maximum number of tokens in a completion response.
`maxTotalTokens`: maximum number of tokens in completion request and response combined.
Typically either `maxTotalTokens` is specified or `maxPromptTokens` and `maxCompletionTokens`. |
-| models..pricing | `unit`: the pricing units (currently `token` and `char_without_whitespace` are supported).
`prompt`: per-unit price for the completion request in USD.
`completion`: per-unit price for the completion response in USD. |
-| models..features | `rateEndpoint`: endpoint for rate requests *(exposed by core as `/rate`)*.
`tokenizeEndpoint`: endpoint for requests to the model tokenizer *(exposed by core as `/tokenize`)*.
`truncatePromptEndpoint`: endpoint for truncating prompt requests *(exposed by core as `/truncate_prompt`)*.
`systemPromptSupported`: does the model support system prompt (default is `true`).
`toolsSupported`: does the model support tools (default is `false`).
`seedSupported`: does the model support `seed` request parameter (default is `false`).
`urlAttachmentsSupported`: does the model/application support attachments with URLs (default is `false`) |
-| models..upstreams | `endpoint`: Model endpoint.
`key`: Your API key. |
-| keys | API Keys parameters:
``: Your API key. |
-| keys. | `project`: Project name assigned to this key.
`role`: A configured role name that defines key permissions. |
-| roles | API key roles `` with associated limits. Each API key has one role defined in the list of roles. Roles are associated with models, applications, assistants, and defined limits. |
-| roles. | `limits`: Limits for models, applications, or assistants. |
-| roles..limits | `minute`: Total tokens per minute limit sent to the model, managed via floating window approach for well-distributed rate limiting.
`day`: Total tokens per day limit sent to the model, managed via floating window approach for balanced rate limiting. |
+| Parameter | Description |
+|---------------------------------------| ------------ |
+| routes | Path(s) for specific upstream routing or to respond with a configured body. |
+| applications | A list of deployed AI DIAL Applications and their parameters:
``: Unique application name. |
+| applications. | `endpoint`: AI DIAL Application API for chat completions.
`iconUrl`: Icon path for the AI DIAL Application on UI.
`description`: Brief AI DIAL Application description.
`displayName`: AI DIAL Application name on UI.
`inputAttachmentTypes`: A list of allowed MIME types for the input attachments.
`maxInputAttachments`: Maximum number of input attachments (default is zero when `inputAttachmentTypes` is unset, otherwise, infinity) |
+| models | A list of deployed models and their parameters:
``: Unique model name. |
+| models. | `type`: Model type—`chat` or `embedding`.
`iconUrl`: Icon path for the model on UI.
`description`: Brief model description.
`displayName`: Model name on UI.
`displayVersion`: Model version on UI.
`endpoint`: Model API for chat completions or embeddings.
`tokenizerModel`: Identifies the specific model whose tokenization algorithm exactly matches that of the referenced model. This is typically the name of the earliest-released model in a series of models sharing an identical tokenization algorithm (e.g. `gpt-3.5-turbo-0301`, `gpt-4-0314`, or `gpt-4-1106-vision-preview`). This parameter is essential for DIAL clients that reimplement tokenization algorithms on their side, instead of utilizing the `tokenizeEndpoint` provided by the model.
`features`: Model features.
`limits`: Model token limits.
`pricing`: Model pricing.
`upstreams`: Used for load-balancing—request is sent to model endpoint containing X-UPSTREAM-ENDPOINT and X-UPSTREAM-KEY headers. |
+| models..limits | `maxPromptTokens`: maximum number of tokens in a completion request.
`maxCompletionTokens`: maximum number of tokens in a completion response.
`maxTotalTokens`: maximum number of tokens in completion request and response combined.
Typically either `maxTotalTokens` is specified or `maxPromptTokens` and `maxCompletionTokens`. |
+| models..pricing | `unit`: the pricing units (currently `token` and `char_without_whitespace` are supported).
`prompt`: per-unit price for the completion request in USD.
`completion`: per-unit price for the completion response in USD. |
+| models..features | `rateEndpoint`: endpoint for rate requests *(exposed by core as `/rate`)*.
`tokenizeEndpoint`: endpoint for requests to the model tokenizer *(exposed by core as `/tokenize`)*.
`truncatePromptEndpoint`: endpoint for truncating prompt requests *(exposed by core as `/truncate_prompt`)*.
`systemPromptSupported`: does the model support system prompt (default is `true`).
`toolsSupported`: does the model support tools (default is `false`).
`seedSupported`: does the model support `seed` request parameter (default is `false`).
`urlAttachmentsSupported`: does the model/application support attachments with URLs (default is `false`) |
+| models..upstreams | `endpoint`: Model endpoint.
`key`: Your API key. |
+| models..defaultUserLimit | Default user limit for the given model.
`minute`: Total tokens per minute limit sent to the model, managed via floating window approach for well-distributed rate limiting.
`day`: Total tokens per day limit sent to the model, managed via floating window approach for balanced rate limiting.|
+| keys | API Keys parameters:
``: Your API key. |
+| keys. | `project`: Project name assigned to this key.
`role`: A configured role name that defines key permissions. |
+| roles | API key roles `` with associated limits. Each API key has one role defined in the list of roles. Roles are associated with models, applications, assistants, and defined limits. |
+| roles. | `limits`: Limits for models, applications, or assistants. |
+| roles..limits | `minute`: Total tokens per minute limit sent to the model, managed via floating window approach for well-distributed rate limiting.
`day`: Total tokens per day limit sent to the model, managed via floating window approach for balanced rate limiting. |
## License
diff --git a/sample/aidial.config.json b/sample/aidial.config.json
index a31e72a79..0bd2c3ac1 100644
--- a/sample/aidial.config.json
+++ b/sample/aidial.config.json
@@ -46,7 +46,12 @@
"endpoint": "http://localhost:7003",
"key": "modelKey3"
}
- ]
+ ],
+ "userRoles": ["role1", "role2"],
+ "defaultUserLimit": {
+ "minute": "100000",
+ "day": "10000000"
+ }
},
"embedding-ada": {
"type": "embedding",
@@ -57,7 +62,8 @@
"key": "modelKey4"
}
]
- }
+ },
+ "userRoles": ["role3"]
},
"keys": {
"proxyKey1": {
@@ -86,6 +92,22 @@
"search_assistant": {},
"app": {}
}
+ },
+ "role1": {
+ "limits": {
+ "chat-gpt-35-turbo": {
+ "minute": "200000",
+ "day": "10000000"
+ }
+ }
+ },
+ "role2": {
+ "limits": {
+ "chat-gpt-35-turbo": {
+ "minute": "100000",
+ "day": "20000000"
+ }
+ }
}
}
}
\ No newline at end of file
diff --git a/src/main/java/com/epam/aidial/core/config/Model.java b/src/main/java/com/epam/aidial/core/config/Model.java
index a95fa70f2..ccb30e279 100644
--- a/src/main/java/com/epam/aidial/core/config/Model.java
+++ b/src/main/java/com/epam/aidial/core/config/Model.java
@@ -17,4 +17,5 @@ public class Model extends Deployment {
private List upstreams = List.of();
// if it's set then the model name is overridden with that name in the request body to the model adapter
private String overrideName;
+ private Limit defaultUserLimit;
}
\ No newline at end of file
diff --git a/src/main/java/com/epam/aidial/core/limiter/RateLimiter.java b/src/main/java/com/epam/aidial/core/limiter/RateLimiter.java
index 74d1d773e..e32f3fa6b 100644
--- a/src/main/java/com/epam/aidial/core/limiter/RateLimiter.java
+++ b/src/main/java/com/epam/aidial/core/limiter/RateLimiter.java
@@ -4,6 +4,7 @@
import com.epam.aidial.core.config.Deployment;
import com.epam.aidial.core.config.Key;
import com.epam.aidial.core.config.Limit;
+import com.epam.aidial.core.config.Model;
import com.epam.aidial.core.config.Role;
import com.epam.aidial.core.data.LimitStats;
import com.epam.aidial.core.data.ResourceType;
@@ -20,10 +21,16 @@
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
@Slf4j
@RequiredArgsConstructor
public class RateLimiter {
+ private static final Limit DEFAULT_LIMIT = new Limit();
+
private final Vertx vertx;
private final ResourceService resourceService;
@@ -34,10 +41,7 @@ public Future increase(ProxyContext context) {
if (resourceService == null) {
return Future.succeededFuture();
}
- Key key = context.getKey();
- if (key == null) {
- return Future.succeededFuture();
- }
+
Deployment deployment = context.getDeployment();
TokenUsage usage = context.getTokenUsage();
@@ -62,8 +66,7 @@ public Future limit(ProxyContext context) {
Deployment deployment = context.getDeployment();
Limit limit;
if (key == null) {
- // don't support user limits yet
- return Future.succeededFuture(RateLimitResult.SUCCESS);
+ limit = getLimitByUser(context);
} else {
limit = getLimitByApiKey(context, deployment.getName());
}
@@ -176,8 +179,38 @@ private Limit getLimitByApiKey(ProxyContext context, String deploymentName) {
return role.getLimits().get(deploymentName);
}
+ private Limit getLimitByUser(ProxyContext context) {
+ List userRoles = context.getUserRoles();
+ Limit defaultUserLimit = getDefaultUserLimit(context.getDeployment());
+ if (userRoles.isEmpty()) {
+ return defaultUserLimit;
+ }
+ String deploymentName = context.getDeployment().getName();
+ Map userRoleToDeploymentLimits = context.getConfig().getRoles();
+ long minuteLimit = 0;
+ long dayLimit = 0;
+ for (String userRole : userRoles) {
+ Limit limit = Optional.ofNullable(userRoleToDeploymentLimits.get(userRole))
+ .map(role -> role.getLimits().get(deploymentName))
+ .orElse(defaultUserLimit);
+ minuteLimit = Math.max(minuteLimit, limit.getMinute());
+ dayLimit = Math.max(dayLimit, limit.getDay());
+ }
+ Limit limit = new Limit();
+ limit.setMinute(minuteLimit);
+ limit.setDay(dayLimit);
+ return limit;
+ }
+
private static String getPath(String deploymentName) {
return String.format("%s/tokens", deploymentName);
}
+ private static Limit getDefaultUserLimit(Deployment deployment) {
+ if (deployment instanceof Model model) {
+ return model.getDefaultUserLimit() == null ? DEFAULT_LIMIT : model.getDefaultUserLimit();
+ }
+ return DEFAULT_LIMIT;
+ }
+
}
diff --git a/src/test/java/com/epam/aidial/core/limiter/RateLimiterTest.java b/src/test/java/com/epam/aidial/core/limiter/RateLimiterTest.java
index 469d0a4b8..c3ef7773a 100644
--- a/src/test/java/com/epam/aidial/core/limiter/RateLimiterTest.java
+++ b/src/test/java/com/epam/aidial/core/limiter/RateLimiterTest.java
@@ -15,7 +15,6 @@
import com.epam.aidial.core.storage.BlobStorage;
import com.epam.aidial.core.token.TokenUsage;
import com.epam.aidial.core.util.HttpStatus;
-import com.epam.aidial.core.util.ProxyUtil;
import io.vertx.core.Future;
import io.vertx.core.Vertx;
import io.vertx.core.http.HttpServerRequest;
@@ -34,7 +33,7 @@
import redis.embedded.RedisServer;
import java.io.IOException;
-import java.util.Collections;
+import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
@@ -128,17 +127,6 @@ public void testLimit_EntityNotFound() {
assertEquals(HttpStatus.FORBIDDEN, result.result().status());
}
- @Test
- public void testLimit_SuccessUser() {
- ProxyContext proxyContext = new ProxyContext(new Config(), request, new ApiKeyData(), new ExtractedClaims("sub", Collections.emptyList(), "hash"), "trace-id", "span-id");
-
- Future result = rateLimiter.limit(proxyContext);
-
- assertNotNull(result);
- assertNotNull(result.result());
- assertEquals(HttpStatus.OK, result.result().status());
- }
-
@Test
public void testLimit_ApiKeyLimitNotFound() {
Key key = new Key();
@@ -340,4 +328,101 @@ public void testGetLimitStats_ApiKey() {
}
+ @Test
+ public void testLimit_User_LimitFound() {
+ Config config = new Config();
+
+ Role role1 = new Role();
+ Limit limit = new Limit();
+ limit.setDay(10000);
+ limit.setMinute(100);
+ role1.setLimits(Map.of("model", limit));
+
+ Role role2 = new Role();
+ limit = new Limit();
+ limit.setDay(20000);
+ limit.setMinute(200);
+ role2.setLimits(Map.of("model", limit));
+
+ config.getRoles().put("role1", role1);
+ config.getRoles().put("role2", role2);
+
+ ApiKeyData apiKeyData = new ApiKeyData();
+ ProxyContext proxyContext = new ProxyContext(config, request, apiKeyData, new ExtractedClaims("sub", List.of("role1", "role2"), "user-hash"), "trace-id", "span-id");
+ Model model = new Model();
+ model.setName("model");
+ proxyContext.setDeployment(model);
+
+ when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> {
+ Callable> callable = invocation.getArgument(0);
+ return Future.succeededFuture(callable.call());
+ });
+
+ TokenUsage tokenUsage = new TokenUsage();
+ tokenUsage.setTotalTokens(150);
+ proxyContext.setTokenUsage(tokenUsage);
+
+ Future increaseLimitFuture = rateLimiter.increase(proxyContext);
+ assertNotNull(increaseLimitFuture);
+ assertNull(increaseLimitFuture.cause());
+
+ Future checkLimitFuture = rateLimiter.limit(proxyContext);
+
+ assertNotNull(checkLimitFuture);
+ assertNotNull(checkLimitFuture.result());
+ assertEquals(HttpStatus.OK, checkLimitFuture.result().status());
+
+ increaseLimitFuture = rateLimiter.increase(proxyContext);
+ assertNotNull(increaseLimitFuture);
+ assertNull(increaseLimitFuture.cause());
+
+ checkLimitFuture = rateLimiter.limit(proxyContext);
+
+ assertNotNull(checkLimitFuture);
+ assertNotNull(checkLimitFuture.result());
+ assertEquals(HttpStatus.TOO_MANY_REQUESTS, checkLimitFuture.result().status());
+
+ }
+
+ @Test
+ public void testLimit_User_LimitNotFound() {
+ Config config = new Config();
+
+ ApiKeyData apiKeyData = new ApiKeyData();
+ ProxyContext proxyContext = new ProxyContext(config, request, apiKeyData, new ExtractedClaims("sub", List.of("role1"), "user-hash"), "trace-id", "span-id");
+ Model model = new Model();
+ model.setName("model");
+ proxyContext.setDeployment(model);
+
+ when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> {
+ Callable> callable = invocation.getArgument(0);
+ return Future.succeededFuture(callable.call());
+ });
+
+ TokenUsage tokenUsage = new TokenUsage();
+ tokenUsage.setTotalTokens(90);
+ proxyContext.setTokenUsage(tokenUsage);
+
+ Future increaseLimitFuture = rateLimiter.increase(proxyContext);
+ assertNotNull(increaseLimitFuture);
+ assertNull(increaseLimitFuture.cause());
+
+ Future checkLimitFuture = rateLimiter.limit(proxyContext);
+
+ assertNotNull(checkLimitFuture);
+ assertNotNull(checkLimitFuture.result());
+ assertEquals(HttpStatus.OK, checkLimitFuture.result().status());
+
+ increaseLimitFuture = rateLimiter.increase(proxyContext);
+ assertNotNull(increaseLimitFuture);
+ assertNull(increaseLimitFuture.cause());
+
+ checkLimitFuture = rateLimiter.limit(proxyContext);
+
+ assertNotNull(checkLimitFuture);
+ assertNotNull(checkLimitFuture.result());
+ assertEquals(HttpStatus.OK, checkLimitFuture.result().status());
+
+ }
+
}