Skip to content

Commit

Permalink
Add support for zarr archive loading
Browse files Browse the repository at this point in the history
  • Loading branch information
shinzlet committed Dec 11, 2024
1 parent 2f84c94 commit 66eb9d4
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/czpeedy/czpeedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

import numpy as np
from termcolor import colored
import tensorstore as ts

from czpeedy.zarr_util import identify_zarr_format
from czpeedy.runner import Runner
from czpeedy.parameter_space import ParameterSpace

Expand Down Expand Up @@ -97,13 +99,11 @@ def shuffle_type(text: str) -> str:
f"\"{text}\" is not a valid shuffle type. Valid shuffle types: {", ".join(shuffles.keys())}"
)


# Takes all the information that the user provided about the input source and attempts to load it into a numpy array.
# Currently, only raw numpy data files are supported.
def load_input(
source: Path, shape: list[int] | None = None, dtype: np.dtype | None = None
) -> np.ndarray:
if source.is_file:
if source.is_file():
# Raw numpy data dump (or known type):
print(f"{colored("Reading input file", "green")} as raw numpy dump")
if shape is None:
Expand All @@ -117,7 +117,18 @@ def load_input(
with open(source, "rb") as f:
return np.fromfile(f, dtype=dtype).reshape(shape)
else:
raise NotImplementedError("Loading from zarr is not yet supported.")
version = identify_zarr_format(source)
if version is None:
raise ValueError(
f"Could not identify the zarr version of the input dataset at {source}. The archive is either unsupported or invalid."
)

dataset = ts.open({
"driver": "zarr" if version == 2 else "zarr_v3",
"kvstore": {"driver": "file", "path": str(source.absolute())},
}).result()

return dataset.read().result()


# Given a callable that can be used as a type in argparse (i.e. it can convert a string to a more specific type),
Expand Down
40 changes: 40 additions & 0 deletions src/czpeedy/zarr_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import functools
import json
import logging
import operator
from logging import Logger
from pathlib import Path
from typing import Literal, Optional

LOG = logging.getLogger(__name__)
METADATA_FILES_BY_VERSION = {
2: [".zarray", ".zattrs", ".zgroup"],
3: ["zarr.json"],
}
ALL_METADATA_FILES = set(functools.reduce(operator.iadd, METADATA_FILES_BY_VERSION.values(), []))
KNOWN_VERSIONS = set(METADATA_FILES_BY_VERSION.keys())


def identify_zarr_format(archive_path: Path, log: Logger = LOG) -> Optional[Literal[2, 3]]:
"""
Identify the zarr version of the archive by identifying a metadata file and reading its zarr_format key.
If the metadata file is missing, the zarr_format key is missing, or the specified version is not "2" or "3",
returns None.
"""

for candidate_file in ALL_METADATA_FILES:
metadata_file = archive_path / candidate_file

if metadata_file.exists():
with open(metadata_file) as f:
metadata = json.load(f)
zarr_format = metadata.get("zarr_format")
if zarr_format in KNOWN_VERSIONS:
log.debug(f"Identified zarr version {zarr_format} from metadata file {metadata_file}")
return zarr_format
else:
log.debug(f"Invalid zarr version {zarr_format} in metadata file {metadata_file}")
return None

log.debug(f"Could not identify zarr version from metadata files in archive folder {archive_path}")
return None

0 comments on commit 66eb9d4

Please sign in to comment.