diff --git a/src/main/java/com/epam/aidial/core/Proxy.java b/src/main/java/com/epam/aidial/core/Proxy.java index f798a0da1..7f77b82d1 100644 --- a/src/main/java/com/epam/aidial/core/Proxy.java +++ b/src/main/java/com/epam/aidial/core/Proxy.java @@ -185,8 +185,23 @@ private void handleRequest(HttpServerRequest request) { respond(request, HttpStatus.UNAUTHORIZED, "At least API-KEY or Authorization header must be provided"); return; } else if (apiKey != null && authorization != null && !apiKey.equals(extractTokenFromHeader(authorization))) { - respond(request, HttpStatus.BAD_REQUEST, "Either API-KEY or Authorization header must be provided but not both"); - return; + // interceptor case + authorizationResultFuture = apiKeyStore.getApiKeyData(apiKey) + .onFailure(error -> onGettingApiKeyDataFailure(error, request)) + .compose(apiKeyData -> { + if (apiKeyData == null) { + String errorMessage = "Unknown api key"; + respond(request, HttpStatus.UNAUTHORIZED, errorMessage); + return Future.failedFuture(errorMessage); + } + if (apiKeyData.isInterceptor()) { + return Future.succeededFuture(new AuthorizationResult(apiKeyData, null)); + } else { + String errorMessage = "Either API-KEY or Authorization header must be provided but not both"; + respond(request, HttpStatus.BAD_REQUEST, errorMessage); + return Future.failedFuture(errorMessage); + } + }); } else if (apiKey != null) { authorizationResultFuture = apiKeyStore.getApiKeyData(apiKey) .onFailure(error -> onGettingApiKeyDataFailure(error, request)) diff --git a/src/main/java/com/epam/aidial/core/config/ApiKeyData.java b/src/main/java/com/epam/aidial/core/config/ApiKeyData.java index b2f4b332f..d798b12f5 100644 --- a/src/main/java/com/epam/aidial/core/config/ApiKeyData.java +++ b/src/main/java/com/epam/aidial/core/config/ApiKeyData.java @@ -4,6 +4,7 @@ import com.epam.aidial.core.data.AutoSharedData; import com.epam.aidial.core.data.ResourceAccessType; import com.epam.aidial.core.security.ExtractedClaims; +import com.fasterxml.jackson.annotation.JsonIgnore; import lombok.Data; import java.util.ArrayList; @@ -77,4 +78,9 @@ public static void initFromContext(ApiKeyData proxyApiKeyData, ProxyContext cont proxyApiKeyData.setSpanId(context.getSpanId()); proxyApiKeyData.setSourceDeployment(context.getDeployment().getName()); } + + @JsonIgnore + public boolean isInterceptor() { + return perRequestKey != null && interceptors != null && interceptorIndex >= 0 && interceptorIndex < interceptors.size(); + } } diff --git a/src/main/java/com/epam/aidial/core/controller/InterceptorController.java b/src/main/java/com/epam/aidial/core/controller/InterceptorController.java index e704f64f5..a037b441d 100644 --- a/src/main/java/com/epam/aidial/core/controller/InterceptorController.java +++ b/src/main/java/com/epam/aidial/core/controller/InterceptorController.java @@ -134,12 +134,7 @@ void handleProxyRequest(HttpClientRequest proxyRequest) { context.setProxyRequest(proxyRequest); context.setProxyConnectTimestamp(System.currentTimeMillis()); - MultiMap excludeHeaders = MultiMap.caseInsensitiveMultiMap(); - if (!context.getDeployment().isForwardAuthToken()) { - excludeHeaders.add(HttpHeaders.AUTHORIZATION, "whatever"); - } - - ProxyUtil.copyHeaders(request.headers(), proxyRequest.headers(), excludeHeaders); + ProxyUtil.copyHeaders(request.headers(), proxyRequest.headers()); ApiKeyData proxyApiKeyData = context.getProxyApiKeyData(); proxyRequest.headers().add(Proxy.HEADER_API_KEY, proxyApiKeyData.getPerRequestKey()); diff --git a/src/test/java/com/epam/aidial/core/ProxyTest.java b/src/test/java/com/epam/aidial/core/ProxyTest.java index 080d7a016..a3aa3993b 100644 --- a/src/test/java/com/epam/aidial/core/ProxyTest.java +++ b/src/test/java/com/epam/aidial/core/ProxyTest.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Answers; import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; @@ -69,7 +70,7 @@ public class ProxyTest { @Mock private BlobStorage storage; - @Mock + @Mock(answer = Answers.RETURNS_DEEP_STUBS) private HttpServerRequest request; @Mock @@ -164,7 +165,50 @@ public void testHandle_MissingApiKeyAndToken() { } @Test - public void testHandle_BothApiKeyAndToken() { + public void testHandle_BothApiKeyAndToken_ApiKeyNotFound() { + when(request.version()).thenReturn(HttpVersion.HTTP_1_1); + when(request.method()).thenReturn(HttpMethod.GET); + MultiMap headers = mock(MultiMap.class); + when(request.headers()).thenReturn(headers); + when(request.getHeader(eq(HttpHeaders.CONTENT_TYPE))).thenReturn(null); + when(request.getHeader(eq(HttpHeaders.AUTHORIZATION))).thenReturn("token"); + when(headers.get(eq(HEADER_API_KEY))).thenReturn("api-key"); + when(headers.get(eq(HttpHeaders.CONTENT_LENGTH))).thenReturn(Integer.toString(512)); + when(request.path()).thenReturn("/foo"); + + Config config = new Config(); + when(configStore.load()).thenReturn(config); + when(apiKeyStore.getApiKeyData(anyString())).thenReturn(Future.succeededFuture()); + + proxy.handle(request); + + verify(response).setStatusCode(UNAUTHORIZED.getCode()); + } + + @Test + public void testHandle_BothApiKeyAndToken_ApiKeyIsNotPerRequestKey() { + when(request.version()).thenReturn(HttpVersion.HTTP_1_1); + when(request.method()).thenReturn(HttpMethod.GET); + MultiMap headers = mock(MultiMap.class); + when(request.headers()).thenReturn(headers); + when(request.getHeader(eq(HttpHeaders.CONTENT_TYPE))).thenReturn(null); + when(request.getHeader(eq(HttpHeaders.AUTHORIZATION))).thenReturn("token"); + when(headers.get(eq(HEADER_API_KEY))).thenReturn("api-key"); + when(headers.get(eq(HttpHeaders.CONTENT_LENGTH))).thenReturn(Integer.toString(512)); + when(request.path()).thenReturn("/foo"); + + Config config = new Config(); + when(configStore.load()).thenReturn(config); + ApiKeyData apiKeyData = new ApiKeyData(); + when(apiKeyStore.getApiKeyData(anyString())).thenReturn(Future.succeededFuture(apiKeyData)); + + proxy.handle(request); + + verify(response).setStatusCode(BAD_REQUEST.getCode()); + } + + @Test + public void testHandle_BothApiKeyAndToken_CallerIsNotInterceptor_1() { when(request.version()).thenReturn(HttpVersion.HTTP_1_1); when(request.method()).thenReturn(HttpMethod.GET); MultiMap headers = mock(MultiMap.class); @@ -175,11 +219,77 @@ public void testHandle_BothApiKeyAndToken() { when(headers.get(eq(HttpHeaders.CONTENT_LENGTH))).thenReturn(Integer.toString(512)); when(request.path()).thenReturn("/foo"); + Config config = new Config(); + when(configStore.load()).thenReturn(config); + ApiKeyData apiKeyData = new ApiKeyData(); + apiKeyData.setPerRequestKey("per-request_key"); + when(apiKeyStore.getApiKeyData(anyString())).thenReturn(Future.succeededFuture(apiKeyData)); + proxy.handle(request); verify(response).setStatusCode(BAD_REQUEST.getCode()); } + @Test + public void testHandle_BothApiKeyAndToken_CallerIsNotInterceptor_2() { + when(request.version()).thenReturn(HttpVersion.HTTP_1_1); + when(request.method()).thenReturn(HttpMethod.GET); + MultiMap headers = mock(MultiMap.class); + when(request.headers()).thenReturn(headers); + when(request.getHeader(eq(HttpHeaders.CONTENT_TYPE))).thenReturn(null); + when(request.getHeader(eq(HttpHeaders.AUTHORIZATION))).thenReturn("token"); + when(headers.get(eq(HEADER_API_KEY))).thenReturn("api-key"); + when(headers.get(eq(HttpHeaders.CONTENT_LENGTH))).thenReturn(Integer.toString(512)); + when(request.path()).thenReturn("/foo"); + + Config config = new Config(); + when(configStore.load()).thenReturn(config); + ApiKeyData apiKeyData = new ApiKeyData(); + apiKeyData.setPerRequestKey("per-request_key"); + apiKeyData.setInterceptors(List.of("interceptor1", "interceptor2")); + apiKeyData.setInterceptorIndex(2); + when(apiKeyStore.getApiKeyData(anyString())).thenReturn(Future.succeededFuture(apiKeyData)); + + proxy.handle(request); + + verify(response).setStatusCode(BAD_REQUEST.getCode()); + } + + @Test + public void testHandle_BothApiKeyAndToken_CallerIsInterceptor() { + when(request.version()).thenReturn(HttpVersion.HTTP_1_1); + when(request.method()).thenReturn(HttpMethod.GET); + MultiMap headers = mock(MultiMap.class); + when(request.headers()).thenReturn(headers); + when(request.getHeader(eq(HttpHeaders.CONTENT_TYPE))).thenReturn(null); + when(request.getHeader(eq(HttpHeaders.AUTHORIZATION))).thenReturn("token"); + when(headers.get(eq(HEADER_API_KEY))).thenReturn("api-key"); + when(headers.get(eq(HttpHeaders.CONTENT_LENGTH))).thenReturn(Integer.toString(512)); + when(request.path()).thenReturn("/foo"); + when(request.uri()).thenReturn("/foo"); + + Config config = new Config(); + Route route = new Route(); + route.setMethods(Set.of(HttpMethod.GET)); + route.setName("route"); + route.setPaths(List.of(Pattern.compile("/foo"))); + route.setResponse(new Route.Response()); + LinkedHashMap routes = new LinkedHashMap<>(); + routes.put("route", route); + config.setRoutes(routes); + when(configStore.load()).thenReturn(config); + + ApiKeyData apiKeyData = new ApiKeyData(); + apiKeyData.setPerRequestKey("per-request_key"); + apiKeyData.setInterceptors(List.of("interceptor1", "interceptor2")); + apiKeyData.setInterceptorIndex(1); + when(apiKeyStore.getApiKeyData(anyString())).thenReturn(Future.succeededFuture(apiKeyData)); + + proxy.handle(request); + + verify(response).setStatusCode(OK.getCode()); + } + @Test public void testHandle_UnknownApiKey() { when(request.version()).thenReturn(HttpVersion.HTTP_1_1);