Skip to content

Commit

Permalink
feat: Allow interceptors to pass authZ header along with API key #478
Browse files Browse the repository at this point in the history
  • Loading branch information
astsiapanay committed Sep 10, 2024
1 parent 9651594 commit dc1e93e
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 10 deletions.
19 changes: 17 additions & 2 deletions src/main/java/com/epam/aidial/core/Proxy.java
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,23 @@ private void handleRequest(HttpServerRequest request) {
respond(request, HttpStatus.UNAUTHORIZED, "At least API-KEY or Authorization header must be provided");
return;
} else if (apiKey != null && authorization != null && !apiKey.equals(extractTokenFromHeader(authorization))) {
respond(request, HttpStatus.BAD_REQUEST, "Either API-KEY or Authorization header must be provided but not both");
return;
// interceptor case
authorizationResultFuture = apiKeyStore.getApiKeyData(apiKey)
.onFailure(error -> onGettingApiKeyDataFailure(error, request))
.compose(apiKeyData -> {
if (apiKeyData == null) {
String errorMessage = "Unknown api key";
respond(request, HttpStatus.UNAUTHORIZED, errorMessage);
return Future.failedFuture(errorMessage);
}
if (apiKeyData.isInterceptor()) {
return Future.succeededFuture(new AuthorizationResult(apiKeyData, null));
} else {
String errorMessage = "Either API-KEY or Authorization header must be provided but not both";
respond(request, HttpStatus.BAD_REQUEST, errorMessage);
return Future.failedFuture(errorMessage);
}
});
} else if (apiKey != null) {
authorizationResultFuture = apiKeyStore.getApiKeyData(apiKey)
.onFailure(error -> onGettingApiKeyDataFailure(error, request))
Expand Down
6 changes: 6 additions & 0 deletions src/main/java/com/epam/aidial/core/config/ApiKeyData.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.epam.aidial.core.data.AutoSharedData;
import com.epam.aidial.core.data.ResourceAccessType;
import com.epam.aidial.core.security.ExtractedClaims;
import com.fasterxml.jackson.annotation.JsonIgnore;
import lombok.Data;

import java.util.ArrayList;
Expand Down Expand Up @@ -77,4 +78,9 @@ public static void initFromContext(ApiKeyData proxyApiKeyData, ProxyContext cont
proxyApiKeyData.setSpanId(context.getSpanId());
proxyApiKeyData.setSourceDeployment(context.getDeployment().getName());
}

@JsonIgnore
public boolean isInterceptor() {
return perRequestKey != null && interceptors != null && interceptorIndex >= 0 && interceptorIndex < interceptors.size();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,7 @@ void handleProxyRequest(HttpClientRequest proxyRequest) {
context.setProxyRequest(proxyRequest);
context.setProxyConnectTimestamp(System.currentTimeMillis());

MultiMap excludeHeaders = MultiMap.caseInsensitiveMultiMap();
if (!context.getDeployment().isForwardAuthToken()) {
excludeHeaders.add(HttpHeaders.AUTHORIZATION, "whatever");
}

ProxyUtil.copyHeaders(request.headers(), proxyRequest.headers(), excludeHeaders);
ProxyUtil.copyHeaders(request.headers(), proxyRequest.headers());

ApiKeyData proxyApiKeyData = context.getProxyApiKeyData();
proxyRequest.headers().add(Proxy.HEADER_API_KEY, proxyApiKeyData.getPerRequestKey());
Expand Down
114 changes: 112 additions & 2 deletions src/test/java/com/epam/aidial/core/ProxyTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Answers;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
Expand Down Expand Up @@ -69,7 +70,7 @@ public class ProxyTest {
@Mock
private BlobStorage storage;

@Mock
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
private HttpServerRequest request;

@Mock
Expand Down Expand Up @@ -164,7 +165,50 @@ public void testHandle_MissingApiKeyAndToken() {
}

@Test
public void testHandle_BothApiKeyAndToken() {
public void testHandle_BothApiKeyAndToken_ApiKeyNotFound() {
when(request.version()).thenReturn(HttpVersion.HTTP_1_1);
when(request.method()).thenReturn(HttpMethod.GET);
MultiMap headers = mock(MultiMap.class);
when(request.headers()).thenReturn(headers);
when(request.getHeader(eq(HttpHeaders.CONTENT_TYPE))).thenReturn(null);
when(request.getHeader(eq(HttpHeaders.AUTHORIZATION))).thenReturn("token");
when(headers.get(eq(HEADER_API_KEY))).thenReturn("api-key");
when(headers.get(eq(HttpHeaders.CONTENT_LENGTH))).thenReturn(Integer.toString(512));
when(request.path()).thenReturn("/foo");

Config config = new Config();
when(configStore.load()).thenReturn(config);
when(apiKeyStore.getApiKeyData(anyString())).thenReturn(Future.succeededFuture());

proxy.handle(request);

verify(response).setStatusCode(UNAUTHORIZED.getCode());
}

@Test
public void testHandle_BothApiKeyAndToken_ApiKeyIsNotPerRequestKey() {
when(request.version()).thenReturn(HttpVersion.HTTP_1_1);
when(request.method()).thenReturn(HttpMethod.GET);
MultiMap headers = mock(MultiMap.class);
when(request.headers()).thenReturn(headers);
when(request.getHeader(eq(HttpHeaders.CONTENT_TYPE))).thenReturn(null);
when(request.getHeader(eq(HttpHeaders.AUTHORIZATION))).thenReturn("token");
when(headers.get(eq(HEADER_API_KEY))).thenReturn("api-key");
when(headers.get(eq(HttpHeaders.CONTENT_LENGTH))).thenReturn(Integer.toString(512));
when(request.path()).thenReturn("/foo");

Config config = new Config();
when(configStore.load()).thenReturn(config);
ApiKeyData apiKeyData = new ApiKeyData();
when(apiKeyStore.getApiKeyData(anyString())).thenReturn(Future.succeededFuture(apiKeyData));

proxy.handle(request);

verify(response).setStatusCode(BAD_REQUEST.getCode());
}

@Test
public void testHandle_BothApiKeyAndToken_CallerIsNotInterceptor_1() {
when(request.version()).thenReturn(HttpVersion.HTTP_1_1);
when(request.method()).thenReturn(HttpMethod.GET);
MultiMap headers = mock(MultiMap.class);
Expand All @@ -175,11 +219,77 @@ public void testHandle_BothApiKeyAndToken() {
when(headers.get(eq(HttpHeaders.CONTENT_LENGTH))).thenReturn(Integer.toString(512));
when(request.path()).thenReturn("/foo");

Config config = new Config();
when(configStore.load()).thenReturn(config);
ApiKeyData apiKeyData = new ApiKeyData();
apiKeyData.setPerRequestKey("per-request_key");
when(apiKeyStore.getApiKeyData(anyString())).thenReturn(Future.succeededFuture(apiKeyData));

proxy.handle(request);

verify(response).setStatusCode(BAD_REQUEST.getCode());
}

@Test
public void testHandle_BothApiKeyAndToken_CallerIsNotInterceptor_2() {
when(request.version()).thenReturn(HttpVersion.HTTP_1_1);
when(request.method()).thenReturn(HttpMethod.GET);
MultiMap headers = mock(MultiMap.class);
when(request.headers()).thenReturn(headers);
when(request.getHeader(eq(HttpHeaders.CONTENT_TYPE))).thenReturn(null);
when(request.getHeader(eq(HttpHeaders.AUTHORIZATION))).thenReturn("token");
when(headers.get(eq(HEADER_API_KEY))).thenReturn("api-key");
when(headers.get(eq(HttpHeaders.CONTENT_LENGTH))).thenReturn(Integer.toString(512));
when(request.path()).thenReturn("/foo");

Config config = new Config();
when(configStore.load()).thenReturn(config);
ApiKeyData apiKeyData = new ApiKeyData();
apiKeyData.setPerRequestKey("per-request_key");
apiKeyData.setInterceptors(List.of("interceptor1", "interceptor2"));
apiKeyData.setInterceptorIndex(2);
when(apiKeyStore.getApiKeyData(anyString())).thenReturn(Future.succeededFuture(apiKeyData));

proxy.handle(request);

verify(response).setStatusCode(BAD_REQUEST.getCode());
}

@Test
public void testHandle_BothApiKeyAndToken_CallerIsInterceptor() {
when(request.version()).thenReturn(HttpVersion.HTTP_1_1);
when(request.method()).thenReturn(HttpMethod.GET);
MultiMap headers = mock(MultiMap.class);
when(request.headers()).thenReturn(headers);
when(request.getHeader(eq(HttpHeaders.CONTENT_TYPE))).thenReturn(null);
when(request.getHeader(eq(HttpHeaders.AUTHORIZATION))).thenReturn("token");
when(headers.get(eq(HEADER_API_KEY))).thenReturn("api-key");
when(headers.get(eq(HttpHeaders.CONTENT_LENGTH))).thenReturn(Integer.toString(512));
when(request.path()).thenReturn("/foo");
when(request.uri()).thenReturn("/foo");

Config config = new Config();
Route route = new Route();
route.setMethods(Set.of(HttpMethod.GET));
route.setName("route");
route.setPaths(List.of(Pattern.compile("/foo")));
route.setResponse(new Route.Response());
LinkedHashMap<String, Route> routes = new LinkedHashMap<>();
routes.put("route", route);
config.setRoutes(routes);
when(configStore.load()).thenReturn(config);

ApiKeyData apiKeyData = new ApiKeyData();
apiKeyData.setPerRequestKey("per-request_key");
apiKeyData.setInterceptors(List.of("interceptor1", "interceptor2"));
apiKeyData.setInterceptorIndex(1);
when(apiKeyStore.getApiKeyData(anyString())).thenReturn(Future.succeededFuture(apiKeyData));

proxy.handle(request);

verify(response).setStatusCode(OK.getCode());
}

@Test
public void testHandle_UnknownApiKey() {
when(request.version()).thenReturn(HttpVersion.HTTP_1_1);
Expand Down

0 comments on commit dc1e93e

Please sign in to comment.