From 14fdfe00d39e8cc45f3388a77558d18ed6a7ae71 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 20 Dec 2023 16:34:23 +0200 Subject: [PATCH] Use safer secret checking --- internal/provider/provider.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 11d1fe1..a4fd414 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -35,10 +35,10 @@ func GetProvider(code string) (*provider, bool) { return p, exists } -func calculateSecret(globalSecret []byte, code string) string { +func calculateSecret(globalSecret []byte, code string) []byte { h := hmac.New(sha256.New, globalSecret) h.Write([]byte(code)) - return base64.RawStdEncoding.EncodeToString(h.Sum(nil)) + return h.Sum(nil) } func RegisterProvider(data registerCommandData, provider *provider) (*registerCommandData, error) { @@ -51,9 +51,13 @@ func RegisterProvider(data registerCommandData, provider *provider) (*registerCo if err != nil { return nil, err } - data.Secret = calculateSecret(provider.globalSecret, data.Code) + data.Secret = base64.RawStdEncoding.EncodeToString(calculateSecret(provider.globalSecret, data.Code)) } else { - if calculateSecret(provider.globalSecret, data.Code) != data.Secret { + if len(data.Code) != 19 || len(data.Secret) > 64 { + return nil, fmt.Errorf("invalid secret") + } + decodedSecret, err := base64.RawStdEncoding.DecodeString(data.Secret) + if err != nil || !hmac.Equal(calculateSecret(provider.globalSecret, data.Code), decodedSecret) { return nil, fmt.Errorf("invalid secret") } if existing, exists := codeToProvider[data.Code]; exists {