diff --git a/metaflow/plugins/kfp/kfp.py b/metaflow/plugins/kfp/kfp.py index 063f3ed07b0..26432512367 100644 --- a/metaflow/plugins/kfp/kfp.py +++ b/metaflow/plugins/kfp/kfp.py @@ -2,6 +2,7 @@ import inspect import json import marshal +import numbers import os import sys from dataclasses import dataclass @@ -47,7 +48,9 @@ from metaflow.plugins import EnvironmentDecorator, KfpInternalDecorator from metaflow.plugins.kfp.kfp_constants import ( S3_SENSOR_RETRY_COUNT, + PVC_CREATE_RETRY_COUNT, EXIT_HANDLER_RETRY_COUNT, + BACKOFF_DURATION, ) from metaflow.plugins.kfp.kfp_decorator import KfpException @@ -98,6 +101,7 @@ def __init__( interruptible_decorator: interruptibleDecorator, environment_decorator: EnvironmentDecorator, total_retries: int, + minutes_between_retries: str, ): self.step_name = step_name self.resource_requirements = resource_requirements @@ -106,6 +110,7 @@ def __init__( self.interruptible_decorator = interruptible_decorator self.environment_decorator = environment_decorator self.total_retries = total_retries + self.minutes_between_retries = minutes_between_retries self.preceding_kfp_func: Callable = ( kfp_decorator.attributes.get("preceding_component", None) @@ -281,6 +286,14 @@ def _get_retries(node: DAGNode) -> Tuple[int, int]: return max_user_code_retries, max_user_code_retries + max_error_retries + @staticmethod + def _get_minutes_between_retries(node: DAGNode) -> Optional[str]: + retry_deco = [deco for deco in node.decorators if deco.name == "retry"] + if retry_deco: + val = retry_deco[0].attributes.get("minutes_between_retries") + return f"{val}m" if isinstance(val, numbers.Number) else val + return None + @staticmethod def _get_resource_requirements(node: DAGNode) -> Dict[str, str]: """ @@ -416,6 +429,7 @@ def build_kfp_component(node: DAGNode, task_id: str) -> KfpComponent: user_code_retries, total_retries = KubeflowPipelines._get_retries(node) resource_requirements = self._get_resource_requirements(node) + minutes_between_retries = self._get_minutes_between_retries(node) return KfpComponent( step_name=node.name, @@ -453,6 +467,7 @@ def build_kfp_component(node: DAGNode, task_id: str) -> KfpComponent: None, # default ), total_retries=total_retries, + minutes_between_retries=minutes_between_retries, ) # Mapping of steps to their KfpComponent @@ -703,7 +718,9 @@ def _create_volume( k8s_resource=k8s_resource, attribute_outputs=attribute_outputs, ) - + resource.set_retry( + PVC_CREATE_RETRY_COUNT, policy="Always", backoff_duration=BACKOFF_DURATION + ) volume = PipelineVolume( name=f"{volume_name}-volume", pvc=resource.outputs["name"] ) @@ -904,7 +921,9 @@ def build_kfp_dag( if kfp_component.total_retries and kfp_component.total_retries > 0: metaflow_step_op.set_retry( - kfp_component.total_retries, policy="Always" + kfp_component.total_retries, + policy="Always", + backoff_duration=kfp_component.minutes_between_retries, ) if preceding_kfp_component_op: @@ -1297,7 +1316,9 @@ def _create_s3_sensor_op( ).set_display_name("s3_sensor") KubeflowPipelines._set_minimal_container_resources(s3_sensor_op) - s3_sensor_op.set_retry(S3_SENSOR_RETRY_COUNT, policy="Always") + s3_sensor_op.set_retry( + S3_SENSOR_RETRY_COUNT, policy="Always", backoff_duration=BACKOFF_DURATION + ) return s3_sensor_op def _create_exit_handler_op(self, package_commands: str) -> ContainerOp: @@ -1348,5 +1369,9 @@ def _create_exit_handler_op(self, package_commands: str) -> ContainerOp: command=exit_handler_command, ) .set_display_name("exit_handler") - .set_retry(EXIT_HANDLER_RETRY_COUNT, policy="Always") + .set_retry( + EXIT_HANDLER_RETRY_COUNT, + policy="Always", + backoff_duration=BACKOFF_DURATION, + ) ) diff --git a/metaflow/plugins/kfp/kfp_constants.py b/metaflow/plugins/kfp/kfp_constants.py index 31504d86762..a5b14062d8e 100644 --- a/metaflow/plugins/kfp/kfp_constants.py +++ b/metaflow/plugins/kfp/kfp_constants.py @@ -13,7 +13,9 @@ INPUT_PATHS_ENV_NAME = "INPUT_PATHS_ENV_NAME" RETRY_COUNT = "MF_ATTEMPT" S3_SENSOR_RETRY_COUNT = 7 +PVC_CREATE_RETRY_COUNT = 7 EXIT_HANDLER_RETRY_COUNT = 7 +BACKOFF_DURATION = "1m" # 1 minute STEP_ENVIRONMENT_VARIABLES = "/tmp/step-environment-variables.sh"