From 4044c95fd28ccb2139151fe425587ee946a4dc1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Tue, 14 Jan 2025 21:01:10 +0100 Subject: [PATCH] Fix: Always propagate pytorch task worker process exception timestamp to task exception MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- flytekit/exceptions/user.py | 5 +-- .../flytekitplugins/kfpytorch/task.py | 4 +-- .../tests/test_elastic_task.py | 34 +++++++++++++++++++ 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/flytekit/exceptions/user.py b/flytekit/exceptions/user.py index 3413d172ff..05a783d034 100644 --- a/flytekit/exceptions/user.py +++ b/flytekit/exceptions/user.py @@ -15,14 +15,15 @@ class FlyteUserException(_FlyteException): class FlyteUserRuntimeException(_FlyteException): _ERROR_CODE = "USER:RuntimeError" - def __init__(self, exc_value: Exception): + def __init__(self, exc_value: Exception, timestamp: typing.Optional[float] = None): """ FlyteUserRuntimeException is thrown when a user code raises an exception. :param exc_value: The exception that was raised from user code. + :param timestamp: The timestamp as fractional seconds since epoch when the exception was raised. """ self._exc_value = exc_value - super().__init__(str(exc_value)) + super().__init__(str(exc_value), timestamp=timestamp) @property def value(self): diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index a951bea0a5..4bbcb814a4 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -18,7 +18,7 @@ from flytekit.core.context_manager import FlyteContextManager, OutputMetadata from flytekit.core.pod_template import PodTemplate from flytekit.core.resources import convert_resources_to_resource_model -from flytekit.exceptions.user import FlyteRecoverableException +from flytekit.exceptions.user import FlyteRecoverableException, FlyteUserRuntimeException from flytekit.extend import IgnoreOutputs, TaskPlugins from flytekit.loggers import logger @@ -475,7 +475,7 @@ def fn_partial(): # the automatically assigned timestamp based on exception creation time raise FlyteRecoverableException(e.format_msg(), timestamp=first_failure.timestamp) else: - raise RuntimeError(e.format_msg()) + raise FlyteUserRuntimeException(e, timestamp=first_failure.timestamp) except SignalException as e: logger.exception(f"Elastic launch agent process terminating: {e}") raise IgnoreOutputs() diff --git a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py index faadc1019f..7fa921aaef 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py @@ -276,3 +276,37 @@ def test_task_omp_set(): assert os.environ["OMP_NUM_THREADS"] == "42" test_task_omp_set() + + +def test_exception_timestamp() -> None: + """Test that the timestamp of the worker process exception is propagated to the task exception.""" + @task( + task_config=Elastic( + nnodes=1, + nproc_per_node=2, + ) + ) + def test_task(): + raise Exception("Test exception") + + with pytest.raises(Exception) as e: + test_task() + + assert e.value.timestamp is not None + + +def test_recoverable_exception_timestamp() -> None: + """Test that the timestamp of the worker process exception is propagated to the task exception.""" + @task( + task_config=Elastic( + nnodes=1, + nproc_per_node=2, + ) + ) + def test_task(): + raise FlyteRecoverableException("Recoverable test exception") + + with pytest.raises(Exception) as e: + test_task() + + assert e.value.timestamp is not None