diff --git a/middlewares/logger.go b/middlewares/logger.go new file mode 100644 index 0000000..7034bc5 --- /dev/null +++ b/middlewares/logger.go @@ -0,0 +1,168 @@ +//go:build go1.21 + +package middlewares + +import ( + "context" + "io" + "log/slog" + "os" + "time" + + "github.com/mojixcoder/kid" +) + +type ( + // LoggerConfig is the config used to build logger middleware. + LoggerConfig struct { + // Logger is the logger instance. + // Optional. If set, Out, Level and Type configs won't be used. + Logger *slog.Logger + + // Out is the writer that logs will be written at. + // Defaults to os.Stdout. + Out io.Writer + + // Level is the log level used for initializing a logger instance. + // Defaults to slog.LevelInfo. + Level slog.Leveler + + // SuccessLevel is the log level when status code < 400. + // Defaults to slog.LevelInfo. + SuccessLevel slog.Leveler + + // ClientErrorLevel is the log level when status code is between 400 and 499. + // Defaults to slog.LevelWarn. + ClientErrorLevel slog.Leveler + + // ServerErrorLevel is the log level when status code >= 500. + // Defaults to slog.LevelError. + ServerErrorLevel slog.Leveler + + // Type is the logger type. + // Defaults to JSON. + Type LoggerType + + // Skipper is a function used for skipping middleware execution. + // Defaults to nil. + Skipper func(c *kid.Context) bool + } + + // LoggerType is the type for specifying logger type. + LoggerType string +) + +const ( + // JSONLogger is the JSON logger type. + TypeJSON LoggerType = "JSON" + + // TextLogger is the text logger type. + TypeText LoggerType = "TEXT" +) + +// DefaultLoggerConfig is the default logger config. +var DefaultLoggerConfig = LoggerConfig{ + Out: os.Stdout, + Level: slog.LevelInfo, + SuccessLevel: slog.LevelInfo, + ClientErrorLevel: slog.LevelWarn, + ServerErrorLevel: slog.LevelError, + Type: TypeJSON, +} + +// NewLogger returns a new logger middleware. +func NewLogger() kid.MiddlewareFunc { + return NewLoggerWithConfig(DefaultLoggerConfig) +} + +// NewLoggerWithConfig returns a new logger middleware with the given config. +func NewLoggerWithConfig(cfg LoggerConfig) kid.MiddlewareFunc { + setLoggerDefaults(&cfg) + + logger := cfg.getLogger() + + successLvl := cfg.SuccessLevel.Level() + clientErrLvl := cfg.ClientErrorLevel.Level() + serverErrLvl := cfg.ServerErrorLevel.Level() + + return func(next kid.HandlerFunc) kid.HandlerFunc { + return func(c *kid.Context) { + // Skip if necessary. + if cfg.Skipper != nil && cfg.Skipper(c) { + next(c) + return + } + + start := time.Now() + + next(c) + + end := time.Now() + req := c.Request() + duration := end.Sub(start) + + status := c.Response().Status() + + attrs := []slog.Attr{ + slog.Time("time", end), + slog.Duration("latency_ns", duration), + slog.String("latency", duration.String()), + slog.Int("status", status), + slog.String("path", req.URL.Path), + slog.String("method", req.Method), + slog.String("user_agent", req.Header.Get("User-Agent")), + } + + if status < 400 { + logger.LogAttrs(context.Background(), successLvl, "SUCCESS", attrs...) + } else if status <= 499 { + logger.LogAttrs(context.Background(), clientErrLvl, "CLIENT ERROR", attrs...) + } else { // 5xx status codes. + logger.LogAttrs(context.Background(), serverErrLvl, "SERVER ERROR", attrs...) + } + } + } +} + +// getLogger returns the appropriate logger instance. +func (cfg LoggerConfig) getLogger() *slog.Logger { + if cfg.Logger != nil { + return cfg.Logger + } + + switch cfg.Type { + case TypeJSON: + return slog.New(slog.NewJSONHandler(cfg.Out, &slog.HandlerOptions{Level: cfg.Level})) + case TypeText: + return slog.New(slog.NewTextHandler(cfg.Out, &slog.HandlerOptions{Level: cfg.Level})) + default: + panic("invalid logger type") + } +} + +// setLoggerDefaults sets logger default values. +func setLoggerDefaults(cfg *LoggerConfig) { + if cfg.Out == nil { + cfg.Out = DefaultLoggerConfig.Out + } + + if cfg.Level == nil { + cfg.Level = DefaultLoggerConfig.Level + } + + if cfg.SuccessLevel == nil { + cfg.SuccessLevel = DefaultLoggerConfig.SuccessLevel + } + + if cfg.ClientErrorLevel == nil { + cfg.ClientErrorLevel = DefaultLoggerConfig.ClientErrorLevel + } + + if cfg.ServerErrorLevel == nil { + cfg.ServerErrorLevel = DefaultLoggerConfig.ServerErrorLevel + } + + if cfg.Type == "" { + cfg.Type = DefaultLoggerConfig.Type + } +} diff --git a/middlewares/logger_test.go b/middlewares/logger_test.go new file mode 100644 index 0000000..990c133 --- /dev/null +++ b/middlewares/logger_test.go @@ -0,0 +1,150 @@ +//go:build go1.21 + +package middlewares + +import ( + "bytes" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/mojixcoder/kid" + "github.com/stretchr/testify/assert" +) + +type logRecord struct { + Msg string `json:"msg"` + Time time.Time `json:"time"` + LatenyNS int64 `json:"latency_ns"` + Latency string `json:"latency"` + Status int `json:"status"` + Path string `json:"path"` + Method string `json:"method"` + UserAgent string `json:"user_agent"` +} + +func TestNewLogger(t *testing.T) { + middleware := NewLogger() + + assert.NotNil(t, middleware) +} + +func TestSetLoggerDefaults(t *testing.T) { + var cfg LoggerConfig + + setLoggerDefaults(&cfg) + + assert.Equal(t, DefaultLoggerConfig.Out, cfg.Out) + assert.Equal(t, DefaultLoggerConfig.Logger, cfg.Logger) + assert.Equal(t, DefaultLoggerConfig.Level, cfg.Level) + assert.Equal(t, DefaultLoggerConfig.ServerErrorLevel, cfg.ServerErrorLevel) + assert.Equal(t, DefaultLoggerConfig.ClientErrorLevel, cfg.ClientErrorLevel) + assert.Equal(t, DefaultLoggerConfig.SuccessLevel, cfg.SuccessLevel) + assert.Equal(t, DefaultLoggerConfig.Type, cfg.Type) +} + +func TestLoggerConfig_getLogger(t *testing.T) { + var cfg LoggerConfig + setLoggerDefaults(&cfg) + + logger := cfg.getLogger() + assert.IsType(t, &slog.JSONHandler{}, logger.Handler()) + + cfg.Type = TypeText + + logger = cfg.getLogger() + assert.IsType(t, &slog.TextHandler{}, logger.Handler()) + + cfg.Logger = slog.New(slog.NewJSONHandler(io.Discard, nil)) + assert.Equal(t, cfg.Logger, cfg.getLogger()) + + assert.PanicsWithValue(t, "invalid logger type", func() { + cfg.Logger = nil + cfg.Type = "" + cfg.getLogger() + }) +} + +func TestNewLoggerWithConfig(t *testing.T) { + var buf bytes.Buffer + + cfg := DefaultLoggerConfig + cfg.Out = &buf + + k := kid.New() + k.Use(NewLoggerWithConfig(cfg)) + + k.Get("/", func(c *kid.Context) { + time.Sleep(time.Millisecond) + c.String(http.StatusOK, "Ok") + }) + + k.Get("/server-error", func(c *kid.Context) { + time.Sleep(time.Millisecond) + c.String(http.StatusInternalServerError, "Internal Server Error") + }) + + k.Get("/not-found", func(c *kid.Context) { + time.Sleep(time.Millisecond) + c.String(http.StatusNotFound, "Not Found") + }) + + testCases := []struct { + path string + msg string + status int + }{ + {path: "/not-found", msg: "CLIENT ERROR", status: http.StatusNotFound}, + {path: "/", msg: "SUCCESS", status: http.StatusOK}, + {path: "/server-error", msg: "SERVER ERROR", status: http.StatusInternalServerError}, + } + + for _, testCase := range testCases { + t.Run(testCase.msg, func(t *testing.T) { + res := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, testCase.path, nil) + req.Header.Set("User-Agent", "Go Test") + + k.ServeHTTP(res, req) + + var logRecord logRecord + err := json.Unmarshal(buf.Bytes(), &logRecord) + assert.NoError(t, err) + + buf.Reset() + + assert.Equal(t, testCase.status, logRecord.Status) + assert.Equal(t, testCase.path, logRecord.Path) + assert.Equal(t, http.MethodGet, logRecord.Method) + assert.Equal(t, "Go Test", logRecord.UserAgent) + assert.NotZero(t, logRecord.Time) + assert.NotEmpty(t, logRecord.Latency) + assert.NotEmpty(t, logRecord.LatenyNS) + assert.Equal(t, testCase.msg, logRecord.Msg) + }) + } +} + +func TestLogger_Skipper(t *testing.T) { + var buf bytes.Buffer + + cfg := DefaultLoggerConfig + cfg.Out = &buf + cfg.Skipper = func(c *kid.Context) bool { + return true + } + + k := kid.New() + k.Use(NewLoggerWithConfig(cfg)) + + res := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + + k.ServeHTTP(res, req) + + assert.Empty(t, buf.Bytes()) +}