From 69198e518462e2d7575401f5477ab8e8110f3f4f Mon Sep 17 00:00:00 2001 From: wietzesuijker Date: Tue, 17 Dec 2024 09:18:27 +0000 Subject: [PATCH] Feat: safely import azure-storage-blob and boto3 --- dev-env.yml | 2 +- odc/geo/cog/_az.py | 6 +-- odc/geo/cog/_s3.py | 6 +-- odc/geo/cog/_tifffile.py | 24 +++++++++-- setup.cfg | 4 ++ tests/test_az.py | 88 +++++++++++++++++++++++++--------------- tests/test_s3.py | 13 ++++-- 7 files changed, 96 insertions(+), 47 deletions(-) diff --git a/dev-env.yml b/dev-env.yml index 0f87332b..2b432f4b 100644 --- a/dev-env.yml +++ b/dev-env.yml @@ -3,7 +3,7 @@ channels: - conda-forge dependencies: - - python =3.8 + - python =3.10 # odc-geo dependencies - pyproj diff --git a/odc/geo/cog/_az.py b/odc/geo/cog/_az.py index 3371c6da..5c79497a 100644 --- a/odc/geo/cog/_az.py +++ b/odc/geo/cog/_az.py @@ -33,7 +33,7 @@ def max_part(self) -> int: return 50_000 -class MultiPartUpload(AzureLimits, MultiPartUploadBase): +class AzMultiPartUpload(AzureLimits, MultiPartUploadBase): def __init__( self, account_url: str, container: str, blob: str, credential: Any = None ): @@ -161,11 +161,11 @@ class DelayedAzureWriter(AzureLimits): Dask-compatible writer for Azure Blob Storage multipart uploads. """ - def __init__(self, mpu: MultiPartUpload, kw: dict[str, Any]): + def __init__(self, mpu: AzMultiPartUpload, kw: dict[str, Any]): """ Initialise the Azure writer. - :param mpu: MultiPartUpload instance. + :param mpu: AzMultiPartUpload instance. :param kw: Additional parameters for the writer. """ self.mpu = mpu diff --git a/odc/geo/cog/_s3.py b/odc/geo/cog/_s3.py index 2434c9da..367b31c9 100644 --- a/odc/geo/cog/_s3.py +++ b/odc/geo/cog/_s3.py @@ -70,7 +70,7 @@ def max_part(self) -> int: return 10_000 -class MultiPartUpload(S3Limits, MultiPartUploadBase): +class S3MultiPartUpload(S3Limits, MultiPartUploadBase): """ Dask to S3 dumper. """ @@ -237,7 +237,7 @@ class DelayedS3Writer(S3Limits): # pylint: disable=import-outside-toplevel,import-error - def __init__(self, mpu: MultiPartUpload, kw: dict[str, Any]): + def __init__(self, mpu: S3MultiPartUpload, kw: dict[str, Any]): self.mpu = mpu self.kw = kw # mostly ContentType= kinda thing self._shared_var: Optional["distributed.Variable"] = None @@ -263,7 +263,7 @@ def _shared(self, client: "distributed.Client") -> "distributed.Variable": self._shared_var = Variable(self._build_name("MPUpload"), client) return self._shared_var - def _ensure_init(self, final_write: bool = False) -> MultiPartUpload: + def _ensure_init(self, final_write: bool = False) -> S3MultiPartUpload: # pylint: disable=too-many-return-statements mpu = self.mpu if mpu.started: diff --git a/odc/geo/cog/_tifffile.py b/odc/geo/cog/_tifffile.py index 9bdd1ccc..7a895869 100644 --- a/odc/geo/cog/_tifffile.py +++ b/odc/geo/cog/_tifffile.py @@ -17,15 +17,29 @@ import numpy as np import xarray as xr - from .._interop import have from ..geobox import GeoBox from ..math import resolve_nodata from ..types import Shape2d, SomeNodata, Unset, shape_ -from ._az import MultiPartUpload as AzMultiPartUpload from ._mpu import mpu_write from ._mpu_fs import MPUFileSink -from ._s3 import MultiPartUpload as S3MultiPartUpload, s3_parse_url + +try: + from ._az import AzMultiPartUpload + + HAVE_AZURE = True +except ImportError: + AzMultiPartUpload = None + HAVE_AZURE = False +try: + from ._s3 import S3MultiPartUpload, s3_parse_url + + HAVE_S3 = True +except ImportError: + S3MultiPartUpload = None + s3_parse_url = None + HAVE_S3 = False + from ._shared import ( GDAL_COMP, GEOTIFF_TAGS, @@ -738,9 +752,13 @@ def save_cog_with_dask( # Determine output type and initiate uploader parsed_url = urlparse(dst) if parsed_url.scheme == "s3": + if not HAVE_S3: + raise ImportError("Install `boto3` to enable S3 support.") bucket, key = s3_parse_url(dst) uploader = S3MultiPartUpload(bucket, key, **aws) elif parsed_url.scheme == "az": + if not HAVE_AZURE: + raise ImportError("Install azure-storage-blob` to enable Azure support.") uploader = AzMultiPartUpload( account_url=azure.get("account_url"), container=parsed_url.netloc, diff --git a/setup.cfg b/setup.cfg index 4f09ae56..a19ab783 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,10 +54,14 @@ tiff = s3 = boto3 +az = + azure-storage-blob + all = %(warp)s %(tiff)s %(s3)s + %(az)s test = pytest diff --git a/tests/test_az.py b/tests/test_az.py index 00325357..024e01ce 100644 --- a/tests/test_az.py +++ b/tests/test_az.py @@ -1,37 +1,50 @@ -"""Tests for the Azure MultiPartUpload class.""" +"""Tests for the Azure AzMultiPartUpload class.""" +import base64 import unittest from unittest.mock import MagicMock, patch -from odc.geo.cog._az import MultiPartUpload +# Conditional import for Azure support +try: + from odc.geo.cog._az import AzMultiPartUpload + HAVE_AZURE = True +except ImportError: + AzMultiPartUpload = None + HAVE_AZURE = False -def test_mpu_init(): - """Basic test for the MultiPartUpload class.""" - account_url = "https://account_name.blob.core.windows.net" - mpu = MultiPartUpload(account_url, "container", "some.blob", None) - if mpu.account_url != account_url: - raise AssertionError(f"mpu.account_url should be '{account_url}'.") - if mpu.container != "container": - raise AssertionError("mpu.container should be 'container'.") - if mpu.blob != "some.blob": - raise AssertionError("mpu.blob should be 'some.blob'.") - if mpu.credential is not None: - raise AssertionError("mpu.credential should be 'None'.") +def require_azure(test_func): + """Decorator to skip tests if Azure dependencies are not installed.""" + return unittest.skipUnless(HAVE_AZURE, "Azure dependencies are not installed")( + test_func + ) -class TestMultiPartUpload(unittest.TestCase): - """Test the MultiPartUpload class.""" +class TestAzMultiPartUpload(unittest.TestCase): + """Test the AzMultiPartUpload class.""" + + @require_azure + def test_mpu_init(self): + """Basic test for AzMultiPartUpload initialization.""" + account_url = "https://account_name.blob.core.windows.net" + mpu = AzMultiPartUpload(account_url, "container", "some.blob", None) + + self.assertEqual(mpu.account_url, account_url) + self.assertEqual(mpu.container, "container") + self.assertEqual(mpu.blob, "some.blob") + self.assertIsNone(mpu.credential) + + @require_azure @patch("odc.geo.cog._az.BlobServiceClient") def test_azure_multipart_upload(self, mock_blob_service_client): - """Test the MultiPartUpload class.""" - # Arrange - mock the Azure Blob SDK - # Mock the blob client and its methods + """Test the full Azure AzMultiPartUpload functionality.""" + # Arrange - Mock Azure Blob SDK client structure mock_blob_client = MagicMock() mock_container_client = MagicMock() - mcc = mock_container_client - mock_blob_service_client.return_value.get_container_client.return_value = mcc + mock_blob_service_client.return_value.get_container_client.return_value = ( + mock_container_client + ) mock_container_client.get_blob_client.return_value = mock_blob_client # Simulate return values for Azure Blob SDK methods @@ -43,32 +56,41 @@ def test_azure_multipart_upload(self, mock_blob_service_client): blob = "mock-blob" credential = "mock-sas-token" - # Act - create an instance of MultiPartUpload and call its methods - azure_upload = MultiPartUpload(account_url, container, blob, credential) + # Act + azure_upload = AzMultiPartUpload(account_url, container, blob, credential) upload_id = azure_upload.initiate() part1 = azure_upload.write_part(1, b"first chunk of data") part2 = azure_upload.write_part(2, b"second chunk of data") etag = azure_upload.finalise([part1, part2]) - # Assert - check the results - # Check that the initiate method behaves as expected + # Correctly calculate block IDs + block_id1 = base64.b64encode(b"block-1").decode("utf-8") + block_id2 = base64.b64encode(b"block-2").decode("utf-8") + + # Assert self.assertEqual(upload_id, "azure-block-upload") + self.assertEqual(etag, "mock-etag") - # Verify the calls to Azure Blob SDK methods + # Verify BlobServiceClient instantiation mock_blob_service_client.assert_called_once_with( account_url=account_url, credential=credential ) + + # Verify stage_block calls mock_blob_client.stage_block.assert_any_call( - part1["BlockId"], b"first chunk of data" + block_id=block_id1, data=b"first chunk of data" ) mock_blob_client.stage_block.assert_any_call( - part2["BlockId"], b"second chunk of data" + block_id=block_id2, data=b"second chunk of data" ) - mock_blob_client.commit_block_list.assert_called_once() - self.assertEqual(etag, "mock-etag") - # Verify block list passed during finalise + # Verify commit_block_list was called correctly block_list = mock_blob_client.commit_block_list.call_args[0][0] self.assertEqual(len(block_list), 2) - self.assertEqual(block_list[0].id, part1["BlockId"]) - self.assertEqual(block_list[1].id, part2["BlockId"]) + self.assertEqual(block_list[0].id, block_id1) + self.assertEqual(block_list[1].id, block_id2) + mock_blob_client.commit_block_list.assert_called_once() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_s3.py b/tests/test_s3.py index 8349bd81..accde52a 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -1,9 +1,14 @@ -from odc.geo.cog._s3 import MultiPartUpload +"""Tests for odc.geo.cog._s3.""" + +from odc.geo.cog._s3 import S3MultiPartUpload # TODO: moto def test_s3_mpu(): - mpu = MultiPartUpload("bucket", "file.dat") - assert mpu.bucket == "bucket" - assert mpu.key == "file.dat" + """Test S3MultiPartUpload class initialization.""" + mpu = S3MultiPartUpload("bucket", "file.dat") + if mpu.bucket != "bucket": + raise ValueError("Invalid bucket") + if mpu.key != "file.dat": + raise ValueError("Invalid key")