Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Logging, auditing, warnings, and PingContext failure handling for switch DB #77

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/audit/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ type QueryData struct {
Namespace string
Pod string
Timestamp int64
DBName string
}

type Audit interface {
Expand Down
2 changes: 1 addition & 1 deletion pkg/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func Run(logger *zap.SugaredLogger) error {
r := mux.NewRouter()
r.Handle("/healthcheck", logHandler(healthLogOutput, handlers.Healthcheck(cfg))).Methods("GET")
r.Handle("/query", logHandler(defaultLogOutput, queryHandler)).Methods("POST")
r.Handle("/dbname", logHandler(defaultLogOutput, handlers.GetCurrentDBName(cfg))).Methods("GET")
r.Handle("/dbname", logHandler(defaultLogOutput, handlers.GetDBName(cfg))).Methods("GET")
r.Handle("/dbname/switch", logHandler(defaultLogOutput, handlers.SwitchDBName(cfg))).Methods("POST")

port := 8080
Expand Down
15 changes: 13 additions & 2 deletions pkg/handlers/getdbname.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,26 @@ package handlers
import (
"encoding/json"
"net/http"
"os"

gabi "github.com/app-sre/gabi/pkg"
"github.com/app-sre/gabi/pkg/models"
)

func GetCurrentDBName(cfg *gabi.Config) http.Handler {
func GetDBName(cfg *gabi.Config) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
dbName := cfg.DBEnv.GetCurrentDBName()
defaultDBName := os.Getenv("DB_NAME")

response := models.DBNameResponse{DBName: dbName}

if dbName != defaultDBName {
warning := "Current database differs from the default"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add dbName and defaultDBName info to the warn msg?

cfg.Logger.Warnf(warning)
response.Warnings = append(response.Warnings, warning)
}

w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(models.DBNameResponse{DBName: dbName})
json.NewEncoder(w).Encode(response)
})
}
51 changes: 33 additions & 18 deletions pkg/handlers/getdbname_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package handlers
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"testing"

"github.com/app-sre/gabi/internal/test"
Expand All @@ -16,53 +16,68 @@ import (
"github.com/stretchr/testify/require"
)

func TestGetCurrentDBName(t *testing.T) {
func TestGetDBName(t *testing.T) {
cases := []struct {
description string
dbName string
code int
body map[string]string
description string
dbName string
defaultDBName string
expectedStatus int
expectedBody string
want string
}{
{
"returns current database name",
"test_db",
"test_db",
200,
map[string]string{"db_name": "test_db"},
`{"db_name":"test_db"}`,
"",
},
{
"returns empty database name",
"",
"",
200,
`{"db_name":""}`,
"",
},
{
"returns warning when current db name is different from default",
"test_db",
"default_db",
200,
map[string]string{"db_name": ""},
`{"db_name":"test_db","warnings":["Current database differs from the default"]}`,
"Current database differs from the default",
},
}

for _, tc := range cases {
tc := tc
t.Run(tc.description, func(t *testing.T) {
var body bytes.Buffer
var output bytes.Buffer

os.Setenv("DB_NAME", tc.defaultDBName)
defer os.Unsetenv("DB_NAME")

dbEnv := &db.Env{Name: tc.dbName}

w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", &bytes.Buffer{})

logger := test.DummyLogger(io.Discard).Sugar()
logger := test.DummyLogger(&output).Sugar()

expected := &gabi.Config{DBEnv: dbEnv, Logger: logger}
GetCurrentDBName(expected).ServeHTTP(w, r.WithContext(context.TODO()))
GetDBName(expected).ServeHTTP(w, r.WithContext(context.TODO()))

actual := w.Result()
defer func() { _ = actual.Body.Close() }()

_, _ = io.Copy(&body, actual.Body)

var responseBody map[string]string
err := json.Unmarshal(body.Bytes(), &responseBody)

body, err := io.ReadAll(actual.Body)
require.NoError(t, err)
assert.Equal(t, tc.code, actual.StatusCode)
assert.Equal(t, tc.body, responseBody)

assert.Equal(t, tc.expectedStatus, actual.StatusCode)
assert.JSONEq(t, tc.expectedBody, string(body))
assert.Contains(t, output.String(), tc.want)
})
}
}
9 changes: 9 additions & 0 deletions pkg/handlers/healthcheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
"os"
"time"

"github.com/etherlabsio/healthcheck/v2"
Expand All @@ -14,17 +15,25 @@ import (
const healthcheckTimeout = 5 * time.Second

func Healthcheck(cfg *gabi.Config) http.Handler {
defaultDBName := os.Getenv("DB_NAME")
return healthcheck.Handler(
healthcheck.WithTimeout(healthcheckTimeout),
healthcheck.WithChecker(
"database", healthcheck.CheckerFunc(
func(ctx context.Context) error {
dbName := cfg.DBEnv.GetCurrentDBName()
err := cfg.DB.PingContext(ctx)
if err != nil {
l := "Unable to connect to the database"
cfg.Logger.Errorf("%s: %s", l, err)
return errors.New(l)
}

if dbName != defaultDBName {
l := "Current database differs from the default"
cfg.Logger.Warnf(l)
}

return nil
},
),
Expand Down
33 changes: 28 additions & 5 deletions pkg/handlers/healthcheck_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ import (
"io"
"net/http"
"net/http/httptest"
"os"
"testing"

"github.com/DATA-DOG/go-sqlmock"
"github.com/app-sre/gabi/internal/test"
gabi "github.com/app-sre/gabi/pkg"
"github.com/app-sre/gabi/pkg/env/db"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand All @@ -19,16 +21,20 @@ func TestHealthcheck(t *testing.T) {
t.Parallel()

cases := []struct {
description string
given func(sqlmock.Sqlmock)
code int
body string
description string
given func(sqlmock.Sqlmock)
dbName string
defaultDBName string
code int
body string
}{
{
"database is accessible and returns ping reply",
func(mock sqlmock.Sqlmock) {
mock.ExpectPing()
},
"default_db",
"default_db",
200,
`{"status":"OK"}`,
},
Expand All @@ -37,9 +43,21 @@ func TestHealthcheck(t *testing.T) {
func(mock sqlmock.Sqlmock) {
mock.ExpectPing().WillReturnError(errors.New("test"))
},
"default_db",
"default_db",
503,
`{"database":"Unable to connect to the database"}`,
},
{
"database name differs from the default",
func(mock sqlmock.Sqlmock) {
mock.ExpectPing()
},
"test_db",
"default_db",
200,
``,
},
}

for _, tc := range cases {
Expand All @@ -49,6 +67,11 @@ func TestHealthcheck(t *testing.T) {

var body bytes.Buffer

os.Setenv("DB_NAME", tc.defaultDBName)
defer os.Unsetenv("DB_NAME")

dbEnv := &db.Env{Name: tc.dbName}

db, mock, _ := sqlmock.New(sqlmock.MonitorPingsOption(true))
defer func() { _ = db.Close() }()

Expand All @@ -59,7 +82,7 @@ func TestHealthcheck(t *testing.T) {

tc.given(mock)

expected := &gabi.Config{DB: db, Logger: logger}
expected := &gabi.Config{DB: db, Logger: logger, DBEnv: dbEnv}
Healthcheck(expected).ServeHTTP(w, r)

actual := w.Result()
Expand Down
12 changes: 11 additions & 1 deletion pkg/handlers/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,17 @@ func Query(cfg *gabi.Config) http.HandlerFunc {
var (
base64Mode byte
request models.QueryRequest
warnings []string
)

defaultDBName := os.Getenv("DB_NAME")
currentDBName := cfg.DBEnv.GetCurrentDBName()
if currentDBName != defaultDBName {
l := "Current database differs from the default"
cfg.Logger.Warnf(l)
warnings = append(warnings, l)
}

if s := r.URL.Query().Get("base64_results"); s != "" {
if ok, err := strconv.ParseBool(s); err == nil && ok {
base64Mode |= base64EncodeResults
Expand Down Expand Up @@ -159,7 +168,8 @@ func Query(cfg *gabi.Config) http.HandlerFunc {
w.Header().Set("Cache-Control", "private, no-store")
w.Header().Set("Content-Type", "application/json; charset=utf-8")
_ = json.NewEncoder(w).Encode(&models.QueryResponse{
Result: result,
Result: result,
Warnings: warnings,
})
}
}
Expand Down
Loading