Skip to content

Commit

Permalink
create handler to add standard headers to all requests instead of inj…
Browse files Browse the repository at this point in the history
…ecting them in individual routes (#43)
  • Loading branch information
aro5000 authored May 3, 2022
1 parent aeacabd commit a6718f7
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 34 deletions.
17 changes: 0 additions & 17 deletions internal/headers/headers.go

This file was deleted.

3 changes: 0 additions & 3 deletions internal/health/health.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@ import (
"net/http"

log "github.com/sirupsen/logrus"

h "tmpnotes/internal/headers"
)

func HealthCheck(w http.ResponseWriter, r *http.Request) {
h.AddStandardHeaders(w.Header())
log.Info(r.RequestURI)

if r.Method == "GET" {
Expand Down
3 changes: 0 additions & 3 deletions internal/notes/counts.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"net/http"

log "github.com/sirupsen/logrus"

h "tmpnotes/internal/headers"
)

type counts struct {
Expand All @@ -16,7 +14,6 @@ type counts struct {
}

func GetCounts(w http.ResponseWriter, r *http.Request) {
h.AddStandardHeaders(w.Header())
log.Info(r.RequestURI)

if r.Method != "GET" {
Expand Down
2 changes: 0 additions & 2 deletions internal/notes/notes.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (

cfg "tmpnotes/internal/config"
"tmpnotes/internal/crypto"
h "tmpnotes/internal/headers"
)

const maxLength = 1000
Expand Down Expand Up @@ -143,7 +142,6 @@ func GetNote(w http.ResponseWriter, r *http.Request) {
id := full[:8]
key := full[8:]
log.Info(id)
h.AddStandardHeaders(w.Header())

if r.Method != "GET" {
log.Errorf("%s Invalid request method: %s", id, r.Method)
Expand Down
37 changes: 28 additions & 9 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
log "github.com/sirupsen/logrus"

cfg "tmpnotes/internal/config"
h "tmpnotes/internal/headers"
"tmpnotes/internal/health"
"tmpnotes/internal/notes"
"tmpnotes/internal/version"
Expand All @@ -22,24 +21,44 @@ func init() {
notes.RedisInit()
}

func addStandardHeaders(h http.Header) {
h.Set("Content-Security-Policy", "default-src 'self'; script-src 'self' https://cdnjs.cloudflare.com; style-src 'self' https://cdn.jsdelivr.net")
h.Set("X-Frame-Options", "DENY")
h.Set("X-Content-Type-Options", "nosniff")
h.Set("X-XSS-Protection", "1; mode=block")
if cfg.Config.EnableHsts {
h.Set("Strict-Transport-Security", "max-age=15552000")
}
}

type tmpnotesHandler func(http.ResponseWriter, *http.Request)

func (th tmpnotesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Anything we want to add for all requests can go in this handler
if r.Method == "GET" {
addStandardHeaders(w.Header())
}
th(w, r)
}

func main() {
port := fmt.Sprint(":", cfg.Config.Port)

fs := http.FileServer(http.Dir("./static"))
http.Handle("/", addHeaders(fs))
http.HandleFunc("/new", notes.AddNote)
http.HandleFunc("/id/", notes.GetNote)
http.HandleFunc("/counts", notes.GetCounts)
http.HandleFunc("/healthz", health.HealthCheck)
http.HandleFunc("/version", version.GetVersion)
http.Handle("/", serveStatic(fs))
http.Handle("/new", tmpnotesHandler(notes.AddNote))
http.Handle("/id/", tmpnotesHandler(notes.GetNote))
http.Handle("/counts", tmpnotesHandler(notes.GetCounts))
http.Handle("/healthz", tmpnotesHandler(health.HealthCheck))
http.Handle("/version", tmpnotesHandler(version.GetVersion))
log.Info("Server listening at ", port)
http.ListenAndServe(port, nil)
}

func addHeaders(fs http.Handler) http.HandlerFunc {
func serveStatic(fs http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
log.Info(r.RequestURI)
h.AddStandardHeaders(w.Header())
addStandardHeaders(w.Header())
fs.ServeHTTP(w, r)
}
}

0 comments on commit a6718f7

Please sign in to comment.