Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support multiple idps #75

Merged
merged 9 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,28 @@ Static settings are used on startup and cannot be changed while application is r
* File specified in "AIDIAL_SETTINGS" environment variable.
* Default resource file: src/main/resources/aidial.settings.json.

| Setting | Default |Description
|--------------------------------------|--------------------|-
| config.files | aidial.config.json |Config files with parts of the whole config.
| config.reload | 60000 |Config reload interval in milliseconds.
| identityProvider.jwksUrl | - |Url to jwks provider.
| identityProvider.rolePath | - |Path to the claim user roles in JWT token, e.g. `resource_access.chatbot-ui.roles` or just `roles`.
| identityProvider.loggingKey | - |User information to search in claims of JWT token.
| identityProvider.loggingSalt | - |Salt to hash user information for logging.
| identityProvider.cacheSize | 10 |How many JWT tokens to cache.
| identityProvider.cacheExpiration | 10 |How long to retain JWT token in cache.
| identityProvider.cacheExpirationUnit | MINUTES |Unit of cache expiration.
| vertx.* | - |Vertx settings.
| server.* | - |Vertx HTTP server settings for incoming requests.
| client.* | - |Vertx HTTP client settings for outbound requests.
| storage.provider | - |Specifies blob storage provider. Supported providers: s3, aws-s3, azureblob, google-cloud-storage
| storage.endpoint | - |Optional. Specifies endpoint url for s3 compatible storages
| storage.identity | - |Blob storage access key
| storage.credential | - |Blob storage secret key
| storage.bucket | - |Blob storage bucket
| storage.createBucket | false |Indicates whether bucket should be created on start-up
| Setting | Default |Description
|------------------------------------------|--------------------|-
| config.files | aidial.config.json |Config files with parts of the whole config.
| config.reload | 60000 |Config reload interval in milliseconds.
| identityProviders | - |List of identity providers
| identityProviders[*].jwksUrl | - |Url to jwks provider.
Maxim-Gadalov marked this conversation as resolved.
Show resolved Hide resolved
| identityProviders[*].rolePath | - |Path to the claim user roles in JWT token, e.g. `resource_access.chatbot-ui.roles` or just `roles`.
| identityProviders[*].loggingKey | - |User information to search in claims of JWT token.
| identityProviders[*].loggingSalt | - |Salt to hash user information for logging.
| identityProviders[*].cacheSize | 10 |How many JWT tokens to cache.
| identityProviders[*].cacheExpiration | 10 |How long to retain JWT token in cache.
| identityProviders[*].cacheExpirationUnit | MINUTES |Unit of cache expiration.
| identityProviders[*].issuerPattern | - |Regexp to match the claim "iss" to identity provider
| vertx.* | - |Vertx settings.
| server.* | - |Vertx HTTP server settings for incoming requests.
| client.* | - |Vertx HTTP client settings for outbound requests.
| storage.provider | - |Specifies blob storage provider. Supported providers: s3, aws-s3, azureblob, google-cloud-storage
| storage.endpoint | - |Optional. Specifies endpoint url for s3 compatible storages
| storage.identity | - |Blob storage access key
| storage.credential | - |Blob storage secret key
| storage.bucket | - |Blob storage bucket
| storage.createBucket | false |Indicates whether bucket should be created on start-up

### Dynamic settings
Dynamic settings are stored in JSON files, specified via "config.files" static setting, and reloaded at interval, specified via "config.reload" static setting.
Expand Down
13 changes: 4 additions & 9 deletions src/main/java/com/epam/aidial/core/AiDial.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.epam.aidial.core.limiter.RateLimiter;
import com.epam.aidial.core.log.GfLogStore;
import com.epam.aidial.core.log.LogStore;
import com.epam.aidial.core.security.AccessTokenValidator;
import com.epam.aidial.core.security.IdentityProvider;
import com.epam.aidial.core.storage.BlobStorage;
import com.epam.aidial.core.upstream.UpstreamBalancer;
Expand All @@ -22,6 +23,7 @@
import io.vertx.core.http.HttpServer;
import io.vertx.core.http.HttpServerOptions;
import io.vertx.core.json.Json;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;
import io.vertx.core.metrics.MetricsOptions;
import io.vertx.micrometer.MicrometerMetricsOptions;
Expand Down Expand Up @@ -65,19 +67,12 @@ void start() throws Exception {
LogStore logStore = new GfLogStore(vertx);
RateLimiter rateLimiter = new RateLimiter();
UpstreamBalancer upstreamBalancer = new UpstreamBalancer();

IdentityProvider identityProvider = new IdentityProvider(settings("identityProvider"), vertx, jwksUrl -> {
try {
return new UrlJwkProvider(new URL(jwksUrl));
} catch (MalformedURLException e) {
throw new IllegalArgumentException(e);
}
});
AccessTokenValidator accessTokenValidator = new AccessTokenValidator(settings.getJsonArray("identityProviders", new JsonArray()), vertx);
if (storage == null) {
Storage storageConfig = Json.decodeValue(settings("storage").toBuffer(), Storage.class);
storage = new BlobStorage(storageConfig);
}
Proxy proxy = new Proxy(vertx, client, configStore, logStore, rateLimiter, upstreamBalancer, identityProvider, storage);
Proxy proxy = new Proxy(vertx, client, configStore, logStore, rateLimiter, upstreamBalancer, accessTokenValidator, storage);

server = vertx.createHttpServer(new HttpServerOptions(settings("server"))).requestHandler(proxy);
open(server, HttpServer::listen);
Expand Down
5 changes: 3 additions & 2 deletions src/main/java/com/epam/aidial/core/Proxy.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.epam.aidial.core.controller.ControllerSelector;
import com.epam.aidial.core.limiter.RateLimiter;
import com.epam.aidial.core.log.LogStore;
import com.epam.aidial.core.security.AccessTokenValidator;
import com.epam.aidial.core.security.ExtractedClaims;
import com.epam.aidial.core.security.IdentityProvider;
import com.epam.aidial.core.storage.BlobStorage;
Expand Down Expand Up @@ -53,7 +54,7 @@ public class Proxy implements Handler<HttpServerRequest> {
private final LogStore logStore;
private final RateLimiter rateLimiter;
private final UpstreamBalancer upstreamBalancer;
private final IdentityProvider identityProvider;
private final AccessTokenValidator tokenValidator;
private final BlobStorage storage;

@Override
Expand Down Expand Up @@ -129,7 +130,7 @@ private void handleRequest(HttpServerRequest request) throws Exception {

request.pause();
final boolean isJwtMustBeValidated = key.getUserAuth() != UserAuth.DISABLED;
Future<ExtractedClaims> extractedClaims = identityProvider.extractClaims(authorization, isJwtMustBeValidated);
Future<ExtractedClaims> extractedClaims = tokenValidator.extractClaims(authorization, isJwtMustBeValidated);

extractedClaims.onComplete(result -> {
try {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package com.epam.aidial.core.security;

import com.auth0.jwk.UrlJwkProvider;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.google.common.annotations.VisibleForTesting;
import io.vertx.core.Future;
import io.vertx.core.Vertx;
import io.vertx.core.json.JsonArray;

import java.net.MalformedURLException;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;

public class AccessTokenValidator {

private final List<IdentityProvider> providers = new ArrayList<>();

public AccessTokenValidator(JsonArray idpConfig, Vertx vertx) {
int size = idpConfig.size();
if (size < 1) {
throw new IllegalArgumentException("At least one identity provider is required");
}
for (int i = 0; i < idpConfig.size(); i++) {
providers.add(new IdentityProvider(idpConfig.getJsonObject(i), vertx, jwksUrl -> {
try {
return new UrlJwkProvider(new URL(jwksUrl));
} catch (MalformedURLException e) {
throw new IllegalArgumentException(e);
}
}));
}
}

public Future<ExtractedClaims> extractClaims(String authHeader, boolean isJwtMustBeValidated) {
try {
if (authHeader == null) {
return isJwtMustBeValidated ? Future.failedFuture(new IllegalArgumentException("Token is missed")) : Future.succeededFuture();
}
String encodedToken = authHeader.split(" ")[1];
DecodedJWT jwt = IdentityProvider.decodeJwtToken(encodedToken);
if (providers.size() == 1) {
return providers.get(0).extractClaims(jwt, isJwtMustBeValidated);
}
for (IdentityProvider idp : providers) {
if (idp.match(jwt)) {
return idp.extractClaims(jwt, isJwtMustBeValidated);
}
}
return Future.failedFuture(new IllegalArgumentException("Unknown Identity Provider"));
} catch (Throwable e) {
return Future.failedFuture(e);
}
}

@VisibleForTesting
void setProviders(List<IdentityProvider> providers) {
this.providers.clear();
this.providers.addAll(providers);
}
}
50 changes: 26 additions & 24 deletions src/main/java/com/epam/aidial/core/security/IdentityProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.regex.Pattern;

import static java.util.Collections.EMPTY_LIST;

Expand Down Expand Up @@ -49,6 +50,8 @@ public class IdentityProvider {

private final long negativeCacheExpirationMs;

private Pattern issuerPattern;
artsiomkorzun marked this conversation as resolved.
Show resolved Hide resolved

public IdentityProvider(JsonObject settings, Vertx vertx, Function<String, JwkProvider> jwkProviderSupplier) {
if (settings == null) {
throw new IllegalArgumentException("Identity provider settings are missed");
Expand All @@ -74,6 +77,12 @@ public IdentityProvider(JsonObject settings, Vertx vertx, Function<String, JwkPr
throw new IllegalArgumentException(e);
}
obfuscateUserEmail = settings.getBoolean("obfuscateUserEmail", true);

String issuerPatternStr = settings.getString("issuerPattern");
if (issuerPatternStr != null) {
issuerPattern = Pattern.compile(issuerPatternStr);
}

long period = Math.min(negativeCacheExpirationMs, positiveCacheExpirationMs);
vertx.setPeriodic(0, period, event -> evictExpiredJwks());
}
Expand Down Expand Up @@ -114,7 +123,7 @@ private List<String> extractUserRoles(DecodedJWT token) {
return EMPTY_LIST;
}

private DecodedJWT decodeJwtToken(String encodedToken) {
public static DecodedJWT decodeJwtToken(String encodedToken) {
return JWT.decode(encodedToken);
}

Expand All @@ -132,21 +141,20 @@ private Future<JwkResult> getJwk(String kid) {
}));
}

private Future<DecodedJWT> decodeAndVerifyJwtToken(String encodedToken) {
DecodedJWT jwt = decodeJwtToken(encodedToken);
private Future<DecodedJWT> verifyJwt(DecodedJWT jwt) {
String kid = jwt.getKeyId();
Future<JwkResult> future = getJwk(kid);
return future.map(jwkResult -> verifyJwt(encodedToken, jwkResult));
return future.map(jwkResult -> verifyJwt(jwt, jwkResult));
}

private DecodedJWT verifyJwt(String encodedToken, JwkResult jwkResult) {
private DecodedJWT verifyJwt(DecodedJWT jwt, JwkResult jwkResult) {
Exception error = jwkResult.error();
if (error != null) {
throw new RuntimeException(error);
}
Jwk jwk = jwkResult.jwk();
try {
return JWT.require(Algorithm.RSA256((RSAPublicKey) jwk.getPublicKey(), null)).build().verify(encodedToken);
return JWT.require(Algorithm.RSA256((RSAPublicKey) jwk.getPublicKey(), null)).build().verify(jwt);
} catch (JwkException e) {
throw new RuntimeException(e);
}
Expand All @@ -173,28 +181,22 @@ private String extractUserHash(DecodedJWT decodedJwt) {
return keyClaim;
}

public Future<ExtractedClaims> extractClaims(String authHeader, boolean isJwtMustBeVerified) {
try {
if (authHeader == null) {
return isJwtMustBeVerified ? Future.failedFuture(new IllegalArgumentException("Token is missed")) : Future.succeededFuture();
}
// Take the 1st authorization parameter from the header value:
// Authorization: <auth-scheme> <authorization-parameters>
String encodedToken = authHeader.split(" ")[1];
return extractClaimsFromEncodedToken(encodedToken, isJwtMustBeVerified);
} catch (Throwable e) {
return Future.failedFuture(e);
Future<ExtractedClaims> extractClaims(DecodedJWT decodedJwt, boolean isJwtMustBeVerified) {
if (decodedJwt == null) {
return isJwtMustBeVerified ? Future.failedFuture(new IllegalArgumentException("decoded JWT must not be null")) : Future.succeededFuture();
}
Future<DecodedJWT> decodedJwtFuture = isJwtMustBeVerified ? verifyJwt(decodedJwt)
: Future.succeededFuture(decodedJwt);
return decodedJwtFuture.map(jwt -> new ExtractedClaims(extractUserSub(jwt), extractUserRoles(jwt),
extractUserHash(jwt)));
}

public Future<ExtractedClaims> extractClaimsFromEncodedToken(String encodedToken, boolean isJwtMustBeVerified) {
if (encodedToken == null) {
return Future.succeededFuture();
boolean match(DecodedJWT jwt) {
if (issuerPattern == null) {
return false;
}
Future<DecodedJWT> decodedJwt = isJwtMustBeVerified ? decodeAndVerifyJwtToken(encodedToken)
: Future.succeededFuture(decodeJwtToken(encodedToken));
return decodedJwt.map(jwt -> new ExtractedClaims(extractUserSub(jwt), extractUserRoles(jwt),
extractUserHash(jwt)));
String issuer = jwt.getIssuer();
return issuerPattern.matcher(issuer).matches();
}

private record JwkResult(Jwk jwk, Exception error, long expirationTime) {
Expand Down
16 changes: 12 additions & 4 deletions src/main/resources/aidial.settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,18 @@
"files": ["aidial.config.json"],
"reload": 60000
},
"identityProvider": {
"jwksUrl": "http://fakeJwksUrl:8080",
"rolePath": "roles"
},
"identityProviders": [
{
"jwksUrl": "http://fakeJwksUrl:8080",
"rolePath": "roles1",
"issuerPattern": "issuer1"
},
{
"jwksUrl": "http://fakeJwksUrl:8081",
"rolePath": "roles2",
"issuerPattern": "issuer1"
}
],
"storage": {
"provider" : "s3",
"endpoint" : "http://localhost:9000",
Expand Down
2 changes: 1 addition & 1 deletion src/test/java/com/epam/aidial/core/FileApiTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,6 @@ private static MultipartForm generateMultipartForm(String fileName, String conte

private static String generateJwtToken(String user) {
Algorithm algorithm = Algorithm.HMAC256("secret_key");
return JWT.create().withClaim("sub", user).sign(algorithm);
return JWT.create().withClaim("iss", "issuer").withClaim("sub", user).sign(algorithm);
}
}
Loading