Skip to content

Commit

Permalink
Merge branch 'development' into feat/issue-637
Browse files Browse the repository at this point in the history
  • Loading branch information
astsiapanay authored Jan 13, 2025
2 parents 798e75b + 6c4e348 commit bd4fbdd
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 21 deletions.
17 changes: 17 additions & 0 deletions .github/workflows/dependency-review.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: Dependency Review

on:
pull_request_target:
types:
- opened
- synchronize

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number }}
cancel-in-progress: true

jobs:
dependency-review:
uses: epam/ai-dial-ci/.github/workflows/[email protected]
secrets:
ACTIONS_BOT_TOKEN: ${{ secrets.ACTIONS_BOT_TOKEN }}
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import com.epam.aidial.core.config.Deployment;
import com.epam.aidial.core.server.Proxy;
import com.epam.aidial.core.server.ProxyContext;
import com.epam.aidial.core.server.data.ApiKeyData;
import com.epam.aidial.core.server.service.PermissionDeniedException;
import com.epam.aidial.core.server.service.ResourceNotFoundException;
import com.epam.aidial.core.server.util.ProxyUtil;
import com.epam.aidial.core.server.vertx.stream.BufferingReadStream;
import com.epam.aidial.core.storage.http.HttpStatus;
import io.vertx.core.Future;
import io.vertx.core.MultiMap;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpClientRequest;
import io.vertx.core.http.HttpClientResponse;
Expand Down Expand Up @@ -54,14 +56,32 @@ private void handleRequestBody(String endpoint, boolean requireEndpoint, Buffer

if (endpoint == null) {
if (requireEndpoint) {
context.respond(HttpStatus.FORBIDDEN, "Forbidden deployment");
respond(HttpStatus.FORBIDDEN, "Forbidden deployment");
} else {
context.respond(HttpStatus.OK);
respond(HttpStatus.OK);
proxy.getLogStore().save(context);
}
return;
}

ApiKeyData proxyApiKeyData = new ApiKeyData();
setupProxyApiKeyData(proxyApiKeyData);

proxy.getVertx().executeBlocking(() -> {
proxy.getApiKeyStore().assignPerRequestApiKey(proxyApiKeyData);
return null;
}, false)
.onSuccess(ignore -> sendRequest(endpoint)).onFailure(this::handleError);

}

private void handleError(Throwable error) {
log.error("Error occurred while processing request", error);
respond(HttpStatus.INTERNAL_SERVER_ERROR, error.getMessage());
}

@SneakyThrows
private void sendRequest(String endpoint) {
RequestOptions options = new RequestOptions()
.setAbsoluteURI(new URL(endpoint))
.setMethod(context.getRequest().method());
Expand All @@ -71,34 +91,53 @@ private void handleRequestBody(String endpoint, boolean requireEndpoint, Buffer
.onFailure(this::handleProxyConnectionError);
}

private void setupProxyApiKeyData(ApiKeyData proxyApiKeyData) {
context.setProxyApiKeyData(proxyApiKeyData);
ApiKeyData.initFromContext(proxyApiKeyData, context);
}

private void handleRequestError(String deploymentId, Throwable error) {
if (error instanceof PermissionDeniedException) {
log.error("Forbidden deployment {}. Project: {}. User sub: {}", deploymentId, context.getProject(), context.getUserSub());
context.respond(HttpStatus.FORBIDDEN, error.getMessage());
respond(HttpStatus.FORBIDDEN, error.getMessage());
} else if (error instanceof ResourceNotFoundException) {
log.error("Deployment not found {}", deploymentId, error);
context.respond(HttpStatus.NOT_FOUND, error.getMessage());
respond(HttpStatus.NOT_FOUND, error.getMessage());
} else {
log.error("Failed to handle deployment {}", deploymentId, error);
context.respond(HttpStatus.INTERNAL_SERVER_ERROR, "Failed to process deployment: " + deploymentId);
respond(HttpStatus.INTERNAL_SERVER_ERROR, "Failed to process deployment: " + deploymentId);
}
}

/**
* Called when proxy connected to the origin.
*/
private void handleProxyRequest(HttpClientRequest proxyRequest) {
log.info("Connected to origin: {}", proxyRequest.connection().remoteAddress());
void handleProxyRequest(HttpClientRequest proxyRequest) {
log.info("Connected to origin. Trace: {}. Span: {}. Project: {}. Deployment: {}. Address: {}",
context.getTraceId(), context.getSpanId(),
context.getProject(), context.getDeployment().getName(),
proxyRequest.connection().remoteAddress());

HttpServerRequest request = context.getRequest();
context.setProxyRequest(proxyRequest);
context.setProxyConnectTimestamp(System.currentTimeMillis());

ProxyUtil.copyHeaders(request.headers(), proxyRequest.headers());
Deployment deployment = context.getDeployment();
MultiMap excludeHeaders = MultiMap.caseInsensitiveMultiMap();
if (!deployment.isForwardAuthToken()) {
excludeHeaders.add(HttpHeaders.AUTHORIZATION, "whatever");
}

ProxyUtil.copyHeaders(request.headers(), proxyRequest.headers(), excludeHeaders);

ApiKeyData proxyApiKeyData = context.getProxyApiKeyData();
proxyRequest.headers().add(Proxy.HEADER_API_KEY, proxyApiKeyData.getPerRequestKey());

Buffer proxyRequestBody = context.getRequestBody();
proxyRequest.putHeader(HttpHeaders.CONTENT_LENGTH, Integer.toString(proxyRequestBody.length()));
Buffer requestBody = context.getRequestBody();
proxyRequest.putHeader(HttpHeaders.CONTENT_LENGTH, Integer.toString(requestBody.length()));
context.getRequestHeaders().forEach(proxyRequest::putHeader);

proxyRequest.send(proxyRequestBody)
proxyRequest.send(requestBody)
.onSuccess(this::handleProxyResponse)
.onFailure(this::handleProxyRequestError);
}
Expand Down Expand Up @@ -142,23 +181,23 @@ private void handleResponse() {
*/
private void handleRequestBodyError(Throwable error) {
log.warn("Failed to receive client body: {}", error.getMessage());
context.respond(HttpStatus.UNPROCESSABLE_ENTITY, "Failed to receive body");
respond(HttpStatus.UNPROCESSABLE_ENTITY, "Failed to receive body");
}

/**
* Called when proxy failed to connect to the origin.
*/
private void handleProxyConnectionError(Throwable error) {
log.warn("Can't connect to origin: {}", error.getMessage());
context.respond(HttpStatus.BAD_GATEWAY, "connection error to origin");
respond(HttpStatus.BAD_GATEWAY, "connection error to origin");
}

/**
* Called when proxy failed to send request to the origin.
*/
private void handleProxyRequestError(Throwable error) {
log.warn("Can't send request to origin: {}", error.getMessage());
context.respond(HttpStatus.BAD_GATEWAY, "deployment responded with error");
respond(HttpStatus.BAD_GATEWAY, "deployment responded with error");
}

/**
Expand All @@ -169,4 +208,26 @@ private void handleResponseError(Throwable error) {
context.getProxyRequest().reset(); // drop connection to stop origin response
context.getResponse().reset(); // drop connection, so that partial client response won't seem complete
}

private void respond(HttpStatus status, String errorMessage) {
finalizeRequest();
context.respond(status, errorMessage);
}

private void respond(HttpStatus status) {
finalizeRequest();
context.respond(status);
}

private void finalizeRequest() {
ApiKeyData proxyApiKeyData = context.getProxyApiKeyData();
if (proxyApiKeyData != null) {
proxy.getApiKeyStore().invalidatePerRequestApiKey(proxyApiKeyData)
.onSuccess(invalidated -> {
if (!invalidated) {
log.warn("Per request is not removed: {}", proxyApiKeyData.getPerRequestKey());
}
}).onFailure(error -> log.error("error occurred on invalidating per-request key", error));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,36 @@

import io.vertx.core.http.HttpMethod;
import lombok.SneakyThrows;
import okhttp3.Headers;
import org.junit.jupiter.api.Test;

import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

public class FeaturesApiTest extends ResourceBaseTest {

private static String[] convertHeadersToFlatArray(Headers headers) {
return StreamSupport.stream(headers.spliterator(), false)
.flatMap(header -> Stream.of(header.getFirst(), header.getSecond()))
.toArray(String[]::new);
}

private static Headers filterHeaders(Headers headers, Headers mask) {
Headers.Builder filteredHeaders = new Headers.Builder();
for (Map.Entry<String, List<String>> entry : headers.toMultimap().entrySet()) {
String key = entry.getKey();
if (mask.names().contains(key.toLowerCase())) {
for (String value : entry.getValue()) {
filteredHeaders.add(key, value);
}
}
}
return filteredHeaders.build();
}

@Test
void testRateEndpointModel() {
String inboundPath = "/v1/chat-gpt-35-turbo/rate";
Expand Down Expand Up @@ -64,11 +88,18 @@ void testUpstreamEndpoint(String inboundPath, String upstream) {

@SneakyThrows
void testUpstreamEndpoint(String inboundPath, String upstream, HttpMethod method) {
Headers requestExtraHeaders = new Headers.Builder().add("foo", "bar").build();
String[] requestExtraHeadersArray = convertHeadersToFlatArray(requestExtraHeaders);

URI uri = URI.create(upstream);
try (TestWebServer server = new TestWebServer(uri.getPort())) {
server.map(method, uri.getPath(), 200, "PONG");
Response response = send(method, inboundPath);
verify(response, 200, "PONG");
server.map(method, uri.getPath(), request -> {
Headers responseHeaders = filterHeaders(request.getHeaders(), requestExtraHeaders);
return TestWebServer.createResponse(200, "PONG", convertHeadersToFlatArray(responseHeaders));
});

Response response = send(method, inboundPath, null, "", requestExtraHeadersArray);
verify(response, 200, "PONG", requestExtraHeadersArray);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,14 @@ static void verify(Response response, int status) {
assertEquals(status, response.status(), () -> "Actual response body: " + response.body());
}

static void verify(Response response, int status, String body) {
static void verify(Response response, int status, String body, String... headers) {
assertEquals(status, response.status(), () -> "Actual response body: " + response.body());
assertEquals(body, response.body());
for (int i = 0; i < headers.length; i += 2) {
String key = headers[i];
String value = headers[i + 1];
assertEquals(value, response.headers.get(key));
}
}

static void verifyJson(Response response, int status, String body) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ public void map(HttpMethod method, String path, int status) {
map(method, path, status, "");
}

public void map(HttpMethod method, String path, int status, String body) {
map(method, path, createResponse(status, body));
public void map(HttpMethod method, String path, int status, String body, String... headers) {
map(method, path, createResponse(status, body, headers));
}

private MockResponse onRequest(RecordedRequest request) {
Expand All @@ -67,10 +67,15 @@ private MockResponse onRequest(RecordedRequest request) {
return response;
}

private static MockResponse createResponse(int status, String body) {
public static MockResponse createResponse(int status, String body, String... headers) {
MockResponse response = new MockResponse();
response.setResponseCode(status);
response.setBody(body);
for (int i = 0; i < headers.length; i += 2) {
String key = headers[i];
String value = headers[i + 1];
response.setHeader(key, value);
}
return response;
}

Expand Down

0 comments on commit bd4fbdd

Please sign in to comment.