diff --git a/src/main/java/com/epam/aidial/core/security/IdentityProvider.java b/src/main/java/com/epam/aidial/core/security/IdentityProvider.java index dc3b5b884..d0b7e5360 100644 --- a/src/main/java/com/epam/aidial/core/security/IdentityProvider.java +++ b/src/main/java/com/epam/aidial/core/security/IdentityProvider.java @@ -152,6 +152,8 @@ private List extractUserRoles(Map map) { if (i == rolePath.length - 1) { if (next instanceof List) { return (List) next; + } else if (next instanceof String) { + return List.of((String) next); } } else { if (next instanceof Map) { diff --git a/src/test/java/com/epam/aidial/core/security/IdentityProviderTest.java b/src/test/java/com/epam/aidial/core/security/IdentityProviderTest.java index 0a00aba0b..87fbadfb0 100644 --- a/src/test/java/com/epam/aidial/core/security/IdentityProviderTest.java +++ b/src/test/java/com/epam/aidial/core/security/IdentityProviderTest.java @@ -360,6 +360,51 @@ public void testExtractClaims_13() { }); } + @Test + public void testExtractClaims_14() { + settings.put("disableJwtVerification", Boolean.TRUE); + IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory); + Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate()); + + String token = JWT.create().withHeader(Map.of("kid", "kid1")) + // test with a single role as a string field + .withClaim("roles", "role") + .withClaim("email", "test@email.com") + .withClaim("id", 15) + .withClaim("title", "title") + .withClaim("access", List.of("read", "write")) + .withClaim("expire", new Date(1713355825858L)) + .withClaim("numberList", List.of("15", "17", "34")) + .withClaim("map", Map.of("a", List.of("b"))) + .withClaim("sub", "sub").sign(algorithm); + + Future result = identityProvider.extractClaimsFromJwt(JWT.decode(token)); + + verifyNoInteractions(jwkProvider); + + assertNotNull(result); + result.onComplete(res -> { + assertTrue(res.succeeded()); + ExtractedClaims claims = res.result(); + assertNotNull(claims); + assertEquals(List.of("role"), claims.userRoles()); + assertEquals("sub", claims.sub()); + assertNotNull(claims.userHash()); + Map> userClaims = claims.userClaims(); + // assert user claim + assertEquals(9, userClaims.size()); + assertEquals(List.of("sub"), userClaims.get("sub")); + assertEquals(List.of("read", "write"), userClaims.get("access")); + assertEquals(List.of("role"), userClaims.get("roles")); + assertEquals(List.of(), userClaims.get("expire")); + assertEquals(List.of("15", "17", "34"), userClaims.get("numberList")); + assertEquals(List.of(), userClaims.get("id")); + assertEquals(List.of("title"), userClaims.get("title")); + assertEquals(List.of(), userClaims.get("map")); + assertEquals(List.of("test@email.com"), userClaims.get("email")); + }); + } + @Test public void testExtractClaims_FromUserInfo_01() { settings.remove("jwksUrl");