-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
214 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
package middlewares | ||
|
||
import ( | ||
"github.com/gofiber/fiber/v2" | ||
"github.com/prismelabs/prismeanalytics/internal/resterror" | ||
) | ||
|
||
func RestError(c *fiber.Ctx) error { | ||
err := c.Next() | ||
if err != nil { | ||
// Handler error. | ||
handlerErr := resterror.FiberErrorHandler(c, err) | ||
// If failed to handle error, return it. | ||
if handlerErr != nil { | ||
return handlerErr | ||
} | ||
|
||
// Return error so it can be logged. | ||
return err | ||
} | ||
|
||
return err | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
package resterror | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
|
||
"github.com/gofiber/fiber/v2" | ||
) | ||
|
||
// RestError is JSON serializable error type that can be used with any web framework. | ||
type RestError struct { | ||
Code string | ||
} | ||
|
||
// RestError implements the error interface. | ||
func (re RestError) Error() string { | ||
return fmt.Sprintf("RestError(%v)", re.Code) | ||
} | ||
|
||
// Create a new error joining a RestError with the given code and errs. | ||
func New(code string, errs ...error) error { | ||
errs = append(errs, RestError{code}) | ||
return errors.Join(errs...) | ||
} | ||
|
||
// NewFiber returns a new error with the given error code, response status and | ||
// error description. | ||
func NewFiber(code string, status int, desc string) error { | ||
return New(code, fiber.NewError(status, desc)) | ||
} | ||
|
||
// NewFiberf works the same as NewFiber but adds formatting support. | ||
func NewFiberf(code string, status int, format string, args ...any) error { | ||
return NewFiber(code, status, fmt.Sprintf(format, args...)) | ||
} | ||
|
||
const InternalServerErrorCode = "InternalServerError" | ||
|
||
// FiberErrorHandler is a fiber error handler that returns JSON encoded response. | ||
// It supports RestError and *fiber.Error and errors.Join of them. If error handled | ||
// contains none of them, a generic InternalServerError is sent without disclosing | ||
// it. Fiber error with a status >= 500 are also hidden. | ||
func FiberErrorHandler(c *fiber.Ctx, err error) error { | ||
responseBody := struct { | ||
Code string `json:"error_code"` | ||
Desc string `json:"error_description"` | ||
}{ | ||
Code: InternalServerErrorCode, | ||
Desc: "internal server error, check server logs for more information", | ||
} | ||
c.Response().SetStatusCode(fiber.StatusInternalServerError) | ||
|
||
var restErr RestError | ||
if errors.As(err, &restErr) { | ||
c.Response().SetStatusCode(fiber.StatusBadRequest) | ||
responseBody.Code = restErr.Code | ||
responseBody.Desc = "" | ||
} | ||
|
||
var fiberErr *fiber.Error | ||
if errors.As(err, &fiberErr) { | ||
c.Response().SetStatusCode(fiberErr.Code) | ||
if fiberErr.Code < 500 { | ||
responseBody.Desc = fiberErr.Message | ||
} | ||
|
||
if responseBody.Code == InternalServerErrorCode { | ||
responseBody.Code = "UnexpectedError" | ||
} | ||
} | ||
|
||
return c.JSON(responseBody) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
package resterror | ||
|
||
import ( | ||
"errors" | ||
"io" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
"github.com/gofiber/fiber/v2" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestFiberErrorHandler(t *testing.T) { | ||
app := fiber.New(fiber.Config{ | ||
ErrorHandler: FiberErrorHandler, | ||
}) | ||
|
||
t.Run("FiberErrorOnly", func(t *testing.T) { | ||
app.Get("/fiber_error", func(c *fiber.Ctx) error { | ||
return fiber.NewError(fiber.StatusUnauthorized, "authorization header missing or invalid") | ||
}) | ||
|
||
req := httptest.NewRequest("GET", "http://server.local/fiber_error", nil) | ||
|
||
resp, err := app.Test(req) | ||
require.NoError(t, err) | ||
|
||
body, err := io.ReadAll(resp.Body) | ||
require.NoError(t, err) | ||
|
||
require.Equal(t, `{"error_code":"UnexpectedError","error_description":"authorization header missing or invalid"}`, string(body)) | ||
require.Equal(t, resp.StatusCode, fiber.StatusUnauthorized) | ||
}) | ||
|
||
t.Run("RestErrorOnly", func(t *testing.T) { | ||
app.Get("/rest_error", func(c *fiber.Ctx) error { | ||
return New("InvalidAuthorizationHeader") | ||
}) | ||
|
||
req := httptest.NewRequest("GET", "http://server.local/rest_error", nil) | ||
|
||
resp, err := app.Test(req) | ||
require.NoError(t, err) | ||
|
||
body, err := io.ReadAll(resp.Body) | ||
require.NoError(t, err) | ||
|
||
require.Equal(t, `{"error_code":"InvalidAuthorizationHeader","error_description":""}`, string(body)) | ||
require.Equal(t, fiber.StatusBadRequest, resp.StatusCode) | ||
}) | ||
|
||
t.Run("FiberRestErrorsJoined", func(t *testing.T) { | ||
app.Get("/fiber_rest_error", func(c *fiber.Ctx) error { | ||
return New( | ||
"InvalidAuthorizationHeader", | ||
fiber.NewError(fiber.StatusUnauthorized, "invalid or missing authorization header"), | ||
) | ||
}) | ||
|
||
req := httptest.NewRequest("GET", "http://server.local/fiber_rest_error", nil) | ||
|
||
resp, err := app.Test(req) | ||
require.NoError(t, err) | ||
|
||
body, err := io.ReadAll(resp.Body) | ||
require.NoError(t, err) | ||
|
||
require.Equal(t, `{"error_code":"InvalidAuthorizationHeader","error_description":"invalid or missing authorization header"}`, string(body)) | ||
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode) | ||
}) | ||
|
||
t.Run("GoError", func(t *testing.T) { | ||
app.Get("/go_error", func(c *fiber.Ctx) error { | ||
return errors.New("runtime error") | ||
}) | ||
|
||
req := httptest.NewRequest("GET", "http://server.local/go_error", nil) | ||
|
||
resp, err := app.Test(req) | ||
require.NoError(t, err) | ||
|
||
body, err := io.ReadAll(resp.Body) | ||
require.NoError(t, err) | ||
|
||
require.Equal(t, `{"error_code":"InternalServerError","error_description":"internal server error, check server logs for more information"}`, string(body)) | ||
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) | ||
}) | ||
} |