Skip to content

Commit

Permalink
move retry utils over, test
Browse files Browse the repository at this point in the history
  • Loading branch information
swayne275 committed Jun 30, 2024
1 parent 0c91d53 commit e81c06f
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 45 deletions.
13 changes: 0 additions & 13 deletions backoff/backoff_constant.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,12 @@
package backoff

import (
"context"
"fmt"
"time"

"github.com/swayne275/go-retry/common/backoff"
"github.com/swayne275/go-retry/retry"
)

// ConstantRetry is a wrapper around retry that uses a constant backoff. It will
// retry the function f until it returns an error, or the context is canceled.
func ConstantRetry(ctx context.Context, t time.Duration, f retry.RetryFunc) error {
b, err := NewConstant(t)
if err != nil {
return fmt.Errorf("failed to create constant backoff: %w", err)
}

return retry.Do(ctx, b, f)
}

// NewConstant creates a new constant backoff using the value t. The wait time
// is the provided constant value. It returns an error if t is not greater than 0.
func NewConstant(t time.Duration) (backoff.Backoff, error) {
Expand Down
14 changes: 0 additions & 14 deletions backoff/backoff_exponential.go
Original file line number Diff line number Diff line change
@@ -1,33 +1,19 @@
package backoff

import (
"context"
"fmt"
"math"
"sync/atomic"
"time"

"github.com/swayne275/go-retry/common/backoff"
"github.com/swayne275/go-retry/retry"
)

type exponentialBackoff struct {
base time.Duration
attempt uint64
}

// ExponentialRetry is a wrapper around retry that uses an exponential backoff. See
// NewExponential.
// TODO is this useful or fine as an example?
func ExponentialRetry(ctx context.Context, base time.Duration, f retry.RetryFunc) error {
b, err := NewExponential(base)
if err != nil {
return fmt.Errorf("failed to create exponential backoff: %w", err)
}

return retry.Do(ctx, b, f)
}

// NewExponential creates a new exponential backoff using the starting value of
// base and doubling on each failure (1, 2, 4, 8, 16, 32, 64...), up to max.
//
Expand Down
14 changes: 0 additions & 14 deletions backoff/backoff_fibonacci.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
package backoff

import (
"context"
"fmt"
"math"
"sync/atomic"
"time"
"unsafe"

"github.com/swayne275/go-retry/common/backoff"
"github.com/swayne275/go-retry/retry"
)

type state [2]time.Duration
Expand All @@ -19,18 +17,6 @@ type fibonacciBackoff struct {
base time.Duration
}

// FibonacciRetry is a wrapper around retry that uses a FibonacciRetry backoff. See
// NewFibonacci.
// TODO is this useful or should we move to example?
func FibonacciRetry(ctx context.Context, base time.Duration, f retry.RetryFunc) error {
b, err := NewFibonacci(base)
if err != nil {
return fmt.Errorf("failed to create fibonacci backoff: %w", err)

}
return retry.Do(ctx, b, f)
}

// NewFibonacci creates a new Fibonacci backoff that follows the fibonacci sequence
// multipled by base. The wait time is the sum of the previous two wait times on each
// previous attempt base * (1, 2, 3, 5, 8, 13...).
Expand Down
15 changes: 11 additions & 4 deletions retry/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@ package retry
import (
"context"
"errors"
"fmt"
"time"

"github.com/swayne275/go-retry/common/backoff"
)

var errFunctionReturnedNonRetryableError = fmt.Errorf("function returned non retryable error")

// TODO do we need to write a test for this case
var errBackoffSignaledToStop = fmt.Errorf("backoff signaled to stop")

// RetryFunc is a function passed to retry.
type RetryFunc func(ctx context.Context) error

Expand Down Expand Up @@ -48,8 +54,9 @@ func (e *retryableError) Error() string {
return "retryable: " + e.err.Error()
}

// Do wraps a function with a backoff to retry. The provided context is the same
// context passed to the RetryFunc.
// Do wraps a function with a backoff to retry. It will retry until f returns either
// nil or a non-retryable error.
// The provided context is the same context passed to the RetryFunc.
func Do(ctx context.Context, b backoff.Backoff, f RetryFunc) error {
for {
// Return immediately if ctx is canceled
Expand All @@ -67,12 +74,12 @@ func Do(ctx context.Context, b backoff.Backoff, f RetryFunc) error {
// Not retryable
var rerr *retryableError
if !errors.As(err, &rerr) {
return err
return fmt.Errorf("%w: %w", errFunctionReturnedNonRetryableError, err)
}

next, stop := b.Next()
if stop {
return rerr.Unwrap()
return fmt.Errorf("%w: %w", errBackoffSignaledToStop, rerr.Unwrap())
}

// ctx.Done() has priority, so we test it alone first
Expand Down
44 changes: 44 additions & 0 deletions retry/retry_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package retry

import (
"context"
"fmt"
"time"

"github.com/swayne275/go-retry/backoff"
)

// TODO tests should include retryable errors and non retryable errors

// ConstantRetry is a wrapper around retry that uses a constant backoff. It will
// retry the function f until it returns a non-retryable error, or the context is canceled.
func ConstantRetry(ctx context.Context, t time.Duration, f RetryFunc) error {
b, err := backoff.NewConstant(t)
if err != nil {
return fmt.Errorf("failed to create constant backoff: %w", err)
}

return Do(ctx, b, f)
}

// ExponentialRetry is a wrapper around retry that uses an exponential backoff. It will
// retry the function f until it returns a non-retryable error, or the context is canceled.
func ExponentialRetry(ctx context.Context, base time.Duration, f RetryFunc) error {
b, err := backoff.NewExponential(base)
if err != nil {
return fmt.Errorf("failed to create exponential backoff: %w", err)
}

return Do(ctx, b, f)
}

// FibonacciRetry is a wrapper around retry that uses a FibonacciRetry backoff. It will
// retry the function f until it returns a non-retryable error, or the context is canceled.
func FibonacciRetry(ctx context.Context, base time.Duration, f RetryFunc) error {
b, err := backoff.NewFibonacci(base)
if err != nil {
return fmt.Errorf("failed to create fibonacci backoff: %w", err)

}
return Do(ctx, b, f)
}
156 changes: 156 additions & 0 deletions retry/retry_utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package retry

import (
"context"
"errors"
"fmt"
"testing"
"time"
)

func TestConstantRetry(t *testing.T) {
t.Parallel()

t.Run("exit_on_context_cancelled", func(t *testing.T) {
t.Parallel()

f := func(_ context.Context) error {
return RetryableError(fmt.Errorf("some retryable err"))
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(10 * time.Nanosecond)
cancel()
}()

if err := ConstantRetry(ctx, 1*time.Nanosecond, f); err != context.Canceled {
t.Errorf("expected %q to be %q", err, context.Canceled)
}
})

t.Run("exit_on_RetryFunc_nonretryable_error", func(t *testing.T) {
t.Parallel()

cnt := 0
nonRetryableCnt := 3
errNonRetryable := fmt.Errorf("some non-retryable error")
f := func(_ context.Context) error {
cnt++

if cnt > nonRetryableCnt {
return errNonRetryable
}

return RetryableError(fmt.Errorf("some non-retryable error"))
}

err := ConstantRetry(context.Background(), 1*time.Nanosecond, f)
if !errors.Is(err, errFunctionReturnedNonRetryableError) {
t.Errorf("expected %q to be %q", err, errFunctionReturnedNonRetryableError)
}
if !errors.Is(err, errNonRetryable) {
t.Errorf("expected %q to be %q", err, errNonRetryable)
}
if cnt != nonRetryableCnt+1 {
t.Errorf("expected %d to be %d", cnt, nonRetryableCnt+1)
}
})
}

func TestExponentialRetry(t *testing.T) {
t.Parallel()

t.Run("exit_on_context_cancelled", func(t *testing.T) {
t.Parallel()

f := func(_ context.Context) error {
return RetryableError(fmt.Errorf("some retryable err"))
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(10 * time.Nanosecond)
cancel()
}()

if err := ExponentialRetry(ctx, 1*time.Nanosecond, f); err != context.Canceled {
t.Errorf("expected %q to be %q", err, context.Canceled)
}
})

t.Run("exit_on_RetryFunc_nonretryable_error", func(t *testing.T) {
t.Parallel()

cnt := 0
nonRetryableCnt := 3
errNonRetryable := fmt.Errorf("some non-retryable error")
f := func(_ context.Context) error {
cnt++

if cnt > nonRetryableCnt {
return errNonRetryable
}

return RetryableError(fmt.Errorf("some non-retryable error"))
}

err := ExponentialRetry(context.Background(), 1*time.Nanosecond, f)
if !errors.Is(err, errFunctionReturnedNonRetryableError) {
t.Errorf("expected %q to be %q", err, errFunctionReturnedNonRetryableError)
}
if !errors.Is(err, errNonRetryable) {
t.Errorf("expected %q to be %q", err, errNonRetryable)
}
if cnt != nonRetryableCnt+1 {
t.Errorf("expected %d to be %d", cnt, nonRetryableCnt+1)
}
})
}

func TestFibonacciRetry(t *testing.T) {
t.Parallel()

t.Run("exit_on_context_cancelled", func(t *testing.T) {
t.Parallel()

f := func(_ context.Context) error {
return RetryableError(fmt.Errorf("some retryable err"))
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(10 * time.Nanosecond)
cancel()
}()

if err := FibonacciRetry(ctx, 1*time.Nanosecond, f); err != context.Canceled {
t.Errorf("expected %q to be %q", err, context.Canceled)
}
})

t.Run("exit_on_RetryFunc_nonretryable_error", func(t *testing.T) {
t.Parallel()

cnt := 0
nonRetryableCnt := 3
errNonRetryable := fmt.Errorf("some non-retryable error")
f := func(_ context.Context) error {
cnt++

if cnt > nonRetryableCnt {
return errNonRetryable
}

return RetryableError(fmt.Errorf("some non-retryable error"))
}

err := FibonacciRetry(context.Background(), 1*time.Nanosecond, f)
if !errors.Is(err, errFunctionReturnedNonRetryableError) {
t.Errorf("expected %q to be %q", err, errFunctionReturnedNonRetryableError)
}
if !errors.Is(err, errNonRetryable) {
t.Errorf("expected %q to be %q", err, errNonRetryable)
}
if cnt != nonRetryableCnt+1 {
t.Errorf("expected %d to be %d", cnt, nonRetryableCnt+1)
}
})
}

0 comments on commit e81c06f

Please sign in to comment.