Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Client Circuit Breaker #930

Merged
merged 1 commit into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions circuit_breaker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
package resty

import (
"errors"
"net/http"
"sync/atomic"
"time"
)

// CircuitBreaker can be in one of three states: Closed, Open, or Half-Open.
// - When the CircuitBreaker is Closed, requests are allowed to pass through.
// - If a failure count threshold is reached within a specified time-frame,
// the CircuitBreaker transitions to the Open state.
// - When the CircuitBreaker is Open, requests are blocked.
// - After a specified timeout, the CircuitBreaker transitions to the Half-Open state.
// - When the CircuitBreaker is Half-Open, a single request is allowed to pass through.
// - If that request fails, the CircuitBreaker returns to the Open state.
// - If the number of successes reaches a specified threshold,
// the CircuitBreaker transitions to the Closed state.
type CircuitBreaker struct {
policies []CircuitBreakerPolicy
timeout time.Duration
failThreshold, successThreshold uint32

state atomic.Value // circuitBreakerState
failCount, successCount atomic.Uint32
lastFail time.Time
}

// NewCircuitBreaker creates a new [CircuitBreaker] with default settings.
// The default settings are:
// - Timeout: 10 seconds
// - FailThreshold: 3
// - SuccessThreshold: 1
// - Policies: CircuitBreaker5xxPolicy
func NewCircuitBreaker() *CircuitBreaker {
cb := &CircuitBreaker{
policies: []CircuitBreakerPolicy{CircuitBreaker5xxPolicy},
timeout: 10 * time.Second,
failThreshold: 3,
successThreshold: 1,
}
cb.state.Store(circuitBreakerStateClosed)
return cb
}

// SetPolicies sets the CircuitBreakerPolicy's that the [CircuitBreaker] will use to determine whether a response is a failure.
func (cb *CircuitBreaker) SetPolicies(policies []CircuitBreakerPolicy) *CircuitBreaker {
cb.policies = policies
return cb
}

// SetTimeout sets the timeout duration for the [CircuitBreaker].
func (cb *CircuitBreaker) SetTimeout(timeout time.Duration) *CircuitBreaker {
cb.timeout = timeout
return cb
}

// SetFailThreshold sets the number of failures that must occur within the timeout duration for the [CircuitBreaker] to
// transition to the Open state.
func (cb *CircuitBreaker) SetFailThreshold(threshold uint32) *CircuitBreaker {
cb.failThreshold = threshold
return cb
}

// SetSuccessThreshold sets the number of successes that must occur to transition the [CircuitBreaker] from the Half-Open state
// to the Closed state.
func (cb *CircuitBreaker) SetSuccessThreshold(threshold uint32) *CircuitBreaker {
cb.successThreshold = threshold
return cb
}

// CircuitBreakerPolicy is a function that determines whether a response should trip the [CircuitBreaker].
type CircuitBreakerPolicy func(resp *http.Response) bool

// CircuitBreaker5xxPolicy is a [CircuitBreakerPolicy] that trips the [CircuitBreaker] if the response status code is 500 or greater.
func CircuitBreaker5xxPolicy(resp *http.Response) bool {
return resp.StatusCode > 499
}

var ErrCircuitBreakerOpen = errors.New("resty: circuit breaker open")

type circuitBreakerState uint32

const (
circuitBreakerStateClosed circuitBreakerState = iota
circuitBreakerStateOpen
circuitBreakerStateHalfOpen
)

func (cb *CircuitBreaker) getState() circuitBreakerState {
return cb.state.Load().(circuitBreakerState)
}

func (cb *CircuitBreaker) allow() error {
if cb == nil {
return nil
}

if cb.getState() == circuitBreakerStateOpen {
return ErrCircuitBreakerOpen
}

return nil
}

func (cb *CircuitBreaker) applyPolicies(resp *http.Response) {
if cb == nil {
return
}

failed := false
for _, policy := range cb.policies {
if policy(resp) {
failed = true
break
}
}

if failed {
if cb.failCount.Load() > 0 && time.Since(cb.lastFail) > cb.timeout {
cb.failCount.Store(0)
}

switch cb.getState() {
case circuitBreakerStateClosed:
failCount := cb.failCount.Add(1)
if failCount >= cb.failThreshold {
cb.open()
} else {
cb.lastFail = time.Now()
}
case circuitBreakerStateHalfOpen:
cb.open()
}
} else {
switch cb.getState() {
case circuitBreakerStateClosed:
return
case circuitBreakerStateHalfOpen:
successCount := cb.successCount.Add(1)
if successCount >= cb.successThreshold {
cb.changeState(circuitBreakerStateClosed)
}
}
}

return
}

func (cb *CircuitBreaker) open() {
cb.changeState(circuitBreakerStateOpen)
go func() {
time.Sleep(cb.timeout)
cb.changeState(circuitBreakerStateHalfOpen)
}()
}

func (cb *CircuitBreaker) changeState(state circuitBreakerState) {
cb.failCount.Store(0)
cb.successCount.Store(0)
cb.state.Store(state)
}
19 changes: 19 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ type Client struct {
contentDecompresserKeys []string
contentDecompressers map[string]ContentDecompresser
certWatcherStopChan chan bool
circuitBreaker *CircuitBreaker
}

// CertWatcherOptions allows configuring a watcher that reloads dynamically TLS certs.
Expand Down Expand Up @@ -942,6 +943,18 @@ func (c *Client) SetContentDecompresserKeys(keys []string) *Client {
return c
}

// SetCircuitBreaker method sets the Circuit Breaker instance into the client.
// It is used to prevent the client from sending requests that are likely to fail.
// For Example: To use the default Circuit Breaker:
//
// client.SetCircuitBreaker(NewCircuitBreaker())
func (c *Client) SetCircuitBreaker(b *CircuitBreaker) *Client {
segevda marked this conversation as resolved.
Show resolved Hide resolved
c.lock.Lock()
defer c.lock.Unlock()
c.circuitBreaker = b
return c
}

// IsDebug method returns `true` if the client is in debug mode; otherwise, it is `false`.
func (c *Client) IsDebug() bool {
c.lock.RLock()
Expand Down Expand Up @@ -2094,6 +2107,10 @@ func (c *Client) executeRequestMiddlewares(req *Request) (err error) {
// Executes method executes the given `Request` object and returns
// response or error.
func (c *Client) execute(req *Request) (*Response, error) {
segevda marked this conversation as resolved.
Show resolved Hide resolved
if err := c.circuitBreaker.allow(); err != nil {
jeevatkm marked this conversation as resolved.
Show resolved Hide resolved
return nil, err
}

if err := c.executeRequestMiddlewares(req); err != nil {
return nil, err
}
Expand All @@ -2118,6 +2135,8 @@ func (c *Client) execute(req *Request) (*Response, error) {
}
}
if resp != nil {
c.circuitBreaker.applyPolicies(resp)

response.Body = resp.Body
if err = response.wrapContentDecompresser(); err != nil {
return response, err
Expand Down
69 changes: 69 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1421,3 +1421,72 @@ func TestClientDebugf(t *testing.T) {
assertEqual(t, "", b.String())
})
}

var _ CircuitBreakerPolicy = CircuitBreaker5xxPolicy

func TestClientCircuitBreaker(t *testing.T) {
ts := createTestServer(func(w http.ResponseWriter, r *http.Request) {
t.Logf("Method: %v", r.Method)
t.Logf("Path: %v", r.URL.Path)

switch r.URL.Path {
case "/200":
w.WriteHeader(http.StatusOK)
return
case "/500":
w.WriteHeader(http.StatusInternalServerError)
return
}
})
defer ts.Close()

failThreshold := uint32(2)
successThreshold := uint32(1)
timeout := 1 * time.Second

c := dcnl().SetCircuitBreaker(
NewCircuitBreaker().
SetTimeout(timeout).
SetFailThreshold(failThreshold).
SetSuccessThreshold(successThreshold).
SetPolicies([]CircuitBreakerPolicy{CircuitBreaker5xxPolicy}))

for i := uint32(0); i < failThreshold; i++ {
_, err := c.R().Get(ts.URL + "/500")
assertNil(t, err)
}
resp, err := c.R().Get(ts.URL + "/500")
assertErrorIs(t, ErrCircuitBreakerOpen, err)
assertNil(t, resp)
assertEqual(t, circuitBreakerStateOpen, c.circuitBreaker.getState())

time.Sleep(timeout + 1*time.Millisecond)
assertEqual(t, circuitBreakerStateHalfOpen, c.circuitBreaker.getState())

resp, err = c.R().Get(ts.URL + "/500")
assertError(t, err)
assertEqual(t, circuitBreakerStateOpen, c.circuitBreaker.getState())

time.Sleep(timeout + 1*time.Millisecond)
assertEqual(t, circuitBreakerStateHalfOpen, c.circuitBreaker.getState())

for i := uint32(0); i < successThreshold; i++ {
_, err := c.R().Get(ts.URL + "/200")
assertNil(t, err)
}
assertEqual(t, circuitBreakerStateClosed, c.circuitBreaker.getState())

resp, err = c.R().Get(ts.URL + "/200")
assertNil(t, err)
assertEqual(t, http.StatusOK, resp.StatusCode())

resp, err = c.R().Get(ts.URL + "/500")
assertError(t, err)
assertEqual(t, uint32(1), c.circuitBreaker.failCount.Load())

time.Sleep(timeout)

resp, err = c.R().Get(ts.URL + "/500")
assertError(t, err)
assertEqual(t, uint32(1), c.circuitBreaker.failCount.Load())
}
Loading