Skip to content

Commit

Permalink
Add skip decorator; A few clean ups
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudw committed Aug 19, 2024
1 parent 898b0c0 commit 5411035
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 8 deletions.
7 changes: 5 additions & 2 deletions metaflow/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def get_plugin_cli():
from .frameworks.pytorch import PytorchParallelDecorator
from .aip.aip_decorator import AIPInternalDecorator
from .aip.accelerator_decorator import AcceleratorDecorator
from .aip.interruptible_decorator import interruptibleDecorator
from .aip.interruptible_decorator import InterruptibleDecorator
from .aip.skip_decorator import SkipDecorator


STEP_DECORATORS = [
Expand All @@ -134,8 +135,9 @@ def get_plugin_cli():
PytorchParallelDecorator,
InternalTestUnboundedForeachDecorator,
AcceleratorDecorator,
interruptibleDecorator,
InterruptibleDecorator,
AIPInternalDecorator,
SkipDecorator,
]
_merge_lists(STEP_DECORATORS, _ext_plugins["STEP_DECORATORS"], "name")

Expand All @@ -159,6 +161,7 @@ def get_plugin_cli():
from .aws.step_functions.schedule_decorator import ScheduleDecorator
from .project_decorator import ProjectDecorator
from .aip.s3_sensor_decorator import S3SensorDecorator

from .aip.exit_handler_decorator import ExitHandlerDecorator

FLOW_DECORATORS = [
Expand Down
6 changes: 3 additions & 3 deletions metaflow/plugins/aip/aip.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from metaflow.plugins.aip.aip_decorator import AIPException
from .accelerator_decorator import AcceleratorDecorator
from .argo_client import ArgoClient
from .interruptible_decorator import interruptibleDecorator
from .interruptible_decorator import InterruptibleDecorator
from .aip_foreach_splits import graph_to_task_ids
from ..aws.batch.batch_decorator import BatchDecorator
from ..aws.step_functions.schedule_decorator import ScheduleDecorator
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(
resource_requirements: Dict[str, str],
aip_decorator: AIPInternalDecorator,
accelerator_decorator: AcceleratorDecorator,
interruptible_decorator: interruptibleDecorator,
interruptible_decorator: InterruptibleDecorator,
environment_decorator: EnvironmentDecorator,
total_retries: int,
minutes_between_retries: str,
Expand Down Expand Up @@ -741,7 +741,7 @@ def build_aip_component(node: DAGNode, task_id: str) -> AIPComponent:
(
deco
for deco in node.decorators
if isinstance(deco, interruptibleDecorator)
if isinstance(deco, InterruptibleDecorator)
),
None, # default
),
Expand Down
2 changes: 1 addition & 1 deletion metaflow/plugins/aip/interruptible_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def _get_ec2_metadata(path: str) -> Optional[str]:
return response.text


class interruptibleDecorator(StepDecorator):
class InterruptibleDecorator(StepDecorator):
"""
For AIP orchestrator plugin only.
Expand Down
1 change: 0 additions & 1 deletion metaflow/plugins/aip/s3_sensor_decorator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from types import FunctionType
from typing import Tuple
from urllib.parse import urlparse

from metaflow.decorators import FlowDecorator
Expand Down
52 changes: 52 additions & 0 deletions metaflow/plugins/aip/skip_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Skip decorator is a workaround solution to implement conditional branching in metaflow.
# When condition variable is_skipping is evaluated to True,
# it will skip current step and execute the supplied next step.

from functools import wraps
from metaflow.decorators import StepDecorator


class SkipDecorator(StepDecorator):
"""
The @skip decorator is a workaround for conditional branching. The @skip decorator checks an artifact
and if it is false, skips the evaluation of the step function and jumps to the supplied next step.
**The `start` and `end` steps are always expected and should not be skipped.**
Usage:
class SkipFlow(FlowSpec):
condition = Parameter("condition", default=False)
@step
def start(self):
print("Should skip:", self.condition)
self.next(self.middle)
@skip(check='condition', next='end')
@step
def middle(self):
print("Running the middle step - not skipping")
self.next(self.end)
@step
def end(self):
pass
"""

name = "skip"

def __init__(self, check="", next=""):
super().__init__()
self.check = check
self.next = next

def __call__(self, f):
@wraps(f)
def func(step):
if getattr(step, self.check):
step.next(getattr(step, self.next))
else:
return f(step)

return func
1 change: 0 additions & 1 deletion metaflow/plugins/aip/tests/flows/resources_flow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import pprint
import subprocess
import time
from typing import Dict, List
from multiprocessing.shared_memory import SharedMemory

Expand Down
41 changes: 41 additions & 0 deletions metaflow/plugins/aip/tests/flows/skip_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from metaflow import Parameter, FlowSpec, step, skip


class SkipFlow(FlowSpec):

condition_true = Parameter("condition-true", default=True)

@step
def start(self):
print("Should skip:", self.condition)
self.desired_step_executed = False
self.condition_false = False
self.next(self.skipped_step)

@skip(check="condition_true", next="desired_step")
@step
def skipped_step(self):
raise Exception(
"Unexpectedly ran the skipped_step step. This step should have been skipped."
)
self.next(self.unreachable)

def unreachable(self):
raise Exception(
"Unexpectedly ran the unreachable step. This step should have been skipped."
)
self.next(self.end)

@skip(check="condition_false", next="end")
@step
def desired_step(self):
self.desired_step_executed = True
self.next(self.end)

@step
def end(self):
assert self.desired_step_executed, "Desired step was not executed"


if __name__ == "__main__":
SkipFlow()

0 comments on commit 5411035

Please sign in to comment.