From 6796c76b310e711ef24b77603bf4fdc645bcc0f7 Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Fri, 10 Jan 2025 15:17:52 +0000 Subject: [PATCH 1/4] fix: propagate headers to feature endpoints --- .../DeploymentFeatureController.java | 37 ++++++++++++++++--- .../aidial/core/server/FeaturesApiTest.java | 24 ++++++++++-- .../aidial/core/server/ResourceBaseTest.java | 7 +++- .../aidial/core/server/TestWebServer.java | 11 ++++-- 4 files changed, 66 insertions(+), 13 deletions(-) diff --git a/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentFeatureController.java b/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentFeatureController.java index 08023d930..eef394b6e 100644 --- a/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentFeatureController.java +++ b/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentFeatureController.java @@ -1,14 +1,18 @@ package com.epam.aidial.core.server.controller; import com.epam.aidial.core.config.Deployment; +import com.epam.aidial.core.config.Model; +import com.epam.aidial.core.config.Upstream; import com.epam.aidial.core.server.Proxy; import com.epam.aidial.core.server.ProxyContext; +import com.epam.aidial.core.server.data.ApiKeyData; import com.epam.aidial.core.server.service.PermissionDeniedException; import com.epam.aidial.core.server.service.ResourceNotFoundException; import com.epam.aidial.core.server.util.ProxyUtil; import com.epam.aidial.core.server.vertx.stream.BufferingReadStream; import com.epam.aidial.core.storage.http.HttpStatus; import io.vertx.core.Future; +import io.vertx.core.MultiMap; import io.vertx.core.buffer.Buffer; import io.vertx.core.http.HttpClientRequest; import io.vertx.core.http.HttpClientResponse; @@ -87,18 +91,39 @@ private void handleRequestError(String deploymentId, Throwable error) { /** * Called when proxy connected to the origin. */ - private void handleProxyRequest(HttpClientRequest proxyRequest) { - log.info("Connected to origin: {}", proxyRequest.connection().remoteAddress()); + void handleProxyRequest(HttpClientRequest proxyRequest) { + log.info("Connected to origin. Trace: {}. Span: {}. Project: {}. Deployment: {}. Address: {}", + context.getTraceId(), context.getSpanId(), + context.getProject(), context.getDeployment().getName(), + proxyRequest.connection().remoteAddress()); HttpServerRequest request = context.getRequest(); context.setProxyRequest(proxyRequest); + context.setProxyConnectTimestamp(System.currentTimeMillis()); - ProxyUtil.copyHeaders(request.headers(), proxyRequest.headers()); + Deployment deployment = context.getDeployment(); + MultiMap excludeHeaders = MultiMap.caseInsensitiveMultiMap(); + if (!deployment.isForwardAuthToken()) { + excludeHeaders.add(HttpHeaders.AUTHORIZATION, "whatever"); + } + + ProxyUtil.copyHeaders(request.headers(), proxyRequest.headers(), excludeHeaders); + + ApiKeyData proxyApiKeyData = context.getProxyApiKeyData(); + proxyRequest.headers().add(Proxy.HEADER_API_KEY, proxyApiKeyData.getPerRequestKey()); + + if (context.getDeployment() instanceof Model model && !model.getUpstreams().isEmpty()) { + Upstream upstream = context.getUpstreamRoute().get(); + proxyRequest.putHeader(Proxy.HEADER_UPSTREAM_ENDPOINT, upstream.getEndpoint()); + proxyRequest.putHeader(Proxy.HEADER_UPSTREAM_KEY, upstream.getKey()); + proxyRequest.putHeader(Proxy.HEADER_UPSTREAM_EXTRA_DATA, upstream.getExtraData()); + } - Buffer proxyRequestBody = context.getRequestBody(); - proxyRequest.putHeader(HttpHeaders.CONTENT_LENGTH, Integer.toString(proxyRequestBody.length())); + Buffer requestBody = context.getRequestBody(); + proxyRequest.putHeader(HttpHeaders.CONTENT_LENGTH, Integer.toString(requestBody.length())); + context.getRequestHeaders().forEach(proxyRequest::putHeader); - proxyRequest.send(proxyRequestBody) + proxyRequest.send(requestBody) .onSuccess(this::handleProxyResponse) .onFailure(this::handleProxyRequestError); } diff --git a/server/src/test/java/com/epam/aidial/core/server/FeaturesApiTest.java b/server/src/test/java/com/epam/aidial/core/server/FeaturesApiTest.java index 75245ee67..6dd0623bb 100644 --- a/server/src/test/java/com/epam/aidial/core/server/FeaturesApiTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/FeaturesApiTest.java @@ -2,9 +2,14 @@ import io.vertx.core.http.HttpMethod; import lombok.SneakyThrows; +import okhttp3.Headers; + import org.junit.jupiter.api.Test; import java.net.URI; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; public class FeaturesApiTest extends ResourceBaseTest { @@ -66,9 +71,22 @@ void testUpstreamEndpoint(String inboundPath, String upstream) { void testUpstreamEndpoint(String inboundPath, String upstream, HttpMethod method) { URI uri = URI.create(upstream); try (TestWebServer server = new TestWebServer(uri.getPort())) { - server.map(method, uri.getPath(), 200, "PONG"); - Response response = send(method, inboundPath); - verify(response, 200, "PONG"); + server.map(method, uri.getPath(), request -> TestWebServer.createResponse(200, "PONG", convertHeadersToFlatArray(request.getHeaders()))); + + Response response = send(method, inboundPath, null, "", "foo", "bar"); + verify(response, 200, "PONG", "foo", "bar"); + } + } + + private static String[] convertHeadersToFlatArray(Headers headers) { + List flatHeadersList = new ArrayList<>(); + for (Map.Entry> entry : headers.toMultimap().entrySet()) { + String key = entry.getKey(); + for (String value : entry.getValue()) { + flatHeadersList.add(key); + flatHeadersList.add(value); + } } + return flatHeadersList.toArray(new String[0]); } } diff --git a/server/src/test/java/com/epam/aidial/core/server/ResourceBaseTest.java b/server/src/test/java/com/epam/aidial/core/server/ResourceBaseTest.java index 6e2e6ad1e..7419f2a14 100644 --- a/server/src/test/java/com/epam/aidial/core/server/ResourceBaseTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/ResourceBaseTest.java @@ -226,9 +226,14 @@ static void verify(Response response, int status) { assertEquals(status, response.status(), () -> "Actual response body: " + response.body()); } - static void verify(Response response, int status, String body) { + static void verify(Response response, int status, String body, String... headers) { assertEquals(status, response.status(), () -> "Actual response body: " + response.body()); assertEquals(body, response.body()); + for (int i = 0; i < headers.length; i += 2) { + String key = headers[i]; + String value = headers[i + 1]; + assertEquals(value, response.headers.get(key)); + } } static void verifyJson(Response response, int status, String body) { diff --git a/server/src/test/java/com/epam/aidial/core/server/TestWebServer.java b/server/src/test/java/com/epam/aidial/core/server/TestWebServer.java index a8fbdbc74..682d0985f 100644 --- a/server/src/test/java/com/epam/aidial/core/server/TestWebServer.java +++ b/server/src/test/java/com/epam/aidial/core/server/TestWebServer.java @@ -51,8 +51,8 @@ public void map(HttpMethod method, String path, int status) { map(method, path, status, ""); } - public void map(HttpMethod method, String path, int status, String body) { - map(method, path, createResponse(status, body)); + public void map(HttpMethod method, String path, int status, String body, String... headers) { + map(method, path, createResponse(status, body, headers)); } private MockResponse onRequest(RecordedRequest request) { @@ -67,10 +67,15 @@ private MockResponse onRequest(RecordedRequest request) { return response; } - private static MockResponse createResponse(int status, String body) { + public static MockResponse createResponse(int status, String body, String... headers) { MockResponse response = new MockResponse(); response.setResponseCode(status); response.setBody(body); + for (int i = 0; i < headers.length; i += 2) { + String key = headers[i]; + String value = headers[i + 1]; + response.setHeader(key, value); + } return response; } From 8c35face6929ac35dea1648d996518d6256c2273 Mon Sep 17 00:00:00 2001 From: Aliaksandr Stsiapanay Date: Fri, 10 Jan 2025 19:56:16 +0300 Subject: [PATCH 2/4] fix: setup proxy API key data --- .../DeploymentFeatureController.java | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentFeatureController.java b/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentFeatureController.java index eef394b6e..8b9829792 100644 --- a/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentFeatureController.java +++ b/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentFeatureController.java @@ -66,6 +66,24 @@ private void handleRequestBody(String endpoint, boolean requireEndpoint, Buffer return; } + ApiKeyData proxyApiKeyData = new ApiKeyData(); + setupProxyApiKeyData(proxyApiKeyData); + + proxy.getVertx().executeBlocking(() -> { + proxy.getApiKeyStore().assignPerRequestApiKey(proxyApiKeyData); + return null; + }, false) + .onSuccess(ignore -> sendRequest(endpoint)).onFailure(this::handleError); + + } + + private void handleError(Throwable error) { + log.info("Error occurred while processing request", error); + context.respond(HttpStatus.INTERNAL_SERVER_ERROR, error.getMessage()); + } + + @SneakyThrows + private void sendRequest(String endpoint) { RequestOptions options = new RequestOptions() .setAbsoluteURI(new URL(endpoint)) .setMethod(context.getRequest().method()); @@ -75,6 +93,11 @@ private void handleRequestBody(String endpoint, boolean requireEndpoint, Buffer .onFailure(this::handleProxyConnectionError); } + private void setupProxyApiKeyData(ApiKeyData proxyApiKeyData) { + context.setProxyApiKeyData(proxyApiKeyData); + ApiKeyData.initFromContext(proxyApiKeyData, context); + } + private void handleRequestError(String deploymentId, Throwable error) { if (error instanceof PermissionDeniedException) { log.error("Forbidden deployment {}. Project: {}. User sub: {}", deploymentId, context.getProject(), context.getUserSub()); @@ -112,13 +135,6 @@ void handleProxyRequest(HttpClientRequest proxyRequest) { ApiKeyData proxyApiKeyData = context.getProxyApiKeyData(); proxyRequest.headers().add(Proxy.HEADER_API_KEY, proxyApiKeyData.getPerRequestKey()); - if (context.getDeployment() instanceof Model model && !model.getUpstreams().isEmpty()) { - Upstream upstream = context.getUpstreamRoute().get(); - proxyRequest.putHeader(Proxy.HEADER_UPSTREAM_ENDPOINT, upstream.getEndpoint()); - proxyRequest.putHeader(Proxy.HEADER_UPSTREAM_KEY, upstream.getKey()); - proxyRequest.putHeader(Proxy.HEADER_UPSTREAM_EXTRA_DATA, upstream.getExtraData()); - } - Buffer requestBody = context.getRequestBody(); proxyRequest.putHeader(HttpHeaders.CONTENT_LENGTH, Integer.toString(requestBody.length())); context.getRequestHeaders().forEach(proxyRequest::putHeader); From f6972e7408cd0679267dc8e2c60176faf458f659 Mon Sep 17 00:00:00 2001 From: Aliaksandr Stsiapanay Date: Sat, 11 Jan 2025 12:16:48 +0300 Subject: [PATCH 3/4] fix: invalidate per request api key --- .../DeploymentFeatureController.java | 42 ++++++++++++++----- .../aidial/core/server/FeaturesApiTest.java | 19 +-------- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentFeatureController.java b/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentFeatureController.java index 8b9829792..36617c3c6 100644 --- a/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentFeatureController.java +++ b/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentFeatureController.java @@ -1,8 +1,6 @@ package com.epam.aidial.core.server.controller; import com.epam.aidial.core.config.Deployment; -import com.epam.aidial.core.config.Model; -import com.epam.aidial.core.config.Upstream; import com.epam.aidial.core.server.Proxy; import com.epam.aidial.core.server.ProxyContext; import com.epam.aidial.core.server.data.ApiKeyData; @@ -58,9 +56,9 @@ private void handleRequestBody(String endpoint, boolean requireEndpoint, Buffer if (endpoint == null) { if (requireEndpoint) { - context.respond(HttpStatus.FORBIDDEN, "Forbidden deployment"); + respond(HttpStatus.FORBIDDEN, "Forbidden deployment"); } else { - context.respond(HttpStatus.OK); + respond(HttpStatus.OK); proxy.getLogStore().save(context); } return; @@ -79,7 +77,7 @@ private void handleRequestBody(String endpoint, boolean requireEndpoint, Buffer private void handleError(Throwable error) { log.info("Error occurred while processing request", error); - context.respond(HttpStatus.INTERNAL_SERVER_ERROR, error.getMessage()); + respond(HttpStatus.INTERNAL_SERVER_ERROR, error.getMessage()); } @SneakyThrows @@ -101,13 +99,13 @@ private void setupProxyApiKeyData(ApiKeyData proxyApiKeyData) { private void handleRequestError(String deploymentId, Throwable error) { if (error instanceof PermissionDeniedException) { log.error("Forbidden deployment {}. Project: {}. User sub: {}", deploymentId, context.getProject(), context.getUserSub()); - context.respond(HttpStatus.FORBIDDEN, error.getMessage()); + respond(HttpStatus.FORBIDDEN, error.getMessage()); } else if (error instanceof ResourceNotFoundException) { log.error("Deployment not found {}", deploymentId, error); - context.respond(HttpStatus.NOT_FOUND, error.getMessage()); + respond(HttpStatus.NOT_FOUND, error.getMessage()); } else { log.error("Failed to handle deployment {}", deploymentId, error); - context.respond(HttpStatus.INTERNAL_SERVER_ERROR, "Failed to process deployment: " + deploymentId); + respond(HttpStatus.INTERNAL_SERVER_ERROR, "Failed to process deployment: " + deploymentId); } } @@ -183,7 +181,7 @@ private void handleResponse() { */ private void handleRequestBodyError(Throwable error) { log.warn("Failed to receive client body: {}", error.getMessage()); - context.respond(HttpStatus.UNPROCESSABLE_ENTITY, "Failed to receive body"); + respond(HttpStatus.UNPROCESSABLE_ENTITY, "Failed to receive body"); } /** @@ -191,7 +189,7 @@ private void handleRequestBodyError(Throwable error) { */ private void handleProxyConnectionError(Throwable error) { log.warn("Can't connect to origin: {}", error.getMessage()); - context.respond(HttpStatus.BAD_GATEWAY, "connection error to origin"); + respond(HttpStatus.BAD_GATEWAY, "connection error to origin"); } /** @@ -199,7 +197,7 @@ private void handleProxyConnectionError(Throwable error) { */ private void handleProxyRequestError(Throwable error) { log.warn("Can't send request to origin: {}", error.getMessage()); - context.respond(HttpStatus.BAD_GATEWAY, "deployment responded with error"); + respond(HttpStatus.BAD_GATEWAY, "deployment responded with error"); } /** @@ -210,4 +208,26 @@ private void handleResponseError(Throwable error) { context.getProxyRequest().reset(); // drop connection to stop origin response context.getResponse().reset(); // drop connection, so that partial client response won't seem complete } + + private void respond(HttpStatus status, String errorMessage) { + finalizeRequest(); + context.respond(status, errorMessage); + } + + private void respond(HttpStatus status) { + finalizeRequest(); + context.respond(status); + } + + private void finalizeRequest() { + ApiKeyData proxyApiKeyData = context.getProxyApiKeyData(); + if (proxyApiKeyData != null) { + proxy.getApiKeyStore().invalidatePerRequestApiKey(proxyApiKeyData) + .onSuccess(invalidated -> { + if (!invalidated) { + log.warn("Per request is not removed: {}", proxyApiKeyData.getPerRequestKey()); + } + }).onFailure(error -> log.error("error occurred on invalidating per-request key", error)); + } + } } diff --git a/server/src/test/java/com/epam/aidial/core/server/FeaturesApiTest.java b/server/src/test/java/com/epam/aidial/core/server/FeaturesApiTest.java index 6dd0623bb..ece981024 100644 --- a/server/src/test/java/com/epam/aidial/core/server/FeaturesApiTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/FeaturesApiTest.java @@ -2,14 +2,9 @@ import io.vertx.core.http.HttpMethod; import lombok.SneakyThrows; -import okhttp3.Headers; - import org.junit.jupiter.api.Test; import java.net.URI; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; public class FeaturesApiTest extends ResourceBaseTest { @@ -71,22 +66,10 @@ void testUpstreamEndpoint(String inboundPath, String upstream) { void testUpstreamEndpoint(String inboundPath, String upstream, HttpMethod method) { URI uri = URI.create(upstream); try (TestWebServer server = new TestWebServer(uri.getPort())) { - server.map(method, uri.getPath(), request -> TestWebServer.createResponse(200, "PONG", convertHeadersToFlatArray(request.getHeaders()))); + server.map(method, uri.getPath(), request -> TestWebServer.createResponse(200, "PONG", "foo", "bar")); Response response = send(method, inboundPath, null, "", "foo", "bar"); verify(response, 200, "PONG", "foo", "bar"); } } - - private static String[] convertHeadersToFlatArray(Headers headers) { - List flatHeadersList = new ArrayList<>(); - for (Map.Entry> entry : headers.toMultimap().entrySet()) { - String key = entry.getKey(); - for (String value : entry.getValue()) { - flatHeadersList.add(key); - flatHeadersList.add(value); - } - } - return flatHeadersList.toArray(new String[0]); - } } From 5e9bf2763a37e5d364cebdef048e4dfb2d8d160d Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Mon, 13 Jan 2025 11:04:40 +0000 Subject: [PATCH 4/4] feat: improved tests --- .../DeploymentFeatureController.java | 2 +- .../aidial/core/server/FeaturesApiTest.java | 36 +++++++++++++++++-- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentFeatureController.java b/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentFeatureController.java index 36617c3c6..04579226c 100644 --- a/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentFeatureController.java +++ b/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentFeatureController.java @@ -76,7 +76,7 @@ private void handleRequestBody(String endpoint, boolean requireEndpoint, Buffer } private void handleError(Throwable error) { - log.info("Error occurred while processing request", error); + log.error("Error occurred while processing request", error); respond(HttpStatus.INTERNAL_SERVER_ERROR, error.getMessage()); } diff --git a/server/src/test/java/com/epam/aidial/core/server/FeaturesApiTest.java b/server/src/test/java/com/epam/aidial/core/server/FeaturesApiTest.java index ece981024..e1af4d89c 100644 --- a/server/src/test/java/com/epam/aidial/core/server/FeaturesApiTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/FeaturesApiTest.java @@ -2,12 +2,36 @@ import io.vertx.core.http.HttpMethod; import lombok.SneakyThrows; +import okhttp3.Headers; import org.junit.jupiter.api.Test; import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; public class FeaturesApiTest extends ResourceBaseTest { + private static String[] convertHeadersToFlatArray(Headers headers) { + return StreamSupport.stream(headers.spliterator(), false) + .flatMap(header -> Stream.of(header.getFirst(), header.getSecond())) + .toArray(String[]::new); + } + + private static Headers filterHeaders(Headers headers, Headers mask) { + Headers.Builder filteredHeaders = new Headers.Builder(); + for (Map.Entry> entry : headers.toMultimap().entrySet()) { + String key = entry.getKey(); + if (mask.names().contains(key.toLowerCase())) { + for (String value : entry.getValue()) { + filteredHeaders.add(key, value); + } + } + } + return filteredHeaders.build(); + } + @Test void testRateEndpointModel() { String inboundPath = "/v1/chat-gpt-35-turbo/rate"; @@ -64,12 +88,18 @@ void testUpstreamEndpoint(String inboundPath, String upstream) { @SneakyThrows void testUpstreamEndpoint(String inboundPath, String upstream, HttpMethod method) { + Headers requestExtraHeaders = new Headers.Builder().add("foo", "bar").build(); + String[] requestExtraHeadersArray = convertHeadersToFlatArray(requestExtraHeaders); + URI uri = URI.create(upstream); try (TestWebServer server = new TestWebServer(uri.getPort())) { - server.map(method, uri.getPath(), request -> TestWebServer.createResponse(200, "PONG", "foo", "bar")); + server.map(method, uri.getPath(), request -> { + Headers responseHeaders = filterHeaders(request.getHeaders(), requestExtraHeaders); + return TestWebServer.createResponse(200, "PONG", convertHeadersToFlatArray(responseHeaders)); + }); - Response response = send(method, inboundPath, null, "", "foo", "bar"); - verify(response, 200, "PONG", "foo", "bar"); + Response response = send(method, inboundPath, null, "PING", requestExtraHeadersArray); + verify(response, 200, "PONG", requestExtraHeadersArray); } } }