Skip to content

Commit

Permalink
Merge pull request #356 from kreshuklab/headless_gui
Browse files Browse the repository at this point in the history
Improved Headless + gui
  • Loading branch information
lorenzocerrone authored Oct 21, 2024
2 parents fd7ff08 + e671682 commit 5a9ae4f
Show file tree
Hide file tree
Showing 10 changed files with 314 additions and 149 deletions.
12 changes: 12 additions & 0 deletions plantseg/headless/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""This module contains the headless workflow for PlantSeg.
To build a headless workflow, you can:
- Register a new workflow manually using the plantseg API.
- Run a workflow from the napari viewer and export it as a configuration file.
The headless workflow configured can be run using the `run_headless_workflow` function.
"""

from plantseg.headless.headless import run_headles_workflow_from_config, run_headless_workflow

__all__ = ["run_headles_workflow_from_config", "run_headless_workflow"]
22 changes: 16 additions & 6 deletions plantseg/headless/basic_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from pathlib import Path

from plantseg.core.image import PlantSegImage
from plantseg.tasks.workflow_handler import DAG, Task, WorkflowHandler

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -43,12 +44,21 @@ def run_task(self, task: Task, var_space: dict):
func = self.func_registry.get_func(task.func)
outputs = func(**inputs, **task.parameters)

# Save outputs in var_space
for i, name in enumerate(task.outputs):
if isinstance(outputs, tuple):
var_space[name] = outputs[i]
else:
var_space[name] = outputs
if isinstance(outputs, PlantSegImage):
outputs = [outputs]

elif outputs is None:
outputs = []

assert isinstance(
outputs, (list, tuple)
), f"Task {task.func} should return a list of PlantSegImage, got {type(outputs)}"
assert len(outputs) == len(
task.outputs
), f"Task {task.func} should return {len(task.outputs)} outputs, got {len(outputs)}"

for name, output in zip(task.outputs, outputs, strict=True):
var_space[name] = output

return var_space

Expand Down
167 changes: 84 additions & 83 deletions plantseg/headless/headless.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
from pathlib import Path
from typing import Literal
from typing import Any, Literal

import yaml

from plantseg.headless.basic_runner import SerialRunner
from plantseg.tasks.workflow_handler import TaskUserInput
from plantseg.io import allowed_data_format
from plantseg.tasks.workflow_handler import RunTimeInputSchema

logger = logging.getLogger(__name__)

Expand All @@ -14,102 +15,115 @@
_implemented_runners = {'serial': SerialRunner}


def parse_input_path(user_input: TaskUserInput):
value = user_input.value
if value is None:
raise ValueError("Input path must be provided.")
def validate_config(config: dict):
if "inputs" not in config:
raise ValueError("The workflow configuration does not contain an 'inputs' section.")

elif isinstance(value, str):
path = Path(value)
if not path.exists():
raise FileNotFoundError(f"Input path {path} does not exist.")
if "infos" not in config:
raise ValueError("The workflow configuration does not contain an 'infos' section.")

return [path]
if "list_tasks" not in config:
raise ValueError("The workflow configuration does not contain an 'list_tasks' section.")

elif isinstance(value, list):
paths = [Path(p) for p in value]
for path in paths:
if not path.exists():
raise FileNotFoundError(f"Input path {path} does not exist.")
if "runner" not in config:
logger.warning(
"The workflow configuration does not contain a 'runner' section. Using the default serial runner."
)
config["runner"] = "serial"

return paths
return config

else:
raise ValueError("Input path must be a string or a list of strings.")

def parse_import_image_task(input_path, allow_dir: bool) -> list[Path]:
if isinstance(input_path, str):
input_path = Path(input_path)

def output_directory(user_input: TaskUserInput):
user_input = user_input.value
if user_input is None:
raise ValueError("Output directory must be provided.")
if not input_path.exists():
raise FileNotFoundError(f"File {input_path} does not exist.")

if not isinstance(user_input, str):
raise ValueError("Output directory must be a string.")
if input_path.is_file():
list_files = [input_path]
elif input_path.is_dir():
if not allow_dir:
raise ValueError(f"Directory {input_path} is not allowed when multiple input files are expected.")

output_dir = Path(user_input)
list_files = list(input_path.glob("*"))
else:
raise ValueError(f"Path {input_path} is not a file or a directory.")

if not output_dir.exists():
output_dir.mkdir(parents=True)
return output_dir
list_files = [f for f in list_files if f.suffix in allowed_data_format]
if not list_files:
raise ValueError(f"No valid files found in {input_path}.")

return list_files

def parse_generic_input(user_input: TaskUserInput):
value = user_input.value

if value is None:
value = user_input.headless_default
def collect_jobs_list(inputs: dict | list[dict], inputs_schema: dict[str, RunTimeInputSchema]) -> list[dict[str, Any]]:
"""
Parse the inputs and create a list of jobs to run.
"""

if value is None and user_input.user_input_required:
raise ValueError(f"Input must be provided. {user_input}")
if isinstance(inputs, dict):
inputs = [inputs]

return value
num_is_input_file = sum([1 for schema in inputs_schema.values() if schema.is_input_file])

if num_is_input_file == 0:
raise ValueError("No input files found in the inputs schema. The workflow cannot run.")
elif num_is_input_file > 1:
allow_dir = False
else:
allow_dir = True

def parse_input_config(inputs_config: dict):
inputs_config = {k: TaskUserInput(**v) for k, v in inputs_config.items()}
all_jobs = []
for input_dict in inputs:
if not isinstance(input_dict, dict):
raise ValueError(f"Input {input_dict} should be a dictionary.")

list_input_keys = {}
single_input_keys = {}
has_input_path = False
has_output_dir = False
inputs_files = {}
for name, schema in inputs_schema.items():
if schema.is_input_file:
if name not in inputs_files:
inputs_files[name] = []

for key, value in inputs_config.items():
if key.find("input_path") != -1:
list_input_keys[key] = parse_input_path(value)
has_input_path = True
inputs_files[name].extend(parse_import_image_task(input_dict[name], allow_dir=allow_dir))

elif key.find("export_directory") != -1:
single_input_keys[key] = output_directory(value)
has_output_dir = True
list_len = [len(files) for files in inputs_files.values()]
if len(set(list_len)) != 1:
raise ValueError(f"Inputs have different number of files. found {inputs_files}")

else:
single_input_keys[key] = parse_generic_input(value)
list_files = list(zip(*inputs_files.values()))
list_keys = list(inputs_files.keys())
list_jobs = [dict(zip(list_keys, files)) for files in list_files]

if not has_input_path:
raise ValueError("The provided workflow configuration does not contain an input path.")
for job in list_jobs:
for key, value in input_dict.items():
if key not in job:
job[key] = value
all_jobs.append(job)
return all_jobs

if not has_output_dir:
raise ValueError("The provided workflow configuration does not contain an output directory.")

all_length = [len(v) for v in list_input_keys.values()]
# check if all input paths have the same length
if not all([_l == all_length[0] for _l in all_length]):
raise ValueError("All input paths must have the same length.")
def run_headles_workflow_from_config(config: dict, path: str | Path):
config = validate_config(config)

num_inputs = all_length[0]
inputs = config["inputs"]
inputs_schema = config["infos"]["inputs_schema"]
inputs_schema = {k: RunTimeInputSchema(**v) for k, v in inputs_schema.items()}

jobs_inputs = []
for i in range(num_inputs):
job_input = {}
for key in list_input_keys:
job_input[key] = list_input_keys[key][i]
jobs_list = collect_jobs_list(inputs, inputs_schema)

for key in single_input_keys:
job_input[key] = single_input_keys[key]
runner = config.get("runner")
if runner not in _implemented_runners:
raise ValueError(f"Runner {runner} is not implemented.")

runner = _implemented_runners[runner](path)

jobs_inputs.append(job_input)
for job_input in jobs_list:
logger.info(f"Submitting job with input: {job_input}")
runner.submit_job(job_input)

return jobs_inputs
logger.info("All jobs have been submitted. Running the workflow...")


def run_headless_workflow(path: str | Path):
Expand All @@ -127,17 +141,4 @@ def run_headless_workflow(path: str | Path):
with path.open("r") as file:
config = yaml.safe_load(file)

job_inputs = parse_input_config(config["inputs"])

runner = config.get("runner", "serial")
if runner not in _implemented_runners:
raise ValueError(f"Runner {runner} is not implemented.")

runner = _implemented_runners[runner](path)

for job_input in job_inputs:
logger.info(f"Submitting job with input: {job_input}")
runner.submit_job(job_input)

logger.info("All jobs have been submitted. Running the workflow...")
# TODO: When parallel runners are implemented, hew we need to add something to wait for all jobs to finish
run_headles_workflow_from_config(config, path=path)
Empty file.
33 changes: 33 additions & 0 deletions plantseg/headless_gui/headless_gui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from enum import Enum

from magicgui import magicgui
from magicgui.experimental import guiclass
from magicgui.widgets import Container

# from plantseg.headless.headless import run_headless_workflow_from_path
from plantseg.headless_gui.plantseg_classic import widget_plantseg_classic

all_workflows = {
"PlantsegClassic": widget_plantseg_classic,
}


@magicgui(
auto_call=True,
name={
'label': 'Mode',
'tooltip': 'Select the workflow to run',
'choices': list(all_workflows.keys()),
},
)
def workflow_selector(name: str = list(all_workflows.keys())[0]):
for workflow in all_workflows.values():
workflow.hide()

all_workflows[name].show()


if __name__ == "__main__":
gui_container = Container(widgets=[workflow_selector, *all_workflows.values()], labels=False)
workflow_selector()
gui_container.show(run=True)
75 changes: 75 additions & 0 deletions plantseg/headless_gui/plantseg_classic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from enum import Enum
from pathlib import Path

from magicgui import magicgui
from magicgui.widgets import Container

########################################################################################################################
#
# Input Setup Widget
#
########################################################################################################################


class FilePickMode(Enum):
File = 'File'
Directory = 'Directory'

@classmethod
def to_choices(cls) -> list[str]:
return [mode.value for mode in FilePickMode]


@magicgui(
call_button=False,
file_pick_mode={
'label': 'Input Mode',
'tooltip': 'Select the workflow to run',
'choices': FilePickMode.to_choices(),
},
file={
'label': 'File',
'mode': 'r',
'layout': 'vertical',
'tooltip': 'Select the file to process one by one',
},
directory={
'label': 'Directory',
'mode': 'd',
'tooltip': 'Process all files in the directory',
},
)
def widget_input_model(
file_pick_mode: str = FilePickMode.File.value,
file: Path = Path('.').absolute(),
directory: Path = Path('.').absolute(),
):
pass


@widget_input_model.file_pick_mode.changed.connect
def _on_mode_change(file_pick_mode):
if file_pick_mode == FilePickMode.File.value:
widget_input_model.file.show()
widget_input_model.directory.hide()
else:
widget_input_model.file.hide()
widget_input_model.directory.show()


widget_input_model.directory.hide()


########################################################################################################################
#
# PlantSeg Classic Workflow
#
########################################################################################################################
@magicgui(
call_button="Run - PlantSeg Classic",
)
def widget_setup_workflow_config():
pass


widget_plantseg_classic = Container(widgets=[widget_input_model, widget_setup_workflow_config], labels=False)
2 changes: 1 addition & 1 deletion plantseg/run_plantseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def launch_workflow_headless(path: Path):
"""Run a workflow in headless mode."""
from plantseg.headless.headless import run_headless_workflow

run_headless_workflow(path)
run_headless_workflow(path=path)


def launch_training(path: Path):
Expand Down
Loading

0 comments on commit 5a9ae4f

Please sign in to comment.