Skip to content

Commit

Permalink
refactor: update AuthMiddleware to handle Bearer token
Browse files Browse the repository at this point in the history
This commit updates the AuthMiddleware function in the middlewares package. It modifies the logic to handle Bearer tokens instead of a generic token. The function now checks if the Authorization header starts with "Bearer " and extracts the token accordingly. It also checks if the extracted token matches the expected token. If not, it returns a 401 unauthorized response. This change improves the security and authentication process in the application.
  • Loading branch information
WangYihang committed Oct 24, 2024
1 parent 6add212 commit 2fa7b08
Show file tree
Hide file tree
Showing 13 changed files with 141 additions and 99 deletions.
28 changes: 11 additions & 17 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@ package main

import (
"context"
"fmt"
"log/slog"
"os"

"github.com/WangYihang/Platypus/pkg/config"
"github.com/WangYihang/Platypus/pkg/dependencies"
"github.com/WangYihang/Platypus/pkg/listeners"
"github.com/WangYihang/Platypus/pkg/models"
"github.com/WangYihang/Platypus/pkg/options"
"github.com/google/uuid"
"go.uber.org/fx"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -42,20 +39,17 @@ func main() {
lc.Append(fx.Hook{
OnStart: func(context.Context) error {
logger.Info("config loaded", zap.Any("config", cfg))
for _, l := range cfg.Listeners {
switch l.Type {
case listeners.ListenerTypePlainShell:
go listeners.NewPlainShellListener(l.BindHost, l.BindPort).Start(logger)
case listeners.ListenerTypeEncryptedShell:
go listeners.NewEncryptedShellListener(l.BindHost, l.BindPort).Start(logger)
case listeners.ListenerTypeRESTful:
token := uuid.New().String()
logger.Info("using generated uuid as token", zap.String("token", token))
go listeners.NewRESTfulListener(l.BindHost, l.BindPort, token).Start(logger)
default:
logger.Error("unsupported listener type", zap.String("type", string(l.Type)))
return fmt.Errorf("unsupported listener type: %s", l.Type)
}
logger.Info("starting plain listeners")
for _, listener := range cfg.PlainListeners {
go listener.Start(logger)
}
logger.Info("starting encrypted listeners")
for _, listener := range cfg.EncryptedListeners {
go listener.Start(logger)
}
logger.Info("starting RESTful listeners")
for _, listener := range cfg.RestfulListeners {
go listener.Start(logger)
}
return nil
},
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/google/uuid v1.3.0
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/jessevdk/go-flags v1.6.1
github.com/matishsiao/goInfo v0.0.0-20240924010139-10388a85396f
github.com/sevlyar/go-daemon v0.1.5
github.com/shirou/gopsutil v3.21.11+incompatible
go.uber.org/fx v1.22.2
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/matishsiao/goInfo v0.0.0-20240924010139-10388a85396f h1:XDrsC/9hdgiU9ecceSmYsS2E3fBtFiYc34dAMFgegnM=
github.com/matishsiao/goInfo v0.0.0-20240924010139-10388a85396f/go.mod h1:aEt7p9Rvh67BYApmZwNDPpgircTO2kgdmDUoF/1QmwA=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
Expand Down
6 changes: 4 additions & 2 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ import (

// Config represents the configuration.
type Config struct {
Listeners []listeners.Listener `json:"listeners" yaml:"listeners" toml:"listeners"`
PlainListeners []listeners.PlainListener `json:"plain" yaml:"plain" toml:"plain"`
EncryptedListeners []listeners.EncryptedListener `json:"encrypted" yaml:"encrypted" toml:"encrypted"`
RestfulListeners []listeners.RESTfulListener `json:"restful" yaml:"restful" toml:"restful"`
}

// LoadConfig loads the configuration from the given path.
Expand All @@ -24,7 +26,7 @@ func LoadConfig(path string) (*Config, error) {
return loadJSONConfig(path)
case ".yaml", ".yml":
return loadYAMLConfig(path)
case ".toml":
case ".toml", ".tml":
return loadTOMLConfig(path)
default:
return nil, fmt.Errorf("unsupported config file format: %s", ext)
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
// NewStatusController returns a new status controller
// Status controller returns the status of the server, including hostname, CPU usage, disk usage, and memory usage
func NewStatusController() func(c *gin.Context) {
cache := expirable.NewLRU[string, models.Status](1, nil, time.Second*60)
cache := expirable.NewLRU[string, models.Status](1, nil, time.Minute)
return func(c *gin.Context) {
r, ok := cache.Get("status")
if ok {
Expand Down
23 changes: 23 additions & 0 deletions pkg/listeners/encrypted.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package listeners

import "go.uber.org/zap"

// EncryptedListener represents an encrypted shell listener.
type EncryptedListener struct {
commonListener
}

// NewEncryptedListener creates a new encrypted shell listener.
func NewEncryptedListener(host string, port uint16) *EncryptedListener {
return &EncryptedListener{
commonListener: commonListener{
BindHost: host,
BindPort: port,
},
}
}

// Start starts the encrypted shell listener.
func (l *EncryptedListener) Start(logger *zap.Logger) {
logger.Info("starting encrypted listener", zap.String("host", l.BindHost), zap.Uint16("port", l.BindPort))
}
23 changes: 0 additions & 23 deletions pkg/listeners/encrypted_shell_listener.go

This file was deleted.

20 changes: 3 additions & 17 deletions pkg/listeners/listener.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,6 @@
package listeners

// ListenerType represents the type of listener.
type ListenerType string

const (
// ListenerTypePlainShell represents a plain shell listener.
ListenerTypePlainShell ListenerType = "plain_shell"
// ListenerTypeEncryptedShell represents an encrypted shell listener.
ListenerTypeEncryptedShell ListenerType = "encrypted_shell"
// ListenerTypeRESTful represents a RESTful listener.
ListenerTypeRESTful ListenerType = "restful"
)

// Listener is a struct that represents a listener
type Listener struct {
BindHost string `json:"bind_host" yaml:"bind_host" toml:"bind_host"`
BindPort uint16 `json:"bind_port" yaml:"bind_port" toml:"bind_port"`
Type ListenerType `json:"type" yaml:"type" toml:"type"`
type commonListener struct {
BindHost string `json:"bind_host" yaml:"bind_host" toml:"bind_host"`
BindPort uint16 `json:"bind_port" yaml:"bind_port" toml:"bind_port"`
}
23 changes: 23 additions & 0 deletions pkg/listeners/plain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package listeners

import "go.uber.org/zap"

// PlainListener represents a plain shell listener.
type PlainListener struct {
commonListener
}

// NewPlainListener creates a new plain shell listener.
func NewPlainListener(host string, port uint16) *PlainListener {
return &PlainListener{
commonListener: commonListener{
BindHost: host,
BindPort: port,
},
}
}

// Start starts the plain shell listener.
func (l *PlainListener) Start(logger *zap.Logger) {
logger.Info("starting plain listener", zap.String("host", l.BindHost), zap.Uint16("port", l.BindPort))
}
23 changes: 0 additions & 23 deletions pkg/listeners/plain_shell_listener.go

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ import (

// RESTfulListener represents a RESTful listener.
type RESTfulListener struct {
Listener
commonListener
Token string `json:"token" yaml:"token" toml:"token"`
}

// NewRESTfulListener creates a new RESTful listener.
func NewRESTfulListener(host string, port uint16, token string) *RESTfulListener {
return &RESTfulListener{
Listener: Listener{
commonListener: commonListener{
BindHost: host,
BindPort: port,
},
Expand All @@ -30,6 +30,7 @@ func (l *RESTfulListener) Start(logger *zap.Logger) {
logger.Info("starting RESTful listener", zap.String("host", l.BindHost), zap.Uint16("port", l.BindPort))
gin.SetMode(gin.ReleaseMode)
r := gin.Default()
logger.Info("configuring routes with token", zap.String("token", l.Token))
routes.ConfigureRoutes(r, logger, l.Token)
err := r.Run(fmt.Sprintf("%s:%d", l.BindHost, l.BindPort))
if err != nil {
Expand Down
27 changes: 24 additions & 3 deletions pkg/middlewares/auth.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,37 @@
package middlewares

import (
"strings"

"github.com/gin-gonic/gin"
)

// AuthMiddleware is a middleware that checks if the request has the correct token.
func AuthMiddleware(token string) gin.HandlerFunc {
// AuthMiddleware is a middleware that checks if the request has the correct Bearer token.
func AuthMiddleware(expectedToken string) gin.HandlerFunc {
return func(c *gin.Context) {
if c.GetHeader("Authorization") != token {
// Get the Authorization header
authorizationHeader := c.GetHeader("Authorization")
if authorizationHeader == "" {
c.AbortWithStatusJSON(401, gin.H{"error": "missing authorization header"})
return
}

// Check if the header starts with "Bearer " and extract the token
if !strings.HasPrefix(authorizationHeader, "Bearer ") {
c.AbortWithStatusJSON(401, gin.H{"error": "invalid authorization scheme"})
return
}

// Extract the token
token := strings.TrimPrefix(authorizationHeader, "Bearer ")

// Check if the token matches the expected token
if token != expectedToken {
c.AbortWithStatusJSON(401, gin.H{"error": "unauthorized"})
return
}

// Token is valid; proceed to the next handler
c.Next()
}
}
Loading

0 comments on commit 2fa7b08

Please sign in to comment.