Skip to content

Commit

Permalink
add resterror package
Browse files Browse the repository at this point in the history
  • Loading branch information
negrel committed Jan 10, 2024
1 parent 91cfd9f commit e864c4f
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 0 deletions.
30 changes: 30 additions & 0 deletions cmd/server/fiber.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package main

import (
"strings"

"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/proxy"
"github.com/prismelabs/prismeanalytics/internal/config"
"github.com/prismelabs/prismeanalytics/internal/middlewares"
"github.com/prismelabs/prismeanalytics/internal/renderer"
Expand Down Expand Up @@ -36,5 +39,32 @@ func ProvideFiber(
app.Use(middlewares.RequestId(cfg.Server))
app.Use(middlewares.AccessLog(accessLogger.Logger))

// Error handler.
// Handle error manually before access log middleware.
app.Use(middlewares.RestError)

apiGroup := app.Group("/api")

// Auth API is forwarded to auth service.
apiAuthGroup := apiGroup.Group("/auth")
apiAuthGroup.Use(proxy.Balancer(proxy.Config{
Servers: []string{cfg.Auth.ServiceUrl.String()},
ModifyRequest: func(c *fiber.Ctx) error {
// Strip prefix.
c.Path(strings.TrimPrefix(c.Path(), "/api/auth"))
return nil
},
Next: func(c *fiber.Ctx) bool {
path := strings.TrimPrefix(c.Path(), "/api/auth")
switch path {
case "/token", "/logout", "/verify", "/signup", "/recover", "/user":
return false

default:
return true
}
},
}))

return app
}
23 changes: 23 additions & 0 deletions internal/middlewares/resterror.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package middlewares

import (
"github.com/gofiber/fiber/v2"
"github.com/prismelabs/prismeanalytics/internal/resterror"
)

func RestError(c *fiber.Ctx) error {
err := c.Next()
if err != nil {
// Handler error.
handlerErr := resterror.FiberErrorHandler(c, err)
// If failed to handle error, return it.
if handlerErr != nil {
return handlerErr
}

// Return error so it can be logged.
return err
}

return err
}
73 changes: 73 additions & 0 deletions internal/resterror/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package resterror

import (
"errors"
"fmt"

"github.com/gofiber/fiber/v2"
)

// RestError is JSON serializable error type that can be used with any web framework.
type RestError struct {
Code string
}

// RestError implements the error interface.
func (re RestError) Error() string {
return fmt.Sprintf("RestError(%v)", re.Code)
}

// Create a new error joining a RestError with the given code and errs.
func New(code string, errs ...error) error {
errs = append(errs, RestError{code})
return errors.Join(errs...)
}

// NewFiber returns a new error with the given error code, response status and
// error description.
func NewFiber(code string, status int, desc string) error {
return New(code, fiber.NewError(status, desc))
}

// NewFiberf works the same as NewFiber but adds formatting support.
func NewFiberf(code string, status int, format string, args ...any) error {
return NewFiber(code, status, fmt.Sprintf(format, args...))
}

const InternalServerErrorCode = "InternalServerError"

// FiberErrorHandler is a fiber error handler that returns JSON encoded response.
// It supports RestError and *fiber.Error and errors.Join of them. If error handled
// contains none of them, a generic InternalServerError is sent without disclosing
// it. Fiber error with a status >= 500 are also hidden.
func FiberErrorHandler(c *fiber.Ctx, err error) error {
responseBody := struct {
Code string `json:"error_code"`
Desc string `json:"error_description"`
}{
Code: InternalServerErrorCode,
Desc: "internal server error, check server logs for more information",
}
c.Response().SetStatusCode(fiber.StatusInternalServerError)

var restErr RestError
if errors.As(err, &restErr) {
c.Response().SetStatusCode(fiber.StatusBadRequest)
responseBody.Code = restErr.Code
responseBody.Desc = ""
}

var fiberErr *fiber.Error
if errors.As(err, &fiberErr) {
c.Response().SetStatusCode(fiberErr.Code)
if fiberErr.Code < 500 {
responseBody.Desc = fiberErr.Message
}

if responseBody.Code == InternalServerErrorCode {
responseBody.Code = "UnexpectedError"
}
}

return c.JSON(responseBody)
}
88 changes: 88 additions & 0 deletions internal/resterror/error_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package resterror

import (
"errors"
"io"
"net/http/httptest"
"testing"

"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/require"
)

func TestFiberErrorHandler(t *testing.T) {
app := fiber.New(fiber.Config{
ErrorHandler: FiberErrorHandler,
})

t.Run("FiberErrorOnly", func(t *testing.T) {
app.Get("/fiber_error", func(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusUnauthorized, "authorization header missing or invalid")
})

req := httptest.NewRequest("GET", "http://server.local/fiber_error", nil)

resp, err := app.Test(req)
require.NoError(t, err)

body, err := io.ReadAll(resp.Body)
require.NoError(t, err)

require.Equal(t, `{"error_code":"UnexpectedError","error_description":"authorization header missing or invalid"}`, string(body))
require.Equal(t, resp.StatusCode, fiber.StatusUnauthorized)
})

t.Run("RestErrorOnly", func(t *testing.T) {
app.Get("/rest_error", func(c *fiber.Ctx) error {
return New("InvalidAuthorizationHeader")
})

req := httptest.NewRequest("GET", "http://server.local/rest_error", nil)

resp, err := app.Test(req)
require.NoError(t, err)

body, err := io.ReadAll(resp.Body)
require.NoError(t, err)

require.Equal(t, `{"error_code":"InvalidAuthorizationHeader","error_description":""}`, string(body))
require.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
})

t.Run("FiberRestErrorsJoined", func(t *testing.T) {
app.Get("/fiber_rest_error", func(c *fiber.Ctx) error {
return New(
"InvalidAuthorizationHeader",
fiber.NewError(fiber.StatusUnauthorized, "invalid or missing authorization header"),
)
})

req := httptest.NewRequest("GET", "http://server.local/fiber_rest_error", nil)

resp, err := app.Test(req)
require.NoError(t, err)

body, err := io.ReadAll(resp.Body)
require.NoError(t, err)

require.Equal(t, `{"error_code":"InvalidAuthorizationHeader","error_description":"invalid or missing authorization header"}`, string(body))
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
})

t.Run("GoError", func(t *testing.T) {
app.Get("/go_error", func(c *fiber.Ctx) error {
return errors.New("runtime error")
})

req := httptest.NewRequest("GET", "http://server.local/go_error", nil)

resp, err := app.Test(req)
require.NoError(t, err)

body, err := io.ReadAll(resp.Body)
require.NoError(t, err)

require.Equal(t, `{"error_code":"InternalServerError","error_description":"internal server error, check server logs for more information"}`, string(body))
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
})
}

0 comments on commit e864c4f

Please sign in to comment.