Skip to content

Commit

Permalink
feat: Add support of interceptors for models, applications #379 (#392)
Browse files Browse the repository at this point in the history
  • Loading branch information
astsiapanay authored Jul 17, 2024
1 parent 6c2e47b commit 670d910
Show file tree
Hide file tree
Showing 11 changed files with 363 additions and 42 deletions.
54 changes: 29 additions & 25 deletions README.md

Large diffs are not rendered by default.

29 changes: 20 additions & 9 deletions sample/aidial.config.json
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
{
"routes": {},
"interceptors": {
"interceptor1": {
"endpoint": "http://localhost:4088/api/v1/interceptor/handle"
},
"interceptor2": {
"endpoint": "http://localhost:4089/api/v1/interceptor/handle"
},
"interceptor3": {
"endpoint": "http://localhost:4090/api/v1/interceptor/handle"
}
},
"addons": {
"search": {
"endpoint": "http://localhost:7010/search"
},
"forecast": {
"endpoint": "http://localhost:7010/forecast",
"token": "token",
"displayName": "Forecast",
"iconUrl": "https://host/forecast.svg",
"description": "Addon that provides forecast",
Expand Down Expand Up @@ -72,7 +82,6 @@
"applications": {
"app": {
"endpoint": "http://localhost:7001/openai/deployments/10k/chat/completions",
"token": "token",
"displayName": "Forecast",
"iconUrl": "https://host/app.svg",
"description": "Addon that provides forecast",
Expand All @@ -98,7 +107,8 @@
"paramBool": true,
"paramInt": 123,
"paramFloat": 0.25
}
},
"interceptors": ["interceptor1", "interceptor2", "interceptor3"]
}
},
"models": {
Expand All @@ -112,8 +122,8 @@
},
"pricing": {
"unit": "token",
"prompt": 0.56,
"completion": 0.67
"prompt": "0.56",
"completion": "0.67"
},
"overrideName": "/some[!exotic?]/model/name",
"displayName": "GPT-3.5",
Expand Down Expand Up @@ -153,7 +163,8 @@
"paramBool": true,
"paramInt": 123,
"paramFloat": 0.25
}
},
"interceptors": ["interceptor1"]
},
"embedding-ada": {
"type": "embedding",
Expand All @@ -163,9 +174,9 @@
"endpoint": "http://localhost:7001",
"key": "modelKey4"
}
]
},
"userRoles": ["role3"]
],
"userRoles": ["role3"]
}
},
"keys": {
"proxyKey1": {
Expand Down
25 changes: 25 additions & 0 deletions src/main/java/com/epam/aidial/core/ProxyContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ public class ProxyContext {
private long responseBodyTimestamp;
private ExtractedClaims extractedClaims;
private ApiKeyData proxyApiKeyData;
// deployment triggers interceptors
private String initialDeployment;
private String initialDeploymentApi;
// List of interceptors copied from the deployment config
private List<String> interceptors;

public ProxyContext(Config config, HttpServerRequest request, ApiKeyData apiKeyData, ExtractedClaims extractedClaims, String traceId, String spanId) {
this.config = config;
Expand Down Expand Up @@ -138,4 +143,24 @@ public List<String> getExecutionPath() {
public boolean getBooleanRequestQueryParam(String name) {
return Boolean.parseBoolean(request.getParam(name, "false"));
}

public List<String> getInterceptors() {
return interceptors == null ? apiKeyData.getInterceptors() : interceptors;
}

public boolean hasNextInterceptor() {
if (apiKeyData.getInterceptors() == null) { // initial call to the deployment
return !deployment.getInterceptors().isEmpty();
} else { // make sure if a next interceptor is available from the list
return apiKeyData.getInterceptorIndex() + 1 < apiKeyData.getInterceptors().size();
}
}

public String getInitialDeployment() {
return initialDeployment == null ? apiKeyData.getInitialDeployment() : initialDeployment;
}

public String getInitialDeploymentApi() {
return initialDeploymentApi == null ? apiKeyData.getInitialDeploymentApi() : initialDeploymentApi;
}
}
14 changes: 13 additions & 1 deletion src/main/java/com/epam/aidial/core/config/ApiKeyData.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,29 @@ public class ApiKeyData {
// list of attached file URLs collected from conversation history of the current request
private Set<String> attachedFiles = new HashSet<>();
private List<String> attachedFolders = new ArrayList<>();
// deployment name of the source(application/assistant/model) associated with the current request
// deployment name of the source(application/assistant/model/interceptor) associated with the current request
private String sourceDeployment;
// Execution path of the root request
private List<String> executionPath;
// List of interceptors copied from the deployment config
private List<String> interceptors;
// Index to track which interceptor is called next
private int interceptorIndex = -1;
// deployment triggers interceptors
private String initialDeployment;
private String initialDeploymentApi;

public ApiKeyData() {
}

public static void initFromContext(ApiKeyData proxyApiKeyData, ProxyContext context) {
ApiKeyData apiKeyData = context.getApiKeyData();
List<String> currentPath;
proxyApiKeyData.setInterceptors(context.getInterceptors());
proxyApiKeyData.setInterceptorIndex(apiKeyData.getInterceptorIndex() + 1); // move to next interceptor
proxyApiKeyData.setInitialDeployment(context.getInitialDeployment());
proxyApiKeyData.setInitialDeploymentApi(context.getInitialDeploymentApi());

if (apiKeyData.getPerRequestKey() == null) {
proxyApiKeyData.setOriginalKey(context.getKey());
proxyApiKeyData.setExtractedClaims(context.getExtractedClaims());
Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/epam/aidial/core/config/Config.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class Config {
private Map<String, Key> keys = new HashMap<>();
private Map<String, Role> roles = new HashMap<>();
private Set<Integer> retriableErrorCodes = Set.of();
private Map<String, Interceptor> interceptors = Map.of();


public Deployment selectDeployment(String deploymentId) {
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/com/epam/aidial/core/config/Deployment.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,8 @@ public abstract class Deployment {
* Default parameters are applied if a request doesn't contain them in OpenAI chat/completions API call.
*/
private Map<String, Object> defaults = Map.of();
/**
* List of interceptors to be called for the deployment
*/
private List<String> interceptors = List.of();
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ private void load(boolean fail) {
role.setName(name);
}

for (Map.Entry<String, Interceptor> entry : config.getInterceptors().entrySet()) {
String name = entry.getKey();
Interceptor interceptor = entry.getValue();
interceptor.setName(name);
}

this.config = config;
} catch (Throwable e) {
if (fail) {
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/com/epam/aidial/core/config/Interceptor.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package com.epam.aidial.core.config;

public class Interceptor extends Deployment {
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.epam.aidial.core.config.ApiKeyData;
import com.epam.aidial.core.config.Config;
import com.epam.aidial.core.config.Deployment;
import com.epam.aidial.core.config.Interceptor;
import com.epam.aidial.core.config.Model;
import com.epam.aidial.core.config.ModelType;
import com.epam.aidial.core.config.Pricing;
Expand Down Expand Up @@ -78,7 +79,14 @@ public Future<?> handle(String deploymentId, String deploymentApi) {
if (!StringUtils.containsIgnoreCase(contentType, Proxy.HEADER_CONTENT_TYPE_APPLICATION_JSON)) {
return respond(HttpStatus.UNSUPPORTED_MEDIA_TYPE, "Only application/json is supported");
}
// handle a special deployment `interceptor`
if ("interceptor".equals(deploymentId)) {
return handleInterceptor();
}
return handleDeployment(deploymentId, deploymentApi);
}

private Future<?> handleDeployment(String deploymentId, String deploymentApi) {
Deployment deployment = context.getConfig().selectDeployment(deploymentId);
boolean isValidDeployment = isValidDeploymentApi(deployment, deploymentApi);

Expand Down Expand Up @@ -109,15 +117,22 @@ public Future<?> handle(String deploymentId, String deploymentApi) {
return dep;
})
.compose(dep -> {
if (dep instanceof Model) {
if (dep instanceof Model && !context.hasNextInterceptor()) {
return proxy.getRateLimiter().limit(context);
} else {
return Future.succeededFuture(RateLimitResult.SUCCESS);
}
})
.map(rateLimitResult -> {
if (rateLimitResult.status() == HttpStatus.OK) {
handleRateLimitSuccess(deploymentId);
if (context.hasNextInterceptor()) {
context.setInitialDeployment(deploymentId);
context.setInitialDeploymentApi(deploymentApi);
context.setInterceptors(context.getDeployment().getInterceptors());
handleInterceptor();
} else {
handleRateLimitSuccess(deploymentId);
}
} else {
handleRateLimitHit(deploymentId, rateLimitResult);
}
Expand All @@ -129,6 +144,28 @@ public Future<?> handle(String deploymentId, String deploymentApi) {
});
}

private Future<?> handleInterceptor() {
ApiKeyData apiKeyData = context.getApiKeyData();
List<String> interceptors = context.getInterceptors();
int nextIndex = apiKeyData.getInterceptorIndex() + 1;
if (nextIndex < interceptors.size()) {
String interceptorName = interceptors.get(nextIndex);
Interceptor interceptor = context.getConfig().getInterceptors().get(interceptorName);
if (interceptor == null) {
log.warn("Interceptor is not found for the given name: {}", interceptorName);
return respond(HttpStatus.NOT_FOUND, "Interceptor is not found");
}
context.setDeployment(interceptor);

setupProxyApiKeyData();

InterceptorController controller = new InterceptorController(proxy, context);
return controller.handle();
} else { // all interceptors are completed we should call the initial deployment
return handleDeployment(apiKeyData.getInitialDeployment(), apiKeyData.getInitialDeploymentApi());
}
}

private void handleRequestError(String deploymentId, Throwable error) {
if (error instanceof PermissionDeniedException) {
log.error("Forbidden deployment {}. Key: {}. User sub: {}", deploymentId, context.getProject(), context.getUserSub());
Expand Down Expand Up @@ -172,9 +209,6 @@ private void handleRateLimitSuccess(String deploymentId) {
.onFailure(this::handleRequestBodyError);
}

/**
* The method uses blocking calls and should not be used in the event loop thread.
*/
private void setupProxyApiKeyData() {
ApiKeyData proxyApiKeyData = new ApiKeyData();
context.setProxyApiKeyData(proxyApiKeyData);
Expand Down
Loading

0 comments on commit 670d910

Please sign in to comment.