diff --git a/adapter/config/default_config.go b/adapter/config/default_config.go index c72370914b..b40c05a784 100644 --- a/adapter/config/default_config.go +++ b/adapter/config/default_config.go @@ -228,7 +228,7 @@ var defaultConfig = &Config{ DropConsoleTestHeaders: true, }, APIKeyConfig: apiKeyConfig{ - InternalAPIKeyHeader: "Choreo-API-Key", + InternalAPIKeyHeader: "choreo-api-key", OAuthAgentURL: "https://localhost:9443", }, PATConfig: patConfig{ diff --git a/enforcer-parent/enforcer/src/main/java/org/wso2/choreo/connect/enforcer/security/jwt/ChoreoAPIKeyAuthenticator.java b/enforcer-parent/enforcer/src/main/java/org/wso2/choreo/connect/enforcer/security/jwt/ChoreoAPIKeyAuthenticator.java index 22a2737799..628aedb471 100644 --- a/enforcer-parent/enforcer/src/main/java/org/wso2/choreo/connect/enforcer/security/jwt/ChoreoAPIKeyAuthenticator.java +++ b/enforcer-parent/enforcer/src/main/java/org/wso2/choreo/connect/enforcer/security/jwt/ChoreoAPIKeyAuthenticator.java @@ -25,9 +25,7 @@ import org.wso2.choreo.connect.enforcer.commons.model.AuthenticationContext; import org.wso2.choreo.connect.enforcer.commons.model.RequestContext; import org.wso2.choreo.connect.enforcer.config.ConfigHolder; -import org.wso2.choreo.connect.enforcer.constants.APIConstants; import org.wso2.choreo.connect.enforcer.exception.APISecurityException; -import org.wso2.choreo.connect.enforcer.security.jwt.validator.JWTConstants; import java.util.Base64; import java.util.Map; @@ -65,6 +63,18 @@ public boolean canAuthenticate(RequestContext requestContext) { @Override public AuthenticationContext authenticate(RequestContext requestContext) throws APISecurityException { + return super.authenticate(requestContext); + } + + private String getAPIKeyFromRequest(RequestContext requestContext) { + Map headers = requestContext.getHeaders(); + return headers.get(ConfigHolder.getInstance().getConfig().getApiKeyConfig() + .getApiKeyInternalHeader().toLowerCase()); + } + + @Override + protected String retrieveTokenFromRequestCtx(RequestContext requestContext) { + String apiKeyHeaderValue = getAPIKeyFromRequest(requestContext); // Skipping the prefix(`chk_`) and checksum. String apiKeyData = apiKeyHeaderValue.substring(4, apiKeyHeaderValue.length() - 6); @@ -73,16 +83,7 @@ public AuthenticationContext authenticate(RequestContext requestContext) throws // Convert data into JSON. JSONObject jsonObject = (JSONObject) JSONValue.parse(decodedKeyData); // Extracting the jwt token. - String jwtToken = jsonObject.getAsString(APIKeyConstants.API_KEY_JSON_KEY); - // Add the JWT as the Authorization header to authenticate the request. - requestContext.getHeaders().put(APIConstants.AUTHORIZATION_HEADER_DEFAULT, - JWTConstants.BEARER + " " + jwtToken); - return super.authenticate(requestContext); - } - - private String getAPIKeyFromRequest(RequestContext requestContext) { - Map headers = requestContext.getHeaders(); - return headers.get(ConfigHolder.getInstance().getConfig().getApiKeyConfig().getApiKeyInternalHeader()); + return jsonObject.getAsString(APIKeyConstants.API_KEY_JSON_KEY); } @Override diff --git a/enforcer-parent/enforcer/src/main/java/org/wso2/choreo/connect/enforcer/security/jwt/JWTAuthenticator.java b/enforcer-parent/enforcer/src/main/java/org/wso2/choreo/connect/enforcer/security/jwt/JWTAuthenticator.java index 6682c40f74..b798dd6af6 100644 --- a/enforcer-parent/enforcer/src/main/java/org/wso2/choreo/connect/enforcer/security/jwt/JWTAuthenticator.java +++ b/enforcer-parent/enforcer/src/main/java/org/wso2/choreo/connect/enforcer/security/jwt/JWTAuthenticator.java @@ -171,30 +171,8 @@ public AuthenticationContext authenticate(RequestContext requestContext) throws Utils.setTag(jwtAuthenticatorInfoSpan, APIConstants.LOG_TRACE_ID, ThreadContext.get(APIConstants.LOG_TRACE_ID)); } - String authHeaderVal = retrieveAuthHeaderValue(requestContext); - if (authHeaderVal == null - && requestContext.getMatchedAPI().getApiType().equalsIgnoreCase(APIConstants.ApiType.WEB_SOCKET)) { - String tokenValue = extractJWTInWSProtocolHeader(requestContext); - if (StringUtils.isNotEmpty(tokenValue)) { - authHeaderVal = JWTConstants.BEARER + " " + tokenValue; - } - } - - if (authHeaderVal == null || !authHeaderVal.toLowerCase().contains(JWTConstants.BEARER)) { - throw new APISecurityException(APIConstants.StatusCodes.UNAUTHENTICATED.getCode(), - APISecurityConstants.API_AUTH_MISSING_CREDENTIALS, "Missing Credentials"); - } - String[] splitToken = authHeaderVal.split("\\s"); - String token = authHeaderVal; - // Extract the token when it is sent as bearer token. i.e Authorization: Bearer - if (splitToken.length > 1) { - token = splitToken[1]; - } - // Handle PAT logic - if (isPATEnabled && token.startsWith(APIKeyConstants.PAT_PREFIX)) { - token = exchangeJWTForPAT(requestContext, token); - } + String token = retrieveTokenFromRequestCtx(requestContext); String context = requestContext.getMatchedAPI().getBasePath(); String name = requestContext.getMatchedAPI().getName(); String version = requestContext.getMatchedAPI().getVersion(); @@ -266,7 +244,7 @@ public AuthenticationContext authenticate(RequestContext requestContext) throws ThreadContext.get(APIConstants.LOG_TRACE_ID)); } // if the token is self contained, validation subscription from `subscribedApis` claim - JSONObject api = validateSubscriptionFromClaim(name, version, claims, splitToken, + JSONObject api = validateSubscriptionFromClaim(name, version, claims, token, apiKeyValidationInfoDTO, true); if (api == null) { if (log.isDebugEnabled()) { @@ -527,6 +505,40 @@ private String retrieveAuthHeaderValue(RequestContext requestContext) { return headers.get(FilterUtils.getAuthHeaderName(requestContext)); } + /** + * Extract the JWT token from the request context. + * + * @param requestContext Request context + * @return JWT token + * @throws APISecurityException If an error occurs while extracting the JWT token + */ + protected String retrieveTokenFromRequestCtx(RequestContext requestContext) throws APISecurityException { + + String authHeaderVal = retrieveAuthHeaderValue(requestContext); + if (authHeaderVal == null + && requestContext.getMatchedAPI().getApiType().equalsIgnoreCase(APIConstants.ApiType.WEB_SOCKET)) { + String tokenValue = extractJWTInWSProtocolHeader(requestContext); + if (StringUtils.isNotEmpty(tokenValue)) { + authHeaderVal = JWTConstants.BEARER + " " + tokenValue; + } + } + if (authHeaderVal == null || !authHeaderVal.toLowerCase().contains(JWTConstants.BEARER)) { + throw new APISecurityException(APIConstants.StatusCodes.UNAUTHENTICATED.getCode(), + APISecurityConstants.API_AUTH_MISSING_CREDENTIALS, "Missing Credentials"); + } + String[] splitToken = authHeaderVal.split("\\s"); + String token = authHeaderVal; + // Extract the token when it is sent as bearer token. i.e Authorization: Bearer + if (splitToken.length > 1) { + token = splitToken[1]; + } + // Handle PAT logic + if (isPATEnabled && token.startsWith(APIKeyConstants.PAT_PREFIX)) { + token = exchangeJWTForPAT(requestContext, token); + } + return token; + } + @Override public int getPriority() { return 10; @@ -612,9 +624,9 @@ private APIKeyValidationInfoDTO validateSubscriptionUsingKeyManager(RequestConte * If the subscription information is not found, return a null object. * @throws APISecurityException if the user is not subscribed to the API */ - private JSONObject validateSubscriptionFromClaim(String name, String version, JWTClaimsSet payload, - String[] splitToken, APIKeyValidationInfoDTO validationInfo, - boolean isOauth) throws APISecurityException { + private JSONObject validateSubscriptionFromClaim(String name, String version, JWTClaimsSet payload, String token, + APIKeyValidationInfoDTO validationInfo, boolean isOauth) + throws APISecurityException { JSONObject api = null; try { validationInfo.setEndUserName(payload.getSubject()); @@ -678,7 +690,7 @@ private JSONObject validateSubscriptionFromClaim(String name, String version, JW } if (log.isDebugEnabled()) { log.debug("User is subscribed to the API: " + name + ", " + - "version: " + version + ". Token: " + FilterUtils.getMaskedToken(splitToken[0])); + "version: " + version + ". Token: " + FilterUtils.getMaskedToken(token)); } break; } @@ -686,7 +698,7 @@ private JSONObject validateSubscriptionFromClaim(String name, String version, JW if (api == null) { if (log.isDebugEnabled()) { log.debug("User is not subscribed to access the API: " + name + - ", version: " + version + ". Token: " + FilterUtils.getMaskedToken(splitToken[0])); + ", version: " + version + ". Token: " + FilterUtils.getMaskedToken(token)); } log.error("User is not subscribed to access the API."); throw new APISecurityException(APIConstants.StatusCodes.UNAUTHORIZED.getCode(),