Skip to content

Commit

Permalink
fix: switch over status codes to errors
Browse files Browse the repository at this point in the history
  • Loading branch information
stebenz committed Dec 10, 2024
1 parent 4ed9e36 commit eab34b4
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 63 deletions.
12 changes: 6 additions & 6 deletions pkg/provider/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Req

samlResponse, err := p.loginResponse(r.Context(), authRequest, response)
if err != nil {
response.sendBackResponse(r, w, response.makeFailedResponse(err.Error(), "failed to create response", p.TimeFormat))
response.sendBackResponse(r, w, response.makeFailedResponse(err, "failed to create response", p.TimeFormat))
return
}

Expand All @@ -66,29 +66,29 @@ func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Req
func (p *IdentityProvider) loginResponse(ctx context.Context, authRequest models.AuthRequestInt, response *Response) (*samlp.ResponseType, error) {
if !authRequest.Done() {
logging.Error(StatusCodeAuthNFailed)
return nil, fmt.Errorf(StatusCodeAuthNFailed)
return nil, StatusCodeAuthNFailed
}

attrs := &Attributes{}
if err := p.storage.SetUserinfoWithUserID(ctx, authRequest.GetApplicationID(), attrs, authRequest.GetUserID(), []int{}); err != nil {
logging.Error(err)
return nil, fmt.Errorf(StatusCodeInvalidAttrNameOrValue)
return nil, StatusCodeInvalidAttrNameOrValue
}

cert, key, err := getResponseCert(ctx, p.storage)
if err != nil {
logging.Error(err)
return nil, fmt.Errorf(StatusCodeInvalidAttrNameOrValue)
return nil, StatusCodeInvalidAttrNameOrValue
}

samlResponse := response.makeSuccessfulResponse(attrs, p.TimeFormat, p.Expiration)
if err := createSignature(response, samlResponse, key, cert, p.conf.SignatureAlgorithm); err != nil {
logging.Error(err)
return nil, fmt.Errorf(StatusCodeResponder)
return nil, StatusCodeResponder
}
return samlResponse, nil
}

func (p *IdentityProvider) errorResponse(response *Response, reason, description string) *samlp.ResponseType {
func (p *IdentityProvider) errorResponse(response *Response, reason error, description string) *samlp.ResponseType {
return response.makeFailedResponse(reason, description, p.TimeFormat)
}
2 changes: 1 addition & 1 deletion pkg/provider/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ func TestSSO_loginHandleFunc(t *testing.T) {
},
res{
code: 302,
state: StatusCodeAuthNFailed,
state: StatusCodeAuthNFailed.Error(),
err: false,
inflate: true,
b64: true,
Expand Down
8 changes: 4 additions & 4 deletions pkg/provider/logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return nil
},
func() {
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to parse form: %w", err).Error(), p.TimeFormat))
response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to parse form: %w", err).Error(), p.TimeFormat))
},
)

Expand All @@ -60,7 +60,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return nil
},
func() {
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to decode request: %w", err).Error(), p.TimeFormat))
response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to decode request: %w", err).Error(), p.TimeFormat))
},
)

Expand All @@ -72,7 +72,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
p.TimeFormat,
),
func() {
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to validate request: %w", err).Error(), p.TimeFormat))
response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to validate request: %w", err).Error(), p.TimeFormat))
},
)

Expand All @@ -83,7 +83,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return err
},
func() {
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to find registered serviceprovider: %w", err).Error(), p.TimeFormat))
response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to find registered serviceprovider: %w", err).Error(), p.TimeFormat))
},
)

Expand Down
39 changes: 6 additions & 33 deletions pkg/provider/logout_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,55 +55,28 @@ func (r *LogoutResponse) sendBackLogoutResponse(w http.ResponseWriter, resp *sam
}
}

func (r *LogoutResponse) makeSuccessfulLogoutResponse(timeFormat string) *samlp.LogoutResponseType {
return makeLogoutResponse(
r.RequestID,
r.LogoutURL,
time.Now().UTC().Format(timeFormat),
StatusCodeSuccess,
"",
getIssuer(r.Issuer),
)
}

func (r *LogoutResponse) makeUnsupportedlLogoutResponse(
message string,
timeFormat string,
) *samlp.LogoutResponseType {
return makeLogoutResponse(
r.RequestID,
r.LogoutURL,
time.Now().UTC().Format(timeFormat),
StatusCodeRequestUnsupported,
message,
getIssuer(r.Issuer),
)
}

func (r *LogoutResponse) makePartialLogoutResponse(
func (r *LogoutResponse) makeFailedLogoutResponse(
reason error,
message string,
timeFormat string,
) *samlp.LogoutResponseType {
return makeLogoutResponse(
r.RequestID,
r.LogoutURL,
time.Now().UTC().Format(timeFormat),
StatusCodePartialLogout,
reason.Error(),
message,
getIssuer(r.Issuer),
)
}

func (r *LogoutResponse) makeDeniedLogoutResponse(
message string,
timeFormat string,
) *samlp.LogoutResponseType {
func (r *LogoutResponse) makeSuccessfulLogoutResponse(timeFormat string) *samlp.LogoutResponseType {
return makeLogoutResponse(
r.RequestID,
r.LogoutURL,
time.Now().UTC().Format(timeFormat),
StatusCodeRequestDenied,
message,
statusCodeSuccess.Error(),
"",
getIssuer(r.Issuer),
)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ func (p *Provider) AuthCallbackResponse(ctx context.Context, authRequest models.
}

// AuthCallbackErrorResponse returns the SAMLResponse from as failed SAMLRequest
func (p *Provider) AuthCallbackErrorResponse(response *Response, reason, description string) *samlp.ResponseType {
func (p *Provider) AuthCallbackErrorResponse(response *Response, reason error, description string) *samlp.ResponseType {
return p.identityProvider.errorResponse(response, reason, description)
}

Expand Down
31 changes: 16 additions & 15 deletions pkg/provider/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package provider
import (
"crypto/rsa"
"encoding/base64"
"errors"
"fmt"
"html/template"
"net/http"
Expand All @@ -13,17 +14,17 @@ import (
"github.com/zitadel/saml/pkg/provider/xml/samlp"
)

const (
StatusCodeSuccess = "urn:oasis:names:tc:SAML:2.0:status:Success"
StatusCodeVersionMissmatch = "urn:oasis:names:tc:SAML:2.0:status:VersionMismatch"
StatusCodeAuthNFailed = "urn:oasis:names:tc:SAML:2.0:status:AuthnFailed"
StatusCodeInvalidAttrNameOrValue = "urn:oasis:names:tc:SAML:2.0:status:InvalidAttrNameOrValue"
StatusCodeInvalidNameIDPolicy = "urn:oasis:names:tc:SAML:2.0:status:InvalidNameIDPolicy"
StatusCodeRequestDenied = "urn:oasis:names:tc:SAML:2.0:status:RequestDenied"
StatusCodeRequestUnsupported = "urn:oasis:names:tc:SAML:2.0:status:RequestUnsupported"
StatusCodeUnsupportedBinding = "urn:oasis:names:tc:SAML:2.0:status:UnsupportedBinding"
StatusCodeResponder = "urn:oasis:names:tc:SAML:2.0:status:Responder"
StatusCodePartialLogout = "urn:oasis:names:tc:SAML:2.0:status:PartialLogout"
var (
statusCodeSuccess = errors.New("urn:oasis:names:tc:SAML:2.0:status:Success")
StatusCodeVersionMissmatch = errors.New("urn:oasis:names:tc:SAML:2.0:status:VersionMismatch")
StatusCodeAuthNFailed = errors.New("urn:oasis:names:tc:SAML:2.0:status:AuthnFailed")
StatusCodeInvalidAttrNameOrValue = errors.New("urn:oasis:names:tc:SAML:2.0:status:InvalidAttrNameOrValue")
StatusCodeInvalidNameIDPolicy = errors.New("urn:oasis:names:tc:SAML:2.0:status:InvalidNameIDPolicy")
StatusCodeRequestDenied = errors.New("urn:oasis:names:tc:SAML:2.0:status:RequestDenied")
StatusCodeRequestUnsupported = errors.New("urn:oasis:names:tc:SAML:2.0:status:RequestUnsupported")
StatusCodeUnsupportedBinding = errors.New("urn:oasis:names:tc:SAML:2.0:status:UnsupportedBinding")
StatusCodeResponder = errors.New("urn:oasis:names:tc:SAML:2.0:status:Responder")
StatusCodePartialLogout = errors.New("urn:oasis:names:tc:SAML:2.0:status:PartialLogout")
)

type Response struct {
Expand Down Expand Up @@ -112,7 +113,7 @@ func createSignature(response *Response, samlResponse *samlp.ResponseType, key *
}

func (r *Response) makeFailedResponse(
reason string,
reason error,
message string,
timeFormat string,
) *samlp.ResponseType {
Expand All @@ -123,7 +124,7 @@ func (r *Response) makeFailedResponse(
r.RequestID,
r.AcsUrl,
nowStr,
reason,
reason.Error(),
message,
r.Issuer,
)
Expand Down Expand Up @@ -151,7 +152,7 @@ func (r *Response) makeAssertionResponse(
attributes *Attributes,
) *samlp.ResponseType {

response := makeResponse(NewID(), r.RequestID, r.AcsUrl, issueInstant, StatusCodeSuccess, "", r.Issuer)
response := makeResponse(NewID(), r.RequestID, r.AcsUrl, issueInstant, statusCodeSuccess.Error(), "", r.Issuer)
assertion := makeAssertion(r.RequestID, r.AcsUrl, r.SendIP, issueInstant, untilInstant, r.Issuer, attributes.GetNameID(), attributes.GetSAML(), r.Audience, true)
response.Assertion = *assertion
return response
Expand Down Expand Up @@ -194,7 +195,7 @@ func makeAttributeQueryResponse(
}
}

response := makeResponse(NewID(), requestID, "", nowStr, StatusCodeSuccess, "", issuer)
response := makeResponse(NewID(), requestID, "", nowStr, statusCodeSuccess.Error(), "", issuer)
assertion := makeAssertion(requestID, "", "", nowStr, fiveFromNowStr, issuer, attributes.GetNameID(), providedAttrs, entityID, false)
response.Assertion = *assertion
return response
Expand Down
6 changes: 3 additions & 3 deletions pkg/provider/sso_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ func TestSSO_ssoHandleFunc(t *testing.T) {
},
res{
code: 200,
state: StatusCodeRequestDenied,
state: StatusCodeRequestDenied.Error(),
err: false,
}},
{
Expand Down Expand Up @@ -557,7 +557,7 @@ func TestSSO_ssoHandleFunc(t *testing.T) {
},
res{
code: 200,
state: StatusCodeRequestDenied,
state: StatusCodeRequestDenied.Error(),
err: false,
}},
{
Expand Down Expand Up @@ -590,7 +590,7 @@ func TestSSO_ssoHandleFunc(t *testing.T) {
},
res{
code: 200,
state: StatusCodeRequestDenied,
state: StatusCodeRequestDenied.Error(),
err: false,
}},
{
Expand Down

0 comments on commit eab34b4

Please sign in to comment.