diff --git a/internal/headers/headers.go b/internal/headers/headers.go deleted file mode 100644 index 42e1df6..0000000 --- a/internal/headers/headers.go +++ /dev/null @@ -1,17 +0,0 @@ -package headers - -import ( - "net/http" - - cfg "tmpnotes/internal/config" -) - -func AddStandardHeaders(h http.Header) { - h.Set("Content-Security-Policy", "default-src 'self'; script-src 'self' https://cdnjs.cloudflare.com; style-src 'self' https://cdn.jsdelivr.net") - h.Set("X-Frame-Options", "DENY") - h.Set("X-Content-Type-Options", "nosniff") - h.Set("X-XSS-Protection", "1; mode=block") - if cfg.Config.EnableHsts { - h.Set("Strict-Transport-Security", "max-age=15552000") - } -} diff --git a/internal/health/health.go b/internal/health/health.go index 840d880..f7c1a12 100644 --- a/internal/health/health.go +++ b/internal/health/health.go @@ -5,12 +5,9 @@ import ( "net/http" log "github.com/sirupsen/logrus" - - h "tmpnotes/internal/headers" ) func HealthCheck(w http.ResponseWriter, r *http.Request) { - h.AddStandardHeaders(w.Header()) log.Info(r.RequestURI) if r.Method == "GET" { diff --git a/internal/notes/counts.go b/internal/notes/counts.go index 3c3ec01..32ef009 100644 --- a/internal/notes/counts.go +++ b/internal/notes/counts.go @@ -5,8 +5,6 @@ import ( "net/http" log "github.com/sirupsen/logrus" - - h "tmpnotes/internal/headers" ) type counts struct { @@ -16,7 +14,6 @@ type counts struct { } func GetCounts(w http.ResponseWriter, r *http.Request) { - h.AddStandardHeaders(w.Header()) log.Info(r.RequestURI) if r.Method != "GET" { diff --git a/internal/notes/notes.go b/internal/notes/notes.go index d2f4373..1d62e08 100644 --- a/internal/notes/notes.go +++ b/internal/notes/notes.go @@ -17,7 +17,6 @@ import ( cfg "tmpnotes/internal/config" "tmpnotes/internal/crypto" - h "tmpnotes/internal/headers" ) const maxLength = 1000 @@ -143,7 +142,6 @@ func GetNote(w http.ResponseWriter, r *http.Request) { id := full[:8] key := full[8:] log.Info(id) - h.AddStandardHeaders(w.Header()) if r.Method != "GET" { log.Errorf("%s Invalid request method: %s", id, r.Method) diff --git a/main.go b/main.go index 38d3368..1d4c1c9 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,6 @@ import ( log "github.com/sirupsen/logrus" cfg "tmpnotes/internal/config" - h "tmpnotes/internal/headers" "tmpnotes/internal/health" "tmpnotes/internal/notes" "tmpnotes/internal/version" @@ -22,24 +21,44 @@ func init() { notes.RedisInit() } +func addStandardHeaders(h http.Header) { + h.Set("Content-Security-Policy", "default-src 'self'; script-src 'self' https://cdnjs.cloudflare.com; style-src 'self' https://cdn.jsdelivr.net") + h.Set("X-Frame-Options", "DENY") + h.Set("X-Content-Type-Options", "nosniff") + h.Set("X-XSS-Protection", "1; mode=block") + if cfg.Config.EnableHsts { + h.Set("Strict-Transport-Security", "max-age=15552000") + } +} + +type tmpnotesHandler func(http.ResponseWriter, *http.Request) + +func (th tmpnotesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Anything we want to add for all requests can go in this handler + if r.Method == "GET" { + addStandardHeaders(w.Header()) + } + th(w, r) +} + func main() { port := fmt.Sprint(":", cfg.Config.Port) fs := http.FileServer(http.Dir("./static")) - http.Handle("/", addHeaders(fs)) - http.HandleFunc("/new", notes.AddNote) - http.HandleFunc("/id/", notes.GetNote) - http.HandleFunc("/counts", notes.GetCounts) - http.HandleFunc("/healthz", health.HealthCheck) - http.HandleFunc("/version", version.GetVersion) + http.Handle("/", serveStatic(fs)) + http.Handle("/new", tmpnotesHandler(notes.AddNote)) + http.Handle("/id/", tmpnotesHandler(notes.GetNote)) + http.Handle("/counts", tmpnotesHandler(notes.GetCounts)) + http.Handle("/healthz", tmpnotesHandler(health.HealthCheck)) + http.Handle("/version", tmpnotesHandler(version.GetVersion)) log.Info("Server listening at ", port) http.ListenAndServe(port, nil) } -func addHeaders(fs http.Handler) http.HandlerFunc { +func serveStatic(fs http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { log.Info(r.RequestURI) - h.AddStandardHeaders(w.Header()) + addStandardHeaders(w.Header()) fs.ServeHTTP(w, r) } }