Skip to content

Commit

Permalink
Merge pull request #155 from lincc-frameworks/caching_expts
Browse files Browse the repository at this point in the history
User Controlled Caching
  • Loading branch information
aritraghsh09 authored Jan 10, 2025
2 parents aaccba0 + f956fb8 commit 9d1fb12
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/fibad/data_sets/hsc_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ def __init__(self, config):
crop_to = config["data_set"]["crop_to"]
filters = config["data_set"]["filters"]

self.use_cache = config["data_set"]["use_cache"]

if config["data_set"]["filter_catalog"]:
filter_catalog = Path(config["data_set"]["filter_catalog"])
elif not config.get("rebuild_manifest", False):
Expand Down Expand Up @@ -1017,9 +1019,10 @@ def _object_id_to_tensor(self, object_id: str) -> torch.Tensor:
torch.Tensor
A tensor with dimension (self.num_filters, self.cutout_shape[0], self.cutout_shape[1])
"""
data_torch = self.tensors.get(object_id, None)
if data_torch is not None:
return data_torch
if self.use_cache is True:
data_torch = self.tensors.get(object_id, None)
if data_torch is not None:
return data_torch

# Read all the files corresponding to this object
data = []
Expand All @@ -1035,5 +1038,7 @@ def _object_id_to_tensor(self, object_id: str) -> torch.Tensor:
# Apply our transform stack
data_torch = self.transform(data_torch) if self.transform is not None else data_torch

self.tensors[object_id] = data_torch
if self.use_cache is True:
self.tensors[object_id] = data_torch

return data_torch
4 changes: 4 additions & 0 deletions src/fibad/fibad_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ test_size = 0.6
# a system source at runtime.
seed = false

#Controls whether images are cached during data loading. For training, this reduces runtimes
#after the first epoch.
use_cache = true

[data_loader]
# Default PyTorch DataLoader parameters
batch_size = 32
Expand Down
2 changes: 2 additions & 0 deletions tests/fibad/test_hsc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def mkconfig(
validate_size=0.1,
seed=False,
filter_catalog=False,
use_cache=False,
):
"""Makes a configuration that points at nonexistent path so HSCDataSet.__init__ will create an object,
and our FakeFitsFS shim can be called.
Expand All @@ -77,6 +78,7 @@ def mkconfig(
"train_size": train_size,
"test_size": test_size,
"validate_size": validate_size,
"use_cache": use_cache,
},
}

Expand Down

0 comments on commit 9d1fb12

Please sign in to comment.