diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 49103319d0..cc0f7bafbe 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -266,7 +266,7 @@ def _dispatch_execute( exc_str = get_traceback_str(e) output_file_dict[error_file_name] = _error_models.ErrorDocument( _error_models.ContainerError( - code="USER", + code=e.error_code, message=exc_str, kind=kind, origin=_execution_models.ExecutionError.ErrorKind.USER, diff --git a/flytekit/exceptions/user.py b/flytekit/exceptions/user.py index 3413d172ff..af4dbf63c6 100644 --- a/flytekit/exceptions/user.py +++ b/flytekit/exceptions/user.py @@ -28,6 +28,10 @@ def __init__(self, exc_value: Exception): def value(self): return self._exc_value + @property + def error_code(self): + return self._ERROR_CODE + class FlyteTypeException(FlyteUserException, TypeError): _ERROR_CODE = "USER:TypeError" diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 323770bed7..8a1709d668 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -169,6 +169,7 @@ def verify_output(*args, **kwargs): assert error_filename_base.startswith("error-") uuid.UUID(hex=error_filename_base[6:], version=4) assert error_filename_ext == ".pb" + assert container_error.code == "USER:RuntimeError" mock_write_to_file.side_effect = verify_output _dispatch_execute(ctx, lambda: python_task, "inputs path", "outputs prefix") @@ -991,3 +992,35 @@ def t1(a: typing.List[int]) -> typing.List[typing.List[str]]: assert lit.literals["o0"].HasField("offloaded_metadata") == False else: assert False, f"Unexpected file {ff}" + + +@mock.patch("flytekit.core.utils.load_proto_from_file") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") +@mock.patch("flytekit.core.utils.write_proto_to_file") +def test_dispatch_execute_custom_error_code_with_flyte_user_runtime_exception(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): + class CustomException(FlyteUserRuntimeException): + _ERROR_CODE = "CUSTOM_ERROR_CODE" + + mock_get_data.return_value = True + mock_upload_dir.return_value = True + + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) + ) + ) as ctx: + python_task = mock.MagicMock() + python_task.dispatch_execute.side_effect = CustomException("custom error") + + empty_literal_map = _literal_models.LiteralMap({}).to_flyte_idl() + mock_load_proto.return_value = empty_literal_map + + def verify_output(*args, **kwargs): + assert isinstance(args[0], ErrorDocument) + assert args[0].error.code == "CUSTOM_ERROR_CODE" + + mock_write_to_file.side_effect = verify_output + _dispatch_execute(ctx, lambda: python_task, "inputs path", "outputs prefix") + assert mock_write_to_file.call_count == 1 diff --git a/tests/flytekit/unit/exceptions/test_user.py b/tests/flytekit/unit/exceptions/test_user.py index fedacbebc6..15ae795250 100644 --- a/tests/flytekit/unit/exceptions/test_user.py +++ b/tests/flytekit/unit/exceptions/test_user.py @@ -11,6 +11,16 @@ def test_flyte_user_exception(): assert type(e).error_code == "USER:Unknown" assert isinstance(e, base.FlyteException) +def test_flyte_user_runtime_exception(): + try: + base_exn = Exception("everywhere is bad") + raise user.FlyteUserRuntimeException("bad") from base_exn + except Exception as e: + assert str(e) == "USER:RuntimeError: error=bad, cause=everywhere is bad" + assert isinstance(type(e), base._FlyteCodedExceptionMetaclass) + assert type(e).error_code == "USER:RuntimeError" + assert isinstance(e, base.FlyteException) + assert isinstance(e, user.FlyteUserRuntimeException) def test_flyte_type_exception(): try: