diff --git a/dev-requirements.txt b/dev-requirements.txt index 34f20d8..b8027b5 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,6 +1,8 @@ pytest +coverage torch coverage pytest-cov adlfs +zarr -r requirements.txt diff --git a/xbatcher/generators.py b/xbatcher/generators.py index da80995..4518e21 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -1,6 +1,7 @@ """Classes for iterating through xarray datarrays / datasets in batches.""" import itertools +import json from collections import OrderedDict from typing import Any, Dict, Hashable, Iterator @@ -65,7 +66,21 @@ def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name='sample'): return ds_stack.transpose(*dim_order) -class BatchGenerator: +class BatchGeneratorBase: + def __init__( + self, + input_dims: Dict[Hashable, int], + input_overlap: Dict[Hashable, int] = {}, + batch_dims: Dict[Hashable, int] = {}, + concat_input_dims: bool = False, + ): + self.input_dims = OrderedDict(input_dims) + self.input_overlap = input_overlap + self.batch_dims = OrderedDict(batch_dims) + self.concat_input_dims = concat_input_dims + + +class BatchGenerator(BatchGeneratorBase): """Create generator for iterating through xarray datarrays / datasets in batches. @@ -107,13 +122,15 @@ def __init__( concat_input_dims: bool = False, preload_batch: bool = True, ): + super().__init__( + input_dims=input_dims, + input_overlap=input_overlap, + batch_dims=batch_dims, + concat_input_dims=concat_input_dims, + ) self.ds = _as_xarray_dataset(ds) # should be a dict - self.input_dims = OrderedDict(input_dims) - self.input_overlap = input_overlap - self.batch_dims = OrderedDict(batch_dims) - self.concat_input_dims = concat_input_dims self.preload_batch = preload_batch self._batches: Dict[ @@ -178,3 +195,72 @@ def _iterate_batch_dims(self, ds): def _iterate_input_dims(self, ds): return _iterate_through_dataset(ds, self.input_dims, self.input_overlap) + + def to_zarr(self, path, chunks={'batch': '1Gb'}): + """ + Store batches into a zarr datastore in `path`. To speed up loading of + batches it is recommended that the chunking across batches is set close + to the available RAM on the computere where you are doing ML model + training + """ + batch_datasets = list(self) + # can't call the batch dimension `batch` because Dataset.batch is used + # for the batch acccessor. Instead we'll call it `batch_number` + ds_all = xr.concat(batch_datasets, dim='batch_number').reset_index( + 'sample' + ) + if 'batch' in chunks: + chunks['batch_number'] = chunks.pop('batch') + + if len(chunks) > 0: + ds_all = ds_all.chunk(chunks) + + for v in StoredBatchesGenerator.INIT_ARGS_TO_SERIALIZE: + ds_all.attrs[v] = json.dumps(getattr(self, v)) + ds_all.to_zarr(path) + + @staticmethod + def from_zarr(path): + """ + Load a batch generator from the zarr datastore at a given `path` + """ + return StoredBatchesGenerator(path=path) + + +class StoredBatchesGenerator(BatchGeneratorBase): + """ + Create a generator which mimicks the behaviour of BatchGenerator but loads + the batches from a zarr store that was previously created with + `BatchGenerator.to_zarr`. Arguments which the original BatchGenerator was + created with are serialized using json and saved as attributes in the + zarr-store + """ + + INIT_ARGS_TO_SERIALIZE = [ + 'input_dims', + 'input_overlap', + 'batch_dims', + 'concat_input_dims', + ] + + def __init__(self, path): + self.ds_batches = xr.open_zarr(path) + self.path = path + + init_kws = { + v: json.loads(self.ds_batches.attrs[v]) + for v in self.INIT_ARGS_TO_SERIALIZE + } + super().__init__(**init_kws) + + def __iter__(self): + for batch_id in self.ds_batches.batch_number.values: + ds_batch = self.ds_batches.sel(batch_number=batch_id) + # create a MultiIndex like we had before storing the batches + stacked_coords = [ + d + for d in ds_batch.coords + if d not in ['sample', 'batch_number'] + ] + ds_batch = ds_batch.set_index(sample=stacked_coords) + yield ds_batch diff --git a/xbatcher/tests/test_to_zarr.py b/xbatcher/tests/test_to_zarr.py new file mode 100644 index 0000000..52ef495 --- /dev/null +++ b/xbatcher/tests/test_to_zarr.py @@ -0,0 +1,35 @@ +import tempfile + +import numpy as np +import xarray as xr + +import xbatcher + + +def test_to_zarr(): + da = xr.DataArray( + np.random.rand(1000, 100, 100), name='foo', dims=['time', 'y', 'x'] + ).chunk({'time': 1}) + + bgen = xbatcher.BatchGenerator(da, {'time': 10}, preload_batch=False) + + for ds_batch in bgen: + ds_first_batch = ds_batch + break + + tempdir = tempfile.TemporaryDirectory().name + bgen.to_zarr(tempdir, chunks={}) + + bgen_loaded = xbatcher.BatchGenerator.from_zarr(tempdir) + + for loaded_batch in bgen_loaded: + loaded_first_batch = loaded_batch + break + + # DataArray.equals doesn't work while the DataArray's are still stacked + da_first_batch = ds_first_batch.unstack() + da_loaded_first_batch = loaded_first_batch.unstack() + # For some reason DataArray.equals doesn't work here, but DataArray.broadcast_equals did + assert da_loaded_first_batch.broadcast_equals(da_first_batch) + # I think this should mean that DataArray.equals should work + assert (da_loaded_first_batch - da_first_batch).max() == 0.0