Skip to content

Commit

Permalink
Slight refactor of TrialParameters to localize everything needed to b…
Browse files Browse the repository at this point in the history
…uild a ts spec
  • Loading branch information
shinzlet committed Jul 29, 2024
1 parent 2442983 commit bab4cf3
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "czpeedy"
version = "0.2.1"
version = "0.2.2"
description = "A command-line tool used to determine the tensorstore settings which yield the fastest write speed on a given machine."
authors = [
{ name = "Seth Hinz", email = "[email protected]" }
Expand Down
2 changes: 1 addition & 1 deletion src/czpeedy/czpeedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def main() -> None:
parameter_space = ParameterSpace(
data.shape,
args.chunk_size,
args.dest,
data.dtype,
args.zarr_version,
args.clevel,
Expand All @@ -243,7 +244,6 @@ def main() -> None:
runner = Runner(
parameter_space.all_combinations(),
data,
args.dest,
args.repetitions,
parameter_space.num_combinations,
)
Expand Down
5 changes: 5 additions & 0 deletions src/czpeedy/parameter_space.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Iterable, Iterator
from itertools import product
from pathlib import Path

import numpy as np
from numpy.typing import ArrayLike
Expand All @@ -15,6 +16,7 @@ class ParameterSpace:

shape: tuple[int, ...]
dtype: np.dtype
dest: Path
zarr_versions: list[int]
clevels: list[int]
compressors: list[str]
Expand All @@ -27,6 +29,7 @@ def __init__(
self,
shape: ArrayLike,
chunk_sizes: Iterable[ArrayLike],
dest: Path,
dtype: np.dtype,
# Default parameters would be nice, but the user needs to be able to explicitly pass None and
# have the parameters still be set when needed (i.e. in ParameterSpace(..., clevels=args.clevels, ...),
Expand Down Expand Up @@ -90,6 +93,7 @@ def __init__(

self.shape = tuple(shape)
self.chunk_sizes = [tuple(chunk_size) for chunk_size in chunk_sizes]
self.dest = dest
self.dtype = dtype
self.zarr_versions = zarr_versions
self.clevels = list(clevels)
Expand Down Expand Up @@ -150,6 +154,7 @@ def to_trial_parameters(
return TrialParameters(
self.shape,
chunk_size,
self.dest,
self.dtype,
zarr_version=zarr_version,
clevel=clevel,
Expand Down
5 changes: 1 addition & 4 deletions src/czpeedy/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
class Runner:
trial_params: Iterable[TrialParameters]
data: np.ndarray
dest: Path
repetitions: int
batch_count: int
results: dict[TrialParameters, list[float]]
Expand All @@ -24,13 +23,11 @@ def __init__(
self,
trial_params: Iterable[TrialParameters],
data: np.ndarray,
dest: Path,
repetitions: int,
batch_count: int | None = None,
):
self.trial_params = trial_params
self.data = data
self.dest = dest
self.repetitions = repetitions
self.batch_count = batch_count
self.results = {}
Expand All @@ -53,7 +50,7 @@ def run_all(self):

for batch_id, trial_param in enumerate(self.trial_params):
result = []
spec = trial_param.to_spec(self.dest)
spec = trial_param.to_spec()
codecs = ts.CodecSpec(trial_param.codecs())

dataset = ts.open(spec, codec=codecs).result()
Expand Down
9 changes: 6 additions & 3 deletions src/czpeedy/trial_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ class TrialParameters:
compressor: str
shuffle: int
endianness: int
output_path: Path

# All parameters are either obvious or can be referenced in tensorstore's spec documentation, with the exception of `endianness`.
# `endianness` is -1 for little endian, 0 for indeterminate endianness (only applies for 1 byte values), and +1 for big endian.
def __init__(
self,
shape: ArrayLike[int],
chunk_size: ArrayLike,
output_path: Path,
dtype: np.dtype,
zarr_version: int,
clevel: int,
Expand All @@ -31,6 +33,7 @@ def __init__(
):
self.shape = shape
self.chunk_size = list(chunk_size)
self.output_path = output_path
self.dtype = dtype
self.zarr_version = zarr_version
self.clevel = clevel
Expand All @@ -55,13 +58,13 @@ def dtype_json_v3(self) -> str:

# Produces a jsonable dict that communicates all the trial parameters to tensorstore.
# Usage: `ts.open(trial_parameters.to_spec()).result()`
def to_spec(self, output_path: Path) -> dict:
def to_spec(self) -> dict:
if self.zarr_version == 2:
return {
"driver": "zarr",
"kvstore": {
"driver": "file",
"path": str(output_path.absolute()),
"path": str(self.output_path.absolute()),
},
"metadata": {
"compressor": {
Expand Down Expand Up @@ -119,7 +122,7 @@ def to_spec(self, output_path: Path) -> dict:
"driver": "zarr3",
"kvstore": {
"driver": "file",
"path": str(output_path.absolute()),
"path": str(self.output_path.absolute()),
},
"metadata": {
"shape": self.shape,
Expand Down

0 comments on commit bab4cf3

Please sign in to comment.