Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: save COGs with dask to azure #195

Merged
merged 6 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions odc/geo/_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ def datacube(self) -> bool:
def tifffile(self) -> bool:
return self._check("tifffile")

@property
def azure(self) -> bool:
return self._check("azure.storage.blob")

@property
def botocore(self) -> bool:
return self._check("botocore")

@staticmethod
def _check(lib_name: str) -> bool:
return importlib.util.find_spec(lib_name) is not None
Expand Down
168 changes: 168 additions & 0 deletions odc/geo/cog/_az.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import base64
from typing import Any

from ._multipart import MultiPartUploadBase


class AzureLimits:
"""
Common Azure writer settings.
"""

@property
def min_write_sz(self) -> int:
# Azure minimum write size for blocks (default is 4 MiB)
return 4 * (1 << 20)

@property
def max_write_sz(self) -> int:
# Azure maximum write size for blocks (default is 100 MiB)
return 100 * (1 << 20)

@property
def min_part(self) -> int:
return 1

@property
def max_part(self) -> int:
# Azure supports up to 50,000 blocks per blob
return 50_000


class AzMultiPartUpload(AzureLimits, MultiPartUploadBase):
"""
Azure Blob Storage multipart upload.
"""

# pylint: disable=too-many-instance-attributes

def __init__(
self, account_url: str, container: str, blob: str, credential: Any = None
):
"""
Initialise Azure multipart upload.

:param account_url: URL of the Azure storage account.
:param container: Name of the container.
:param blob: Name of the blob.
:param credential: Authentication credentials (e.g., SAS token or key).
"""
self.account_url = account_url
self.container = container
self.blob = blob
self.credential = credential

# Initialise Azure Blob service client
# pylint: disable=import-outside-toplevel,import-error
from azure.storage.blob import BlobServiceClient

self.blob_service_client = BlobServiceClient(
account_url=account_url, credential=credential
)
self.container_client = self.blob_service_client.get_container_client(container)
self.blob_client = self.container_client.get_blob_client(blob)

self.block_ids: list[str] = []

def initiate(self, **kwargs) -> str:
"""
Initialise the upload. No-op for Azure.
"""
return "azure-block-upload"

def write_part(self, part: int, data: bytes) -> dict[str, Any]:
"""
Stage a block in Azure.

:param part: Part number (unique).
:param data: Data for this part.
:return: A dictionary containing part information.
"""
block_id = base64.b64encode(f"block-{part}".encode()).decode()
self.blob_client.stage_block(block_id=block_id, data=data)
self.block_ids.append(block_id)
return {"PartNumber": part, "BlockId": block_id}

def finalise(self, parts: list[dict[str, Any]]) -> str:
"""
Commit the block list to finalise the upload.

:param parts: List of uploaded parts metadata.
:return: The ETag of the finalised blob.
"""
# pylint: disable=import-outside-toplevel,import-error
from azure.storage.blob import BlobBlock

block_list = [BlobBlock(block_id=part["BlockId"]) for part in parts]
self.blob_client.commit_block_list(block_list)
return self.blob_client.get_blob_properties().etag

def cancel(self, other: str = ""):
"""
Cancel the upload by clearing the block list.
"""
assert other == ""
self.block_ids.clear()

@property
def url(self) -> str:
"""
Get the Azure blob URL.

:return: The full URL of the blob.
"""
return self.blob_client.url

@property
def started(self) -> bool:
"""
Check if any blocks have been staged.

:return: True if blocks have been staged, False otherwise.
"""
return bool(self.block_ids)

def writer(self, kw: dict[str, Any], *, client: Any = None):
"""
Return a stateless writer compatible with Dask.
"""
return DelayedAzureWriter(self, kw)

def dask_name_prefix(self) -> str:
"""Return the Dask name prefix for Azure."""
return "azure-finalise"


class DelayedAzureWriter(AzureLimits):
"""
Dask-compatible writer for Azure Blob Storage multipart uploads.
"""

def __init__(self, mpu: AzMultiPartUpload, kw: dict[str, Any]):
"""
Initialise the Azure writer.

:param mpu: AzMultiPartUpload instance.
:param kw: Additional parameters for the writer.
"""
self.mpu = mpu
self.kw = kw # Additional metadata like ContentType

def __call__(self, part: int, data: bytes) -> dict[str, Any]:
"""
Write a single part to Azure Blob Storage.

:param part: Part number.
:param data: Chunk data.
:return: Metadata for the written part.
"""
return self.mpu.write_part(part, data)

def finalise(self, parts: list[dict[str, Any]]) -> str:
"""
Finalise the upload by committing the block list.

:param parts: List of uploaded parts metadata.
:return: ETag of the finalised blob.
"""
return self.mpu.finalise(parts)
46 changes: 46 additions & 0 deletions odc/geo/cog/_mpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,3 +495,49 @@ def _finalizer_dask_op(

_, rr = _root.flush(write, leftPartId=1, finalise=True)
return rr


def get_mpu_kwargs(
mk_header=None,
mk_footer=None,
user_kw=None,
writes_per_chunk=1,
spill_sz=20 * (1 << 20),
client=None,
) -> dict:
"""
Construct shared keyword arguments for multipart uploads.
"""
return {
"mk_header": mk_header,
"mk_footer": mk_footer,
"user_kw": user_kw,
"writes_per_chunk": writes_per_chunk,
"spill_sz": spill_sz,
"client": client,
}


def mpu_upload(
chunks: Union[dask.bag.Bag, list[dask.bag.Bag]],
*,
writer: Any,
dask_name_prefix: str,
**kw,
) -> "Delayed":
"""Shared logic for multipart uploads to storage services."""
client = kw.pop("client", None)
writer_kw = dict(kw)
if client is not None:
writer_kw["client"] = client
spill_sz = kw.get("spill_sz", 20 * (1 << 20))
if spill_sz:
write = writer(writer_kw)
else:
write = None
return mpu_write(
chunks,
write,
dask_name_prefix=dask_name_prefix,
**kw, # everything else remains
)
87 changes: 87 additions & 0 deletions odc/geo/cog/_multipart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
Multipart upload interface.

Defines the `MultiPartUploadBase` class for implementing multipart upload functionality.
This interface standardises methods for initiating, uploading, and finalising
multipart uploads across storage backends.
"""

from abc import ABC, abstractmethod
from typing import Any, Union, TYPE_CHECKING

from dask.delayed import Delayed
from ._mpu import get_mpu_kwargs, mpu_upload

if TYPE_CHECKING:
# pylint: disable=import-outside-toplevel,import-error
import dask.bag


class MultiPartUploadBase(ABC):
"""Abstract base class for multipart upload."""

@abstractmethod
def initiate(self, **kwargs) -> str:
"""Initiate a multipart upload and return an identifier."""

@abstractmethod
def write_part(self, part: int, data: bytes) -> dict[str, Any]:
"""Upload a single part."""

@abstractmethod
def finalise(self, parts: list[dict[str, Any]]) -> str:
"""Finalise the upload with a list of parts."""

@abstractmethod
def cancel(self, other: str = ""):
"""Cancel the multipart upload."""

@property
@abstractmethod
def url(self) -> str:
"""Return the URL of the upload target."""

@property
@abstractmethod
def started(self) -> bool:
"""Check if the multipart upload has been initiated."""

@abstractmethod
def writer(self, kw: dict[str, Any], *, client: Any = None) -> Any:
"""
Return a Dask-compatible writer for multipart uploads.

:param kw: Additional parameters for the writer.
:param client: Dask client for distributed execution.
"""

@abstractmethod
def dask_name_prefix(self) -> str:
"""Return the dask name prefix specific to the backend."""

def upload(
self,
chunks: Union["dask.bag.Bag", list["dask.bag.Bag"]],
*,
mk_header: Any = None,
mk_footer: Any = None,
user_kw: dict[str, Any] | None = None,
writes_per_chunk: int = 1,
spill_sz: int = 20 * (1 << 20),
client: Any = None,
) -> Delayed:
"""High-level upload that calls mpu_upload under the hood."""
kwargs = get_mpu_kwargs(
mk_header=mk_header,
mk_footer=mk_footer,
user_kw=user_kw,
writes_per_chunk=writes_per_chunk,
spill_sz=spill_sz,
client=client,
)
return mpu_upload(
chunks,
writer=self.writer,
dask_name_prefix=self.dask_name_prefix(),
**kwargs,
)
Loading
Loading