diff --git a/Dockerfile.dev b/Dockerfile.dev index 1a839104e4..e6bd69740c 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -36,7 +36,7 @@ COPY . /flytekit RUN SETUPTOOLS_SCM_PRETEND_VERSION_FOR_FLYTEKIT=$PSEUDO_VERSION \ SETUPTOOLS_SCM_PRETEND_VERSION_FOR_FLYTEIDL=3.0.0dev0 \ uv pip install --system --no-cache-dir -U \ - "git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl" \ + "git+https://github.com/flyteorg/flyte.git@streaming-deck-v2#subdirectory=flyteidl" \ -e /flytekit \ -e /flytekit/plugins/flytekit-deck-standard \ -e /flytekit/plugins/flytekit-flyteinteractive \ diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 49103319d0..d41c95d6da 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -324,7 +324,7 @@ def _dispatch_execute( logger.info(f"Engine folder written successfully to the output prefix {output_prefix}") if task_def is not None and not getattr(task_def, "disable_deck", True): - _output_deck(task_def.name.split(".")[-1], ctx.user_space_params) + _output_deck(task_name=task_def.name.split(".")[-1], new_user_params=ctx.user_space_params) logger.debug("Finished _dispatch_execute") diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 6430aa9eac..866b9eb34f 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -42,6 +42,7 @@ from flyteidl.core import artifact_id_pb2 as art_id from flyteidl.core import tasks_pb2 +from google.protobuf.wrappers_pb2 import BoolValue from flytekit.configuration import LocalConfig, SerializationSettings from flytekit.core.artifact_utils import ( @@ -129,6 +130,7 @@ class TaskMetadata(object): timeout (Optional[Union[datetime.timedelta, int]]): the max amount of time for which one execution of this task should be executed for. The execution will be terminated if the runtime exceeds the given timeout (approximately) + generates_deck (bool): Whether the task will generate a Deck URI. pod_template_name (Optional[str]): the name of existing PodTemplate resource in the cluster which will be used in this task. """ @@ -141,6 +143,7 @@ class TaskMetadata(object): retries: int = 0 timeout: Optional[Union[datetime.timedelta, int]] = None pod_template_name: Optional[str] = None + generates_deck: bool = False is_eager: bool = False def __post_init__(self): @@ -179,6 +182,7 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata: discovery_version=self.cache_version, deprecated_error_message=self.deprecated, cache_serializable=self.cache_serialize, + generates_deck=BoolValue(value=self.generates_deck), pod_template_name=self.pod_template_name, cache_ignore_input_vars=self.cache_ignore_input_vars, is_eager=self.is_eager, @@ -720,8 +724,11 @@ def dispatch_execute( may be none * ``DynamicJobSpec`` is returned when a dynamic workflow is executed """ - if DeckField.TIMELINE.value in self.deck_fields and ctx.user_space_params is not None: - ctx.user_space_params.decks.append(ctx.user_space_params.timeline_deck) + if not self.disable_deck and ctx.user_space_params is not None: + ctx.user_space_params.builder().add_attr("ENABLE_DECK", True) + if DeckField.TIMELINE.value in self.deck_fields: + ctx.user_space_params.decks.append(ctx.user_space_params.timeline_deck) + # Invoked before the task is executed new_user_params = self.pre_execute(ctx.user_space_params) diff --git a/flytekit/deck/deck.py b/flytekit/deck/deck.py index 025306d47b..24b7729ec3 100644 --- a/flytekit/deck/deck.py +++ b/flytekit/deck/deck.py @@ -41,10 +41,6 @@ class Deck: scatter plots or Markdown text. In addition, users can create new decks to render their data with custom renderers. - .. warning:: - - This feature is in beta. - .. code-block:: python iris_df = px.data.iris() @@ -86,6 +82,12 @@ def name(self) -> str: def html(self) -> str: return self._html + @staticmethod + def publish(): + params = FlyteContextManager.current_context().user_space_params + task_name = params.task_id.name + _output_deck(task_name=task_name, new_user_params=params) + class TimeLineDeck(Deck): """ @@ -148,7 +150,8 @@ def generate_time_table(data: dict) -> str: def _get_deck( - new_user_params: ExecutionParameters, ignore_jupyter: bool = False + new_user_params: ExecutionParameters, + ignore_jupyter: bool = False, ) -> typing.Union[str, "IPython.core.display.HTML"]: # type:ignore """ Get flyte deck html string @@ -176,11 +179,17 @@ def _get_deck( def _output_deck(task_name: str, new_user_params: ExecutionParameters): ctx = FlyteContext.current_context() + params = ctx.user_space_params + + if not params.has_attr("ENABLE_DECK") or not params.enable_deck: + logger.warning("Deck is disabled for this task, please don't call Deck.publish()") + return + local_dir = ctx.file_access.get_random_local_directory() local_path = f"{local_dir}{os.sep}{DECK_FILE_NAME}" try: with open(local_path, "w", encoding="utf-8") as f: - f.write(_get_deck(new_user_params, ignore_jupyter=True)) + f.write(_get_deck(new_user_params=new_user_params, ignore_jupyter=True)) logger.info(f"{task_name} task creates flyte deck html to file://{local_path}") if ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: fs = ctx.file_access.get_filesystem_for_path(new_user_params.output_metadata_prefix) @@ -197,6 +206,7 @@ def _output_deck(task_name: str, new_user_params: ExecutionParameters): def get_deck_template() -> Template: root = os.path.dirname(os.path.abspath(__file__)) templates_dir = os.path.join(root, "html", "template.html") + with open(templates_dir, "r") as f: template_content = f.read() return Template(template_content) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 960555fd9b..b440f86730 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -180,6 +180,7 @@ def __init__( pod_template_name, cache_ignore_input_vars, is_eager: bool = False, + generates_deck: bool = False, ): """ Information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts, @@ -199,6 +200,7 @@ def __init__( receive deprecation warnings. :param bool cache_serializable: Whether or not caching operations are executed in serial. This means only a single instance over identical inputs is executed, other concurrent executions wait for the cached results. + :param bool generates_deck: Whether the task will generate a Deck URI. :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. :param cache_ignore_input_vars: Input variables that should not be included when calculating hash for cache. :param is_eager: @@ -214,6 +216,7 @@ def __init__( self._pod_template_name = pod_template_name self._cache_ignore_input_vars = cache_ignore_input_vars self._is_eager = is_eager + self._generates_deck = generates_deck @property def is_eager(self): @@ -295,6 +298,14 @@ def pod_template_name(self): """ return self._pod_template_name + @property + def generates_deck(self) -> bool: + """ + Whether the task will generate a Deck URI. + :rtype: bool + """ + return self._generates_deck + @property def cache_ignore_input_vars(self): """ @@ -315,6 +326,7 @@ def to_flyte_idl(self): discovery_version=self.discovery_version, deprecated_error_message=self.deprecated_error_message, cache_serializable=self.cache_serializable, + generates_deck=self.generates_deck, pod_template_name=self.pod_template_name, cache_ignore_input_vars=self.cache_ignore_input_vars, is_eager=self.is_eager, @@ -338,6 +350,7 @@ def from_flyte_idl(cls, pb2_object: _core_task.TaskMetadata): discovery_version=pb2_object.discovery_version, deprecated_error_message=pb2_object.deprecated_error_message, cache_serializable=pb2_object.cache_serializable, + generates_deck=pb2_object.generates_deck, pod_template_name=pb2_object.pod_template_name, cache_ignore_input_vars=pb2_object.cache_ignore_input_vars, is_eager=pb2_object.is_eager, diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 5c7a6d5eb4..bad5c27710 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -185,11 +185,13 @@ def get_serializable_task( entity.reset_command_fn() entity_config = entity.get_config(settings) or {} - extra_config = {} - if hasattr(entity, "task_function") and isinstance(entity.task_function, ClassDecorator): - extra_config = entity.task_function.get_extra_config() + if hasattr(entity, "task_function"): + if isinstance(entity.task_function, ClassDecorator): + extra_config = entity.task_function.get_extra_config() + if not entity.disable_deck: + entity.metadata.generates_deck = True merged_config = {**entity_config, **extra_config} diff --git a/tests/flytekit/unit/test_translator.py b/tests/flytekit/unit/test_translator.py index 24f2c14131..cafe9ffe1c 100644 --- a/tests/flytekit/unit/test_translator.py +++ b/tests/flytekit/unit/test_translator.py @@ -12,6 +12,7 @@ from flytekit.models.core import identifier as identifier_models from flytekit.models.task import Resources as resource_model from flytekit.tools.translator import get_serializable, Options +from google.protobuf.wrappers_pb2 import BoolValue default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( @@ -93,14 +94,41 @@ def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): def t2(a: str, b: str) -> str: return b + a - ssettings = ( + settings = ( serialization_settings.new_builder() .with_fast_serialization_settings(FastSerializationSettings(enabled=True)) .build() ) - task_spec = get_serializable(OrderedDict(), ssettings, t1) + task_spec = get_serializable(OrderedDict(), settings, t1) assert "pyflyte-fast-execute" in task_spec.template.container.args +def test_deck(): + from flytekit.deck import Deck + + @task(enable_deck=False) + def t_no_deck(): + pass + + @task(enable_deck=True) + def t_deck(): + Deck.publish() + + deck_settings = ( + serialization_settings.new_builder() + .with_fast_serialization_settings(FastSerializationSettings(enabled=True)) + .build() + ) + deck_task_spec = get_serializable(OrderedDict(), deck_settings, t_deck) + assert deck_task_spec.template.metadata.generates_deck == BoolValue(value=True) + + no_deck_settings = ( + serialization_settings.new_builder() + .with_fast_serialization_settings(FastSerializationSettings(enabled=True)) + .build() + ) + no_deck_task_spec = get_serializable(OrderedDict(), no_deck_settings, t_no_deck) + assert no_deck_task_spec.template.metadata.generates_deck == BoolValue(value=False) + def test_container(): @task