Skip to content

Commit

Permalink
Fix: Always propagate pytorch task worker process exception timestamp…
Browse files Browse the repository at this point in the history
… to task exception

Signed-off-by: Fabio Grätz <[email protected]>
  • Loading branch information
Fabio Grätz committed Jan 14, 2025
1 parent 34af2e2 commit 4044c95
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 4 deletions.
5 changes: 3 additions & 2 deletions flytekit/exceptions/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ class FlyteUserException(_FlyteException):
class FlyteUserRuntimeException(_FlyteException):
_ERROR_CODE = "USER:RuntimeError"

def __init__(self, exc_value: Exception):
def __init__(self, exc_value: Exception, timestamp: typing.Optional[float] = None):
"""
FlyteUserRuntimeException is thrown when a user code raises an exception.
:param exc_value: The exception that was raised from user code.
:param timestamp: The timestamp as fractional seconds since epoch when the exception was raised.
"""
self._exc_value = exc_value
super().__init__(str(exc_value))
super().__init__(str(exc_value), timestamp=timestamp)

Check warning on line 26 in flytekit/exceptions/user.py

View check run for this annotation

Codecov / codecov/patch

flytekit/exceptions/user.py#L26

Added line #L26 was not covered by tests

@property
def value(self):
Expand Down
4 changes: 2 additions & 2 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from flytekit.core.context_manager import FlyteContextManager, OutputMetadata
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.exceptions.user import FlyteRecoverableException
from flytekit.exceptions.user import FlyteRecoverableException, FlyteUserRuntimeException
from flytekit.extend import IgnoreOutputs, TaskPlugins
from flytekit.loggers import logger

Expand Down Expand Up @@ -475,7 +475,7 @@ def fn_partial():
# the automatically assigned timestamp based on exception creation time
raise FlyteRecoverableException(e.format_msg(), timestamp=first_failure.timestamp)
else:
raise RuntimeError(e.format_msg())
raise FlyteUserRuntimeException(e, timestamp=first_failure.timestamp)
except SignalException as e:
logger.exception(f"Elastic launch agent process terminating: {e}")
raise IgnoreOutputs()
Expand Down
34 changes: 34 additions & 0 deletions plugins/flytekit-kf-pytorch/tests/test_elastic_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,37 @@ def test_task_omp_set():
assert os.environ["OMP_NUM_THREADS"] == "42"

test_task_omp_set()


def test_exception_timestamp() -> None:
"""Test that the timestamp of the worker process exception is propagated to the task exception."""
@task(
task_config=Elastic(
nnodes=1,
nproc_per_node=2,
)
)
def test_task():
raise Exception("Test exception")

with pytest.raises(Exception) as e:
test_task()

assert e.value.timestamp is not None


def test_recoverable_exception_timestamp() -> None:
"""Test that the timestamp of the worker process exception is propagated to the task exception."""
@task(
task_config=Elastic(
nnodes=1,
nproc_per_node=2,
)
)
def test_task():
raise FlyteRecoverableException("Recoverable test exception")

with pytest.raises(Exception) as e:
test_task()

assert e.value.timestamp is not None

0 comments on commit 4044c95

Please sign in to comment.