diff --git a/README.md b/README.md index ff525a427..089b0ecec 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ Priority order: | identityProviders | - | Yes |Map of identity providers. **Note**: At least one identity provider must be provided. Refer to [examples](sample/aidial.settings.json) to view available providers. Refer to [IDP Configuration](https://github.com/epam/ai-dial/blob/main/docs/Auth/2.%20Web/1.overview.md) to view guidelines for configuring supported providers. | identityProviders.*.jwksUrl | - | Optional |Url to jwks provider. **Required** if `disabledVerifyJwt` is set to `false`. **Note**: Either `jwksUrl` or `userInfoEndpoint` must be provided. | identityProviders.*.userInfoEndpoint | - | Optional |Url to user info endpoint. **Note**: Either `jwksUrl` or `userInfoEndpoint` must be provided or `disableJwtVerification` is unset. Refer to [Google example](sample/aidial.settings.json). -| identityProviders.*.rolePath | - | Yes |Path to the claim user roles in JWT token or user info response, e.g. `resource_access.chatbot-ui.roles` or just `roles`. Refer to [IDP Configuration](https://github.com/epam/ai-dial/blob/main/docs/Auth/2.%20Web/1.overview.md) to view guidelines for configuring supported providers. +| identityProviders.*.rolePath | - | Yes |Path(s) to the claim user roles in JWT token or user info response, e.g. `resource_access.chatbot-ui.roles` or just `roles`. Can be single String or Array of Strings. Refer to [IDP Configuration](https://github.com/epam/ai-dial/blob/main/docs/Auth/2.%20Web/1.overview.md) to view guidelines for configuring supported providers. | identityProviders.*.rolesDelimiter | - | No |Delimiter to split roles into array in case when list of roles presented as single String. e.g. `"rolesDelimiter": " "` | identityProviders.*.loggingKey | - | No |User information to search in claims of JWT token. `email` or `sub` should be sufficient in most cases. **Note**: `email` might be unavailable for some IDPs. Please check your IDP documentation in this case. | identityProviders.*.loggingSalt | - | No |Salt to hash user information for logging. diff --git a/server/src/main/java/com/epam/aidial/core/server/security/IdentityProvider.java b/server/src/main/java/com/epam/aidial/core/server/security/IdentityProvider.java index 521d8e116..209eed08c 100644 --- a/server/src/main/java/com/epam/aidial/core/server/security/IdentityProvider.java +++ b/server/src/main/java/com/epam/aidial/core/server/security/IdentityProvider.java @@ -13,6 +13,7 @@ import io.vertx.core.http.HttpClient; import io.vertx.core.http.HttpMethod; import io.vertx.core.http.RequestOptions; +import io.vertx.core.json.JsonArray; import io.vertx.core.json.JsonObject; import lombok.extern.slf4j.Slf4j; import org.apache.http.HttpHeaders; @@ -23,6 +24,7 @@ import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.security.interfaces.RSAPublicKey; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -33,13 +35,11 @@ import java.util.function.Function; import java.util.regex.Pattern; -import static java.util.Collections.EMPTY_LIST; - @Slf4j public class IdentityProvider { - // path to the claim of user roles in JWT - private final String[] rolePath; + // path(s) to the claim of user roles in JWT + private final List rolePaths = new ArrayList<>(); // Delimiter to split the roles if they are set as a single String private final String rolesDelimiter; @@ -114,9 +114,23 @@ public IdentityProvider(JsonObject settings, Vertx vertx, HttpClient client, } } - String rolePathStr = Objects.requireNonNull(settings.getString("rolePath"), "rolePath is missed"); - getUserRoleFn = factory.getUserRoleFn(rolePathStr); - rolePath = rolePathStr.split("\\."); + Object rolePathObj = Objects.requireNonNull(settings.getValue("rolePath"), "rolePath is missed"); + List rolePathList; + + if (rolePathObj instanceof String rolePathStr) { + getUserRoleFn = factory.getUserRoleFn(rolePathStr); + rolePathList = List.of(rolePathStr); + } else if (rolePathObj instanceof JsonArray rolePathArray){ + getUserRoleFn = null; + rolePathList = rolePathArray.stream().map(o -> (String) o).toList(); + } else { + throw new IllegalArgumentException("rolePath should be either String or Array"); + } + + for (String rolePath : rolePathList) { + rolePaths.add(rolePath.split("\\.")); + } + rolesDelimiter = settings.getString("rolesDelimiter"); loggingKey = settings.getString("loggingKey"); @@ -149,32 +163,37 @@ private void evictExpiredJwks() { @SuppressWarnings("unchecked") private List extractUserRoles(Map map) { - for (int i = 0; i < rolePath.length; i++) { - Object next = map.get(rolePath[i]); - if (next == null) { - return EMPTY_LIST; - } - if (i == rolePath.length - 1) { - if (next instanceof List) { - return (List) next; - } else if (next instanceof String) { - if (rolesDelimiter != null) { - return Arrays.stream(((String) next) - .split(rolesDelimiter)) - .filter(s -> !s.isBlank()) - .toList(); - } - return List.of((String) next); + List result = new ArrayList<>(); + for (String[] rolePath : rolePaths) { + Map mapPointer = map; + for (int i = 0; i < rolePath.length; i++) { + Object next = mapPointer.get(rolePath[i]); + if (next == null) { + break; } - } else { - if (next instanceof Map) { - map = (Map) next; + if (i == rolePath.length - 1) { + if (next instanceof List) { + result.addAll((List) next); + } else if (next instanceof String) { + if (rolesDelimiter != null) { + result.addAll(Arrays.stream(((String) next) + .split(rolesDelimiter)) + .filter(s -> !s.isBlank()) + .toList()); + } else { + result.add((String) next); + } + } } else { - return EMPTY_LIST; + if (next instanceof Map) { + mapPointer = (Map) next; + } else { + break; + } } } } - return EMPTY_LIST; + return result; } public static DecodedJWT decodeJwtToken(String encodedToken) { diff --git a/server/src/test/java/com/epam/aidial/core/server/security/IdentityProviderTest.java b/server/src/test/java/com/epam/aidial/core/server/security/IdentityProviderTest.java index dcef9ddb9..09fb57944 100644 --- a/server/src/test/java/com/epam/aidial/core/server/security/IdentityProviderTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/security/IdentityProviderTest.java @@ -561,6 +561,151 @@ public void testExtractClaims_20() throws JwkException { }); } + @Test + public void testExtractClaims_21() throws JwkException { + settings.put("rolePath", List.of("roles", "roles2")); + settings.put("rolesDelimiter", " "); + IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory); + Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate()); + + String token = JWT.create().withHeader(Map.of("kid", "kid1")) + .withClaim("roles", "r1 r2 r3") + .withClaim("roles2", "r4 r5 r6") + .sign(algorithm); + Jwk jwk = mock(Jwk.class); + when(jwk.getPublicKey()).thenReturn(keyPair.getPublic()); + when(jwkProvider.get(eq("kid1"))).thenReturn(jwk); + when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> { + Callable callable = invocation.getArgument(0); + return Future.succeededFuture(callable.call()); + }); + + Future result = identityProvider.extractClaimsFromJwt(JWT.decode(token)); + + assertNotNull(result); + result.onComplete(res -> { + assertTrue(res.succeeded()); + ExtractedClaims claims = res.result(); + assertNotNull(claims); + assertEquals(List.of("r1", "r2", "r3", "r4", "r5", "r6"), claims.userRoles()); + }); + } + + @Test + public void testExtractClaims_22() throws JwkException { + settings.put("rolePath", List.of("roles", "roles2")); + IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory); + Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate()); + + String token = JWT.create().withHeader(Map.of("kid", "kid1")) + .withClaim("roles", "r1") + .withClaim("roles2", "r2") + .sign(algorithm); + Jwk jwk = mock(Jwk.class); + when(jwk.getPublicKey()).thenReturn(keyPair.getPublic()); + when(jwkProvider.get(eq("kid1"))).thenReturn(jwk); + when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> { + Callable callable = invocation.getArgument(0); + return Future.succeededFuture(callable.call()); + }); + + Future result = identityProvider.extractClaimsFromJwt(JWT.decode(token)); + + assertNotNull(result); + result.onComplete(res -> { + assertTrue(res.succeeded()); + ExtractedClaims claims = res.result(); + assertNotNull(claims); + assertEquals(List.of("r1", "r2"), claims.userRoles()); + }); + } + + @Test + public void testExtractClaims_23() throws JwkException { + settings.put("rolePath", List.of("roles", "roles2")); + IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory); + Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate()); + + String token = JWT.create().withHeader(Map.of("kid", "kid1")) + .withClaim("roles", List.of("r1", "r2", "r3")) + .sign(algorithm); + Jwk jwk = mock(Jwk.class); + when(jwk.getPublicKey()).thenReturn(keyPair.getPublic()); + when(jwkProvider.get(eq("kid1"))).thenReturn(jwk); + when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> { + Callable callable = invocation.getArgument(0); + return Future.succeededFuture(callable.call()); + }); + + Future result = identityProvider.extractClaimsFromJwt(JWT.decode(token)); + + assertNotNull(result); + result.onComplete(res -> { + assertTrue(res.succeeded()); + ExtractedClaims claims = res.result(); + assertNotNull(claims); + assertEquals(List.of("r1", "r2", "r3"), claims.userRoles()); + }); + } + + @Test + public void testExtractClaims_24() throws JwkException { + settings.put("rolePath", List.of("roles", "roles2")); + IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory); + Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate()); + + String token = JWT.create().withHeader(Map.of("kid", "kid1")) + .withClaim("roles", List.of("r1", "r2", "r3")) + .withClaim("roles2", "r4") + .sign(algorithm); + Jwk jwk = mock(Jwk.class); + when(jwk.getPublicKey()).thenReturn(keyPair.getPublic()); + when(jwkProvider.get(eq("kid1"))).thenReturn(jwk); + when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> { + Callable callable = invocation.getArgument(0); + return Future.succeededFuture(callable.call()); + }); + + Future result = identityProvider.extractClaimsFromJwt(JWT.decode(token)); + + assertNotNull(result); + result.onComplete(res -> { + assertTrue(res.succeeded()); + ExtractedClaims claims = res.result(); + assertNotNull(claims); + assertEquals(List.of("r1", "r2", "r3", "r4"), claims.userRoles()); + }); + } + + @Test + public void testExtractClaims_25() throws JwkException { + settings.put("rolePath", List.of("p0.p1.p2.p3", "p0.p1.p2.p4")); + IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory); + Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate()); + + Jwk jwk = mock(Jwk.class); + when(jwk.getPublicKey()).thenReturn(keyPair.getPublic()); + when(jwkProvider.get(eq("kid1"))).thenReturn(jwk); + when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> { + Callable callable = invocation.getArgument(0); + return Future.succeededFuture(callable.call()); + }); + Map claim = Map.of("some", "val", "k1", 12, "p1", + Map.of("p2", Map.of("p3", List.of("r1", "r2"), "p4", List.of("r3", "r4")))); + + String token = JWT.create().withHeader(Map.of("kid", "kid1")).withClaim("p0", claim).sign(algorithm); + + Future result = identityProvider.extractClaimsFromJwt(JWT.decode(token)); + + assertNotNull(result); + result.onComplete(res -> { + assertTrue(res.succeeded()); + ExtractedClaims claims = res.result(); + assertNotNull(claims); + assertEquals(List.of("r1", "r2", "r3", "r4"), claims.userRoles()); + }); + } + @Test public void testExtractClaims_FromUserInfo_01() { settings.remove("jwksUrl");