From 69e85d9ae259084b711e25f494c78db56bc09e71 Mon Sep 17 00:00:00 2001 From: Segev Dagan Date: Mon, 23 Dec 2024 13:28:46 +0200 Subject: [PATCH] feat: client circuit breaker #448 Co-authored-by: ccoVeille <3875889+ccoVeille@users.noreply.github.com> --- circuit_breaker.go | 163 +++++++++++++++++++++++++++++++++++++++++++++ client.go | 19 ++++++ client_test.go | 69 +++++++++++++++++++ 3 files changed, 251 insertions(+) create mode 100644 circuit_breaker.go diff --git a/circuit_breaker.go b/circuit_breaker.go new file mode 100644 index 0000000..fac650c --- /dev/null +++ b/circuit_breaker.go @@ -0,0 +1,163 @@ +package resty + +import ( + "errors" + "net/http" + "sync/atomic" + "time" +) + +// CircuitBreaker can be in one of three states: Closed, Open, or Half-Open. +// - When the CircuitBreaker is Closed, requests are allowed to pass through. +// - If a failure count threshold is reached within a specified time-frame, +// the CircuitBreaker transitions to the Open state. +// - When the CircuitBreaker is Open, requests are blocked. +// - After a specified timeout, the CircuitBreaker transitions to the Half-Open state. +// - When the CircuitBreaker is Half-Open, a single request is allowed to pass through. +// - If that request fails, the CircuitBreaker returns to the Open state. +// - If the number of successes reaches a specified threshold, +// the CircuitBreaker transitions to the Closed state. +type CircuitBreaker struct { + policies []CircuitBreakerPolicy + timeout time.Duration + failThreshold, successThreshold uint32 + + state atomic.Value // circuitBreakerState + failCount, successCount atomic.Uint32 + lastFail time.Time +} + +// NewCircuitBreaker creates a new [CircuitBreaker] with default settings. +// The default settings are: +// - Timeout: 10 seconds +// - FailThreshold: 3 +// - SuccessThreshold: 1 +// - Policies: CircuitBreaker5xxPolicy +func NewCircuitBreaker() *CircuitBreaker { + cb := &CircuitBreaker{ + policies: []CircuitBreakerPolicy{CircuitBreaker5xxPolicy}, + timeout: 10 * time.Second, + failThreshold: 3, + successThreshold: 1, + } + cb.state.Store(circuitBreakerStateClosed) + return cb +} + +// SetPolicies sets the CircuitBreakerPolicy's that the [CircuitBreaker] will use to determine whether a response is a failure. +func (cb *CircuitBreaker) SetPolicies(policies []CircuitBreakerPolicy) *CircuitBreaker { + cb.policies = policies + return cb +} + +// SetTimeout sets the timeout duration for the [CircuitBreaker]. +func (cb *CircuitBreaker) SetTimeout(timeout time.Duration) *CircuitBreaker { + cb.timeout = timeout + return cb +} + +// SetFailThreshold sets the number of failures that must occur within the timeout duration for the [CircuitBreaker] to +// transition to the Open state. +func (cb *CircuitBreaker) SetFailThreshold(threshold uint32) *CircuitBreaker { + cb.failThreshold = threshold + return cb +} + +// SetSuccessThreshold sets the number of successes that must occur to transition the [CircuitBreaker] from the Half-Open state +// to the Closed state. +func (cb *CircuitBreaker) SetSuccessThreshold(threshold uint32) *CircuitBreaker { + cb.successThreshold = threshold + return cb +} + +// CircuitBreakerPolicy is a function that determines whether a response should trip the [CircuitBreaker]. +type CircuitBreakerPolicy func(resp *http.Response) bool + +// CircuitBreaker5xxPolicy is a [CircuitBreakerPolicy] that trips the [CircuitBreaker] if the response status code is 500 or greater. +func CircuitBreaker5xxPolicy(resp *http.Response) bool { + return resp.StatusCode > 499 +} + +var ErrCircuitBreakerOpen = errors.New("resty: circuit breaker open") + +type circuitBreakerState uint32 + +const ( + circuitBreakerStateClosed circuitBreakerState = iota + circuitBreakerStateOpen + circuitBreakerStateHalfOpen +) + +func (cb *CircuitBreaker) getState() circuitBreakerState { + return cb.state.Load().(circuitBreakerState) +} + +func (cb *CircuitBreaker) allow() error { + if cb == nil { + return nil + } + + if cb.getState() == circuitBreakerStateOpen { + return ErrCircuitBreakerOpen + } + + return nil +} + +func (cb *CircuitBreaker) applyPolicies(resp *http.Response) { + if cb == nil { + return + } + + failed := false + for _, policy := range cb.policies { + if policy(resp) { + failed = true + break + } + } + + if failed { + if cb.failCount.Load() > 0 && time.Since(cb.lastFail) > cb.timeout { + cb.failCount.Store(0) + } + + switch cb.getState() { + case circuitBreakerStateClosed: + failCount := cb.failCount.Add(1) + if failCount >= cb.failThreshold { + cb.open() + } else { + cb.lastFail = time.Now() + } + case circuitBreakerStateHalfOpen: + cb.open() + } + } else { + switch cb.getState() { + case circuitBreakerStateClosed: + return + case circuitBreakerStateHalfOpen: + successCount := cb.successCount.Add(1) + if successCount >= cb.successThreshold { + cb.changeState(circuitBreakerStateClosed) + } + } + } + + return +} + +func (cb *CircuitBreaker) open() { + cb.changeState(circuitBreakerStateOpen) + go func() { + time.Sleep(cb.timeout) + cb.changeState(circuitBreakerStateHalfOpen) + }() +} + +func (cb *CircuitBreaker) changeState(state circuitBreakerState) { + cb.failCount.Store(0) + cb.successCount.Store(0) + cb.state.Store(state) +} diff --git a/client.go b/client.go index 0d089f6..d505bc8 100644 --- a/client.go +++ b/client.go @@ -223,6 +223,7 @@ type Client struct { contentDecompresserKeys []string contentDecompressers map[string]ContentDecompresser certWatcherStopChan chan bool + circuitBreaker *CircuitBreaker } // CertWatcherOptions allows configuring a watcher that reloads dynamically TLS certs. @@ -942,6 +943,18 @@ func (c *Client) SetContentDecompresserKeys(keys []string) *Client { return c } +// SetCircuitBreaker method sets the Circuit Breaker instance into the client. +// It is used to prevent the client from sending requests that are likely to fail. +// For Example: To use the default Circuit Breaker: +// +// client.SetCircuitBreaker(NewCircuitBreaker()) +func (c *Client) SetCircuitBreaker(b *CircuitBreaker) *Client { + c.lock.Lock() + defer c.lock.Unlock() + c.circuitBreaker = b + return c +} + // IsDebug method returns `true` if the client is in debug mode; otherwise, it is `false`. func (c *Client) IsDebug() bool { c.lock.RLock() @@ -2094,6 +2107,10 @@ func (c *Client) executeRequestMiddlewares(req *Request) (err error) { // Executes method executes the given `Request` object and returns // response or error. func (c *Client) execute(req *Request) (*Response, error) { + if err := c.circuitBreaker.allow(); err != nil { + return nil, err + } + if err := c.executeRequestMiddlewares(req); err != nil { return nil, err } @@ -2118,6 +2135,8 @@ func (c *Client) execute(req *Request) (*Response, error) { } } if resp != nil { + c.circuitBreaker.applyPolicies(resp) + response.Body = resp.Body if err = response.wrapContentDecompresser(); err != nil { return response, err diff --git a/client_test.go b/client_test.go index db81f4c..2bd71d8 100644 --- a/client_test.go +++ b/client_test.go @@ -1421,3 +1421,72 @@ func TestClientDebugf(t *testing.T) { assertEqual(t, "", b.String()) }) } + +var _ CircuitBreakerPolicy = CircuitBreaker5xxPolicy + +func TestClientCircuitBreaker(t *testing.T) { + ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { + t.Logf("Method: %v", r.Method) + t.Logf("Path: %v", r.URL.Path) + + switch r.URL.Path { + case "/200": + w.WriteHeader(http.StatusOK) + return + case "/500": + w.WriteHeader(http.StatusInternalServerError) + return + } + }) + defer ts.Close() + + failThreshold := uint32(2) + successThreshold := uint32(1) + timeout := 1 * time.Second + + c := dcnl().SetCircuitBreaker( + NewCircuitBreaker(). + SetTimeout(timeout). + SetFailThreshold(failThreshold). + SetSuccessThreshold(successThreshold). + SetPolicies([]CircuitBreakerPolicy{CircuitBreaker5xxPolicy})) + + for i := uint32(0); i < failThreshold; i++ { + _, err := c.R().Get(ts.URL + "/500") + assertNil(t, err) + } + resp, err := c.R().Get(ts.URL + "/500") + assertErrorIs(t, ErrCircuitBreakerOpen, err) + assertNil(t, resp) + assertEqual(t, circuitBreakerStateOpen, c.circuitBreaker.getState()) + + time.Sleep(timeout + 1*time.Millisecond) + assertEqual(t, circuitBreakerStateHalfOpen, c.circuitBreaker.getState()) + + resp, err = c.R().Get(ts.URL + "/500") + assertError(t, err) + assertEqual(t, circuitBreakerStateOpen, c.circuitBreaker.getState()) + + time.Sleep(timeout + 1*time.Millisecond) + assertEqual(t, circuitBreakerStateHalfOpen, c.circuitBreaker.getState()) + + for i := uint32(0); i < successThreshold; i++ { + _, err := c.R().Get(ts.URL + "/200") + assertNil(t, err) + } + assertEqual(t, circuitBreakerStateClosed, c.circuitBreaker.getState()) + + resp, err = c.R().Get(ts.URL + "/200") + assertNil(t, err) + assertEqual(t, http.StatusOK, resp.StatusCode()) + + resp, err = c.R().Get(ts.URL + "/500") + assertError(t, err) + assertEqual(t, uint32(1), c.circuitBreaker.failCount.Load()) + + time.Sleep(timeout) + + resp, err = c.R().Get(ts.URL + "/500") + assertError(t, err) + assertEqual(t, uint32(1), c.circuitBreaker.failCount.Load()) +}