Skip to content

Commit

Permalink
- streamline error messages & moved them to unified package.
Browse files Browse the repository at this point in the history
- harden siws verification.
- remove redundant checks.
- add /nonce? endpoint.
- create new nonce table (nonce usage without db is incompliant with siws).
  • Loading branch information
Bewinxed committed Jan 19, 2025
1 parent 0e96f8a commit 54fdc0a
Show file tree
Hide file tree
Showing 11 changed files with 635 additions and 286 deletions.
1 change: 1 addition & 0 deletions internal/api/provider/eip4361.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,4 @@ Expiration Time: %s`,
return "", fmt.Errorf("message generation not implemented for %s", chainCfg.NetworkName)
}
}

95 changes: 92 additions & 3 deletions internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ import (
"github.com/xeipuuv/gojsonschema"

"github.com/supabase/auth/internal/api/provider"
"github.com/supabase/auth/internal/crypto"
"github.com/supabase/auth/internal/hooks"
"github.com/supabase/auth/internal/metering"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/storage"
siws "github.com/supabase/auth/internal/utilities/solana"
)

// AccessTokenClaims is a struct thats used for JWT claims
Expand Down Expand Up @@ -311,6 +313,87 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request)
return sendJSON(w, http.StatusOK, token)
}

type StoredNonce struct {
ID uuid.UUID `db:"id"`
Nonce string `db:"nonce"`
Address string `db:"address"` // Optional: can be empty until signature verification
CreatedAt time.Time `db:"created_at"`
ExpiresAt time.Time `db:"expires_at"`
Used bool `db:"used"`
}

const NonceExpiration = 5 * time.Minute

// GetNonce handles nonce generation requests
func (a *API) GetNonce(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
db := a.db.WithContext(ctx)

nonce := crypto.SecureToken()

storedNonce := &StoredNonce{
ID: uuid.Must(uuid.NewV4()),
Nonce: nonce,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(NonceExpiration),
Used: false,
}

err := db.Transaction(func(tx *storage.Connection) error {
// Store the nonce
_, err := tx.TX.Exec(`
INSERT INTO auth.nonces (id, nonce, created_at, expires_at, used)
VALUES ($1, $2, $3, $4, $5)
`, storedNonce.ID, storedNonce.Nonce, storedNonce.CreatedAt,
storedNonce.ExpiresAt, storedNonce.Used)
return err
})

if err != nil {
return internalServerError("Error storing nonce").WithInternalError(err)
}

return sendJSON(w, http.StatusOK, map[string]interface{}{
"nonce": nonce,
"expiresAt": storedNonce.ExpiresAt,
})
}

func (a *API) verifyAndConsumeNonce(ctx context.Context, nonce string, address string) error {
db := a.db.WithContext(ctx)

var storedNonce StoredNonce
err := db.Transaction(func(tx *storage.Connection) error {
// Find the nonce
err := tx.TX.QueryRow(`
SELECT id, nonce, address, created_at, expires_at, used
FROM auth.nonces
WHERE nonce = $1 AND used = false
`, nonce).Scan(&storedNonce.ID, &storedNonce.Nonce,
&storedNonce.Address, &storedNonce.CreatedAt,
&storedNonce.ExpiresAt, &storedNonce.Used)
if err != nil {
return err
}

// Check expiration
if time.Now().After(storedNonce.ExpiresAt) {
return fmt.Errorf("nonce expired")
}

// Mark as used
_, err = tx.TX.Exec(`
UPDATE auth.nonces
SET used = true, address = $1
WHERE id = $2
`, address, storedNonce.ID)
return err
})

return err
}


func (a *API) Web3Grant(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
db := a.db.WithContext(ctx)

Expand All @@ -319,7 +402,12 @@ func (a *API) Web3Grant(ctx context.Context, w http.ResponseWriter, r *http.Requ
return err
}

web3Provider, err := provider.Web3Provider(ctx, a.config.External.Web3)
// Verify and consume nonce first
if err := a.verifyAndConsumeNonce(ctx, params.Nonce, params.Address); err != nil {
return siws.ErrorCodeInvalidNonce
}

web3Provider, err := provider.NewWeb3Provider(ctx, a.config.External.Web3)
if err != nil {
return err
}
Expand All @@ -333,7 +421,6 @@ func (a *API) Web3Grant(ctx context.Context, w http.ResponseWriter, r *http.Requ
}

userData, err := web3Provider.VerifySignedMessage(msg)

if err != nil {
return oauthError("invalid_grant", "Signature verification failed").WithInternalError(err)
}
Expand All @@ -348,7 +435,6 @@ func (a *API) Web3Grant(ctx context.Context, w http.ResponseWriter, r *http.Requ
return terr
}

// Log the auth attempt
if terr := models.NewAuditLogEntry(r, tx, user, models.LoginAction, "", map[string]interface{}{
"provider": "web3",
"chain": msg.Chain,
Expand Down Expand Up @@ -379,6 +465,7 @@ func (a *API) Web3Grant(ctx context.Context, w http.ResponseWriter, r *http.Requ
return sendJSON(w, http.StatusOK, token)
}


func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user *models.User, sessionId *uuid.UUID, authenticationMethod models.AuthenticationMethod) (string, int64, error) {
config := a.config
if sessionId == nil {
Expand Down Expand Up @@ -576,3 +663,5 @@ func validateTokenClaims(outputClaims map[string]interface{}) error {

return nil
}


2 changes: 2 additions & 0 deletions internal/api/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -855,3 +855,5 @@ $$;`
})
}
}


1 change: 1 addition & 0 deletions internal/api/web3.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ type Web3GrantParams struct {
Signature string `json:"signature"`
Address string `json:"address"`
Chain string `json:"chain"`
Nonce string `json:"nonce"` // Added nonce field
}
162 changes: 99 additions & 63 deletions internal/crypto/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"math"
"math/big"
"net/http"
"net/url"
"strconv"
"strings"

Expand All @@ -20,8 +21,6 @@ import (

"golang.org/x/crypto/hkdf"

"encoding/hex"

"github.com/btcsuite/btcutil/base58"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
Expand Down Expand Up @@ -169,68 +168,103 @@ func NewEncryptedString(id string, data []byte, keyID string, keyBase64URL strin
return &es, nil
}

// VerifySIWS fully verifies:
// - The domain in msg matches expected domain
// - The ed25519 signature matches the parsed SIWS message text
// - The base58-encoded public key is valid
// - The message is within the allowed time window (if requested)
func VerifySIWS(
rawMessage string,
signature []byte,
msg *siws.SIWSMessage,
params siws.SIWSVerificationParams,
rawMessage string,
signature []byte,
msg *siws.SIWSMessage,
params siws.SIWSVerificationParams,
) error {
// 1) Domain check
if params.ExpectedDomain == "" {
// Server misconfiguration
return siws.NewSIWSError("expected domain is not specified", http.StatusInternalServerError)
}

if !siws.IsValidDomain(msg.Domain) {
// Malformed request
return siws.NewSIWSError("invalid domain", http.StatusBadRequest)
}

if msg.Domain != params.ExpectedDomain {
// Per RFC 7235, 403 is more appropriate than 401 here since we're not requesting new credentials
return siws.NewSIWSError("domain mismatch", http.StatusForbidden)
}

// 2) Base58 decode -> ed25519.PublicKey
pubKey := base58.Decode(msg.Address)
if !siws.IsBase58PubKey(msg.Address) {
// Malformed credentials
return siws.NewSIWSError("invalid base58 public key or wrong size (must be 32 bytes)", http.StatusBadRequest)
}

// 3) Verify signature
if !ed25519.Verify(pubKey, []byte(rawMessage), signature) {
// Per RFC 7235, 401 indicates the credentials were rejected and new ones should be provided
return siws.NewSIWSError("signature verification failed", http.StatusUnauthorized)
}

// 4) Time check if requested
if params.CheckTime && params.TimeDuration > 0 {
if msg.IssuedAt.IsZero() {
// Malformed request
return siws.NewSIWSError("issuedAt not set, but time check requested", http.StatusBadRequest)
}

now := time.Now().UTC()
expiry := msg.IssuedAt.Add(params.TimeDuration)

if now.Before(msg.IssuedAt) {
// Invalid timestamp in request
return siws.NewSIWSError("message is issued in the future", http.StatusBadRequest)
}

if now.After(expiry) {
// Per RFC 7235, expired credentials should prompt for new ones
return siws.NewSIWSError("message is expired", http.StatusUnauthorized)
}
}

return nil
// 1) Basic input validation
if rawMessage == "" {
return siws.ErrEmptyRawMessage
}
if len(signature) == 0 {
return siws.ErrEmptySignature
}
if msg == nil {
return siws.ErrNilMessage
}

// 2) Domain validation
if params.ExpectedDomain == "" {
return siws.ErrMissingDomain
}
if !siws.IsValidDomain(msg.Domain) {
return siws.ErrInvalidDomainFormat
}
if msg.Domain != params.ExpectedDomain {
return siws.ErrDomainMismatch
}

// 3) Address/Public Key validation (combined checks)
pubKey := base58.Decode(msg.Address)
if !siws.IsBase58PubKey(pubKey) {
return siws.ErrInvalidPubKeySize
}

// 4) Version validation
if msg.Version != "1" {
return siws.ErrInvalidVersion
}

// 5) Chain ID validation (using helper)
if msg.ChainID != "" {
if !siws.IsValidSolanaNetwork(msg.ChainID) {

return siws.ErrInvalidChainID
}
}

// 6) Nonce validation (consolidated)
if msg.Nonce != "" {
if len(msg.Nonce) < 8 {
return siws.ErrNonceTooShort
}
}

// 7) URI and Resources validation
if msg.URI != "" {
if _, err := url.Parse(msg.URI); err != nil {
return siws.ErrInvalidURI
}
}

for _, resource := range msg.Resources {
if _, err := url.Parse(resource); err != nil {
return siws.ErrInvalidResourceURI
}
}

// 8) Signature verification
if !ed25519.Verify(pubKey, []byte(rawMessage), signature) {
return siws.ErrSignatureVerification
}

// 9) Time validations (consolidated)
now := time.Now().UTC()

if !msg.IssuedAt.IsZero() {
if now.Before(msg.IssuedAt) {
return siws.ErrFutureMessage
}

if params.CheckTime && params.TimeDuration > 0 {
expiry := msg.IssuedAt.Add(params.TimeDuration)
if now.After(expiry) {
return siws.ErrMessageExpired
}
}
}

if !msg.NotBefore.IsZero() && now.Before(msg.NotBefore) {
return siws.ErrNotYetValid
}

if !msg.ExpirationTime.IsZero() && now.After(msg.ExpirationTime) {
return siws.ErrMessageExpired
}

return nil
}

func VerifyEthereumSignature(message string, signature string, address string) error {
Expand Down Expand Up @@ -280,3 +314,5 @@ func removeHexPrefix(signature string) string {
}
return signature
}


Loading

0 comments on commit 54fdc0a

Please sign in to comment.