diff --git a/predicate.go b/predicate.go index d039f76..0ddb95f 100644 --- a/predicate.go +++ b/predicate.go @@ -51,6 +51,21 @@ func If(predicate Predicate, originalStep Step) Step { return wrappedStep } +// IfOrElse returns a new step that wraps the given steps and executes its action based on the given Predicate. +// The name of the step is taken from `trueStep`. +// The context.Context from the pipeline is passed through the given actions. +func IfOrElse(predicate Predicate, trueStep Step, falseStep Step) Step { + wrappedStep := Step{Name: trueStep.Name} + wrappedStep.F = func(ctx context.Context) Result { + if predicate(ctx) { + return trueStep.F(ctx) + } else { + return falseStep.F(ctx) + } + } + return wrappedStep +} + // Bool returns a Predicate that simply returns v when evaluated. // Use BoolPtr() over Bool() if the value can change between setting up the pipeline and evaluating the predicate. func Bool(v bool) Predicate { diff --git a/predicate_test.go b/predicate_test.go index 9859a08..c8abbd1 100644 --- a/predicate_test.go +++ b/predicate_test.go @@ -131,6 +131,41 @@ func TestIf(t *testing.T) { } } +func TestIfOrElse(t *testing.T) { + counter := 0 + tests := map[string]struct { + givenPredicate Predicate + expectedCalls int + }{ + "GivenWrappedStep_WhenPredicateEvalsTrue_ThenRunMainAction": { + givenPredicate: truePredicate(&counter), + expectedCalls: 2, + }, + "GivenWrappedStep_WhenPredicateEvalsFalse_ThenRunAlternativeAction": { + givenPredicate: falsePredicate(&counter), + expectedCalls: -2, + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + counter = 0 + trueStep := NewStep("true", func(_ context.Context) Result { + counter++ + return newEmptyResult("true") + }) + falseStep := NewStep("false", func(ctx context.Context) Result { + counter-- + return newEmptyResult("false") + }) + wrapped := IfOrElse(tt.givenPredicate, trueStep, falseStep) + result := wrapped.F(nil) + require.NoError(t, result.Err()) + assert.Equal(t, tt.expectedCalls, counter) + assert.Equal(t, trueStep.Name, wrapped.Name) + }) + } +} func TestBoolPtr(t *testing.T) { called := false b := false