Skip to content

Commit

Permalink
Proposed fix for missing WWW-Authenticate header
Browse files Browse the repository at this point in the history
Current implementation does not include the WWW-Authenticate
header when returning a 401 for missing/invalid credentials when
attempting to access the token endpoints. This PR would change
to use the standard BasicAuthenticationEntryPoint in order to
populate this header correctly.

Fixes-468

Signed-off-by: Lucian Holland <[email protected]>
  • Loading branch information
symposion committed Jan 27, 2025
1 parent b76300b commit 552a4a0
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.springframework.context.event.GenericApplicationListenerAdapter;
import org.springframework.context.event.SmartApplicationListener;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.security.config.Customizer;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
Expand All @@ -48,7 +47,7 @@
import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenGenerator;
import org.springframework.security.oauth2.server.authorization.web.NimbusJwkSetEndpointFilter;
import org.springframework.security.web.authentication.HttpStatusEntryPoint;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ServerAuthenticationEntryPoint;
import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter;
import org.springframework.security.web.context.SecurityContextHolderFilter;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
Expand Down Expand Up @@ -344,7 +343,8 @@ public void init(HttpSecurity httpSecurity) throws Exception {
ExceptionHandlingConfigurer<HttpSecurity> exceptionHandling = httpSecurity
.getConfigurer(ExceptionHandlingConfigurer.class);
if (exceptionHandling != null) {
exceptionHandling.defaultAuthenticationEntryPointFor(new HttpStatusEntryPoint(HttpStatus.UNAUTHORIZED),
var entryPoint = new OAuth2ServerAuthenticationEntryPoint();
exceptionHandling.defaultAuthenticationEntryPointFor(entryPoint,
new OrRequestMatcher(getRequestMatcher(OAuth2TokenEndpointConfigurer.class),
getRequestMatcher(OAuth2TokenIntrospectionEndpointConfigurer.class),
getRequestMatcher(OAuth2TokenRevocationEndpointConfigurer.class),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,18 @@
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
import org.springframework.security.oauth2.server.authorization.oidc.OidcClientRegistration;
import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientConfigurationAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcClientRegistrationHttpMessageConverter;
import org.springframework.security.oauth2.server.authorization.oidc.web.authentication.OidcClientRegistrationAuthenticationConverter;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
Expand Down Expand Up @@ -85,13 +84,11 @@ public final class OidcClientRegistrationEndpointFilter extends OncePerRequestFi

private final HttpMessageConverter<OidcClientRegistration> clientRegistrationHttpMessageConverter = new OidcClientRegistrationHttpMessageConverter();

private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter();

private AuthenticationConverter authenticationConverter = new OidcClientRegistrationAuthenticationConverter();

private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendClientRegistrationResponse;

private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse;
private AuthenticationFailureHandler authenticationFailureHandler = new OAuth2ErrorAuthenticationFailureHandler();

/**
* Constructs an {@code OidcClientRegistrationEndpointFilter} using the provided
Expand Down Expand Up @@ -224,22 +221,4 @@ private void sendClientRegistrationResponse(HttpServletRequest request, HttpServ
this.clientRegistrationHttpMessageConverter.write(clientRegistration, null, httpResponse);
}

private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
AuthenticationException authenticationException) throws IOException {
OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError();
HttpStatus httpStatus = HttpStatus.BAD_REQUEST;
if (OAuth2ErrorCodes.INVALID_TOKEN.equals(error.getErrorCode())) {
httpStatus = HttpStatus.UNAUTHORIZED;
}
else if (OAuth2ErrorCodes.INSUFFICIENT_SCOPE.equals(error.getErrorCode())) {
httpStatus = HttpStatus.FORBIDDEN;
}
else if (OAuth2ErrorCodes.INVALID_CLIENT.equals(error.getErrorCode())) {
httpStatus = HttpStatus.UNAUTHORIZED;
}
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
httpResponse.setStatusCode(httpStatus);
this.errorHttpResponseConverter.write(error, null, httpResponse);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,19 @@

import org.springframework.core.log.LogMessage;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcUserInfoHttpMessageConverter;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
Expand Down Expand Up @@ -73,13 +71,11 @@ public final class OidcUserInfoEndpointFilter extends OncePerRequestFilter {

private final HttpMessageConverter<OidcUserInfo> userInfoHttpMessageConverter = new OidcUserInfoHttpMessageConverter();

private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter();

private AuthenticationConverter authenticationConverter = this::createAuthentication;

private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendUserInfoResponse;

private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse;
private AuthenticationFailureHandler authenticationFailureHandler = new OAuth2ErrorAuthenticationFailureHandler();

/**
* Constructs an {@code OidcUserInfoEndpointFilter} using the provided parameters.
Expand Down Expand Up @@ -193,19 +189,4 @@ private void sendUserInfoResponse(HttpServletRequest request, HttpServletRespons
this.userInfoHttpMessageConverter.write(userInfoAuthenticationToken.getUserInfo(), null, httpResponse);
}

private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
AuthenticationException authenticationException) throws IOException {
OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError();
HttpStatus httpStatus = HttpStatus.BAD_REQUEST;
if (error.getErrorCode().equals(OAuth2ErrorCodes.INVALID_TOKEN)) {
httpStatus = HttpStatus.UNAUTHORIZED;
}
else if (error.getErrorCode().equals(OAuth2ErrorCodes.INSUFFICIENT_SCOPE)) {
httpStatus = HttpStatus.FORBIDDEN;
}
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
httpResponse.setStatusCode(httpStatus);
this.errorHttpResponseConverter.write(error, null, httpResponse);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,15 @@
import jakarta.servlet.http.HttpServletResponse;

import org.springframework.core.log.LogMessage;
import org.springframework.http.HttpStatus;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
import org.springframework.security.oauth2.server.authorization.authentication.ClientSecretAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.authentication.JwtClientAssertionAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
Expand All @@ -46,6 +41,7 @@
import org.springframework.security.oauth2.server.authorization.web.authentication.ClientSecretBasicAuthenticationConverter;
import org.springframework.security.oauth2.server.authorization.web.authentication.ClientSecretPostAuthenticationConverter;
import org.springframework.security.oauth2.server.authorization.web.authentication.JwtClientAssertionAuthenticationConverter;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler;
import org.springframework.security.oauth2.server.authorization.web.authentication.PublicClientAuthenticationConverter;
import org.springframework.security.oauth2.server.authorization.web.authentication.X509ClientCertificateAuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationConverter;
Expand Down Expand Up @@ -86,15 +82,13 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter

private final RequestMatcher requestMatcher;

private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter();

private final AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();

private AuthenticationConverter authenticationConverter;

private AuthenticationSuccessHandler authenticationSuccessHandler = this::onAuthenticationSuccess;

private AuthenticationFailureHandler authenticationFailureHandler = this::onAuthenticationFailure;
private AuthenticationFailureHandler authenticationFailureHandler = new OAuth2ErrorAuthenticationFailureHandler();

/**
* Constructs an {@code OAuth2ClientAuthenticationFilter} using the provided
Expand Down Expand Up @@ -199,35 +193,6 @@ private void onAuthenticationSuccess(HttpServletRequest request, HttpServletResp
}
}

private void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response,
AuthenticationException exception) throws IOException {

SecurityContextHolder.clearContext();

// TODO
// The authorization server MAY return an HTTP 401 (Unauthorized) status code
// to indicate which HTTP authentication schemes are supported.
// If the client attempted to authenticate via the "Authorization" request header
// field,
// the authorization server MUST respond with an HTTP 401 (Unauthorized) status
// code and
// include the "WWW-Authenticate" response header field
// matching the authentication scheme used by the client.

OAuth2Error error = ((OAuth2AuthenticationException) exception).getError();
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
if (OAuth2ErrorCodes.INVALID_CLIENT.equals(error.getErrorCode())) {
httpResponse.setStatusCode(HttpStatus.UNAUTHORIZED);
}
else {
httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
}
// We don't want to reveal too much information to the caller so just return the
// error code
OAuth2Error errorResponse = new OAuth2Error(error.getErrorCode());
this.errorHttpResponseConverter.write(errorResponse, null, httpResponse);
}

private static void validateClientIdentifier(Authentication authentication) {
if (!(authentication instanceof OAuth2ClientAuthenticationToken)) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.util.Assert;
Expand All @@ -49,17 +51,27 @@ public final class OAuth2ErrorAuthenticationFailureHandler implements Authentica

private HttpMessageConverter<OAuth2Error> errorResponseConverter = new OAuth2ErrorHttpMessageConverter();

private final String realmName = "oauth2"; // TODO configure this properly

@Override
public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response,
AuthenticationException authenticationException) throws IOException, ServletException {

SecurityContextHolder.clearContext();

ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);

if (authenticationException instanceof OAuth2AuthenticationException) {
OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError();
var status = getHttpStatus(error);
httpResponse.setStatusCode(status);
if (status == HttpStatus.UNAUTHORIZED && this.realmName != null) {
httpResponse.getHeaders().set("WWW-Authenticate", "Basic realm=\"" + this.realmName + "\"");
}
this.errorResponseConverter.write(error, null, httpResponse);
}
else {
httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
if (this.logger.isWarnEnabled()) {
this.logger.warn(AuthenticationException.class.getSimpleName() + " must be of type "
+ OAuth2AuthenticationException.class.getName() + " but was "
Expand All @@ -68,6 +80,14 @@ public void onAuthenticationFailure(HttpServletRequest request, HttpServletRespo
}
}

private HttpStatus getHttpStatus(OAuth2Error error) {
return switch (error.getErrorCode()) {
case OAuth2ErrorCodes.INVALID_CLIENT, OAuth2ErrorCodes.INVALID_TOKEN -> HttpStatus.UNAUTHORIZED;
case OAuth2ErrorCodes.INSUFFICIENT_SCOPE, OAuth2ErrorCodes.ACCESS_DENIED -> HttpStatus.FORBIDDEN;
default -> HttpStatus.BAD_REQUEST;
};
}

/**
* Sets the {@link HttpMessageConverter} used for converting an {@link OAuth2Error} to
* an HTTP response.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package org.springframework.security.oauth2.server.authorization.web.authentication;

import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.http.HttpStatus;
import org.springframework.security.authentication.InsufficientAuthenticationException;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.web.AuthenticationEntryPoint;

import java.io.IOException;

public class OAuth2ServerAuthenticationEntryPoint implements AuthenticationEntryPoint {

private final OAuth2ErrorAuthenticationFailureHandler authenticationFailureHandler = new OAuth2ErrorAuthenticationFailureHandler();

@Override
public void commence(HttpServletRequest request, HttpServletResponse response,
AuthenticationException authException) throws IOException, ServletException {
var convertedException = convertInsufficientAccessException(authException);

if (authException instanceof OAuth2AuthenticationException) {
authenticationFailureHandler.onAuthenticationFailure(request, response, convertedException);
}
else {
response.sendError(HttpStatus.BAD_REQUEST.value(), HttpStatus.BAD_REQUEST.getReasonPhrase());
}
}

private AuthenticationException convertInsufficientAccessException(AuthenticationException authException) {
if (authException instanceof InsufficientAuthenticationException) {
return new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT),
authException.getCause());
}
return authException;
}

}

0 comments on commit 552a4a0

Please sign in to comment.