Skip to content

Commit

Permalink
define custom apple claims with validation
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCuse committed Mar 29, 2024
1 parent 5645188 commit 619b2e9
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 130 deletions.
40 changes: 40 additions & 0 deletions lib/oauth/apple/claims.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package apple

import (
"fmt"

"github.com/go-jose/go-jose/v3/jwt"
)

type Claims struct {
Email string `json:"email"`
jwt.Claims
}

// Validate performs apple-specific id_token validation.
// `email` is the only additional claim we currently require.
// See https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_rest_api/authenticating_users_with_sign_in_with_apple#3383773
// for more details.
func (c Claims) Validate(clientID string) error {
if clientID == "" {
return fmt.Errorf("cannot validate with empty clientID")
}

if c.Email == "" {
return fmt.Errorf("missing claim 'email'")
}

if c.Expiry == nil {
return fmt.Errorf("missing claim 'exp'")
}

if c.IssuedAt == nil {
return fmt.Errorf("missing claim 'iat'")
}

// is default 1m leeway OK here?
return c.Claims.Validate(jwt.Expected{
Issuer: BaseURL,
Audience: jwt.Audience{clientID},
})
}
151 changes: 151 additions & 0 deletions lib/oauth/apple/claims_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package apple_test

import (
"testing"
"time"

"github.com/go-jose/go-jose/v3/jwt"
"github.com/keratin/authn-server/lib/oauth/apple"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)

func TestClaimsValidate(t *testing.T) {
t.Run("invalid", func(t *testing.T) {
t.Run("empty clientID", func(t *testing.T) {
claims := apple.Claims{}
err := claims.Validate("")
assert.EqualError(t, err, "cannot validate with empty clientID")
})

t.Run("missing email", func(t *testing.T) {
claims := apple.Claims{}
err := claims.Validate("audience")
assert.EqualError(t, err, "missing claim 'email'")
})

t.Run("missing Expiry", func(t *testing.T) {
claims := apple.Claims{
Email: "email",
Claims: jwt.Claims{},
}
err := claims.Validate("audience")
assert.EqualError(t, err, "missing claim 'exp'")
})

t.Run("missing IssuedAt", func(t *testing.T) {
claims := apple.Claims{
Email: "email",
Claims: jwt.Claims{
Expiry: jwt.NewNumericDate(time.Now()),
},
}
err := claims.Validate("audience")
assert.EqualError(t, err, "missing claim 'iat'")
})
})

// ensure we don't break underlying jwt.Claims validation
t.Run("jwt.Claims.Validate", func(t *testing.T) {
t.Run("issuer", func(t *testing.T) {
t.Run("missing", func(t *testing.T) {
claims := apple.Claims{
Email: "email",
Claims: jwt.Claims{
Expiry: jwt.NewNumericDate(time.Unix(0, 0)),
IssuedAt: jwt.NewNumericDate(time.Unix(0, 0)),
},
}
err := claims.Validate("audience")
assert.True(t, errors.Is(err, jwt.ErrInvalidIssuer))
})

t.Run("invalid", func(t *testing.T) {
claims := apple.Claims{
Email: "email",
Claims: jwt.Claims{
Issuer: "invalid",
Expiry: jwt.NewNumericDate(time.Unix(0, 0)),
IssuedAt: jwt.NewNumericDate(time.Unix(0, 0)),
},
}
err := claims.Validate("audience")
assert.True(t, errors.Is(err, jwt.ErrInvalidIssuer))
})
})

t.Run("audience", func(t *testing.T) {
t.Run("missing", func(t *testing.T) {
claims := apple.Claims{
Email: "email",
Claims: jwt.Claims{
Issuer: apple.BaseURL,
Expiry: jwt.NewNumericDate(time.Unix(0, 0)),
IssuedAt: jwt.NewNumericDate(time.Unix(0, 0)),
},
}
err := claims.Validate("audience")
assert.True(t, errors.Is(err, jwt.ErrInvalidAudience))
})

t.Run("invalid", func(t *testing.T) {
claims := apple.Claims{
Email: "email",
Claims: jwt.Claims{
Issuer: apple.BaseURL,
Audience: jwt.Audience{"invalid"},
Expiry: jwt.NewNumericDate(time.Unix(0, 0)),
IssuedAt: jwt.NewNumericDate(time.Unix(0, 0)),
},
}
err := claims.Validate("audience")
assert.True(t, errors.Is(err, jwt.ErrInvalidAudience))
})
})

t.Run("expired", func(t *testing.T) {
claims := apple.Claims{
Email: "email",
Claims: jwt.Claims{
Issuer: apple.BaseURL,
Audience: jwt.Audience{"audience"},
// Default leeway is 1 minute
Expiry: jwt.NewNumericDate(time.Now().Add(-2 * time.Minute)),
IssuedAt: jwt.NewNumericDate(time.Unix(0, 0)),
},
}
err := claims.Validate("audience")
assert.True(t, errors.Is(err, jwt.ErrExpired))
})

t.Run("issued in the future", func(t *testing.T) {
claims := apple.Claims{
Email: "email",
Claims: jwt.Claims{
Issuer: apple.BaseURL,
Audience: jwt.Audience{"audience"},
Expiry: jwt.NewNumericDate(time.Now()),
// Default leeway is 1 minute
IssuedAt: jwt.NewNumericDate(time.Now().Add(2 * time.Minute)),
},
}
err := claims.Validate("audience")
assert.Error(t, err)
assert.True(t, errors.Is(err, jwt.ErrIssuedInTheFuture))
})
})

t.Run("valid", func(t *testing.T) {
claims := apple.Claims{
Email: "email",
Claims: jwt.Claims{
Issuer: apple.BaseURL,
Audience: jwt.Audience{"audience"},
IssuedAt: jwt.NewNumericDate(time.Now().Add(-30 * time.Second)),
Expiry: jwt.NewNumericDate(time.Now().Add(30 * time.Second)),
},
}
err := claims.Validate("audience")
assert.NoError(t, err)
})
}
117 changes: 30 additions & 87 deletions lib/oauth/apple/idtoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package apple
import (
"fmt"
"net/http"
"strings"
"time"

"github.com/go-jose/go-jose/v3"
Expand All @@ -16,38 +15,44 @@ type TokenReader struct {
clientID string
}

func NewTokenReader(clientID string) *TokenReader {
return &TokenReader{
func NewTokenReader(clientID string, opts ...func(*TokenReader)) *TokenReader {
tr := &TokenReader{
clientID: clientID,
keyStore: newSigningKeyStore(&http.Client{
}

for _, opt := range opts {
opt(tr)
}

if tr.keyStore == nil {
tr.keyStore = newSigningKeyStore(&http.Client{
Timeout: 10 * time.Second,
}),
})
}

return tr
}

func (tr *TokenReader) GetUserDetailsFromToken(t *oauth2.Token) (string, string, error) {
claims, err := tr.getAppleIDTokenClaims(t)
if err != nil {
return "", "", fmt.Errorf("failed to get apple ID token claims: %w", err)
func WithKeyStore(ks rsaKeyStore) func(*TokenReader) {
return func(tr *TokenReader) {
tr.keyStore = ks
}

return extractUserFromClaims(claims, tr.clientID)
}

func (tr *TokenReader) getAppleIDTokenClaims(t *oauth2.Token) (map[string]interface{}, error) {
func (tr *TokenReader) GetUserDetailsFromToken(t *oauth2.Token) (string, string, error) {
idTokenVal := t.Extra("id_token")
if idTokenVal == nil {
return nil, fmt.Errorf("missing id_token")
return "", "", fmt.Errorf("missing id_token")
}

idToken, ok := idTokenVal.(string)
if !ok {
return nil, fmt.Errorf("id_token is not a string")
return "", "", fmt.Errorf("id_token is not a string")
}

parsedIDToken, err := jwt.ParseSigned(idToken)
if err != nil {
return nil, err
return "", "", err
}

var hdr *jose.Header
Expand All @@ -59,87 +64,25 @@ func (tr *TokenReader) getAppleIDTokenClaims(t *oauth2.Token) (map[string]interf
}
}
if hdr == nil {
return nil, fmt.Errorf("no RS256 key header found")
return "", "", fmt.Errorf("no RS256 key header found")
}

appleRSA, err := tr.keyStore.get(hdr.KeyID)
appleRSA, err := tr.keyStore.Get(hdr.KeyID)
if err != nil {
return nil, fmt.Errorf("failed to get apple RSA key: %w", err)
return "", "", fmt.Errorf("failed to Get apple RSA key: %w", err)
}

claims := make(map[string]interface{})
claims := Claims{}
err = parsedIDToken.Claims(appleRSA, &claims)
if err != nil {
return nil, fmt.Errorf("failed to verify claims: %w", err)
}

return claims, nil
}

func extractUserFromClaims(claims map[string]interface{}, clientID string) (string, string, error) {
// We could validate iat here if we had a good minimum value to use.
// A nonce claim is also available but would need to be sent on code exchange.
if iss, ok := claims["iss"]; !ok || !strings.Contains(iss.(string), BaseURL) {
return "", "", fmt.Errorf("invalid or missing issuer")
}

if aud, ok := claims["aud"]; !ok || aud.(string) != clientID {
return "", "", fmt.Errorf("invalid or missing audience")
}

if exp, ok := claims["exp"]; !ok {
return "", "", fmt.Errorf("missing exp")
} else {
expErr := validateExp(exp)
if expErr != nil {
return "", "", expErr
}
}

id, ok := claims["sub"]

if !ok {
return "", "", fmt.Errorf("missing claim 'sub'")
}

idString, ok := id.(string)
if !ok {
return "", "", fmt.Errorf("claim 'sub' is not a string")
return "", "", fmt.Errorf("failed to verify claims: %w", err)
}

email, ok := claims["email"]
// TODO: figure out a way to cleanly pass the nonce in authorize request and make available for validation here
err = claims.Validate(tr.clientID)

if !ok {
return "", "", fmt.Errorf("missing claim 'email'")
}

emailString, ok := email.(string)
if !ok {
return "", "", fmt.Errorf("claim 'email' is not a string")
}

return idString, emailString, nil
}

func validateExp(exp interface{}) error {
switch v := exp.(type) {
case float64:
return validateExpInt64(int64(v))
case int:
return validateExpInt64(int64(v))
case int32:
return validateExpInt64(int64(v))
case int64:
return validateExpInt64(v)
default:
return fmt.Errorf("invalid exp")
}
}

func validateExpInt64(exp int64) error {
if exp < time.Now().Unix() {
return fmt.Errorf("token expired")
if err != nil {
return "", "", fmt.Errorf("failed to validate claims: %w", err)
}

return nil
return claims.Subject, claims.Email, nil
}
Loading

0 comments on commit 619b2e9

Please sign in to comment.