Skip to content

Commit

Permalink
test: add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
joshiste committed May 22, 2024
1 parent 461cac0 commit 50ae007
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 109 deletions.
17 changes: 15 additions & 2 deletions exthealth/health.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
/*
* Copyright 2024 steadybit GmbH. All rights reserved.
*/

// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 Steadybit GmbH

package exthealth

import (
"errors"
"fmt"
"github.com/kelseyhightower/envconfig"
"github.com/rs/zerolog/log"
Expand All @@ -14,6 +19,7 @@ import (

var (
isReady int32 = 1
server *http.Server
)

type HealthSpecification struct {
Expand Down Expand Up @@ -66,13 +72,20 @@ func StartProbes(port int) {
addReadinessProbe(serverMux.Handle)
go func() {
log.Info().Msgf("Starting probes server on port %d, ready: %t", healthPort, atomic.LoadInt32(&isReady) == 1)
err := http.ListenAndServe(fmt.Sprintf(":%d", healthPort), serverMux)
if err != nil {
server = &http.Server{Addr: fmt.Sprintf(":%d", healthPort), Handler: serverMux}

if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatal().Err(err).Msgf("Failed to start probes server")
}
}()
}

func StopProbes() {
if server != nil {
_ = server.Close()
}
}

// SetReady sets the readiness state of the service. If the service is not ready the readiness probe will report an error.
func SetReady(ready bool) {
log.Info().Msgf("Update readiness probe - ready: %t", ready)
Expand Down
53 changes: 47 additions & 6 deletions exthealth/health_test.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
/*
* Copyright 2024 steadybit GmbH. All rights reserved.
*/

package exthealth

import (
"context"
"fmt"
"github.com/phayes/freeport"
"github.com/steadybit/extension-kit/exthttp"
"github.com/stretchr/testify/require"
"net"
"net/http"
"path/filepath"
"testing"
)

func TestShouldServeReadinessAndLiveness(t *testing.T) {
SetReady(false)
addLivenessProbe(http.Handle)
addReadinessProbe(http.Handle)

func TestServeProbes(t *testing.T) {
port, err := freeport.GetFreePort()
require.NoError(t, err)

go http.ListenAndServe(fmt.Sprintf(":%d", port), nil)
SetReady(false)
StartProbes(port)
defer StopProbes()

res, err := http.Get(fmt.Sprintf("http://localhost:%d/health/liveness", port))
require.NoError(t, err)
Expand All @@ -32,3 +38,38 @@ func TestShouldServeReadinessAndLiveness(t *testing.T) {
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)
}

func TestServerProbesUsingUnixSocket(t *testing.T) {
sock := filepath.Join(t.TempDir(), "sock")

t.Setenv("STEADYBIT_EXTENSION_UNIX_SOCKET", sock)
go exthttp.Listen(exthttp.ListenOpts{})
exthttp.WaitForServe()
defer exthttp.StopListen()

client := http.Client{
Transport: &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", sock)
},
},
}

SetReady(false)
StartProbes(0)
defer StopProbes()

res, err := client.Get("http://localhost/health/liveness")
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)

res, err = client.Get("http://localhost/health/readiness")
require.NoError(t, err)
require.Equal(t, http.StatusServiceUnavailable, res.StatusCode)

SetReady(true)

res, err = client.Get("http://localhost/health/readiness")
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)
}
138 changes: 99 additions & 39 deletions exthttp/listener.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
/*
* Copyright 2024 steadybit GmbH. All rights reserved.
*/

// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 Steadybit GmbH

Expand All @@ -6,6 +10,7 @@ package exthttp
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"github.com/kelseyhightower/envconfig"
"github.com/rs/zerolog/log"
Expand All @@ -15,6 +20,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"
)

type ListenSpecification struct {
Expand All @@ -25,6 +31,11 @@ type ListenSpecification struct {
TlsClientCas []string `json:"tlsClientCas" split_words:"true" required:"false"`
}

var (
wrapper *httpServerWrapper
serveCond = sync.NewCond(&sync.Mutex{})
)

func (spec *ListenSpecification) parseConfigurationFromEnvironment() {
err := envconfig.Process("steadybit_extension", spec)
if err != nil {
Expand Down Expand Up @@ -60,15 +71,14 @@ type ListenOpts struct {
// Port Default port to bind to. Can be overridden through the environment variable STEADYBIT_EXTENSION_PORT.
Port int
}
type httpServerWrapper struct {
serve func() error
server *http.Server
}

func Listen(opts ListenOpts) {
_, start, err := listen(opts)
if err != nil {
log.Fatal().Err(err).Msgf("Failed to start extension server")
}

err = start()
if err != nil {
err := listen(opts)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatal().Err(err).Msgf("Failed to start extension server")
}
}
Expand All @@ -79,26 +89,61 @@ func IsUnixSocketEnabled() bool {
return spec.UnixSocket != ""
}

func listen(opts ListenOpts) (*http.Server, func() error, error) {
func listen(opts ListenOpts) error {
success := false
serveCond.L.Lock()
defer func() {
if !success {
serveCond.L.Unlock()
}
}()

spec := ListenSpecification{}
spec.parseConfigurationFromEnvironment()
err := spec.validateSpecification()
if err != nil {
log.Fatal().Err(err).Msgf("Failed to validate HTTP server configuration.")
if err := spec.validateSpecification(); err != nil {
return fmt.Errorf("failed to validate listen specification: %w", err)
}

port := opts.Port
if spec.Port != 0 {
port = spec.Port
}

var err error
if spec.UnixSocket != "" {
return prepareUnixSocketServer(spec.UnixSocket)
wrapper, err = prepareUnixSocketServer(spec.UnixSocket)
} else if spec.isTlsEnabled() {
return prepareHttpsServer(port, spec)
wrapper, err = prepareHttpsServer(port, spec)
} else {
return prepareHttpServer(port)
wrapper, err = prepareHttpServer(port)
}
if err != nil {
return err
}

serveCond.Broadcast()
serveCond.L.Unlock()
success = true
if err = wrapper.serve(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
}

func WaitForServe() {
serveCond.L.Lock()
defer serveCond.L.Unlock()
serveCond.Wait()
}

func StopListen() {
if wrapper == nil || wrapper.server == nil {
return
}
if err := wrapper.server.Close(); err != nil {
log.Error().Err(err).Msgf("Failed to stop extension server")
}
wrapper = nil
}

type forwardToZeroLogWriter struct {
Expand All @@ -116,55 +161,63 @@ func (fw *forwardToZeroLogWriter) Write(p []byte) (n int, err error) {
return len([]byte(trimmed)), nil
}

func prepareHttpServer(port int) (*http.Server, func() error, error) {
server := &http.Server{
Addr: fmt.Sprintf(":%d", port),
ErrorLog: stdLog.New(&forwardToZeroLogWriter{}, "", 0),
}

start := func() error {
log.Info().Msgf("Starting extension http server on port %d", port)
return server.ListenAndServe()
func prepareHttpServer(port int) (*httpServerWrapper, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return nil, err
}

return server, start, nil
}

func prepareUnixSocketServer(path string) (*http.Server, func() error, error) {
server := &http.Server{
ErrorLog: stdLog.New(&forwardToZeroLogWriter{}, "", 0),
}

log.Info().Msgf("Starting extension http server on port %d", port)
return &httpServerWrapper{
serve: func() error {
return server.Serve(listener)
},
server: server,
}, nil
}

func prepareUnixSocketServer(path string) (*httpServerWrapper, error) {
if _, err := os.Stat(filepath.Dir(path)); os.IsNotExist(err) {
err = os.MkdirAll(filepath.Dir(path), 0755)
if err != nil {
return nil, nil, fmt.Errorf("failed to create directory for unix socket: %w", err)
return nil, fmt.Errorf("failed to create directory for unix socket: %w", err)
}
} else {
_ = os.Remove(path)
}

unixListener, err := net.Listen("unix", path)
if err != nil {
return nil, nil, fmt.Errorf("failed listen on unix socket: %w", err)
return nil, fmt.Errorf("failed listen on unix socket: %w", err)
}

server := &http.Server{
ErrorLog: stdLog.New(&forwardToZeroLogWriter{}, "", 0),
}

return server, func() error {
log.Info().Msgf("Starting extension http server on unix domain socket (%s)", path)
return server.Serve(unixListener)
return &httpServerWrapper{
serve: func() error {
log.Info().Msgf("Starting extension http server on unix domain socket (%s)", path)
return server.Serve(unixListener)
},
server: server,
}, nil
}

func prepareHttpsServer(port int, spec ListenSpecification) (*http.Server, func() error, error) {
func prepareHttpsServer(port int, spec ListenSpecification) (*httpServerWrapper, error) {
certReloader := NewCertReloader(spec.TlsServerCert, spec.TlsServerKey)

if _, err := certReloader.GetCertificate(nil); err != nil {
return nil, nil, fmt.Errorf("failed to load TLS certificate: %w", err)
return nil, fmt.Errorf("failed to load TLS certificate: %w", err)
}

clientCAs, err := loadCertPool(spec.TlsClientCas)
if err != nil {
return nil, nil, fmt.Errorf("failed to load TLS client CA certificates: %w", err)
return nil, fmt.Errorf("failed to load TLS client CA certificates: %w", err)
}

tlsConfig := tls.Config{
Expand All @@ -173,14 +226,21 @@ func prepareHttpsServer(port int, spec ListenSpecification) (*http.Server, func(
ClientCAs: clientCAs,
}

listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return nil, err
}

server := &http.Server{
Addr: fmt.Sprintf(":%d", port),
TLSConfig: &tlsConfig,
ErrorLog: stdLog.New(&forwardToZeroLogWriter{}, "", 0),
}
return server, func() error {
log.Info().Msgf("Starting extension https server on port %d (ClientAuth: %s)", port, spec.getClientAuthType())
return server.ListenAndServeTLS("", "")
return &httpServerWrapper{
serve: func() error {
log.Info().Msgf("Starting extension https server on port %d (ClientAuth: %s)", port, spec.getClientAuthType())
return server.ServeTLS(listener, "", "")
},
server: server,
}, nil
}

Expand Down
Loading

0 comments on commit 50ae007

Please sign in to comment.