diff --git a/README.md b/README.md index 092d742..9e94120 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,36 @@ Flags: -p, --port int port to listen on (default 8080) ``` +## Usage as a library +`github.com/coreruleset/albedo/server` package provides a handler that can be used for testing purposes. +Usage example: +```go +package albedo_test + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/coreruleset/albedo/server" + "github.com/stretchr/testify/require" +) + +func TestAlbedo(t *testing.T) { + testServer := httptest.NewServer(server.Handler()) + defer testServer.Close() + + client := http.Client{ + Timeout: time.Duration(1 * time.Second), + } + + _, err := client.Get(testServer.URL) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) +} +``` + ## Endpoints ```yaml diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..93993b3 --- /dev/null +++ b/main_test.go @@ -0,0 +1,24 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/coreruleset/albedo/server" + "github.com/stretchr/testify/require" +) + +func TestAlbedoLibrary(t *testing.T) { + testServer := httptest.NewServer(server.Handler()) + defer testServer.Close() + + client := http.Client{ + Timeout: time.Duration(1 * time.Second), + } + + resp, err := client.Get(testServer.URL) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) +} diff --git a/server/server.go b/server/server.go index fac13d8..b6a68c0 100644 --- a/server/server.go +++ b/server/server.go @@ -66,19 +66,26 @@ const capabilitiesDescription = ` func Start(binding string, port int) *http.Server { server := &http.Server{ - Addr: fmt.Sprintf("%s:%d", binding, port), + Addr: fmt.Sprintf("%s:%d", binding, port), + Handler: Handler(), } - http.HandleFunc("/", handleDefault) - http.HandleFunc("/capabilities", handleCapabilities) - http.HandleFunc("/capabilities/", handleCapabilities) - http.HandleFunc("POST /reflect", handleReflect) - http.HandleFunc("POST /reflect/", handleReflect) - log.Fatal(server.ListenAndServe()) return server } +func Handler() http.Handler { + mux := http.NewServeMux() + + mux.HandleFunc("/", handleDefault) + mux.HandleFunc("/capabilities", handleCapabilities) + mux.HandleFunc("/capabilities/", handleCapabilities) + mux.HandleFunc("POST /reflect", handleReflect) + mux.HandleFunc("POST /reflect/", handleReflect) + + return mux +} + // Respond with empty 200 for all requests by default func handleDefault(w http.ResponseWriter, r *http.Request) { log.Printf("Received default request to %s", r.URL)