Skip to content

Commit

Permalink
fix: polish headless
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzocerrone committed Oct 21, 2024
1 parent a59370c commit 2f503a2
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
9 changes: 4 additions & 5 deletions plantseg/tasks/io_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@
required=True,
is_input_file=True,
),
"image_name": RunTimeInputSchema(
description="Name of the image (if None, the file name will be used)",
required=False,
),
},
)
def import_image_task(
Expand Down Expand Up @@ -68,7 +64,10 @@ def import_image_task(
"export_directory": RunTimeInputSchema(
description="Output directory path where the image will be saved", required=True
),
"name_pattern": RunTimeInputSchema(description="Output file name", required=False),
"name_pattern": RunTimeInputSchema(
description="Output file name pattern. Can contain the special {image_name} or {file_name} tokens ",
required=False,
),
},
)
def export_image_task(
Expand Down
10 changes: 5 additions & 5 deletions plantseg/tasks/workflow_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ class Task(BaseModel):

class DAG(BaseModel):
infos: Infos = Field(default_factory=Infos)
inputs: dict[str, Any] | list[dict[str, Any]] = Field(default_factory=dict)
inputs: list[dict[str, Any]] = Field(default_factory=lambda: [{}])
list_tasks: list[Task] = Field(default_factory=list)

"""
This model represents the Directed Acyclic Graph (DAG) of the workflow.
Attributes:
infos (Infos): A dictionary with the information of the workflow.
inputs (dict[str, Any]): A dictionary of the inputs of the workflow. For example path to the images and other runtime parameters.
inputs (list[dict[str, Any]): A dictionary of the inputs of the workflow. For example path to the images and other runtime parameters.
list_tasks (list[Task]): A list of the tasks in the workflow.
"""
Expand Down Expand Up @@ -136,9 +136,9 @@ def prune_dag(dag: DAG) -> DAG:
if task.id in reachable:
new_dag.list_tasks.append(task)

for input_key, text in dag.inputs.items():
for input_key, text in dag.inputs[0].items():
if input_key in reachable_inputs:
new_dag.inputs[input_key] = text
new_dag.inputs[0][input_key] = text
new_dag.infos.inputs_schema[input_key] = dag.infos.inputs_schema[input_key]

return new_dag
Expand Down Expand Up @@ -242,7 +242,7 @@ def _unique_input(name, id: int = 0):
value_schema.task = func_name
self._dag.infos.inputs_schema[unique_name] = value_schema

self._dag.inputs[unique_name] = value
self._dag.inputs[0][unique_name] = value
return unique_name

def clean_dag(self):
Expand Down
2 changes: 2 additions & 0 deletions tests/headless/test_headless.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def test_create_workflow(tmp_path):
},
]

config['inputs'] = job_list

with open(tmp_path / 'workflow.yaml', 'w') as file:
yaml.dump(config, file)

Expand Down

0 comments on commit 2f503a2

Please sign in to comment.