Skip to content

Commit

Permalink
Make FlyteUserRuntimeException to return error_code in Container Error (
Browse files Browse the repository at this point in the history
#3059)

* Make FlyteUserRuntimeException to return error_code in the ContainerError

Signed-off-by: Rafael Ribeiro Raposo <[email protected]>
  • Loading branch information
RRap0so authored Jan 16, 2025
1 parent 3b2b573 commit 4b50681
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 1 deletion.
2 changes: 1 addition & 1 deletion flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _dispatch_execute(
exc_str = get_traceback_str(e)
output_file_dict[error_file_name] = _error_models.ErrorDocument(
_error_models.ContainerError(
code="USER",
code=e.error_code,
message=exc_str,
kind=kind,
origin=_execution_models.ExecutionError.ErrorKind.USER,
Expand Down
4 changes: 4 additions & 0 deletions flytekit/exceptions/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def __init__(self, exc_value: Exception):
def value(self):
return self._exc_value

@property
def error_code(self):
return self._ERROR_CODE


class FlyteTypeException(FlyteUserException, TypeError):
_ERROR_CODE = "USER:TypeError"
Expand Down
33 changes: 33 additions & 0 deletions tests/flytekit/unit/bin/test_python_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def verify_output(*args, **kwargs):
assert error_filename_base.startswith("error-")
uuid.UUID(hex=error_filename_base[6:], version=4)
assert error_filename_ext == ".pb"
assert container_error.code == "USER:RuntimeError"

mock_write_to_file.side_effect = verify_output
_dispatch_execute(ctx, lambda: python_task, "inputs path", "outputs prefix")
Expand Down Expand Up @@ -991,3 +992,35 @@ def t1(a: typing.List[int]) -> typing.List[typing.List[str]]:
assert lit.literals["o0"].HasField("offloaded_metadata") == False
else:
assert False, f"Unexpected file {ff}"


@mock.patch("flytekit.core.utils.load_proto_from_file")
@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data")
@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data")
@mock.patch("flytekit.core.utils.write_proto_to_file")
def test_dispatch_execute_custom_error_code_with_flyte_user_runtime_exception(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto):
class CustomException(FlyteUserRuntimeException):
_ERROR_CODE = "CUSTOM_ERROR_CODE"

mock_get_data.return_value = True
mock_upload_dir.return_value = True

ctx = context_manager.FlyteContext.current_context()
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION)
)
) as ctx:
python_task = mock.MagicMock()
python_task.dispatch_execute.side_effect = CustomException("custom error")

empty_literal_map = _literal_models.LiteralMap({}).to_flyte_idl()
mock_load_proto.return_value = empty_literal_map

def verify_output(*args, **kwargs):
assert isinstance(args[0], ErrorDocument)
assert args[0].error.code == "CUSTOM_ERROR_CODE"

mock_write_to_file.side_effect = verify_output
_dispatch_execute(ctx, lambda: python_task, "inputs path", "outputs prefix")
assert mock_write_to_file.call_count == 1
10 changes: 10 additions & 0 deletions tests/flytekit/unit/exceptions/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ def test_flyte_user_exception():
assert type(e).error_code == "USER:Unknown"
assert isinstance(e, base.FlyteException)

def test_flyte_user_runtime_exception():
try:
base_exn = Exception("everywhere is bad")
raise user.FlyteUserRuntimeException("bad") from base_exn
except Exception as e:
assert str(e) == "USER:RuntimeError: error=bad, cause=everywhere is bad"
assert isinstance(type(e), base._FlyteCodedExceptionMetaclass)
assert type(e).error_code == "USER:RuntimeError"
assert isinstance(e, base.FlyteException)
assert isinstance(e, user.FlyteUserRuntimeException)

def test_flyte_type_exception():
try:
Expand Down

0 comments on commit 4b50681

Please sign in to comment.