Skip to content

Commit

Permalink
CR fixes: rename state enum, use uint32
Browse files Browse the repository at this point in the history
  • Loading branch information
segevda committed Dec 31, 2024
1 parent a27ebab commit 34ae727
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 28 deletions.
38 changes: 19 additions & 19 deletions circuit_breaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
type CircuitBreaker struct {
policies []CircuitBreakerPolicy
timeout time.Duration
failThreshold, successThreshold int
failThreshold, successThreshold uint32

state atomic.Value // circuitBreakerState
failCount, successCount atomic.Uint32
Expand All @@ -38,7 +38,7 @@ func NewCircuitBreaker() *CircuitBreaker {
failThreshold: 3,
successThreshold: 1,
}
cb.state.Store(closed)
cb.state.Store(circuitBreakerStateClosed)
return cb
}

Expand All @@ -56,14 +56,14 @@ func (cb *CircuitBreaker) SetTimeout(timeout time.Duration) *CircuitBreaker {

// SetFailThreshold sets the number of failures that must occur within the timeout duration for the CircuitBreaker to
// transition to the Open state.
func (cb *CircuitBreaker) SetFailThreshold(threshold int) *CircuitBreaker {
func (cb *CircuitBreaker) SetFailThreshold(threshold uint32) *CircuitBreaker {
cb.failThreshold = threshold
return cb
}

// SetSuccessThreshold sets the number of successes that must occur to transition the CircuitBreaker from the Half-Open state
// to the Closed state.
func (cb *CircuitBreaker) SetSuccessThreshold(threshold int) *CircuitBreaker {
func (cb *CircuitBreaker) SetSuccessThreshold(threshold uint32) *CircuitBreaker {
cb.successThreshold = threshold
return cb
}
Expand All @@ -81,9 +81,9 @@ var ErrCircuitBreakerOpen = errors.New("resty: circuit breaker open")
type circuitBreakerState uint32

const (
closed circuitBreakerState = iota
open
halfOpen
circuitBreakerStateClosed circuitBreakerState = iota
circuitBreakerStateOpen
circuitBreakerStateHalfOpen
)

func (cb *CircuitBreaker) getState() circuitBreakerState {
Expand All @@ -95,7 +95,7 @@ func (cb *CircuitBreaker) allow() error {
return nil
}

if cb.getState() == open {
if cb.getState() == circuitBreakerStateOpen {
return ErrCircuitBreakerOpen
}

Expand All @@ -121,24 +121,24 @@ func (cb *CircuitBreaker) applyPolicies(resp *http.Response) {
}

switch cb.getState() {
case closed:
cb.failCount.Add(1)
if cb.failCount.Load() >= uint32(cb.failThreshold) {
case circuitBreakerStateClosed:
failCount := cb.failCount.Add(1)
if failCount >= cb.failThreshold {
cb.open()
} else {
cb.lastFail = time.Now()
}
case halfOpen:
case circuitBreakerStateHalfOpen:
cb.open()
}
} else {
switch cb.getState() {
case closed:
case circuitBreakerStateClosed:
return
case halfOpen:
cb.successCount.Add(1)
if cb.successCount.Load() >= uint32(cb.successThreshold) {
cb.changeState(closed)
case circuitBreakerStateHalfOpen:
successCount := cb.successCount.Add(1)
if successCount >= cb.successThreshold {
cb.changeState(circuitBreakerStateClosed)
}
}
}
Expand All @@ -147,10 +147,10 @@ func (cb *CircuitBreaker) applyPolicies(resp *http.Response) {
}

func (cb *CircuitBreaker) open() {
cb.changeState(open)
cb.changeState(circuitBreakerStateOpen)
go func() {
time.Sleep(cb.timeout)
cb.changeState(halfOpen)
cb.changeState(circuitBreakerStateHalfOpen)
}()
}

Expand Down
18 changes: 9 additions & 9 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1440,8 +1440,8 @@ func TestClientCircuitBreaker(t *testing.T) {
})
defer ts.Close()

failThreshold := 2
successThreshold := 1
failThreshold := uint32(2)
successThreshold := uint32(1)
timeout := 1 * time.Second

c := dcnl().SetCircuitBreaker(
Expand All @@ -1451,30 +1451,30 @@ func TestClientCircuitBreaker(t *testing.T) {
SetSuccessThreshold(successThreshold).
SetPolicies([]CircuitBreakerPolicy{CircuitBreaker5xxPolicy}))

for i := 0; i < failThreshold; i++ {
for i := uint32(0); i < failThreshold; i++ {
_, err := c.R().Get(ts.URL + "/500")
assertNil(t, err)
}
resp, err := c.R().Get(ts.URL + "/500")
assertErrorIs(t, ErrCircuitBreakerOpen, err)
assertNil(t, resp)
assertEqual(t, c.circuitBreaker.getState(), open)
assertEqual(t, c.circuitBreaker.getState(), circuitBreakerStateOpen)

time.Sleep(timeout + 1*time.Millisecond)
assertEqual(t, c.circuitBreaker.getState(), halfOpen)
assertEqual(t, c.circuitBreaker.getState(), circuitBreakerStateHalfOpen)

resp, err = c.R().Get(ts.URL + "/500")
assertError(t, err)
assertEqual(t, c.circuitBreaker.getState(), open)
assertEqual(t, c.circuitBreaker.getState(), circuitBreakerStateOpen)

time.Sleep(timeout + 1*time.Millisecond)
assertEqual(t, c.circuitBreaker.getState(), halfOpen)
assertEqual(t, c.circuitBreaker.getState(), circuitBreakerStateHalfOpen)

for i := 0; i < successThreshold; i++ {
for i := uint32(0); i < successThreshold; i++ {
_, err := c.R().Get(ts.URL + "/200")
assertNil(t, err)
}
assertEqual(t, c.circuitBreaker.getState(), closed)
assertEqual(t, c.circuitBreaker.getState(), circuitBreakerStateClosed)

resp, err = c.R().Get(ts.URL + "/200")
assertNil(t, err)
Expand Down

0 comments on commit 34ae727

Please sign in to comment.