Skip to content

Commit

Permalink
fix(http): add basic auth middleware for http server (#1372)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ja7ad authored Jun 26, 2024
1 parent d15c3b9 commit 8955d0c
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 132 deletions.
88 changes: 0 additions & 88 deletions www/http/blockchain.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,6 @@ import (

func (s *Server) BlockchainHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if s.enableAuth {
user, password, ok := r.BasicAuth()
if !ok {
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "unauthorized", http.StatusUnauthorized)

return
}

ctx = s.basicAuth(ctx, user, password)
}

res, err := s.blockchain.GetBlockchainInfo(ctx,
&pactus.GetBlockchainInfoRequest{})
Expand All @@ -53,17 +42,6 @@ func (s *Server) BlockchainHandler(w http.ResponseWriter, r *http.Request) {

func (s *Server) GetBlockByHeightHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if s.enableAuth {
user, password, ok := r.BasicAuth()
if !ok {
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "unauthorized", http.StatusUnauthorized)

return
}

ctx = s.basicAuth(ctx, user, password)
}

vars := mux.Vars(r)
height, err := strconv.ParseInt(vars["height"], 10, 32)
Expand All @@ -77,17 +55,6 @@ func (s *Server) GetBlockByHeightHandler(w http.ResponseWriter, r *http.Request)

func (s *Server) GetBlockByHashHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if s.enableAuth {
user, password, ok := r.BasicAuth()
if !ok {
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "unauthorized", http.StatusUnauthorized)

return
}

ctx = s.basicAuth(ctx, user, password)
}

vars := mux.Vars(r)
blockHash, err := hash.FromString(vars["hash"])
Expand Down Expand Up @@ -155,17 +122,6 @@ func (s *Server) blockByHeight(ctx context.Context, w http.ResponseWriter, block
// GetAccountHandler returns a handler to get account by address.
func (s *Server) GetAccountHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if s.enableAuth {
user, password, ok := r.BasicAuth()
if !ok {
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "unauthorized", http.StatusUnauthorized)

return
}

ctx = s.basicAuth(ctx, user, password)
}

vars := mux.Vars(r)
res, err := s.blockchain.GetAccount(ctx,
Expand All @@ -189,17 +145,6 @@ func (s *Server) GetAccountHandler(w http.ResponseWriter, r *http.Request) {
// GetValidatorHandler returns a handler to get validator by address.
func (s *Server) GetValidatorHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if s.enableAuth {
user, password, ok := r.BasicAuth()
if !ok {
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "unauthorized", http.StatusUnauthorized)

return
}

ctx = s.basicAuth(ctx, user, password)
}

vars := mux.Vars(r)
res, err := s.blockchain.GetValidator(ctx,
Expand All @@ -217,17 +162,6 @@ func (s *Server) GetValidatorHandler(w http.ResponseWriter, r *http.Request) {
// GetValidatorByNumberHandler returns a handler to get validator by number.
func (s *Server) GetValidatorByNumberHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if s.enableAuth {
user, password, ok := r.BasicAuth()
if !ok {
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "unauthorized", http.StatusUnauthorized)

return
}

ctx = s.basicAuth(ctx, user, password)
}

vars := mux.Vars(r)

Expand All @@ -254,17 +188,6 @@ func (s *Server) GetValidatorByNumberHandler(w http.ResponseWriter, r *http.Requ

func (s *Server) GetTxPoolContentHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if s.enableAuth {
user, password, ok := r.BasicAuth()
if !ok {
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "unauthorized", http.StatusUnauthorized)

return
}

ctx = s.basicAuth(ctx, user, password)
}

res, err := s.blockchain.GetTxPoolContent(ctx, &pactus.GetTxPoolContentRequest{})
if err != nil {
Expand Down Expand Up @@ -298,17 +221,6 @@ func (*Server) writeValidatorTable(val *pactus.ValidatorInfo) *tableMaker {

func (s *Server) ConsensusHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if s.enableAuth {
user, password, ok := r.BasicAuth()
if !ok {
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "unauthorized", http.StatusUnauthorized)

return
}

ctx = s.basicAuth(ctx, user, password)
}

res, err := s.blockchain.GetConsensusInfo(ctx,
&pactus.GetConsensusInfoRequest{})
Expand Down
28 changes: 28 additions & 0 deletions www/http/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package http

import (
"net/http"

"github.com/pactus-project/pactus/www/grpc/basicauth"
"google.golang.org/grpc/metadata"
)

func basicAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, password, ok := r.BasicAuth()
if !ok {
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "unauthorized", http.StatusUnauthorized)

return
}

ba := basicauth.New(user, password)
tokens, _ := ba.GetRequestMetadata(r.Context())
md := metadata.New(tokens)

r = r.WithContext(metadata.NewOutgoingContext(r.Context(), md))

next.ServeHTTP(w, r)
})
}
72 changes: 72 additions & 0 deletions www/http/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package http

import (
"encoding/base64"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata"
)

func TestBasicAuthMiddleware(t *testing.T) {
handler := basicAuth(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte("authorized"))
assert.NoError(t, err)
}))

t.Run("NoAuth", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
rr := httptest.NewRecorder()

handler.ServeHTTP(rr, req)

assert.Equal(t, http.StatusUnauthorized, rr.Code)
assert.Equal(t, `Basic realm="restricted", charset="UTF-8"`, rr.Header().Get("WWW-Authenticate"))
})

t.Run("WithAuth", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req.SetBasicAuth("username", "password")
rr := httptest.NewRecorder()

handler.ServeHTTP(rr, req)

assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "authorized", rr.Body.String())
})

t.Run("CheckMetadata", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req.SetBasicAuth("username", "password")
rr := httptest.NewRecorder()

checkMetadataHandler := basicAuth(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
md, ok := metadata.FromOutgoingContext(r.Context())
if !ok {
t.Errorf("No metadata in context")
}

auth := md["authorization"][0]

const prefix = "Basic "
c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
assert.NoError(t, err)
cs := string(c)
username, password, ok := strings.Cut(cs, ":")
assert.True(t, ok)

assert.Equal(t, "username", username)
assert.Equal(t, "password", password)

w.WriteHeader(http.StatusOK)
}))

checkMetadataHandler.ServeHTTP(rr, req)

assert.Equal(t, http.StatusOK, rr.Code)
})
}
22 changes: 0 additions & 22 deletions www/http/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,6 @@ import (

func (s *Server) NetworkHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if s.enableAuth {
user, password, ok := r.BasicAuth()
if !ok {
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "unauthorized", http.StatusUnauthorized)

return
}

ctx = s.basicAuth(ctx, user, password)
}

onlyConnected := false

Expand Down Expand Up @@ -115,17 +104,6 @@ func (s *Server) NetworkHandler(w http.ResponseWriter, r *http.Request) {

func (s *Server) NodeHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if s.enableAuth {
user, password, ok := r.BasicAuth()
if !ok {
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "unauthorized", http.StatusUnauthorized)

return
}

ctx = s.basicAuth(ctx, user, password)
}

res, err := s.network.GetNodeInfo(ctx,
&pactus.GetNodeInfoRequest{})
Expand Down
17 changes: 6 additions & 11 deletions www/http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@ import (
ret "github.com/grpc-ecosystem/go-grpc-middleware/retry"
"github.com/pactus-project/pactus/types/amount"
"github.com/pactus-project/pactus/util/logger"
"github.com/pactus-project/pactus/www/grpc/basicauth"
pactus "github.com/pactus-project/pactus/www/grpc/gen/go"
"github.com/prometheus/client_golang/prometheus/promhttp"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
)

type Server struct {
Expand Down Expand Up @@ -87,7 +85,12 @@ func (s *Server) StartServer(grpcServer string) error {
s.router.HandleFunc("/validator/address/{address}", s.GetValidatorHandler)
s.router.HandleFunc("/validator/number/{number}", s.GetValidatorByNumberHandler)
s.router.HandleFunc("/metrics/prometheus", promhttp.Handler().ServeHTTP)
http.Handle("/", handlers.RecoveryHandler()(s.router))

if s.enableAuth {
http.Handle("/", handlers.RecoveryHandler()(basicAuth(s.router)))
} else {
http.Handle("/", handlers.RecoveryHandler()(s.router))
}

listener, err := net.Listen("tcp", s.config.Listen)
if err != nil {
Expand Down Expand Up @@ -190,14 +193,6 @@ func (*Server) writeHTML(w http.ResponseWriter, html string) int {
return n
}

func (*Server) basicAuth(ctx context.Context, username, password string) context.Context {
ba := basicauth.New(username, password)
tokens, _ := ba.GetRequestMetadata(ctx)
md := metadata.New(tokens)

return metadata.NewOutgoingContext(ctx, md)
}

type tableMaker struct {
w *bytes.Buffer
}
Expand Down
11 changes: 0 additions & 11 deletions www/http/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,6 @@ import (

func (s *Server) GetTransactionHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if s.enableAuth {
user, password, ok := r.BasicAuth()
if !ok {
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "unauthorized", http.StatusUnauthorized)

return
}

ctx = s.basicAuth(ctx, user, password)
}

vars := mux.Vars(r)
id, err := hex.DecodeString(vars["id"])
Expand Down

0 comments on commit 8955d0c

Please sign in to comment.