From 8009deb94dd6389ea8ace230c73a7cebf31a8fea Mon Sep 17 00:00:00 2001 From: Artsiom Korzun <72259616+artsiomkorzun@users.noreply.github.com> Date: Thu, 8 Feb 2024 14:48:41 +0100 Subject: [PATCH] fix: selector with strict patterns (#200) --- .../core/controller/ControllerSelector.java | 123 ++++++++---------- .../controller/ControllerSelectorTest.java | 30 +++++ 2 files changed, 87 insertions(+), 66 deletions(-) diff --git a/src/main/java/com/epam/aidial/core/controller/ControllerSelector.java b/src/main/java/com/epam/aidial/core/controller/ControllerSelector.java index 39a5afef5..b5656b0cf 100644 --- a/src/main/java/com/epam/aidial/core/controller/ControllerSelector.java +++ b/src/main/java/com/epam/aidial/core/controller/ControllerSelector.java @@ -4,11 +4,10 @@ import com.epam.aidial.core.ProxyContext; import com.epam.aidial.core.config.Deployment; import com.epam.aidial.core.config.Features; +import com.epam.aidial.core.util.UrlUtil; import io.vertx.core.http.HttpMethod; import lombok.experimental.UtilityClass; -import java.net.URLDecoder; -import java.nio.charset.StandardCharsets; import java.util.Optional; import java.util.function.Function; import java.util.regex.Matcher; @@ -17,46 +16,44 @@ @UtilityClass public class ControllerSelector { - private static final Pattern PATTERN_POST_DEPLOYMENT = Pattern.compile("/+openai/deployments/([-.@a-zA-Z0-9]+)/(completions|chat/completions|embeddings)"); - private static final Pattern PATTERN_DEPLOYMENT = Pattern.compile("/+openai/deployments/([-.@a-zA-Z0-9]+)"); - private static final Pattern PATTERN_DEPLOYMENTS = Pattern.compile("/+openai/deployments"); + private static final Pattern PATTERN_POST_DEPLOYMENT = Pattern.compile("^/+openai/deployments/([^/]+)/(completions|chat/completions|embeddings)$"); + private static final Pattern PATTERN_DEPLOYMENT = Pattern.compile("^/+openai/deployments/([^/]+)$"); + private static final Pattern PATTERN_DEPLOYMENTS = Pattern.compile("^/+openai/deployments$"); - private static final Pattern PATTERN_MODEL = Pattern.compile("/+openai/models/([-.@a-zA-Z0-9]+)"); - private static final Pattern PATTERN_MODELS = Pattern.compile("/+openai/models"); + private static final Pattern PATTERN_MODEL = Pattern.compile("^/+openai/models/([^/]+)$"); + private static final Pattern PATTERN_MODELS = Pattern.compile("^/+openai/models$"); - private static final Pattern PATTERN_ADDON = Pattern.compile("/+openai/addons/([-.@a-zA-Z0-9]+)"); - private static final Pattern PATTERN_ADDONS = Pattern.compile("/+openai/addons"); + private static final Pattern PATTERN_ADDON = Pattern.compile("^/+openai/addons/([^/]+)$"); + private static final Pattern PATTERN_ADDONS = Pattern.compile("^/+openai/addons$"); - private static final Pattern PATTERN_ASSISTANT = Pattern.compile("/+openai/assistants/([-.@a-zA-Z0-9]+)"); - private static final Pattern PATTERN_ASSISTANTS = Pattern.compile("/+openai/assistants"); + private static final Pattern PATTERN_ASSISTANT = Pattern.compile("^/+openai/assistants/([^/]+)$"); + private static final Pattern PATTERN_ASSISTANTS = Pattern.compile("^/+openai/assistants$"); - private static final Pattern PATTERN_APPLICATION = Pattern.compile("/+openai/applications/([-.@a-zA-Z0-9]+)"); - private static final Pattern PATTERN_APPLICATIONS = Pattern.compile("/+openai/applications"); + private static final Pattern PATTERN_APPLICATION = Pattern.compile("^/+openai/applications/([^/]+)$"); + private static final Pattern PATTERN_APPLICATIONS = Pattern.compile("^/+openai/applications$"); - private static final Pattern PATTERN_BUCKET = Pattern.compile("/v1/bucket"); + private static final Pattern PATTERN_BUCKET = Pattern.compile("^/v1/bucket$"); - private static final Pattern PATTERN_FILES = Pattern.compile("/v1/files/([a-zA-Z0-9]+)/(.*)"); + private static final Pattern PATTERN_FILES = Pattern.compile("^/v1/files/([a-zA-Z0-9]+)/(.*)"); + private static final Pattern PATTERN_FILES_METADATA = Pattern.compile("^/v1/metadata/files/([a-zA-Z0-9]+)/(.*)"); - private static final Pattern PATTERN_FILES_METADATA = Pattern.compile("/v1/metadata/files/([a-zA-Z0-9]+)/(.*)"); + private static final Pattern PATTERN_RESOURCE = Pattern.compile("^/v1/(conversations|prompts)/([a-zA-Z0-9]+)/(.*)"); + private static final Pattern PATTERN_RESOURCE_METADATA = Pattern.compile("^/v1/metadata/(conversations|prompts)/([a-zA-Z0-9]+)/(.*)"); - private static final Pattern PATTERN_RESOURCE = Pattern.compile("/v1/(conversations|prompts)/([a-zA-Z0-9]+)/(.*)"); - private static final Pattern PATTERN_RESOURCE_METADATA = Pattern.compile("/v1/metadata/(conversations|prompts)/([a-zA-Z0-9]+)/(.*)"); - - private static final Pattern PATTERN_RATE_RESPONSE = Pattern.compile("/+v1/([-.@a-zA-Z0-9]+)/rate"); - private static final Pattern PATTERN_TOKENIZE = Pattern.compile("/+v1/deployments/([-.@a-zA-Z0-9]+)/tokenize"); - private static final Pattern PATTERN_TRUNCATE_PROMPT = Pattern.compile("/+v1/deployments/([-.@a-zA-Z0-9]+)/truncate_prompt"); + private static final Pattern PATTERN_RATE_RESPONSE = Pattern.compile("^/+v1/([^/]+)/rate$"); + private static final Pattern PATTERN_TOKENIZE = Pattern.compile("^/+v1/deployments/([^/]+)/tokenize$"); + private static final Pattern PATTERN_TRUNCATE_PROMPT = Pattern.compile("^/+v1/deployments/([^/]+)/truncate_prompt$"); public Controller select(Proxy proxy, ProxyContext context) { String path = context.getRequest().path(); - String decodedPath = URLDecoder.decode(path, StandardCharsets.UTF_8); HttpMethod method = context.getRequest().method(); Controller controller = null; if (method == HttpMethod.GET) { - controller = selectGet(proxy, context, path, decodedPath); + controller = selectGet(proxy, context, path); } else if (method == HttpMethod.POST) { - controller = selectPost(proxy, context, decodedPath); + controller = selectPost(proxy, context, path); } else if (method == HttpMethod.DELETE) { controller = selectDelete(proxy, context, path); } else if (method == HttpMethod.PUT) { @@ -66,69 +63,69 @@ public Controller select(Proxy proxy, ProxyContext context) { return (controller == null) ? new RouteController(proxy, context) : controller; } - private static Controller selectGet(Proxy proxy, ProxyContext context, String path, String decodedPath) { + private static Controller selectGet(Proxy proxy, ProxyContext context, String path) { Matcher match; - match = match(PATTERN_DEPLOYMENT, decodedPath); + match = match(PATTERN_DEPLOYMENT, path); if (match != null) { DeploymentController controller = new DeploymentController(context); - String deploymentId = match.group(1); + String deploymentId = UrlUtil.decodePath(match.group(1)); return () -> controller.getDeployment(deploymentId); } - match = match(PATTERN_DEPLOYMENTS, decodedPath); + match = match(PATTERN_DEPLOYMENTS, path); if (match != null) { DeploymentController controller = new DeploymentController(context); return controller::getDeployments; } - match = match(PATTERN_MODEL, decodedPath); + match = match(PATTERN_MODEL, path); if (match != null) { ModelController controller = new ModelController(context); - String modelId = match.group(1); + String modelId = UrlUtil.decodePath(match.group(1)); return () -> controller.getModel(modelId); } - match = match(PATTERN_MODELS, decodedPath); + match = match(PATTERN_MODELS, path); if (match != null) { ModelController controller = new ModelController(context); return controller::getModels; } - match = match(PATTERN_ADDON, decodedPath); + match = match(PATTERN_ADDON, path); if (match != null) { AddonController controller = new AddonController(context); - String addonId = match.group(1); + String addonId = UrlUtil.decodePath(match.group(1)); return () -> controller.getAddon(addonId); } - match = match(PATTERN_ADDONS, decodedPath); + match = match(PATTERN_ADDONS, path); if (match != null) { AddonController controller = new AddonController(context); return controller::getAddons; } - match = match(PATTERN_ASSISTANT, decodedPath); + match = match(PATTERN_ASSISTANT, path); if (match != null) { AssistantController controller = new AssistantController(context); - String assistantId = match.group(1); + String assistantId = UrlUtil.decodePath(match.group(1)); return () -> controller.getAssistant(assistantId); } - match = match(PATTERN_ASSISTANTS, decodedPath); + match = match(PATTERN_ASSISTANTS, path); if (match != null) { AssistantController controller = new AssistantController(context); return controller::getAssistants; } - match = match(PATTERN_APPLICATION, decodedPath); + match = match(PATTERN_APPLICATION, path); if (match != null) { ApplicationController controller = new ApplicationController(context); - String application = match.group(1); + String application = UrlUtil.decodePath(match.group(1)); return () -> controller.getApplication(application); } - match = match(PATTERN_APPLICATIONS, decodedPath); + match = match(PATTERN_APPLICATIONS, path); if (match != null) { ApplicationController controller = new ApplicationController(context); return controller::getApplications; @@ -168,7 +165,7 @@ private static Controller selectGet(Proxy proxy, ProxyContext context, String pa return () -> controller.handle(resource, bucket, relativePath); } - match = match(PATTERN_BUCKET, decodedPath); + match = match(PATTERN_BUCKET, path); if (match != null) { BucketController controller = new BucketController(proxy, context); return controller::getBucket; @@ -180,22 +177,20 @@ private static Controller selectGet(Proxy proxy, ProxyContext context, String pa private static Controller selectPost(Proxy proxy, ProxyContext context, String path) { Matcher match = match(PATTERN_POST_DEPLOYMENT, path); if (match != null) { - String deploymentId = match.group(1); - String deploymentApi = match.group(2); + String deploymentId = UrlUtil.decodePath(match.group(1)); + String deploymentApi = UrlUtil.decodePath(match.group(2)); DeploymentPostController controller = new DeploymentPostController(proxy, context); return () -> controller.handle(deploymentId, deploymentApi); } match = match(PATTERN_RATE_RESPONSE, path); if (match != null) { - String deploymentId = match.group(1); + String deploymentId = UrlUtil.decodePath(match.group(1)); - Function getter = (model) -> { - return Optional.ofNullable(model) - .map(d -> d.getFeatures()) - .map(t -> t.getRateEndpoint()) - .orElse(null); - }; + Function getter = (model) -> Optional.ofNullable(model) + .map(Deployment::getFeatures) + .map(Features::getRateEndpoint) + .orElse(null); DeploymentFeatureController controller = new DeploymentFeatureController(proxy, context); return () -> controller.handle(deploymentId, getter, false); @@ -203,14 +198,12 @@ private static Controller selectPost(Proxy proxy, ProxyContext context, String p match = match(PATTERN_TOKENIZE, path); if (match != null) { - String deploymentId = match.group(1); + String deploymentId = UrlUtil.decodePath(match.group(1)); - Function getter = (model) -> { - return Optional.ofNullable(model) - .map(d -> d.getFeatures()) - .map(t -> t.getTokenizeEndpoint()) - .orElse(null); - }; + Function getter = (model) -> Optional.ofNullable(model) + .map(Deployment::getFeatures) + .map(Features::getTokenizeEndpoint) + .orElse(null); DeploymentFeatureController controller = new DeploymentFeatureController(proxy, context); return () -> controller.handle(deploymentId, getter, true); @@ -218,14 +211,12 @@ private static Controller selectPost(Proxy proxy, ProxyContext context, String p match = match(PATTERN_TRUNCATE_PROMPT, path); if (match != null) { - String deploymentId = match.group(1); - - Function getter = (model) -> { - return Optional.ofNullable(model) - .map(Deployment::getFeatures) - .map(Features::getTruncatePromptEndpoint) - .orElse(null); - }; + String deploymentId = UrlUtil.decodePath(match.group(1)); + + Function getter = (model) -> Optional.ofNullable(model) + .map(Deployment::getFeatures) + .map(Features::getTruncatePromptEndpoint) + .orElse(null); DeploymentFeatureController controller = new DeploymentFeatureController(proxy, context); return () -> controller.handle(deploymentId, getter, true); diff --git a/src/test/java/com/epam/aidial/core/controller/ControllerSelectorTest.java b/src/test/java/com/epam/aidial/core/controller/ControllerSelectorTest.java index a4cca7b6e..c77e2df07 100644 --- a/src/test/java/com/epam/aidial/core/controller/ControllerSelectorTest.java +++ b/src/test/java/com/epam/aidial/core/controller/ControllerSelectorTest.java @@ -381,6 +381,36 @@ public void testSelectRouteController() { assertInstanceOf(RouteController.class, controller); } + @Test + void testSelectDeploymentWithSpecialName() { + when(request.path()).thenReturn("/openai/deployments/deployment_x-y%2B%2F"); + when(request.method()).thenReturn(HttpMethod.GET); + Controller controller = ControllerSelector.select(proxy, context); + assertNotNull(controller); + SerializedLambda lambda = getSerializedLambda(controller); + assertNotNull(lambda); + Object arg1 = lambda.getCapturedArg(0); + Object arg2 = lambda.getCapturedArg(1); + assertInstanceOf(DeploymentController.class, arg1); + assertEquals("deployment_x-y+/", arg2); + } + + @Test + void testFailDeploymentWithSlash() { + when(request.path()).thenReturn("/openai/deployments/deployment/xy"); + when(request.method()).thenReturn(HttpMethod.GET); + Controller controller = ControllerSelector.select(proxy, context); + assertInstanceOf(RouteController.class, controller); + } + + @Test + void testFailDeploymentWithBadPrefix() { + when(request.path()).thenReturn("/prefix/openai/deployments/deployment"); + when(request.method()).thenReturn(HttpMethod.GET); + Controller controller = ControllerSelector.select(proxy, context); + assertInstanceOf(RouteController.class, controller); + } + @Nullable private static SerializedLambda getSerializedLambda(Serializable lambda) { for (Class cl = lambda.getClass(); cl != null; cl = cl.getSuperclass()) {