diff --git a/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentFeatureController.java b/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentFeatureController.java index 08023d93..04579226 100644 --- a/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentFeatureController.java +++ b/server/src/main/java/com/epam/aidial/core/server/controller/DeploymentFeatureController.java @@ -3,12 +3,14 @@ import com.epam.aidial.core.config.Deployment; import com.epam.aidial.core.server.Proxy; import com.epam.aidial.core.server.ProxyContext; +import com.epam.aidial.core.server.data.ApiKeyData; import com.epam.aidial.core.server.service.PermissionDeniedException; import com.epam.aidial.core.server.service.ResourceNotFoundException; import com.epam.aidial.core.server.util.ProxyUtil; import com.epam.aidial.core.server.vertx.stream.BufferingReadStream; import com.epam.aidial.core.storage.http.HttpStatus; import io.vertx.core.Future; +import io.vertx.core.MultiMap; import io.vertx.core.buffer.Buffer; import io.vertx.core.http.HttpClientRequest; import io.vertx.core.http.HttpClientResponse; @@ -54,14 +56,32 @@ private void handleRequestBody(String endpoint, boolean requireEndpoint, Buffer if (endpoint == null) { if (requireEndpoint) { - context.respond(HttpStatus.FORBIDDEN, "Forbidden deployment"); + respond(HttpStatus.FORBIDDEN, "Forbidden deployment"); } else { - context.respond(HttpStatus.OK); + respond(HttpStatus.OK); proxy.getLogStore().save(context); } return; } + ApiKeyData proxyApiKeyData = new ApiKeyData(); + setupProxyApiKeyData(proxyApiKeyData); + + proxy.getVertx().executeBlocking(() -> { + proxy.getApiKeyStore().assignPerRequestApiKey(proxyApiKeyData); + return null; + }, false) + .onSuccess(ignore -> sendRequest(endpoint)).onFailure(this::handleError); + + } + + private void handleError(Throwable error) { + log.error("Error occurred while processing request", error); + respond(HttpStatus.INTERNAL_SERVER_ERROR, error.getMessage()); + } + + @SneakyThrows + private void sendRequest(String endpoint) { RequestOptions options = new RequestOptions() .setAbsoluteURI(new URL(endpoint)) .setMethod(context.getRequest().method()); @@ -71,34 +91,53 @@ private void handleRequestBody(String endpoint, boolean requireEndpoint, Buffer .onFailure(this::handleProxyConnectionError); } + private void setupProxyApiKeyData(ApiKeyData proxyApiKeyData) { + context.setProxyApiKeyData(proxyApiKeyData); + ApiKeyData.initFromContext(proxyApiKeyData, context); + } + private void handleRequestError(String deploymentId, Throwable error) { if (error instanceof PermissionDeniedException) { log.error("Forbidden deployment {}. Project: {}. User sub: {}", deploymentId, context.getProject(), context.getUserSub()); - context.respond(HttpStatus.FORBIDDEN, error.getMessage()); + respond(HttpStatus.FORBIDDEN, error.getMessage()); } else if (error instanceof ResourceNotFoundException) { log.error("Deployment not found {}", deploymentId, error); - context.respond(HttpStatus.NOT_FOUND, error.getMessage()); + respond(HttpStatus.NOT_FOUND, error.getMessage()); } else { log.error("Failed to handle deployment {}", deploymentId, error); - context.respond(HttpStatus.INTERNAL_SERVER_ERROR, "Failed to process deployment: " + deploymentId); + respond(HttpStatus.INTERNAL_SERVER_ERROR, "Failed to process deployment: " + deploymentId); } } /** * Called when proxy connected to the origin. */ - private void handleProxyRequest(HttpClientRequest proxyRequest) { - log.info("Connected to origin: {}", proxyRequest.connection().remoteAddress()); + void handleProxyRequest(HttpClientRequest proxyRequest) { + log.info("Connected to origin. Trace: {}. Span: {}. Project: {}. Deployment: {}. Address: {}", + context.getTraceId(), context.getSpanId(), + context.getProject(), context.getDeployment().getName(), + proxyRequest.connection().remoteAddress()); HttpServerRequest request = context.getRequest(); context.setProxyRequest(proxyRequest); + context.setProxyConnectTimestamp(System.currentTimeMillis()); - ProxyUtil.copyHeaders(request.headers(), proxyRequest.headers()); + Deployment deployment = context.getDeployment(); + MultiMap excludeHeaders = MultiMap.caseInsensitiveMultiMap(); + if (!deployment.isForwardAuthToken()) { + excludeHeaders.add(HttpHeaders.AUTHORIZATION, "whatever"); + } + + ProxyUtil.copyHeaders(request.headers(), proxyRequest.headers(), excludeHeaders); + + ApiKeyData proxyApiKeyData = context.getProxyApiKeyData(); + proxyRequest.headers().add(Proxy.HEADER_API_KEY, proxyApiKeyData.getPerRequestKey()); - Buffer proxyRequestBody = context.getRequestBody(); - proxyRequest.putHeader(HttpHeaders.CONTENT_LENGTH, Integer.toString(proxyRequestBody.length())); + Buffer requestBody = context.getRequestBody(); + proxyRequest.putHeader(HttpHeaders.CONTENT_LENGTH, Integer.toString(requestBody.length())); + context.getRequestHeaders().forEach(proxyRequest::putHeader); - proxyRequest.send(proxyRequestBody) + proxyRequest.send(requestBody) .onSuccess(this::handleProxyResponse) .onFailure(this::handleProxyRequestError); } @@ -142,7 +181,7 @@ private void handleResponse() { */ private void handleRequestBodyError(Throwable error) { log.warn("Failed to receive client body: {}", error.getMessage()); - context.respond(HttpStatus.UNPROCESSABLE_ENTITY, "Failed to receive body"); + respond(HttpStatus.UNPROCESSABLE_ENTITY, "Failed to receive body"); } /** @@ -150,7 +189,7 @@ private void handleRequestBodyError(Throwable error) { */ private void handleProxyConnectionError(Throwable error) { log.warn("Can't connect to origin: {}", error.getMessage()); - context.respond(HttpStatus.BAD_GATEWAY, "connection error to origin"); + respond(HttpStatus.BAD_GATEWAY, "connection error to origin"); } /** @@ -158,7 +197,7 @@ private void handleProxyConnectionError(Throwable error) { */ private void handleProxyRequestError(Throwable error) { log.warn("Can't send request to origin: {}", error.getMessage()); - context.respond(HttpStatus.BAD_GATEWAY, "deployment responded with error"); + respond(HttpStatus.BAD_GATEWAY, "deployment responded with error"); } /** @@ -169,4 +208,26 @@ private void handleResponseError(Throwable error) { context.getProxyRequest().reset(); // drop connection to stop origin response context.getResponse().reset(); // drop connection, so that partial client response won't seem complete } + + private void respond(HttpStatus status, String errorMessage) { + finalizeRequest(); + context.respond(status, errorMessage); + } + + private void respond(HttpStatus status) { + finalizeRequest(); + context.respond(status); + } + + private void finalizeRequest() { + ApiKeyData proxyApiKeyData = context.getProxyApiKeyData(); + if (proxyApiKeyData != null) { + proxy.getApiKeyStore().invalidatePerRequestApiKey(proxyApiKeyData) + .onSuccess(invalidated -> { + if (!invalidated) { + log.warn("Per request is not removed: {}", proxyApiKeyData.getPerRequestKey()); + } + }).onFailure(error -> log.error("error occurred on invalidating per-request key", error)); + } + } } diff --git a/server/src/test/java/com/epam/aidial/core/server/FeaturesApiTest.java b/server/src/test/java/com/epam/aidial/core/server/FeaturesApiTest.java index 75245ee6..3add7d25 100644 --- a/server/src/test/java/com/epam/aidial/core/server/FeaturesApiTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/FeaturesApiTest.java @@ -2,12 +2,36 @@ import io.vertx.core.http.HttpMethod; import lombok.SneakyThrows; +import okhttp3.Headers; import org.junit.jupiter.api.Test; import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; public class FeaturesApiTest extends ResourceBaseTest { + private static String[] convertHeadersToFlatArray(Headers headers) { + return StreamSupport.stream(headers.spliterator(), false) + .flatMap(header -> Stream.of(header.getFirst(), header.getSecond())) + .toArray(String[]::new); + } + + private static Headers filterHeaders(Headers headers, Headers mask) { + Headers.Builder filteredHeaders = new Headers.Builder(); + for (Map.Entry> entry : headers.toMultimap().entrySet()) { + String key = entry.getKey(); + if (mask.names().contains(key.toLowerCase())) { + for (String value : entry.getValue()) { + filteredHeaders.add(key, value); + } + } + } + return filteredHeaders.build(); + } + @Test void testRateEndpointModel() { String inboundPath = "/v1/chat-gpt-35-turbo/rate"; @@ -64,11 +88,18 @@ void testUpstreamEndpoint(String inboundPath, String upstream) { @SneakyThrows void testUpstreamEndpoint(String inboundPath, String upstream, HttpMethod method) { + Headers requestExtraHeaders = new Headers.Builder().add("foo", "bar").build(); + String[] requestExtraHeadersArray = convertHeadersToFlatArray(requestExtraHeaders); + URI uri = URI.create(upstream); try (TestWebServer server = new TestWebServer(uri.getPort())) { - server.map(method, uri.getPath(), 200, "PONG"); - Response response = send(method, inboundPath); - verify(response, 200, "PONG"); + server.map(method, uri.getPath(), request -> { + Headers responseHeaders = filterHeaders(request.getHeaders(), requestExtraHeaders); + return TestWebServer.createResponse(200, "PONG", convertHeadersToFlatArray(responseHeaders)); + }); + + Response response = send(method, inboundPath, null, "", requestExtraHeadersArray); + verify(response, 200, "PONG", requestExtraHeadersArray); } } } diff --git a/server/src/test/java/com/epam/aidial/core/server/ResourceBaseTest.java b/server/src/test/java/com/epam/aidial/core/server/ResourceBaseTest.java index d459f85b..4376e29f 100644 --- a/server/src/test/java/com/epam/aidial/core/server/ResourceBaseTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/ResourceBaseTest.java @@ -229,9 +229,14 @@ static void verify(Response response, int status) { assertEquals(status, response.status(), () -> "Actual response body: " + response.body()); } - static void verify(Response response, int status, String body) { + static void verify(Response response, int status, String body, String... headers) { assertEquals(status, response.status(), () -> "Actual response body: " + response.body()); assertEquals(body, response.body()); + for (int i = 0; i < headers.length; i += 2) { + String key = headers[i]; + String value = headers[i + 1]; + assertEquals(value, response.headers.get(key)); + } } static void verifyJson(Response response, int status, String body) { diff --git a/server/src/test/java/com/epam/aidial/core/server/TestWebServer.java b/server/src/test/java/com/epam/aidial/core/server/TestWebServer.java index a8fbdbc7..682d0985 100644 --- a/server/src/test/java/com/epam/aidial/core/server/TestWebServer.java +++ b/server/src/test/java/com/epam/aidial/core/server/TestWebServer.java @@ -51,8 +51,8 @@ public void map(HttpMethod method, String path, int status) { map(method, path, status, ""); } - public void map(HttpMethod method, String path, int status, String body) { - map(method, path, createResponse(status, body)); + public void map(HttpMethod method, String path, int status, String body, String... headers) { + map(method, path, createResponse(status, body, headers)); } private MockResponse onRequest(RecordedRequest request) { @@ -67,10 +67,15 @@ private MockResponse onRequest(RecordedRequest request) { return response; } - private static MockResponse createResponse(int status, String body) { + public static MockResponse createResponse(int status, String body, String... headers) { MockResponse response = new MockResponse(); response.setResponseCode(status); response.setBody(body); + for (int i = 0; i < headers.length; i += 2) { + String key = headers[i]; + String value = headers[i + 1]; + response.setHeader(key, value); + } return response; }