diff --git a/README.md b/README.md index 83dfdfb..5d0d651 100644 --- a/README.md +++ b/README.md @@ -176,9 +176,9 @@ from pypipeline import Pipeline pipeline = Pipeline() -def custom_pipe(passable, next): +def custom_pipe(passable, next_pipe): passable = passable.replace('hello', 'goodbye') - return next(passable) + return next_pipe(passable) pipeline.through([ custom_pipe, @@ -202,15 +202,15 @@ optionally implement the `StellarWP\Pipeline\Contracts\Pipe` interface to enforc First class: ```python class CapitalizePipe: - def handle(self, passable, next): - return next(passable.capitalize()) + def handle(self, passable, next_pipe): + return next_pipe(passable.capitalize()) ``` Second class: ```python class StripPipe: - def handle(self, passable, next): - return next(passable.strip()) + def handle(self, passable, next_pipe): + return next_pipe(passable.strip()) ``` #### Example pipeline @@ -444,7 +444,7 @@ pipeline.through([ str.capitalize, str.strip ]) pipeline.pipes([ str.capitalize, str.strip ]) # Pass closures as pipes. -pipeline.through([ str.capitalize, lambda passable, next: next(passable.strip)]) +pipeline.through([ str.capitalize, lambda passable, next: next_pipe(passable.strip)]) # Pass objects as pipes. pipeline.through([ CapitalizePipe(), StripPipe() ]) diff --git a/pypipeline/pipeline.py b/pypipeline/pipeline.py index 38babd5..18fa3c3 100644 --- a/pypipeline/pipeline.py +++ b/pypipeline/pipeline.py @@ -1,4 +1,5 @@ from typing import Any, Callable, List, Optional, Union +from functools import reduce from inspect import signature class Pipeline: @@ -39,17 +40,23 @@ def carry(self, next_pipe: Callable, current_pipe: Any) -> Callable: Callable: The callable for the pipe. """ def wrapper(passable): - print(passable) try: - # Determine how many parameters the current_pipe accepts - params = signature(current_pipe).parameters - if callable(current_pipe) and len(params) == 2 and 'next' in params: + params = None + if callable(current_pipe): + params = signature(current_pipe).parameters + + is_object = isinstance(current_pipe, object) and not isinstance(current_pipe, (int, float, str, bool, list, dict, tuple, set)) + if is_object == True and hasattr(current_pipe, self.method): + method = getattr(current_pipe, self.method, None) + return method(passable, next_pipe) + elif callable(current_pipe) and 'next_pipe' in params: return current_pipe(passable, next_pipe) elif callable(current_pipe): - result = current_pipe(passable) - return next_pipe(result) + passable = current_pipe(passable) + return next_pipe(passable) else: raise TypeError("The pipe must be callable") + except Exception as e: return self.handle_exception(passable, e) return wrapper @@ -146,11 +153,7 @@ def then(self, destination: Optional[Callable] = None): if destination is None: destination = lambda x: x - pipeline = self.prepare_destination(destination) - - # We reverse the - for pipe in reversed(self.pipes): - pipeline = self.carry(pipeline, pipe) + pipeline = reduce(self.carry, reversed(self.pipes), self.prepare_destination(destination)) return pipeline(self.passable) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index e29d571..65ce635 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -4,60 +4,111 @@ class TestPipeline(unittest.TestCase): + def test_it_runs_a_pipeline_with_one_callable(self): + pipeline = Pipeline() + result = pipeline.send('hello ') \ + .through([str.strip]) \ + .then_return() + self.assertEqual('hello', result) + def test_it_runs_a_pipeline_with_callables(self): pipeline = Pipeline() - result = pipeline.send('a sample string that is passed through to all pipes. ') \ + result = pipeline.send('hello world ') \ .through([str.title, str.strip]) \ .then_return() - self.assertEqual('A Sample String That Is Passed Through To All Pipes.', result) + self.assertEqual('Hello World', result) def test_it_runs_a_pipeline_with_callables_and_executes_the_destination(self): pipeline = Pipeline() - result = pipeline.send('a sample string that is passed through to all pipes. ') \ + result = pipeline.send('hello world ') \ .through([str.title, str.strip]) \ - .then(lambda x: x.replace('A Sample', 'A Nice Long')) - self.assertEqual('A Nice Long String That Is Passed Through To All Pipes.', result) + .then(lambda x: x.replace('Hello', 'Goodbye')) + self.assertEqual('Goodbye World', result) def test_it_runs_a_pipeline_with_callables_and_closures(self): pipeline = Pipeline() - result = pipeline.send('a sample string that is passed through to all pipes. ') \ + result = pipeline.send('hello world ') \ .through([ - lambda x, next: next(x.replace('all', 'all the')), + lambda x, next_pipe: next_pipe(x.replace('hello', 'goodbye')), str.title, str.strip ]) \ .then_return() - self.assertEqual('A Sample String That Is Passed Through To All The Pipes.', result) + self.assertEqual('Goodbye World', result) def test_it_runs_a_pipeline_with_closures(self): pipeline = Pipeline() - result = pipeline.send('a sample string that is passed through to all pipes.') \ + result = pipeline.send('hello world') \ + .through([ + lambda x, next_pipe: next_pipe(x.title()), + lambda x, next_pipe: next_pipe(x.replace('Hello', 'Goodbye')) + ]) \ + .then_return() + self.assertEqual('Goodbye World', result) + + def test_it_runs_a_pipeline_with_custom_pipes(self): + def custom_pipe(passable, next_pipe): + passable = passable.replace('Hello', 'Goodbye') + return next_pipe(passable) + + pipeline = Pipeline() + result = pipeline.send(' hello world ') \ + .through([ + lambda x, next_pipe: next_pipe(x.title()), + str.strip, + custom_pipe + ]) \ + .then_return() + self.assertEqual('Goodbye World', result) + + def test_it_runs_a_pipeline_with_classes(self): + class TitlePipe: + def handle(self, passable, next_pipe): + return next_pipe(passable.title()) + + pipeline = Pipeline() + result = pipeline.send(' hello world ') \ + .through([ + TitlePipe(), + str.strip + ]) \ + .then_return() + self.assertEqual('Hello World', result) + + def test_it_runs_a_pipeline_with_classes_and_custom_handler(self): + class TitlePipe: + def execute(self, passable, next_pipe): + return next_pipe(passable.title()) + + pipeline = Pipeline() + result = pipeline.send(' hello world ') \ + .via('execute') \ .through([ - lambda x, next: next(x.title()), - lambda x, next: next(x.replace('All', 'All The')) + TitlePipe(), + str.strip ]) \ .then_return() - self.assertEqual('A Sample String That Is Passed Through To All The Pipes.', result) + self.assertEqual('Hello World', result) def test_it_runs_a_pipeline_by_sending_late(self): pipeline = Pipeline() pipeline.through([str.title, str.strip]) - result = pipeline.send('a sample string that is passed through to all pipes. ') \ + result = pipeline.send('hello ') \ .then_return() - self.assertEqual('A Sample String That Is Passed Through To All Pipes.', result) + self.assertEqual('Hello', result) def test_it_runs_a_pipeline_setup_via_pipe(self): pipeline = Pipeline() pipeline.pipe([str.title, str.strip]) - result = pipeline.send('a sample string that is passed through to all pipes. ') \ + result = pipeline.send('hello ') \ .then_return() - self.assertEqual('A Sample String That Is Passed Through To All Pipes.', result) + self.assertEqual('Hello', result) def test_it_bails_early(self): pipeline = Pipeline() result = pipeline.send('bork') \ .through([ - lambda x, next: False, + lambda x, next_pipe: False, str.strip ]) \ .then() @@ -68,7 +119,7 @@ def test_it_bails_in_the_middle(self): result = pipeline.send('bork ') \ .through([ str.strip, - lambda x, next: x, + lambda x, next_pipe: x, str.title ]) \ .then()