diff --git a/README.md b/README.md index 87f67bc..f39cb71 100644 --- a/README.md +++ b/README.md @@ -7,15 +7,9 @@ Builds off of the wonderful work of https://github.com/sethvargo/go-retry but ad TODO: - update godoc to my version - update this documentation with changes -- put unit tests in same package? - update Int63n to not panic? -- figure out if we can delete stuff (eg Constant() and the TODOs around that sort of function) -- potentially update the import to be like go-backoff -- split up into better package isolation -- setup github actions to run tests? -- tune up repeat.go to be useful -- add tests around stuff in repeat.go -- perhaps have repeat and retry be separate packages so each can have a Do method? +- add tests to rand.go +- make rand its own package? ## Added features diff --git a/backoff.go b/backoff/backoff.go similarity index 81% rename from backoff.go rename to backoff/backoff.go index 4698e6b..f2690ca 100644 --- a/backoff.go +++ b/backoff/backoff.go @@ -1,21 +1,15 @@ -package retry +package backoff import ( "sync" "time" -) -// Backoff is an interface that backs off. -type Backoff interface { - // Next returns the time duration to wait and whether to stop. - Next() (next time.Duration, stop bool) - // Reset sets the undecorated backoff back to its initial parameters - Reset() -} + "github.com/swayne275/go-retry/common/backoff" +) // TODO clean up interface, struct, etc -var _ Backoff = (BackoffFunc)(nil) +var _ backoff.Backoff = (BackoffFunc)(nil) // BackoffFunc is a backoff expressed as a function. type BackoffFunc func() (time.Duration, bool) @@ -28,7 +22,7 @@ func (b BackoffFunc) Next() (time.Duration, bool) { func (b BackoffFunc) Reset() {} type ResettableBackoff struct { - Backoff + backoff.Backoff // reset returns the backoff to its initial state. reset func() } @@ -41,7 +35,7 @@ func (b *ResettableBackoff) Reset() { b.reset() } -func WithReset(reset func() Backoff, next Backoff) *ResettableBackoff { +func WithReset(reset func() backoff.Backoff, next backoff.Backoff) *ResettableBackoff { rb := &ResettableBackoff{ Backoff: next, } @@ -56,7 +50,7 @@ func WithReset(reset func() Backoff, next Backoff) *ResettableBackoff { // interpreted as "+/- j". For example, if j were 5 seconds and the backoff // returned 20s, the value could be between 15 and 25 seconds. The value can // never be less than 0. -func WithJitter(j time.Duration, next Backoff) *ResettableBackoff { +func WithJitter(j time.Duration, next backoff.Backoff) *ResettableBackoff { r := newLockedRandom(time.Now().UnixNano()) nextWithJitter := BackoffFunc(func() (time.Duration, bool) { @@ -73,7 +67,7 @@ func WithJitter(j time.Duration, next Backoff) *ResettableBackoff { return val, false }) - reset := func() Backoff { + reset := func() backoff.Backoff { next.Reset() return nextWithJitter } @@ -85,7 +79,7 @@ func WithJitter(j time.Duration, next Backoff) *ResettableBackoff { // percentage. j can be interpreted as "+/- j%". For example, if j were 5 and // the backoff returned 20s, the value could be between 19 and 21 seconds. The // value can never be less than 0 or greater than 100. -func WithJitterPercent(j uint64, next Backoff) *ResettableBackoff { +func WithJitterPercent(j uint64, next backoff.Backoff) *ResettableBackoff { r := newLockedRandom(time.Now().UnixNano()) nextWithJitterPercent := BackoffFunc(func() (time.Duration, bool) { @@ -105,7 +99,7 @@ func WithJitterPercent(j uint64, next Backoff) *ResettableBackoff { return val, false }) - reset := func() Backoff { + reset := func() backoff.Backoff { next.Reset() return nextWithJitterPercent } @@ -114,7 +108,7 @@ func WithJitterPercent(j uint64, next Backoff) *ResettableBackoff { } // WithMaxRetries executes the backoff function up until the maximum attempts. -func WithMaxRetries(max uint64, next Backoff) *ResettableBackoff { +func WithMaxRetries(max uint64, next backoff.Backoff) *ResettableBackoff { var l sync.Mutex var attempt uint64 @@ -135,7 +129,7 @@ func WithMaxRetries(max uint64, next Backoff) *ResettableBackoff { return val, false }) - reset := func() Backoff { + reset := func() backoff.Backoff { l.Lock() defer l.Unlock() attempt = 0 @@ -151,7 +145,7 @@ func WithMaxRetries(max uint64, next Backoff) *ResettableBackoff { // backoff. This is NOT a total backoff time, but rather a cap on the maximum // value a backoff can return. Without another middleware, the backoff will // continue infinitely. -func WithCappedDuration(cap time.Duration, next Backoff) *ResettableBackoff { +func WithCappedDuration(cap time.Duration, next backoff.Backoff) *ResettableBackoff { nextWithCappedDuration := BackoffFunc(func() (time.Duration, bool) { val, stop := next.Next() if stop { @@ -164,7 +158,7 @@ func WithCappedDuration(cap time.Duration, next Backoff) *ResettableBackoff { return val, false }) - reset := func() Backoff { + reset := func() backoff.Backoff { next.Reset() return nextWithCappedDuration } @@ -175,7 +169,7 @@ func WithCappedDuration(cap time.Duration, next Backoff) *ResettableBackoff { // WithMaxDuration sets a maximum on the total amount of time a backoff should // execute. It's best-effort, and should not be used to guarantee an exact // amount of time. -func WithMaxDuration(timeout time.Duration, next Backoff) *ResettableBackoff { +func WithMaxDuration(timeout time.Duration, next backoff.Backoff) *ResettableBackoff { var l sync.RWMutex start := time.Now() @@ -199,7 +193,7 @@ func WithMaxDuration(timeout time.Duration, next Backoff) *ResettableBackoff { return val, false }) - reset := func() Backoff { + reset := func() backoff.Backoff { l.Lock() defer l.Unlock() start = time.Now() diff --git a/backoff/backoff_constant.go b/backoff/backoff_constant.go new file mode 100644 index 0000000..4958562 --- /dev/null +++ b/backoff/backoff_constant.go @@ -0,0 +1,20 @@ +package backoff + +import ( + "fmt" + "time" + + "github.com/swayne275/go-retry/common/backoff" +) + +// NewConstant creates a new constant backoff using the value t. The wait time +// is the provided constant value. It returns an error if t is not greater than 0. +func NewConstant(t time.Duration) (backoff.Backoff, error) { + if t <= 0 { + return nil, fmt.Errorf("constant backoff must be greater than zero") + } + + return BackoffFunc(func() (time.Duration, bool) { + return t, false + }), nil +} diff --git a/backoff_constant_test.go b/backoff/backoff_constant_test.go similarity index 88% rename from backoff_constant_test.go rename to backoff/backoff_constant_test.go index bc719aa..1cfc936 100644 --- a/backoff_constant_test.go +++ b/backoff/backoff_constant_test.go @@ -1,4 +1,4 @@ -package retry_test +package backoff import ( "reflect" @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/swayne275/go-retry" + cb "github.com/swayne275/go-retry/common/backoff" ) func TestConstantBackoff(t *testing.T) { @@ -75,7 +75,7 @@ func TestConstantBackoff(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - b, err := retry.NewConstant(tc.base) + b, err := NewConstant(tc.base) if tc.expectErr && err == nil { t.Fatal("expected an error") } @@ -113,12 +113,12 @@ func TestConstantBackoff(t *testing.T) { func TestConstantBackoff_WithReset(t *testing.T) { expectedDuration := 3 * time.Second - b, err := retry.NewConstant(expectedDuration) + b, err := NewConstant(expectedDuration) if err != nil { t.Fatalf("failed to create constant backoff: %v", err) } - resettableB := retry.WithReset(func() retry.Backoff { + resettableB := WithReset(func() cb.Backoff { return b }, b) resettableB.Reset() @@ -133,12 +133,12 @@ func TestConstantBackoff_WithCappedDuration_WithReset(t *testing.T) { expectedDuration := 3 * time.Second cappedDuration := 2 * time.Second - b, err := retry.NewConstant(expectedDuration) + b, err := NewConstant(expectedDuration) if err != nil { t.Fatalf("failed to create constant backoff: %v", err) } - resettableB := retry.WithCappedDuration(cappedDuration, b) + resettableB := WithCappedDuration(cappedDuration, b) val, _ := resettableB.Next() if val != cappedDuration { @@ -156,12 +156,12 @@ func TestConstantBackoff_ExplicitReset(t *testing.T) { expectedDuration := 3 * time.Second cappedDuration := 2 * time.Second - b, err := retry.NewConstant(expectedDuration) + b, err := NewConstant(expectedDuration) if err != nil { t.Fatalf("failed to create constant backoff: %v", err) } - resettableB := retry.WithCappedDuration(cappedDuration, b) + resettableB := WithCappedDuration(cappedDuration, b) val, _ := resettableB.Next() if val != cappedDuration { @@ -171,8 +171,8 @@ func TestConstantBackoff_ExplicitReset(t *testing.T) { // now we're going to explicitly pass in a reset function that DOES NOT observe the cap, // and we expect the reset to no longer have the cap - explicitylyResettableB := retry.WithReset(func() retry.Backoff { - b, err := retry.NewConstant(expectedDuration) + explicitylyResettableB := WithReset(func() cb.Backoff { + b, err := NewConstant(expectedDuration) if err != nil { t.Fatalf("failed to create constant backoff: %v", err) } diff --git a/backoff_exponential.go b/backoff/backoff_exponential.go similarity index 67% rename from backoff_exponential.go rename to backoff/backoff_exponential.go index e88845e..c143038 100644 --- a/backoff_exponential.go +++ b/backoff/backoff_exponential.go @@ -1,11 +1,12 @@ -package retry +package backoff import ( - "context" "fmt" "math" "sync/atomic" "time" + + "github.com/swayne275/go-retry/common/backoff" ) type exponentialBackoff struct { @@ -13,18 +14,6 @@ type exponentialBackoff struct { attempt uint64 } -// Exponential is a wrapper around Retry that uses an exponential backoff. See -// NewExponential. -// TODO is this useful or fine as an example? -func Exponential(ctx context.Context, base time.Duration, f RetryFunc) error { - b, err := NewExponential(base) - if err != nil { - return fmt.Errorf("failed to create exponential backoff: %w", err) - } - - return Do(ctx, b, f) -} - // NewExponential creates a new exponential backoff using the starting value of // base and doubling on each failure (1, 2, 4, 8, 16, 32, 64...), up to max. // @@ -32,7 +21,7 @@ func Exponential(ctx context.Context, base time.Duration, f RetryFunc) error { // for a 64-bit integer. // // It returns an error if the given base is less than zero. -func NewExponential(base time.Duration) (Backoff, error) { +func NewExponential(base time.Duration) (backoff.Backoff, error) { if base <= 0 { return nil, fmt.Errorf("base must be greater than 0") } diff --git a/backoff_exponential_test.go b/backoff/backoff_exponential_test.go similarity index 91% rename from backoff_exponential_test.go rename to backoff/backoff_exponential_test.go index 75d663f..dbaa634 100644 --- a/backoff_exponential_test.go +++ b/backoff/backoff_exponential_test.go @@ -1,4 +1,4 @@ -package retry_test +package backoff import ( "math" @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/swayne275/go-retry" + cb "github.com/swayne275/go-retry/common/backoff" ) func TestExponentialBackoff(t *testing.T) { @@ -81,7 +81,7 @@ func TestExponentialBackoff(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - b, err := retry.NewExponential(tc.base) + b, err := NewExponential(tc.base) if tc.expectErr && err == nil { t.Fatal("expected an error") } @@ -126,13 +126,13 @@ func TestExponentialBackoff_WithReset(t *testing.T) { 8 * time.Second, } - b, err := retry.NewExponential(base) + b, err := NewExponential(base) if err != nil { t.Fatalf("failed to create exponential backoff: %v", err) } - resettableB := retry.WithReset(func() retry.Backoff { - newB, err := retry.NewExponential(base) + resettableB := WithReset(func() cb.Backoff { + newB, err := NewExponential(base) if err != nil { t.Fatalf("failed to reset exponential backoff: %v", err) } @@ -168,12 +168,12 @@ func TestExponentialBackoff_WithCappedDuration_WithReset(t *testing.T) { 4 * time.Second, } - b, err := retry.NewExponential(base) + b, err := NewExponential(base) if err != nil { t.Fatalf("failed to create exponential backoff: %v", err) } - cappedB := retry.WithCappedDuration(cappedDuration, b) + cappedB := WithCappedDuration(cappedDuration, b) // test pre reset for i := 0; i < numRounds; i++ { @@ -200,9 +200,9 @@ func TestExponentialBackoff_WithCappedDuration_WithReset(t *testing.T) { 4 * time.Second, 8 * time.Second, } - resettableB := retry.WithReset(func() retry.Backoff { + resettableB := WithReset(func() cb.Backoff { // don't set a cap on the explicit reset - newB, err := retry.NewExponential(base) + newB, err := NewExponential(base) if err != nil { t.Fatalf("failed to reset exponential backoff: %v", err) } diff --git a/backoff_fibonacci.go b/backoff/backoff_fibonacci.go similarity index 73% rename from backoff_fibonacci.go rename to backoff/backoff_fibonacci.go index f3e988f..33d43fc 100644 --- a/backoff_fibonacci.go +++ b/backoff/backoff_fibonacci.go @@ -1,12 +1,13 @@ -package retry +package backoff import ( - "context" "fmt" "math" "sync/atomic" "time" "unsafe" + + "github.com/swayne275/go-retry/common/backoff" ) type state [2]time.Duration @@ -16,18 +17,6 @@ type fibonacciBackoff struct { base time.Duration } -// Fibonacci is a wrapper around Retry that uses a Fibonacci backoff. See -// NewFibonacci. -// TODO is this useful or should we move to example? -func Fibonacci(ctx context.Context, base time.Duration, f RetryFunc) error { - b, err := NewFibonacci(base) - if err != nil { - return fmt.Errorf("failed to create fibonacci backoff: %w", err) - - } - return Do(ctx, b, f) -} - // NewFibonacci creates a new Fibonacci backoff that follows the fibonacci sequence // multipled by base. The wait time is the sum of the previous two wait times on each // previous attempt base * (1, 2, 3, 5, 8, 13...). @@ -36,7 +25,7 @@ func Fibonacci(ctx context.Context, base time.Duration, f RetryFunc) error { // for a 64-bit integer. // // It returns an error if the given base is less than zero. -func NewFibonacci(base time.Duration) (Backoff, error) { +func NewFibonacci(base time.Duration) (backoff.Backoff, error) { if base <= 0 { return nil, fmt.Errorf("base must be greater than 0") } diff --git a/backoff_fibonacci_test.go b/backoff/backoff_fibonacci_test.go similarity index 91% rename from backoff_fibonacci_test.go rename to backoff/backoff_fibonacci_test.go index 2f87831..a05c460 100644 --- a/backoff_fibonacci_test.go +++ b/backoff/backoff_fibonacci_test.go @@ -1,4 +1,4 @@ -package retry_test +package backoff import ( "math" @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/swayne275/go-retry" + cb "github.com/swayne275/go-retry/common/backoff" ) func TestFibonacciBackoff(t *testing.T) { @@ -93,7 +93,7 @@ func TestFibonacciBackoff(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - b, err := retry.NewFibonacci(tc.base) + b, err := NewFibonacci(tc.base) if tc.expectErr && err == nil { t.Fatal("expected an error") } @@ -140,13 +140,13 @@ func TestFibonacciBackoff_WithReset(t *testing.T) { 8 * time.Second, } - b, err := retry.NewFibonacci(base) + b, err := NewFibonacci(base) if err != nil { t.Fatalf("failed to create fibonacci backoff: %v", err) } - resettableB := retry.WithReset(func() retry.Backoff { - newB, err := retry.NewFibonacci(base) + resettableB := WithReset(func() cb.Backoff { + newB, err := NewFibonacci(base) if err != nil { t.Fatalf("failed to reset fibonacci backoff: %v", err) } @@ -183,7 +183,7 @@ func TestFibonacciBackoff_WithReset_ChangeBase(t *testing.T) { 8 * time.Second, } - b, err := retry.NewFibonacci(base) + b, err := NewFibonacci(base) if err != nil { t.Fatalf("failed to create fibonacci backoff: %v", err) } @@ -196,8 +196,8 @@ func TestFibonacciBackoff_WithReset_ChangeBase(t *testing.T) { 10 * time.Second, 16 * time.Second, } - resettableB := retry.WithReset(func() retry.Backoff { - newB, err := retry.NewFibonacci(newBase) + resettableB := WithReset(func() cb.Backoff { + newB, err := NewFibonacci(newBase) if err != nil { t.Fatalf("failed to reset fibonacci backoff: %v", err) } @@ -235,12 +235,12 @@ func TestFibonacciBackoff_WithCappedDuration_WithReset(t *testing.T) { 5 * time.Second, } - b, err := retry.NewFibonacci(base) + b, err := NewFibonacci(base) if err != nil { t.Fatalf("failed to create fibonacci backoff: %v", err) } - cappedB := retry.WithCappedDuration(cappedDuration, b) + cappedB := WithCappedDuration(cappedDuration, b) // test pre reset for i := 0; i < numRounds; i++ { @@ -270,9 +270,9 @@ func TestFibonacciBackoff_WithCappedDuration_WithReset(t *testing.T) { 8 * time.Second, } - resettableB := retry.WithReset(func() retry.Backoff { + resettableB := WithReset(func() cb.Backoff { // don't set a cap on the explicit reset - newB, err := retry.NewFibonacci(base) + newB, err := NewFibonacci(base) if err != nil { t.Fatalf("failed to reset fibonacci backoff: %v", err) } diff --git a/backoff_test.go b/backoff/backoff_test.go similarity index 83% rename from backoff_test.go rename to backoff/backoff_test.go index 808819a..d8e7102 100644 --- a/backoff_test.go +++ b/backoff/backoff_test.go @@ -1,8 +1,10 @@ -package retry +package backoff import ( "testing" "time" + + "github.com/swayne275/go-retry/common/backoff" ) func TestWithJitter(t *testing.T) { @@ -129,7 +131,7 @@ func TestWithMaxDuration(t *testing.T) { func TestResettableBackoff(t *testing.T) { var attempt uint64 - b := WithReset(func() Backoff { + b := WithReset(func() backoff.Backoff { attempt = 0 return BackoffFunc(func() (time.Duration, bool) { @@ -303,6 +305,83 @@ func TestResettableBackoff_WithCappedDuration(t *testing.T) { } } +func TestResettableBackoff_WithMaxDuration(t *testing.T) { + t.Parallel() + + baseDuration := 1 * time.Second + maxDuration := 250 * time.Millisecond + b := WithMaxDuration(maxDuration, BackoffFunc(func() (time.Duration, bool) { + return baseDuration, false + })) + + validateMaxDuration(t, b, maxDuration) + + // a reset should clear it, and we do the process again + b.reset() + + validateMaxDuration(t, b, maxDuration) +} + +// TestResettableBackoff_MultipleDecorators ensures that multiple decorators can be applied to a ResettableBackoff +// and that the decorators are still observed after a reset. +func TestResettableBackoff_MultipleDecorators(t *testing.T) { + base := 1 * time.Second + cappedDuration := 5 * time.Second + maxRetries := uint64(7) + expected := []time.Duration{ + 1 * time.Second, + 2 * time.Second, + 3 * time.Second, + 5 * time.Second, + 5 * time.Second, + 5 * time.Second, + 5 * time.Second, + } + + b, err := NewFibonacci(base) + if err != nil { + t.Fatalf("failed to create fibonacci backoff: %v", err) + } + + cappedB := WithCappedDuration(cappedDuration, b) + maxRetriesB := WithMaxRetries(maxRetries, cappedB) + + for _, tc := range expected { + val, stop := maxRetriesB.Next() + if stop { + t.Errorf("pre reset should not stop") + } + if val != tc { + t.Errorf("pre reset expected %v to be %v", val, tc) + } + } + + // we expect it to stop after the max number of retries + _, stop := maxRetriesB.Next() + if !stop { + t.Errorf("pre reset should stop") + } + + // reset it and verify that we repeat the above + maxRetriesB.Reset() + + for _, tc := range expected { + val, stop := maxRetriesB.Next() + if stop { + t.Errorf("post reset should not stop") + } + if val != tc { + t.Errorf("post reset expected %v to be %v", val, tc) + } + } + + // we again expect it to stop after the max number of retries + _, stop = maxRetriesB.Next() + if !stop { + t.Errorf("post reset should stop") + } +} + func validateMaxDuration(t *testing.T, b *ResettableBackoff, maxDuration time.Duration) { t.Helper() @@ -343,20 +422,3 @@ func validateMaxDuration(t *testing.T, b *ResettableBackoff, maxDuration time.Du t.Errorf("expected %v to be %v", val, 0) } } - -func TestResettableBackoff_WithMaxDuration(t *testing.T) { - t.Parallel() - - baseDuration := 1 * time.Second - maxDuration := 250 * time.Millisecond - b := WithMaxDuration(maxDuration, BackoffFunc(func() (time.Duration, bool) { - return baseDuration, false - })) - - validateMaxDuration(t, b, maxDuration) - - // a reset should clear it, and we do the process again - b.reset() - - validateMaxDuration(t, b, maxDuration) -} diff --git a/rand.go b/backoff/rand.go similarity index 98% rename from rand.go rename to backoff/rand.go index 4799fb0..0f93932 100644 --- a/rand.go +++ b/backoff/rand.go @@ -1,4 +1,4 @@ -package retry +package backoff import ( "math/rand" diff --git a/backoff_constant.go b/backoff_constant.go deleted file mode 100644 index 3d94bda..0000000 --- a/backoff_constant.go +++ /dev/null @@ -1,31 +0,0 @@ -package retry - -import ( - "context" - "fmt" - "time" -) - -// Constant is a wrapper around Retry that uses a constant backoff. It will -// retry the function f until it returns an error, or the context is canceled. -// TODO is this really useful vs an example? would have to extend with a repeat version too. -func Constant(ctx context.Context, t time.Duration, f RetryFunc) error { - b, err := NewConstant(t) - if err != nil { - return fmt.Errorf("failed to create constant backoff: %w", err) - } - - return Do(ctx, b, f) -} - -// NewConstant creates a new constant backoff using the value t. The wait time -// is the provided constant value. It returns an error if t is not greater than 0. -func NewConstant(t time.Duration) (Backoff, error) { - if t <= 0 { - return nil, fmt.Errorf("constant backoff must be greater than zero") - } - - return BackoffFunc(func() (time.Duration, bool) { - return t, false - }), nil -} diff --git a/common/backoff/backoff.go b/common/backoff/backoff.go new file mode 100644 index 0000000..18ad969 --- /dev/null +++ b/common/backoff/backoff.go @@ -0,0 +1,11 @@ +package backoff + +import "time" + +// Backoff is an interface that backs off. +type Backoff interface { + // Next returns the time duration to wait and whether to stop. + Next() (next time.Duration, stop bool) + // Reset sets the undecorated backoff back to its initial parameters + Reset() +} diff --git a/internal/example/example_test.go b/internal/example/example.go similarity index 55% rename from internal/example/example_test.go rename to internal/example/example.go index 6085b93..0a5ff8b 100644 --- a/internal/example/example_test.go +++ b/internal/example/example.go @@ -3,16 +3,19 @@ package example import ( "context" "fmt" + "net/http" "time" - "github.com/swayne275/go-retry" + "github.com/swayne275/go-retry/backoff" + cb "github.com/swayne275/go-retry/common/backoff" + "github.com/swayne275/go-retry/retry" ) func ExampleBackoffFunc() { ctx := context.Background() // Example backoff middleware that adds the provided duration t to the result. - withShift := func(t time.Duration, next retry.Backoff) retry.BackoffFunc { + withShift := func(t time.Duration, next cb.Backoff) backoff.BackoffFunc { return func() (time.Duration, bool) { val, stop := next.Next() if stop { @@ -23,7 +26,7 @@ func ExampleBackoffFunc() { } // Middlewrap wrap another backoff: - b, err := retry.NewFibonacci(1 * time.Second) + b, err := backoff.NewFibonacci(1 * time.Second) if err != nil { // handle the error here, likely from bad input } @@ -40,11 +43,11 @@ func ExampleBackoffFunc() { func ExampleWithJitter() { ctx := context.Background() - b, err := retry.NewFibonacci(1 * time.Second) + b, err := backoff.NewFibonacci(1 * time.Second) if err != nil { // handle the error here, likely from bad input } - b = retry.WithJitter(1*time.Second, b) + b = backoff.WithJitter(1*time.Second, b) if err := retry.Do(ctx, b, func(_ context.Context) error { // your retry logic here @@ -57,11 +60,11 @@ func ExampleWithJitter() { func ExampleWithJitterPercent() { ctx := context.Background() - b, err := retry.NewFibonacci(1 * time.Second) + b, err := backoff.NewFibonacci(1 * time.Second) if err != nil { // handle err } - b = retry.WithJitterPercent(5, b) + b = backoff.WithJitterPercent(5, b) if err := retry.Do(ctx, b, func(_ context.Context) error { // your retry logic here @@ -74,11 +77,11 @@ func ExampleWithJitterPercent() { func ExampleWithMaxRetries() { ctx := context.Background() - b, err := retry.NewFibonacci(1 * time.Second) + b, err := backoff.NewFibonacci(1 * time.Second) if err != nil { // handle err } - b = retry.WithMaxRetries(3, b) + b = backoff.WithMaxRetries(3, b) if err := retry.Do(ctx, b, func(_ context.Context) error { // your retry logic here @@ -91,11 +94,11 @@ func ExampleWithMaxRetries() { func ExampleWithCappedDuration() { ctx := context.Background() - b, err := retry.NewFibonacci(1 * time.Second) + b, err := backoff.NewFibonacci(1 * time.Second) if err != nil { // handle err } - b = retry.WithCappedDuration(3*time.Second, b) + b = backoff.WithCappedDuration(3*time.Second, b) if err := retry.Do(ctx, b, func(_ context.Context) error { // your retry logic here @@ -108,11 +111,11 @@ func ExampleWithCappedDuration() { func ExampleWithMaxDuration() { ctx := context.Background() - b, err := retry.NewFibonacci(1 * time.Second) + b, err := backoff.NewFibonacci(1 * time.Second) if err != nil { // handle err } - b = retry.WithMaxDuration(5*time.Second, b) + b = backoff.WithMaxDuration(5*time.Second, b) if err := retry.Do(ctx, b, func(_ context.Context) error { // your retry logic here @@ -123,7 +126,7 @@ func ExampleWithMaxDuration() { } func ExampleNewConstant() { - b, err := retry.NewConstant(1 * time.Second) + b, err := backoff.NewConstant(1 * time.Second) if err != nil { // handle the error here, likely from bad input return @@ -142,7 +145,7 @@ func ExampleNewConstant() { } func ExampleNewExponential() { - b, err := retry.NewExponential(1 * time.Second) + b, err := backoff.NewExponential(1 * time.Second) if err != nil { // handle the error here, likely from bad input return @@ -161,7 +164,7 @@ func ExampleNewExponential() { } func ExampleNewFibonacci() { - b, err := retry.NewFibonacci(1 * time.Second) + b, err := backoff.NewFibonacci(1 * time.Second) if err != nil { // handle err } @@ -177,3 +180,57 @@ func ExampleNewFibonacci() { // 5s // 8s } + +func ExampleDo_simple() { + ctx := context.Background() + + b, err := backoff.NewFibonacci(1 * time.Nanosecond) + if err != nil { + // handle error + } + + i := 0 + if err := retry.Do(ctx, backoff.WithMaxRetries(3, b), func(ctx context.Context) error { + fmt.Printf("%d\n", i) + i++ + return retry.RetryableError(fmt.Errorf("oops")) + }); err != nil { + // handle error + } + + // Output: + // 0 + // 1 + // 2 + // 3 +} + +func ExampleDo_customRetry() { + ctx := context.Background() + + b, err := backoff.NewFibonacci(1 * time.Nanosecond) + if err != nil { + // handle error + } + + // This example demonstrates selectively retrying specific errors. Only errors + // wrapped with RetryableError are eligible to be retried. + if err := retry.Do(ctx, backoff.WithMaxRetries(3, b), func(ctx context.Context) error { + resp, err := http.Get("https://google.com/") + if err != nil { + return err + } + defer resp.Body.Close() + + switch resp.StatusCode / 100 { + case 4: + return fmt.Errorf("bad response: %v", resp.StatusCode) + case 5: + return retry.RetryableError(fmt.Errorf("bad response: %v", resp.StatusCode)) + default: + return nil + } + }); err != nil { + // handle error + } +} diff --git a/repeat.go b/repeat.go deleted file mode 100644 index 4eba5c0..0000000 --- a/repeat.go +++ /dev/null @@ -1,51 +0,0 @@ -package retry - -import ( - "context" - "time" -) - -// RepeatFunc is a function passed to retry. -type RepeatFunc func(ctx context.Context) error - -// Repeat wraps a function with a backoff to repeat until it returns an error, or the backoff -// signals to stop. -// The provided context is passed to the RepeatFunc. -func Repeat(ctx context.Context, b Backoff, f RepeatFunc) error { - for { - // Return immediately if ctx is canceled - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - if err := f(ctx); err != nil { - return err - } - - next, stop := b.Next() - if stop { - return nil - } - - // ctx.Done() has priority, so we test it alone first - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - t := time.NewTimer(next) - select { - case <-ctx.Done(): - t.Stop() - return ctx.Err() - case <-t.C: - continue - } - } -} - -// TODO make the above like repeat.DoUntilError and then have a repeat.Do that takes an -// error handling function and keeps going diff --git a/repeat/repeat.go b/repeat/repeat.go new file mode 100644 index 0000000..83c83c3 --- /dev/null +++ b/repeat/repeat.go @@ -0,0 +1,98 @@ +package repeat + +import ( + "context" + "fmt" + "time" + + "github.com/swayne275/go-retry/common/backoff" +) + +var errFunctionSignaledToStop = fmt.Errorf("function signaled to stop") +var errBackoffSignaledToStop = fmt.Errorf("backoff signaled to stop") + +// RepeatFunc is a function passed to retry. +// It returns true if the function should be repeated, false otherwise. +type RepeatFunc func(ctx context.Context) bool + +// Do wraps a function with a backoff to repeat as long as f returns true, or until +// the backoff signals to stop. +// The provided context is passed to the RepeatFunc. +func Do(ctx context.Context, b backoff.Backoff, f RepeatFunc) error { + for { + // Return immediately if ctx is canceled + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if !f(ctx) { + return errFunctionSignaledToStop + } + + next, stop := b.Next() + if stop { + return errBackoffSignaledToStop + } + + // ctx.Done() has priority, so we test it alone first + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + t := time.NewTimer(next) + select { + case <-ctx.Done(): + t.Stop() + return ctx.Err() + case <-t.C: + continue + } + } +} + +// RepeatFunc is a function passed to retry. +// It returns true if the function should be repeated, false otherwise. +type RepeatUntilErrorFunc func(ctx context.Context) error + +// DoUntilError wraps a function with a backoff to repeat until f returns an error, or +// until the backoff signals to stop. +// The provided context is passed to the RepeatFunc. +func DoUntilError(ctx context.Context, b backoff.Backoff, f RepeatUntilErrorFunc) error { + for { + // Return immediately if ctx is canceled + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if err := f(ctx); err != nil { + return fmt.Errorf("%w: %w", errFunctionSignaledToStop, err) + } + + next, stop := b.Next() + if stop { + return errBackoffSignaledToStop + } + + // ctx.Done() has priority, so we test it alone first + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + t := time.NewTimer(next) + select { + case <-ctx.Done(): + t.Stop() + return ctx.Err() + case <-t.C: + continue + } + } +} diff --git a/repeat/repeat_test.go b/repeat/repeat_test.go new file mode 100644 index 0000000..f9fef8d --- /dev/null +++ b/repeat/repeat_test.go @@ -0,0 +1,136 @@ +package repeat + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/swayne275/go-retry/backoff" +) + +func TestDo(t *testing.T) { + t.Parallel() + + t.Run("exit_on_context_cancelled", func(t *testing.T) { + t.Parallel() + + b, err := backoff.NewConstant(1 * time.Nanosecond) + if err != nil { + t.Fatalf("failed to create constant backoff: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + retryFunc := func(_ context.Context) bool { return true } + + go func() { + time.Sleep(10 * time.Nanosecond) + cancel() + }() + if err = Do(ctx, b, retryFunc); err != context.Canceled { + t.Errorf("expected %q to be %q", err, context.Canceled) + } + }) + + t.Run("exit_on_RepeatFunc_false", func(t *testing.T) { + t.Parallel() + + b, err := backoff.NewConstant(1 * time.Nanosecond) + if err != nil { + t.Fatalf("failed to create constant backoff: %v", err) + } + + cnt := 0 + maxCnt := 3 + retryFunc := func(_ context.Context) bool { + cnt++ + return cnt <= maxCnt + } + + if err = Do(context.Background(), b, retryFunc); err != errFunctionSignaledToStop { + t.Errorf("expected %q to be %q", err, errFunctionSignaledToStop) + } + if cnt != maxCnt+1 { + t.Errorf("expected %d to be %d", cnt, maxCnt+1) + } + }) + + t.Run("exit_on_backoff_stop", func(t *testing.T) { + t.Parallel() + + b := backoff.WithMaxRetries(3, backoff.BackoffFunc(func() (time.Duration, bool) { + return 1 * time.Nanosecond, false + })) + + retryFunc := func(_ context.Context) bool { return true } + + if err := Do(context.Background(), b, retryFunc); err != errBackoffSignaledToStop { + t.Errorf("expected %q to be %q", err, errBackoffSignaledToStop) + } + }) +} + +func TestDoUntilError(t *testing.T) { + t.Parallel() + + t.Run("exit_on_context_cancelled", func(t *testing.T) { + t.Parallel() + + b, err := backoff.NewConstant(1 * time.Nanosecond) + if err != nil { + t.Fatalf("failed to create constant backoff: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + retryFunc := func(_ context.Context) error { return nil } + + go func() { + time.Sleep(10 * time.Nanosecond) + cancel() + }() + if err = DoUntilError(ctx, b, retryFunc); err != context.Canceled { + t.Errorf("expected %q to be %q", err, context.Canceled) + } + }) + + t.Run("exit_on_RepeatFunc_error", func(t *testing.T) { + t.Parallel() + + b, err := backoff.NewConstant(1 * time.Nanosecond) + if err != nil { + t.Fatalf("failed to create constant backoff: %v", err) + } + + cnt := 0 + maxCnt := 3 + retryFunc := func(_ context.Context) error { + cnt++ + if cnt > maxCnt { + return fmt.Errorf("function error") + } + return nil + } + + if err = DoUntilError(context.Background(), b, retryFunc); !errors.Is(err, errFunctionSignaledToStop) { + t.Errorf("expected %q to contain %q", err, errFunctionSignaledToStop) + } + if cnt != maxCnt+1 { + t.Errorf("expected %d to be %d", cnt, maxCnt+1) + } + }) + + t.Run("exit_on_backoff_stop", func(t *testing.T) { + t.Parallel() + + b := backoff.WithMaxRetries(3, backoff.BackoffFunc(func() (time.Duration, bool) { + return 1 * time.Nanosecond, false + })) + + retryFunc := func(_ context.Context) error { return nil } + + if err := DoUntilError(context.Background(), b, retryFunc); err != errBackoffSignaledToStop { + t.Errorf("expected %q to be %q", err, errBackoffSignaledToStop) + } + }) +} diff --git a/repeat/repeat_utils.go b/repeat/repeat_utils.go new file mode 100644 index 0000000..1043f19 --- /dev/null +++ b/repeat/repeat_utils.go @@ -0,0 +1,42 @@ +package repeat + +import ( + "context" + "fmt" + "time" + + "github.com/swayne275/go-retry/backoff" +) + +// ConstantRepeat is a wrapper around repeat that uses a constant backoff. It will +// repeat the function f until it returns false, or the context is canceled. +func ConstantRepeat(ctx context.Context, t time.Duration, f RepeatFunc) error { + b, err := backoff.NewConstant(t) + if err != nil { + return fmt.Errorf("failed to create constant backoff: %w", err) + } + + return Do(ctx, b, f) +} + +// ExponentialRetry is a wrapper around repeat that uses an exponential backoff. It will +// repeat the function f until it returns false, or the context is canceled. +func ExponentialRepeat(ctx context.Context, base time.Duration, f RepeatFunc) error { + b, err := backoff.NewExponential(base) + if err != nil { + return fmt.Errorf("failed to create exponential backoff: %w", err) + } + + return Do(ctx, b, f) +} + +// FibonacciRepeat is a wrapper around repeat that uses a FibonacciRetry backoff. It will +// repeat the function f until it returns false, or the context is canceled. +func FibonacciRepeat(ctx context.Context, base time.Duration, f RepeatFunc) error { + b, err := backoff.NewFibonacci(base) + if err != nil { + return fmt.Errorf("failed to create fibonacci backoff: %w", err) + + } + return Do(ctx, b, f) +} diff --git a/repeat/repeat_utils_test.go b/repeat/repeat_utils_test.go new file mode 100644 index 0000000..a1f5996 --- /dev/null +++ b/repeat/repeat_utils_test.go @@ -0,0 +1,121 @@ +package repeat + +import ( + "context" + "testing" + "time" +) + +func TestConstantRepeat(t *testing.T) { + t.Parallel() + + t.Run("exit_on_context_cancelled", func(t *testing.T) { + t.Parallel() + + f := func(_ context.Context) bool { return true } + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(10 * time.Nanosecond) + cancel() + }() + + if err := ConstantRepeat(ctx, 1*time.Nanosecond, f); err != context.Canceled { + t.Errorf("expected %q to be %q", err, context.Canceled) + } + }) + + t.Run("exit_on_RepeatFunc_false", func(t *testing.T) { + t.Parallel() + + cnt := 0 + maxCnt := 3 + f := func(_ context.Context) bool { + cnt++ + + return cnt <= maxCnt + } + + if err := ConstantRepeat(context.Background(), 1*time.Nanosecond, f); err != errFunctionSignaledToStop { + t.Errorf("expected %q to be %q", err, context.Canceled) + } + if cnt != maxCnt+1 { + t.Errorf("expected %d to be %d", cnt, maxCnt+1) + } + }) +} + +func TestExponentialRepeat(t *testing.T) { + t.Parallel() + + t.Run("exit_on_context_cancelled", func(t *testing.T) { + t.Parallel() + + f := func(_ context.Context) bool { return true } + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(10 * time.Nanosecond) + cancel() + }() + + if err := ExponentialRepeat(ctx, 1*time.Nanosecond, f); err != context.Canceled { + t.Errorf("expected %q to be %q", err, context.Canceled) + } + }) + + t.Run("exit_on_RepeatFunc_false", func(t *testing.T) { + t.Parallel() + + cnt := 0 + maxCnt := 3 + f := func(_ context.Context) bool { + cnt++ + + return cnt <= maxCnt + } + + if err := ExponentialRepeat(context.Background(), 1*time.Nanosecond, f); err != errFunctionSignaledToStop { + t.Errorf("expected %q to be %q", err, context.Canceled) + } + if cnt != maxCnt+1 { + t.Errorf("expected %d to be %d", cnt, maxCnt+1) + } + }) +} + +func TestFibonacciRepeat(t *testing.T) { + t.Parallel() + + t.Run("exit_on_context_cancelled", func(t *testing.T) { + t.Parallel() + + f := func(_ context.Context) bool { return true } + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(10 * time.Nanosecond) + cancel() + }() + + if err := FibonacciRepeat(ctx, 1*time.Nanosecond, f); err != context.Canceled { + t.Errorf("expected %q to be %q", err, context.Canceled) + } + }) + + t.Run("exit_on_RepeatFunc_false", func(t *testing.T) { + t.Parallel() + + cnt := 0 + maxCnt := 3 + f := func(_ context.Context) bool { + cnt++ + + return cnt <= maxCnt + } + + if err := FibonacciRepeat(context.Background(), 1*time.Nanosecond, f); err != errFunctionSignaledToStop { + t.Errorf("expected %q to be %q", err, context.Canceled) + } + if cnt != maxCnt+1 { + t.Errorf("expected %d to be %d", cnt, maxCnt+1) + } + }) +} diff --git a/retry.go b/retry/retry.go similarity index 74% rename from retry.go rename to retry/retry.go index a7e10b9..60b7d15 100644 --- a/retry.go +++ b/retry/retry.go @@ -15,9 +15,15 @@ package retry import ( "context" "errors" + "fmt" "time" + + "github.com/swayne275/go-retry/common/backoff" ) +var errFunctionReturnedNonRetryableError = fmt.Errorf("function returned non retryable error") +var errBackoffSignaledToStop = fmt.Errorf("backoff signaled to stop") + // RetryFunc is a function passed to retry. type RetryFunc func(ctx context.Context) error @@ -46,9 +52,10 @@ func (e *retryableError) Error() string { return "retryable: " + e.err.Error() } -// Do wraps a function with a backoff to retry. The provided context is the same -// context passed to the RetryFunc. -func Do(ctx context.Context, b Backoff, f RetryFunc) error { +// Do wraps a function with a backoff to retry. It will retry until f returns either +// nil or a non-retryable error. +// The provided context is the same context passed to the RetryFunc. +func Do(ctx context.Context, b backoff.Backoff, f RetryFunc) error { for { // Return immediately if ctx is canceled select { @@ -65,12 +72,12 @@ func Do(ctx context.Context, b Backoff, f RetryFunc) error { // Not retryable var rerr *retryableError if !errors.As(err, &rerr) { - return err + return fmt.Errorf("%w: %w", errFunctionReturnedNonRetryableError, err) } next, stop := b.Next() if stop { - return rerr.Unwrap() + return fmt.Errorf("%w: %w", errBackoffSignaledToStop, rerr.Unwrap()) } // ctx.Done() has priority, so we test it alone first diff --git a/retry/retry_test.go b/retry/retry_test.go new file mode 100644 index 0000000..cdf27b8 --- /dev/null +++ b/retry/retry_test.go @@ -0,0 +1,190 @@ +package retry + +import ( + "context" + "errors" + "fmt" + "io" + "strings" + "testing" + "time" + + "github.com/swayne275/go-retry/backoff" +) + +func TestRetryableError(t *testing.T) { + t.Parallel() + + err := RetryableError(fmt.Errorf("oops")) + if got, want := err.Error(), "retryable: "; !strings.Contains(got, want) { + t.Errorf("expected %v to contain %v", got, want) + } +} + +func TestDo(t *testing.T) { + t.Parallel() + + t.Run("exit_on_non_retryable", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + b := backoff.WithMaxRetries(3, backoff.BackoffFunc(func() (time.Duration, bool) { + return 1 * time.Nanosecond, false + })) + + var i int + if err := Do(ctx, b, func(_ context.Context) error { + i++ + return fmt.Errorf("oops") // not retryable + }); err == nil { + t.Fatal("expected err") + } + + if got, want := i, 1; got != want { + t.Errorf("expected %v to be %v", got, want) + } + }) + + t.Run("unwraps", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + b := backoff.WithMaxRetries(1, backoff.BackoffFunc(func() (time.Duration, bool) { + return 1 * time.Nanosecond, false + })) + + err := Do(ctx, b, func(_ context.Context) error { + return RetryableError(io.EOF) + }) + if err == nil { + t.Fatal("expected err") + } + + if !errors.Is(err, io.EOF) { + t.Errorf("expected %q to be %q", err, io.EOF) + } + }) + + t.Run("exit_no_error", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + b := backoff.WithMaxRetries(3, backoff.BackoffFunc(func() (time.Duration, bool) { + return 1 * time.Nanosecond, false + })) + + var i int + if err := Do(ctx, b, func(_ context.Context) error { + i++ + return nil // no error + }); err != nil { + t.Fatal("expected no err") + } + + if got, want := i, 1; got != want { + t.Errorf("expected %v to be %v", got, want) + } + }) + + t.Run("exit_on_context_canceled", func(t *testing.T) { + t.Parallel() + + b, err := backoff.NewConstant(1 * time.Nanosecond) + if err != nil { + t.Fatalf("failed to create constant backoff: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + retryFunc := func(_ context.Context) error { + return RetryableError(fmt.Errorf("some retryable error")) + } + + go func() { + time.Sleep(10 * time.Nanosecond) + cancel() + }() + if err = Do(ctx, b, retryFunc); err != context.Canceled { + t.Errorf("expected %q to be %q", err, context.Canceled) + } + }) + + t.Run("exit_on_RetryFunc_nonretryable_error", func(t *testing.T) { + t.Parallel() + + b, err := backoff.NewConstant(1 * time.Nanosecond) + if err != nil { + t.Fatalf("failed to create constant backoff: %v", err) + } + + cnt := 0 + maxCnt := 3 + retryFunc := func(_ context.Context) error { + cnt++ + if cnt > maxCnt { + return fmt.Errorf("function error") + } + + return RetryableError(fmt.Errorf("some retryable error")) + } + + if err = Do(context.Background(), b, retryFunc); !errors.Is(err, errFunctionReturnedNonRetryableError) { + t.Errorf("expected %q to contain %q", err, errFunctionReturnedNonRetryableError) + } + if cnt != maxCnt+1 { + t.Errorf("expected %d to be %d", cnt, maxCnt+1) + } + }) + + t.Run("exit_on_backoff_stop", func(t *testing.T) { + t.Parallel() + + b := backoff.WithMaxRetries(3, backoff.BackoffFunc(func() (time.Duration, bool) { + return 1 * time.Nanosecond, false + })) + + errUnderlyingRetryable := RetryableError(fmt.Errorf("some retryable error")) + err := Do(context.Background(), b, func(_ context.Context) error { + return RetryableError(errUnderlyingRetryable) + }) + if !errors.Is(err, errBackoffSignaledToStop) { + t.Errorf("expected %q to be %q", err, errBackoffSignaledToStop) + } + if !errors.Is(err, errUnderlyingRetryable) { + t.Errorf("expected %q to be %q", err, errUnderlyingRetryable) + } + }) +} + +func TestCancel(t *testing.T) { + for i := 0; i < 100000; i++ { + ctx, cancel := context.WithCancel(context.Background()) + + calls := 0 + rf := func(ctx context.Context) error { + calls++ + // Never succeed. + // Always return a RetryableError + return RetryableError(errors.New("nope")) + } + + const delay time.Duration = time.Millisecond + b, err := backoff.NewConstant(delay) + if err != nil { + t.Fatalf("failed to create constant backoff: %v", err) + } + + const maxRetries = 5 + b = backoff.WithMaxRetries(maxRetries, b) + + const jitter time.Duration = 5 * time.Millisecond + b = backoff.WithJitter(jitter, b) + + // Here we cancel the Context *before* the call to Do + cancel() + Do(ctx, b, rf) + + if calls > 1 { + t.Errorf("rf was called %d times instead of 0 or 1", calls) + } + } +} diff --git a/retry/retry_utils.go b/retry/retry_utils.go new file mode 100644 index 0000000..1cc08b7 --- /dev/null +++ b/retry/retry_utils.go @@ -0,0 +1,44 @@ +package retry + +import ( + "context" + "fmt" + "time" + + "github.com/swayne275/go-retry/backoff" +) + +// TODO tests should include retryable errors and non retryable errors + +// ConstantRetry is a wrapper around retry that uses a constant backoff. It will +// retry the function f until it returns a non-retryable error, or the context is canceled. +func ConstantRetry(ctx context.Context, t time.Duration, f RetryFunc) error { + b, err := backoff.NewConstant(t) + if err != nil { + return fmt.Errorf("failed to create constant backoff: %w", err) + } + + return Do(ctx, b, f) +} + +// ExponentialRetry is a wrapper around retry that uses an exponential backoff. It will +// retry the function f until it returns a non-retryable error, or the context is canceled. +func ExponentialRetry(ctx context.Context, base time.Duration, f RetryFunc) error { + b, err := backoff.NewExponential(base) + if err != nil { + return fmt.Errorf("failed to create exponential backoff: %w", err) + } + + return Do(ctx, b, f) +} + +// FibonacciRetry is a wrapper around retry that uses a FibonacciRetry backoff. It will +// retry the function f until it returns a non-retryable error, or the context is canceled. +func FibonacciRetry(ctx context.Context, base time.Duration, f RetryFunc) error { + b, err := backoff.NewFibonacci(base) + if err != nil { + return fmt.Errorf("failed to create fibonacci backoff: %w", err) + + } + return Do(ctx, b, f) +} diff --git a/retry/retry_utils_test.go b/retry/retry_utils_test.go new file mode 100644 index 0000000..aed11f3 --- /dev/null +++ b/retry/retry_utils_test.go @@ -0,0 +1,156 @@ +package retry + +import ( + "context" + "errors" + "fmt" + "testing" + "time" +) + +func TestConstantRetry(t *testing.T) { + t.Parallel() + + t.Run("exit_on_context_cancelled", func(t *testing.T) { + t.Parallel() + + f := func(_ context.Context) error { + return RetryableError(fmt.Errorf("some retryable err")) + } + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(10 * time.Nanosecond) + cancel() + }() + + if err := ConstantRetry(ctx, 1*time.Nanosecond, f); err != context.Canceled { + t.Errorf("expected %q to be %q", err, context.Canceled) + } + }) + + t.Run("exit_on_RetryFunc_nonretryable_error", func(t *testing.T) { + t.Parallel() + + cnt := 0 + nonRetryableCnt := 3 + errNonRetryable := fmt.Errorf("some non-retryable error") + f := func(_ context.Context) error { + cnt++ + + if cnt > nonRetryableCnt { + return errNonRetryable + } + + return RetryableError(fmt.Errorf("some non-retryable error")) + } + + err := ConstantRetry(context.Background(), 1*time.Nanosecond, f) + if !errors.Is(err, errFunctionReturnedNonRetryableError) { + t.Errorf("expected %q to be %q", err, errFunctionReturnedNonRetryableError) + } + if !errors.Is(err, errNonRetryable) { + t.Errorf("expected %q to be %q", err, errNonRetryable) + } + if cnt != nonRetryableCnt+1 { + t.Errorf("expected %d to be %d", cnt, nonRetryableCnt+1) + } + }) +} + +func TestExponentialRetry(t *testing.T) { + t.Parallel() + + t.Run("exit_on_context_cancelled", func(t *testing.T) { + t.Parallel() + + f := func(_ context.Context) error { + return RetryableError(fmt.Errorf("some retryable err")) + } + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(10 * time.Nanosecond) + cancel() + }() + + if err := ExponentialRetry(ctx, 1*time.Nanosecond, f); err != context.Canceled { + t.Errorf("expected %q to be %q", err, context.Canceled) + } + }) + + t.Run("exit_on_RetryFunc_nonretryable_error", func(t *testing.T) { + t.Parallel() + + cnt := 0 + nonRetryableCnt := 3 + errNonRetryable := fmt.Errorf("some non-retryable error") + f := func(_ context.Context) error { + cnt++ + + if cnt > nonRetryableCnt { + return errNonRetryable + } + + return RetryableError(fmt.Errorf("some non-retryable error")) + } + + err := ExponentialRetry(context.Background(), 1*time.Nanosecond, f) + if !errors.Is(err, errFunctionReturnedNonRetryableError) { + t.Errorf("expected %q to be %q", err, errFunctionReturnedNonRetryableError) + } + if !errors.Is(err, errNonRetryable) { + t.Errorf("expected %q to be %q", err, errNonRetryable) + } + if cnt != nonRetryableCnt+1 { + t.Errorf("expected %d to be %d", cnt, nonRetryableCnt+1) + } + }) +} + +func TestFibonacciRetry(t *testing.T) { + t.Parallel() + + t.Run("exit_on_context_cancelled", func(t *testing.T) { + t.Parallel() + + f := func(_ context.Context) error { + return RetryableError(fmt.Errorf("some retryable err")) + } + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(10 * time.Nanosecond) + cancel() + }() + + if err := FibonacciRetry(ctx, 1*time.Nanosecond, f); err != context.Canceled { + t.Errorf("expected %q to be %q", err, context.Canceled) + } + }) + + t.Run("exit_on_RetryFunc_nonretryable_error", func(t *testing.T) { + t.Parallel() + + cnt := 0 + nonRetryableCnt := 3 + errNonRetryable := fmt.Errorf("some non-retryable error") + f := func(_ context.Context) error { + cnt++ + + if cnt > nonRetryableCnt { + return errNonRetryable + } + + return RetryableError(fmt.Errorf("some non-retryable error")) + } + + err := FibonacciRetry(context.Background(), 1*time.Nanosecond, f) + if !errors.Is(err, errFunctionReturnedNonRetryableError) { + t.Errorf("expected %q to be %q", err, errFunctionReturnedNonRetryableError) + } + if !errors.Is(err, errNonRetryable) { + t.Errorf("expected %q to be %q", err, errNonRetryable) + } + if cnt != nonRetryableCnt+1 { + t.Errorf("expected %d to be %d", cnt, nonRetryableCnt+1) + } + }) +} diff --git a/retry_test.go b/retry_test.go deleted file mode 100644 index 2459d0f..0000000 --- a/retry_test.go +++ /dev/null @@ -1,216 +0,0 @@ -package retry_test - -import ( - "context" - "errors" - "fmt" - "io" - "net/http" - "strings" - "testing" - "time" - - "github.com/swayne275/go-retry" -) - -func TestRetryableError(t *testing.T) { - t.Parallel() - - err := retry.RetryableError(fmt.Errorf("oops")) - if got, want := err.Error(), "retryable: "; !strings.Contains(got, want) { - t.Errorf("expected %v to contain %v", got, want) - } -} - -func TestDo(t *testing.T) { - t.Parallel() - - t.Run("exit_on_max_attempt", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - b := retry.WithMaxRetries(3, retry.BackoffFunc(func() (time.Duration, bool) { - return 1 * time.Nanosecond, false - })) - - var i int - if err := retry.Do(ctx, b, func(_ context.Context) error { - i++ - return retry.RetryableError(fmt.Errorf("oops")) - }); err == nil { - t.Fatal("expected err") - } - - // 1 + retries - if got, want := i, 4; got != want { - t.Errorf("expected %v to be %v", got, want) - } - }) - - t.Run("exit_on_non_retryable", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - b := retry.WithMaxRetries(3, retry.BackoffFunc(func() (time.Duration, bool) { - return 1 * time.Nanosecond, false - })) - - var i int - if err := retry.Do(ctx, b, func(_ context.Context) error { - i++ - return fmt.Errorf("oops") // not retryable - }); err == nil { - t.Fatal("expected err") - } - - if got, want := i, 1; got != want { - t.Errorf("expected %v to be %v", got, want) - } - }) - - t.Run("unwraps", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - b := retry.WithMaxRetries(1, retry.BackoffFunc(func() (time.Duration, bool) { - return 1 * time.Nanosecond, false - })) - - err := retry.Do(ctx, b, func(_ context.Context) error { - return retry.RetryableError(io.EOF) - }) - if err == nil { - t.Fatal("expected err") - } - - if got, want := err, io.EOF; got != want { - t.Errorf("expected %#v to be %#v", got, want) - } - }) - - t.Run("exit_no_error", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - b := retry.WithMaxRetries(3, retry.BackoffFunc(func() (time.Duration, bool) { - return 1 * time.Nanosecond, false - })) - - var i int - if err := retry.Do(ctx, b, func(_ context.Context) error { - i++ - return nil // no error - }); err != nil { - t.Fatal("expected no err") - } - - if got, want := i, 1; got != want { - t.Errorf("expected %v to be %v", got, want) - } - }) - - t.Run("context_canceled", func(t *testing.T) { - t.Parallel() - - b := retry.BackoffFunc(func() (time.Duration, bool) { - return 5 * time.Second, false - }) - - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() - - if err := retry.Do(ctx, b, func(_ context.Context) error { - return retry.RetryableError(fmt.Errorf("oops")) // no error - }); err != context.DeadlineExceeded { - t.Errorf("expected %v to be %v", err, context.DeadlineExceeded) - } - }) -} - -func ExampleDo_simple() { - ctx := context.Background() - - b, err := retry.NewFibonacci(1 * time.Nanosecond) - if err != nil { - // handle error - } - - i := 0 - if err := retry.Do(ctx, retry.WithMaxRetries(3, b), func(ctx context.Context) error { - fmt.Printf("%d\n", i) - i++ - return retry.RetryableError(fmt.Errorf("oops")) - }); err != nil { - // handle error - } - - // Output: - // 0 - // 1 - // 2 - // 3 -} - -func ExampleDo_customRetry() { - ctx := context.Background() - - b, err := retry.NewFibonacci(1 * time.Nanosecond) - if err != nil { - // handle error - } - - // This example demonstrates selectively retrying specific errors. Only errors - // wrapped with RetryableError are eligible to be retried. - if err := retry.Do(ctx, retry.WithMaxRetries(3, b), func(ctx context.Context) error { - resp, err := http.Get("https://google.com/") - if err != nil { - return err - } - defer resp.Body.Close() - - switch resp.StatusCode / 100 { - case 4: - return fmt.Errorf("bad response: %v", resp.StatusCode) - case 5: - return retry.RetryableError(fmt.Errorf("bad response: %v", resp.StatusCode)) - default: - return nil - } - }); err != nil { - // handle error - } -} - -func TestCancel(t *testing.T) { - for i := 0; i < 100000; i++ { - ctx, cancel := context.WithCancel(context.Background()) - - calls := 0 - rf := func(ctx context.Context) error { - calls++ - // Never succeed. - // Always return a RetryableError - return retry.RetryableError(errors.New("nope")) - } - - const delay time.Duration = time.Millisecond - b, err := retry.NewConstant(delay) - if err != nil { - t.Fatalf("failed to create constant backoff: %v", err) - } - - const maxRetries = 5 - b = retry.WithMaxRetries(maxRetries, b) - - const jitter time.Duration = 5 * time.Millisecond - b = retry.WithJitter(jitter, b) - - // Here we cancel the Context *before* the call to Do - cancel() - retry.Do(ctx, b, rf) - - if calls > 1 { - t.Errorf("rf was called %d times instead of 0 or 1", calls) - } - } -}