diff --git a/cmd/registration_relay/main.go b/cmd/registration_relay/main.go index a817e77..2e9d742 100644 --- a/cmd/registration_relay/main.go +++ b/cmd/registration_relay/main.go @@ -34,6 +34,12 @@ func main() { "Metrics listen address", ) + validateAuthURL := flag.String( + "validateAuthURL", + flagenv.StringEnvWithDefault("REGISTRATION_RELAY_VALIDATE_AUTH_URL", ""), + "Validate auth header URL", + ) + flag.Parse() if *prettyLogs { @@ -48,6 +54,7 @@ func main() { cfg := config.Config{} cfg.API.Listen = *listenAddr + cfg.API.ValidateAuthURL = *validateAuthURL log.Info().Str("commit", Commit).Str("build_time", BuildTime).Msg("registration-relay starting") diff --git a/internal/api/api.go b/internal/api/api.go index b3364ee..a556e2c 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -41,7 +41,15 @@ func NewAPI(cfg config.Config) *api { r.Get("/health", health.Health) r.Get("/api/v1/provider", api.providerWebsocket) - r.Post("/api/v1/bridge/{command}", api.bridgeExecuteCommand) + + commandHandler := api.bridgeExecuteCommand + if cfg.API.ValidateAuthURL != "" { + commandHandler = api.requireAuthHandler( + cfg.API.ValidateAuthURL, + commandHandler, + ) + } + r.Post("/api/v1/bridge/{command}", commandHandler) api.server = &http.Server{Addr: cfg.API.Listen, Handler: r} diff --git a/internal/api/auth.go b/internal/api/auth.go new file mode 100644 index 0000000..ef00176 --- /dev/null +++ b/internal/api/auth.go @@ -0,0 +1,76 @@ +package api + +import ( + "encoding/json" + "net/http" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/hlog" +) + +var httpClient = &http.Client{} + +type authResp struct { + Identifier string `json:"identifier"` +} + +func (a *api) requireAuthHandler( + validateURL string, + next http.HandlerFunc, +) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + authToken := r.Header.Get("Authorization") + + if authToken == "" { + a.log.Warn().Msg("Request missing auth header") + w.WriteHeader(http.StatusUnauthorized) + return + } + + req, err := http.NewRequest(http.MethodGet, validateURL, nil) + if err != nil { + a.log.Err(err).Msg("Failed to create request to auth validation URL") + w.WriteHeader(http.StatusInternalServerError) + return + } + + req.Header.Add("Authorization", authToken) + + resp, err := httpClient.Do(req) + if err != nil { + a.log.Err(err).Msg("Failed to make request to auth validation URL") + w.WriteHeader(http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + if resp.StatusCode >= 500 { + a.log.Error(). + Int("status_code", resp.StatusCode). + Msg("Unexpected status from auth URL") + w.WriteHeader(http.StatusInternalServerError) + return + } + + if resp.StatusCode != 200 { + a.log.Warn(). + Int("status_code", resp.StatusCode). + Msg("Unauthorized status from auth URL") + w.WriteHeader(http.StatusUnauthorized) + return + } + + var response authResp + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + a.log.Err(err).Msg("Failed to decode auth response") + w.WriteHeader(http.StatusInternalServerError) + return + } + + hlog.FromRequest(r).UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("identifier", response.Identifier) + }) + + next(w, r) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index fe82c24..46cfa56 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,6 +3,7 @@ package config type Config struct { Version string API struct { - Listen string + Listen string + ValidateAuthURL string } }