Skip to content

Commit

Permalink
Added coarsening levels and append mode
Browse files Browse the repository at this point in the history
  • Loading branch information
goord committed Jan 13, 2025
1 parent d90d564 commit 6d9d8d3
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
26 changes: 22 additions & 4 deletions dales2zarr/convert_int8.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#!/usr/bin/env python

import argparse
import logging
import xarray as xr
import yaml
import zarr
import yaml
from dales2zarr.zarr_cast import multi_cast_to_int8


# Parse command-line arguments
def parse_args(arg_list=None):
"""Parse command-line arguments for the convert_int8 script.
Expand All @@ -24,6 +25,12 @@ def parse_args(arg_list=None):
help="Path to the output zarr file")
parser.add_argument("--config", metavar="FILE", type=str, required=False, default=None,
help="Path to the input configuration file (yaml)")
parser.add_argument("--levels", metavar="INT", type=int, required=False, default=0,
help="Number of coarsening levels")
parser.add_argument("--timestamps" , metavar="INT", type=int, required=False, default=0,
help="Number of timestamps to keep")
parser.add_argument("--mode", metavar="w|a", type=str, required=False, default="a", choices=["w", "a"],
help="Write or append mode")
return parser.parse_args(args=arg_list)


Expand Down Expand Up @@ -51,15 +58,26 @@ def main(arg_list=None):
with open(args.config, "r") as f:
input_config = yaml.safe_load(f)

# Keep only the first args.timestamps timestamps
if args.timestamps > 0:
input_ds = input_ds.isel(time=slice(0, args.timestamps))

# Call multi_cast_to_int8 on the input dataset
output_ds, output_variables = multi_cast_to_int8(input_ds, input_config)

outfile = args.output if args.output is not None else args.input.replace(".nc", "_int8.zarr")

# Write the result to zarr with Blosc compression
compressor = zarr.Blosc(cname="zstd", clevel=3, shuffle=zarr.Blosc.BITSHUFFLE)
compressor = zarr.Blosc(cname="lz4", clevel=6, shuffle=zarr.Blosc.BITSHUFFLE)
var_encoding = {"dtype": "uint8", "compressor": compressor}
output_ds.to_zarr(outfile, mode="w", encoding={var: var_encoding for var in output_variables})
output_ds.to_zarr(outfile, mode=args.mode, encoding={var: var_encoding for var in output_variables})

# Coarsen the dataset and write to zarr
ds = output_ds
for level in range(1, args.levels + 1):
ds = ds.coarsen({dim: 2 for dim in ds.dims if dim != "time"}, boundary="trim").mean()
ds.to_zarr(outfile.replace(".zarr", f"-{level}.zarr"), mode="a", encoding={var: var_encoding for var in output_variables})


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions dales2zarr/zarr_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def cast_to_int8_3d(input_ds, var_name, mode='linear', epsilon=1e-10):

# Compute the maximum values for each layer
layer_maxes = input_data.max(['xt', 'yt']).values
log.info(f'layer maxes are {layer_maxes}')
log.info(f'layer maxes computed')

# Compute the global maximum
glob_max = np.max(layer_maxes)
Expand All @@ -105,7 +105,7 @@ def cast_to_int8_3d(input_ds, var_name, mode='linear', epsilon=1e-10):
log.info(f'computed kbot and ktop: {kbot}, {ktop}')

# Get the heights associated with the bottom and top layers
zbot, ztop = input_ds.zt[kbot], input_ds.zt[ktop]
zbot, ztop = input_ds.zt[kbot].values, input_ds.zt[ktop].values
log.info(f'computed associated heights are: {zbot}, {ztop}')

# Create a new array with the appropriate shape
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers = [
dependencies = [
"numpy",
"xarray[io,accell]",
"zarr",
"zarr >= 2.18.0, < 3",
"pyyaml",
"argparse",
"dask",
Expand Down

0 comments on commit 6d9d8d3

Please sign in to comment.