From edcccb2eb927efa99b10b5596e0a44dc6af51b80 Mon Sep 17 00:00:00 2001 From: wietzesuijker Date: Mon, 16 Dec 2024 21:36:38 +0000 Subject: [PATCH 1/6] Feat: MultiPartUploadBase class for multipart upload --- odc/geo/cog/_multipart.py | 81 +++++++++++++++++++++++++++++++++++++++ odc/geo/cog/_s3.py | 7 ++-- odc/geo/cog/_tifffile.py | 52 ++++++++++++++----------- 3 files changed, 114 insertions(+), 26 deletions(-) create mode 100644 odc/geo/cog/_multipart.py diff --git a/odc/geo/cog/_multipart.py b/odc/geo/cog/_multipart.py new file mode 100644 index 00000000..0ba3c9b1 --- /dev/null +++ b/odc/geo/cog/_multipart.py @@ -0,0 +1,81 @@ +""" +Multipart upload interface. +""" + +from abc import ABC, abstractmethod +from typing import Any, Union + +import dask.bag + + +class MultiPartUploadBase(ABC): + """Abstract base class for multipart upload.""" + + @abstractmethod + def initiate(self, **kwargs) -> str: + """Initiate a multipart upload and return an identifier.""" + pass + + @abstractmethod + def write_part(self, part: int, data: bytes) -> dict[str, Any]: + """Upload a single part.""" + pass + + @abstractmethod + def finalise(self, parts: list[dict[str, Any]]) -> str: + """Finalise the upload with a list of parts.""" + pass + + @abstractmethod + def cancel(self, other: str = ""): + """Cancel the multipart upload.""" + pass + + @property + @abstractmethod + def url(self) -> str: + """Return the URL of the upload target.""" + pass + + @property + @abstractmethod + def started(self) -> bool: + """Check if the multipart upload has been initiated.""" + pass + + @abstractmethod + def writer(self, kw: dict[str, Any], *, client: Any = None) -> Any: + """ + Return a Dask-compatible writer for multipart uploads. + + :param kw: Additional parameters for the writer. + :param client: Dask client for distributed execution. + """ + pass + + @abstractmethod + def upload( + self, + chunks: Union["dask.bag.Bag", list["dask.bag.Bag"]], + *, + mk_header: Any = None, + mk_footer: Any = None, + user_kw: dict[str, Any] = None, + writes_per_chunk: int = 1, + spill_sz: int = 20 * (1 << 20), + client: Any = None, + **kw, + ) -> Any: + """ + Orchestrate the upload process with multipart uploads. + + :param chunks: Dask bag of chunks to upload. + :param mk_header: Function to create header data. + :param mk_footer: Function to create footer data. + :param user_kw: User-provided metadata for the upload. + :param writes_per_chunk: Number of writes per chunk. + :param spill_sz: Spill size for buffering data. + :param client: Dask client for distributed execution. + :return: A Dask delayed object representing the finalised upload. + """ + pass diff --git a/odc/geo/cog/_s3.py b/odc/geo/cog/_s3.py index 6da5cf85..2434c9da 100644 --- a/odc/geo/cog/_s3.py +++ b/odc/geo/cog/_s3.py @@ -10,13 +10,15 @@ from cachetools import cached from ._mpu import PartsWriter, SomeData, mpu_write +from ._multipart import MultiPartUploadBase if TYPE_CHECKING: import dask.bag - import distributed from botocore.credentials import ReadOnlyCredentials from dask.delayed import Delayed + import distributed + _state: dict[str, Any] = {} @@ -68,7 +70,7 @@ def max_part(self) -> int: return 10_000 -class MultiPartUpload(S3Limits): +class MultiPartUpload(S3Limits, MultiPartUploadBase): """ Dask to S3 dumper. """ @@ -195,7 +197,6 @@ def writer(self, kw, *, client: Any = None) -> PartsWriter: writer.prep_client(client) return writer - # pylint: disable=too-many-arguments def upload( self, chunks: "dask.bag.Bag" | list["dask.bag.Bag"], diff --git a/odc/geo/cog/_tifffile.py b/odc/geo/cog/_tifffile.py index 34066293..7ca6f3fe 100644 --- a/odc/geo/cog/_tifffile.py +++ b/odc/geo/cog/_tifffile.py @@ -21,9 +21,10 @@ from ..geobox import GeoBox from ..math import resolve_nodata from ..types import Shape2d, SomeNodata, Unset, shape_ +from ._az import MultiPartUpload as AzMultiPartUpload from ._mpu import mpu_write from ._mpu_fs import MPUFileSink -from ._s3 import MultiPartUpload, s3_parse_url +from ._s3 import MultiPartUpload as S3MultiPartUpload, s3_parse_url from ._shared import ( GDAL_COMP, GEOTIFF_TAGS, @@ -631,10 +632,10 @@ def save_cog_with_dask( **kw, ) -> Any: """ - Save a Cloud Optimized GeoTIFF to S3 or file with Dask. + Save a Cloud Optimized GeoTIFF to S3, Azure Blob Storage, or file with Dask. :param xx: Pixels as :py:class:`xarray.DataArray` backed by Dask - :param dst: S3 url or a file path on shared storage + :param dst: S3, Azure URL, or file path :param compression: Compression to use, default is ``DEFLATE`` :param level: Compression "level", depends on chosen compression :param predictor: TIFF predictor setting @@ -643,6 +644,7 @@ def save_cog_with_dask( :param blocksize: Configure blocksizes for main and overview images :param bigtiff: Generate BigTIFF by default, set to ``False`` to disable :param aws: Configure AWS write access + :param azure: Azure credentials/config :param client: Dask client :param stats: Set to ``False`` to disable stats computation @@ -699,7 +701,7 @@ def save_cog_with_dask( if band_names and len(band_names) != meta.nsamples: raise ValueError( - f"Found {len(band_names)} band names ({band_names}) but there are {meta.nsamples} bands." + f"Found {len(band_names)} band names ({band_names}), expected {meta.nsamples} bands." ) layers = _pyramids_from_cog_metadata(xx, meta, resampling=overview_resampling) @@ -731,19 +733,23 @@ def save_cog_with_dask( "_stats": _stats, } - tiles_write_order = _tiles[::-1] - if len(tiles_write_order) > 4: - tiles_write_order = [ - dask.bag.concat(tiles_write_order[:4]), - *tiles_write_order[4:], - ] - - bucket, key = s3_parse_url(dst) - if not bucket: - # assume disk output + # Determine output type and initiate uploader + parsed_url = urlparse(dst) + if parsed_url.scheme == "s3": + bucket, key = s3_parse_url(dst) + uploader = S3MultiPartUpload(bucket, key, **aws) + elif parsed_url.scheme == "az": + uploader = AzMultiPartUpload( + account_url=azure.get("account_url"), + container=parsed_url.netloc, + blob=parsed_url.path.lstrip("/"), + credential=azure.get("credential"), + ) + else: + # Assume local disk write = MPUFileSink(dst, parts_base=parts_base) return mpu_write( - tiles_write_order, + _tiles[::-1], write, mk_header=_patch_hdr, user_kw={ @@ -755,15 +761,15 @@ def save_cog_with_dask( **upload_params, ) - upload_params["ContentType"] = ( - "image/tiff;application=geotiff;profile=cloud-optimized" - ) + # Upload tiles + tiles_write_order = _tiles[::-1] # Reverse tiles for writing + if len(tiles_write_order) > 4: # Optimize for larger datasets + tiles_write_order = [ + dask.bag.concat(tiles_write_order[:4]), + *tiles_write_order[4:], + ] - cleanup = aws.pop("cleanup", False) - s3_sink = MultiPartUpload(bucket, key, **aws) - if cleanup: - s3_sink.cancel("all") - return s3_sink.upload( + return uploader.upload( tiles_write_order, mk_header=_patch_hdr, user_kw={ From 141bf21e11bb814bec5f346b4b0ec7a0644d8bfd Mon Sep 17 00:00:00 2001 From: wietzesuijker Date: Mon, 16 Dec 2024 21:37:26 +0000 Subject: [PATCH 2/6] Feat: azure backend to save cogs with dask --- odc/geo/cog/_az.py | 191 +++++++++++++++++++++++++++++++++++++++++++++ tests/test_az.py | 74 ++++++++++++++++++ 2 files changed, 265 insertions(+) create mode 100644 odc/geo/cog/_az.py create mode 100644 tests/test_az.py diff --git a/odc/geo/cog/_az.py b/odc/geo/cog/_az.py new file mode 100644 index 00000000..3371c6da --- /dev/null +++ b/odc/geo/cog/_az.py @@ -0,0 +1,191 @@ +import base64 +from typing import Any, Union + +from azure.storage.blob import BlobBlock, BlobServiceClient +from dask.delayed import Delayed + +from ._mpu import mpu_write +from ._multipart import MultiPartUploadBase + + +class AzureLimits: + """ + Common Azure writer settings. + """ + + @property + def min_write_sz(self) -> int: + # Azure minimum write size for blocks (default is 4 MiB) + return 4 * (1 << 20) + + @property + def max_write_sz(self) -> int: + # Azure maximum write size for blocks (default is 100 MiB) + return 100 * (1 << 20) + + @property + def min_part(self) -> int: + return 1 + + @property + def max_part(self) -> int: + # Azure supports up to 50,000 blocks per blob + return 50_000 + + +class MultiPartUpload(AzureLimits, MultiPartUploadBase): + def __init__( + self, account_url: str, container: str, blob: str, credential: Any = None + ): + """ + Initialise Azure multipart upload. + + :param account_url: URL of the Azure storage account. + :param container: Name of the container. + :param blob: Name of the blob. + :param credential: Authentication credentials (e.g., SAS token or key). + """ + self.account_url = account_url + self.container = container + self.blob = blob + self.credential = credential + + # Initialise Azure Blob service client + self.blob_service_client = BlobServiceClient( + account_url=account_url, credential=credential + ) + self.container_client = self.blob_service_client.get_container_client(container) + self.blob_client = self.container_client.get_blob_client(blob) + + self.block_ids: list[str] = [] + + def initiate(self, **kwargs) -> str: + """ + Initialise the upload. No-op for Azure. + """ + return "azure-block-upload" + + def write_part(self, part: int, data: bytes) -> dict[str, Any]: + """ + Stage a block in Azure. + + :param part: Part number (unique). + :param data: Data for this part. + :return: A dictionary containing part information. + """ + block_id = base64.b64encode(f"block-{part}".encode()).decode() + self.blob_client.stage_block(block_id=block_id, data=data) + self.block_ids.append(block_id) + return {"PartNumber": part, "BlockId": block_id} + + def finalise(self, parts: list[dict[str, Any]]) -> str: + """ + Commit the block list to finalise the upload. + + :param parts: List of uploaded parts metadata. + :return: The ETag of the finalised blob. + """ + block_list = [BlobBlock(block_id=part["BlockId"]) for part in parts] + self.blob_client.commit_block_list(block_list) + return self.blob_client.get_blob_properties().etag + + def cancel(self): + """ + Cancel the upload by clearing the block list. + """ + self.block_ids.clear() + + @property + def url(self) -> str: + """ + Get the Azure blob URL. + + :return: The full URL of the blob. + """ + return self.blob_client.url + + @property + def started(self) -> bool: + """ + Check if any blocks have been staged. + + :return: True if blocks have been staged, False otherwise. + """ + return bool(self.block_ids) + + def writer(self, kw: dict[str, Any], client: Any = None): + """ + Return a stateless writer compatible with Dask. + """ + return DelayedAzureWriter(self, kw) + + def upload( + self, + chunks: Union["dask.bag.Bag", list["dask.bag.Bag"]], + *, + mk_header: Any = None, + mk_footer: Any = None, + user_kw: dict[str, Any] = None, + writes_per_chunk: int = 1, + spill_sz: int = 20 * (1 << 20), + client: Any = None, + **kw, + ) -> "Delayed": + """ + Upload chunks to Azure Blob Storage with multipart uploads. + + :param chunks: Dask bag of chunks to upload. + :param mk_header: Function to create header data. + :param mk_footer: Function to create footer data. + :param user_kw: User-provided metadata for the upload. + :param writes_per_chunk: Number of writes per chunk. + :param spill_sz: Spill size for buffering data. + :param client: Dask client for distributed execution. + :return: A Dask delayed object representing the finalised upload. + """ + write = self.writer(kw, client=client) if spill_sz else None + return mpu_write( + chunks, + write, + mk_header=mk_header, + mk_footer=mk_footer, + user_kw=user_kw, + writes_per_chunk=writes_per_chunk, + spill_sz=spill_sz, + dask_name_prefix="azure-finalise", + ) + + +class DelayedAzureWriter(AzureLimits): + """ + Dask-compatible writer for Azure Blob Storage multipart uploads. + """ + + def __init__(self, mpu: MultiPartUpload, kw: dict[str, Any]): + """ + Initialise the Azure writer. + + :param mpu: MultiPartUpload instance. + :param kw: Additional parameters for the writer. + """ + self.mpu = mpu + self.kw = kw # Additional metadata like ContentType + + def __call__(self, part: int, data: bytes) -> dict[str, Any]: + """ + Write a single part to Azure Blob Storage. + + :param part: Part number. + :param data: Chunk data. + :return: Metadata for the written part. + """ + return self.mpu.write_part(part, data) + + def finalise(self, parts: list[dict[str, Any]]) -> str: + """ + Finalise the upload by committing the block list. + + :param parts: List of uploaded parts metadata. + :return: ETag of the finalised blob. + """ + return self.mpu.finalise(parts) diff --git a/tests/test_az.py b/tests/test_az.py new file mode 100644 index 00000000..00325357 --- /dev/null +++ b/tests/test_az.py @@ -0,0 +1,74 @@ +"""Tests for the Azure MultiPartUpload class.""" + +import unittest +from unittest.mock import MagicMock, patch + +from odc.geo.cog._az import MultiPartUpload + + +def test_mpu_init(): + """Basic test for the MultiPartUpload class.""" + account_url = "https://account_name.blob.core.windows.net" + mpu = MultiPartUpload(account_url, "container", "some.blob", None) + if mpu.account_url != account_url: + raise AssertionError(f"mpu.account_url should be '{account_url}'.") + if mpu.container != "container": + raise AssertionError("mpu.container should be 'container'.") + if mpu.blob != "some.blob": + raise AssertionError("mpu.blob should be 'some.blob'.") + if mpu.credential is not None: + raise AssertionError("mpu.credential should be 'None'.") + + +class TestMultiPartUpload(unittest.TestCase): + """Test the MultiPartUpload class.""" + + @patch("odc.geo.cog._az.BlobServiceClient") + def test_azure_multipart_upload(self, mock_blob_service_client): + """Test the MultiPartUpload class.""" + # Arrange - mock the Azure Blob SDK + # Mock the blob client and its methods + mock_blob_client = MagicMock() + mock_container_client = MagicMock() + mcc = mock_container_client + mock_blob_service_client.return_value.get_container_client.return_value = mcc + mock_container_client.get_blob_client.return_value = mock_blob_client + + # Simulate return values for Azure Blob SDK methods + mock_blob_client.get_blob_properties.return_value.etag = "mock-etag" + + # Test parameters + account_url = "https://mockaccount.blob.core.windows.net" + container = "mock-container" + blob = "mock-blob" + credential = "mock-sas-token" + + # Act - create an instance of MultiPartUpload and call its methods + azure_upload = MultiPartUpload(account_url, container, blob, credential) + upload_id = azure_upload.initiate() + part1 = azure_upload.write_part(1, b"first chunk of data") + part2 = azure_upload.write_part(2, b"second chunk of data") + etag = azure_upload.finalise([part1, part2]) + + # Assert - check the results + # Check that the initiate method behaves as expected + self.assertEqual(upload_id, "azure-block-upload") + + # Verify the calls to Azure Blob SDK methods + mock_blob_service_client.assert_called_once_with( + account_url=account_url, credential=credential + ) + mock_blob_client.stage_block.assert_any_call( + part1["BlockId"], b"first chunk of data" + ) + mock_blob_client.stage_block.assert_any_call( + part2["BlockId"], b"second chunk of data" + ) + mock_blob_client.commit_block_list.assert_called_once() + self.assertEqual(etag, "mock-etag") + + # Verify block list passed during finalise + block_list = mock_blob_client.commit_block_list.call_args[0][0] + self.assertEqual(len(block_list), 2) + self.assertEqual(block_list[0].id, part1["BlockId"]) + self.assertEqual(block_list[1].id, part2["BlockId"]) From 0591a68293c7540820ecd84b28a3629a0467f0b1 Mon Sep 17 00:00:00 2001 From: wietzesuijker Date: Tue, 17 Dec 2024 09:18:27 +0000 Subject: [PATCH 3/6] Feat: safely import azure-storage-blob and boto3 --- odc/geo/cog/_az.py | 6 +-- odc/geo/cog/_s3.py | 6 +-- odc/geo/cog/_tifffile.py | 24 +++++++++-- setup.cfg | 4 ++ tests/test_az.py | 88 +++++++++++++++++++++++++--------------- tests/test_s3.py | 30 +++++++++++--- 6 files changed, 111 insertions(+), 47 deletions(-) diff --git a/odc/geo/cog/_az.py b/odc/geo/cog/_az.py index 3371c6da..5c79497a 100644 --- a/odc/geo/cog/_az.py +++ b/odc/geo/cog/_az.py @@ -33,7 +33,7 @@ def max_part(self) -> int: return 50_000 -class MultiPartUpload(AzureLimits, MultiPartUploadBase): +class AzMultiPartUpload(AzureLimits, MultiPartUploadBase): def __init__( self, account_url: str, container: str, blob: str, credential: Any = None ): @@ -161,11 +161,11 @@ class DelayedAzureWriter(AzureLimits): Dask-compatible writer for Azure Blob Storage multipart uploads. """ - def __init__(self, mpu: MultiPartUpload, kw: dict[str, Any]): + def __init__(self, mpu: AzMultiPartUpload, kw: dict[str, Any]): """ Initialise the Azure writer. - :param mpu: MultiPartUpload instance. + :param mpu: AzMultiPartUpload instance. :param kw: Additional parameters for the writer. """ self.mpu = mpu diff --git a/odc/geo/cog/_s3.py b/odc/geo/cog/_s3.py index 2434c9da..367b31c9 100644 --- a/odc/geo/cog/_s3.py +++ b/odc/geo/cog/_s3.py @@ -70,7 +70,7 @@ def max_part(self) -> int: return 10_000 -class MultiPartUpload(S3Limits, MultiPartUploadBase): +class S3MultiPartUpload(S3Limits, MultiPartUploadBase): """ Dask to S3 dumper. """ @@ -237,7 +237,7 @@ class DelayedS3Writer(S3Limits): # pylint: disable=import-outside-toplevel,import-error - def __init__(self, mpu: MultiPartUpload, kw: dict[str, Any]): + def __init__(self, mpu: S3MultiPartUpload, kw: dict[str, Any]): self.mpu = mpu self.kw = kw # mostly ContentType= kinda thing self._shared_var: Optional["distributed.Variable"] = None @@ -263,7 +263,7 @@ def _shared(self, client: "distributed.Client") -> "distributed.Variable": self._shared_var = Variable(self._build_name("MPUpload"), client) return self._shared_var - def _ensure_init(self, final_write: bool = False) -> MultiPartUpload: + def _ensure_init(self, final_write: bool = False) -> S3MultiPartUpload: # pylint: disable=too-many-return-statements mpu = self.mpu if mpu.started: diff --git a/odc/geo/cog/_tifffile.py b/odc/geo/cog/_tifffile.py index 7ca6f3fe..8d813c4d 100644 --- a/odc/geo/cog/_tifffile.py +++ b/odc/geo/cog/_tifffile.py @@ -16,15 +16,29 @@ import numpy as np import xarray as xr - from .._interop import have from ..geobox import GeoBox from ..math import resolve_nodata from ..types import Shape2d, SomeNodata, Unset, shape_ -from ._az import MultiPartUpload as AzMultiPartUpload from ._mpu import mpu_write from ._mpu_fs import MPUFileSink -from ._s3 import MultiPartUpload as S3MultiPartUpload, s3_parse_url + +try: + from ._az import AzMultiPartUpload + + HAVE_AZURE = True +except ImportError: + AzMultiPartUpload = None + HAVE_AZURE = False +try: + from ._s3 import S3MultiPartUpload, s3_parse_url + + HAVE_S3 = True +except ImportError: + S3MultiPartUpload = None + s3_parse_url = None + HAVE_S3 = False + from ._shared import ( GDAL_COMP, GEOTIFF_TAGS, @@ -736,9 +750,13 @@ def save_cog_with_dask( # Determine output type and initiate uploader parsed_url = urlparse(dst) if parsed_url.scheme == "s3": + if not HAVE_S3: + raise ImportError("Install `boto3` to enable S3 support.") bucket, key = s3_parse_url(dst) uploader = S3MultiPartUpload(bucket, key, **aws) elif parsed_url.scheme == "az": + if not HAVE_AZURE: + raise ImportError("Install azure-storage-blob` to enable Azure support.") uploader = AzMultiPartUpload( account_url=azure.get("account_url"), container=parsed_url.netloc, diff --git a/setup.cfg b/setup.cfg index 4f09ae56..a19ab783 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,10 +54,14 @@ tiff = s3 = boto3 +az = + azure-storage-blob + all = %(warp)s %(tiff)s %(s3)s + %(az)s test = pytest diff --git a/tests/test_az.py b/tests/test_az.py index 00325357..024e01ce 100644 --- a/tests/test_az.py +++ b/tests/test_az.py @@ -1,37 +1,50 @@ -"""Tests for the Azure MultiPartUpload class.""" +"""Tests for the Azure AzMultiPartUpload class.""" +import base64 import unittest from unittest.mock import MagicMock, patch -from odc.geo.cog._az import MultiPartUpload +# Conditional import for Azure support +try: + from odc.geo.cog._az import AzMultiPartUpload + HAVE_AZURE = True +except ImportError: + AzMultiPartUpload = None + HAVE_AZURE = False -def test_mpu_init(): - """Basic test for the MultiPartUpload class.""" - account_url = "https://account_name.blob.core.windows.net" - mpu = MultiPartUpload(account_url, "container", "some.blob", None) - if mpu.account_url != account_url: - raise AssertionError(f"mpu.account_url should be '{account_url}'.") - if mpu.container != "container": - raise AssertionError("mpu.container should be 'container'.") - if mpu.blob != "some.blob": - raise AssertionError("mpu.blob should be 'some.blob'.") - if mpu.credential is not None: - raise AssertionError("mpu.credential should be 'None'.") +def require_azure(test_func): + """Decorator to skip tests if Azure dependencies are not installed.""" + return unittest.skipUnless(HAVE_AZURE, "Azure dependencies are not installed")( + test_func + ) -class TestMultiPartUpload(unittest.TestCase): - """Test the MultiPartUpload class.""" +class TestAzMultiPartUpload(unittest.TestCase): + """Test the AzMultiPartUpload class.""" + + @require_azure + def test_mpu_init(self): + """Basic test for AzMultiPartUpload initialization.""" + account_url = "https://account_name.blob.core.windows.net" + mpu = AzMultiPartUpload(account_url, "container", "some.blob", None) + + self.assertEqual(mpu.account_url, account_url) + self.assertEqual(mpu.container, "container") + self.assertEqual(mpu.blob, "some.blob") + self.assertIsNone(mpu.credential) + + @require_azure @patch("odc.geo.cog._az.BlobServiceClient") def test_azure_multipart_upload(self, mock_blob_service_client): - """Test the MultiPartUpload class.""" - # Arrange - mock the Azure Blob SDK - # Mock the blob client and its methods + """Test the full Azure AzMultiPartUpload functionality.""" + # Arrange - Mock Azure Blob SDK client structure mock_blob_client = MagicMock() mock_container_client = MagicMock() - mcc = mock_container_client - mock_blob_service_client.return_value.get_container_client.return_value = mcc + mock_blob_service_client.return_value.get_container_client.return_value = ( + mock_container_client + ) mock_container_client.get_blob_client.return_value = mock_blob_client # Simulate return values for Azure Blob SDK methods @@ -43,32 +56,41 @@ def test_azure_multipart_upload(self, mock_blob_service_client): blob = "mock-blob" credential = "mock-sas-token" - # Act - create an instance of MultiPartUpload and call its methods - azure_upload = MultiPartUpload(account_url, container, blob, credential) + # Act + azure_upload = AzMultiPartUpload(account_url, container, blob, credential) upload_id = azure_upload.initiate() part1 = azure_upload.write_part(1, b"first chunk of data") part2 = azure_upload.write_part(2, b"second chunk of data") etag = azure_upload.finalise([part1, part2]) - # Assert - check the results - # Check that the initiate method behaves as expected + # Correctly calculate block IDs + block_id1 = base64.b64encode(b"block-1").decode("utf-8") + block_id2 = base64.b64encode(b"block-2").decode("utf-8") + + # Assert self.assertEqual(upload_id, "azure-block-upload") + self.assertEqual(etag, "mock-etag") - # Verify the calls to Azure Blob SDK methods + # Verify BlobServiceClient instantiation mock_blob_service_client.assert_called_once_with( account_url=account_url, credential=credential ) + + # Verify stage_block calls mock_blob_client.stage_block.assert_any_call( - part1["BlockId"], b"first chunk of data" + block_id=block_id1, data=b"first chunk of data" ) mock_blob_client.stage_block.assert_any_call( - part2["BlockId"], b"second chunk of data" + block_id=block_id2, data=b"second chunk of data" ) - mock_blob_client.commit_block_list.assert_called_once() - self.assertEqual(etag, "mock-etag") - # Verify block list passed during finalise + # Verify commit_block_list was called correctly block_list = mock_blob_client.commit_block_list.call_args[0][0] self.assertEqual(len(block_list), 2) - self.assertEqual(block_list[0].id, part1["BlockId"]) - self.assertEqual(block_list[1].id, part2["BlockId"]) + self.assertEqual(block_list[0].id, block_id1) + self.assertEqual(block_list[1].id, block_id2) + mock_blob_client.commit_block_list.assert_called_once() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_s3.py b/tests/test_s3.py index 8349bd81..d04a462a 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -1,9 +1,29 @@ -from odc.geo.cog._s3 import MultiPartUpload +"""Tests for odc.geo.cog._s3.""" -# TODO: moto +import unittest +from odc.geo.cog._s3 import S3MultiPartUpload +# Conditional import for S3 support +try: + from odc.geo.cog._s3 import S3MultiPartUpload + + HAVE_S3 = True +except ImportError: + S3MultiPartUpload = None + HAVE_S3 = False + + +def require_s3(test_func): + """Decorator to skip tests if s3 dependencies are not installed.""" + return unittest.skipUnless(HAVE_S3, "s3 dependencies are not installed")(test_func) + + +@require_s3 def test_s3_mpu(): - mpu = MultiPartUpload("bucket", "file.dat") - assert mpu.bucket == "bucket" - assert mpu.key == "file.dat" + """Test S3MultiPartUpload class initialization.""" + mpu = S3MultiPartUpload("bucket", "file.dat") + if mpu.bucket != "bucket": + raise ValueError("Invalid bucket") + if mpu.key != "file.dat": + raise ValueError("Invalid key") From bc8811cc95912424d52255844ee69ccee1411bb5 Mon Sep 17 00:00:00 2001 From: wietzesuijker Date: Thu, 19 Dec 2024 22:40:18 +0000 Subject: [PATCH 4/6] Feat: allow s3 and or az dependencies Merged branch 'develop' into feat/save-cog-to-azure. --- odc/geo/_interop.py | 8 ++ odc/geo/cog/_az.py | 13 +++- odc/geo/cog/_multipart.py | 4 + odc/geo/cog/_tifffile.py | 60 +++++++-------- tests/test_az.py | 154 +++++++++++++++++--------------------- 5 files changed, 116 insertions(+), 123 deletions(-) diff --git a/odc/geo/_interop.py b/odc/geo/_interop.py index e7c61d49..f25ea5d6 100644 --- a/odc/geo/_interop.py +++ b/odc/geo/_interop.py @@ -43,6 +43,14 @@ def datacube(self) -> bool: def tifffile(self) -> bool: return self._check("tifffile") + @property + def azure(self) -> bool: + return self._check("azure.storage.blob") + + @property + def botocore(self) -> bool: + return self._check("botocore") + @staticmethod def _check(lib_name: str) -> bool: return importlib.util.find_spec(lib_name) is not None diff --git a/odc/geo/cog/_az.py b/odc/geo/cog/_az.py index 5c79497a..9995a0ab 100644 --- a/odc/geo/cog/_az.py +++ b/odc/geo/cog/_az.py @@ -1,8 +1,7 @@ import base64 from typing import Any, Union -from azure.storage.blob import BlobBlock, BlobServiceClient -from dask.delayed import Delayed +import dask from ._mpu import mpu_write from ._multipart import MultiPartUploadBase @@ -51,6 +50,9 @@ def __init__( self.credential = credential # Initialise Azure Blob service client + # pylint: disable=import-outside-toplevel,import-error + from azure.storage.blob import BlobServiceClient + self.blob_service_client = BlobServiceClient( account_url=account_url, credential=credential ) @@ -85,6 +87,9 @@ def finalise(self, parts: list[dict[str, Any]]) -> str: :param parts: List of uploaded parts metadata. :return: The ETag of the finalised blob. """ + # pylint: disable=import-outside-toplevel,import-error + from azure.storage.blob import BlobBlock + block_list = [BlobBlock(block_id=part["BlockId"]) for part in parts] self.blob_client.commit_block_list(block_list) return self.blob_client.get_blob_properties().etag @@ -121,7 +126,7 @@ def writer(self, kw: dict[str, Any], client: Any = None): def upload( self, - chunks: Union["dask.bag.Bag", list["dask.bag.Bag"]], + chunks: Union[dask.bag.Bag, list[dask.bag.Bag]], *, mk_header: Any = None, mk_footer: Any = None, @@ -130,7 +135,7 @@ def upload( spill_sz: int = 20 * (1 << 20), client: Any = None, **kw, - ) -> "Delayed": + ) -> dask.delayed.Delayed: """ Upload chunks to Azure Blob Storage with multipart uploads. diff --git a/odc/geo/cog/_multipart.py b/odc/geo/cog/_multipart.py index 0ba3c9b1..c9060bee 100644 --- a/odc/geo/cog/_multipart.py +++ b/odc/geo/cog/_multipart.py @@ -1,5 +1,9 @@ """ Multipart upload interface. + +Defines the `MultiPartUploadBase` class for implementing multipart upload functionality. +This interface standardises methods for initiating, uploading, and finalising +multipart uploads across storage backends. """ from abc import ABC, abstractmethod diff --git a/odc/geo/cog/_tifffile.py b/odc/geo/cog/_tifffile.py index 8d813c4d..fae90408 100644 --- a/odc/geo/cog/_tifffile.py +++ b/odc/geo/cog/_tifffile.py @@ -11,11 +11,13 @@ from functools import partial from io import BytesIO from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from urllib.parse import urlparse from xml.sax.saxutils import escape as xml_escape import numpy as np import xarray as xr + from .._interop import have from ..geobox import GeoBox from ..math import resolve_nodata @@ -23,22 +25,6 @@ from ._mpu import mpu_write from ._mpu_fs import MPUFileSink -try: - from ._az import AzMultiPartUpload - - HAVE_AZURE = True -except ImportError: - AzMultiPartUpload = None - HAVE_AZURE = False -try: - from ._s3 import S3MultiPartUpload, s3_parse_url - - HAVE_S3 = True -except ImportError: - S3MultiPartUpload = None - s3_parse_url = None - HAVE_S3 = False - from ._shared import ( GDAL_COMP, GEOTIFF_TAGS, @@ -641,6 +627,7 @@ def save_cog_with_dask( bigtiff: bool = True, overview_resampling: Union[int, str] = "nearest", aws: Optional[dict[str, Any]] = None, + azure: Optional[dict[str, Any]] = None, client: Any = None, stats: bool | int = True, **kw, @@ -669,13 +656,12 @@ def save_cog_with_dask( from ..xr import ODCExtensionDa - if aws is None: - aws = {} + aws = aws or {} + azure = azure or {} - upload_params = {k: kw.pop(k) for k in ["writes_per_chunk", "spill_sz"] if k in kw} - upload_params.update( - {k: aws.pop(k) for k in ["writes_per_chunk", "spill_sz"] if k in aws} - ) + upload_params = { + k: kw.pop(k, None) for k in ["writes_per_chunk", "spill_sz"] if k in kw + } parts_base = kw.pop("parts_base", None) # Normalise compression settings and remove GDAL compat options from kw @@ -750,19 +736,25 @@ def save_cog_with_dask( # Determine output type and initiate uploader parsed_url = urlparse(dst) if parsed_url.scheme == "s3": - if not HAVE_S3: - raise ImportError("Install `boto3` to enable S3 support.") - bucket, key = s3_parse_url(dst) - uploader = S3MultiPartUpload(bucket, key, **aws) + if have.s3: + from ._s3 import S3MultiPartUpload, s3_parse_url + + bucket, key = s3_parse_url(dst) + uploader = S3MultiPartUpload(bucket, key, **aws) + else: + raise RuntimeError("Please install `boto3` to use S3") elif parsed_url.scheme == "az": - if not HAVE_AZURE: - raise ImportError("Install azure-storage-blob` to enable Azure support.") - uploader = AzMultiPartUpload( - account_url=azure.get("account_url"), - container=parsed_url.netloc, - blob=parsed_url.path.lstrip("/"), - credential=azure.get("credential"), - ) + if have.azure: + from ._az import AzMultiPartUpload + + uploader = AzMultiPartUpload( + account_url=azure.get("account_url"), + container=parsed_url.netloc, + blob=parsed_url.path.lstrip("/"), + credential=azure.get("credential"), + ) + else: + raise RuntimeError("Please install `azure-storage-blob` to use Azure") else: # Assume local disk write = MPUFileSink(dst, parts_base=parts_base) diff --git a/tests/test_az.py b/tests/test_az.py index 024e01ce..9462f091 100644 --- a/tests/test_az.py +++ b/tests/test_az.py @@ -1,96 +1,80 @@ """Tests for the Azure AzMultiPartUpload class.""" import base64 -import unittest from unittest.mock import MagicMock, patch -# Conditional import for Azure support -try: - from odc.geo.cog._az import AzMultiPartUpload +import pytest - HAVE_AZURE = True -except ImportError: - AzMultiPartUpload = None - HAVE_AZURE = False +pytest.importorskip("azure.storage.blob") +from odc.geo.cog._az import AzMultiPartUpload # noqa: E402 -def require_azure(test_func): - """Decorator to skip tests if Azure dependencies are not installed.""" - return unittest.skipUnless(HAVE_AZURE, "Azure dependencies are not installed")( - test_func +@pytest.fixture +def azure_mpu(): + """Fixture for initializing AzMultiPartUpload.""" + account_url = "https://account_name.blob.core.windows.net" + return AzMultiPartUpload(account_url, "container", "some.blob", None) + + +def test_mpu_init(azure_mpu): + """Basic test for AzMultiPartUpload initialization.""" + assert azure_mpu.account_url == "https://account_name.blob.core.windows.net" + assert azure_mpu.container == "container" + assert azure_mpu.blob == "some.blob" + assert azure_mpu.credential is None + + +@patch("odc.geo.cog._az.BlobServiceClient") +def test_azure_multipart_upload(mock_blob_service_client): + """Test the full Azure AzMultiPartUpload functionality.""" + # Mock Azure Blob SDK client structure + mock_blob_client = MagicMock() + mock_container_client = MagicMock() + mock_blob_service_client.return_value.get_container_client.return_value = ( + mock_container_client + ) + mock_container_client.get_blob_client.return_value = mock_blob_client + + # Simulate return values for Azure Blob SDK methods + mock_blob_client.get_blob_properties.return_value.etag = "mock-etag" + + # Test parameters + account_url = "https://mockaccount.blob.core.windows.net" + container = "mock-container" + blob = "mock-blob" + credential = "mock-sas-token" + + # Create an instance of AzMultiPartUpload and call its methods + azure_upload = AzMultiPartUpload(account_url, container, blob, credential) + upload_id = azure_upload.initiate() + part1 = azure_upload.write_part(1, b"first chunk of data") + part2 = azure_upload.write_part(2, b"second chunk of data") + etag = azure_upload.finalise([part1, part2]) + + # Define block IDs + block_id1 = base64.b64encode(b"block-1").decode("utf-8") + block_id2 = base64.b64encode(b"block-2").decode("utf-8") + + # Verify the results + assert upload_id == "azure-block-upload" + assert etag == "mock-etag" + + # Verify BlobServiceClient instantiation + mock_blob_service_client.assert_called_once_with( + account_url=account_url, credential=credential ) + # Verify stage_block calls + mock_blob_client.stage_block.assert_any_call( + block_id=block_id1, data=b"first chunk of data" + ) + mock_blob_client.stage_block.assert_any_call( + block_id=block_id2, data=b"second chunk of data" + ) -class TestAzMultiPartUpload(unittest.TestCase): - """Test the AzMultiPartUpload class.""" - - @require_azure - def test_mpu_init(self): - """Basic test for AzMultiPartUpload initialization.""" - account_url = "https://account_name.blob.core.windows.net" - mpu = AzMultiPartUpload(account_url, "container", "some.blob", None) - - self.assertEqual(mpu.account_url, account_url) - self.assertEqual(mpu.container, "container") - self.assertEqual(mpu.blob, "some.blob") - self.assertIsNone(mpu.credential) - - @require_azure - @patch("odc.geo.cog._az.BlobServiceClient") - def test_azure_multipart_upload(self, mock_blob_service_client): - """Test the full Azure AzMultiPartUpload functionality.""" - # Arrange - Mock Azure Blob SDK client structure - mock_blob_client = MagicMock() - mock_container_client = MagicMock() - mock_blob_service_client.return_value.get_container_client.return_value = ( - mock_container_client - ) - mock_container_client.get_blob_client.return_value = mock_blob_client - - # Simulate return values for Azure Blob SDK methods - mock_blob_client.get_blob_properties.return_value.etag = "mock-etag" - - # Test parameters - account_url = "https://mockaccount.blob.core.windows.net" - container = "mock-container" - blob = "mock-blob" - credential = "mock-sas-token" - - # Act - azure_upload = AzMultiPartUpload(account_url, container, blob, credential) - upload_id = azure_upload.initiate() - part1 = azure_upload.write_part(1, b"first chunk of data") - part2 = azure_upload.write_part(2, b"second chunk of data") - etag = azure_upload.finalise([part1, part2]) - - # Correctly calculate block IDs - block_id1 = base64.b64encode(b"block-1").decode("utf-8") - block_id2 = base64.b64encode(b"block-2").decode("utf-8") - - # Assert - self.assertEqual(upload_id, "azure-block-upload") - self.assertEqual(etag, "mock-etag") - - # Verify BlobServiceClient instantiation - mock_blob_service_client.assert_called_once_with( - account_url=account_url, credential=credential - ) - - # Verify stage_block calls - mock_blob_client.stage_block.assert_any_call( - block_id=block_id1, data=b"first chunk of data" - ) - mock_blob_client.stage_block.assert_any_call( - block_id=block_id2, data=b"second chunk of data" - ) - - # Verify commit_block_list was called correctly - block_list = mock_blob_client.commit_block_list.call_args[0][0] - self.assertEqual(len(block_list), 2) - self.assertEqual(block_list[0].id, block_id1) - self.assertEqual(block_list[1].id, block_id2) - mock_blob_client.commit_block_list.assert_called_once() - - -if __name__ == "__main__": - unittest.main() + # Verify commit_block_list was called correctly + block_list = mock_blob_client.commit_block_list.call_args[0][0] + assert len(block_list) == 2 + assert block_list[0].id == block_id1 + assert block_list[1].id == block_id2 + mock_blob_client.commit_block_list.assert_called_once() From 77bd4c2c9c1bd63f9ac0c8aa3c04451687574349 Mon Sep 17 00:00:00 2001 From: Kirill Kouzoubov Date: Tue, 7 Jan 2025 13:46:22 +1100 Subject: [PATCH 5/6] Fix: mypy errors in geo/cog modules --- odc/geo/cog/_az.py | 16 ++++++++++++---- odc/geo/cog/_multipart.py | 16 +++++----------- odc/geo/cog/_tifffile.py | 13 +++++++++---- odc/geo/gcp.py | 11 +++++++++-- odc/geo/geom.py | 12 +++++++++--- odc/geo/roi.py | 2 +- 6 files changed, 45 insertions(+), 25 deletions(-) diff --git a/odc/geo/cog/_az.py b/odc/geo/cog/_az.py index 9995a0ab..d0f91a4b 100644 --- a/odc/geo/cog/_az.py +++ b/odc/geo/cog/_az.py @@ -2,6 +2,7 @@ from typing import Any, Union import dask +from dask.delayed import Delayed from ._mpu import mpu_write from ._multipart import MultiPartUploadBase @@ -33,6 +34,12 @@ def max_part(self) -> int: class AzMultiPartUpload(AzureLimits, MultiPartUploadBase): + """ + Azure Blob Storage multipart upload. + """ + + # pylint: disable=too-many-instance-attributes + def __init__( self, account_url: str, container: str, blob: str, credential: Any = None ): @@ -94,10 +101,11 @@ def finalise(self, parts: list[dict[str, Any]]) -> str: self.blob_client.commit_block_list(block_list) return self.blob_client.get_blob_properties().etag - def cancel(self): + def cancel(self, other: str = ""): """ Cancel the upload by clearing the block list. """ + assert other == "" self.block_ids.clear() @property @@ -118,7 +126,7 @@ def started(self) -> bool: """ return bool(self.block_ids) - def writer(self, kw: dict[str, Any], client: Any = None): + def writer(self, kw: dict[str, Any], *, client: Any = None): """ Return a stateless writer compatible with Dask. """ @@ -130,12 +138,12 @@ def upload( *, mk_header: Any = None, mk_footer: Any = None, - user_kw: dict[str, Any] = None, + user_kw: dict[str, Any] | None = None, writes_per_chunk: int = 1, spill_sz: int = 20 * (1 << 20), client: Any = None, **kw, - ) -> dask.delayed.Delayed: + ) -> Delayed: """ Upload chunks to Azure Blob Storage with multipart uploads. diff --git a/odc/geo/cog/_multipart.py b/odc/geo/cog/_multipart.py index c9060bee..0fc9b4c8 100644 --- a/odc/geo/cog/_multipart.py +++ b/odc/geo/cog/_multipart.py @@ -7,9 +7,11 @@ """ from abc import ABC, abstractmethod -from typing import Any, Union +from typing import Any, Union, TYPE_CHECKING -import dask.bag +if TYPE_CHECKING: + # pylint: disable=import-outside-toplevel,import-error + import dask.bag class MultiPartUploadBase(ABC): @@ -18,34 +20,28 @@ class MultiPartUploadBase(ABC): @abstractmethod def initiate(self, **kwargs) -> str: """Initiate a multipart upload and return an identifier.""" - pass @abstractmethod def write_part(self, part: int, data: bytes) -> dict[str, Any]: """Upload a single part.""" - pass @abstractmethod def finalise(self, parts: list[dict[str, Any]]) -> str: """Finalise the upload with a list of parts.""" - pass @abstractmethod def cancel(self, other: str = ""): """Cancel the multipart upload.""" - pass @property @abstractmethod def url(self) -> str: """Return the URL of the upload target.""" - pass @property @abstractmethod def started(self) -> bool: """Check if the multipart upload has been initiated.""" - pass @abstractmethod def writer(self, kw: dict[str, Any], *, client: Any = None) -> Any: @@ -55,7 +51,6 @@ def writer(self, kw: dict[str, Any], *, client: Any = None) -> Any: :param kw: Additional parameters for the writer. :param client: Dask client for distributed execution. """ - pass @abstractmethod def upload( @@ -64,7 +59,7 @@ def upload( *, mk_header: Any = None, mk_footer: Any = None, - user_kw: dict[str, Any] = None, + user_kw: dict[str, Any] | None = None, writes_per_chunk: int = 1, spill_sz: int = 20 * (1 << 20), client: Any = None, @@ -82,4 +77,3 @@ def upload( :param client: Dask client for distributed execution. :return: A Dask delayed object representing the finalised upload. """ - pass diff --git a/odc/geo/cog/_tifffile.py b/odc/geo/cog/_tifffile.py index fae90408..3ab86e21 100644 --- a/odc/geo/cog/_tifffile.py +++ b/odc/geo/cog/_tifffile.py @@ -24,6 +24,7 @@ from ..types import Shape2d, SomeNodata, Unset, shape_ from ._mpu import mpu_write from ._mpu_fs import MPUFileSink +from ._multipart import MultiPartUploadBase from ._shared import ( GDAL_COMP, @@ -736,22 +737,26 @@ def save_cog_with_dask( # Determine output type and initiate uploader parsed_url = urlparse(dst) if parsed_url.scheme == "s3": - if have.s3: + if have.botocore: from ._s3 import S3MultiPartUpload, s3_parse_url bucket, key = s3_parse_url(dst) - uploader = S3MultiPartUpload(bucket, key, **aws) + uploader: MultiPartUploadBase = S3MultiPartUpload(bucket, key, **aws) else: raise RuntimeError("Please install `boto3` to use S3") elif parsed_url.scheme == "az": if have.azure: from ._az import AzMultiPartUpload + assert azure is not None + assert "account_url" in azure + assert "credential" in azure + uploader = AzMultiPartUpload( - account_url=azure.get("account_url"), + account_url=azure["account_url"], container=parsed_url.netloc, blob=parsed_url.path.lstrip("/"), - credential=azure.get("credential"), + credential=azure["credential"], ) else: raise RuntimeError("Please install `azure-storage-blob` to use Azure") diff --git a/odc/geo/gcp.py b/odc/geo/gcp.py index 56fade35..f72f7dd4 100644 --- a/odc/geo/gcp.py +++ b/odc/geo/gcp.py @@ -102,9 +102,16 @@ def resolution(self) -> Resolution: def points(self) -> Tuple[Geometry, Geometry]: """Return multipoint geometries for (Pixel, World).""" + pix_points: list[tuple[float, float]] = [ + (float(p[0]), float(p[1])) for p in self._pix.tolist() + ] + wld_points: list[tuple[float, float]] = [ + (float(p[0]), float(p[1])) for p in self._wld.tolist() + ] + return ( - multipoint(self._pix.tolist(), None), - multipoint(self._wld.tolist(), self.crs), + multipoint(pix_points, None), + multipoint(wld_points, self.crs), ) def __dask_tokenize__(self): diff --git a/odc/geo/geom.py b/odc/geo/geom.py index 043ac6c7..6aff6ef4 100644 --- a/odc/geo/geom.py +++ b/odc/geo/geom.py @@ -320,6 +320,7 @@ def boundary(self, pts_per_side: int = 2) -> "Geometry": self.crs, ) + def qr2sample( self, n: int, @@ -350,13 +351,15 @@ def qr2sample( ny = y1 - y0 pts = quasi_random_r2(n, offset=offset) s = numpy.asarray([nx, ny], dtype="float32") - edge_pts = [] + edge_pts: list[tuple[float, float]] = [] if with_edges: sample_density = numpy.sqrt(n / (nx * ny)) n_side = int(numpy.round(sample_density * min(nx, ny))) + 1 n_side = max(2, n_side) - edge_pts = self.boundary(n_side).coords[:-1] + edge_pts = [ + (float(ep[0]), float(ep[1])) for ep in list(self.boundary(n_side).coords[:-1]) + ] if padding is None: padding = 0.3 * min(nx, ny) / (n_side - 1) @@ -368,8 +371,11 @@ def qr2sample( pts[:, 0] += x0 pts[:, 1] += y0 - return multipoint(pts.tolist() + edge_pts, self.crs) + coords: list[tuple[float, float]] = [ + (float(p[0]), float(p[1])) for p in pts.tolist() + ] + edge_pts + return multipoint(coords, self.crs) def wrap_shapely(method): """ diff --git a/odc/geo/roi.py b/odc/geo/roi.py index 7e4ee423..50759ed3 100644 --- a/odc/geo/roi.py +++ b/odc/geo/roi.py @@ -284,7 +284,7 @@ def base(self) -> Shape2d: @property def chunks(self) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: """Dask compatible chunk rerpesentation.""" - y, x = (tuple(np.diff(idx).tolist()) for idx in self._offsets) + y, x = (tuple(map(int, np.diff(idx))) for idx in self._offsets) return (y, x) def locate(self, pix: SomeIndex2d) -> Tuple[int, int]: From 9fcc5d59d6e2c6dc18ab511ca60a56569b1d1b89 Mon Sep 17 00:00:00 2001 From: wietzesuijker Date: Tue, 7 Jan 2025 17:17:09 +0000 Subject: [PATCH 6/6] Feat: Refactor upload method and reduce duplication - Restored type hints in `upload` for improved type safety. - Added `get_mpu_kwargs` to centralize shared keyword arguments. - Simplified `upload` and `mpu_upload` implementations by reusing `get_mpu_kwargs`. - Reduced code duplication across `_mpu.py` and `_multipart.py`. --- odc/geo/cog/_az.py | 44 ++++--------------------------------- odc/geo/cog/_mpu.py | 46 +++++++++++++++++++++++++++++++++++++++ odc/geo/cog/_multipart.py | 36 ++++++++++++++++++------------ odc/geo/cog/_s3.py | 29 ++++-------------------- odc/geo/geom.py | 5 +++-- 5 files changed, 79 insertions(+), 81 deletions(-) diff --git a/odc/geo/cog/_az.py b/odc/geo/cog/_az.py index d0f91a4b..78c5e212 100644 --- a/odc/geo/cog/_az.py +++ b/odc/geo/cog/_az.py @@ -1,10 +1,6 @@ import base64 -from typing import Any, Union +from typing import Any -import dask -from dask.delayed import Delayed - -from ._mpu import mpu_write from ._multipart import MultiPartUploadBase @@ -132,41 +128,9 @@ def writer(self, kw: dict[str, Any], *, client: Any = None): """ return DelayedAzureWriter(self, kw) - def upload( - self, - chunks: Union[dask.bag.Bag, list[dask.bag.Bag]], - *, - mk_header: Any = None, - mk_footer: Any = None, - user_kw: dict[str, Any] | None = None, - writes_per_chunk: int = 1, - spill_sz: int = 20 * (1 << 20), - client: Any = None, - **kw, - ) -> Delayed: - """ - Upload chunks to Azure Blob Storage with multipart uploads. - - :param chunks: Dask bag of chunks to upload. - :param mk_header: Function to create header data. - :param mk_footer: Function to create footer data. - :param user_kw: User-provided metadata for the upload. - :param writes_per_chunk: Number of writes per chunk. - :param spill_sz: Spill size for buffering data. - :param client: Dask client for distributed execution. - :return: A Dask delayed object representing the finalised upload. - """ - write = self.writer(kw, client=client) if spill_sz else None - return mpu_write( - chunks, - write, - mk_header=mk_header, - mk_footer=mk_footer, - user_kw=user_kw, - writes_per_chunk=writes_per_chunk, - spill_sz=spill_sz, - dask_name_prefix="azure-finalise", - ) + def dask_name_prefix(self) -> str: + """Return the Dask name prefix for Azure.""" + return "azure-finalise" class DelayedAzureWriter(AzureLimits): diff --git a/odc/geo/cog/_mpu.py b/odc/geo/cog/_mpu.py index f1776d8d..ddc453a5 100644 --- a/odc/geo/cog/_mpu.py +++ b/odc/geo/cog/_mpu.py @@ -495,3 +495,49 @@ def _finalizer_dask_op( _, rr = _root.flush(write, leftPartId=1, finalise=True) return rr + + +def get_mpu_kwargs( + mk_header=None, + mk_footer=None, + user_kw=None, + writes_per_chunk=1, + spill_sz=20 * (1 << 20), + client=None, +) -> dict: + """ + Construct shared keyword arguments for multipart uploads. + """ + return { + "mk_header": mk_header, + "mk_footer": mk_footer, + "user_kw": user_kw, + "writes_per_chunk": writes_per_chunk, + "spill_sz": spill_sz, + "client": client, + } + + +def mpu_upload( + chunks: Union[dask.bag.Bag, list[dask.bag.Bag]], + *, + writer: Any, + dask_name_prefix: str, + **kw, +) -> "Delayed": + """Shared logic for multipart uploads to storage services.""" + client = kw.pop("client", None) + writer_kw = dict(kw) + if client is not None: + writer_kw["client"] = client + spill_sz = kw.get("spill_sz", 20 * (1 << 20)) + if spill_sz: + write = writer(writer_kw) + else: + write = None + return mpu_write( + chunks, + write, + dask_name_prefix=dask_name_prefix, + **kw, # everything else remains + ) diff --git a/odc/geo/cog/_multipart.py b/odc/geo/cog/_multipart.py index 0fc9b4c8..c7376bfc 100644 --- a/odc/geo/cog/_multipart.py +++ b/odc/geo/cog/_multipart.py @@ -9,6 +9,9 @@ from abc import ABC, abstractmethod from typing import Any, Union, TYPE_CHECKING +from dask.delayed import Delayed +from ._mpu import get_mpu_kwargs, mpu_upload + if TYPE_CHECKING: # pylint: disable=import-outside-toplevel,import-error import dask.bag @@ -53,6 +56,9 @@ def writer(self, kw: dict[str, Any], *, client: Any = None) -> Any: """ @abstractmethod + def dask_name_prefix(self) -> str: + """Return the dask name prefix specific to the backend.""" + def upload( self, chunks: Union["dask.bag.Bag", list["dask.bag.Bag"]], @@ -63,17 +69,19 @@ def upload( writes_per_chunk: int = 1, spill_sz: int = 20 * (1 << 20), client: Any = None, - **kw, - ) -> Any: - """ - Orchestrate the upload process with multipart uploads. - - :param chunks: Dask bag of chunks to upload. - :param mk_header: Function to create header data. - :param mk_footer: Function to create footer data. - :param user_kw: User-provided metadata for the upload. - :param writes_per_chunk: Number of writes per chunk. - :param spill_sz: Spill size for buffering data. - :param client: Dask client for distributed execution. - :return: A Dask delayed object representing the finalised upload. - """ + ) -> Delayed: + """High-level upload that calls mpu_upload under the hood.""" + kwargs = get_mpu_kwargs( + mk_header=mk_header, + mk_footer=mk_footer, + user_kw=user_kw, + writes_per_chunk=writes_per_chunk, + spill_sz=spill_sz, + client=client, + ) + return mpu_upload( + chunks, + writer=self.writer, + dask_name_prefix=self.dask_name_prefix(), + **kwargs, + ) diff --git a/odc/geo/cog/_s3.py b/odc/geo/cog/_s3.py index 367b31c9..0ae6981f 100644 --- a/odc/geo/cog/_s3.py +++ b/odc/geo/cog/_s3.py @@ -9,7 +9,7 @@ from cachetools import cached -from ._mpu import PartsWriter, SomeData, mpu_write +from ._mpu import PartsWriter, SomeData from ._multipart import MultiPartUploadBase if TYPE_CHECKING: @@ -197,30 +197,9 @@ def writer(self, kw, *, client: Any = None) -> PartsWriter: writer.prep_client(client) return writer - def upload( - self, - chunks: "dask.bag.Bag" | list["dask.bag.Bag"], - *, - mk_header: Any = None, - mk_footer: Any = None, - user_kw: dict[str, Any] | None = None, - writes_per_chunk: int = 1, - spill_sz: int = 20 * (1 << 20), - client: Any = None, - **kw, - ) -> "Delayed": - """Upload chunks to S3 with multipart uploads.""" - write = self.writer(kw, client=client) if spill_sz else None - return mpu_write( - chunks, - write, - mk_header=mk_header, - mk_footer=mk_footer, - user_kw=user_kw, - writes_per_chunk=writes_per_chunk, - spill_sz=spill_sz, - dask_name_prefix="s3finalise", - ) + def dask_name_prefix(self) -> str: + """Return the Dask name prefix for S3.""" + return "s3finalise" def _safe_get(v, timeout=0.1): diff --git a/odc/geo/geom.py b/odc/geo/geom.py index 6aff6ef4..f1e89850 100644 --- a/odc/geo/geom.py +++ b/odc/geo/geom.py @@ -320,7 +320,6 @@ def boundary(self, pts_per_side: int = 2) -> "Geometry": self.crs, ) - def qr2sample( self, n: int, @@ -358,7 +357,8 @@ def qr2sample( n_side = int(numpy.round(sample_density * min(nx, ny))) + 1 n_side = max(2, n_side) edge_pts = [ - (float(ep[0]), float(ep[1])) for ep in list(self.boundary(n_side).coords[:-1]) + (float(ep[0]), float(ep[1])) + for ep in list(self.boundary(n_side).coords[:-1]) ] if padding is None: padding = 0.3 * min(nx, ny) / (n_side - 1) @@ -377,6 +377,7 @@ def qr2sample( return multipoint(coords, self.crs) + def wrap_shapely(method): """ Takes a method that expects shapely geometry arguments and converts it to a method that operates