Skip to content

Commit

Permalink
add PRISME_PROXY_REQUEST_ID_HEADER configuration option
Browse files Browse the repository at this point in the history
  • Loading branch information
negrel committed Oct 23, 2024
1 parent 5dc0669 commit 194bcfd
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 46 deletions.
17 changes: 10 additions & 7 deletions pkg/config/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ type Server struct {
TrustProxy bool
// X-Forwarded-For proxy header.
ProxyHeader string
// X-Request-Id proxy header.
ProxyRequestIdHeader string
// host:port address of admin http server.
AdminHostPort string
// Timeout for /api/v1/events/* handlers.
Expand All @@ -23,12 +25,13 @@ type Server struct {
// ServerFromEnv loads server related options from environment variables.
func ServerFromEnv() Server {
return Server{
AccessLog: GetEnvOrDefault("PRISME_ACCESS_LOG", "/dev/stdout"),
Debug: GetEnvOrDefault("PRISME_DEBUG", "false") != "false",
Port: uint16(ParseUintEnvOrDefault("PRISME_PORT", 80, 16)),
TrustProxy: GetEnvOrDefault("PRISME_TRUST_PROXY", "false") != "false",
ProxyHeader: GetEnvOrDefault("PRISME_PROXY_HEADER", "X-Forwarded-For"),
AdminHostPort: GetEnvOrDefault("PRISME_ADMIN_HOSTPORT", "127.0.0.1:9090"),
ApiEventsTimeout: ParseDurationEnvOrDefault("PRISME_API_EVENTS_TIMEOUT", 3*time.Second),
AccessLog: GetEnvOrDefault("PRISME_ACCESS_LOG", "/dev/stdout"),
Debug: GetEnvOrDefault("PRISME_DEBUG", "false") != "false",
Port: uint16(ParseUintEnvOrDefault("PRISME_PORT", 80, 16)),
TrustProxy: GetEnvOrDefault("PRISME_TRUST_PROXY", "false") != "false",
ProxyHeader: GetEnvOrDefault("PRISME_PROXY_HEADER", "X-Forwarded-For"),
ProxyRequestIdHeader: GetEnvOrDefault("PRISME_PROXY_REQUEST_ID_HEADER", "X-Request-ID"),
AdminHostPort: GetEnvOrDefault("PRISME_ADMIN_HOSTPORT", "127.0.0.1:9090"),
ApiEventsTimeout: ParseDurationEnvOrDefault("PRISME_API_EVENTS_TIMEOUT", 3*time.Second),
}
}
2 changes: 1 addition & 1 deletion pkg/middlewares/request_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func ProvideRequestId(cfg config.Server) RequestId {
var requestId string

if cfg.TrustProxy {
requestId = utils.UnsafeString(c.Request().Header.Peek("X-Request-Id"))
requestId = utils.UnsafeString(c.Request().Header.Peek(cfg.ProxyRequestIdHeader))
}

if requestId == "" {
Expand Down
138 changes: 100 additions & 38 deletions pkg/middlewares/request_id_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,34 +37,65 @@ func TestRequestIdMiddleware(t *testing.T) {
})

t.Run("WithRequestIdHeader", func(t *testing.T) {
middlewareCalled := false
reqRequestId := uuid.New()
t.Run("Default", func(t *testing.T) {
middlewareCalled := false
reqRequestId := uuid.New()

app := fiber.New()
app.Use(fiber.Handler(ProvideRequestId(cfg)))
app.Use(func(c *fiber.Ctx) error {
middlewareCalled = true
app := fiber.New()
app.Use(fiber.Handler(ProvideRequestId(cfg)))
app.Use(func(c *fiber.Ctx) error {
middlewareCalled = true

requestId := c.Locals(RequestIdKey{}).(string)
require.Regexp(t, "[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}", requestId)
require.NotEqual(t, reqRequestId.String(), requestId)
requestId := c.Locals(RequestIdKey{}).(string)
require.Regexp(t, "[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}", requestId)
require.NotEqual(t, reqRequestId.String(), requestId)

return nil
return nil
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
// Add request id.
req.Header.Add(fiber.HeaderXRequestID, reqRequestId.String())

_, err := app.Test(req)
require.NoError(t, err)
require.True(t, middlewareCalled)
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
// Add request id.
req.Header.Add(fiber.HeaderXRequestID, reqRequestId.String())
t.Run("Custom", func(t *testing.T) {
cfg := config.Server{
TrustProxy: false,
ProxyRequestIdHeader: "X-Custom-Request-Id",
}
middlewareCalled := false

_, err := app.Test(req)
require.NoError(t, err)
require.True(t, middlewareCalled)
app := fiber.New()
app.Use(fiber.Handler(ProvideRequestId(cfg)))
app.Use(func(c *fiber.Ctx) error {
middlewareCalled = true

requestId := c.Locals(RequestIdKey{}).(string)
require.Regexp(t, "[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}", requestId)

return nil
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
// Add request ids.
req.Header.Add(fiber.HeaderXRequestID, "bar")
req.Header.Add("X-Custom-Request-Id", "foo")

_, err := app.Test(req)
require.NoError(t, err)
require.True(t, middlewareCalled)
})
})
})

t.Run("TrustProxy", func(t *testing.T) {
cfg := config.Server{
TrustProxy: true,
TrustProxy: true,
ProxyRequestIdHeader: fiber.HeaderXRequestID,
}

t.Run("WithoutRequestIdHeader", func(t *testing.T) {
Expand All @@ -88,29 +119,60 @@ func TestRequestIdMiddleware(t *testing.T) {
})

t.Run("WithRequestIdHeader", func(t *testing.T) {
middlewareCalled := false
expectedRequestId := uuid.New()

app := fiber.New()
app.Use(fiber.Handler(ProvideRequestId(cfg)))
app.Use(func(c *fiber.Ctx) error {
middlewareCalled = true

require.Equal(t, expectedRequestId.String(), c.Locals(RequestIdKey{}))
return nil
t.Run("Default", func(t *testing.T) {
middlewareCalled := false
expectedRequestId := uuid.New()

app := fiber.New()
app.Use(fiber.Handler(ProvideRequestId(cfg)))
app.Use(func(c *fiber.Ctx) error {
middlewareCalled = true

require.Equal(t, expectedRequestId.String(), c.Locals(RequestIdKey{}))
return nil
})
app.Get("/", func(c *fiber.Ctx) error {
return nil
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
// Add request id.
req.Header.Add(fiber.HeaderXRequestID, expectedRequestId.String())

_, err := app.Test(req)
require.NoError(t, err)
require.True(t, middlewareCalled)
})
app.Get("/", func(c *fiber.Ctx) error {
t.Log("HELLO")
return nil
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
// Add request id.
req.Header.Add(fiber.HeaderXRequestID, expectedRequestId.String())

_, err := app.Test(req)
require.NoError(t, err)
require.True(t, middlewareCalled)
t.Run("Custom", func(t *testing.T) {
cfg := config.Server{
TrustProxy: true,
ProxyRequestIdHeader: "X-Custom-Request-Id",
}

middlewareCalled := false

app := fiber.New()
app.Use(fiber.Handler(ProvideRequestId(cfg)))
app.Use(func(c *fiber.Ctx) error {
middlewareCalled = true

require.Equal(t, "foo", c.Locals(RequestIdKey{}))
return nil
})
app.Get("/", func(c *fiber.Ctx) error {
return nil
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
// Add request id.
req.Header.Add(fiber.HeaderXRequestID, "bar")
req.Header.Add("X-Custom-Request-Id", "foo")

_, err := app.Test(req)
require.NoError(t, err)
require.True(t, middlewareCalled)
})
})
})
}

0 comments on commit 194bcfd

Please sign in to comment.