diff --git a/xee/ext.py b/xee/ext.py index eb68d4c..9a59fe7 100644 --- a/xee/ext.py +++ b/xee/ext.py @@ -102,6 +102,11 @@ class EarthEngineStore(common.AbstractDataStore): 'height': 256, } + GETITEM_KWARGS: Dict[str, int] = { + 'max_retries': 6, + 'initial_delay': 500, + } + SCALE_UNITS: Dict[str, int] = { 'degree': 1, 'metre': 10_000, @@ -147,6 +152,7 @@ def open( ee_init_kwargs: Optional[Dict[str, Any]] = None, ee_init_if_necessary: bool = False, executor_kwargs: Optional[Dict[str, Any]] = None, + getitem_kwargs: Optional[Dict[str, int]] = None, ) -> 'EarthEngineStore': if mode != 'r': raise ValueError( @@ -168,6 +174,7 @@ def open( ee_init_kwargs=ee_init_kwargs, ee_init_if_necessary=ee_init_if_necessary, executor_kwargs=executor_kwargs, + getitem_kwargs=getitem_kwargs, ) def __init__( @@ -186,6 +193,7 @@ def __init__( ee_init_kwargs: Optional[Dict[str, Any]] = None, ee_init_if_necessary: bool = False, executor_kwargs: Optional[Dict[str, Any]] = None, + getitem_kwargs: Optional[Dict[str, int]] = None, ): self.ee_init_kwargs = ee_init_kwargs self.ee_init_if_necessary = ee_init_if_necessary @@ -195,6 +203,8 @@ def __init__( executor_kwargs = {} self.executor_kwargs = executor_kwargs + self.getitem_kwargs = {**self.GETITEM_KWARGS, **(getitem_kwargs or {})} + self.image_collection = image_collection if n_images != -1: self.image_collection = image_collection.limit(n_images) @@ -478,7 +488,10 @@ def image_to_array( **kwargs, } raw = common.robust_getitem( - pixels_getter, params, catch=ee.ee_exception.EEException + pixels_getter, + params, + catch=ee.ee_exception.EEException, + **self.getitem_kwargs, ) # Extract out the shape information from EE response. @@ -960,6 +973,7 @@ def open_dataset( ee_init_if_necessary: bool = False, ee_init_kwargs: Optional[Dict[str, Any]] = None, executor_kwargs: Optional[Dict[str, Any]] = None, + getitem_kwargs: Optional[Dict[str, int]] = None, ) -> xarray.Dataset: # type: ignore """Open an Earth Engine ImageCollection as an Xarray Dataset. @@ -1032,7 +1046,11 @@ def open_dataset( executor_kwargs (optional): A dictionary of keyword arguments to pass to the ThreadPoolExecutor that handles the parallel computation of pixels i.e. {'max_workers': 2}. - + getitem_kwargs (optional): Exponential backoff kwargs passed into the + xarray function to index the array (`robust_getitem`). + - 'max_retries', the maximum number of retry attempts. Defaults to 6. + - 'initial_delay', the initial delay in milliseconds before the first + retry. Defaults to 500. Returns: An xarray.Dataset that streams in remote data from Earth Engine. """ @@ -1062,6 +1080,7 @@ def open_dataset( ee_init_kwargs=ee_init_kwargs, ee_init_if_necessary=ee_init_if_necessary, executor_kwargs=executor_kwargs, + getitem_kwargs=getitem_kwargs, ) store_entrypoint = backends_store.StoreBackendEntrypoint() diff --git a/xee/ext_integration_test.py b/xee/ext_integration_test.py index 35cb340..880c3e6 100644 --- a/xee/ext_integration_test.py +++ b/xee/ext_integration_test.py @@ -69,6 +69,7 @@ def setUp(self): '2017-01-01', '2017-01-03' ), n_images=64, + getitem_kwargs={'max_retries': 10, 'initial_delay': 1500}, ) self.store_with_neg_mask_value = xee.EarthEngineStore( ee.ImageCollection('LANDSAT/LC08/C01/T1').filterDate( @@ -87,6 +88,7 @@ def setUp(self): '2020-03-30', '2020-04-01' ), n_images=64, + getitem_kwargs={'max_retries': 9}, ) self.all_img_store = xee.EarthEngineStore( ee.ImageCollection('LANDSAT/LC08/C01/T1').filterDate( @@ -267,6 +269,19 @@ def __getitem__(self, params): self.assertEqual(getter.count, 3) + def test_getitem_kwargs(self): + arr = xee.EarthEngineBackendArray('B4', self.store) + self.assertEqual(arr.store.getitem_kwargs['initial_delay'], 1500) + self.assertEqual(arr.store.getitem_kwargs['max_retries'], 10) + + arr1 = xee.EarthEngineBackendArray('longitude', self.lnglat_store) + self.assertEqual(arr1.store.getitem_kwargs['initial_delay'], 500) + self.assertEqual(arr1.store.getitem_kwargs['max_retries'], 6) + + arr2 = xee.EarthEngineBackendArray('spi2y', self.conus_store) + self.assertEqual(arr2.store.getitem_kwargs['initial_delay'], 500) + self.assertEqual(arr2.store.getitem_kwargs['max_retries'], 9) + class EEBackendEntrypointTest(absltest.TestCase):