Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support multiple rolePath in IDP configuration #577

Merged
merged 3 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Priority order:
| identityProviders | - | Yes |Map of identity providers. **Note**: At least one identity provider must be provided. Refer to [examples](sample/aidial.settings.json) to view available providers. Refer to [IDP Configuration](https://github.com/epam/ai-dial/blob/main/docs/Auth/2.%20Web/1.overview.md) to view guidelines for configuring supported providers.
| identityProviders.*.jwksUrl | - | Optional |Url to jwks provider. **Required** if `disabledVerifyJwt` is set to `false`. **Note**: Either `jwksUrl` or `userInfoEndpoint` must be provided.
| identityProviders.*.userInfoEndpoint | - | Optional |Url to user info endpoint. **Note**: Either `jwksUrl` or `userInfoEndpoint` must be provided or `disableJwtVerification` is unset. Refer to [Google example](sample/aidial.settings.json).
| identityProviders.*.rolePath | - | Yes |Path to the claim user roles in JWT token or user info response, e.g. `resource_access.chatbot-ui.roles` or just `roles`. Refer to [IDP Configuration](https://github.com/epam/ai-dial/blob/main/docs/Auth/2.%20Web/1.overview.md) to view guidelines for configuring supported providers.
| identityProviders.*.rolePath | - | Yes |Path(s) to the claim user roles in JWT token or user info response, e.g. `resource_access.chatbot-ui.roles` or just `roles`. Can be single String or Array of Strings. Refer to [IDP Configuration](https://github.com/epam/ai-dial/blob/main/docs/Auth/2.%20Web/1.overview.md) to view guidelines for configuring supported providers.
| identityProviders.*.rolesDelimiter | - | No |Delimiter to split roles into array in case when list of roles presented as single String. e.g. `"rolesDelimiter": " "`
| identityProviders.*.loggingKey | - | No |User information to search in claims of JWT token. `email` or `sub` should be sufficient in most cases. **Note**: `email` might be unavailable for some IDPs. Please check your IDP documentation in this case.
| identityProviders.*.loggingSalt | - | No |Salt to hash user information for logging.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import io.vertx.core.http.HttpClient;
import io.vertx.core.http.HttpMethod;
import io.vertx.core.http.RequestOptions;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;
import lombok.extern.slf4j.Slf4j;
import org.apache.http.HttpHeaders;
Expand All @@ -23,6 +24,7 @@
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.interfaces.RSAPublicKey;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
Expand All @@ -33,13 +35,11 @@
import java.util.function.Function;
import java.util.regex.Pattern;

import static java.util.Collections.EMPTY_LIST;

@Slf4j
public class IdentityProvider {

// path to the claim of user roles in JWT
private final String[] rolePath;
// path(s) to the claim of user roles in JWT
private final List<String[]> rolePaths = new ArrayList<>();

// Delimiter to split the roles if they are set as a single String
private final String rolesDelimiter;
Expand Down Expand Up @@ -114,9 +114,23 @@ public IdentityProvider(JsonObject settings, Vertx vertx, HttpClient client,
}
}

String rolePathStr = Objects.requireNonNull(settings.getString("rolePath"), "rolePath is missed");
getUserRoleFn = factory.getUserRoleFn(rolePathStr);
rolePath = rolePathStr.split("\\.");
Object rolePathObj = Objects.requireNonNull(settings.getValue("rolePath"), "rolePath is missed");
List<String> rolePathList;

if (rolePathObj instanceof String rolePathStr) {
getUserRoleFn = factory.getUserRoleFn(rolePathStr);
rolePathList = List.of(rolePathStr);
} else if (rolePathObj instanceof JsonArray rolePathArray) {
getUserRoleFn = null;
rolePathList = rolePathArray.stream().map(o -> (String) o).toList();
} else {
throw new IllegalArgumentException("rolePath should be either String or Array");
}

for (String rolePath : rolePathList) {
rolePaths.add(rolePath.split("\\."));
}

rolesDelimiter = settings.getString("rolesDelimiter");

loggingKey = settings.getString("loggingKey");
Expand Down Expand Up @@ -149,32 +163,37 @@ private void evictExpiredJwks() {

@SuppressWarnings("unchecked")
private List<String> extractUserRoles(Map<String, Object> map) {
for (int i = 0; i < rolePath.length; i++) {
Object next = map.get(rolePath[i]);
if (next == null) {
return EMPTY_LIST;
}
if (i == rolePath.length - 1) {
if (next instanceof List) {
return (List<String>) next;
} else if (next instanceof String) {
if (rolesDelimiter != null) {
return Arrays.stream(((String) next)
.split(rolesDelimiter))
.filter(s -> !s.isBlank())
.toList();
}
return List.of((String) next);
List<String> result = new ArrayList<>();
astsiapanay marked this conversation as resolved.
Show resolved Hide resolved
for (String[] rolePath : rolePaths) {
Map<String, Object> mapPointer = map;
for (int i = 0; i < rolePath.length; i++) {
Object next = mapPointer.get(rolePath[i]);
if (next == null) {
break;
}
} else {
if (next instanceof Map) {
map = (Map<String, Object>) next;
if (i == rolePath.length - 1) {
if (next instanceof List) {
result.addAll((List<String>) next);
} else if (next instanceof String) {
if (rolesDelimiter != null) {
result.addAll(Arrays.stream(((String) next)
.split(rolesDelimiter))
.filter(s -> !s.isBlank())
.toList());
} else {
result.add((String) next);
}
}
} else {
return EMPTY_LIST;
if (next instanceof Map) {
mapPointer = (Map<String, Object>) next;
} else {
break;
}
}
}
}
return EMPTY_LIST;
return result;
}

public static DecodedJWT decodeJwtToken(String encodedToken) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,151 @@ public void testExtractClaims_20() throws JwkException {
});
}

@Test
public void testExtractClaims_21() throws JwkException {
settings.put("rolePath", List.of("roles", "roles2"));
settings.put("rolesDelimiter", " ");
IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory);
Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate());

String token = JWT.create().withHeader(Map.of("kid", "kid1"))
.withClaim("roles", "r1 r2 r3")
.withClaim("roles2", "r4 r5 r6")
.sign(algorithm);
Jwk jwk = mock(Jwk.class);
when(jwk.getPublicKey()).thenReturn(keyPair.getPublic());
when(jwkProvider.get(eq("kid1"))).thenReturn(jwk);
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

Future<ExtractedClaims> result = identityProvider.extractClaimsFromJwt(JWT.decode(token));

assertNotNull(result);
result.onComplete(res -> {
assertTrue(res.succeeded());
ExtractedClaims claims = res.result();
assertNotNull(claims);
assertEquals(List.of("r1", "r2", "r3", "r4", "r5", "r6"), claims.userRoles());
});
}

@Test
public void testExtractClaims_22() throws JwkException {
settings.put("rolePath", List.of("roles", "roles2"));
IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory);
Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate());

String token = JWT.create().withHeader(Map.of("kid", "kid1"))
.withClaim("roles", "r1")
.withClaim("roles2", "r2")
.sign(algorithm);
Jwk jwk = mock(Jwk.class);
when(jwk.getPublicKey()).thenReturn(keyPair.getPublic());
when(jwkProvider.get(eq("kid1"))).thenReturn(jwk);
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

Future<ExtractedClaims> result = identityProvider.extractClaimsFromJwt(JWT.decode(token));

assertNotNull(result);
result.onComplete(res -> {
assertTrue(res.succeeded());
ExtractedClaims claims = res.result();
assertNotNull(claims);
assertEquals(List.of("r1", "r2"), claims.userRoles());
});
}

@Test
public void testExtractClaims_23() throws JwkException {
settings.put("rolePath", List.of("roles", "roles2"));
IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory);
Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate());

String token = JWT.create().withHeader(Map.of("kid", "kid1"))
.withClaim("roles", List.of("r1", "r2", "r3"))
.sign(algorithm);
Jwk jwk = mock(Jwk.class);
when(jwk.getPublicKey()).thenReturn(keyPair.getPublic());
when(jwkProvider.get(eq("kid1"))).thenReturn(jwk);
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

Future<ExtractedClaims> result = identityProvider.extractClaimsFromJwt(JWT.decode(token));

assertNotNull(result);
result.onComplete(res -> {
assertTrue(res.succeeded());
ExtractedClaims claims = res.result();
assertNotNull(claims);
assertEquals(List.of("r1", "r2", "r3"), claims.userRoles());
});
}

@Test
public void testExtractClaims_24() throws JwkException {
settings.put("rolePath", List.of("roles", "roles2"));
IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory);
Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate());

String token = JWT.create().withHeader(Map.of("kid", "kid1"))
.withClaim("roles", List.of("r1", "r2", "r3"))
.withClaim("roles2", "r4")
.sign(algorithm);
Jwk jwk = mock(Jwk.class);
when(jwk.getPublicKey()).thenReturn(keyPair.getPublic());
when(jwkProvider.get(eq("kid1"))).thenReturn(jwk);
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

Future<ExtractedClaims> result = identityProvider.extractClaimsFromJwt(JWT.decode(token));

assertNotNull(result);
result.onComplete(res -> {
assertTrue(res.succeeded());
ExtractedClaims claims = res.result();
assertNotNull(claims);
assertEquals(List.of("r1", "r2", "r3", "r4"), claims.userRoles());
});
}

@Test
public void testExtractClaims_25() throws JwkException {
settings.put("rolePath", List.of("p0.p1.p2.p3", "p0.p1.p2.p4"));
IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory);
Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate());

Jwk jwk = mock(Jwk.class);
when(jwk.getPublicKey()).thenReturn(keyPair.getPublic());
when(jwkProvider.get(eq("kid1"))).thenReturn(jwk);
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});
Map<String, Object> claim = Map.of("some", "val", "k1", 12, "p1",
Map.of("p2", Map.of("p3", List.of("r1", "r2"), "p4", List.of("r3", "r4"))));

String token = JWT.create().withHeader(Map.of("kid", "kid1")).withClaim("p0", claim).sign(algorithm);

Future<ExtractedClaims> result = identityProvider.extractClaimsFromJwt(JWT.decode(token));

assertNotNull(result);
result.onComplete(res -> {
assertTrue(res.succeeded());
ExtractedClaims claims = res.result();
assertNotNull(claims);
assertEquals(List.of("r1", "r2", "r3", "r4"), claims.userRoles());
});
}

@Test
public void testExtractClaims_FromUserInfo_01() {
settings.remove("jwksUrl");
Expand Down
Loading