Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: propagate headers to feature endpoints #635

Merged
merged 5 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import com.epam.aidial.core.config.Deployment;
import com.epam.aidial.core.server.Proxy;
import com.epam.aidial.core.server.ProxyContext;
import com.epam.aidial.core.server.data.ApiKeyData;
import com.epam.aidial.core.server.service.PermissionDeniedException;
import com.epam.aidial.core.server.service.ResourceNotFoundException;
import com.epam.aidial.core.server.util.ProxyUtil;
import com.epam.aidial.core.server.vertx.stream.BufferingReadStream;
import com.epam.aidial.core.storage.http.HttpStatus;
import io.vertx.core.Future;
import io.vertx.core.MultiMap;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpClientRequest;
import io.vertx.core.http.HttpClientResponse;
Expand Down Expand Up @@ -54,14 +56,32 @@ private void handleRequestBody(String endpoint, boolean requireEndpoint, Buffer

if (endpoint == null) {
if (requireEndpoint) {
context.respond(HttpStatus.FORBIDDEN, "Forbidden deployment");
respond(HttpStatus.FORBIDDEN, "Forbidden deployment");
} else {
context.respond(HttpStatus.OK);
respond(HttpStatus.OK);
proxy.getLogStore().save(context);
}
return;
}

ApiKeyData proxyApiKeyData = new ApiKeyData();
setupProxyApiKeyData(proxyApiKeyData);

proxy.getVertx().executeBlocking(() -> {
proxy.getApiKeyStore().assignPerRequestApiKey(proxyApiKeyData);
return null;
}, false)
.onSuccess(ignore -> sendRequest(endpoint)).onFailure(this::handleError);

}

private void handleError(Throwable error) {
log.error("Error occurred while processing request", error);
respond(HttpStatus.INTERNAL_SERVER_ERROR, error.getMessage());
}

@SneakyThrows
private void sendRequest(String endpoint) {
RequestOptions options = new RequestOptions()
.setAbsoluteURI(new URL(endpoint))
.setMethod(context.getRequest().method());
Expand All @@ -71,34 +91,53 @@ private void handleRequestBody(String endpoint, boolean requireEndpoint, Buffer
.onFailure(this::handleProxyConnectionError);
}

private void setupProxyApiKeyData(ApiKeyData proxyApiKeyData) {
context.setProxyApiKeyData(proxyApiKeyData);
ApiKeyData.initFromContext(proxyApiKeyData, context);
}

private void handleRequestError(String deploymentId, Throwable error) {
if (error instanceof PermissionDeniedException) {
log.error("Forbidden deployment {}. Project: {}. User sub: {}", deploymentId, context.getProject(), context.getUserSub());
context.respond(HttpStatus.FORBIDDEN, error.getMessage());
respond(HttpStatus.FORBIDDEN, error.getMessage());
} else if (error instanceof ResourceNotFoundException) {
log.error("Deployment not found {}", deploymentId, error);
context.respond(HttpStatus.NOT_FOUND, error.getMessage());
respond(HttpStatus.NOT_FOUND, error.getMessage());
} else {
log.error("Failed to handle deployment {}", deploymentId, error);
context.respond(HttpStatus.INTERNAL_SERVER_ERROR, "Failed to process deployment: " + deploymentId);
respond(HttpStatus.INTERNAL_SERVER_ERROR, "Failed to process deployment: " + deploymentId);
}
}

/**
* Called when proxy connected to the origin.
*/
private void handleProxyRequest(HttpClientRequest proxyRequest) {
log.info("Connected to origin: {}", proxyRequest.connection().remoteAddress());
void handleProxyRequest(HttpClientRequest proxyRequest) {
log.info("Connected to origin. Trace: {}. Span: {}. Project: {}. Deployment: {}. Address: {}",
context.getTraceId(), context.getSpanId(),
context.getProject(), context.getDeployment().getName(),
proxyRequest.connection().remoteAddress());

HttpServerRequest request = context.getRequest();
context.setProxyRequest(proxyRequest);
context.setProxyConnectTimestamp(System.currentTimeMillis());

ProxyUtil.copyHeaders(request.headers(), proxyRequest.headers());
Deployment deployment = context.getDeployment();
MultiMap excludeHeaders = MultiMap.caseInsensitiveMultiMap();
if (!deployment.isForwardAuthToken()) {
excludeHeaders.add(HttpHeaders.AUTHORIZATION, "whatever");
}

ProxyUtil.copyHeaders(request.headers(), proxyRequest.headers(), excludeHeaders);

ApiKeyData proxyApiKeyData = context.getProxyApiKeyData();
proxyRequest.headers().add(Proxy.HEADER_API_KEY, proxyApiKeyData.getPerRequestKey());

Buffer proxyRequestBody = context.getRequestBody();
proxyRequest.putHeader(HttpHeaders.CONTENT_LENGTH, Integer.toString(proxyRequestBody.length()));
Buffer requestBody = context.getRequestBody();
proxyRequest.putHeader(HttpHeaders.CONTENT_LENGTH, Integer.toString(requestBody.length()));
context.getRequestHeaders().forEach(proxyRequest::putHeader);

proxyRequest.send(proxyRequestBody)
proxyRequest.send(requestBody)
.onSuccess(this::handleProxyResponse)
.onFailure(this::handleProxyRequestError);
}
Expand Down Expand Up @@ -142,23 +181,23 @@ private void handleResponse() {
*/
private void handleRequestBodyError(Throwable error) {
log.warn("Failed to receive client body: {}", error.getMessage());
context.respond(HttpStatus.UNPROCESSABLE_ENTITY, "Failed to receive body");
respond(HttpStatus.UNPROCESSABLE_ENTITY, "Failed to receive body");
}

/**
* Called when proxy failed to connect to the origin.
*/
private void handleProxyConnectionError(Throwable error) {
log.warn("Can't connect to origin: {}", error.getMessage());
context.respond(HttpStatus.BAD_GATEWAY, "connection error to origin");
respond(HttpStatus.BAD_GATEWAY, "connection error to origin");
}

/**
* Called when proxy failed to send request to the origin.
*/
private void handleProxyRequestError(Throwable error) {
log.warn("Can't send request to origin: {}", error.getMessage());
context.respond(HttpStatus.BAD_GATEWAY, "deployment responded with error");
respond(HttpStatus.BAD_GATEWAY, "deployment responded with error");
}

/**
Expand All @@ -169,4 +208,26 @@ private void handleResponseError(Throwable error) {
context.getProxyRequest().reset(); // drop connection to stop origin response
context.getResponse().reset(); // drop connection, so that partial client response won't seem complete
}

private void respond(HttpStatus status, String errorMessage) {
finalizeRequest();
context.respond(status, errorMessage);
}

private void respond(HttpStatus status) {
finalizeRequest();
context.respond(status);
}

private void finalizeRequest() {
ApiKeyData proxyApiKeyData = context.getProxyApiKeyData();
if (proxyApiKeyData != null) {
proxy.getApiKeyStore().invalidatePerRequestApiKey(proxyApiKeyData)
.onSuccess(invalidated -> {
if (!invalidated) {
log.warn("Per request is not removed: {}", proxyApiKeyData.getPerRequestKey());
}
}).onFailure(error -> log.error("error occurred on invalidating per-request key", error));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,36 @@

import io.vertx.core.http.HttpMethod;
import lombok.SneakyThrows;
import okhttp3.Headers;
import org.junit.jupiter.api.Test;

import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

public class FeaturesApiTest extends ResourceBaseTest {

private static String[] convertHeadersToFlatArray(Headers headers) {
return StreamSupport.stream(headers.spliterator(), false)
.flatMap(header -> Stream.of(header.getFirst(), header.getSecond()))
.toArray(String[]::new);
}

private static Headers filterHeaders(Headers headers, Headers mask) {
Headers.Builder filteredHeaders = new Headers.Builder();
for (Map.Entry<String, List<String>> entry : headers.toMultimap().entrySet()) {
String key = entry.getKey();
if (mask.names().contains(key.toLowerCase())) {
for (String value : entry.getValue()) {
filteredHeaders.add(key, value);
}
}
}
return filteredHeaders.build();
}

@Test
void testRateEndpointModel() {
String inboundPath = "/v1/chat-gpt-35-turbo/rate";
Expand Down Expand Up @@ -64,11 +88,18 @@ void testUpstreamEndpoint(String inboundPath, String upstream) {

@SneakyThrows
void testUpstreamEndpoint(String inboundPath, String upstream, HttpMethod method) {
Headers requestExtraHeaders = new Headers.Builder().add("foo", "bar").build();
String[] requestExtraHeadersArray = convertHeadersToFlatArray(requestExtraHeaders);

URI uri = URI.create(upstream);
try (TestWebServer server = new TestWebServer(uri.getPort())) {
server.map(method, uri.getPath(), 200, "PONG");
Response response = send(method, inboundPath);
verify(response, 200, "PONG");
server.map(method, uri.getPath(), request -> {
Headers responseHeaders = filterHeaders(request.getHeaders(), requestExtraHeaders);
return TestWebServer.createResponse(200, "PONG", convertHeadersToFlatArray(responseHeaders));
});

Response response = send(method, inboundPath, null, "", requestExtraHeadersArray);
verify(response, 200, "PONG", requestExtraHeadersArray);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,14 @@ static void verify(Response response, int status) {
assertEquals(status, response.status(), () -> "Actual response body: " + response.body());
}

static void verify(Response response, int status, String body) {
static void verify(Response response, int status, String body, String... headers) {
assertEquals(status, response.status(), () -> "Actual response body: " + response.body());
assertEquals(body, response.body());
for (int i = 0; i < headers.length; i += 2) {
String key = headers[i];
String value = headers[i + 1];
assertEquals(value, response.headers.get(key));
}
}

static void verifyJson(Response response, int status, String body) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ public void map(HttpMethod method, String path, int status) {
map(method, path, status, "");
}

public void map(HttpMethod method, String path, int status, String body) {
map(method, path, createResponse(status, body));
public void map(HttpMethod method, String path, int status, String body, String... headers) {
map(method, path, createResponse(status, body, headers));
}

private MockResponse onRequest(RecordedRequest request) {
Expand All @@ -67,10 +67,15 @@ private MockResponse onRequest(RecordedRequest request) {
return response;
}

private static MockResponse createResponse(int status, String body) {
public static MockResponse createResponse(int status, String body, String... headers) {
MockResponse response = new MockResponse();
response.setResponseCode(status);
response.setBody(body);
for (int i = 0; i < headers.length; i += 2) {
String key = headers[i];
String value = headers[i + 1];
response.setHeader(key, value);
}
return response;
}

Expand Down
Loading