From e666dbe78111786d278b52772ca2b9668eb23254 Mon Sep 17 00:00:00 2001 From: Michael Tauraso Date: Fri, 8 Nov 2024 11:01:55 -0800 Subject: [PATCH] Initial version of filter_catalog feature. (#113) - We can take a fits file as a config - We filter objects_ids out of a big dataset based on it - We also skip filesystem checks if there is enough info in the filter catalog. - Lacks any unit testing - Added the prepare verb, but right now it just gives you the dataset object when run from a notebook. --- src/fibad/data_sets/hsc_data_set.py | 99 ++++++++++++++++++++-- src/fibad/downloadCutout/downloadCutout.py | 5 ++ src/fibad/fibad.py | 10 ++- src/fibad/fibad_default_config.toml | 4 + src/fibad/prepare.py | 21 +++++ tests/fibad/test_hsc_dataset.py | 13 ++- 6 files changed, 141 insertions(+), 11 deletions(-) create mode 100644 src/fibad/prepare.py diff --git a/src/fibad/data_sets/hsc_data_set.py b/src/fibad/data_sets/hsc_data_set.py index e7d0730..7df1bd8 100644 --- a/src/fibad/data_sets/hsc_data_set.py +++ b/src/fibad/data_sets/hsc_data_set.py @@ -9,6 +9,7 @@ import numpy as np import torch from astropy.io import fits +from astropy.table import Table from torch.utils.data import Dataset from torchvision.transforms.v2 import CenterCrop, Compose, Lambda @@ -240,6 +241,10 @@ def __getitem__(self, idx: int) -> torch.Tensor: return self.data[self.indexes[idx]] +dim_dict = dict[str, tuple[int, int]] +files_dict = dict[str, dict[str, str]] + + class HSCDataSetContainer(Dataset): def __init__(self, config): # TODO: What will be a reasonable set of tranformations? @@ -250,12 +255,14 @@ def __init__(self, config): crop_to = config["data_set"]["crop_to"] filters = config["data_set"]["filters"] + filter_catalog = config["data_set"]["filter_catalog"] self._init_from_path( config["general"]["data_dir"], transform=transform, cutout_shape=crop_to if crop_to else None, filters=filters if filters else None, + filter_catalog=Path(filter_catalog) if filter_catalog else None, ) def _init_from_path( @@ -265,6 +272,7 @@ def _init_from_path( transform=None, cutout_shape: Optional[tuple[int, int]] = None, filters: Optional[list[str]] = None, + filter_catalog: Optional[Path] = None, ): """__init__ helper. Initialize an HSC data set from a path. This involves several filesystem scan operations and will ultimately open and read the header info of every fits file in the given directory @@ -284,12 +292,31 @@ def _init_from_path( cutouts which do not have fits files corresponding to every filter in the list will be dropped from the data set. Defaults to None. If not provided, the filters available on the filesystem for the first object in the directory will be used. + filter_catalog: Path, optional + Path to a .fits file which specifies objects and or files to use directly, bypassing the default + of attempting to use every file in the path. + Columns for this fits file are object_id (required), filter (optional), filename (optional), and + dims (optional tuple of x/y pixel size of images). + - Filenames must be relative to the path provided to this function. + - When filters and filenames are both provided, initialization skips a directory listing, which + can provide better performance on large datasets. + - When filters, filenames, and dims are specified we also skip opening the files to get + the dimensions. This can also provide better performance on large datasets. """ self.path = path self.transform = transform - self.files = self._scan_file_names(filters) - self.dims = self._scan_file_dimensions() + self.filter_catalog = self._read_filter_catalog(filter_catalog) + if isinstance(self.filter_catalog, tuple): + self.files = self.filter_catalog[0] + self.dims = self.filter_catalog[1] + print(self.dims) + elif isinstance(self.filter_catalog, dict): + self.files = self.filter_catalog + self.dims = self._scan_file_dimensions() + else: + self.files = self._scan_file_names(filters) + self.dims = self._scan_file_dimensions() # If no filters provided, we choose the first file in the dict as the prototypical set of filters # Any objects lacking this full set of filters will be pruned by _prune_objects @@ -313,7 +340,7 @@ def _init_from_path( logger.info(f"HSC Data set loader has {len(self)} objects") - def _scan_file_names(self, filters: Optional[list[str]] = None) -> dict[str, dict[str, str]]: + def _scan_file_names(self, filters: Optional[list[str]] = None) -> files_dict: """Class initialization helper Parameters @@ -335,11 +362,17 @@ def _scan_file_names(self, filters: Optional[list[str]] = None) -> dict[str, dic files = {} # Go scan the path for object ID's so we have a list. - for filepath in Path(self.path).glob("[0-9]*.fits"): + for filepath in Path(self.path).iterdir(): filename = filepath.name - m = re.match(full_regex, filename) - # Skip files that don't match the pattern. + # If we are filtering based off a user-provided catalog of object ids, Filter out any + # objects_ids not in the catalog. Do this before regex match for speed of discarding + # irrelevant files. + if isinstance(self.filter_catalog, list) and filename[:17] not in self.filter_catalog: + continue + + m = re.match(full_regex, filename) + # Skip files that don't allow us to extract both object_id and filter if m is None: continue @@ -359,7 +392,57 @@ def _scan_file_names(self, filters: Optional[list[str]] = None) -> dict[str, dic return files - def _scan_file_dimensions(self) -> dict[str, tuple[int, int]]: + def _read_filter_catalog( + self, filter_catalog_path: Optional[Path] + ) -> Optional[Union[list[str], files_dict, tuple[files_dict, dim_dict]]]: + if filter_catalog_path is None: + return None + + if not filter_catalog_path.exists(): + logger.error(f"Filter catalog file {filter_catalog_path} given in config does not exist.") + return None + + table = Table.read(filter_catalog_path, format="fits") + colnames = table.colnames + if "object_id" not in colnames: + logger.error(f"Filter catalog file {filter_catalog_path} has no column object_id") + return None + + # We are dealing with just a list of object_ids + if "filter" not in colnames and "filename" not in colnames: + return list(table["object_id"]) + + # Or a table that lacks both filter and filename + elif "filter" not in colnames or "filename" not in colnames: + msg = f"Filter catalog file {filter_catalog_path} provides one of filters or filenames " + msg += "without the other. Filesystem scan will still occur without both defined." + logger.warning(msg) + return list(set(table["object_id"])) + + # We have filter and filename defined so we can assemble the catalog at file level. + filter_catalog = {} + if "dim" in colnames: + dim_catalog = {} + + for row in table: + object_id = row["object_id"] + filter = row["filter"] + filename = row["filename"] + + if object_id not in filter_catalog: + filter_catalog[object_id] = {} + + filter_catalog[object_id][filter] = filename + + # Dimension is optional + if "dim" in colnames: + if object_id not in dim_catalog: + dim_catalog[object_id] = [] + dim_catalog[object_id].append(tuple(row["dim"])) + + return (filter_catalog, dim_catalog) if "dim" in colnames else filter_catalog + + def _scan_file_dimensions(self) -> dim_dict: # Scan the filesystem to get the widths and heights of all images into a dict return { object_id: [self._fits_file_dims(filepath) for filepath in self._object_files(object_id)] @@ -445,7 +528,7 @@ def _check_file_dimensions(self) -> tuple[int, int]: The minimum width and height in pixels of the entire dataset. In other words: the maximal image size in pixels that can be generated from ALL cutout images via cropping. """ - # Find the makximal cutout size that all images can support + # Find the maximal cutout size that all images can support all_widths = [shape[0] for shape_list in self.dims.values() for shape in shape_list] cutout_width = np.min(all_widths) diff --git a/src/fibad/downloadCutout/downloadCutout.py b/src/fibad/downloadCutout/downloadCutout.py index 2f28a51..0aa27e1 100644 --- a/src/fibad/downloadCutout/downloadCutout.py +++ b/src/fibad/downloadCutout/downloadCutout.py @@ -21,6 +21,8 @@ from collections.abc import Generator from typing import IO, Any, Callable, Optional, Union, cast +import numpy as np + __all__ = [] @@ -762,6 +764,9 @@ def parse_bool(s: Union[str, bool]) -> bool: if isinstance(s, bool): return s + if isinstance(s, np.bool): + return s + return { "false": False, "f": False, diff --git a/src/fibad/fibad.py b/src/fibad/fibad.py index 23dd6c5..474a005 100644 --- a/src/fibad/fibad.py +++ b/src/fibad/fibad.py @@ -14,7 +14,7 @@ class Fibad: CLI functions in fibad_cli are implemented by calling this class """ - verbs = ["train", "predict", "download"] + verbs = ["train", "predict", "download", "prepare"] def __init__(self, *, config_file: Union[Path, str] = None, setup_logging: bool = True): """Initialize fibad. Always applies the default config, and merges it with any provided config file. @@ -177,3 +177,11 @@ def predict(self, **kwargs): from .predict import run return run(config=self.config, **kwargs) + + def prepare(self, **kwargs): + """ + See Fibad.predict.run() + """ + from .prepare import run + + return run(config=self.config, **kwargs) diff --git a/src/fibad/fibad_default_config.toml b/src/fibad/fibad_default_config.toml index 9df057e..96616a0 100644 --- a/src/fibad/fibad_default_config.toml +++ b/src/fibad/fibad_default_config.toml @@ -90,6 +90,10 @@ crop_to = false #filters = ["HSC-G", "HSC-R", "HSC-I", "HSC-Z", "HSC-Y"] filters = false +# A fits file which specifies object IDs to filter a large dataset in [general].data_dir down +# Implementation is dataset class dependent. Default is false meaning now filtering. +filter_catalog = false + [data_loader] # Default PyTorch DataLoader parameters batch_size = 32 diff --git a/src/fibad/prepare.py b/src/fibad/prepare.py new file mode 100644 index 0000000..6be43c9 --- /dev/null +++ b/src/fibad/prepare.py @@ -0,0 +1,21 @@ +import logging + +from fibad.pytorch_ignite import setup_model_and_dataset + +logger = logging.getLogger(__name__) + + +def run(config): + """Prepare the dataset for a given model and data loader. + + Parameters + ---------- + config : dict + The parsed config file as a nested + dict + """ + + _, data_set = setup_model_and_dataset(config, split=config["train"]["split"]) + + logger.info("Finished Prepare") + return data_set diff --git a/tests/fibad/test_hsc_dataset.py b/tests/fibad/test_hsc_dataset.py index 7a32e49..2717685 100644 --- a/tests/fibad/test_hsc_dataset.py +++ b/tests/fibad/test_hsc_dataset.py @@ -31,7 +31,7 @@ def __init__(self, test_files: dict): self.test_files = test_files mock_paths = [Path(x) for x in list(test_files.keys())] - target = "fibad.data_sets.hsc_data_set.Path.glob" + target = "fibad.data_sets.hsc_data_set.Path.iterdir" self.patchers.append(mock.patch(target, return_value=mock_paths)) mock_fits_open = mock.Mock(side_effect=self._open_file) @@ -53,7 +53,15 @@ def __exit__(self, *exc): patcher.stop() -def mkconfig(crop_to=False, filters=False, train_size=0.2, test_size=0.6, validate_size=0, seed=False): +def mkconfig( + crop_to=False, + filters=False, + train_size=0.2, + test_size=0.6, + validate_size=0, + seed=False, + filter_catalog=False, +): """Makes a configuration that points at nonexistent path so HSCDataSet.__init__ will create an object, and our FakeFitsFS shim can be called. """ @@ -62,6 +70,7 @@ def mkconfig(crop_to=False, filters=False, train_size=0.2, test_size=0.6, valida "data_set": { "crop_to": crop_to, "filters": filters, + "filter_catalog": filter_catalog, }, "prepare": { "seed": seed,