Skip to content

Commit

Permalink
fix: selector with strict patterns (#200)
Browse files Browse the repository at this point in the history
  • Loading branch information
artsiomkorzun authored Feb 8, 2024
1 parent ba607b9 commit 8009deb
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 66 deletions.
123 changes: 57 additions & 66 deletions src/main/java/com/epam/aidial/core/controller/ControllerSelector.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
import com.epam.aidial.core.ProxyContext;
import com.epam.aidial.core.config.Deployment;
import com.epam.aidial.core.config.Features;
import com.epam.aidial.core.util.UrlUtil;
import io.vertx.core.http.HttpMethod;
import lombok.experimental.UtilityClass;

import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.util.Optional;
import java.util.function.Function;
import java.util.regex.Matcher;
Expand All @@ -17,46 +16,44 @@
@UtilityClass
public class ControllerSelector {

private static final Pattern PATTERN_POST_DEPLOYMENT = Pattern.compile("/+openai/deployments/([-.@a-zA-Z0-9]+)/(completions|chat/completions|embeddings)");
private static final Pattern PATTERN_DEPLOYMENT = Pattern.compile("/+openai/deployments/([-.@a-zA-Z0-9]+)");
private static final Pattern PATTERN_DEPLOYMENTS = Pattern.compile("/+openai/deployments");
private static final Pattern PATTERN_POST_DEPLOYMENT = Pattern.compile("^/+openai/deployments/([^/]+)/(completions|chat/completions|embeddings)$");
private static final Pattern PATTERN_DEPLOYMENT = Pattern.compile("^/+openai/deployments/([^/]+)$");
private static final Pattern PATTERN_DEPLOYMENTS = Pattern.compile("^/+openai/deployments$");

private static final Pattern PATTERN_MODEL = Pattern.compile("/+openai/models/([-.@a-zA-Z0-9]+)");
private static final Pattern PATTERN_MODELS = Pattern.compile("/+openai/models");
private static final Pattern PATTERN_MODEL = Pattern.compile("^/+openai/models/([^/]+)$");
private static final Pattern PATTERN_MODELS = Pattern.compile("^/+openai/models$");

private static final Pattern PATTERN_ADDON = Pattern.compile("/+openai/addons/([-.@a-zA-Z0-9]+)");
private static final Pattern PATTERN_ADDONS = Pattern.compile("/+openai/addons");
private static final Pattern PATTERN_ADDON = Pattern.compile("^/+openai/addons/([^/]+)$");
private static final Pattern PATTERN_ADDONS = Pattern.compile("^/+openai/addons$");

private static final Pattern PATTERN_ASSISTANT = Pattern.compile("/+openai/assistants/([-.@a-zA-Z0-9]+)");
private static final Pattern PATTERN_ASSISTANTS = Pattern.compile("/+openai/assistants");
private static final Pattern PATTERN_ASSISTANT = Pattern.compile("^/+openai/assistants/([^/]+)$");
private static final Pattern PATTERN_ASSISTANTS = Pattern.compile("^/+openai/assistants$");

private static final Pattern PATTERN_APPLICATION = Pattern.compile("/+openai/applications/([-.@a-zA-Z0-9]+)");
private static final Pattern PATTERN_APPLICATIONS = Pattern.compile("/+openai/applications");
private static final Pattern PATTERN_APPLICATION = Pattern.compile("^/+openai/applications/([^/]+)$");
private static final Pattern PATTERN_APPLICATIONS = Pattern.compile("^/+openai/applications$");


private static final Pattern PATTERN_BUCKET = Pattern.compile("/v1/bucket");
private static final Pattern PATTERN_BUCKET = Pattern.compile("^/v1/bucket$");

private static final Pattern PATTERN_FILES = Pattern.compile("/v1/files/([a-zA-Z0-9]+)/(.*)");
private static final Pattern PATTERN_FILES = Pattern.compile("^/v1/files/([a-zA-Z0-9]+)/(.*)");
private static final Pattern PATTERN_FILES_METADATA = Pattern.compile("^/v1/metadata/files/([a-zA-Z0-9]+)/(.*)");

private static final Pattern PATTERN_FILES_METADATA = Pattern.compile("/v1/metadata/files/([a-zA-Z0-9]+)/(.*)");
private static final Pattern PATTERN_RESOURCE = Pattern.compile("^/v1/(conversations|prompts)/([a-zA-Z0-9]+)/(.*)");
private static final Pattern PATTERN_RESOURCE_METADATA = Pattern.compile("^/v1/metadata/(conversations|prompts)/([a-zA-Z0-9]+)/(.*)");

private static final Pattern PATTERN_RESOURCE = Pattern.compile("/v1/(conversations|prompts)/([a-zA-Z0-9]+)/(.*)");
private static final Pattern PATTERN_RESOURCE_METADATA = Pattern.compile("/v1/metadata/(conversations|prompts)/([a-zA-Z0-9]+)/(.*)");

private static final Pattern PATTERN_RATE_RESPONSE = Pattern.compile("/+v1/([-.@a-zA-Z0-9]+)/rate");
private static final Pattern PATTERN_TOKENIZE = Pattern.compile("/+v1/deployments/([-.@a-zA-Z0-9]+)/tokenize");
private static final Pattern PATTERN_TRUNCATE_PROMPT = Pattern.compile("/+v1/deployments/([-.@a-zA-Z0-9]+)/truncate_prompt");
private static final Pattern PATTERN_RATE_RESPONSE = Pattern.compile("^/+v1/([^/]+)/rate$");
private static final Pattern PATTERN_TOKENIZE = Pattern.compile("^/+v1/deployments/([^/]+)/tokenize$");
private static final Pattern PATTERN_TRUNCATE_PROMPT = Pattern.compile("^/+v1/deployments/([^/]+)/truncate_prompt$");

public Controller select(Proxy proxy, ProxyContext context) {
String path = context.getRequest().path();
String decodedPath = URLDecoder.decode(path, StandardCharsets.UTF_8);
HttpMethod method = context.getRequest().method();
Controller controller = null;

if (method == HttpMethod.GET) {
controller = selectGet(proxy, context, path, decodedPath);
controller = selectGet(proxy, context, path);
} else if (method == HttpMethod.POST) {
controller = selectPost(proxy, context, decodedPath);
controller = selectPost(proxy, context, path);
} else if (method == HttpMethod.DELETE) {
controller = selectDelete(proxy, context, path);
} else if (method == HttpMethod.PUT) {
Expand All @@ -66,69 +63,69 @@ public Controller select(Proxy proxy, ProxyContext context) {
return (controller == null) ? new RouteController(proxy, context) : controller;
}

private static Controller selectGet(Proxy proxy, ProxyContext context, String path, String decodedPath) {
private static Controller selectGet(Proxy proxy, ProxyContext context, String path) {
Matcher match;

match = match(PATTERN_DEPLOYMENT, decodedPath);
match = match(PATTERN_DEPLOYMENT, path);
if (match != null) {
DeploymentController controller = new DeploymentController(context);
String deploymentId = match.group(1);
String deploymentId = UrlUtil.decodePath(match.group(1));
return () -> controller.getDeployment(deploymentId);
}

match = match(PATTERN_DEPLOYMENTS, decodedPath);
match = match(PATTERN_DEPLOYMENTS, path);
if (match != null) {
DeploymentController controller = new DeploymentController(context);
return controller::getDeployments;
}

match = match(PATTERN_MODEL, decodedPath);
match = match(PATTERN_MODEL, path);
if (match != null) {
ModelController controller = new ModelController(context);
String modelId = match.group(1);
String modelId = UrlUtil.decodePath(match.group(1));
return () -> controller.getModel(modelId);
}

match = match(PATTERN_MODELS, decodedPath);
match = match(PATTERN_MODELS, path);
if (match != null) {
ModelController controller = new ModelController(context);
return controller::getModels;
}

match = match(PATTERN_ADDON, decodedPath);
match = match(PATTERN_ADDON, path);
if (match != null) {
AddonController controller = new AddonController(context);
String addonId = match.group(1);
String addonId = UrlUtil.decodePath(match.group(1));
return () -> controller.getAddon(addonId);
}

match = match(PATTERN_ADDONS, decodedPath);
match = match(PATTERN_ADDONS, path);
if (match != null) {
AddonController controller = new AddonController(context);
return controller::getAddons;
}

match = match(PATTERN_ASSISTANT, decodedPath);
match = match(PATTERN_ASSISTANT, path);
if (match != null) {
AssistantController controller = new AssistantController(context);
String assistantId = match.group(1);
String assistantId = UrlUtil.decodePath(match.group(1));
return () -> controller.getAssistant(assistantId);
}

match = match(PATTERN_ASSISTANTS, decodedPath);
match = match(PATTERN_ASSISTANTS, path);
if (match != null) {
AssistantController controller = new AssistantController(context);
return controller::getAssistants;
}

match = match(PATTERN_APPLICATION, decodedPath);
match = match(PATTERN_APPLICATION, path);
if (match != null) {
ApplicationController controller = new ApplicationController(context);
String application = match.group(1);
String application = UrlUtil.decodePath(match.group(1));
return () -> controller.getApplication(application);
}

match = match(PATTERN_APPLICATIONS, decodedPath);
match = match(PATTERN_APPLICATIONS, path);
if (match != null) {
ApplicationController controller = new ApplicationController(context);
return controller::getApplications;
Expand Down Expand Up @@ -168,7 +165,7 @@ private static Controller selectGet(Proxy proxy, ProxyContext context, String pa
return () -> controller.handle(resource, bucket, relativePath);
}

match = match(PATTERN_BUCKET, decodedPath);
match = match(PATTERN_BUCKET, path);
if (match != null) {
BucketController controller = new BucketController(proxy, context);
return controller::getBucket;
Expand All @@ -180,52 +177,46 @@ private static Controller selectGet(Proxy proxy, ProxyContext context, String pa
private static Controller selectPost(Proxy proxy, ProxyContext context, String path) {
Matcher match = match(PATTERN_POST_DEPLOYMENT, path);
if (match != null) {
String deploymentId = match.group(1);
String deploymentApi = match.group(2);
String deploymentId = UrlUtil.decodePath(match.group(1));
String deploymentApi = UrlUtil.decodePath(match.group(2));
DeploymentPostController controller = new DeploymentPostController(proxy, context);
return () -> controller.handle(deploymentId, deploymentApi);
}

match = match(PATTERN_RATE_RESPONSE, path);
if (match != null) {
String deploymentId = match.group(1);
String deploymentId = UrlUtil.decodePath(match.group(1));

Function<Deployment, String> getter = (model) -> {
return Optional.ofNullable(model)
.map(d -> d.getFeatures())
.map(t -> t.getRateEndpoint())
.orElse(null);
};
Function<Deployment, String> getter = (model) -> Optional.ofNullable(model)
.map(Deployment::getFeatures)
.map(Features::getRateEndpoint)
.orElse(null);

DeploymentFeatureController controller = new DeploymentFeatureController(proxy, context);
return () -> controller.handle(deploymentId, getter, false);
}

match = match(PATTERN_TOKENIZE, path);
if (match != null) {
String deploymentId = match.group(1);
String deploymentId = UrlUtil.decodePath(match.group(1));

Function<Deployment, String> getter = (model) -> {
return Optional.ofNullable(model)
.map(d -> d.getFeatures())
.map(t -> t.getTokenizeEndpoint())
.orElse(null);
};
Function<Deployment, String> getter = (model) -> Optional.ofNullable(model)
.map(Deployment::getFeatures)
.map(Features::getTokenizeEndpoint)
.orElse(null);

DeploymentFeatureController controller = new DeploymentFeatureController(proxy, context);
return () -> controller.handle(deploymentId, getter, true);
}

match = match(PATTERN_TRUNCATE_PROMPT, path);
if (match != null) {
String deploymentId = match.group(1);

Function<Deployment, String> getter = (model) -> {
return Optional.ofNullable(model)
.map(Deployment::getFeatures)
.map(Features::getTruncatePromptEndpoint)
.orElse(null);
};
String deploymentId = UrlUtil.decodePath(match.group(1));

Function<Deployment, String> getter = (model) -> Optional.ofNullable(model)
.map(Deployment::getFeatures)
.map(Features::getTruncatePromptEndpoint)
.orElse(null);

DeploymentFeatureController controller = new DeploymentFeatureController(proxy, context);
return () -> controller.handle(deploymentId, getter, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,36 @@ public void testSelectRouteController() {
assertInstanceOf(RouteController.class, controller);
}

@Test
void testSelectDeploymentWithSpecialName() {
when(request.path()).thenReturn("/openai/deployments/deployment_x-y%2B%2F");
when(request.method()).thenReturn(HttpMethod.GET);
Controller controller = ControllerSelector.select(proxy, context);
assertNotNull(controller);
SerializedLambda lambda = getSerializedLambda(controller);
assertNotNull(lambda);
Object arg1 = lambda.getCapturedArg(0);
Object arg2 = lambda.getCapturedArg(1);
assertInstanceOf(DeploymentController.class, arg1);
assertEquals("deployment_x-y+/", arg2);
}

@Test
void testFailDeploymentWithSlash() {
when(request.path()).thenReturn("/openai/deployments/deployment/xy");
when(request.method()).thenReturn(HttpMethod.GET);
Controller controller = ControllerSelector.select(proxy, context);
assertInstanceOf(RouteController.class, controller);
}

@Test
void testFailDeploymentWithBadPrefix() {
when(request.path()).thenReturn("/prefix/openai/deployments/deployment");
when(request.method()).thenReturn(HttpMethod.GET);
Controller controller = ControllerSelector.select(proxy, context);
assertInstanceOf(RouteController.class, controller);
}

@Nullable
private static SerializedLambda getSerializedLambda(Serializable lambda) {
for (Class<?> cl = lambda.getClass(); cl != null; cl = cl.getSuperclass()) {
Expand Down

0 comments on commit 8009deb

Please sign in to comment.