Skip to content

Commit

Permalink
feat: cleanup headless
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzocerrone committed Oct 14, 2024
1 parent 320b175 commit 10f6b4f
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 89 deletions.
12 changes: 12 additions & 0 deletions plantseg/headless/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""This module contains the headless workflow for PlantSeg.
To build a headless workflow, you can:
- Register a new workflow manually using the plantseg API.
- Run a workflow from the napari viewer and export it as a configuration file.
The headless workflow configured can be run using the `run_headless_workflow` function.
"""

from plantseg.headless.headless import run_headles_workflow_from_config, run_headless_workflow

__all__ = ["run_headles_workflow_from_config", "run_headless_workflow"]
167 changes: 84 additions & 83 deletions plantseg/headless/headless.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
from pathlib import Path
from typing import Literal
from typing import Any, Literal

import yaml

from plantseg.headless.basic_runner import SerialRunner
from plantseg.tasks.workflow_handler import TaskUserInput
from plantseg.io import allowed_data_format
from plantseg.tasks.workflow_handler import RunTimeInputSchema

logger = logging.getLogger(__name__)

Expand All @@ -14,102 +15,115 @@
_implemented_runners = {'serial': SerialRunner}


def parse_input_path(user_input: TaskUserInput):
value = user_input.value
if value is None:
raise ValueError("Input path must be provided.")
def validate_config(config: dict):
if "inputs" not in config:
raise ValueError("The workflow configuration does not contain an 'inputs' section.")

elif isinstance(value, str):
path = Path(value)
if not path.exists():
raise FileNotFoundError(f"Input path {path} does not exist.")
if "infos" not in config:
raise ValueError("The workflow configuration does not contain an 'infos' section.")

return [path]
if "list_tasks" not in config:
raise ValueError("The workflow configuration does not contain an 'list_tasks' section.")

elif isinstance(value, list):
paths = [Path(p) for p in value]
for path in paths:
if not path.exists():
raise FileNotFoundError(f"Input path {path} does not exist.")
if "runner" not in config:
logger.warning(
"The workflow configuration does not contain a 'runner' section. Using the default serial runner."
)
config["runner"] = "serial"

return paths
return config

else:
raise ValueError("Input path must be a string or a list of strings.")

def parse_import_image_task(input_path, allow_dir: bool) -> list[Path]:
if isinstance(input_path, str):
input_path = Path(input_path)

def output_directory(user_input: TaskUserInput):
user_input = user_input.value
if user_input is None:
raise ValueError("Output directory must be provided.")
if not input_path.exists():
raise FileNotFoundError(f"File {input_path} does not exist.")

if not isinstance(user_input, str):
raise ValueError("Output directory must be a string.")
if input_path.is_file():
list_files = [input_path]
elif input_path.is_dir():
if not allow_dir:
raise ValueError(f"Directory {input_path} is not allowed when multiple input files are expected.")

output_dir = Path(user_input)
list_files = list(input_path.glob("*"))
else:
raise ValueError(f"Path {input_path} is not a file or a directory.")

if not output_dir.exists():
output_dir.mkdir(parents=True)
return output_dir
list_files = [f for f in list_files if f.suffix in allowed_data_format]
if not list_files:
raise ValueError(f"No valid files found in {input_path}.")

return list_files

def parse_generic_input(user_input: TaskUserInput):
value = user_input.value

if value is None:
value = user_input.headless_default
def collect_jobs_list(inputs: dict | list[dict], inputs_schema: dict[str, RunTimeInputSchema]) -> list[dict[str, Any]]:
"""
Parse the inputs and create a list of jobs to run.
"""

if value is None and user_input.user_input_required:
raise ValueError(f"Input must be provided. {user_input}")
if isinstance(inputs, dict):
inputs = [inputs]

return value
num_is_input_file = sum([1 for schema in inputs_schema.values() if schema.is_input_file])

if num_is_input_file == 0:
raise ValueError("No input files found in the inputs schema. The workflow cannot run.")
elif num_is_input_file > 1:
allow_dir = False
else:
allow_dir = True

def parse_input_config(inputs_config: dict):
inputs_config = {k: TaskUserInput(**v) for k, v in inputs_config.items()}
all_jobs = []
for input_dict in inputs:
if not isinstance(input_dict, dict):
raise ValueError(f"Input {input_dict} should be a dictionary.")

list_input_keys = {}
single_input_keys = {}
has_input_path = False
has_output_dir = False
inputs_files = {}
for name, schema in inputs_schema.items():
if schema.is_input_file:
if name not in inputs_files:
inputs_files[name] = []

for key, value in inputs_config.items():
if key.find("input_path") != -1:
list_input_keys[key] = parse_input_path(value)
has_input_path = True
inputs_files[name].extend(parse_import_image_task(input_dict[name], allow_dir=allow_dir))

elif key.find("export_directory") != -1:
single_input_keys[key] = output_directory(value)
has_output_dir = True
list_len = [len(files) for files in inputs_files.values()]
if len(set(list_len)) != 1:
raise ValueError(f"Inputs have different number of files. found {inputs_files}")

else:
single_input_keys[key] = parse_generic_input(value)
list_files = list(zip(*inputs_files.values()))
list_keys = list(inputs_files.keys())
list_jobs = [dict(zip(list_keys, files)) for files in list_files]

if not has_input_path:
raise ValueError("The provided workflow configuration does not contain an input path.")
for job in list_jobs:
for key, value in input_dict.items():
if key not in job:
job[key] = value
all_jobs.append(job)
return all_jobs

if not has_output_dir:
raise ValueError("The provided workflow configuration does not contain an output directory.")

all_length = [len(v) for v in list_input_keys.values()]
# check if all input paths have the same length
if not all([_l == all_length[0] for _l in all_length]):
raise ValueError("All input paths must have the same length.")
def run_headles_workflow_from_config(config: dict, path: str | Path = None):
config = validate_config(config)

num_inputs = all_length[0]
inputs = config["inputs"]
inputs_schema = config["infos"]["inputs_schema"]
inputs_schema = {k: RunTimeInputSchema(**v) for k, v in inputs_schema.items()}

jobs_inputs = []
for i in range(num_inputs):
job_input = {}
for key in list_input_keys:
job_input[key] = list_input_keys[key][i]
jobs_list = collect_jobs_list(inputs, inputs_schema)

for key in single_input_keys:
job_input[key] = single_input_keys[key]
runner = config.get("runner")
if runner not in _implemented_runners:
raise ValueError(f"Runner {runner} is not implemented.")

runner = _implemented_runners[runner](path)

jobs_inputs.append(job_input)
for job_input in jobs_list:
logger.info(f"Submitting job with input: {job_input}")
runner.submit_job(job_input)

return jobs_inputs
logger.info("All jobs have been submitted. Running the workflow...")


def run_headless_workflow(path: str | Path):
Expand All @@ -127,17 +141,4 @@ def run_headless_workflow(path: str | Path):
with path.open("r") as file:
config = yaml.safe_load(file)

job_inputs = parse_input_config(config["inputs"])

runner = config.get("runner", "serial")
if runner not in _implemented_runners:
raise ValueError(f"Runner {runner} is not implemented.")

runner = _implemented_runners[runner](path)

for job_input in job_inputs:
logger.info(f"Submitting job with input: {job_input}")
runner.submit_job(job_input)

logger.info("All jobs have been submitted. Running the workflow...")
# TODO: When parallel runners are implemented, hew we need to add something to wait for all jobs to finish
run_headles_workflow_from_config(config, path=path)
2 changes: 1 addition & 1 deletion plantseg/tasks/io_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"input_path": RunTimeInputSchema(
description="Path to a file, or a directory containing files (all files will be imported) or list of paths.",
required=True,
is_input=True,
is_input_file=True,
),
"image_name": RunTimeInputSchema(
description="Name of the image (if None, the file name will be used)",
Expand Down
4 changes: 2 additions & 2 deletions plantseg/tasks/workflow_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class Task(BaseModel):

class DAG(BaseModel):
infos: Infos = Field(default_factory=Infos)
inputs: dict[str, Any] = Field(default_factory=dict)
inputs: dict[str, Any] | list[dict[str, Any]] = Field(default_factory=dict)
list_tasks: list[Task] = Field(default_factory=list)

"""
Expand All @@ -93,7 +93,7 @@ class DAG(BaseModel):

@property
def list_inputs(self):
return list(self.inputs.keys())
return list(self.infos.inputs_schema.keys())


def prune_dag(dag: DAG) -> DAG:
Expand Down
17 changes: 14 additions & 3 deletions tests/headless/test_headless.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,20 @@ def test_create_workflow(tmp_path):
with open(tmp_path / 'workflow.yaml', 'r') as file:
config = yaml.safe_load(file)

config['inputs']['input_path']['value'] = [str(path_tiff_1), str(path_tiff_2)]
config['inputs']['export_directory']['value'] = str(tmp_path / 'output')
config['inputs']['name_pattern']['value'] = '{original_name}_export'
job_list = [
{
"input_path": str(path_tiff_1),
"export_directory": str(tmp_path / 'output'),
"name_pattern": "{original_name}_export",
},
{
"input_path": str(path_tiff_2),
"export_directory": str(tmp_path / 'output'),
"name_pattern": "{original_name}_export",
},
]

config['inputs'] = job_list

with open(tmp_path / 'workflow.yaml', 'w') as file:
yaml.dump(config, file)
Expand Down

0 comments on commit 10f6b4f

Please sign in to comment.