Skip to content

Commit

Permalink
Support multiple rolePath in IDP configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
akurnosau committed Nov 19, 2024
1 parent bf0ed85 commit b538f93
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 29 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Priority order:
| identityProviders | - | Yes |Map of identity providers. **Note**: At least one identity provider must be provided. Refer to [examples](sample/aidial.settings.json) to view available providers. Refer to [IDP Configuration](https://github.com/epam/ai-dial/blob/main/docs/Auth/2.%20Web/1.overview.md) to view guidelines for configuring supported providers.
| identityProviders.*.jwksUrl | - | Optional |Url to jwks provider. **Required** if `disabledVerifyJwt` is set to `false`. **Note**: Either `jwksUrl` or `userInfoEndpoint` must be provided.
| identityProviders.*.userInfoEndpoint | - | Optional |Url to user info endpoint. **Note**: Either `jwksUrl` or `userInfoEndpoint` must be provided or `disableJwtVerification` is unset. Refer to [Google example](sample/aidial.settings.json).
| identityProviders.*.rolePath | - | Yes |Path to the claim user roles in JWT token or user info response, e.g. `resource_access.chatbot-ui.roles` or just `roles`. Refer to [IDP Configuration](https://github.com/epam/ai-dial/blob/main/docs/Auth/2.%20Web/1.overview.md) to view guidelines for configuring supported providers.
| identityProviders.*.rolePath | - | Yes |Path(s) to the claim user roles in JWT token or user info response, e.g. `resource_access.chatbot-ui.roles` or just `roles`. Can be single String or Array of Strings. Refer to [IDP Configuration](https://github.com/epam/ai-dial/blob/main/docs/Auth/2.%20Web/1.overview.md) to view guidelines for configuring supported providers.
| identityProviders.*.rolesDelimiter | - | No |Delimiter to split roles into array in case when list of roles presented as single String. e.g. `"rolesDelimiter": " "`
| identityProviders.*.loggingKey | - | No |User information to search in claims of JWT token. `email` or `sub` should be sufficient in most cases. **Note**: `email` might be unavailable for some IDPs. Please check your IDP documentation in this case.
| identityProviders.*.loggingSalt | - | No |Salt to hash user information for logging.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import io.vertx.core.http.HttpClient;
import io.vertx.core.http.HttpMethod;
import io.vertx.core.http.RequestOptions;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;
import lombok.extern.slf4j.Slf4j;
import org.apache.http.HttpHeaders;
Expand All @@ -23,6 +24,7 @@
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.interfaces.RSAPublicKey;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
Expand All @@ -33,13 +35,11 @@
import java.util.function.Function;
import java.util.regex.Pattern;

import static java.util.Collections.EMPTY_LIST;

@Slf4j
public class IdentityProvider {

// path to the claim of user roles in JWT
private final String[] rolePath;
// path(s) to the claim of user roles in JWT
private final List<String[]> rolePaths = new ArrayList<>();

// Delimiter to split the roles if they are set as a single String
private final String rolesDelimiter;
Expand Down Expand Up @@ -114,9 +114,23 @@ public IdentityProvider(JsonObject settings, Vertx vertx, HttpClient client,
}
}

String rolePathStr = Objects.requireNonNull(settings.getString("rolePath"), "rolePath is missed");
getUserRoleFn = factory.getUserRoleFn(rolePathStr);
rolePath = rolePathStr.split("\\.");
Object rolePathObj = Objects.requireNonNull(settings.getValue("rolePath"), "rolePath is missed");
List<String> rolePathList;

if (rolePathObj instanceof String rolePathStr) {
getUserRoleFn = factory.getUserRoleFn(rolePathStr);
rolePathList = List.of(rolePathStr);
} else if (rolePathObj instanceof JsonArray rolePathArray){
getUserRoleFn = null;
rolePathList = rolePathArray.stream().map(o -> (String) o).toList();
} else {
throw new IllegalArgumentException("rolePath should be either String or Array");
}

for (String rolePath : rolePathList) {
rolePaths.add(rolePath.split("\\."));
}

rolesDelimiter = settings.getString("rolesDelimiter");

loggingKey = settings.getString("loggingKey");
Expand Down Expand Up @@ -149,32 +163,37 @@ private void evictExpiredJwks() {

@SuppressWarnings("unchecked")
private List<String> extractUserRoles(Map<String, Object> map) {
for (int i = 0; i < rolePath.length; i++) {
Object next = map.get(rolePath[i]);
if (next == null) {
return EMPTY_LIST;
}
if (i == rolePath.length - 1) {
if (next instanceof List) {
return (List<String>) next;
} else if (next instanceof String) {
if (rolesDelimiter != null) {
return Arrays.stream(((String) next)
.split(rolesDelimiter))
.filter(s -> !s.isBlank())
.toList();
}
return List.of((String) next);
List<String> result = new ArrayList<>();
for (String[] rolePath : rolePaths) {
Map<String, Object> mapPointer = map;
for (int i = 0; i < rolePath.length; i++) {
Object next = mapPointer.get(rolePath[i]);
if (next == null) {
break;
}
} else {
if (next instanceof Map) {
map = (Map<String, Object>) next;
if (i == rolePath.length - 1) {
if (next instanceof List) {
result.addAll((List<String>) next);
} else if (next instanceof String) {
if (rolesDelimiter != null) {
result.addAll(Arrays.stream(((String) next)
.split(rolesDelimiter))
.filter(s -> !s.isBlank())
.toList());
} else {
result.add((String) next);
}
}
} else {
return EMPTY_LIST;
if (next instanceof Map) {
mapPointer = (Map<String, Object>) next;
} else {
break;
}
}
}
}
return EMPTY_LIST;
return result;
}

public static DecodedJWT decodeJwtToken(String encodedToken) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,151 @@ public void testExtractClaims_20() throws JwkException {
});
}

@Test
public void testExtractClaims_21() throws JwkException {
settings.put("rolePath", List.of("roles", "roles2"));
settings.put("rolesDelimiter", " ");
IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory);
Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate());

String token = JWT.create().withHeader(Map.of("kid", "kid1"))
.withClaim("roles", "r1 r2 r3")
.withClaim("roles2", "r4 r5 r6")
.sign(algorithm);
Jwk jwk = mock(Jwk.class);
when(jwk.getPublicKey()).thenReturn(keyPair.getPublic());
when(jwkProvider.get(eq("kid1"))).thenReturn(jwk);
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

Future<ExtractedClaims> result = identityProvider.extractClaimsFromJwt(JWT.decode(token));

assertNotNull(result);
result.onComplete(res -> {
assertTrue(res.succeeded());
ExtractedClaims claims = res.result();
assertNotNull(claims);
assertEquals(List.of("r1", "r2", "r3", "r4", "r5", "r6"), claims.userRoles());
});
}

@Test
public void testExtractClaims_22() throws JwkException {
settings.put("rolePath", List.of("roles", "roles2"));
IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory);
Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate());

String token = JWT.create().withHeader(Map.of("kid", "kid1"))
.withClaim("roles", "r1")
.withClaim("roles2", "r2")
.sign(algorithm);
Jwk jwk = mock(Jwk.class);
when(jwk.getPublicKey()).thenReturn(keyPair.getPublic());
when(jwkProvider.get(eq("kid1"))).thenReturn(jwk);
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

Future<ExtractedClaims> result = identityProvider.extractClaimsFromJwt(JWT.decode(token));

assertNotNull(result);
result.onComplete(res -> {
assertTrue(res.succeeded());
ExtractedClaims claims = res.result();
assertNotNull(claims);
assertEquals(List.of("r1", "r2"), claims.userRoles());
});
}

@Test
public void testExtractClaims_23() throws JwkException {
settings.put("rolePath", List.of("roles", "roles2"));
IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory);
Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate());

String token = JWT.create().withHeader(Map.of("kid", "kid1"))
.withClaim("roles", List.of("r1", "r2", "r3"))
.sign(algorithm);
Jwk jwk = mock(Jwk.class);
when(jwk.getPublicKey()).thenReturn(keyPair.getPublic());
when(jwkProvider.get(eq("kid1"))).thenReturn(jwk);
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

Future<ExtractedClaims> result = identityProvider.extractClaimsFromJwt(JWT.decode(token));

assertNotNull(result);
result.onComplete(res -> {
assertTrue(res.succeeded());
ExtractedClaims claims = res.result();
assertNotNull(claims);
assertEquals(List.of("r1", "r2", "r3"), claims.userRoles());
});
}

@Test
public void testExtractClaims_24() throws JwkException {
settings.put("rolePath", List.of("roles", "roles2"));
IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory);
Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate());

String token = JWT.create().withHeader(Map.of("kid", "kid1"))
.withClaim("roles", List.of("r1", "r2", "r3"))
.withClaim("roles2", "r4")
.sign(algorithm);
Jwk jwk = mock(Jwk.class);
when(jwk.getPublicKey()).thenReturn(keyPair.getPublic());
when(jwkProvider.get(eq("kid1"))).thenReturn(jwk);
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

Future<ExtractedClaims> result = identityProvider.extractClaimsFromJwt(JWT.decode(token));

assertNotNull(result);
result.onComplete(res -> {
assertTrue(res.succeeded());
ExtractedClaims claims = res.result();
assertNotNull(claims);
assertEquals(List.of("r1", "r2", "r3", "r4"), claims.userRoles());
});
}

@Test
public void testExtractClaims_25() throws JwkException {
settings.put("rolePath", List.of("p0.p1.p2.p3", "p0.p1.p2.p4"));
IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory);
Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate());

Jwk jwk = mock(Jwk.class);
when(jwk.getPublicKey()).thenReturn(keyPair.getPublic());
when(jwkProvider.get(eq("kid1"))).thenReturn(jwk);
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});
Map<String, Object> claim = Map.of("some", "val", "k1", 12, "p1",
Map.of("p2", Map.of("p3", List.of("r1", "r2"), "p4", List.of("r3", "r4"))));

String token = JWT.create().withHeader(Map.of("kid", "kid1")).withClaim("p0", claim).sign(algorithm);

Future<ExtractedClaims> result = identityProvider.extractClaimsFromJwt(JWT.decode(token));

assertNotNull(result);
result.onComplete(res -> {
assertTrue(res.succeeded());
ExtractedClaims claims = res.result();
assertNotNull(claims);
assertEquals(List.of("r1", "r2", "r3", "r4"), claims.userRoles());
});
}

@Test
public void testExtractClaims_FromUserInfo_01() {
settings.remove("jwksUrl");
Expand Down

0 comments on commit b538f93

Please sign in to comment.