diff --git a/oauthproxy.go b/oauthproxy.go index 10a69b311d..7181cc3e4b 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -98,6 +98,7 @@ type OAuthProxy struct { forceJSONErrors bool realClientIPParser ipapi.RealClientIPParser trustedIPs *ip.NetSet + statePostfix string sessionChain alice.Chain headersChain alice.Chain @@ -235,6 +236,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr upstreamProxy: upstreamProxy, redirectValidator: redirectValidator, appDirector: appDirector, + statePostfix: opts.StatePostfix, } p.buildServeMux(opts.ProxyPrefix) @@ -787,7 +789,7 @@ func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, ove callbackRedirect := p.getOAuthRedirectURI(req) loginURL := p.provider.GetLoginURL( callbackRedirect, - encodeState(csrf.HashOAuthState(), appRedirect), + encodeState(csrf.HashOAuthState(), appRedirect, p.statePostfix), csrf.HashOIDCNonce(), extraParams, ) @@ -845,7 +847,8 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { csrf.ClearCookie(rw, req) - nonce, appRedirect, err := decodeState(req) + nonce, appRedirect, _, err := decodeState(req) + if err != nil { logger.Errorf("Error while parsing OAuth2 state: %v", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) @@ -1185,18 +1188,20 @@ func checkAllowedEmails(req *http.Request, s *sessionsapi.SessionState) bool { // encodedState builds the OAuth state param out of our nonce and // original application redirect -func encodeState(nonce string, redirect string) string { - return fmt.Sprintf("%v:%v", nonce, redirect) +func encodeState(nonce string, redirect string, additional string) string { + return fmt.Sprintf("%v|%v|%v", nonce, redirect, additional) } // decodeState splits the reflected OAuth state response back into // the nonce and original application redirect -func decodeState(req *http.Request) (string, string, error) { - state := strings.SplitN(req.Form.Get("state"), ":", 2) - if len(state) != 2 { - return "", "", errors.New("invalid length") +func decodeState(req *http.Request) (string, string, string, error) { + state := strings.SplitN(req.Form.Get("state"), "|", 3) + + if len(state) != 3 { + return "", "", "", errors.New("invalid length") } - return state[0], state[1], nil + + return state[0], state[1], state[2], nil } // addHeadersForProxying adds the appropriate headers the request / response for proxying diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 0d8bc91a6e..689111181e 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -413,7 +413,7 @@ func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int, cookie http.MethodGet, fmt.Sprintf( "/oauth2/callback?code=callback_code&state=%s", - encodeState(csrf.HashOAuthState(), "%2F"), + encodeState(csrf.HashOAuthState(), "", "%2F"), ), strings.NewReader(""), ) diff --git a/pkg/apis/options/options.go b/pkg/apis/options/options.go index 0af8df3fc6..0ca47a5abd 100644 --- a/pkg/apis/options/options.go +++ b/pkg/apis/options/options.go @@ -27,6 +27,7 @@ type Options struct { TrustedIPs []string `flag:"trusted-ip" cfg:"trusted_ips"` ForceHTTPS bool `flag:"force-https" cfg:"force_https"` RawRedirectURL string `flag:"redirect-url" cfg:"redirect_url"` + StatePostfix string `flag:"state-postfix" cfg:"state_postfix"` AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` EmailDomains []string `flag:"email-domain" cfg:"email_domains"` @@ -151,6 +152,7 @@ func NewFlagSet() *pflag.FlagSet { flagSet.Int("redis-connection-idle-timeout", 0, "Redis connection idle timeout seconds, if Redis timeout option is non-zero, the --redis-connection-idle-timeout must be less then Redis timeout option") flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)") flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints") + flagSet.String("state-postfix", "", "state_postifx") flagSet.AddFlagSet(cookieFlagSet()) flagSet.AddFlagSet(loggingFlagSet())