Skip to content

Commit

Permalink
feat: Support user-defined batch script content with SlurmShellTask
Browse files Browse the repository at this point in the history
`SlurmTask` and `SlurmShellTask` now share the same agent.

Signed-off-by: JiaWei Jiang <[email protected]>
  • Loading branch information
JiangJiaWei1103 committed Jan 14, 2025
1 parent 26cc201 commit 16d953e
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 5 deletions.
8 changes: 6 additions & 2 deletions flytekit/extras/tasks/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,11 @@ def __init__(

if task_config is not None:
fully_qualified_class_name = task_config.__module__ + "." + task_config.__class__.__name__
if fully_qualified_class_name not in ["flytekitplugins.pod.task.Pod", "flytekitplugins.slurm.task.Slurm"]:
if fully_qualified_class_name not in [
"flytekitplugins.pod.task.Pod",
"flytekitplugins.slurm.task.Slurm",
"flytekitplugins.slurm.task.SlurmShell",
]:
raise ValueError("TaskConfig can either be empty - indicating simple container task or a PodConfig.")

# Each instance of NotebookTask instantiates an underlying task with a dummy function that will only be used
Expand All @@ -259,7 +263,7 @@ def __init__(
# errors.
# This seem like a hack. We should use a plugin_class that doesn't require a fake-function to make work.
plugin_class = TaskPlugins.find_pythontask_plugin(type(task_config))
if plugin_class.__name__ == "SlurmTask":
if plugin_class.__name__ in ["SlurmTask", "SlurmShellTask"]:
self._config_task_instance = None
else:
self._config_task_instance = plugin_class(task_config=task_config, task_function=_dummy_task_func)
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-slurm/flytekitplugins/slurm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .agent import SlurmAgent
from .task import Slurm, SlurmTask
from .task import Slurm, SlurmShell, SlurmShellTask, SlurmTask
28 changes: 26 additions & 2 deletions plugins/flytekit-slurm/flytekitplugins/slurm/agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import tempfile
from dataclasses import dataclass
from typing import Dict, List, Optional

Expand Down Expand Up @@ -29,6 +30,12 @@ class SlurmAgent(AsyncAgentBase):
# _ssh_clients: Dict[str, SSHClientConnection]
_conn: Optional[SSHClientConnection] = None

# Tmp remote path of the batch script
REMOTE_PATH = "/tmp/echo_shell.slurm"

# Dummy script content
DUMMY_SCRIPT = "#!/bin/bash"

def __init__(self) -> None:
super(SlurmAgent, self).__init__(task_type_name="slurm", metadata_type=SlurmJobMetadata)

Expand All @@ -40,16 +47,33 @@ async def create(
) -> SlurmJobMetadata:
# Retrieve task config
slurm_host = task_template.custom["slurm_host"]
batch_script_path = task_template.custom["batch_script_path"]
batch_script_args = task_template.custom["batch_script_args"]
sbatch_conf = task_template.custom["sbatch_conf"]

# Construct sbatch command for Slurm cluster
cmd = _get_sbatch_cmd(sbatch_conf=sbatch_conf, batch_script_path=batch_script_path, batch_script_args=batch_script_args)
upload_script = False
if "script" in task_template.custom:
script = task_template.custom["script"]
assert script != self.DUMMY_SCRIPT, "Please write the user-defined batch script content."

batch_script_path = self.REMOTE_PATH
upload_script = True
else:
# Assume the batch script is already on Slurm
batch_script_path = task_template.custom["batch_script_path"]
cmd = _get_sbatch_cmd(
sbatch_conf=sbatch_conf, batch_script_path=batch_script_path, batch_script_args=batch_script_args
)

# Run Slurm job
if self._conn is None:
await self._connect(slurm_host)
if upload_script:
with tempfile.NamedTemporaryFile("w") as f:
f.write(script)
f.flush()
async with self._conn.start_sftp_client() as sftp:
await sftp.put(f.name, self.REMOTE_PATH)
res = await self._conn.run(cmd, check=True)

# Retrieve Slurm job id
Expand Down
41 changes: 41 additions & 0 deletions plugins/flytekit-slurm/flytekitplugins/slurm/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ def __post_init__(self):
self.sbatch_conf = {}


@dataclass
class SlurmShell(object):
"""Encounter collision if Slurm is shared btw SlurmTask and SlurmShellTask."""

slurm_host: str
batch_script_args: Optional[List[str]] = None
sbatch_conf: Optional[Dict[str, str]] = None

def __post_init__(self):
if self.sbatch_conf is None:
self.sbatch_conf = {}


class SlurmTask(AsyncAgentExecutorMixin, ShellTask[Slurm]):
"""
Actual Plugin that transforms the local python code for execution within a slurm context...
Expand Down Expand Up @@ -66,4 +79,32 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
}


class SlurmShellTask(AsyncAgentExecutorMixin, ShellTask[Slurm]):
_TASK_TYPE = "slurm"

def __init__(
self,
name: str,
task_config: SlurmShell,
script: Optional[str] = None,
**kwargs,
):
super(SlurmShellTask, self).__init__(
name,
task_config=task_config,
task_type=self._TASK_TYPE,
script=script,
**kwargs,
)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
return {
"slurm_host": self.task_config.slurm_host,
"batch_script_args": self.task_config.batch_script_args,
"sbatch_conf": self.task_config.sbatch_conf,
"script": self._script,
}


TaskPlugins.register_pythontask_plugin(Slurm, SlurmTask)
TaskPlugins.register_pythontask_plugin(SlurmShell, SlurmShellTask)

0 comments on commit 16d953e

Please sign in to comment.