From 320b175d08da7c0c79abd3e21e86cfa8b03040c8 Mon Sep 17 00:00:00 2001 From: lorenzocerrone Date: Mon, 14 Oct 2024 22:58:23 +0200 Subject: [PATCH 1/6] feat: refactor workflow tracker --- plantseg/tasks/io_tasks.py | 41 ++++++++-------- plantseg/tasks/workflow_handler.py | 79 +++++++++++++++++++----------- 2 files changed, 72 insertions(+), 48 deletions(-) diff --git a/plantseg/tasks/io_tasks.py b/plantseg/tasks/io_tasks.py index d8097ab3..ef3b016c 100644 --- a/plantseg/tasks/io_tasks.py +++ b/plantseg/tasks/io_tasks.py @@ -2,24 +2,20 @@ from plantseg.core.image import PlantSegImage, import_image, save_image from plantseg.tasks import task_tracker -from plantseg.tasks.workflow_handler import TaskUserInput +from plantseg.tasks.workflow_handler import RunTimeInputSchema @task_tracker( is_root=True, - list_private_params=["semantic_type", "stack_layout"], list_inputs={ - "input_path": TaskUserInput( - allowed_types=['str', 'list[str]'], + "input_path": RunTimeInputSchema( description="Path to a file, or a directory containing files (all files will be imported) or list of paths.", - headless_default=None, - user_input_required=True, + required=True, + is_input=True, ), - "image_name": TaskUserInput( - allowed_types=['None', 'str'], + "image_name": RunTimeInputSchema( description="Name of the image (if None, the file name will be used)", - headless_default=None, - user_input_required=False, + required=False, ), }, ) @@ -38,8 +34,18 @@ def import_image_task( input_path (Path): path to the image file semantic_type (str): semantic type of the image (raw, segmentation, prediction) stack_layout (str): stack layout of the image (3D, 2D, 2D_time) - image_name (str | None): name of the image (if None, the file name will be used) - key (str | None): key for the image (used only for h5 and zarr formats) + image_name (str | Noinput_path = inputs[input_schema["name"]] + input_paths = parse_import_image_task(input_path) + list_inputs.extend(input_paths)ne): name of the image (if None, the file name will be used) + key (str | None):"export_directory": RunTimeInput( + allowed_types=['str'], + description="Output directory path where the image will be saved", + headless_default=None, + user_input_required=True, + ), + "name_pattern": RunTimeInput( + allowed_types=['str'], description="Output file name", headless_default=None, user_input_required=False + ), key for the image (used only for h5 and zarr formats) m_slicing (str | None): m_slicing of the image (None, time, z, y, x) """ @@ -59,15 +65,10 @@ def import_image_task( @task_tracker( is_leaf=True, list_inputs={ - "export_directory": TaskUserInput( - allowed_types=['str'], - description="Output directory path where the image will be saved", - headless_default=None, - user_input_required=True, - ), - "name_pattern": TaskUserInput( - allowed_types=['str'], description="Output file name", headless_default=None, user_input_required=False + "export_directory": RunTimeInputSchema( + description="Output directory path where the image will be saved", required=True ), + "name_pattern": RunTimeInputSchema(description="Output file name", required=False), }, ) def export_image_task( diff --git a/plantseg/tasks/workflow_handler.py b/plantseg/tasks/workflow_handler.py index 56aeb73b..1a083f73 100644 --- a/plantseg/tasks/workflow_handler.py +++ b/plantseg/tasks/workflow_handler.py @@ -1,5 +1,7 @@ import json +from datetime import datetime from enum import Enum +from inspect import signature from pathlib import Path from typing import Any, Callable from uuid import UUID, uuid4 @@ -17,23 +19,45 @@ class NodeType(str, Enum): LEAF = "leaf" -class TaskUserInput(BaseModel): - value: Any = None - allowed_types: list[str] = Field(default_factory=list) +class RunTimeInputSchema(BaseModel): + description: str | None = None + task: str | None = None + required: bool = True + default: Any = None + is_input_file: bool = False + + +class Infos(BaseModel): + """ + Information about the workflow. + + Attributes: + description (str): An optional description of the workflow. This is only used for documentation purposes. + creation_date (str): The date and time when the workflow was created. This is automatically generated. + version (str): The version of plantseg used to create the workflow. + inputs_schema (dict[str, RunTimeInputSchema]): A dictionary with the schema of the inputs of the workflow. + instructions (str): Instructions on how to customize the workflow file. + """ + description: str = "No description provided" - headless_default: Any = None - task: str = "" - user_input_required: bool = True + creation_date: str = Field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d-%H:%M:%S")) + version: str = __version__ + inputs_schema: dict[str, RunTimeInputSchema] = Field(default_factory=dict) + instructions: str = ( + "This configuration file was generated by PlantSeg and can be used to run a headless workflow. " + "In order to run the workflow, you will need to customize this file with the correct paths and parameters. " + "To read detailed instructions on how to do this, please visit the PlantSeg documentation (https://kreshuklab.github.io/plant-seg/)." + ) class Task(BaseModel): func: str images_inputs: dict parameters: dict - list_private_parameters: list[str] outputs: list[str] node_type: NodeType id: UUID = Field(default_factory=uuid4) + skip: bool = False """ A task is a single operation in the workflow. It is defined by: @@ -43,16 +67,17 @@ class Task(BaseModel): images_inputs (dict): A image input represent a Image object. The key is the name of the parameter in the function, and the value is the name of the image. parameters (dict): The kwargs parameters of the workflow function. - list_private_parameters (list[str]): A list of the names of the private parameters. outputs (list[str]): A list of the names of the output images. node_type (NodeType): The type of the node in the workflow (ROOT, LEAF, NODE) id (UUID): A unique identifier for the task. + skip (bool): If True, the task will be skipped during the execution of the workflow. + This is useful to remove a task that requires manual intervention (like a cropping task from the GUI). """ class DAG(BaseModel): - plantseg_version: str = Field(default=__version__) + infos: Infos = Field(default_factory=Infos) inputs: dict[str, Any] = Field(default_factory=dict) list_tasks: list[Task] = Field(default_factory=list) @@ -60,7 +85,7 @@ class DAG(BaseModel): This model represents the Directed Acyclic Graph (DAG) of the workflow. Attributes: - plantseg_version (str): The version of PlantSeg used to create the workflow. + infos (Infos): A dictionary with the information of the workflow. inputs (dict[str, Any]): A dictionary of the inputs of the workflow. For example path to the images and other runtime parameters. list_tasks (list[Task]): A list of the tasks in the workflow. @@ -114,6 +139,7 @@ def prune_dag(dag: DAG) -> DAG: for input_key, text in dag.inputs.items(): if input_key in reachable_inputs: new_dag.inputs[input_key] = text + new_dag.infos.inputs_schema[input_key] = dag.infos.inputs_schema[input_key] return new_dag @@ -174,7 +200,6 @@ def add_task( func: Callable, images_inputs: dict, parameters: dict, - list_private_parameters: list[str], outputs: list[str], node_type: NodeType, ): @@ -186,7 +211,6 @@ def add_task( images_inputs (dict): A dictionary of the image inputs. The key is the name of the parameter in the function, and the value is the unique_name of the image. parameters (dict): The kwargs parameters of the workflow function. - list_private_parameters (list[str]): A list of the names of the private parameters. outputs (list[str]): A list of the names of the output images. node_type (NodeType): The type of the node in the workflow (ROOT, LEAF, NODE) @@ -197,13 +221,12 @@ def add_task( func=func.__name__, images_inputs=images_inputs, parameters=parameters, - list_private_parameters=list_private_parameters, outputs=outputs, node_type=node_type, ) self._dag.list_tasks.append(task) - def add_input(self, name: str, value=TaskUserInput, func_name: str | None = None): + def add_input(self, name: str, value: Any, value_schema: RunTimeInputSchema, func_name: str | None = None): def _unique_input(name, id: int = 0): new_name = f"{name}_{id}" if new_name not in self._dag.list_inputs: @@ -216,7 +239,9 @@ def _unique_input(name, id: int = 0): else: unique_name = _unique_input(name) - value.task = func_name + value_schema.task = func_name + self._dag.infos.inputs_schema[unique_name] = value_schema + self._dag.inputs[unique_name] = value return unique_name @@ -231,7 +256,7 @@ def prune_dag(self) -> DAG: def save_to_yaml(self, path: Path | str): clean_dag = self.prune_dag() - dag_dict = json.loads(clean_dag.model_dump_json()) + dag_dict = json.loads(clean_dag.model_dump_json(exclude_none=True)) if isinstance(path, str): path = Path(path) @@ -256,8 +281,7 @@ def task_tracker( func: Callable | None = None, is_root=False, is_leaf=False, - list_inputs: dict[str, TaskUserInput] | None = None, - list_private_params: list[str] | None = None, + list_inputs: dict[str, RunTimeInputSchema] | None = None, ): """ Decorator to register a function as a task in the workflow. @@ -267,8 +291,6 @@ def task_tracker( is_root (bool): If True, the function is a root node in the workflow (usually a import task). is_leaf (bool): If True, the function is a leaf node in the workflow (usually a writer task). list_inputs (dict[str, TaskUserInput]): A dictionary of the inputs of the function. The key is the name of the parameter - list_private_params (list[str]): A list of the names of the private parameters. If a gui will - be used to run the workflow, these parameters should not be exposed to the user. """ if is_root and is_leaf: @@ -282,33 +304,34 @@ def task_tracker( node_type = NodeType.NODE list_inputs = list_inputs or {} - list_private_params = list_private_params or [] def _inner_decorator(func): workflow_handler.register_func(func) def wrapper(*args, **kwargs): assert len(args) == 0, "Workflow functions should not have positional arguments" + func_signature = signature(func) + parameters = {param: func_signature.parameters[param].default for param in func_signature.parameters} images_inputs = {} - parameters = {} for name, arg in kwargs.items(): if isinstance(arg, PlantSegImage): images_inputs[name] = arg.unique_name + parameters.pop(name) elif name in list_inputs.keys(): value = list_inputs[name] - input_name = workflow_handler.add_input(name, value=value, func_name=func.__name__) + input_name = workflow_handler.add_input( + name, value=arg, value_schema=value, func_name=func.__name__ + ) images_inputs[name] = input_name + parameters.pop(name) else: + # Replace the default value with the provided value parameters[name] = arg - for private_param in list_private_params: - if private_param not in parameters: - raise ValueError(f"Private parameter {private_param} not found in the function parameters") - # Execute the function out_image = func(*args, **kwargs) @@ -336,11 +359,11 @@ def wrapper(*args, **kwargs): f"Function {func.__name__} has an output image with the same name as an input image: {name}" ) + # Add the task to the workflow workflow_handler.add_task( func=func, images_inputs=images_inputs, parameters=parameters, - list_private_parameters=list_private_params, outputs=list_outputs, node_type=node_type, ) From 10f6b4fb105e692b049c79c9c7b9151006f4f898 Mon Sep 17 00:00:00 2001 From: lorenzocerrone Date: Mon, 14 Oct 2024 23:27:57 +0200 Subject: [PATCH 2/6] feat: cleanup headless --- plantseg/headless/__init__.py | 12 +++ plantseg/headless/headless.py | 167 +++++++++++++++-------------- plantseg/tasks/io_tasks.py | 2 +- plantseg/tasks/workflow_handler.py | 4 +- tests/headless/test_headless.py | 17 ++- 5 files changed, 113 insertions(+), 89 deletions(-) diff --git a/plantseg/headless/__init__.py b/plantseg/headless/__init__.py index e69de29b..7ab8276f 100644 --- a/plantseg/headless/__init__.py +++ b/plantseg/headless/__init__.py @@ -0,0 +1,12 @@ +"""This module contains the headless workflow for PlantSeg. + +To build a headless workflow, you can: + - Register a new workflow manually using the plantseg API. + - Run a workflow from the napari viewer and export it as a configuration file. + +The headless workflow configured can be run using the `run_headless_workflow` function. +""" + +from plantseg.headless.headless import run_headles_workflow_from_config, run_headless_workflow + +__all__ = ["run_headles_workflow_from_config", "run_headless_workflow"] diff --git a/plantseg/headless/headless.py b/plantseg/headless/headless.py index 8352e356..27111e18 100644 --- a/plantseg/headless/headless.py +++ b/plantseg/headless/headless.py @@ -1,11 +1,12 @@ import logging from pathlib import Path -from typing import Literal +from typing import Any, Literal import yaml from plantseg.headless.basic_runner import SerialRunner -from plantseg.tasks.workflow_handler import TaskUserInput +from plantseg.io import allowed_data_format +from plantseg.tasks.workflow_handler import RunTimeInputSchema logger = logging.getLogger(__name__) @@ -14,102 +15,115 @@ _implemented_runners = {'serial': SerialRunner} -def parse_input_path(user_input: TaskUserInput): - value = user_input.value - if value is None: - raise ValueError("Input path must be provided.") +def validate_config(config: dict): + if "inputs" not in config: + raise ValueError("The workflow configuration does not contain an 'inputs' section.") - elif isinstance(value, str): - path = Path(value) - if not path.exists(): - raise FileNotFoundError(f"Input path {path} does not exist.") + if "infos" not in config: + raise ValueError("The workflow configuration does not contain an 'infos' section.") - return [path] + if "list_tasks" not in config: + raise ValueError("The workflow configuration does not contain an 'list_tasks' section.") - elif isinstance(value, list): - paths = [Path(p) for p in value] - for path in paths: - if not path.exists(): - raise FileNotFoundError(f"Input path {path} does not exist.") + if "runner" not in config: + logger.warning( + "The workflow configuration does not contain a 'runner' section. Using the default serial runner." + ) + config["runner"] = "serial" - return paths + return config - else: - raise ValueError("Input path must be a string or a list of strings.") +def parse_import_image_task(input_path, allow_dir: bool) -> list[Path]: + if isinstance(input_path, str): + input_path = Path(input_path) -def output_directory(user_input: TaskUserInput): - user_input = user_input.value - if user_input is None: - raise ValueError("Output directory must be provided.") + if not input_path.exists(): + raise FileNotFoundError(f"File {input_path} does not exist.") - if not isinstance(user_input, str): - raise ValueError("Output directory must be a string.") + if input_path.is_file(): + list_files = [input_path] + elif input_path.is_dir(): + if not allow_dir: + raise ValueError(f"Directory {input_path} is not allowed when multiple input files are expected.") - output_dir = Path(user_input) + list_files = list(input_path.glob("*")) + else: + raise ValueError(f"Path {input_path} is not a file or a directory.") - if not output_dir.exists(): - output_dir.mkdir(parents=True) - return output_dir + list_files = [f for f in list_files if f.suffix in allowed_data_format] + if not list_files: + raise ValueError(f"No valid files found in {input_path}.") + return list_files -def parse_generic_input(user_input: TaskUserInput): - value = user_input.value - if value is None: - value = user_input.headless_default +def collect_jobs_list(inputs: dict | list[dict], inputs_schema: dict[str, RunTimeInputSchema]) -> list[dict[str, Any]]: + """ + Parse the inputs and create a list of jobs to run. + """ - if value is None and user_input.user_input_required: - raise ValueError(f"Input must be provided. {user_input}") + if isinstance(inputs, dict): + inputs = [inputs] - return value + num_is_input_file = sum([1 for schema in inputs_schema.values() if schema.is_input_file]) + if num_is_input_file == 0: + raise ValueError("No input files found in the inputs schema. The workflow cannot run.") + elif num_is_input_file > 1: + allow_dir = False + else: + allow_dir = True -def parse_input_config(inputs_config: dict): - inputs_config = {k: TaskUserInput(**v) for k, v in inputs_config.items()} + all_jobs = [] + for input_dict in inputs: + if not isinstance(input_dict, dict): + raise ValueError(f"Input {input_dict} should be a dictionary.") - list_input_keys = {} - single_input_keys = {} - has_input_path = False - has_output_dir = False + inputs_files = {} + for name, schema in inputs_schema.items(): + if schema.is_input_file: + if name not in inputs_files: + inputs_files[name] = [] - for key, value in inputs_config.items(): - if key.find("input_path") != -1: - list_input_keys[key] = parse_input_path(value) - has_input_path = True + inputs_files[name].extend(parse_import_image_task(input_dict[name], allow_dir=allow_dir)) - elif key.find("export_directory") != -1: - single_input_keys[key] = output_directory(value) - has_output_dir = True + list_len = [len(files) for files in inputs_files.values()] + if len(set(list_len)) != 1: + raise ValueError(f"Inputs have different number of files. found {inputs_files}") - else: - single_input_keys[key] = parse_generic_input(value) + list_files = list(zip(*inputs_files.values())) + list_keys = list(inputs_files.keys()) + list_jobs = [dict(zip(list_keys, files)) for files in list_files] - if not has_input_path: - raise ValueError("The provided workflow configuration does not contain an input path.") + for job in list_jobs: + for key, value in input_dict.items(): + if key not in job: + job[key] = value + all_jobs.append(job) + return all_jobs - if not has_output_dir: - raise ValueError("The provided workflow configuration does not contain an output directory.") - all_length = [len(v) for v in list_input_keys.values()] - # check if all input paths have the same length - if not all([_l == all_length[0] for _l in all_length]): - raise ValueError("All input paths must have the same length.") +def run_headles_workflow_from_config(config: dict, path: str | Path = None): + config = validate_config(config) - num_inputs = all_length[0] + inputs = config["inputs"] + inputs_schema = config["infos"]["inputs_schema"] + inputs_schema = {k: RunTimeInputSchema(**v) for k, v in inputs_schema.items()} - jobs_inputs = [] - for i in range(num_inputs): - job_input = {} - for key in list_input_keys: - job_input[key] = list_input_keys[key][i] + jobs_list = collect_jobs_list(inputs, inputs_schema) - for key in single_input_keys: - job_input[key] = single_input_keys[key] + runner = config.get("runner") + if runner not in _implemented_runners: + raise ValueError(f"Runner {runner} is not implemented.") + + runner = _implemented_runners[runner](path) - jobs_inputs.append(job_input) + for job_input in jobs_list: + logger.info(f"Submitting job with input: {job_input}") + runner.submit_job(job_input) - return jobs_inputs + logger.info("All jobs have been submitted. Running the workflow...") def run_headless_workflow(path: str | Path): @@ -127,17 +141,4 @@ def run_headless_workflow(path: str | Path): with path.open("r") as file: config = yaml.safe_load(file) - job_inputs = parse_input_config(config["inputs"]) - - runner = config.get("runner", "serial") - if runner not in _implemented_runners: - raise ValueError(f"Runner {runner} is not implemented.") - - runner = _implemented_runners[runner](path) - - for job_input in job_inputs: - logger.info(f"Submitting job with input: {job_input}") - runner.submit_job(job_input) - - logger.info("All jobs have been submitted. Running the workflow...") - # TODO: When parallel runners are implemented, hew we need to add something to wait for all jobs to finish + run_headles_workflow_from_config(config, path=path) diff --git a/plantseg/tasks/io_tasks.py b/plantseg/tasks/io_tasks.py index ef3b016c..a3c503c9 100644 --- a/plantseg/tasks/io_tasks.py +++ b/plantseg/tasks/io_tasks.py @@ -11,7 +11,7 @@ "input_path": RunTimeInputSchema( description="Path to a file, or a directory containing files (all files will be imported) or list of paths.", required=True, - is_input=True, + is_input_file=True, ), "image_name": RunTimeInputSchema( description="Name of the image (if None, the file name will be used)", diff --git a/plantseg/tasks/workflow_handler.py b/plantseg/tasks/workflow_handler.py index 1a083f73..964bf624 100644 --- a/plantseg/tasks/workflow_handler.py +++ b/plantseg/tasks/workflow_handler.py @@ -78,7 +78,7 @@ class Task(BaseModel): class DAG(BaseModel): infos: Infos = Field(default_factory=Infos) - inputs: dict[str, Any] = Field(default_factory=dict) + inputs: dict[str, Any] | list[dict[str, Any]] = Field(default_factory=dict) list_tasks: list[Task] = Field(default_factory=list) """ @@ -93,7 +93,7 @@ class DAG(BaseModel): @property def list_inputs(self): - return list(self.inputs.keys()) + return list(self.infos.inputs_schema.keys()) def prune_dag(dag: DAG) -> DAG: diff --git a/tests/headless/test_headless.py b/tests/headless/test_headless.py index d0803e2b..62691f99 100644 --- a/tests/headless/test_headless.py +++ b/tests/headless/test_headless.py @@ -51,9 +51,20 @@ def test_create_workflow(tmp_path): with open(tmp_path / 'workflow.yaml', 'r') as file: config = yaml.safe_load(file) - config['inputs']['input_path']['value'] = [str(path_tiff_1), str(path_tiff_2)] - config['inputs']['export_directory']['value'] = str(tmp_path / 'output') - config['inputs']['name_pattern']['value'] = '{original_name}_export' + job_list = [ + { + "input_path": str(path_tiff_1), + "export_directory": str(tmp_path / 'output'), + "name_pattern": "{original_name}_export", + }, + { + "input_path": str(path_tiff_2), + "export_directory": str(tmp_path / 'output'), + "name_pattern": "{original_name}_export", + }, + ] + + config['inputs'] = job_list with open(tmp_path / 'workflow.yaml', 'w') as file: yaml.dump(config, file) From c14d146efaf8b8f06e85f3e924caddbb389c5183 Mon Sep 17 00:00:00 2001 From: lorenzocerrone Date: Mon, 21 Oct 2024 21:58:33 +0200 Subject: [PATCH 3/6] feat: add scheleton for headless gui --- plantseg/headless_gui/__init__.py | 0 plantseg/headless_gui/headless_gui.py | 33 ++++++++++ plantseg/headless_gui/plantseg_classic.py | 75 +++++++++++++++++++++++ plantseg/run_plantseg.py | 4 +- 4 files changed, 110 insertions(+), 2 deletions(-) create mode 100644 plantseg/headless_gui/__init__.py create mode 100644 plantseg/headless_gui/headless_gui.py create mode 100644 plantseg/headless_gui/plantseg_classic.py diff --git a/plantseg/headless_gui/__init__.py b/plantseg/headless_gui/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/plantseg/headless_gui/headless_gui.py b/plantseg/headless_gui/headless_gui.py new file mode 100644 index 00000000..9e62934d --- /dev/null +++ b/plantseg/headless_gui/headless_gui.py @@ -0,0 +1,33 @@ +from enum import Enum + +from magicgui import magicgui +from magicgui.experimental import guiclass +from magicgui.widgets import Container + +# from plantseg.headless.headless import run_headless_workflow_from_path +from plantseg.headless_gui.plantseg_classic import widget_plantseg_classic + +all_workflows = { + "PlantsegClassic": widget_plantseg_classic, +} + + +@magicgui( + auto_call=True, + name={ + 'label': 'Mode', + 'tooltip': 'Select the workflow to run', + 'choices': list(all_workflows.keys()), + }, +) +def workflow_selector(name: str = list(all_workflows.keys())[0]): + for workflow in all_workflows.values(): + workflow.hide() + + all_workflows[name].show() + + +if __name__ == "__main__": + gui_container = Container(widgets=[workflow_selector, *all_workflows.values()], labels=False) + workflow_selector() + gui_container.show(run=True) diff --git a/plantseg/headless_gui/plantseg_classic.py b/plantseg/headless_gui/plantseg_classic.py new file mode 100644 index 00000000..48fe6db0 --- /dev/null +++ b/plantseg/headless_gui/plantseg_classic.py @@ -0,0 +1,75 @@ +from enum import Enum +from pathlib import Path + +from magicgui import magicgui +from magicgui.widgets import Container + +######################################################################################################################## +# +# Input Setup Widget +# +######################################################################################################################## + + +class FilePickMode(Enum): + File = 'File' + Directory = 'Directory' + + @classmethod + def to_choices(cls) -> list[str]: + return [mode.value for mode in FilePickMode] + + +@magicgui( + call_button=False, + file_pick_mode={ + 'label': 'Input Mode', + 'tooltip': 'Select the workflow to run', + 'choices': FilePickMode.to_choices(), + }, + file={ + 'label': 'File', + 'mode': 'r', + 'layout': 'vertical', + 'tooltip': 'Select the file to process one by one', + }, + directory={ + 'label': 'Directory', + 'mode': 'd', + 'tooltip': 'Process all files in the directory', + }, +) +def widget_input_model( + file_pick_mode: str = FilePickMode.File.value, + file: Path = Path('.').absolute(), + directory: Path = Path('.').absolute(), +): + pass + + +@widget_input_model.file_pick_mode.changed.connect +def _on_mode_change(file_pick_mode): + if file_pick_mode == FilePickMode.File.value: + widget_input_model.file.show() + widget_input_model.directory.hide() + else: + widget_input_model.file.hide() + widget_input_model.directory.show() + + +widget_input_model.directory.hide() + + +######################################################################################################################## +# +# PlantSeg Classic Workflow +# +######################################################################################################################## +@magicgui( + call_button="Run - PlantSeg Classic", +) +def widget_setup_workflow_config(): + pass + + +widget_plantseg_classic = Container(widgets=[widget_input_model, widget_setup_workflow_config], labels=False) diff --git a/plantseg/run_plantseg.py b/plantseg/run_plantseg.py index 50c1dd62..af1eec48 100644 --- a/plantseg/run_plantseg.py +++ b/plantseg/run_plantseg.py @@ -47,9 +47,9 @@ def launch_napari(): def launch_workflow_headless(path: Path): """Run a workflow in headless mode.""" - from plantseg.headless.headless import run_headless_workflow + from plantseg.headless.headless import run_headless_workflow_from_path - run_headless_workflow(path) + run_headless_workflow_from_path(path) def launch_training(path: Path): From a59370caa01cd8b2a3564cf44ff83def38af0a22 Mon Sep 17 00:00:00 2001 From: lorenzocerrone Date: Mon, 21 Oct 2024 22:26:39 +0200 Subject: [PATCH 4/6] fix: bug in headless tasks with multiple outputs --- plantseg/headless/basic_runner.py | 22 ++++++++++++++++------ plantseg/headless/headless.py | 2 +- plantseg/run_plantseg.py | 4 ++-- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/plantseg/headless/basic_runner.py b/plantseg/headless/basic_runner.py index 6e244cfc..3723a6e4 100644 --- a/plantseg/headless/basic_runner.py +++ b/plantseg/headless/basic_runner.py @@ -1,6 +1,7 @@ import logging from pathlib import Path +from plantseg.core.image import PlantSegImage from plantseg.tasks.workflow_handler import DAG, Task, WorkflowHandler logger = logging.getLogger(__name__) @@ -43,12 +44,21 @@ def run_task(self, task: Task, var_space: dict): func = self.func_registry.get_func(task.func) outputs = func(**inputs, **task.parameters) - # Save outputs in var_space - for i, name in enumerate(task.outputs): - if isinstance(outputs, tuple): - var_space[name] = outputs[i] - else: - var_space[name] = outputs + if isinstance(outputs, PlantSegImage): + outputs = [outputs] + + elif outputs is None: + outputs = [] + + assert isinstance( + outputs, (list, tuple) + ), f"Task {task.func} should return a list of PlantSegImage, got {type(outputs)}" + assert len(outputs) == len( + task.outputs + ), f"Task {task.func} should return {len(task.outputs)} outputs, got {len(outputs)}" + + for name, output in zip(task.outputs, outputs, strict=True): + var_space[name] = output return var_space diff --git a/plantseg/headless/headless.py b/plantseg/headless/headless.py index 27111e18..116b721f 100644 --- a/plantseg/headless/headless.py +++ b/plantseg/headless/headless.py @@ -104,7 +104,7 @@ def collect_jobs_list(inputs: dict | list[dict], inputs_schema: dict[str, RunTim return all_jobs -def run_headles_workflow_from_config(config: dict, path: str | Path = None): +def run_headles_workflow_from_config(config: dict, path: str | Path): config = validate_config(config) inputs = config["inputs"] diff --git a/plantseg/run_plantseg.py b/plantseg/run_plantseg.py index af1eec48..e7c4b226 100644 --- a/plantseg/run_plantseg.py +++ b/plantseg/run_plantseg.py @@ -47,9 +47,9 @@ def launch_napari(): def launch_workflow_headless(path: Path): """Run a workflow in headless mode.""" - from plantseg.headless.headless import run_headless_workflow_from_path + from plantseg.headless.headless import run_headless_workflow - run_headless_workflow_from_path(path) + run_headless_workflow(path=path) def launch_training(path: Path): From 2f503a2f46347ccd163cda3341511a802f5cd73f Mon Sep 17 00:00:00 2001 From: lorenzocerrone Date: Mon, 21 Oct 2024 22:48:29 +0200 Subject: [PATCH 5/6] fix: polish headless --- plantseg/tasks/io_tasks.py | 9 ++++----- plantseg/tasks/workflow_handler.py | 10 +++++----- tests/headless/test_headless.py | 2 ++ 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/plantseg/tasks/io_tasks.py b/plantseg/tasks/io_tasks.py index 6cdb3842..00046950 100644 --- a/plantseg/tasks/io_tasks.py +++ b/plantseg/tasks/io_tasks.py @@ -13,10 +13,6 @@ required=True, is_input_file=True, ), - "image_name": RunTimeInputSchema( - description="Name of the image (if None, the file name will be used)", - required=False, - ), }, ) def import_image_task( @@ -68,7 +64,10 @@ def import_image_task( "export_directory": RunTimeInputSchema( description="Output directory path where the image will be saved", required=True ), - "name_pattern": RunTimeInputSchema(description="Output file name", required=False), + "name_pattern": RunTimeInputSchema( + description="Output file name pattern. Can contain the special {image_name} or {file_name} tokens ", + required=False, + ), }, ) def export_image_task( diff --git a/plantseg/tasks/workflow_handler.py b/plantseg/tasks/workflow_handler.py index 964bf624..991c490f 100644 --- a/plantseg/tasks/workflow_handler.py +++ b/plantseg/tasks/workflow_handler.py @@ -78,7 +78,7 @@ class Task(BaseModel): class DAG(BaseModel): infos: Infos = Field(default_factory=Infos) - inputs: dict[str, Any] | list[dict[str, Any]] = Field(default_factory=dict) + inputs: list[dict[str, Any]] = Field(default_factory=lambda: [{}]) list_tasks: list[Task] = Field(default_factory=list) """ @@ -86,7 +86,7 @@ class DAG(BaseModel): Attributes: infos (Infos): A dictionary with the information of the workflow. - inputs (dict[str, Any]): A dictionary of the inputs of the workflow. For example path to the images and other runtime parameters. + inputs (list[dict[str, Any]): A dictionary of the inputs of the workflow. For example path to the images and other runtime parameters. list_tasks (list[Task]): A list of the tasks in the workflow. """ @@ -136,9 +136,9 @@ def prune_dag(dag: DAG) -> DAG: if task.id in reachable: new_dag.list_tasks.append(task) - for input_key, text in dag.inputs.items(): + for input_key, text in dag.inputs[0].items(): if input_key in reachable_inputs: - new_dag.inputs[input_key] = text + new_dag.inputs[0][input_key] = text new_dag.infos.inputs_schema[input_key] = dag.infos.inputs_schema[input_key] return new_dag @@ -242,7 +242,7 @@ def _unique_input(name, id: int = 0): value_schema.task = func_name self._dag.infos.inputs_schema[unique_name] = value_schema - self._dag.inputs[unique_name] = value + self._dag.inputs[0][unique_name] = value return unique_name def clean_dag(self): diff --git a/tests/headless/test_headless.py b/tests/headless/test_headless.py index a363f91f..975c8b0f 100644 --- a/tests/headless/test_headless.py +++ b/tests/headless/test_headless.py @@ -64,6 +64,8 @@ def test_create_workflow(tmp_path): }, ] + config['inputs'] = job_list + with open(tmp_path / 'workflow.yaml', 'w') as file: yaml.dump(config, file) From e67168292492b5fbc194b67fa73586643ad31bd5 Mon Sep 17 00:00:00 2001 From: lorenzocerrone Date: Mon, 21 Oct 2024 23:03:06 +0200 Subject: [PATCH 6/6] fix: minor bug in testing --- tests/headless/test_headless.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/headless/test_headless.py b/tests/headless/test_headless.py index 975c8b0f..ab91a73b 100644 --- a/tests/headless/test_headless.py +++ b/tests/headless/test_headless.py @@ -41,7 +41,7 @@ def test_create_workflow(tmp_path): dag = workflow_handler.dag assert len(dag.list_tasks) == 3 - assert len(dag.inputs.keys()) == 3 + assert len(dag.inputs[0].keys()) == 3 # Run the headless workflow