Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Get the Project name from JWT token claims #605

Merged
merged 2 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Priority order:
| identityProviders.*.jwksUrl | - | Optional |Url to jwks provider. **Required** if `disabledVerifyJwt` is set to `false`. **Note**: Either `jwksUrl` or `userInfoEndpoint` must be provided.
| identityProviders.*.userInfoEndpoint | - | Optional |Url to user info endpoint. **Note**: Either `jwksUrl` or `userInfoEndpoint` must be provided or `disableJwtVerification` is unset. Refer to [Google example](sample/aidial.settings.json).
| identityProviders.*.rolePath | - | Yes |Path(s) to the claim user roles in JWT token or user info response, e.g. `resource_access.chatbot-ui.roles` or just `roles`. Can be single String or Array of Strings. Refer to [IDP Configuration](https://github.com/epam/ai-dial/blob/main/docs/Auth/2.%20Web/1.overview.md) to view guidelines for configuring supported providers.
| identityProviders.*.projectPath | - | No |Path(s) to the claim in JWT token or user info response, e.g. `azp`, `aud` or `some.path.client` from which project name can be taken. Can be single String. Refer to [IDP Configuration](https://github.com/epam/ai-dial/blob/main/docs/Auth/2.%20Web/1.overview.md) to view guidelines for configuring supported providers.
| identityProviders.*.rolesDelimiter | - | No |Delimiter to split roles into array in case when list of roles presented as single String. e.g. `"rolesDelimiter": " "`
| identityProviders.*.loggingKey | - | No |User information to search in claims of JWT token. `email` or `sub` should be sufficient in most cases. **Note**: `email` might be unavailable for some IDPs. Please check your IDP documentation in this case.
| identityProviders.*.loggingSalt | - | No |Salt to hash user information for logging.
Expand Down
9 changes: 8 additions & 1 deletion sample/aidial.settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@
"azure": {
"jwksUrl": "https://login.microsoftonline.com/path/discovery/keys",
"rolePath": "groups",
"projectPath": "aud",
"issuerPattern": "^https:\\/\\/some\\.windows\\.net.+$"
},
"keycloak": {
"jwksUrl": "https://host.com/realms/your/protocol/openid-connect/certs",
"rolePath": "resource_access.your.roles",
"projectPath": "azp",
"issuerPattern": "^https:\\/\\/some-keycloak.com.+$"
},
"google": {
"rolePath": "fn:getGoogleWorkspaceGroups",
"projectPath": "aud",
"userInfoEndpoint": "https://openidconnect.googleapis.com/v1/userinfo",
"loggingKey": "email",
"loggingSalt": "salt"
Expand All @@ -29,11 +32,13 @@
"loggingKey": "email",
"issuerPattern": "^https:\\/\\/cognito-idp\\.eu-north-1\\.amazonaws\\.com.+$",
"rolePath": "roles",
"projectPath": "aud",
"jwksUrl": "https://cognito-idp.eu-north-1.amazonaws.com/eu-north-1_PWSAjo4OY/.well-known/jwks.json",
"loggingSalt": "loggingSalt"
},
"gitlab": {
"rolePath": "groups",
"projectPath": "aud",
"userInfoEndpoint": "https://gitlab.com/oauth/userinfo",
"loggingKey": "email",
"loggingSalt": "salt"
Expand All @@ -42,13 +47,15 @@
"loggingKey": "email",
"issuerPattern": "^https:\\/\\/chatbot-ui-staging\\.eu\\.auth0\\.com.+$",
"rolePath": "dial_roles",
"projectPath": "aud",
"jwksUrl": "https://<your_domain>.auth0.com/.well-known/jwks.json",
"loggingSalt": "loggingSalt"
},
"okta": {
"loggingKey": "sub",
"issuerPattern": "^https:\\/\\/<your_domain>\\.okta\\.com.*$",
"rolePath": "Groups",
"projectPath": "aud",
"jwksUrl": "https://<your_domain>.okta.com/oauth2/default/v1/keys",
"loggingSalt": "loggingSalt"
},
Expand All @@ -68,4 +75,4 @@
"key": "key",
"secret": "secret"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public class ProxyContext {
private Deployment deployment;
private String userSub;
private List<String> userRoles;
private String userProject;
private String userHash;
private TokenUsage tokenUsage;
private Route route;
Expand Down Expand Up @@ -121,6 +122,7 @@ private void initExtractedClaims(ExtractedClaims extractedClaims, Key originalKe
this.userRoles = extractedClaims.userRoles();
this.userHash = extractedClaims.userHash();
this.userSub = extractedClaims.sub();
this.userProject = extractedClaims.project();
} else {
this.userRoles = Objects.requireNonNull(originalKey, "API key must be provided if user claims are missed")
.getMergedRoles();
Expand Down Expand Up @@ -149,7 +151,7 @@ public Future<?> respond(HttpStatus status, String body) {
}

if (status != HttpStatus.OK) {
log.warn("Responding with error. Key: {}. Trace: {}. Span: {}. Status: {}. Body: {}", getProject(), traceId, spanId, status,
log.warn("Responding with error. Project: {}. Trace: {}. Span: {}. Status: {}. Body: {}", getProject(), traceId, spanId, status,
body.length() > LOG_MAX_ERROR_LENGTH ? body.substring(0, LOG_MAX_ERROR_LENGTH) : body);
}

Expand All @@ -169,6 +171,9 @@ public Future<?> respond(HttpException exception) {
}

public String getProject() {
if (userProject != null) {
return userProject;
}
return key == null ? null : key.getProject();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ private void handleRequestBody(String endpoint, boolean requireEndpoint, Buffer

private void handleRequestError(String deploymentId, Throwable error) {
if (error instanceof PermissionDeniedException) {
log.error("Forbidden deployment {}. Key: {}. User sub: {}", deploymentId, context.getProject(), context.getUserSub());
log.error("Forbidden deployment {}. Project: {}. User sub: {}", deploymentId, context.getProject(), context.getUserSub());
context.respond(HttpStatus.FORBIDDEN, error.getMessage());
} else if (error instanceof ResourceNotFoundException) {
log.error("Deployment not found {}", deploymentId, error);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ private Future<?> handleInterceptor(int interceptorIndex) {

private void handleRequestError(String deploymentId, Throwable error) {
if (error instanceof PermissionDeniedException) {
log.error("Forbidden deployment {}. Key: {}. User sub: {}", deploymentId, context.getProject(), context.getUserSub());
log.error("Forbidden deployment {}. Project: {}. User sub: {}", deploymentId, context.getProject(), context.getUserSub());
respond(HttpStatus.FORBIDDEN, error.getMessage());
} else if (error instanceof ResourceNotFoundException) {
log.error("Deployment not found {}", deploymentId, error);
Expand All @@ -174,7 +174,7 @@ private void handleRequestError(String deploymentId, Throwable error) {
}

private Future<?> handleRateLimitSuccess() {
log.info("Received request from client. Trace: {}. Span: {}. Key: {}. Deployment: {}. Headers: {}",
log.info("Received request from client. Trace: {}. Span: {}. Project: {}. Deployment: {}. Headers: {}",
context.getTraceId(), context.getSpanId(),
context.getProject(), context.getDeployment().getName(),
context.getRequest().headers().size());
Expand Down Expand Up @@ -208,13 +208,13 @@ private void handleRateLimitHit(String deploymentId, RateLimitResult result) {
ErrorData rateLimitError = new ErrorData();
rateLimitError.getError().setCode(String.valueOf(result.status().getCode()));
rateLimitError.getError().setMessage(result.errorMessage());
log.error("Rate limit error {}. Key: {}. User sub: {}. Deployment: {}. Trace: {}. Span: {}", result.errorMessage(),
log.error("Rate limit error {}. Project: {}. User sub: {}. Deployment: {}. Trace: {}. Span: {}", result.errorMessage(),
context.getProject(), context.getUserSub(), deploymentId, context.getTraceId(), context.getSpanId());
respond(result.status(), rateLimitError);
}

private void handleError(Throwable error) {
log.error("Can't handle request. Key: {}. User sub: {}. Trace: {}. Span: {}. Error: {}",
log.error("Can't handle request. Project: {}. User sub: {}. Trace: {}. Span: {}. Error: {}",
context.getProject(), context.getUserSub(), context.getTraceId(), context.getSpanId(), error.getMessage());
respond(HttpStatus.INTERNAL_SERVER_ERROR);
}
Expand All @@ -241,7 +241,7 @@ private void sendRequest() {
@VisibleForTesting
void handleRequestBody(Buffer requestBody) {
Deployment deployment = context.getDeployment();
log.info("Received body from client. Trace: {}. Span: {}. Key: {}. Deployment: {}. Length: {}",
log.info("Received body from client. Trace: {}. Span: {}. Project: {}. Deployment: {}. Length: {}",
context.getTraceId(), context.getSpanId(),
context.getProject(), deployment.getName(), requestBody.length());

Expand Down Expand Up @@ -272,7 +272,7 @@ void handleRequestBody(Buffer requestBody) {
*/
@VisibleForTesting
void handleProxyRequest(HttpClientRequest proxyRequest) {
log.info("Connected to origin. Trace: {}. Span: {}. Key: {}. Deployment: {}. Address: {}",
log.info("Connected to origin. Trace: {}. Span: {}. Project: {}. Deployment: {}. Address: {}",
context.getTraceId(), context.getSpanId(),
context.getProject(), context.getDeployment().getName(),
proxyRequest.connection().remoteAddress());
Expand Down Expand Up @@ -314,7 +314,7 @@ void handleProxyRequest(HttpClientRequest proxyRequest) {
private void handleProxyResponse(HttpClientResponse proxyResponse) {
UpstreamRoute upstreamRoute = context.getUpstreamRoute();
Upstream currentUpstream = upstreamRoute.get();
log.info("Received header from origin. Trace: {}. Span: {}. Key: {}. Deployment: {}. Endpoint: {}. Upstream: {}. Status: {}. Headers: {}",
log.info("Received header from origin. Trace: {}. Span: {}. Project: {}. Deployment: {}. Endpoint: {}. Upstream: {}. Status: {}. Headers: {}",
context.getTraceId(), context.getSpanId(),
context.getProject(), context.getDeployment().getName(),
context.getDeployment().getEndpoint(), currentUpstream == null ? "N/A" : currentUpstream.getEndpoint(),
Expand Down Expand Up @@ -404,7 +404,7 @@ private Future<TokenUsage> collectTokenUsage(Buffer responseBody) {
if (tokenUsage == null) {
Pricing pricing = model.getPricing();
if (pricing == null || "token".equals(pricing.getUnit())) {
log.warn("Can't find token usage. Trace: {}. Span: {}. Key: {}. Deployment: {}. Endpoint: {}. Upstream: {}. Status: {}. Length: {}",
log.warn("Can't find token usage. Trace: {}. Span: {}. Project: {}. Deployment: {}. Endpoint: {}. Upstream: {}. Status: {}. Length: {}",
context.getTraceId(), context.getSpanId(),
context.getProject(), context.getDeployment().getName(),
context.getDeployment().getEndpoint(),
Expand Down Expand Up @@ -454,7 +454,7 @@ private void completeProxyResponse(BufferingReadStream responseStream) {

proxy.getLogStore().save(context);

log.info("Sent response to client. Trace: {}. Span: {}. Key: {}. Deployment: {}. Endpoint: {}. Upstream: {}. Status: {}. Length: {}."
log.info("Sent response to client. Trace: {}. Span: {}. Project: {}. Deployment: {}. Endpoint: {}. Upstream: {}. Status: {}. Length: {}."
+ " Timing: {} (body={}, connect={}, header={}, body={}). Tokens: {}",
context.getTraceId(), context.getSpanId(),
context.getProject(), context.getDeployment().getName(),
Expand Down Expand Up @@ -486,7 +486,7 @@ private void handleRequestBodyError(Throwable error) {
* Called when proxy failed to connect to the origin.
*/
private void handleProxyConnectionError(Throwable error) {
log.warn("Can't connect to origin. Trace: {}. Span: {}. Key: {}. Deployment: {}. Address: {}. Error: {}",
log.warn("Can't connect to origin. Trace: {}. Span: {}. Project: {}. Deployment: {}. Address: {}. Error: {}",
context.getTraceId(), context.getSpanId(),
context.getProject(), context.getDeployment().getName(),
buildUri(context), error.getMessage());
Expand All @@ -499,7 +499,7 @@ private void handleProxyConnectionError(Throwable error) {
*/
private void handleProxyResponseError(Throwable error) {
UpstreamRoute upstreamRoute = context.getUpstreamRoute();
log.warn("Proxy failed to receive response header from origin. Trace: {}. Span: {}. Key: {}. Deployment: {}. Address: {}. Error:",
log.warn("Proxy failed to receive response header from origin. Trace: {}. Span: {}. Project: {}. Deployment: {}. Address: {}. Error:",
context.getTraceId(), context.getSpanId(),
context.getProject(), context.getDeployment().getName(),
context.getProxyRequest().connection().remoteAddress(),
Expand Down Expand Up @@ -557,7 +557,7 @@ private boolean canRetry(UpstreamRoute route) {
try {
route.next();
} catch (HttpException e) {
log.error("No route. Trace: {}. Span: {}. Key: {}. Deployment: {}. User sub: {}",
log.error("No route. Trace: {}. Span: {}. Project: {}. Deployment: {}. User sub: {}",
context.getTraceId(), context.getSpanId(),
context.getProject(), context.getDeployment().getName(), context.getUserSub());
respond(e);
Expand Down Expand Up @@ -598,4 +598,4 @@ private void finalizeRequest() {
}).onFailure(error -> log.error("error occurred on invalidating per-request key", error));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import io.vertx.core.http.RequestOptions;
import lombok.extern.slf4j.Slf4j;

import java.io.IOException;
import java.io.InputStream;
import java.util.List;

Expand All @@ -43,7 +42,7 @@ public InterceptorController(Proxy proxy, ProxyContext context) {
}

public Future<?> handle() {
log.info("Received request from client. Trace: {}. Span: {}. Key: {}. Deployment: {}. Headers: {}",
log.info("Received request from client. Trace: {}. Span: {}. Project: {}. Deployment: {}. Headers: {}",
context.getTraceId(), context.getSpanId(),
context.getProject(), context.getDeployment().getName(),
context.getRequest().headers().size());
Expand All @@ -60,7 +59,7 @@ public Future<?> handle() {
}

private void handleError(Throwable error) {
log.error("Can't handle request. Key: {}. User sub: {}. Trace: {}. Span: {}. Error: {}",
log.error("Can't handle request. Project: {}. User sub: {}. Trace: {}. Span: {}. Error: {}",
context.getProject(), context.getUserSub(), context.getTraceId(), context.getSpanId(), error.getMessage());
respond(HttpStatus.INTERNAL_SERVER_ERROR);
}
Expand Down Expand Up @@ -118,7 +117,7 @@ private void handleRequestBodyError(Throwable error) {
* Called when proxy failed to connect to the origin.
*/
private void handleProxyConnectionError(Throwable error) {
log.warn("Can't connect to origin. Trace: {}. Span: {}. Key: {}. Deployment: {}. Address: {}. Error: {}",
log.warn("Can't connect to origin. Trace: {}. Span: {}. Project: {}. Deployment: {}. Address: {}. Error: {}",
context.getTraceId(), context.getSpanId(),
context.getProject(), context.getDeployment().getName(),
context.getDeployment().getEndpoint(), error.getMessage());
Expand All @@ -128,7 +127,7 @@ private void handleProxyConnectionError(Throwable error) {


void handleProxyRequest(HttpClientRequest proxyRequest) {
log.info("Connected to interceptor. Trace: {}. Span: {}. Key: {}. Deployment: {}. Address: {}",
log.info("Connected to interceptor. Trace: {}. Span: {}. Project: {}. Deployment: {}. Address: {}",
context.getTraceId(), context.getSpanId(),
context.getProject(), context.getDeployment().getName(),
proxyRequest.connection().remoteAddress());
Expand All @@ -155,15 +154,15 @@ void handleProxyRequest(HttpClientRequest proxyRequest) {
* Called when proxy failed to receive response header from origin.
*/
private void handleProxyResponseError(Throwable error) {
log.warn("Proxy failed to receive response header from origin. Trace: {}. Span: {}. Key: {}. Deployment: {}. Address: {}. Error:",
log.warn("Proxy failed to receive response header from origin. Trace: {}. Span: {}. Project: {}. Deployment: {}. Address: {}. Error:",
context.getTraceId(), context.getSpanId(),
context.getProject(), context.getDeployment().getName(),
context.getProxyRequest().connection().remoteAddress(),
error);
}

private void handleProxyResponse(HttpClientResponse proxyResponse) {
log.info("Received header from origin. Trace: {}. Span: {}. Key: {}. Deployment: {}. Endpoint: {}. Status: {}. Headers: {}",
log.info("Received header from origin. Trace: {}. Span: {}. Project: {}. Deployment: {}. Endpoint: {}. Status: {}. Headers: {}",
context.getTraceId(), context.getSpanId(),
context.getProject(), context.getDeployment().getName(),
context.getDeployment().getEndpoint(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public Future<?> getLimits(String deploymentId) {

private void handleRequestError(String deploymentId, Throwable error) {
if (error instanceof PermissionDeniedException) {
log.error("LimitController. Forbidden deployment {}. Key: {}. User sub: {}", deploymentId, context.getProject(), context.getUserSub());
log.error("LimitController. Forbidden deployment {}. Project: {}. User sub: {}", deploymentId, context.getProject(), context.getUserSub());
context.respond(HttpStatus.FORBIDDEN, error.getMessage());
} else if (error instanceof ResourceNotFoundException) {
log.error("LimitController. Deployment not found {}", deploymentId, error);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public Future<?> handle() {
}

if (!route.hasAccess(context.getUserRoles())) {
log.error("Forbidden route {}. Trace: {}. Span: {}. Key: {}. User sub: {}.",
log.error("Forbidden route {}. Trace: {}. Span: {}. Project: {}. User sub: {}.",
route.getName(), context.getTraceId(), context.getSpanId(), context.getProject(), context.getUserSub());
context.respond(HttpStatus.FORBIDDEN, "Forbidden route");
return Future.succeededFuture();
Expand Down Expand Up @@ -205,7 +205,7 @@ private void handleRateLimitHit(RateLimitResult result) {
ErrorData rateLimitError = new ErrorData();
rateLimitError.getError().setCode(String.valueOf(result.status().getCode()));
rateLimitError.getError().setMessage(result.errorMessage());
log.error("Rate limit error {}. Key: {}. User sub: {}. Route: {}. Trace: {}. Span: {}", result.errorMessage(),
log.error("Rate limit error {}. Project: {}. User sub: {}. Route: {}. Trace: {}. Span: {}", result.errorMessage(),
context.getProject(), context.getUserSub(), context.getRoute().getName(), context.getTraceId(),
context.getSpanId());
context.respond(result.status(), rateLimitError);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
import java.util.List;
import java.util.Map;

public record ExtractedClaims(String sub, List<String> userRoles, String userHash, Map<String, List<String>> userClaims) {
public record ExtractedClaims(String sub, List<String> userRoles, String userHash, Map<String, List<String>> userClaims, String project) {
}
Loading
Loading