diff --git a/CHANGELOG.md b/CHANGELOG.md index dffa68557a..b74deda546 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,25 @@ **Note**: Numbers like (\#1234) point to closed Pull Requests on the fractal-server repository. -# 2.11.0 (unreleased) +> WARNING: Notes for 2.11.0 prereleases are currently separated, and they should be merged at a later stage. + +# 2.11.0a1 + +> Note: This release requires running a `fractalctl update-db-data` + +(changes that affect API, database lifecycle, runner, ...) + +* Split filters into attribute and types (\#2168). +* Support multiple options for attribute filters (\#2168). +* Deprecate support for attribute filters in workflowtask (\#2168). +* Introduce support for attribute filters in jobs (\#2168). +* Data migration script (\#2168). + +# 2.11.0a0 -* Runner - * Integrate database write access in runner component (\#2169). * API: * Update and simplify `/api/v2/project/{project_id}/status/` (\#2169). +* Runner + * Integrate database write access in runner component (\#2169). # 2.10.5 diff --git a/benchmarks/runner/mocks.py b/benchmarks/runner/mocks.py index 20e16c34a9..26c1ce8013 100644 --- a/benchmarks/runner/mocks.py +++ b/benchmarks/runner/mocks.py @@ -1,5 +1,4 @@ from typing import Any -from typing import Literal from typing import Optional from pydantic import BaseModel @@ -13,21 +12,14 @@ class DatasetV2Mock(BaseModel): name: str zarr_dir: str images: list[dict[str, Any]] = Field(default_factory=list) - filters: dict[Literal["types", "attributes"], dict[str, Any]] = Field( - default_factory=dict - ) + type_filters: dict[str, bool] = Field(default_factory=dict) + attribute_filters: dict[str, list[Any]] = Field(default_factory=dict) history: list = Field(default_factory=list) @property def image_zarr_urls(self) -> list[str]: return [image["zarr_urls"] for image in self.images] - @validator("filters", always=True) - def _default_filters(cls, value): - if value == {}: - return {"types": {}, "attributes": {}} - return value - class TaskV2Mock(BaseModel): id: int @@ -77,13 +69,8 @@ class WorkflowTaskV2Mock(BaseModel): meta_parallel: Optional[dict[str, Any]] = Field() meta_non_parallel: Optional[dict[str, Any]] = Field() task: TaskV2Mock = None - input_filters: dict[str, Any] = Field(default_factory=dict) + type_filters: dict[str, bool] = Field(default_factory=dict) order: int id: int workflow_id: int = 0 task_id: int - - @validator("input_filters", always=True) - def _default_filters(cls, value): - if value == {}: - return {"types": {}, "attributes": {}} diff --git a/fractal_server/app/models/v2/dataset.py b/fractal_server/app/models/v2/dataset.py index 91ef47006c..7f4cbfe2ae 100644 --- a/fractal_server/app/models/v2/dataset.py +++ b/fractal_server/app/models/v2/dataset.py @@ -11,6 +11,7 @@ from sqlmodel import SQLModel from ....utils import get_timestamp +from fractal_server.images.models import AttributeFiltersType class DatasetV2(SQLModel, table=True): @@ -41,12 +42,14 @@ class Config: sa_column=Column(JSON, server_default="[]", nullable=False) ) - filters: dict[Literal["attributes", "types"], dict[str, Any]] = Field( - sa_column=Column( - JSON, - nullable=False, - server_default='{"attributes": {}, "types": {}}', - ) + filters: Optional[ + dict[Literal["attributes", "types"], dict[str, Any]] + ] = Field(sa_column=Column(JSON, nullable=True, server_default="null")) + type_filters: dict[str, bool] = Field( + sa_column=Column(JSON, nullable=False, server_default="{}") + ) + attribute_filters: AttributeFiltersType = Field( + sa_column=Column(JSON, nullable=False, server_default="{}") ) @property diff --git a/fractal_server/app/models/v2/job.py b/fractal_server/app/models/v2/job.py index 2b5789a53b..26efd12b07 100644 --- a/fractal_server/app/models/v2/job.py +++ b/fractal_server/app/models/v2/job.py @@ -10,6 +10,7 @@ from ....utils import get_timestamp from ...schemas.v2 import JobStatusTypeV2 +from fractal_server.images.models import AttributeFiltersType class JobV2(SQLModel, table=True): @@ -49,3 +50,7 @@ class Config: ) status: str = JobStatusTypeV2.SUBMITTED log: Optional[str] = None + + attribute_filters: AttributeFiltersType = Field( + sa_column=Column(JSON, nullable=False, server_default="{}") + ) diff --git a/fractal_server/app/models/v2/workflowtask.py b/fractal_server/app/models/v2/workflowtask.py index 32f64215a7..30d29ae356 100644 --- a/fractal_server/app/models/v2/workflowtask.py +++ b/fractal_server/app/models/v2/workflowtask.py @@ -25,14 +25,11 @@ class Config: args_parallel: Optional[dict[str, Any]] = Field(sa_column=Column(JSON)) args_non_parallel: Optional[dict[str, Any]] = Field(sa_column=Column(JSON)) - input_filters: dict[ - Literal["attributes", "types"], dict[str, Any] - ] = Field( - sa_column=Column( - JSON, - nullable=False, - server_default='{"attributes": {}, "types": {}}', - ) + input_filters: Optional[ + dict[Literal["attributes", "types"], dict[str, Any]] + ] = Field(sa_column=Column(JSON, nullable=True, server_default="null")) + type_filters: dict[str, bool] = Field( + sa_column=Column(JSON, nullable=False, server_default="{}") ) # Task diff --git a/fractal_server/app/routes/api/v2/_aux_functions.py b/fractal_server/app/routes/api/v2/_aux_functions.py index 63762d9cdd..358ac1d26e 100644 --- a/fractal_server/app/routes/api/v2/_aux_functions.py +++ b/fractal_server/app/routes/api/v2/_aux_functions.py @@ -21,7 +21,6 @@ from ....models.v2 import WorkflowTaskV2 from ....models.v2 import WorkflowV2 from ....schemas.v2 import JobStatusTypeV2 -from fractal_server.images import Filters async def _get_project_check_owner( @@ -336,7 +335,7 @@ async def _workflow_insert_task( meta_non_parallel: Optional[dict[str, Any]] = None, args_non_parallel: Optional[dict[str, Any]] = None, args_parallel: Optional[dict[str, Any]] = None, - input_filters: Optional[Filters] = None, + type_filters: Optional[dict[str, bool]] = None, db: AsyncSession, ) -> WorkflowTaskV2: """ @@ -350,7 +349,7 @@ async def _workflow_insert_task( meta_non_parallel: args_non_parallel: args_parallel: - input_filters: + type_filters: db: """ db_workflow = await db.get(WorkflowV2, workflow_id) @@ -376,12 +375,6 @@ async def _workflow_insert_task( if final_meta_non_parallel == {}: final_meta_non_parallel = None - # Prepare input_filters attribute - if input_filters is None: - input_filters_kwarg = {} - else: - input_filters_kwarg = dict(input_filters=input_filters) - # Create DB entry wf_task = WorkflowTaskV2( task_type=task_type, @@ -390,7 +383,7 @@ async def _workflow_insert_task( args_parallel=args_parallel, meta_parallel=final_meta_parallel, meta_non_parallel=final_meta_non_parallel, - **input_filters_kwarg, + type_filters=(type_filters or dict()), ) db_workflow.task_list.append(wf_task) flag_modified(db_workflow, "task_list") diff --git a/fractal_server/app/routes/api/v2/images.py b/fractal_server/app/routes/api/v2/images.py index 01ccd5e187..d41e925768 100644 --- a/fractal_server/app/routes/api/v2/images.py +++ b/fractal_server/app/routes/api/v2/images.py @@ -8,6 +8,8 @@ from fastapi import status from pydantic import BaseModel from pydantic import Field +from pydantic import root_validator +from pydantic import validator from sqlalchemy.orm.attributes import flag_modified from ._aux_functions import _get_dataset_check_owner @@ -15,9 +17,14 @@ from fractal_server.app.db import get_async_db from fractal_server.app.models import UserOAuth from fractal_server.app.routes.auth import current_active_user -from fractal_server.images import Filters +from fractal_server.app.schemas._filter_validators import ( + validate_attribute_filters, +) +from fractal_server.app.schemas._filter_validators import validate_type_filters +from fractal_server.app.schemas._validators import root_validate_dict_keys from fractal_server.images import SingleImage from fractal_server.images import SingleImageUpdate +from fractal_server.images.models import AttributeFiltersType from fractal_server.images.tools import find_image_by_zarr_url from fractal_server.images.tools import match_filter @@ -38,7 +45,18 @@ class ImagePage(BaseModel): class ImageQuery(BaseModel): zarr_url: Optional[str] - filters: Filters = Field(default_factory=Filters) + type_filters: dict[str, bool] = Field(default_factory=dict) + attribute_filters: AttributeFiltersType = Field(default_factory=dict) + + _dict_keys = root_validator(pre=True, allow_reuse=True)( + root_validate_dict_keys + ) + _type_filters = validator("type_filters", allow_reuse=True)( + validate_type_filters + ) + _attribute_filters = validator("attribute_filters", allow_reuse=True)( + validate_attribute_filters + ) @router.post( @@ -124,7 +142,11 @@ async def query_dataset_images( images = [ image for image in images - if match_filter(image, Filters(**dataset.filters)) + if match_filter( + image=image, + type_filters=dataset.type_filters, + attribute_filters=dataset.attribute_filters, + ) ] attributes = {} @@ -154,13 +176,14 @@ async def query_dataset_images( else: images = [image] - if query.filters.attributes or query.filters.types: + if query.attribute_filters or query.type_filters: images = [ image for image in images if match_filter( - image, - Filters(**query.filters.dict()), + image=image, + type_filters=query.type_filters, + attribute_filters=query.attribute_filters, ) ] diff --git a/fractal_server/app/routes/api/v2/submit.py b/fractal_server/app/routes/api/v2/submit.py index fc34f2929b..2c6284d277 100644 --- a/fractal_server/app/routes/api/v2/submit.py +++ b/fractal_server/app/routes/api/v2/submit.py @@ -159,7 +159,11 @@ async def apply_workflow( dataset_id=dataset_id, workflow_id=workflow_id, user_email=user.email, - dataset_dump=json.loads(dataset.json(exclude={"images", "history"})), + # The 'filters' field is not supported any more but still exists as a + # database column, therefore we manually exclude it from dumps. + dataset_dump=json.loads( + dataset.json(exclude={"images", "history", "filters"}) + ), workflow_dump=json.loads(workflow.json(exclude={"task_list"})), project_dump=json.loads(project.json(exclude={"user_list"})), **job_create.dict(), diff --git a/fractal_server/app/routes/api/v2/workflowtask.py b/fractal_server/app/routes/api/v2/workflowtask.py index d43bcf0a25..ee4ac3caab 100644 --- a/fractal_server/app/routes/api/v2/workflowtask.py +++ b/fractal_server/app/routes/api/v2/workflowtask.py @@ -109,7 +109,7 @@ async def replace_workflowtask( task_type=task.type, task=task, # old-task values - input_filters=old_workflow_task.input_filters, + type_filters=old_workflow_task.type_filters, # possibly new values args_non_parallel=_args_non_parallel, args_parallel=_args_parallel, @@ -183,7 +183,7 @@ async def create_workflowtask( meta_parallel=new_task.meta_parallel, args_non_parallel=new_task.args_non_parallel, args_parallel=new_task.args_parallel, - input_filters=new_task.input_filters, + type_filters=new_task.type_filters, db=db, ) @@ -274,7 +274,7 @@ async def update_workflowtask( if not actual_args: actual_args = None setattr(db_wf_task, key, actual_args) - elif key in ["meta_parallel", "meta_non_parallel", "input_filters"]: + elif key in ["meta_parallel", "meta_non_parallel", "type_filters"]: setattr(db_wf_task, key, value) else: raise HTTPException( diff --git a/fractal_server/app/runner/v2/__init__.py b/fractal_server/app/runner/v2/__init__.py index fee385c738..61438a9630 100644 --- a/fractal_server/app/runner/v2/__init__.py +++ b/fractal_server/app/runner/v2/__init__.py @@ -327,6 +327,7 @@ async def submit_workflow( worker_init=worker_init, first_task_index=job.first_task_index, last_task_index=job.last_task_index, + job_attribute_filters=job.attribute_filters, **backend_specific_kwargs, ) diff --git a/fractal_server/app/runner/v2/_local/__init__.py b/fractal_server/app/runner/v2/_local/__init__.py index 74ca9fc78a..5230b870e2 100644 --- a/fractal_server/app/runner/v2/_local/__init__.py +++ b/fractal_server/app/runner/v2/_local/__init__.py @@ -29,6 +29,7 @@ from ..runner import execute_tasks_v2 from ._submit_setup import _local_submit_setup from .executor import FractalThreadPoolExecutor +from fractal_server.images.models import AttributeFiltersType def _process_workflow( @@ -39,6 +40,7 @@ def _process_workflow( workflow_dir_local: Path, first_task_index: int, last_task_index: int, + job_attribute_filters: AttributeFiltersType, ) -> None: """ Run the workflow using a `FractalThreadPoolExecutor`. @@ -54,6 +56,7 @@ def _process_workflow( workflow_dir_remote=workflow_dir_local, logger_name=logger_name, submit_setup_call=_local_submit_setup, + job_attribute_filters=job_attribute_filters, ) @@ -66,6 +69,7 @@ async def process_workflow( first_task_index: Optional[int] = None, last_task_index: Optional[int] = None, logger_name: str, + job_attribute_filters: AttributeFiltersType, # Slurm-specific user_cache_dir: Optional[str] = None, slurm_user: Optional[str] = None, @@ -146,4 +150,5 @@ async def process_workflow( workflow_dir_local=workflow_dir_local, first_task_index=first_task_index, last_task_index=last_task_index, + job_attribute_filters=job_attribute_filters, ) diff --git a/fractal_server/app/runner/v2/_local_experimental/__init__.py b/fractal_server/app/runner/v2/_local_experimental/__init__.py index c298779ec2..36a20eeb4b 100644 --- a/fractal_server/app/runner/v2/_local_experimental/__init__.py +++ b/fractal_server/app/runner/v2/_local_experimental/__init__.py @@ -11,6 +11,7 @@ from ..runner import execute_tasks_v2 from ._submit_setup import _local_submit_setup from .executor import FractalProcessPoolExecutor +from fractal_server.images.models import AttributeFiltersType def _process_workflow( @@ -21,6 +22,7 @@ def _process_workflow( workflow_dir_local: Path, first_task_index: int, last_task_index: int, + job_attribute_filters: AttributeFiltersType, ) -> None: """ Run the workflow using a `FractalProcessPoolExecutor`. @@ -39,6 +41,7 @@ def _process_workflow( workflow_dir_remote=workflow_dir_local, logger_name=logger_name, submit_setup_call=_local_submit_setup, + job_attribute_filters=job_attribute_filters, ) except BrokenProcessPool as e: raise JobExecutionError( @@ -58,6 +61,7 @@ async def process_workflow( first_task_index: Optional[int] = None, last_task_index: Optional[int] = None, logger_name: str, + job_attribute_filters: AttributeFiltersType, # Slurm-specific user_cache_dir: Optional[str] = None, slurm_user: Optional[str] = None, @@ -138,4 +142,5 @@ async def process_workflow( workflow_dir_local=workflow_dir_local, first_task_index=first_task_index, last_task_index=last_task_index, + job_attribute_filters=job_attribute_filters, ) diff --git a/fractal_server/app/runner/v2/_slurm_ssh/__init__.py b/fractal_server/app/runner/v2/_slurm_ssh/__init__.py index def74eb560..8cfeeedecb 100644 --- a/fractal_server/app/runner/v2/_slurm_ssh/__init__.py +++ b/fractal_server/app/runner/v2/_slurm_ssh/__init__.py @@ -29,9 +29,9 @@ from ...set_start_and_last_task_index import set_start_and_last_task_index from ..runner import execute_tasks_v2 from ._submit_setup import _slurm_submit_setup +from fractal_server.images.models import AttributeFiltersType from fractal_server.logger import set_logger - logger = set_logger(__name__) @@ -46,6 +46,7 @@ def _process_workflow( last_task_index: int, fractal_ssh: FractalSSH, worker_init: Optional[Union[str, list[str]]] = None, + job_attribute_filters: AttributeFiltersType, ) -> None: """ Run the workflow using a `FractalSlurmSSHExecutor`. @@ -86,6 +87,7 @@ def _process_workflow( workflow_dir_remote=workflow_dir_remote, logger_name=logger_name, submit_setup_call=_slurm_submit_setup, + job_attribute_filters=job_attribute_filters, ) @@ -98,12 +100,13 @@ async def process_workflow( first_task_index: Optional[int] = None, last_task_index: Optional[int] = None, logger_name: str, - # Not used + job_attribute_filters: AttributeFiltersType, fractal_ssh: FractalSSH, + worker_init: Optional[str] = None, + # Not used user_cache_dir: Optional[str] = None, slurm_user: Optional[str] = None, slurm_account: Optional[str] = None, - worker_init: Optional[str] = None, ) -> None: """ Process workflow (SLURM backend public interface) @@ -127,4 +130,5 @@ async def process_workflow( last_task_index=last_task_index, worker_init=worker_init, fractal_ssh=fractal_ssh, + job_attribute_filters=job_attribute_filters, ) diff --git a/fractal_server/app/runner/v2/_slurm_sudo/__init__.py b/fractal_server/app/runner/v2/_slurm_sudo/__init__.py index 2fafc67ef5..2653bc2a5e 100644 --- a/fractal_server/app/runner/v2/_slurm_sudo/__init__.py +++ b/fractal_server/app/runner/v2/_slurm_sudo/__init__.py @@ -27,6 +27,7 @@ from ...set_start_and_last_task_index import set_start_and_last_task_index from ..runner import execute_tasks_v2 from ._submit_setup import _slurm_submit_setup +from fractal_server.images.models import AttributeFiltersType def _process_workflow( @@ -42,6 +43,7 @@ def _process_workflow( slurm_account: Optional[str] = None, user_cache_dir: str, worker_init: Optional[Union[str, list[str]]] = None, + job_attribute_filters: AttributeFiltersType, ) -> None: """ Run the workflow using a `FractalSlurmExecutor`. @@ -79,6 +81,7 @@ def _process_workflow( workflow_dir_remote=workflow_dir_remote, logger_name=logger_name, submit_setup_call=_slurm_submit_setup, + job_attribute_filters=job_attribute_filters, ) @@ -91,6 +94,7 @@ async def process_workflow( first_task_index: Optional[int] = None, last_task_index: Optional[int] = None, logger_name: str, + job_attribute_filters: AttributeFiltersType, # Slurm-specific user_cache_dir: Optional[str] = None, slurm_user: Optional[str] = None, @@ -120,4 +124,5 @@ async def process_workflow( slurm_user=slurm_user, slurm_account=slurm_account, worker_init=worker_init, + job_attribute_filters=job_attribute_filters, ) diff --git a/fractal_server/app/runner/v2/merge_outputs.py b/fractal_server/app/runner/v2/merge_outputs.py index bf84c94b8b..c1a6cbe2cf 100644 --- a/fractal_server/app/runner/v2/merge_outputs.py +++ b/fractal_server/app/runner/v2/merge_outputs.py @@ -1,38 +1,35 @@ -from copy import copy - from fractal_server.app.runner.v2.deduplicate_list import deduplicate_list from fractal_server.app.runner.v2.task_interface import TaskOutput def merge_outputs(task_outputs: list[TaskOutput]) -> TaskOutput: + if len(task_outputs) == 0: + return TaskOutput() + final_image_list_updates = [] final_image_list_removals = [] - last_new_filters = None - for ind, task_output in enumerate(task_outputs): + for task_output in task_outputs: final_image_list_updates.extend(task_output.image_list_updates) final_image_list_removals.extend(task_output.image_list_removals) - # Check that all filters are the same - current_new_filters = task_output.filters - if ind == 0: - last_new_filters = copy(current_new_filters) - if current_new_filters != last_new_filters: - raise ValueError(f"{current_new_filters=} but {last_new_filters=}") - last_new_filters = copy(current_new_filters) + # Check that all type_filters are the same + if task_output.type_filters != task_outputs[0].type_filters: + raise ValueError( + f"{task_output.type_filters=} " + f"but {task_outputs[0].type_filters=}" + ) + # Note: the ordering of `image_list_removals` is not guaranteed final_image_list_updates = deduplicate_list(final_image_list_updates) - - additional_args = {} - if last_new_filters is not None: - additional_args["filters"] = last_new_filters + final_image_list_removals = list(set(final_image_list_removals)) final_output = TaskOutput( image_list_updates=final_image_list_updates, image_list_removals=final_image_list_removals, - **additional_args, + type_filters=task_outputs[0].type_filters, ) return final_output diff --git a/fractal_server/app/runner/v2/runner.py b/fractal_server/app/runner/v2/runner.py index 93c645d6c3..25fe1f80aa 100644 --- a/fractal_server/app/runner/v2/runner.py +++ b/fractal_server/app/runner/v2/runner.py @@ -8,7 +8,6 @@ from sqlalchemy.orm.attributes import flag_modified -from ....images import Filters from ....images import SingleImage from ....images.tools import filter_image_list from ....images.tools import find_image_by_zarr_url @@ -24,9 +23,11 @@ from fractal_server.app.models.v2 import WorkflowTaskV2 from fractal_server.app.schemas.v2.dataset import _DatasetHistoryItemV2 from fractal_server.app.schemas.v2.workflowtask import WorkflowTaskStatusTypeV2 +from fractal_server.images.models import AttributeFiltersType def execute_tasks_v2( + *, wf_task_list: list[WorkflowTaskV2], dataset: DatasetV2, executor: ThreadPoolExecutor, @@ -34,6 +35,7 @@ def execute_tasks_v2( workflow_dir_remote: Optional[Path] = None, logger_name: Optional[str] = None, submit_setup_call: Callable = no_op_submit_setup_call, + job_attribute_filters: AttributeFiltersType, ) -> None: logger = logging.getLogger(logger_name) @@ -47,7 +49,7 @@ def execute_tasks_v2( # Initialize local dataset attributes zarr_dir = dataset.zarr_dir tmp_images = deepcopy(dataset.images) - tmp_filters = deepcopy(dataset.filters) + tmp_type_filters = deepcopy(dataset.type_filters) for wftask in wf_task_list: task = wftask.task @@ -57,19 +59,20 @@ def execute_tasks_v2( # PRE TASK EXECUTION # Get filtered images - pre_filters = dict( - types=copy(tmp_filters["types"]), - attributes=copy(tmp_filters["attributes"]), - ) - pre_filters["types"].update(wftask.input_filters["types"]) - pre_filters["attributes"].update(wftask.input_filters["attributes"]) + pre_type_filters = copy(tmp_type_filters) + pre_type_filters.update(wftask.type_filters) filtered_images = filter_image_list( images=tmp_images, - filters=Filters(**pre_filters), + type_filters=pre_type_filters, + attribute_filters=job_attribute_filters, ) # Verify that filtered images comply with task input_types for image in filtered_images: - if not match_filter(image, Filters(types=task.input_types)): + if not match_filter( + image=image, + type_filters=task.input_types, + attribute_filters={}, + ): raise JobExecutionError( "Invalid filtered image list\n" f"Task input types: {task.input_types=}\n" @@ -259,38 +262,30 @@ def execute_tasks_v2( else: tmp_images.pop(img_search["index"]) - # Update filters.attributes: - # current + (task_output: not really, in current examples..) - if current_task_output.filters is not None: - tmp_filters["attributes"].update( - current_task_output.filters.attributes - ) - - # Find manifest ouptut types - types_from_manifest = task.output_types + # Update type_filters - # Find task-output types - if current_task_output.filters is not None: - types_from_task = current_task_output.filters.types - else: - types_from_task = {} + # Assign the type filters based on different sources + # (task manifest and post-execution task output) + type_filters_from_task_manifest = task.output_types + type_filters_from_task_output = current_task_output.type_filters # Check that key sets are disjoint - set_types_from_manifest = set(types_from_manifest.keys()) - set_types_from_task = set(types_from_task.keys()) - if not set_types_from_manifest.isdisjoint(set_types_from_task): - overlap = set_types_from_manifest.intersection(set_types_from_task) + keys_from_manifest = set(type_filters_from_task_manifest.keys()) + keys_from_task_output = set(type_filters_from_task_output.keys()) + if not keys_from_manifest.isdisjoint(keys_from_task_output): + overlap = keys_from_manifest.intersection(keys_from_task_output) raise JobExecutionError( "Some type filters are being set twice, " f"for task '{task_name}'.\n" - f"Types from task output: {types_from_task}\n" - f"Types from task maniest: {types_from_manifest}\n" + f"Types from task output: {type_filters_from_task_output}\n" + "Types from task manifest: " + f"{type_filters_from_task_manifest}\n" f"Overlapping keys: {overlap}" ) # Update filters.types - tmp_filters["types"].update(types_from_manifest) - tmp_filters["types"].update(types_from_task) + tmp_type_filters.update(type_filters_from_task_manifest) + tmp_type_filters.update(type_filters_from_task_output) # Write current dataset attributes (history, images, filters) into the # database. They can be used (1) to retrieve the latest state @@ -299,9 +294,13 @@ def execute_tasks_v2( with next(get_sync_db()) as db: db_dataset = db.get(DatasetV2, dataset.id) db_dataset.history[-1]["status"] = WorkflowTaskStatusTypeV2.DONE - db_dataset.filters = tmp_filters + db_dataset.type_filters = tmp_type_filters db_dataset.images = tmp_images - for attribute_name in ["filters", "history", "images"]: + for attribute_name in [ + "type_filters", + "history", + "images", + ]: flag_modified(db_dataset, attribute_name) db.merge(db_dataset) db.commit() diff --git a/fractal_server/app/runner/v2/task_interface.py b/fractal_server/app/runner/v2/task_interface.py index ab1a92aa90..f00523a72d 100644 --- a/fractal_server/app/runner/v2/task_interface.py +++ b/fractal_server/app/runner/v2/task_interface.py @@ -1,22 +1,47 @@ from typing import Any +from typing import Optional from pydantic import BaseModel from pydantic import Extra from pydantic import Field +from pydantic import root_validator from pydantic import validator from ....images import SingleImageTaskOutput -from fractal_server.images import Filters +from fractal_server.app.schemas._filter_validators import validate_type_filters +from fractal_server.app.schemas._validators import root_validate_dict_keys from fractal_server.urls import normalize_url +class LegacyFilters(BaseModel, extra=Extra.forbid): + """ + For fractal-server<2.11, task output could include both + `filters["attributes"]` and `filters["types"]`. In the new version + there is a single field, named `type_filters`. + The current schema is only used to convert old type filters into the + new form, but it will reject any attribute filters. + """ + + types: dict[str, bool] = Field(default_factory=dict) + _types = validator("types", allow_reuse=True)(validate_type_filters) + + class TaskOutput(BaseModel, extra=Extra.forbid): image_list_updates: list[SingleImageTaskOutput] = Field( default_factory=list ) image_list_removals: list[str] = Field(default_factory=list) - filters: Filters = Field(default_factory=Filters) + + filters: Optional[LegacyFilters] = None + type_filters: dict[str, bool] = Field(default_factory=dict) + + _dict_keys = root_validator(pre=True, allow_reuse=True)( + root_validate_dict_keys + ) + _type_filters = validator("type_filters", allow_reuse=True)( + validate_type_filters + ) def check_zarr_urls_are_unique(self) -> None: zarr_urls = [img.zarr_url for img in self.image_list_updates] @@ -37,6 +62,20 @@ def check_zarr_urls_are_unique(self) -> None: msg = f"{msg}\n{duplicate}" raise ValueError(msg) + @root_validator() + def update_legacy_filters(cls, values): + if values["filters"] is not None: + if values["type_filters"] != {}: + raise ValueError( + "Cannot set both (legacy) 'filters' and 'type_filters'." + ) + else: + # Convert legacy filters.types into new type_filters + values["type_filters"] = values["filters"].types + values["filters"] = None + + return values + @validator("image_list_removals") def normalize_paths(cls, v: list[str]) -> list[str]: return [normalize_url(zarr_url) for zarr_url in v] diff --git a/fractal_server/app/schemas/_filter_validators.py b/fractal_server/app/schemas/_filter_validators.py new file mode 100644 index 0000000000..c862b9c049 --- /dev/null +++ b/fractal_server/app/schemas/_filter_validators.py @@ -0,0 +1,47 @@ +from typing import Optional + +from ._validators import valdict_keys +from fractal_server.images.models import AttributeFiltersType + + +def validate_type_filters( + type_filters: Optional[dict[str, bool]] +) -> dict[str, bool]: + if type_filters is None: + raise ValueError("'type_filters' cannot be 'None'.") + + type_filters = valdict_keys("type_filters")(type_filters) + return type_filters + + +def validate_attribute_filters( + attribute_filters: Optional[AttributeFiltersType], +) -> AttributeFiltersType: + if attribute_filters is None: + raise ValueError("'attribute_filters' cannot be 'None'.") + + attribute_filters = valdict_keys("attribute_filters")(attribute_filters) + for key, values in attribute_filters.items(): + if values is None: + # values=None corresponds to not applying any filter for + # attribute `key` + pass + elif values == []: + # WARNING: in this case, no image can match with the current + # filter. In the future we may deprecate this possibility. + pass + else: + # values is a non-empty list, and its items must all be of the + # same scalar non-None type + _type = type(values[0]) + if not all(isinstance(value, _type) for value in values): + raise ValueError( + f"attribute_filters[{key}] has values with " + f"non-homogeneous types: {values}." + ) + if _type not in (int, float, str, bool): + raise ValueError( + f"attribute_filters[{key}] has values with " + f"invalid types: {values}." + ) + return attribute_filters diff --git a/fractal_server/app/schemas/_validators.py b/fractal_server/app/schemas/_validators.py index bdce6d3507..b2c9d6e65e 100644 --- a/fractal_server/app/schemas/_validators.py +++ b/fractal_server/app/schemas/_validators.py @@ -27,7 +27,7 @@ def val(string: Optional[str]) -> Optional[str]: return val -def valdictkeys(attribute: str): +def valdict_keys(attribute: str): def val(d: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]: """ Apply valstr to every key of the dictionary, and fail if there are @@ -38,7 +38,7 @@ def val(d: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]: new_keys = [valstr(f"{attribute}[{key}]")(key) for key in old_keys] if len(new_keys) != len(set(new_keys)): raise ValueError( - f"Dictionary contains multiple identical keys: {d}." + f"Dictionary contains multiple identical keys: '{d}'." ) for old_key, new_key in zip(old_keys, new_keys): if new_key != old_key: @@ -101,3 +101,14 @@ def val(must_be_unique: Optional[list]) -> Optional[list]: return must_be_unique return val + + +def root_validate_dict_keys(cls, object: dict) -> dict: + """ + For each dictionary in `object.values()`, + checks that that dictionary has only keys of type str. + """ + for dictionary in (v for v in object.values() if isinstance(v, dict)): + if not all(isinstance(key, str) for key in dictionary.keys()): + raise ValueError("Dictionary keys must be strings.") + return object diff --git a/fractal_server/app/schemas/v2/dataset.py b/fractal_server/app/schemas/v2/dataset.py index e680106e30..6b3eb3cb95 100644 --- a/fractal_server/app/schemas/v2/dataset.py +++ b/fractal_server/app/schemas/v2/dataset.py @@ -1,17 +1,22 @@ from datetime import datetime +from typing import Any from typing import Optional from pydantic import BaseModel from pydantic import Extra from pydantic import Field +from pydantic import root_validator from pydantic import validator +from .._filter_validators import validate_attribute_filters +from .._filter_validators import validate_type_filters +from .._validators import root_validate_dict_keys from .._validators import valstr from .dumps import WorkflowTaskDumpV2 from .project import ProjectReadV2 from .workflowtask import WorkflowTaskStatusTypeV2 -from fractal_server.images import Filters from fractal_server.images import SingleImage +from fractal_server.images.models import AttributeFiltersType from fractal_server.urls import normalize_url @@ -34,17 +39,29 @@ class DatasetCreateV2(BaseModel, extra=Extra.forbid): zarr_dir: Optional[str] = None - filters: Filters = Field(default_factory=Filters) + type_filters: dict[str, bool] = Field(default_factory=dict) + attribute_filters: AttributeFiltersType = Field(default_factory=dict) # Validators + + _dict_keys = root_validator(pre=True, allow_reuse=True)( + root_validate_dict_keys + ) + _type_filters = validator("type_filters", allow_reuse=True)( + validate_type_filters + ) + _attribute_filters = validator("attribute_filters", allow_reuse=True)( + validate_attribute_filters + ) + + _name = validator("name", allow_reuse=True)(valstr("name")) + @validator("zarr_dir") def normalize_zarr_dir(cls, v: Optional[str]) -> Optional[str]: if v is not None: return normalize_url(v) return v - _name = validator("name", allow_reuse=True)(valstr("name")) - class DatasetReadV2(BaseModel): @@ -59,24 +76,37 @@ class DatasetReadV2(BaseModel): timestamp_created: datetime zarr_dir: str - filters: Filters = Field(default_factory=Filters) + type_filters: dict[str, bool] + attribute_filters: AttributeFiltersType class DatasetUpdateV2(BaseModel, extra=Extra.forbid): name: Optional[str] zarr_dir: Optional[str] - filters: Optional[Filters] + type_filters: Optional[dict[str, bool]] + attribute_filters: Optional[dict[str, list[Any]]] # Validators + + _dict_keys = root_validator(pre=True, allow_reuse=True)( + root_validate_dict_keys + ) + _type_filters = validator("type_filters", allow_reuse=True)( + validate_type_filters + ) + _attribute_filters = validator("attribute_filters", allow_reuse=True)( + validate_attribute_filters + ) + + _name = validator("name", allow_reuse=True)(valstr("name")) + @validator("zarr_dir") def normalize_zarr_dir(cls, v: Optional[str]) -> Optional[str]: if v is not None: return normalize_url(v) return v - _name = validator("name", allow_reuse=True)(valstr("name")) - class DatasetImportV2(BaseModel, extra=Extra.forbid): """ @@ -86,15 +116,29 @@ class DatasetImportV2(BaseModel, extra=Extra.forbid): name: zarr_dir: images: - filters: + type_filters: + attribute_filters: """ name: str zarr_dir: str images: list[SingleImage] = Field(default_factory=list) - filters: Filters = Field(default_factory=Filters) + + type_filters: dict[str, bool] = Field(default_factory=dict) + attribute_filters: AttributeFiltersType = Field(default_factory=dict) # Validators + + _dict_keys = root_validator(pre=True, allow_reuse=True)( + root_validate_dict_keys + ) + _type_filters = validator("type_filters", allow_reuse=True)( + validate_type_filters + ) + _attribute_filters = validator("attribute_filters", allow_reuse=True)( + validate_attribute_filters + ) + @validator("zarr_dir") def normalize_zarr_dir(cls, v: str) -> str: return normalize_url(v) @@ -108,10 +152,12 @@ class DatasetExportV2(BaseModel): name: zarr_dir: images: - filters: + type_filters: + attribute_filters: """ name: str zarr_dir: str images: list[SingleImage] - filters: Filters + type_filters: dict[str, bool] + attribute_filters: AttributeFiltersType diff --git a/fractal_server/app/schemas/v2/dumps.py b/fractal_server/app/schemas/v2/dumps.py index d72feec30e..d2f61c20fe 100644 --- a/fractal_server/app/schemas/v2/dumps.py +++ b/fractal_server/app/schemas/v2/dumps.py @@ -13,7 +13,7 @@ from pydantic import BaseModel from pydantic import Extra -from fractal_server.images import Filters +from fractal_server.images.models import AttributeFiltersType class ProjectDumpV2(BaseModel, extra=Extra.forbid): @@ -39,19 +39,16 @@ class TaskDumpV2(BaseModel): class WorkflowTaskDumpV2(BaseModel): """ - Before v2.5.0, WorkflowTaskV2 could have `task_id=task=None` and - non-`None` `task_legacy_id` and `task_legacy`. Since these objects - may still exist in the database after version updates, we are setting - `task_id` and `task` to `Optional` to avoid response-validation errors + We do not include 'extra=Extra.forbid' because legacy data may include + 'input_filters' field and we want to avoid response-validation errors for the endpoints that GET datasets. - Ref issue #1783. """ id: int workflow_id: int order: Optional[int] - input_filters: Filters + type_filters: dict[str, bool] task_id: Optional[int] task: Optional[TaskDumpV2] @@ -71,4 +68,5 @@ class DatasetDumpV2(BaseModel, extra=Extra.forbid): timestamp_created: str zarr_dir: str - filters: Filters + type_filters: dict[str, bool] + attribute_filters: AttributeFiltersType diff --git a/fractal_server/app/schemas/v2/job.py b/fractal_server/app/schemas/v2/job.py index e09232ea44..6625c6318f 100644 --- a/fractal_server/app/schemas/v2/job.py +++ b/fractal_server/app/schemas/v2/job.py @@ -4,13 +4,18 @@ from pydantic import BaseModel from pydantic import Extra +from pydantic import Field +from pydantic import root_validator from pydantic import validator from pydantic.types import StrictStr +from .._filter_validators import validate_attribute_filters +from .._validators import root_validate_dict_keys from .._validators import valstr from .dumps import DatasetDumpV2 from .dumps import ProjectDumpV2 from .dumps import WorkflowDumpV2 +from fractal_server.images.models import AttributeFiltersType class JobStatusTypeV2(str, Enum): @@ -41,10 +46,18 @@ class JobCreateV2(BaseModel, extra=Extra.forbid): slurm_account: Optional[StrictStr] = None worker_init: Optional[str] + attribute_filters: AttributeFiltersType = Field(default_factory=dict) + # Validators _worker_init = validator("worker_init", allow_reuse=True)( valstr("worker_init") ) + _dict_keys = root_validator(pre=True, allow_reuse=True)( + root_validate_dict_keys + ) + _attribute_filters = validator("attribute_filters", allow_reuse=True)( + validate_attribute_filters + ) @validator("first_task_index", always=True) def first_task_index_non_negative(cls, v, values): @@ -99,6 +112,7 @@ class JobReadV2(BaseModel): first_task_index: Optional[int] last_task_index: Optional[int] worker_init: Optional[str] + attribute_filters: AttributeFiltersType class JobUpdateV2(BaseModel, extra=Extra.forbid): diff --git a/fractal_server/app/schemas/v2/task.py b/fractal_server/app/schemas/v2/task.py index 166aec66e4..0ae5377061 100644 --- a/fractal_server/app/schemas/v2/task.py +++ b/fractal_server/app/schemas/v2/task.py @@ -10,7 +10,7 @@ from pydantic import validator from fractal_server.app.schemas._validators import val_unique_list -from fractal_server.app.schemas._validators import valdictkeys +from fractal_server.app.schemas._validators import valdict_keys from fractal_server.app.schemas._validators import valstr from fractal_server.string_tools import validate_cmd @@ -66,25 +66,25 @@ def validate_commands(cls, values): _version = validator("version", allow_reuse=True)(valstr("version")) _meta_non_parallel = validator("meta_non_parallel", allow_reuse=True)( - valdictkeys("meta_non_parallel") + valdict_keys("meta_non_parallel") ) _meta_parallel = validator("meta_parallel", allow_reuse=True)( - valdictkeys("meta_parallel") + valdict_keys("meta_parallel") ) _args_schema_non_parallel = validator( "args_schema_non_parallel", allow_reuse=True - )(valdictkeys("args_schema_non_parallel")) + )(valdict_keys("args_schema_non_parallel")) _args_schema_parallel = validator( "args_schema_parallel", allow_reuse=True - )(valdictkeys("args_schema_parallel")) + )(valdict_keys("args_schema_parallel")) _args_schema_version = validator("args_schema_version", allow_reuse=True)( valstr("args_schema_version") ) _input_types = validator("input_types", allow_reuse=True)( - valdictkeys("input_types") + valdict_keys("input_types") ) _output_types = validator("output_types", allow_reuse=True)( - valdictkeys("output_types") + valdict_keys("output_types") ) _category = validator("category", allow_reuse=True)( @@ -158,10 +158,10 @@ def val_is_dict(cls, v): "command_non_parallel", allow_reuse=True )(valstr("command_non_parallel")) _input_types = validator("input_types", allow_reuse=True)( - valdictkeys("input_types") + valdict_keys("input_types") ) _output_types = validator("output_types", allow_reuse=True)( - valdictkeys("output_types") + valdict_keys("output_types") ) _category = validator("category", allow_reuse=True)( diff --git a/fractal_server/app/schemas/v2/task_group.py b/fractal_server/app/schemas/v2/task_group.py index 254f6e9d17..c744050227 100644 --- a/fractal_server/app/schemas/v2/task_group.py +++ b/fractal_server/app/schemas/v2/task_group.py @@ -8,7 +8,7 @@ from pydantic import validator from .._validators import val_absolute_path -from .._validators import valdictkeys +from .._validators import valdict_keys from .._validators import valstr from .task import TaskReadV2 @@ -57,7 +57,7 @@ class TaskGroupCreateV2(BaseModel, extra=Extra.forbid): ) _pinned_package_versions = validator( "pinned_package_versions", allow_reuse=True - )(valdictkeys("pinned_package_versions")) + )(valdict_keys("pinned_package_versions")) _pip_extras = validator("pip_extras", allow_reuse=True)( valstr("pip_extras") ) diff --git a/fractal_server/app/schemas/v2/workflowtask.py b/fractal_server/app/schemas/v2/workflowtask.py index 46a6e51d2a..d8c2f37d15 100644 --- a/fractal_server/app/schemas/v2/workflowtask.py +++ b/fractal_server/app/schemas/v2/workflowtask.py @@ -6,14 +6,16 @@ from pydantic import BaseModel from pydantic import Extra from pydantic import Field +from pydantic import root_validator from pydantic import validator -from .._validators import valdictkeys +from .._filter_validators import validate_type_filters +from .._validators import root_validate_dict_keys +from .._validators import valdict_keys from .task import TaskExportV2 from .task import TaskImportV2 from .task import TaskImportV2Legacy from .task import TaskReadV2 -from fractal_server.images import Filters RESERVED_ARGUMENTS = {"zarr_dir", "zarr_url", "zarr_urls", "init_args"} @@ -43,21 +45,28 @@ class WorkflowTaskCreateV2(BaseModel, extra=Extra.forbid): meta_parallel: Optional[dict[str, Any]] args_non_parallel: Optional[dict[str, Any]] args_parallel: Optional[dict[str, Any]] - input_filters: Filters = Field(default_factory=Filters) + type_filters: dict[str, bool] = Field(default_factory=dict) # Validators + _dict_keys = root_validator(pre=True, allow_reuse=True)( + root_validate_dict_keys + ) + _type_filters = validator("type_filters", allow_reuse=True)( + validate_type_filters + ) + _meta_non_parallel = validator("meta_non_parallel", allow_reuse=True)( - valdictkeys("meta_non_parallel") + valdict_keys("meta_non_parallel") ) _meta_parallel = validator("meta_parallel", allow_reuse=True)( - valdictkeys("meta_parallel") + valdict_keys("meta_parallel") ) @validator("args_non_parallel") def validate_args_non_parallel(cls, value): if value is None: return - valdictkeys("args_non_parallel")(value) + valdict_keys("args_non_parallel")(value) args_keys = set(value.keys()) intersect_keys = RESERVED_ARGUMENTS.intersection(args_keys) if intersect_keys: @@ -71,7 +80,7 @@ def validate_args_non_parallel(cls, value): def validate_args_parallel(cls, value): if value is None: return - valdictkeys("args_parallel")(value) + valdict_keys("args_parallel")(value) args_keys = set(value.keys()) intersect_keys = RESERVED_ARGUMENTS.intersection(args_keys) if intersect_keys: @@ -101,7 +110,7 @@ class WorkflowTaskReadV2(BaseModel): args_non_parallel: Optional[dict[str, Any]] args_parallel: Optional[dict[str, Any]] - input_filters: Filters + type_filters: dict[str, bool] task_type: str task_id: int @@ -118,21 +127,28 @@ class WorkflowTaskUpdateV2(BaseModel, extra=Extra.forbid): meta_parallel: Optional[dict[str, Any]] args_non_parallel: Optional[dict[str, Any]] args_parallel: Optional[dict[str, Any]] - input_filters: Optional[Filters] + type_filters: Optional[dict[str, bool]] # Validators + _dict_keys = root_validator(pre=True, allow_reuse=True)( + root_validate_dict_keys + ) + _type_filters = validator("type_filters", allow_reuse=True)( + validate_type_filters + ) + _meta_non_parallel = validator("meta_non_parallel", allow_reuse=True)( - valdictkeys("meta_non_parallel") + valdict_keys("meta_non_parallel") ) _meta_parallel = validator("meta_parallel", allow_reuse=True)( - valdictkeys("meta_parallel") + valdict_keys("meta_parallel") ) @validator("args_non_parallel") def validate_args_non_parallel(cls, value): if value is None: return - valdictkeys("args_non_parallel")(value) + valdict_keys("args_non_parallel")(value) args_keys = set(value.keys()) intersect_keys = RESERVED_ARGUMENTS.intersection(args_keys) if intersect_keys: @@ -146,7 +162,7 @@ def validate_args_non_parallel(cls, value): def validate_args_parallel(cls, value): if value is None: return - valdictkeys("args_parallel")(value) + valdict_keys("args_parallel")(value) args_keys = set(value.keys()) intersect_keys = RESERVED_ARGUMENTS.intersection(args_keys) if intersect_keys: @@ -164,21 +180,28 @@ class WorkflowTaskImportV2(BaseModel, extra=Extra.forbid): args_non_parallel: Optional[dict[str, Any]] = None args_parallel: Optional[dict[str, Any]] = None - input_filters: Optional[Filters] = None + type_filters: Optional[dict[str, bool]] = None task: Union[TaskImportV2, TaskImportV2Legacy] + _dict_keys = root_validator(pre=True, allow_reuse=True)( + root_validate_dict_keys + ) + _type_filters = validator("type_filters", allow_reuse=True)( + validate_type_filters + ) + _meta_non_parallel = validator("meta_non_parallel", allow_reuse=True)( - valdictkeys("meta_non_parallel") + valdict_keys("meta_non_parallel") ) _meta_parallel = validator("meta_parallel", allow_reuse=True)( - valdictkeys("meta_parallel") + valdict_keys("meta_parallel") ) _args_non_parallel = validator("args_non_parallel", allow_reuse=True)( - valdictkeys("args_non_parallel") + valdict_keys("args_non_parallel") ) _args_parallel = validator("args_parallel", allow_reuse=True)( - valdictkeys("args_parallel") + valdict_keys("args_parallel") ) @@ -188,6 +211,6 @@ class WorkflowTaskExportV2(BaseModel): meta_parallel: Optional[dict[str, Any]] = None args_non_parallel: Optional[dict[str, Any]] = None args_parallel: Optional[dict[str, Any]] = None - input_filters: Filters = Field(default_factory=Filters) + type_filters: dict[str, bool] = Field(default_factory=dict) task: TaskExportV2 diff --git a/fractal_server/data_migrations/2_11_0.py b/fractal_server/data_migrations/2_11_0.py new file mode 100644 index 0000000000..6033b103c5 --- /dev/null +++ b/fractal_server/data_migrations/2_11_0.py @@ -0,0 +1,67 @@ +import logging + +from sqlalchemy.orm.attributes import flag_modified +from sqlmodel import select + +from fractal_server.app.db import get_sync_db +from fractal_server.app.models import DatasetV2 +from fractal_server.app.models import JobV2 +from fractal_server.app.models import WorkflowTaskV2 + +logger = logging.getLogger("fix_db") +logger.setLevel(logging.INFO) + + +def fix_db(): + + logger.info("START execution of fix_db function") + + with next(get_sync_db()) as db: + + # DatasetV2.filters + # DatasetV2.history[].workflowtask.input_filters + stm = select(DatasetV2).order_by(DatasetV2.id) + datasets = db.execute(stm).scalars().all() + for ds in datasets: + ds.attribute_filters = ds.filters["attributes"] + ds.type_filters = ds.filters["types"] + ds.filters = None + for i, h in enumerate(ds.history): + ds.history[i]["workflowtask"]["type_filters"] = h[ + "workflowtask" + ]["input_filters"]["types"] + flag_modified(ds, "history") + db.add(ds) + logger.info(f"Fixed filters in DatasetV2[{ds.id}]") + + # WorkflowTaskV2.input_filters + stm = select(WorkflowTaskV2).order_by(WorkflowTaskV2.id) + wftasks = db.execute(stm).scalars().all() + for wft in wftasks: + wft.type_filters = wft.input_filters["types"] + if wft.input_filters["attributes"]: + logger.warning( + f"Removing WorkflowTaskV2[{wft.id}].input_filters" + f"['attributes'] = {wft.input_filters['attributes']}" + ) + wft.input_filters = None + flag_modified(wft, "input_filters") + db.add(wft) + logger.info(f"Fixed filters in WorkflowTaskV2[{wft.id}]") + + # JOBS V2 + stm = select(JobV2).order_by(JobV2.id) + jobs = db.execute(stm).scalars().all() + for job in jobs: + job.dataset_dump["type_filters"] = job.dataset_dump["filters"][ + "types" + ] + job.dataset_dump["attribute_filters"] = job.dataset_dump[ + "filters" + ]["attributes"] + job.dataset_dump.pop("filters") + flag_modified(job, "dataset_dump") + logger.info(f"Fixed filters in JobV2[{job.id}].datasetdump") + + db.commit() + logger.info("Changes committed.") diff --git a/fractal_server/images/__init__.py b/fractal_server/images/__init__.py index 1d2e8dc55d..960ac2105a 100644 --- a/fractal_server/images/__init__.py +++ b/fractal_server/images/__init__.py @@ -1,4 +1,3 @@ -from .models import Filters # noqa: F401 from .models import SingleImage # noqa: F401 from .models import SingleImageTaskOutput # noqa: F401 from .models import SingleImageUpdate # noqa: F401 diff --git a/fractal_server/images/models.py b/fractal_server/images/models.py index 32e289d42d..4536d4f495 100644 --- a/fractal_server/images/models.py +++ b/fractal_server/images/models.py @@ -3,15 +3,16 @@ from typing import Union from pydantic import BaseModel -from pydantic import Extra from pydantic import Field from pydantic import validator -from fractal_server.app.schemas._validators import valdictkeys +from fractal_server.app.schemas._validators import valdict_keys from fractal_server.urls import normalize_url +AttributeFiltersType = dict[str, Optional[list[Any]]] -class SingleImageBase(BaseModel): + +class _SingleImageBase(BaseModel): """ Base for SingleImage and SingleImageTaskOutput. @@ -30,9 +31,9 @@ class SingleImageBase(BaseModel): # Validators _attributes = validator("attributes", allow_reuse=True)( - valdictkeys("attributes") + valdict_keys("attributes") ) - _types = validator("types", allow_reuse=True)(valdictkeys("types")) + _types = validator("types", allow_reuse=True)(valdict_keys("types")) @validator("zarr_url") def normalize_zarr_url(cls, v: str) -> str: @@ -44,7 +45,7 @@ def normalize_orig(cls, v: Optional[str]) -> Optional[str]: return normalize_url(v) -class SingleImageTaskOutput(SingleImageBase): +class SingleImageTaskOutput(_SingleImageBase): """ `SingleImageBase`, with scalar `attributes` values (`None` included). """ @@ -63,7 +64,7 @@ def validate_attributes( return v -class SingleImage(SingleImageBase): +class SingleImage(_SingleImageBase): """ `SingleImageBase`, with scalar `attributes` values (`None` excluded). """ @@ -83,8 +84,8 @@ def validate_attributes( class SingleImageUpdate(BaseModel): zarr_url: str - attributes: Optional[dict[str, Any]] - types: Optional[dict[str, bool]] + attributes: Optional[dict[str, Any]] = None + types: Optional[dict[str, bool]] = None @validator("zarr_url") def normalize_zarr_url(cls, v: str) -> str: @@ -96,7 +97,7 @@ def validate_attributes( ) -> dict[str, Union[int, float, str, bool]]: if v is not None: # validate keys - valdictkeys("attributes")(v) + valdict_keys("attributes")(v) # validate values for key, value in v.items(): if not isinstance(value, (int, float, str, bool)): @@ -107,28 +108,4 @@ def validate_attributes( ) return v - _types = validator("types", allow_reuse=True)(valdictkeys("types")) - - -class Filters(BaseModel, extra=Extra.forbid): - attributes: dict[str, Any] = Field(default_factory=dict) - types: dict[str, bool] = Field(default_factory=dict) - - # Validators - _attributes = validator("attributes", allow_reuse=True)( - valdictkeys("attributes") - ) - _types = validator("types", allow_reuse=True)(valdictkeys("types")) - - @validator("attributes") - def validate_attributes( - cls, v: dict[str, Any] - ) -> dict[str, Union[int, float, str, bool, None]]: - for key, value in v.items(): - if not isinstance(value, (int, float, str, bool, type(None))): - raise ValueError( - f"Filters.attributes[{key}] must be a scalar " - "(int, float, str, bool, or None). " - f"Given {value} ({type(value)})" - ) - return v + _types = validator("types", allow_reuse=True)(valdict_keys("types")) diff --git a/fractal_server/images/tools.py b/fractal_server/images/tools.py index 6452d08dd9..aaa0cc6bc8 100644 --- a/fractal_server/images/tools.py +++ b/fractal_server/images/tools.py @@ -4,8 +4,7 @@ from typing import Optional from typing import Union -from fractal_server.images import Filters - +from fractal_server.images.models import AttributeFiltersType ImageSearch = dict[Literal["image", "index"], Union[int, dict[str, Any]]] @@ -33,52 +32,69 @@ def find_image_by_zarr_url( return dict(image=copy(images[ind]), index=ind) -def match_filter(image: dict[str, Any], filters: Filters) -> bool: +def match_filter( + *, + image: dict[str, Any], + type_filters: dict[str, bool], + attribute_filters: AttributeFiltersType, +) -> bool: """ Find whether an image matches a filter set. Arguments: image: A single image. - filters: A set of filters. + type_filters: + attribute_filters: Returns: Whether the image matches the filter set. """ + # Verify match with types (using a False default) - for key, value in filters.types.items(): + for key, value in type_filters.items(): if image["types"].get(key, False) != value: return False - # Verify match with attributes (only for non-None filters) - for key, value in filters.attributes.items(): - if value is None: + + # Verify match with attributes (only for not-None filters) + for key, values in attribute_filters.items(): + if values is None: continue - if image["attributes"].get(key) != value: + if image["attributes"].get(key) not in values: return False + return True def filter_image_list( images: list[dict[str, Any]], - filters: Filters, + type_filters: Optional[dict[str, bool]] = None, + attribute_filters: Optional[AttributeFiltersType] = None, ) -> list[dict[str, Any]]: """ Compute a sublist with images that match a filter set. Arguments: images: A list of images. - filters: A set of filters. + type_filters: + attribute_filters: Returns: List of the `images` elements which match the filter set. """ # When no filter is provided, return all images - if filters.attributes == {} and filters.types == {}: + if type_filters is None and attribute_filters is None: return images + actual_type_filters = type_filters or {} + actual_attribute_filters = attribute_filters or {} filtered_images = [ copy(this_image) for this_image in images - if match_filter(this_image, filters=filters) + if match_filter( + image=this_image, + type_filters=actual_type_filters, + attribute_filters=actual_attribute_filters, + ) ] return filtered_images diff --git a/fractal_server/migrations/versions/db09233ad13a_split_filters_and_keep_old_columns.py b/fractal_server/migrations/versions/db09233ad13a_split_filters_and_keep_old_columns.py new file mode 100644 index 0000000000..5cc9848fe4 --- /dev/null +++ b/fractal_server/migrations/versions/db09233ad13a_split_filters_and_keep_old_columns.py @@ -0,0 +1,96 @@ +"""split filters and keep old columns + +Revision ID: db09233ad13a +Revises: 316140ff7ee1 +Create Date: 2025-01-14 14:50:46.007222 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "db09233ad13a" +down_revision = "316140ff7ee1" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("datasetv2", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "type_filters", sa.JSON(), server_default="{}", nullable=False + ) + ) + batch_op.add_column( + sa.Column( + "attribute_filters", + sa.JSON(), + server_default="{}", + nullable=False, + ) + ) + batch_op.alter_column( + "filters", + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=True, + server_default="null", + ) + + with op.batch_alter_table("jobv2", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "attribute_filters", + sa.JSON(), + server_default="{}", + nullable=False, + ) + ) + + with op.batch_alter_table("workflowtaskv2", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "type_filters", sa.JSON(), server_default="{}", nullable=False + ) + ) + batch_op.alter_column( + "input_filters", + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=True, + server_default="null", + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("workflowtaskv2", schema=None) as batch_op: + batch_op.alter_column( + "input_filters", + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=False, + existing_server_default=sa.text( + '\'{"attributes": {}, "types": {}}\'::json' + ), + ) + batch_op.drop_column("type_filters") + + with op.batch_alter_table("jobv2", schema=None) as batch_op: + batch_op.drop_column("attribute_filters") + + with op.batch_alter_table("datasetv2", schema=None) as batch_op: + batch_op.alter_column( + "filters", + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=False, + existing_server_default=sa.text( + '\'{"attributes": {}, "types": {}}\'::json' + ), + ) + batch_op.drop_column("attribute_filters") + batch_op.drop_column("type_filters") + + # ### end Alembic commands ### diff --git a/tests/fixtures_server_v2.py b/tests/fixtures_server_v2.py index d297531e90..eaf2cbc6cf 100644 --- a/tests/fixtures_server_v2.py +++ b/tests/fixtures_server_v2.py @@ -147,7 +147,7 @@ async def __job_factory( dataset_id=dataset_id, workflow_id=workflow_id, dataset_dump=json.loads( - dataset.json(exclude={"history", "images"}) + dataset.json(exclude={"history", "images", "filters"}) ), workflow_dump=json.loads(workflow.json(exclude={"task_list"})), project_dump=json.loads(project.json(exclude={"user_list"})), diff --git a/tests/v2/01_schemas/test_schemas_dataset.py b/tests/v2/01_schemas/test_schemas_dataset.py index 883ca75844..ae0b647eb2 100644 --- a/tests/v2/01_schemas/test_schemas_dataset.py +++ b/tests/v2/01_schemas/test_schemas_dataset.py @@ -1,4 +1,5 @@ import pytest +from devtools import debug from pydantic import ValidationError from fractal_server.app.models.v2 import DatasetV2 @@ -10,29 +11,66 @@ from fractal_server.urls import normalize_url -async def test_schemas_dataset_v2(): +VALID_ATTRIBUTE_FILTERS = ( + {}, + {"key1": []}, + {"key1": ["A"]}, + {"key1": ["A", "B"]}, + {"key1": [1, 2]}, + {"key1": [True, False]}, + {"key1": [1.5, -1.2]}, + {"key1": None}, + {"key1": [1, 2], "key2": ["A", "B"]}, +) + +INVALID_ATTRIBUTE_FILTERS = ( + {True: ["value"]}, # non-string key + {1: ["value"]}, # non-string key + {"key1": 1}, # not a list + {"key1": True}, # not a list + {"key1": "something"}, # not a list + {"key1": [1], " key1": [1]}, # non-unique normalized keys + {"key1": [None]}, # None value + {"key1": [1, 1.0]}, # non-homogeneous types + {"key1": [1, "a"]}, # non-homogeneous types + {"key1": [[1, 2], [1, 2]]}, # non-scalar type + # {"key1": [1, True]}, # non-homogeneous types - FIXME unsupported +) + + +@pytest.mark.parametrize("attribute_filters", VALID_ATTRIBUTE_FILTERS) +def test_valid_attribute_filters(attribute_filters: dict): + DatasetCreateV2( + name="x", + zarr_dir="/x", + attribute_filters=attribute_filters, + ) - project = ProjectV2(id=1, name="project") - # Create - with pytest.raises(ValidationError): - # Non-scalar attribute +@pytest.mark.parametrize("attribute_filters", INVALID_ATTRIBUTE_FILTERS) +def test_invalid_attribute_filters(attribute_filters: dict): + debug(attribute_filters) + with pytest.raises(ValueError) as e: DatasetCreateV2( - name="name", - zarr_dir="/zarr", - filters={"attributes": {"x": [1, 0]}}, + name="x", + zarr_dir="/x", + attribute_filters=attribute_filters, ) + debug(e.value) + + +async def test_schemas_dataset_v2(): + + project = ProjectV2(id=1, name="project") + with pytest.raises(ValidationError): # Non-boolean types - DatasetCreateV2( - name="name", zarr_dir="/zarr", filters={"types": {"a": "b"}} - ) + DatasetCreateV2(name="name", zarr_dir="/zarr", type_filters={"a": "b"}) # Test zarr_dir=None is valid DatasetCreateV2(name="name", zarr_dir=None) dataset_create = DatasetCreateV2( name="name", - filters={"attributes": {"x": 10}}, zarr_dir="/tmp/", ) assert dataset_create.zarr_dir == normalize_url(dataset_create.zarr_dir) @@ -42,7 +80,7 @@ async def test_schemas_dataset_v2(): dataset_import = DatasetImportV2( name="name", - filters={"attributes": {"x": 10}}, + attribute_filters={"x": [10]}, zarr_dir="/tmp/", images=[{"zarr_url": "/tmp/image/"}], ) @@ -61,10 +99,14 @@ async def test_schemas_dataset_v2(): # Update - # validation accepts `zarr_dir` and `filters` as None, but not `name` - DatasetUpdateV2(zarr_dir=None, filters=None) + # validation accepts `zarr_dir` as None, but not `name` and `filters` + DatasetUpdateV2(zarr_dir=None) + with pytest.raises(ValidationError): + DatasetUpdateV2(name=None) + with pytest.raises(ValidationError): + DatasetUpdateV2(attribute_filters=None) with pytest.raises(ValidationError): - DatasetUpdateV2(name=None, zarr_dir=None, filters=None) + DatasetUpdateV2(type_filters=None) dataset_update = DatasetUpdateV2(name="new name", zarr_dir="/zarr/") assert not dataset_update.zarr_dir.endswith("/") diff --git a/tests/v2/01_schemas/test_unit_schemas_v2.py b/tests/v2/01_schemas/test_unit_schemas_v2.py index cc84d89372..66d7c5316e 100644 --- a/tests/v2/01_schemas/test_unit_schemas_v2.py +++ b/tests/v2/01_schemas/test_unit_schemas_v2.py @@ -11,7 +11,6 @@ from fractal_server.app.schemas.v2 import WorkflowCreateV2 from fractal_server.app.schemas.v2 import WorkflowTaskCreateV2 from fractal_server.app.schemas.v2 import WorkflowTaskDumpV2 -from fractal_server.images import Filters def test_extra_on_create_models(): @@ -103,7 +102,7 @@ def test_workflow_task_dump(): WorkflowTaskDumpV2( id=1, workflow_id=1, - input_filters=Filters(), + type_filters={}, task_id=1, task=TaskDumpV2( id=1, diff --git a/tests/v2/03_api/test_api_dataset.py b/tests/v2/03_api/test_api_dataset.py index 881723aac7..8ddcc68581 100644 --- a/tests/v2/03_api/test_api_dataset.py +++ b/tests/v2/03_api/test_api_dataset.py @@ -58,7 +58,7 @@ async def test_new_dataset_v2(client, MockCurrentUser): f"api/v2/project/{p2_id}/dataset/", json=dict( name="dataset", - filters={"attributes": {"x": 10}}, + attribute_filters={"x": [10]}, zarr_dir="/tmp", ), ) @@ -374,7 +374,11 @@ async def test_patch_dataset( ): async with MockCurrentUser() as user: project = await project_factory_v2(user) - dataset = await dataset_factory_v2(project_id=project.id) + dataset = await dataset_factory_v2( + project_id=project.id, + attribute_filters={"a": [1, 2], "b": [3]}, + type_filters={"c": True, "d": False}, + ) project_id = project.id dataset_id = dataset.id @@ -419,6 +423,62 @@ async def test_patch_dataset( ) assert res.status_code == 422 + # Patch `attribute_filters` + res = await client.get( + f"{PREFIX}/project/{project_id}/dataset/{dataset_id}/" + ) + assert res.json()["attribute_filters"] == {"a": [1, 2], "b": [3]} + assert res.json()["type_filters"] == {"c": True, "d": False} + res = await client.patch( + f"{PREFIX}/project/{project_id}/dataset/{dataset_id}/", + json=dict(attribute_filters={"c": 3}), + ) + assert res.status_code == 422 + res = await client.patch( + f"{PREFIX}/project/{project_id}/dataset/{dataset_id}/", + json=dict(attribute_filters={"c": [3]}), + ) + assert res.status_code == 200 + assert res.json()["attribute_filters"] == {"c": [3]} + assert res.json()["type_filters"] == {"c": True, "d": False} + res = await client.patch( + f"{PREFIX}/project/{project_id}/dataset/{dataset_id}/", + json=dict(type_filters={"x": 42}), + ) + assert res.status_code == 422 + res = await client.patch( + f"{PREFIX}/project/{project_id}/dataset/{dataset_id}/", + json=dict(type_filters={"x": True}), + ) + assert res.status_code == 200 + assert res.json()["name"] == "something-new" + assert res.json()["zarr_dir"] == "/new_zarr_dir" + assert res.json()["attribute_filters"] == {"c": [3]} + assert res.json()["type_filters"] == {"x": True} + res = await client.patch( + f"{PREFIX}/project/{project_id}/dataset/{dataset_id}/", + json=dict(attribute_filters={}, type_filters=None), + ) + assert res.status_code == 422 + res = await client.patch( + f"{PREFIX}/project/{project_id}/dataset/{dataset_id}/", + json=dict(attribute_filters={}), + ) + assert res.status_code == 200 + assert res.json()["name"] == "something-new" + assert res.json()["zarr_dir"] == "/new_zarr_dir" + assert res.json()["attribute_filters"] == {} + assert res.json()["type_filters"] == {"x": True} + res = await client.patch( + f"{PREFIX}/project/{project_id}/dataset/{dataset_id}/", + json=dict(type_filters={}), + ) + assert res.status_code == 200 + assert res.json()["name"] == "something-new" + assert res.json()["zarr_dir"] == "/new_zarr_dir" + assert res.json()["attribute_filters"] == {} + assert res.json()["type_filters"] == {} + async def test_dataset_export( app, client, MockCurrentUser, project_factory_v2, dataset_factory_v2 @@ -442,10 +502,9 @@ async def test_dataset_export( assert res_dataset["name"] == "My Dataset" assert res_dataset["zarr_dir"] == "/zarr_dir" assert res_dataset["images"] == IMAGES - assert res_dataset["filters"] == dict( - attributes={}, - types={}, - ) + assert res_dataset["attribute_filters"] == dict() + assert res_dataset["type_filters"] == dict() + assert "filters" not in res_dataset.keys() async def test_dataset_import( @@ -460,10 +519,8 @@ async def test_dataset_import( name="Dataset", zarr_dir="/somewhere/invalid/", images=IMAGES, - filters=dict( - attributes={}, - types={}, - ), + attribute_filters={}, + type_filters={}, ) res = await client.post( f"{PREFIX}/project/{project.id}/dataset/import/", json=dataset @@ -476,10 +533,8 @@ async def test_dataset_import( name="Dataset", zarr_dir=ZARR_DIR, images=IMAGES, - filters=dict( - attributes={}, - types={}, - ), + attribute_filters=dict(), + type_filters=dict(), ) res = await client.post( f"{PREFIX}/project/{project.id}/dataset/import/", json=dataset @@ -489,7 +544,5 @@ async def test_dataset_import( debug(res_dataset) assert res_dataset["name"] == "Dataset" assert res_dataset["zarr_dir"] == ZARR_DIR - assert res_dataset["filters"] == dict( - attributes={}, - types={}, - ) + assert res_dataset["attribute_filters"] == dict() + assert res_dataset["type_filters"] == dict() diff --git a/tests/v2/03_api/test_api_dataset_images.py b/tests/v2/03_api/test_api_dataset_images.py index 534c39ddff..c23a302abd 100644 --- a/tests/v2/03_api/test_api_dataset_images.py +++ b/tests/v2/03_api/test_api_dataset_images.py @@ -1,4 +1,3 @@ -from fractal_server.images import Filters from fractal_server.images import SingleImage from fractal_server.images.tools import find_image_by_zarr_url from fractal_server.images.tools import match_filter @@ -136,14 +135,16 @@ async def test_query_images( # use `query.attributes` res = await client.post( f"{PREFIX}/project/{project.id}/dataset/{dataset.id}/images/query/", - json=dict(filters=dict(types=dict(flag=False))), + json=dict(type_filters=dict(flag=False)), ) assert res.status_code == 200 assert res.json()["total_count"] == len( [ image for image in images - if match_filter(image, Filters(types={"flag": False})) + if match_filter( + image=image, type_filters={"flag": False}, attribute_filters={} + ) ] ) assert res.json()["current_page"] == 1 @@ -154,14 +155,16 @@ async def test_query_images( res = await client.post( f"{PREFIX}/project/{project.id}/dataset/{dataset.id}/images/query/" "?page_size=1000", - json=dict(filters=dict(types={"flag": True})), + json=dict(type_filters={"flag": True}), ) assert res.status_code == 200 assert res.json()["total_count"] == len( [ image for image in images - if match_filter(image, filters=Filters(types={"flag": 1})) + if match_filter( + image=image, type_filters={"flag": 1}, attribute_filters={} + ) ] ) assert res.json()["page_size"] == 1000 @@ -171,7 +174,7 @@ async def test_query_images( res = await client.post( f"{PREFIX}/project/{project.id}/dataset/{dataset.id}/images/query/" "?page_size=42", - json=dict(filters=dict(types={"foo": True})), + json=dict(type_filters={"foo": True}), ) assert res.status_code == 200 assert res.json()["total_count"] == 0 @@ -180,7 +183,7 @@ async def test_query_images( assert res.json()["images"] == [] res = await client.post( f"{PREFIX}/project/{project.id}/dataset/{dataset.id}/images/query/", - json=dict(filters=dict(types={"foo": False})), + json=dict(type_filters={"foo": False}), ) assert res.status_code == 200 assert res.json()["total_count"] == N diff --git a/tests/v2/03_api/test_api_job.py b/tests/v2/03_api/test_api_job.py index 946f2c3e10..6d41720c2b 100644 --- a/tests/v2/03_api/test_api_job.py +++ b/tests/v2/03_api/test_api_job.py @@ -353,7 +353,9 @@ async def test_project_apply_workflow_subset( **json.loads(workflow.json(exclude={"task_list"})) ).dict() expected_dataset_dump = DatasetDumpV2( - **json.loads(dataset1.json(exclude={"history", "images"})) + **json.loads( + dataset1.json(exclude={"history", "images", "filters"}) + ) ).dict() assert res.json()["project_dump"] == expected_project_dump assert res.json()["workflow_dump"] == expected_workflow_dump diff --git a/tests/v2/03_api/test_api_workflow_task.py b/tests/v2/03_api/test_api_workflow_task.py index 8148eba797..1f14291c51 100644 --- a/tests/v2/03_api/test_api_workflow_task.py +++ b/tests/v2/03_api/test_api_workflow_task.py @@ -8,7 +8,7 @@ from fractal_server.app.models import UserGroup from fractal_server.app.models.v2 import WorkflowTaskV2 from fractal_server.app.models.v2 import WorkflowV2 -from fractal_server.images.models import Filters + PREFIX = "api/v2" @@ -296,10 +296,7 @@ async def test_patch_workflow_task( args_parallel={"c": 333, "d": 444}, meta_non_parallel={"non": "parallel"}, meta_parallel={"executor": "cpu-low"}, - input_filters={ - "attributes": {"a": "b", "c": "d"}, - "types": {"e": True, "f": False, "g": True}, - }, + type_filters={"e": True, "f": False, "g": True}, ) res = await client.patch( f"{PREFIX}/project/{project.id}/workflow/{workflow['id']}/" @@ -323,9 +320,7 @@ async def test_patch_workflow_task( assert ( patched_workflow_task["meta_parallel"] == payload["meta_parallel"] ) - assert ( - patched_workflow_task["input_filters"] == payload["input_filters"] - ) + assert patched_workflow_task["type_filters"] == payload["type_filters"] assert res.status_code == 200 payload_up = dict( @@ -377,15 +372,13 @@ async def test_patch_workflow_task( f"wftask/{workflow['task_list'][0]['id']}/", json=dict( args_non_parallel={}, - input_filters=Filters().dict(), + type_filters=dict(), ), ) patched_workflow_task = res.json() debug(patched_workflow_task["args_non_parallel"]) assert patched_workflow_task["args_non_parallel"] is None - assert patched_workflow_task["input_filters"] == dict( - attributes={}, types={} - ) + assert patched_workflow_task["type_filters"] == dict() assert res.status_code == 200 # Test 422 diff --git a/tests/v2/03_api/test_unit_issue_1783.py b/tests/v2/03_api/test_unit_issue_1783.py index 0b544cfd98..c15eb35e07 100644 --- a/tests/v2/03_api/test_unit_issue_1783.py +++ b/tests/v2/03_api/test_unit_issue_1783.py @@ -14,7 +14,7 @@ def test_issue_1783(): id=1, workflow_id=1, order=0, - input_filters=dict(attributes=dict(), types=dict()), + type_filters=dict(), task_id=1, task=dict( id=1, @@ -36,7 +36,7 @@ def test_issue_1783(): id=1, workflow_id=1, order=0, - input_filters=dict(attributes=dict(), types=dict()), + type_filters=dict(), task_legacy_id=1, task_legacy=dict( id=1, diff --git a/tests/v2/04_runner/aux_get_dataset_attrs.py b/tests/v2/04_runner/aux_get_dataset_attrs.py index 6fed8b0185..93ee375505 100644 --- a/tests/v2/04_runner/aux_get_dataset_attrs.py +++ b/tests/v2/04_runner/aux_get_dataset_attrs.py @@ -7,6 +7,6 @@ async def _get_dataset_attrs(db, dataset_id) -> dict[str, Any]: await db.close() db_dataset = await db.get(DatasetV2, dataset_id) dataset_attrs = db_dataset.model_dump( - include={"filters", "history", "images"} + include={"type_filters", "attribute_filters", "history", "images"} ) return dataset_attrs diff --git a/tests/v2/04_runner/test_dummy_examples.py b/tests/v2/04_runner/test_dummy_examples.py index 4c535e5ae0..0731a0901c 100644 --- a/tests/v2/04_runner/test_dummy_examples.py +++ b/tests/v2/04_runner/test_dummy_examples.py @@ -29,6 +29,7 @@ def execute_tasks_v2(wf_task_list, workflow_dir_local, **kwargs) -> None: raw_execute_tasks_v2( wf_task_list=wf_task_list, workflow_dir_local=workflow_dir_local, + job_attribute_filters={}, **kwargs, ) diff --git a/tests/v2/04_runner/test_fractal_examples.py b/tests/v2/04_runner/test_fractal_examples.py index 1882a87902..0312616b21 100644 --- a/tests/v2/04_runner/test_fractal_examples.py +++ b/tests/v2/04_runner/test_fractal_examples.py @@ -32,6 +32,7 @@ def execute_tasks_v2(wf_task_list, workflow_dir_local, **kwargs): raw_execute_tasks_v2( wf_task_list=wf_task_list, workflow_dir_local=workflow_dir_local, + job_attribute_filters={}, **kwargs, ) @@ -106,8 +107,8 @@ async def test_fractal_demos_01( assert _task_names_from_history(dataset_attrs["history"]) == [ "create_ome_zarr_compound" ] - assert dataset_attrs["filters"]["attributes"] == {} - assert dataset_attrs["filters"]["types"] == {} + assert dataset_attrs["attribute_filters"] == {} + assert dataset_attrs["type_filters"] == {} _assert_image_data_exist(dataset_attrs["images"]) assert len(dataset_attrs["images"]) == 2 @@ -134,8 +135,8 @@ async def test_fractal_demos_01( "create_ome_zarr_compound", "illumination_correction", ] - assert dataset_attrs["filters"]["attributes"] == {} - assert dataset_attrs["filters"]["types"] == { + assert dataset_attrs["attribute_filters"] == {} + assert dataset_attrs["type_filters"] == { "illumination_correction": True, } assert set(img["zarr_url"] for img in dataset_attrs["images"]) == { @@ -187,8 +188,8 @@ async def test_fractal_demos_01( "MIP_compound", ] - assert dataset_attrs["filters"]["attributes"] == {} - assert dataset_attrs["filters"]["types"] == { + assert dataset_attrs["attribute_filters"] == {} + assert dataset_attrs["type_filters"] == { "illumination_correction": True, "3D": False, } @@ -310,8 +311,8 @@ async def test_fractal_demos_01_no_overwrite( "create_ome_zarr_compound", "illumination_correction", ] - assert dataset_attrs["filters"]["attributes"] == {} - assert dataset_attrs["filters"]["types"] == { + assert dataset_attrs["attribute_filters"] == {} + assert dataset_attrs["type_filters"] == { "illumination_correction": True, } assert [img["zarr_url"] for img in dataset_attrs["images"]] == [ @@ -390,8 +391,8 @@ async def test_fractal_demos_01_no_overwrite( "illumination_correction", "MIP_compound", ] - assert dataset_attrs["filters"]["attributes"] == {} - assert dataset_attrs["filters"]["types"] == { + assert dataset_attrs["attribute_filters"] == {} + assert dataset_attrs["type_filters"] == { "3D": False, "illumination_correction": True, } @@ -429,8 +430,8 @@ async def test_fractal_demos_01_no_overwrite( }, } - assert dataset_attrs["filters"]["attributes"] == {} - assert dataset_attrs["filters"]["types"] == { + assert dataset_attrs["attribute_filters"] == {} + assert dataset_attrs["type_filters"] == { "3D": False, "illumination_correction": True, } diff --git a/tests/v2/04_runner/test_no_images_parallelization.py b/tests/v2/04_runner/test_no_images_parallelization.py index 3ae0653d98..cbe6ce6043 100644 --- a/tests/v2/04_runner/test_no_images_parallelization.py +++ b/tests/v2/04_runner/test_no_images_parallelization.py @@ -23,6 +23,7 @@ def execute_tasks_v2(wf_task_list, workflow_dir_local, **kwargs): raw_execute_tasks_v2( wf_task_list=wf_task_list, workflow_dir_local=workflow_dir_local, + job_attribute_filters={}, **kwargs, ) diff --git a/tests/v2/04_runner/test_unit_aux_functions_v2.py b/tests/v2/04_runner/test_unit_aux_functions_v2.py index f3407fe560..4cf676b006 100644 --- a/tests/v2/04_runner/test_unit_aux_functions_v2.py +++ b/tests/v2/04_runner/test_unit_aux_functions_v2.py @@ -4,6 +4,7 @@ from fractal_server.app.runner.exceptions import JobExecutionError from fractal_server.app.runner.v2.deduplicate_list import deduplicate_list +from fractal_server.app.runner.v2.merge_outputs import merge_outputs from fractal_server.app.runner.v2.runner_functions import ( _cast_and_validate_InitTaskOutput, ) @@ -12,6 +13,7 @@ ) from fractal_server.app.runner.v2.task_interface import InitArgsModel from fractal_server.app.runner.v2.task_interface import TaskOutput +from fractal_server.images import SingleImageTaskOutput def test_deduplicate_list_of_dicts(): @@ -84,3 +86,85 @@ def test_cast_and_validate_functions(): ) with pytest.raises(JobExecutionError): _cast_and_validate_InitTaskOutput(dict(invalid=True)) + + +def test_merge_outputs(): + + # 1 + task_outputs = [ + TaskOutput(type_filters={"a": True}), + TaskOutput(type_filters={"a": True}), + ] + merged = merge_outputs(task_outputs) + assert merged.type_filters == {"a": True} + + # 2 + task_outputs = [ + TaskOutput(type_filters={"a": True}), + TaskOutput(type_filters={"b": True}), + ] + with pytest.raises(ValueError): + merge_outputs(task_outputs) + + # 3 + task_outputs = [ + TaskOutput(type_filters={"a": True}), + TaskOutput(type_filters={"a": False}), + ] + with pytest.raises(ValueError): + merge_outputs(task_outputs) + + # 4 + merged = merge_outputs([]) + assert merged == TaskOutput() + + # 5 + task_outputs = [ + TaskOutput( + image_list_updates=[ + SingleImageTaskOutput(zarr_url="/a"), + SingleImageTaskOutput(zarr_url="/b"), + ], + image_list_removals=["/x", "/y", "/z"], + ), + TaskOutput( + image_list_updates=[ + SingleImageTaskOutput(zarr_url="/c"), + SingleImageTaskOutput(zarr_url="/a"), + ], + image_list_removals=["/x", "/w", "/z"], + ), + ] + merged = merge_outputs(task_outputs) + assert merged.image_list_updates == [ + SingleImageTaskOutput(zarr_url="/a"), + SingleImageTaskOutput(zarr_url="/b"), + SingleImageTaskOutput(zarr_url="/c"), + ] + assert set(merged.image_list_removals) == set(["/x", "/y", "/z", "/w"]) + + +def test_update_legacy_filters(): + + legacy_filters = {"types": {"a": True}} + + # 1 + output = TaskOutput(filters=legacy_filters) + assert output.filters is None + assert output.type_filters == legacy_filters["types"] + + # 2 + output = TaskOutput(type_filters=legacy_filters["types"]) + assert output.filters is None + assert output.type_filters == legacy_filters["types"] + + # 3 + with pytest.raises(ValidationError): + TaskOutput( + filters=legacy_filters, type_filters=legacy_filters["types"] + ) + + # 4 + output = TaskOutput() + assert output.filters is None + assert output.type_filters == {} diff --git a/tests/v2/04_runner/v2_mock_models.py b/tests/v2/04_runner/v2_mock_models.py index 4321f2f9ec..a842f33186 100644 --- a/tests/v2/04_runner/v2_mock_models.py +++ b/tests/v2/04_runner/v2_mock_models.py @@ -55,13 +55,8 @@ class WorkflowTaskV2Mock(BaseModel): meta_parallel: Optional[dict[str, Any]] = Field() meta_non_parallel: Optional[dict[str, Any]] = Field() task: TaskV2Mock - input_filters: dict[str, Any] = Field(default_factory=dict) + type_filters: dict[str, bool] = Field(default_factory=dict) order: int id: int workflow_id: int = 0 task_id: int - - @validator("input_filters", always=True) - def _default_filters(cls, value): - if value == {}: - return {"types": {}, "attributes": {}} diff --git a/tests/v2/05_images/test_benchmark_helper_functions.py b/tests/v2/05_images/test_benchmark_helper_functions.py index 9d49b6269d..60aba1b152 100644 --- a/tests/v2/05_images/test_benchmark_helper_functions.py +++ b/tests/v2/05_images/test_benchmark_helper_functions.py @@ -6,7 +6,6 @@ from fractal_server.app.runner.v2.deduplicate_list import deduplicate_list from fractal_server.app.runner.v2.task_interface import InitArgsModel -from fractal_server.images import Filters from fractal_server.images import SingleImage from fractal_server.images.tools import filter_image_list @@ -46,10 +45,8 @@ def test_filter_image_list_with_filters( ): new_list = filter_image_list( images=images, - filters=Filters( - attributes=dict(a1=0, a2="a2", a3=None), - types=dict(t1=True, t2=False), - ), + attribute_filters=dict(a1=[0], a2=["a2"], a3=None), + type_filters=dict(t1=True, t2=False), ) debug(len(images), len(new_list)) assert len(new_list) == len(images) // 4 @@ -65,7 +62,7 @@ def test_filter_image_list_few_filters( ): new_list = filter_image_list( images=images, - filters=Filters(attributes=dict(a1=0)), + attribute_filters=dict(a1=[0]), ) debug(len(images), len(new_list)) assert len(new_list) == len(images) // 2 diff --git a/tests/v2/05_images/test_filters.py b/tests/v2/05_images/test_filters.py index 8ac82a6c53..c650f99a2b 100644 --- a/tests/v2/05_images/test_filters.py +++ b/tests/v2/05_images/test_filters.py @@ -2,7 +2,6 @@ from devtools import debug from pydantic import ValidationError -from fractal_server.images import Filters from fractal_server.images import SingleImage from fractal_server.images.tools import filter_image_list @@ -96,38 +95,18 @@ def test_singleimage_attributes_validation(): ) -def test_filters_attributes_validation(): - invalid = [ - ["l", "i", "s", "t"], - {"d": "i", "c": "t"}, - {"s", "e", "t"}, - ("t", "u", "p", "l", "e"), - bool, # type - lambda x: x, # function - ] - - for item in invalid: - with pytest.raises(ValidationError) as e: - Filters(attributes={"key": item}) - debug(e) - - valid = ["string", -7, 3.14, True, None] - for item in valid: - assert Filters(attributes={"key": item}).attributes["key"] == item - - @pytest.mark.parametrize( "attribute_filters,type_filters,expected_number", [ # No filter ({}, {}, 6), # Key is not part of attribute keys - ({"missing_key": "whatever"}, {}, 0), + ({"missing_key": ["whatever"]}, {}, 0), # Key is not part of type keys (default is False) ({}, {"missing_key": True}, 0), ({}, {"missing_key": False}, 6), # Key is part of attribute keys, but value is missing - ({"plate": "missing_plate.zarr"}, {}, 0), + ({"plate": ["missing_plate.zarr"]}, {}, 0), # Meaning of None for attributes: skip a given filter ({"plate": None}, {}, 6), # Single type filter @@ -138,24 +117,24 @@ def test_filters_attributes_validation(): ({}, {"3D": True, "illumination_correction": True}, 2), # Both attribute and type filters ( - {"plate": "plate.zarr"}, + {"plate": ["plate.zarr"]}, {"3D": True, "illumination_correction": True}, 2, ), # Both attribute and type filters ( - {"plate": "plate_2d.zarr"}, + {"plate": ["plate_2d.zarr"]}, {"3D": True, "illumination_correction": True}, 0, ), # Both attribute and type filters ( - {"plate": "plate.zarr", "well": "A01"}, + {"plate": ["plate.zarr"], "well": ["A01"]}, {"3D": True, "illumination_correction": True}, 1, ), # Single attribute filter - ({"well": "A01"}, {}, 3), + ({"well": ["A01"]}, {}, 3), ], ) def test_filter_image_list_SingleImage( @@ -165,7 +144,8 @@ def test_filter_image_list_SingleImage( ): filtered_list = filter_image_list( images=IMAGES, - filters=Filters(attributes=attribute_filters, types=type_filters), + attribute_filters=attribute_filters, + type_filters=type_filters, ) debug(attribute_filters) diff --git a/tests/v2/05_images/test_image_models.py b/tests/v2/05_images/test_image_models.py index 89960e99ac..e52d338754 100644 --- a/tests/v2/05_images/test_image_models.py +++ b/tests/v2/05_images/test_image_models.py @@ -1,133 +1,245 @@ -import pytest -from pydantic import ValidationError - -from fractal_server.images import Filters -from fractal_server.images import SingleImage -from fractal_server.images import SingleImageTaskOutput -from fractal_server.images import SingleImageUpdate -from fractal_server.images.models import SingleImageBase +from typing import TypeVar +from pydantic import ValidationError -def test_single_image(): +from fractal_server.images.models import _SingleImageBase +from fractal_server.images.models import SingleImage +from fractal_server.images.models import SingleImageTaskOutput +from fractal_server.images.models import SingleImageUpdate - with pytest.raises(ValidationError): - SingleImage() +T = TypeVar("T") - assert SingleImage(zarr_url="/somewhere").zarr_url == "/somewhere" - assert SingleImage(zarr_url="/somewhere", origin="/foo").origin == "/foo" - assert SingleImage(zarr_url="/somewhere", origin=None).origin is None +def image_ok(model: T, **kwargs) -> T: + return model(**kwargs) - valid_attributes = dict(a="string", b=3, c=0.33, d=True) - assert ( - SingleImage( - zarr_url="/somewhere", attributes=valid_attributes - ).attributes - == valid_attributes - ) - invalid_attributes = [ - dict(a=None), - dict(a=["l", "i", "s", "t"]), - dict(a={"d": "i", "c": "t"}), - ] - for attr in invalid_attributes: - with pytest.raises(ValidationError): - SingleImage(zarr_url="/somewhere", attributes=attr) - valid_types = dict(a=True, b=False) - assert ( - SingleImage(zarr_url="/somewhere", types=valid_types).types - == valid_types - ) +def image_fail(model: T, **kwargs) -> str: + try: + model(**kwargs) + raise AssertionError(f"{model=}, {kwargs=}") + except ValidationError as e: + return str(e) - invalid_types = dict(a="not a bool") - with pytest.raises(ValidationError): - SingleImage(zarr_url="/somewhere", types=invalid_types) +def test_SingleImageBase(): -def test_url_normalization(): + image_fail(model=_SingleImageBase) # zarr_url - assert SingleImage(zarr_url="/valid/url").zarr_url == "/valid/url" - assert SingleImage(zarr_url="/remove/slash/").zarr_url == "/remove/slash" - - with pytest.raises(ValidationError) as e: - SingleImage(zarr_url="s3/foo") - assert "S3 handling" in e._excinfo[1].errors()[0]["msg"] - - with pytest.raises(ValidationError) as e: - SingleImage(zarr_url="https://foo.bar") - assert "URLs must begin" in e._excinfo[1].errors()[0]["msg"] + image = image_ok(model=_SingleImageBase, zarr_url="/x") + assert image.dict() == { + "zarr_url": "/x", + "origin": None, + "attributes": {}, + "types": {}, + } + image_fail( + model=_SingleImageBase, zarr_url="x" + ) # see 'test_url_normalization' + image_fail(model=_SingleImageBase, zarr_url=None) # origin - assert SingleImage(zarr_url="/valid/url", origin=None).origin is None - assert ( - SingleImage(zarr_url="/valid/url", origin="/valid/origin").origin - == "/valid/origin" + image = image_ok(model=_SingleImageBase, zarr_url="/x", origin="/y") + assert image.origin == "/y" + image = image_ok(model=_SingleImageBase, zarr_url="/x", origin=None) + assert image.origin is None + image_fail(model=_SingleImageBase, zarr_url="/x", origin="y") + image_fail(model=_SingleImageBase, origin="/y") + + # attributes + valid_attributes = { + "int": 1, + "float": 1.2, + "string": "abc", + "bool": True, + "null": None, + "list": ["l", "i", "s", "t"], + "dict": {"d": "i", "c": "t"}, + "function": lambda x: x, + "type": int, + } # Any + image = image_ok( + model=_SingleImageBase, zarr_url="/x", attributes=valid_attributes ) - assert ( - SingleImage(zarr_url="/valid/url", origin="/remove/slash//").origin - == "/remove/slash" + assert image.attributes == valid_attributes + invalid_attributes = { + "repeated": 1, + " repeated ": 2, + } + image_fail( + model=_SingleImageBase, zarr_url="/x", attributes=invalid_attributes ) - with pytest.raises(ValidationError) as e: - SingleImage(zarr_url="/valid/url", origin="s3/foo") - assert "S3 handling" in e._excinfo[1].errors()[0]["msg"] - with pytest.raises(ValidationError) as e: - SingleImage(zarr_url="/valid/url", origin="http://foo.bar") - assert "URLs must begin" in e._excinfo[1].errors()[0]["msg"] - -def test_single_image_task_output(): - base = SingleImageBase(zarr_url="/zarr/url", attributes={"x": None}) - - # SingleImageTaskOutput accepts 'None' as value - SingleImageTaskOutput(**base.dict()) - # SingleImage does not accept 'None' as value - with pytest.raises(ValidationError): - SingleImage(**base.dict()) + # types + valid_types = {"y": True, "n": False} # only booleans + image = image_ok(model=_SingleImageBase, zarr_url="/x", types=valid_types) + assert image.types == valid_types + image_fail( + model=_SingleImageBase, zarr_url="/x", types={"a": "not a bool"} + ) + image_fail( + model=_SingleImageBase, zarr_url="/x", types={"a": True, " a": True} + ) + image_ok(model=_SingleImageBase, zarr_url="/x", types={1: True}) -def test_filters(): +def test_url_normalization(): - Filters() + image = image_ok(model=_SingleImageBase, zarr_url="/valid/url") + assert image.zarr_url == "/valid/url" + image = image_ok(model=_SingleImageBase, zarr_url="/remove/slash/") + assert image.zarr_url == "/remove/slash" + + e = image_fail(model=_SingleImageBase, zarr_url="s3/foo") + assert "S3 handling" in e + e = image_fail(model=_SingleImageBase, zarr_url="https://foo.bar") + assert "URLs must begin" in e + + image_ok(model=_SingleImageBase, zarr_url="/x", origin=None) + image_ok(model=_SingleImageBase, zarr_url="/x", origin="/y") + image = image_ok(model=_SingleImageBase, zarr_url="/x", origin="/y///") + assert image.origin == "/y" + + e = image_fail(model=_SingleImageBase, zarr_url="/x", origin="s3/foo") + assert "S3 handling" in e + e = image_fail( + model=_SingleImageBase, zarr_url="/x", origin="https://foo.bar" + ) + assert "URLs must begin" in e - valid_attributes = dict(a="string", b=3, c=0.33, d=True, e=None) - assert Filters(attributes=valid_attributes).attributes == valid_attributes - invalid_attributes = [ - dict(a=["l", "i", "s", "t"]), - dict(a={"d": "i", "c": "t"}), - ] - for attr in invalid_attributes: - with pytest.raises(ValidationError): - Filters(attributes=attr) +def test_SingleImageTaskOutput(): - valid_types = dict(a=True, b=False) - assert Filters(types=valid_types).types == valid_types + image_ok( + model=SingleImageTaskOutput, + zarr_url="/x", + attributes={ + "int": 1, + "float": 1.2, + "string": "abc", + "bool": True, + "null": None, + }, + ) + image_fail( + model=SingleImageTaskOutput, + zarr_url="/x", + attributes={"list": ["l", "i", "s", "t"]}, + ) + image_fail( + model=SingleImageTaskOutput, + zarr_url="/x", + attributes={"dict": {"d": "i", "c": "t"}}, + ) + image_fail( + model=SingleImageTaskOutput, + zarr_url="/x", + attributes={"function": lambda x: x}, + ) + image_fail( + model=SingleImageTaskOutput, + zarr_url="/x", + attributes={"type": int}, + ) + image_fail( + model=SingleImageTaskOutput, + zarr_url="/x", + attributes={"repeated": 1, " repeated": 2}, + ) - invalid_types = dict(a="not a bool") - with pytest.raises(ValidationError): - Filters(types=invalid_types) +def test_SingleImage(): -def test_single_image_update(): + image_ok( + model=SingleImage, + zarr_url="/x", + attributes={ + "int": 1, + "float": 1.2, + "string": "abc", + "bool": True, + }, + ) + image_fail( + model=SingleImage, + zarr_url="/x", + attributes={"null": None}, + ) + image_fail( + model=SingleImage, + zarr_url="/x", + attributes={"list": ["l", "i", "s", "t"]}, + ) + image_fail( + model=SingleImage, + zarr_url="/x", + attributes={"dict": {"d": "i", "c": "t"}}, + ) + image_fail( + model=SingleImage, + zarr_url="/x", + attributes={"function": lambda x: x}, + ) + image_fail( + model=SingleImage, + zarr_url="/x", + attributes={"type": int}, + ) + image_fail( + model=SingleImage, + zarr_url="/x", + attributes={"repeated": 1, " repeated": 2}, + ) - with pytest.raises(ValidationError): - SingleImageUpdate() - SingleImageUpdate(zarr_url="/something") - # override SingleImageBase validation - args = dict(zarr_url="/something", attributes=None) - with pytest.raises(ValidationError): - SingleImageBase(**args) - SingleImageUpdate(**args) +def test_SingleImageUpdate(): - args = dict(zarr_url="/something", types=None) - with pytest.raises(ValidationError): - SingleImageBase(**args) - SingleImageUpdate(**args) + image_fail(model=SingleImageUpdate) - with pytest.raises(ValidationError): - SingleImageUpdate( - zarr_url="/something", attributes={"invalid": ["l", "i", "s", "t"]} + # zarr_url + image = image_ok(model=SingleImageUpdate, zarr_url="/x") + assert image.dict() == { + "zarr_url": "/x", + "attributes": None, + "types": None, + } + image_fail(model=SingleImageUpdate, zarr_url="x") + image_fail(model=SingleImageUpdate, zarr_url=None) + + # attributes + valid_attributes = { + "int": 1, + "float": 1.2, + "string": "abc", + "bool": True, + } + image = image_ok( + model=SingleImageUpdate, zarr_url="/x", attributes=valid_attributes + ) + assert image.attributes == valid_attributes + for invalid_attributes in [ + {"null": None}, + {"list": ["l", "i", "s", "t"]}, + {"dict": {"d": "i", "c": "t"}}, + {"function": lambda x: x}, + {"type": int}, + {"repeated": 1, " repeated ": 2}, + ]: + image_fail( + model=SingleImageUpdate, + zarr_url="/x", + attributes=invalid_attributes, ) + + # types + valid_types = {"y": True, "n": False} # only booleans + image = image_ok(model=SingleImageUpdate, zarr_url="/x", types=valid_types) + assert image.types == valid_types + image_fail( + model=SingleImageUpdate, zarr_url="/x", types={"a": "not a bool"} + ) + image_fail( + model=SingleImageUpdate, zarr_url="/x", types={"a": True, " a": True} + ) + image_ok(model=SingleImageUpdate, zarr_url="/x", types={1: True}) diff --git a/tests/v2/05_images/test_unit_image_tools.py b/tests/v2/05_images/test_unit_image_tools.py index 87520782dc..86a0131530 100644 --- a/tests/v2/05_images/test_unit_image_tools.py +++ b/tests/v2/05_images/test_unit_image_tools.py @@ -1,232 +1,158 @@ -from fractal_server.images import Filters -from fractal_server.images import SingleImage from fractal_server.images.tools import filter_image_list from fractal_server.images.tools import find_image_by_zarr_url from fractal_server.images.tools import match_filter -N = 100 -images = [ - SingleImage( - zarr_url=f"/a/b/c{i}.zarr", - attributes=dict( - name=("a" if i % 2 == 0 else "b"), - num=i % 3, - ), - types=dict( - a=(i <= N // 2), - b=(i >= N // 3), - ), - ).dict() - for i in range(N) -] - def test_find_image_by_zarr_url(): + images = [{"zarr_url": "/x"}, {"zarr_url": "/y"}, {"zarr_url": "/z"}] + res = find_image_by_zarr_url(zarr_url="/x", images=images) + assert res == { + "index": 0, + "image": {"zarr_url": "/x"}, + } + res = find_image_by_zarr_url(zarr_url="/y", images=images) + assert res == { + "index": 1, + "image": {"zarr_url": "/y"}, + } + res = find_image_by_zarr_url(zarr_url="/z", images=images) + assert res == { + "index": 2, + "image": {"zarr_url": "/z"}, + } + res = find_image_by_zarr_url(zarr_url="/k", images=images) + assert res is None - for i in range(N): - image_search = find_image_by_zarr_url( - zarr_url=f"/a/b/c{i}.zarr", images=images - ) - assert image_search["image"]["zarr_url"] == f"/a/b/c{i}.zarr" - assert image_search["index"] == i - image_search = find_image_by_zarr_url(zarr_url="/xxx", images=images) - assert image_search is None +def test_match_filter(): + # empty filters (always match) + assert match_filter(image=..., type_filters={}, attribute_filters={}) -def test_match_filter(): + image = {"types": {"a": True, "b": False}, "attributes": {"a": 1, "b": 2}} - image = SingleImage( - zarr_url="/a/b/c0.zarr", - attributes=dict( - name="a", - num=0, - ), - types=dict( - a=True, - b=False, - ), - ).dict() - - # Empty - assert match_filter(image, Filters()) is True - - # Attributes - - f = Filters(attributes=dict(foo="bar")) # not existing attribute - assert match_filter(image, f) is False - - f = Filters(attributes=dict(name="a")) - assert match_filter(image, f) is True - - f = Filters(attributes=dict(num=0)) - assert match_filter(image, f) is True - - f = Filters( - attributes=dict( - name="a", - num=0, - ) - ) - assert match_filter(image, f) is True - - f = Filters( - attributes=dict( - name="a", - num=0, - foo="bar", # not existing attribute - ) - ) - assert match_filter(image, f) is False - - f = Filters( - attributes=dict( - name="a", - num="0", # int as string - ) - ) - assert match_filter(image, f) is False - - # Types - - f = Filters(types=dict(a=True)) - assert match_filter(image, f) is True - f = Filters(types=dict(b=False)) - assert match_filter(image, f) is True - f = Filters( - types=dict( - a=True, - b=False, - ) - ) - assert match_filter(image, f) is True - f = Filters( - types=dict( - a=False, - ) - ) - assert match_filter(image, f) is False - f = Filters( - types=dict( - a=True, - b=True, - ) - ) - assert match_filter(image, f) is False - f = Filters( - types=dict( - c=True, # not existing 'True' types are checked - ) - ) - assert match_filter(image, f) is False - f = Filters( - types=dict( - c=False, # not existing 'False' types are ignored - ) - ) - assert match_filter(image, f) is True - f = Filters( - types=dict( - a=True, - b=False, - c=True, - ) - ) - assert match_filter(image, f) is False - f = Filters( - types=dict( - a=True, - b=False, - c=False, - ) - ) - assert match_filter(image, f) is True - - # Both - - f = Filters( - attributes=dict( - name="a", - ), - types=dict(a=True), - ) - assert match_filter(image, f) is True - f = Filters( - attributes=dict( - name="a", - ), - types=dict(a=False), - ) - assert match_filter(image, f) is False - f = Filters( - attributes=dict( - name="b", - ), - types=dict(a=True), - ) - assert match_filter(image, f) is False - f = Filters( - attributes=dict( - name="a", - ), - types=dict( - x=False, - y=False, - z=False, - ), - ) - assert match_filter(image, f) is True - f = Filters( - attributes=dict( - name="a", - ), - types=dict( - x=False, - y=False, - z=True, - ), - ) - assert match_filter(image, f) is False + # type filters + # a + assert match_filter( + image=image, type_filters={"a": True}, attribute_filters={} + ) + assert not match_filter( + image=image, type_filters={"a": False}, attribute_filters={} + ) + # b + assert not match_filter( + image=image, type_filters={"b": True}, attribute_filters={} + ) + assert match_filter( + image=image, type_filters={"b": False}, attribute_filters={} + ) + # c + assert not match_filter( + image=image, type_filters={"c": True}, attribute_filters={} + ) + assert match_filter( + image=image, type_filters={"c": False}, attribute_filters={} + ) + # a b c + assert match_filter( + image=image, + type_filters={"a": True, "b": False, "c": False}, + attribute_filters={}, + ) + assert not match_filter( + image=image, + type_filters={"a": False, "b": False, "c": False}, + attribute_filters={}, + ) + assert not match_filter( + image=image, + type_filters={"a": True, "b": True, "c": False}, + attribute_filters={}, + ) + assert not match_filter( + image=image, + type_filters={"a": False, "b": True, "c": False}, + attribute_filters={}, + ) + + # attribute filters + assert match_filter( + image=image, type_filters={}, attribute_filters={"a": [1]} + ) + assert match_filter( + image=image, type_filters={}, attribute_filters={"a": [1], "b": [1, 2]} + ) + assert not match_filter( + image=image, type_filters={}, attribute_filters={"a": [0], "b": [1, 2]} + ) + assert match_filter( + image=image, + type_filters={}, + attribute_filters={"a": None, "b": [1, 2]}, + ) + + # both + assert match_filter( + image=image, type_filters={"a": True}, attribute_filters={"a": [1]} + ) + assert not match_filter( + image=image, type_filters={"a": False}, attribute_filters={"a": [1]} + ) + assert not match_filter( + image=image, type_filters={"a": True}, attribute_filters={"a": [0]} + ) def test_filter_image_list(): - # Empty - res = filter_image_list(images, Filters()) + + images = [ + {"types": {"a": True}, "attributes": {"a": 1, "b": 2}}, + {"types": {"a": True}, "attributes": {"a": 2, "b": 2}}, + {"types": {"a": False}, "attributes": {"a": 1, "b": 1}}, + {"types": {}, "attributes": {"a": 1, "b": 1}}, + {"types": {}, "attributes": {}}, + ] + + # empty + res = filter_image_list(images) assert res == images - # Attributes - f = Filters(attributes=dict(name="a")) - res = filter_image_list(images, f) - k = (N // 2) if not N % 2 else (N + 1) // 2 - assert len(res) == k - f = Filters(attributes=dict(name="b")) - res = filter_image_list(images, f) - assert len(res) == N - k - f = Filters(attributes=dict(num=0)) - res = filter_image_list(images, f) - assert len(res) == len([i for i in range(N) if i % 3 == 0]) - f = Filters(attributes=dict(num=1)) - res = filter_image_list(images, f) - assert len(res) == len([i for i in range(N) if i % 3 == 1]) - f = Filters(attributes=dict(num=2)) - res = filter_image_list(images, f) - assert len(res) == len([i for i in range(N) if i % 3 == 2]) - f = Filters(attributes=dict(name="foo")) - res = filter_image_list(images, f) - assert len(res) == 0 - f = Filters(attributes=dict(num=3)) - res = filter_image_list(images, f) - assert len(res) == 0 - f = Filters(attributes=dict(name="a", num=3)) - res = filter_image_list(images, f) + res = filter_image_list(images, type_filters={}) + assert res == images + res = filter_image_list(images, attribute_filters={}) + assert res == images + res = filter_image_list(images, type_filters={}, attribute_filters={}) + assert res == images + + # type filters + res = filter_image_list(images, type_filters={"a": True}) + assert len(res) == 2 + res = filter_image_list(images, type_filters={"a": False}) + assert len(res) == 3 # complementary of 2 + res = filter_image_list(images, type_filters={"b": True}) assert len(res) == 0 - f = Filters(attributes=dict(name="foo", num=0)) - res = filter_image_list(images, f) + res = filter_image_list(images, type_filters={"b": False}) + assert len(res) == 5 + + # attribute filters + res = filter_image_list(images, attribute_filters={"a": [1]}) + assert len(res) == 3 + res = filter_image_list(images, attribute_filters={"a": [1, 2]}) + assert len(res) == 4 + res = filter_image_list(images, attribute_filters={"a": None, "b": None}) + assert len(res) == 5 + + # both + res = filter_image_list( + images, type_filters={"a": True}, attribute_filters={"a": [1]} + ) + assert len(res) == 1 + res = filter_image_list( + images, type_filters={"a": True}, attribute_filters={"a": [1, 2]} + ) + assert len(res) == 2 + res = filter_image_list( + images, + type_filters={"a": True}, + attribute_filters={"a": [1, 2], "b": [-1]}, + ) assert len(res) == 0 - f = Filters( - types=dict( - a=True, - b=True, - ) - ) - res = filter_image_list(images, f) - assert len(res) == N // 2 - N // 3 + 1 diff --git a/tests/v2/08_full_workflow/common_functions.py b/tests/v2/08_full_workflow/common_functions.py index 5cfc2932cb..aa159e73b7 100644 --- a/tests/v2/08_full_workflow/common_functions.py +++ b/tests/v2/08_full_workflow/common_functions.py @@ -128,7 +128,7 @@ async def full_workflow( assert res.status_code == 200 dataset = res.json() assert len(dataset["history"]) == 2 - assert dataset["filters"]["types"] == {"3D": False} + assert dataset["type_filters"] == {"3D": False} res = await client.post( f"{PREFIX}/project/{project_id}/dataset/{dataset_id}/" "images/query/", @@ -281,13 +281,10 @@ async def full_workflow_TaskExecutionError( ) assert res.status_code == 200 dataset = res.json() - EXPECTED_FILTERS = { - "attributes": {}, - "types": { - "3D": False, - }, - } - assert dataset["filters"] == EXPECTED_FILTERS + EXPECTED_TYPE_FILTERS = {"3D": False} + EXPECTED_ATTRIBUTE_FILTERS = {} + assert dataset["type_filters"] == EXPECTED_TYPE_FILTERS + assert dataset["attribute_filters"] == EXPECTED_ATTRIBUTE_FILTERS assert len(dataset["history"]) == 3 assert [item["status"] for item in dataset["history"]] == [ "done", diff --git a/tests/v2/09_backends/test_local_experimental.py b/tests/v2/09_backends/test_local_experimental.py index d4f259d6da..957f7fae42 100644 --- a/tests/v2/09_backends/test_local_experimental.py +++ b/tests/v2/09_backends/test_local_experimental.py @@ -59,6 +59,7 @@ async def test_unit_process_workflow(): logger_name=None, workflow_dir_local="/foo", workflow_dir_remote="/bar", + job_attribute_filters={}, ) @@ -218,6 +219,7 @@ async def test_indirect_shutdown_during_process_workflow( workflow_dir_local=tmp_path, first_task_index=0, last_task_index=0, + job_attribute_filters={}, ) tmp_stdout.close() tmp_stderr.close() diff --git a/tests/v2/09_backends/test_slurm_config.py b/tests/v2/09_backends/test_slurm_config.py index 601977c032..7ab8f279f4 100644 --- a/tests/v2/09_backends/test_slurm_config.py +++ b/tests/v2/09_backends/test_slurm_config.py @@ -53,7 +53,7 @@ class WorkflowTaskV2Mock(BaseModel, extra=Extra.forbid): meta_parallel: Optional[dict[str, Any]] = Field() meta_non_parallel: Optional[dict[str, Any]] = Field() task: TaskV2Mock - input_filters: dict[str, Any] = Field(default_factory=dict) + type_filters: dict[str, bool] = Field(default_factory=dict) order: int = 0 id: int = 1 workflow_id: int = 0