diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 786f2dd..f36970c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,7 +45,7 @@ repos: # Can run individually with `pre-commit run mypy --all-files` - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.13.0 + rev: v1.14.1 hooks: - id: mypy args: ["--config-file", "pyproject.toml", "--ignore-missing-imports"] diff --git a/analysis/analyze_scalar.py b/analysis/analyze_scalar.py index 668c933..a3a0627 100644 --- a/analysis/analyze_scalar.py +++ b/analysis/analyze_scalar.py @@ -28,7 +28,7 @@ from importlib.resources import files from itertools import chain from pathlib import Path -from typing import Any, Callable, Dict, List, Union +from typing import Callable import matplotlib as mpl import matplotlib.pyplot as plt @@ -36,25 +36,17 @@ import pandas as pd import toml import xarray as xr -from dask.distributed import Client, progress import pism_ragis.processing as prp -from pism_ragis.analyze import delta_analysis -from pism_ragis.decorators import timeit +from pism_ragis.analyze import run_sensitivity_analysis from pism_ragis.filtering import filter_outliers, run_importance_sampling from pism_ragis.logger import get_logger from pism_ragis.plotting import ( plot_basins, - plot_outliers, plot_posteriors, plot_prior_posteriors, plot_sensitivity_indices, ) -from pism_ragis.processing import ( - config_to_dataframe, - filter_config, - filter_retreat_experiments, -) logger = get_logger("pism_ragis") @@ -70,385 +62,6 @@ ) -def sort_columns(df: pd.DataFrame, sorted_columns: List[str]) -> pd.DataFrame: - """ - Sort columns of a DataFrame. - - This function sorts the columns of a DataFrame such that the columns specified in - `sorted_columns` appear in the specified order, while all other columns appear before - the sorted columns in their original order. - - Parameters - ---------- - df : pd.DataFrame - The input DataFrame to be sorted. - sorted_columns : List[str] - A list of column names to be sorted. - - Returns - ------- - pd.DataFrame - The DataFrame with columns sorted as specified. - """ - # Identify columns that are not in the list - other_columns = [col for col in df.columns if col not in sorted_columns] - - # Concatenate other columns with the sorted columns - new_column_order = other_columns + sorted_columns - - # Reindex the DataFrame - return df.reindex(columns=new_column_order) - - -def add_prefix_coord( - sensitivity_indices: xr.Dataset, parameter_groups: Dict[str, str] -) -> xr.Dataset: - """ - Add prefix coordinates to an xarray Dataset. - - This function extracts the prefix from each coordinate value in the 'pism_config_axis' - and adds it as a new coordinate. It also maps the prefixes to their corresponding - sensitivity indices groups. - - Parameters - ---------- - sensitivity_indices : xr.Dataset - The input dataset containing sensitivity indices. - parameter_groups : Dict[str, str] - A dictionary mapping parameter names to their corresponding groups. - - Returns - ------- - xr.Dataset - The dataset with added prefix coordinates and sensitivity indices groups. - """ - prefixes = [ - name.split(".")[0] for name in sensitivity_indices.pism_config_axis.values - ] - - sensitivity_indices = sensitivity_indices.assign_coords( - prefix=("pism_config_axis", prefixes) - ) - si_prefixes = [parameter_groups[name] for name in sensitivity_indices.prefix.values] - - sensitivity_indices = sensitivity_indices.assign_coords( - sensitivity_indices_group=("pism_config_axis", si_prefixes) - ) - return sensitivity_indices - - -def prepare_input( - df: pd.DataFrame, - params: List[str] = [ - "surface.given.file", - "ocean.th.file", - "calving.rate_scaling.file", - "geometry.front_retreat.prescribed.file", - ], -) -> pd.DataFrame: - """ - Prepare the input DataFrame by converting columns to numeric and mapping unique values to integers. - - This function processes the input DataFrame by converting specified columns to numeric values, - dropping specified columns, and mapping unique values in the specified parameters to integers. - - Parameters - ---------- - df : pd.DataFrame - The input DataFrame to be processed. - params : List[str], optional - A list of column names to be processed. Unique values in these columns will be mapped to integers. - By default, the list includes: - ["surface.given.file", "ocean.th.file", "calving.rate_scaling.file", "geometry.front_retreat.prescribed.file"]. - - Returns - ------- - pd.DataFrame - The processed DataFrame with specified columns converted to numeric and unique values mapped to integers. - - Examples - -------- - >>> df = pd.DataFrame({ - ... "surface.given.file": ["file1", "file2", "file1"], - ... "ocean.th.file": ["fileA", "fileB", "fileA"], - ... "calving.rate_scaling.file": ["fileX", "fileY", "fileX"], - ... "geometry.front_retreat.prescribed.file": ["fileM", "fileN", "fileM"], - ... "ensemble": [1, 2, 3], - ... "exp_id": [101, 102, 103] - ... }) - >>> prepare_input(df) - surface.given.file ocean.th.file calving.rate_scaling.file geometry.front_retreat.prescribed.file - 0 0 0 0 0 - 1 1 1 1 1 - 2 0 0 0 0 - """ - df = df.apply(prp.convert_column_to_numeric).drop( - columns=["ensemble", "exp_id"], errors="ignore" - ) - - for param in params: - m_dict: Dict[str, int] = {v: k for k, v in enumerate(df[param].unique())} - df[param] = df[param].map(m_dict) - - return df - - -@timeit -def prepare_simulations( - filenames: List[Union[Path, str]], - config: Dict[str, Any], - reference_date: str, - parallel: bool = True, - engine: str = "h5netcdf", -) -> xr.Dataset: - """ - Prepare simulations by loading and processing ensemble datasets. - - This function loads ensemble datasets from the specified filenames, processes them - according to the provided configuration, and returns the processed dataset. The - processing steps include sorting, converting byte strings to strings, dropping NaNs, - standardizing variable names, calculating cumulative variables, and normalizing - cumulative variables. - - Parameters - ---------- - filenames : List[Union[Path, str]] - A list of file paths to the ensemble datasets. - config : Dict[str, Any] - A dictionary containing configuration settings for processing the datasets. - reference_date : str - The reference date for normalizing cumulative variables. - parallel : bool, optional - Whether to load the datasets in parallel, by default True. - engine : str, optional - The engine to use for loading the datasets, by default "h5netcdf". - - Returns - ------- - xr.Dataset - The processed xarray dataset. - - Examples - -------- - >>> filenames = ["file1.nc", "file2.nc"] - >>> config = { - ... "PISM Spatial": {...}, - ... "Cumulative Variables": { - ... "cumulative_grounding_line_flux": "cumulative_gl_flux", - ... "cumulative_smb": "cumulative_smb_flux" - ... }, - ... "Flux Variables": { - ... "grounding_line_flux": "gl_flux", - ... "smb_flux": "smb_flux" - ... } - ... } - >>> reference_date = "2000-01-01" - >>> ds = prepare_simulations(filenames, config, reference_date) - """ - ds = prp.load_ensemble(filenames, parallel=parallel, engine=engine).sortby("basin") - ds = xr.apply_ufunc(np.vectorize(convert_bstrings_to_str), ds, dask="parallelized") - - ds = prp.standardize_variable_names(ds, config["PISM Spatial"]) - ds[config["Cumulative Variables"]["cumulative_grounding_line_flux"]] = ds[ - config["Flux Variables"]["grounding_line_flux"] - ].cumsum() / len(ds.time) - ds[config["Cumulative Variables"]["cumulative_smb"]] = ds[ - config["Flux Variables"]["smb_flux"] - ].cumsum() / len(ds.time) - ds = prp.normalize_cumulative_variables( - ds, - list(config["Cumulative Variables"].values()), - reference_date=reference_date, - ) - return ds - - -@timeit -def prepare_observations( - basin_url: Union[Path, str], - grace_url: Union[Path, str], - config: Dict[str, Any], - reference_date: str, - engine: str = "h5netcdf", -) -> tuple[xr.Dataset, xr.Dataset]: - """ - Prepare observation datasets by normalizing cumulative variables. - - This function loads observation datasets from the specified URLs, sorts them by basin, - normalizes the cumulative variables, and returns the processed datasets. - - Parameters - ---------- - basin_url : Union[Path, str] - The URL or path to the basin observation dataset. - grace_url : Union[Path, str] - The URL or path to the GRACE observation dataset. - config : Dict[str, Any] - A dictionary containing configuration settings for processing the datasets. - reference_date : str - The reference date for normalizing cumulative variables. - engine : str, optional - The engine to use for loading the datasets, by default "h5netcdf". - - Returns - ------- - tuple[xr.Dataset, xr.Dataset] - A tuple containing the processed basin and GRACE observation datasets. - - Examples - -------- - >>> config = { - ... "Cumulative Variables": {"cumulative_mass_balance": "mass_balance"}, - ... "Cumulative Uncertainty Variables": {"cumulative_mass_balance_uncertainty": "mass_balance_uncertainty"} - ... } - >>> prepare_observations("basin.nc", "grace.nc", config, "2000-01-1") - (, ) - """ - obs_basin = xr.open_dataset(basin_url, engine=engine, chunks=-1) - obs_basin = obs_basin.sortby("basin") - - cumulative_vars = config["Cumulative Variables"] - cumulative_uncertainty_vars = config["Cumulative Uncertainty Variables"] - - obs_basin = prp.normalize_cumulative_variables( - obs_basin, - list(cumulative_vars.values()) + list(cumulative_uncertainty_vars.values()), - reference_date, - ) - - obs_grace = xr.open_dataset(grace_url, engine=engine, chunks=-1) - obs_grace = obs_grace.sortby("basin") - - cumulative_vars = config["Cumulative Variables"]["cumulative_mass_balance"] - cumulative_uncertainty_vars = config["Cumulative Uncertainty Variables"][ - "cumulative_mass_balance_uncertainty" - ] - - obs_grace = prp.normalize_cumulative_variables( - obs_grace, - [cumulative_vars] + [cumulative_uncertainty_vars], - reference_date, - ) - - return obs_basin, obs_grace - - -def convert_bstrings_to_str(element: Any) -> Any: - """ - Convert byte strings to regular strings. - - Parameters - ---------- - element : Any - The element to be checked and potentially converted. If the element is a byte string, - it will be converted to a regular string. Otherwise, the element will be returned as is. - - Returns - ------- - Any - The converted element if it was a byte string, otherwise the original element. - """ - if isinstance(element, bytes): - return element.decode("utf-8") - return element - - -@timeit -def run_sensitivity_analysis( - input_df: pd.DataFrame, - response_ds: xr.Dataset, - filter_vars: List[str], - group_dim: str = "basin", - iter_dim: str = "time", - notebook: bool = False, -) -> xr.Dataset: - """ - Run delta sensitivity analysis on the given dataset. - - This function calculates sensitivity indices for each basin in the dataset, - filtered by the specified variables. It uses Dask for parallel processing - to improve performance. - - Parameters - ---------- - input_df : pd.DataFrame - DataFrame containing ensemble information, with a 'basin' column to group by. - response_ds : xr.Dataset - The input dataset containing the data to be analyzed. - filter_vars : List[str] - List of variables to filter by for sensitivity analysis. - group_dim : str, optional - The dimension to group by, by default "basin". - iter_dim : str, optional - The dimension to iterate over, by default "time". - notebook : bool, optional - Whether to display a nicer progress bar when running in a notebook, by default False. - - Returns - ------- - xr.Dataset - A dataset containing the calculated sensitivity indices for each basin and filter variable. - - Notes - ----- - It is imperative to load the dataset before starting the Dask client, - to avoid each Dask worker loading the dataset separately, which would - significantly slow down the computation. - """ - print("Calculating Sensitivity Indices") - print("===============================") - - client = Client() - print(f"Open client in browser: {client.dashboard_link}") - sensitivity_indices_list = [] - for gdim, df in input_df.groupby(by=group_dim): - df = df.drop(columns=[group_dim]) - problem = { - "num_vars": len(df.columns), - "names": df.columns, # Parameter names - "bounds": zip( - df.min().values, - df.max().values, - ), # Parameter bounds - } - for filter_var in filter_vars: - print( - f" ...sensitivity indices for basin {gdim} filtered by {filter_var} ", - ) - - responses = response_ds.sel({"basin": gdim})[filter_var].load() - responses_scattered = client.scatter( - [ - responses.isel({"time": k}).to_numpy() - for k in range(len(responses[iter_dim])) - ] - ) - - futures = client.map( - delta_analysis, - responses_scattered, - X=df.to_numpy(), - problem=problem, - ) - progress(futures, notebook=notebook) - result = client.gather(futures) - - sensitivity_indices = xr.concat( - [r.expand_dims(iter_dim) for r in result], dim=iter_dim - ) - sensitivity_indices[iter_dim] = responses[iter_dim] - sensitivity_indices = sensitivity_indices.expand_dims(group_dim, axis=1) - sensitivity_indices[group_dim] = [gdim] - sensitivity_indices = sensitivity_indices.expand_dims("filtered_by", axis=2) - sensitivity_indices["filtered_by"] = [filter_var] - sensitivity_indices_list.append(sensitivity_indices) - - all_sensitivity_indices: xr.Dataset = xr.merge(sensitivity_indices_list) - client.close() - - return all_sensitivity_indices - - if __name__ == "__main__": __spec__ = None # type: ignore @@ -523,13 +136,6 @@ def run_sensitivity_analysis( type=str, default="2020-01-1", ) - parser.add_argument( - "--retreat_method", - help="""Sub-select retreat method. Default='all'.""", - type=str, - choices=["all", "free", "prescribed"], - default="all", - ) parser.add_argument( "--n_jobs", help="""Number of parallel jobs. Default=8.""", @@ -567,7 +173,6 @@ def run_sensitivity_analysis( parallel = options.parallel reference_date = options.reference_date resampling_frequency = options.resampling_frequency - retreat_method = options.retreat_method outlier_variable = options.outlier_variable outlier_range = options.outlier_range ragis_config_file = Path( @@ -581,18 +186,18 @@ def run_sensitivity_analysis( sim_cmap = config["Plotting"]["sim_cmap"] grace_fudge_factor = config["Importance Sampling"]["grace_fudge_factor"] mankoff_fudge_factor = config["Importance Sampling"]["mankoff_fudge_factor"] + retreat_methods = ["All", "Free", "Prescribed"] result_dir = Path(options.result_dir) data_dir = result_dir / Path("posteriors") data_dir.mkdir(parents=True, exist_ok=True) - fig_dir = result_dir / Path("figures") - fig_dir.mkdir(parents=True, exist_ok=True) - plot_dir = fig_dir / Path("basin_timeseries") - plot_dir.mkdir(parents=True, exist_ok=True) - pdf_dir = plot_dir / Path("pdfs") - pdf_dir.mkdir(parents=True, exist_ok=True) - png_dir = plot_dir / Path("pngs") - png_dir.mkdir(parents=True, exist_ok=True) + + column_function_mapping: dict[str, list[Callable]] = { + "surface.given.file": [prp.simplify_path, prp.simplify_climate], + "ocean.th.file": [prp.simplify_path, prp.simplify_ocean], + "calving.rate_scaling.file": [prp.simplify_path, prp.simplify_calving], + "geometry.front_retreat.prescribed.file": [prp.simplify_retreat], + } rcparams = { "axes.linewidth": 0.25, @@ -607,11 +212,11 @@ def run_sensitivity_analysis( mpl.rcParams.update(rcparams) - simulated_ds = prepare_simulations( + simulated_ds = prp.prepare_simulations( basin_files, config, reference_date, parallel=parallel, engine=engine ) - observed_mankoff_ds, observed_grace_ds = prepare_observations( + observed_mankoff_ds, observed_grace_ds = prp.prepare_observations( options.mankoff_url, options.grace_url, config, @@ -619,75 +224,101 @@ def run_sensitivity_analysis( engine=engine, ) - simulated_ds = filter_retreat_experiments(simulated_ds, retreat_method) + da = observed_mankoff_ds.sel( + time=slice(f"{filter_range[0]}", f"{filter_range[-1]}") + )["grounding_line_flux"].mean(dim="time") + posterior_basins_sorted = observed_mankoff_ds.basin.sortby(da).values - filtered_ds, outliers_ds = filter_outliers( - simulated_ds, - outlier_range=outlier_range, - outlier_variable=outlier_variable, - subset={"basin": "GIS"}, - ) + bins_dict = config["Posterior Bins"] + parameter_categories = config["Parameter Categories"] + params_sorted_by_category: dict = { + group: [] for group in sorted(parameter_categories.values()) + } + for param in params: + prefix = param.split(".")[0] + if prefix in parameter_categories: + group = parameter_categories[prefix] + if param not in params_sorted_by_category[group]: + params_sorted_by_category[group].append(param) - plot_outliers( - filtered_ds.sel(basin="GIS")[outlier_variable], - outliers_ds.sel(basin="GIS")[outlier_variable], - Path(pdf_dir) / Path(f"{outlier_variable}_filtering.pdf"), - ) + params_sorted_list = list(chain(*params_sorted_by_category.values())) + params_sorted_dict = {k: params_short_dict[k] for k in params_sorted_list} + + pp_retreat_list: list[pd.DataFrame] = [] + for retreat_method in retreat_methods: + fig_dir = ( + result_dir / Path(f"retreat_{retreat_method.lower()}") / Path("figures") + ) + fig_dir.mkdir(parents=True, exist_ok=True) - outliers_config = filter_config(outliers_ds, params) - outliers_df = config_to_dataframe(outliers_config, ensemble="Outliers") + simulated_retreat_ds = prp.filter_retreat_experiments( + simulated_ds, retreat_method + ) - filtered_config = filter_config(filtered_ds, params) - filtered_df = config_to_dataframe(filtered_config, ensemble="Filtered") + filtered_ds, outliers_ds = filter_outliers( + simulated_retreat_ds, + outlier_range=outlier_range, + outlier_variable=outlier_variable, + subset={"basin": "GIS"}, + ) - obs_mankoff_basins = set(observed_mankoff_ds.basin.values) - obs_grace_basins = set(observed_grace_ds.basin.values) + outliers_config = prp.filter_config(outliers_ds, params) + outliers_df = prp.config_to_dataframe(outliers_config, ensemble="Outliers") - simulated = filtered_ds + filtered_config = prp.filter_config(filtered_ds, params) + filtered_df = prp.config_to_dataframe(filtered_config, ensemble="Filtered") - sim_basins = set(simulated.basin.values) - sim_grace = set(simulated.basin.values) + obs_mankoff_basins = set(observed_mankoff_ds.basin.values) + obs_grace_basins = set(observed_grace_ds.basin.values) - intersection_mankoff = list(sim_basins.intersection(obs_mankoff_basins)) - intersection_grace = list(sim_grace.intersection(obs_grace_basins)) + simulated = filtered_ds - observed_mankoff_basins_ds = observed_mankoff_ds.sel( - {"basin": intersection_mankoff} - ) - simulated_mankoff_basins_ds = simulated.sel({"basin": intersection_mankoff}) - - observed_mankoff_basins_resampled_ds = observed_mankoff_basins_ds.resample( - {"time": resampling_frequency} - ).mean() - simulated_mankoff_basins_resampled_ds = simulated_mankoff_basins_ds.resample( - {"time": resampling_frequency} - ).mean() - - observed_grace_basins_ds = observed_grace_ds.sel({"basin": intersection_grace}) - simulated_grace_basins_ds = simulated.sel({"basin": intersection_grace}) - - observed_grace_basins_resampled_ds = observed_grace_basins_ds.resample( - {"time": resampling_frequency} - ).mean() - simulated_grace_basins_resampled_ds = simulated_grace_basins_ds.resample( - {"time": resampling_frequency} - ).mean() - - obs_mean_vars_mankoff: List[str] = ["grounding_line_flux", "mass_balance"] - obs_std_vars_mankoff: List[str] = [ - "grounding_line_flux_uncertainty", - "mass_balance_uncertainty", - ] - sim_vars_mankoff: List[str] = ["grounding_line_flux", "mass_balance"] - - sim_plot_vars = ( - [ragis_config["Cumulative Variables"]["cumulative_mass_balance"]] - + list(ragis_config["Flux Variables"].values()) - + ["ensemble"] - ) + sim_basins = set(simulated.basin.values) + sim_grace = set(simulated.basin.values) - prior_posterior_mankoff, simulated_prior_mankoff, simulated_posterior_mankoff = ( - run_importance_sampling( + intersection_mankoff = list(sim_basins.intersection(obs_mankoff_basins)) + intersection_grace = list(sim_grace.intersection(obs_grace_basins)) + + observed_mankoff_basins_ds = observed_mankoff_ds.sel( + {"basin": intersection_mankoff} + ) + simulated_mankoff_basins_ds = simulated.sel({"basin": intersection_mankoff}) + + observed_mankoff_basins_resampled_ds = observed_mankoff_basins_ds.resample( + {"time": resampling_frequency} + ).mean() + simulated_mankoff_basins_resampled_ds = simulated_mankoff_basins_ds.resample( + {"time": resampling_frequency} + ).mean() + + observed_grace_basins_ds = observed_grace_ds.sel({"basin": intersection_grace}) + simulated_grace_basins_ds = simulated.sel({"basin": intersection_grace}) + + observed_grace_basins_resampled_ds = observed_grace_basins_ds.resample( + {"time": resampling_frequency} + ).mean() + simulated_grace_basins_resampled_ds = simulated_grace_basins_ds.resample( + {"time": resampling_frequency} + ).mean() + + obs_mean_vars_mankoff: list[str] = ["grounding_line_flux", "mass_balance"] + obs_std_vars_mankoff: list[str] = [ + "grounding_line_flux_uncertainty", + "mass_balance_uncertainty", + ] + sim_vars_mankoff: list[str] = ["grounding_line_flux", "mass_balance"] + + sim_plot_vars = ( + [ragis_config["Cumulative Variables"]["cumulative_mass_balance"]] + + list(ragis_config["Flux Variables"].values()) + + ["ensemble"] + ) + + ( + prior_posterior_mankoff, + simulated_prior_mankoff, + simulated_posterior_mankoff, + ) = run_importance_sampling( observed=observed_mankoff_basins_resampled_ds, simulated=simulated_mankoff_basins_resampled_ds, obs_mean_vars=obs_mean_vars_mankoff, @@ -697,131 +328,118 @@ def run_sensitivity_analysis( fudge_factor=mankoff_fudge_factor, params=params, ) - ) - for filter_var in obs_mean_vars_mankoff: - plot_basins( - observed_mankoff_basins_resampled_ds, - simulated_prior_mankoff[sim_plot_vars], - simulated_posterior_mankoff.sel({"filtered_by": filter_var})[sim_plot_vars], - filter_var=filter_var, - filter_range=filter_range, - fig_dir=fig_dir, - fudge_factor=mankoff_fudge_factor, - config=config, + for filter_var in obs_mean_vars_mankoff: + plot_basins( + observed_mankoff_basins_resampled_ds, + simulated_prior_mankoff[sim_plot_vars], + simulated_posterior_mankoff.sel({"filtered_by": filter_var})[ + sim_plot_vars + ], + filter_var=filter_var, + filter_range=filter_range, + fig_dir=fig_dir, + fudge_factor=mankoff_fudge_factor, + config=config, + ) + + obs_mean_vars_grace: list[str] = ["mass_balance"] + obs_std_vars_grace: list[str] = [ + "mass_balance_uncertainty", + ] + sim_vars_grace: list[str] = ["mass_balance"] + + prior_posterior_grace, simulated_prior_grace, simulated_posterior_grace = ( + run_importance_sampling( + observed=observed_grace_basins_resampled_ds, + simulated=simulated_grace_basins_resampled_ds, + obs_mean_vars=obs_mean_vars_grace, + obs_std_vars=obs_std_vars_grace, + sim_vars=sim_vars_grace, + fudge_factor=grace_fudge_factor, + filter_range=filter_range, + params=params, + ) ) - obs_mean_vars_grace: List[str] = ["mass_balance"] - obs_std_vars_grace: List[str] = [ - "mass_balance_uncertainty", - ] - sim_vars_grace: List[str] = ["mass_balance"] - - prior_posterior_grace, simulated_prior_grace, simulated_posterior_grace = ( - run_importance_sampling( - observed=observed_grace_basins_resampled_ds, - simulated=simulated_grace_basins_resampled_ds, - obs_mean_vars=obs_mean_vars_grace, - obs_std_vars=obs_std_vars_grace, - sim_vars=sim_vars_grace, - fudge_factor=grace_fudge_factor, - filter_range=filter_range, - params=params, + for filter_var in obs_mean_vars_grace: + plot_basins( + observed_grace_basins_resampled_ds, + simulated_prior_grace[sim_plot_vars], + simulated_posterior_grace.sel({"filtered_by": filter_var})[ + sim_plot_vars + ], + filter_var=filter_var, + filter_range=filter_range, + fudge_factor=grace_fudge_factor, + fig_dir=fig_dir, + config=config, + ) + + prior_posterior = pd.concat( + [prior_posterior_mankoff, prior_posterior_grace] + ).reset_index(drop=True) + + # Apply the functions to the corresponding columns + for col, functions in column_function_mapping.items(): + for func in functions: + prior_posterior[col] = prior_posterior[col].apply(func) + + if "frontal_melt.routing.parameter_a" in prior_posterior.columns: + prior_posterior["frontal_melt.routing.parameter_a"] *= 10**4 + if "ocean.th.gamma_T" in prior_posterior.columns: + prior_posterior["ocean.th.gamma_T"] *= 10**4 + if "calving.vonmises_calving.sigma_max" in prior_posterior.columns: + prior_posterior["calving.vonmises_calving.sigma_max"] *= 10**-3 + + prior_posterior.to_parquet( + data_dir + / Path( + f"""prior_posterior_retreat_{retreat_method}_{filter_range[0]}-{filter_range[1]}.parquet""" + ) ) - ) - for filter_var in obs_mean_vars_grace: - plot_basins( - observed_grace_basins_resampled_ds, - simulated_prior_grace[sim_plot_vars], - simulated_posterior_grace.sel({"filtered_by": filter_var})[sim_plot_vars], - filter_var=filter_var, - filter_range=filter_range, - fudge_factor=grace_fudge_factor, + plot_prior_posteriors( + prior_posterior.rename(columns=params_sorted_dict), + x_order=params_sorted_dict.values(), fig_dir=fig_dir, - config=config, + bins_dict=bins_dict, ) - column_function_mapping: Dict[str, List[Callable]] = { - "surface.given.file": [prp.simplify_path, prp.simplify_climate], - "ocean.th.file": [prp.simplify_path, prp.simplify_ocean], - "calving.rate_scaling.file": [prp.simplify_path, prp.simplify_calving], - "geometry.front_retreat.prescribed.file": [prp.simplify_retreat], - } - - # Apply the functions to the corresponding columns - for col, functions in column_function_mapping.items(): - for func in functions: - prior_posterior_mankoff[col] = prior_posterior_mankoff[col].apply(func) - prior_posterior_grace[col] = prior_posterior_grace[col].apply(func) - - bins_dict = config["Posterior Bins"] - parameter_catetories = config["Parameter Categories"] - params_sorted_by_category: dict = { - group: [] for group in sorted(parameter_catetories.values()) - } - for param in params: - prefix = param.split(".")[0] - if prefix in parameter_catetories: - group = parameter_catetories[prefix] - if param not in params_sorted_by_category[group]: - params_sorted_by_category[group].append(param) - - params_sorted_list = list(chain(*params_sorted_by_category.values())) - if "frontal_melt.routing.parameter_a" in prior_posterior_mankoff.columns: - prior_posterior_mankoff["frontal_melt.routing.parameter_a"] *= 10**4 - if "frontal_melt.routing.parameter_a" in prior_posterior_grace.columns: - prior_posterior_grace["frontal_melt.routing.parameter_a"] *= 10**4 - if "ocean.th.gamma_T" in prior_posterior_mankoff.columns: - prior_posterior_mankoff["ocean.th.gamma_T"] *= 10**4 - if "ocean.th.gamma_T" in prior_posterior_grace.columns: - prior_posterior_grace["ocean.th.gamma_T"] *= 10**4 - if "calving.vonmises_calving.sigma_max" in prior_posterior_mankoff.columns: - prior_posterior_mankoff["calving.vonmises_calving.sigma_max"] *= 10**-3 - if "calving.vonmises_calving.sigma_max" in prior_posterior_grace.columns: - prior_posterior_mankoff["calving.vonmises_calving.sigma_max"] *= 10**-3 - - prior_posterior = pd.concat( - [prior_posterior_mankoff, prior_posterior_grace] - ).reset_index(drop=True) - - prior_posterior.to_parquet( - data_dir - / Path(f"""prior_posterior_{filter_range[0]}-{filter_range[1]}.parquet""") - ) - - prior_posterior_sorted = sort_columns(prior_posterior, params_sorted_list) - prior_posterior_mankoff_sorted = sort_columns( - prior_posterior_mankoff, params_sorted_list - ) - prior_posterior_grace = _sorted = sort_columns( - prior_posterior_grace, params_sorted_list - ) - - params_sorted_dict = {k: params_short_dict[k] for k in params_sorted_list} - bins_sorted_dict = {params_short_dict[k]: bins_dict[k] for k in params_sorted_list} - plot_prior_posteriors( - prior_posterior_sorted.rename(columns=params_sorted_dict), - fig_dir=fig_dir, - bins_dict=bins_sorted_dict, - ) + plot_posteriors( + prior_posterior.rename(columns=params_sorted_dict), + x_order=params_sorted_dict.values(), + y_order=posterior_basins_sorted, + hue="filtered_by", + fig_dir=fig_dir, + ) - da = observed_mankoff_basins_resampled_ds.sel( - time=slice(f"{filter_range[0]}", f"{filter_range[-1]}") - ).grounding_line_flux.mean(dim="time") - posterior_basins_sorted = observed_mankoff_basins_resampled_ds.basin.sortby( - da - ).values - - plot_posteriors( - prior_posterior_mankoff_sorted.rename(columns=params_sorted_dict), - order=posterior_basins_sorted, - fig_dir=fig_dir, - ) + p_df = prior_posterior + p_df["retreat_method"] = retreat_method + pp_retreat_list.append(p_df) + + retreat_df = pd.concat(pp_retreat_list).reset_index(drop=True) + + for f_var in ["grounding_line_flux", "mass_balance"]: + fig_p_dir = result_dir / Path(f"filtered_by_{f_var.lower()}") / Path("figures") + fig_p_dir.mkdir(parents=True, exist_ok=True) + + df = retreat_df[ + (retreat_df["filtered_by"] == f_var) + & (retreat_df["ensemble"] == "Posterior") + ] + df = df[df["retreat_method"] != "All"].drop(columns=["filtered_by"]) + plot_posteriors( + df.rename(columns=params_sorted_dict), + x_order=params_sorted_dict.values(), + y_order=["GIS", "CW"], + hue="retreat_method", + fig_dir=fig_p_dir, + ) - prior_config = filter_config(simulated.isel({"time": 0}), params) - prior_df = config_to_dataframe(prior_config, ensemble="Prior") - params_df = prepare_input(prior_df) + prior_config = prp.filter_config(simulated.isel({"time": 0}), params) + prior_df = prp.config_to_dataframe(prior_config, ensemble="Prior") + params_df = prp.prepare_input(prior_df) sensitivity_indices_list = [] for basin_group, intersection, filtering_vars in zip( @@ -846,7 +464,9 @@ def run_sensitivity_analysis( si_dir.mkdir(parents=True, exist_ok=True) sensitivity_indices.to_netcdf(si_dir / Path("sensitivity_indices.nc")) - sensitivity_indices = add_prefix_coord(sensitivity_indices, parameter_catetories) + sensitivity_indices = prp.add_prefix_coord( + sensitivity_indices, parameter_categories + ) # Group by the new coordinate and compute the sum for each group indices_vars = [v for v in sensitivity_indices.data_vars if "_conf" not in v] diff --git a/pism_ragis/analyze.py b/pism_ragis/analyze.py index 666a25e..f8319a8 100644 --- a/pism_ragis/analyze.py +++ b/pism_ragis/analyze.py @@ -16,22 +16,123 @@ # along with PISM; if not, write to the Free Software # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +# pylint: disable=too-many-positional-arguments + """ Module for sensitivity analysis. """ -from typing import Any, Dict +from typing import Any import numpy as np import pandas as pd import xarray as xr +from dask.distributed import Client, progress from SALib.analyze import delta, sobol +from pism_ragis.decorators import timeit + + +@timeit +def run_sensitivity_analysis( + input_df: pd.DataFrame, + response_ds: xr.Dataset, + filter_vars: list[str], + group_dim: str = "basin", + iter_dim: str = "time", + notebook: bool = False, +) -> xr.Dataset: + """ + Run delta sensitivity analysis on the given dataset. + + This function calculates sensitivity indices for each basin in the dataset, + filtered by the specified variables. It uses Dask for parallel processing + to improve performance. + + Parameters + ---------- + input_df : pd.DataFrame + DataFrame containing ensemble information, with a 'basin' column to group by. + response_ds : xr.Dataset + The input dataset containing the data to be analyzed. + filter_vars : list[str] + List of variables to filter by for sensitivity analysis. + group_dim : str, optional + The dimension to group by, by default "basin". + iter_dim : str, optional + The dimension to iterate over, by default "time". + notebook : bool, optional + Whether to display a nicer progress bar when running in a notebook, by default False. + + Returns + ------- + xr.Dataset + A dataset containing the calculated sensitivity indices for each basin and filter variable. + + Notes + ----- + It is imperative to load the dataset before starting the Dask client, + to avoid each Dask worker loading the dataset separately, which would + significantly slow down the computation. + """ + print("Calculating Sensitivity Indices") + print("===============================") + + client = Client() + print(f"Open client in browser: {client.dashboard_link}") + sensitivity_indices_list = [] + for gdim, df in input_df.groupby(by=group_dim): + df = df.drop(columns=[group_dim]) + problem = { + "num_vars": len(df.columns), + "names": df.columns, # Parameter names + "bounds": zip( + df.min().values, + df.max().values, + ), # Parameter bounds + } + for filter_var in filter_vars: + print( + f" ...sensitivity indices for basin {gdim} filtered by {filter_var} ", + ) + + responses = response_ds.sel({"basin": gdim})[filter_var].load() + responses_scattered = client.scatter( + [ + responses.isel({"time": k}).to_numpy() + for k in range(len(responses[iter_dim])) + ] + ) + + futures = client.map( + delta_analysis, + responses_scattered, + X=df.to_numpy(), + problem=problem, + ) + progress(futures, notebook=notebook) + result = client.gather(futures) + + sensitivity_indices = xr.concat( + [r.expand_dims(iter_dim) for r in result], dim=iter_dim + ) + sensitivity_indices[iter_dim] = responses[iter_dim] + sensitivity_indices = sensitivity_indices.expand_dims(group_dim, axis=1) + sensitivity_indices[group_dim] = [gdim] + sensitivity_indices = sensitivity_indices.expand_dims("filtered_by", axis=2) + sensitivity_indices["filtered_by"] = [filter_var] + sensitivity_indices_list.append(sensitivity_indices) + + all_sensitivity_indices: xr.Dataset = xr.merge(sensitivity_indices_list) + client.close() + + return all_sensitivity_indices + def delta_analysis( Y: np.ndarray, X: np.ndarray, - problem: Dict[str, Any], + problem: dict[str, Any], dim: str = "pism_config_axis", ) -> xr.Dataset: """ @@ -74,7 +175,7 @@ def delta_analysis( def sobol_analysis( response: np.ndarray, - problem: Dict[str, Any], + problem: dict[str, Any], ensemble_df: pd.DataFrame, dim: str = "pism_config_axis", ) -> xr.Dataset: diff --git a/pism_ragis/plotting.py b/pism_ragis/plotting.py index 89a1f40..eaccc36 100644 --- a/pism_ragis/plotting.py +++ b/pism_ragis/plotting.py @@ -49,7 +49,9 @@ @timeit def plot_posteriors( df: pd.DataFrame, - order: list[str] | None = None, + x_order: list[str], + y_order: list[str] | None = None, + hue: str | None = "filtered_by", figsize: tuple[float, float] | None = (6.4, 5.2), fig_dir: str | Path = "figures", fontsize: float = 4, @@ -61,8 +63,12 @@ def plot_posteriors( ---------- df : pd.DataFrame DataFrame containing the data to plot. - order : list[str] or None, optional + x_order : Iterable[str] or None, optional + Order of the variables for the x-axis, by default None. + y_order : Iterable[str] or None, optional Order of the basins for the y-axis, by default None. + hue : str or None, optional + Variable name for the hue, by default "filtered_by". figsize : tuple[float, float] or None, optional Size of the figure, by default (6.4, 5.2). fig_dir : str or Path, optional @@ -84,31 +90,28 @@ def plot_posteriors( } with mpl.rc_context(rc=rc_params): - m_df = df.drop(columns=["exp_id"]) fig, axs = plt.subplots( 3, 6, sharey=True, figsize=figsize, ) - fig.subplots_adjust(hspace=0.1, wspace=0.1) - for k, v in enumerate( - m_df.drop(columns=["ensemble", "basin", "filtered_by"]).columns - ): + fig.subplots_adjust(hspace=0.75, wspace=0.1) + for k, v in enumerate(x_order): legend = bool(k == 0) ax = axs.ravel()[k] try: _ = sns.violinplot( - data=m_df, + data=df, x=v, y="basin", - order=order, + order=y_order, linewidth=0.25, cut=0, gap=0.1, split=True, inner="quart", - hue="filtered_by", + hue=hue, orient="h", palette=["#DDCC77", "#CC6677"], ax=ax, @@ -118,14 +121,24 @@ def plot_posteriors( pass if legend: - ax.get_legend().set_title(None) - ax.get_legend().get_frame().set_linewidth(0.0) - ax.get_legend().get_frame().set_alpha(0.0) + ax.get_legend().remove() - if k > len(m_df.drop(columns=["ensemble", "basin", "filtered_by"]).columns): + if k > len(x_order): ax.set_visible(False) + # Create a legend outside the figure at the bottom middle + handles, labels = axs[0, 0].get_legend_handles_labels() + legend_main = fig.legend( + handles, labels, loc="lower center", bbox_to_anchor=(0.5, 0.0), ncol=2 + ) + legend_main.set_title(None) + legend_main.get_frame().set_linewidth(0.0) + legend_main.get_frame().set_alpha(0.0) + + # Adjust layout to make room for the legend fig.tight_layout() + fig.subplots_adjust(bottom=0.1) + fn = pdf_dir / Path("posteriors_violinplots.pdf") fig.savefig(fn) fn = png_dir / Path("posteriors_violinplots.png") @@ -140,7 +153,9 @@ def plot_prior_posteriors( figsize: tuple[float, float] | None = (6.4, 3.2), fig_dir: str | Path = "figures", fontsize: float = 4, - bins_dict: dict = {}, + x_order: list[str] = [], + bins_dict: dict | None = None, + group_columns: list = ["basin", "filtered_by"], ): """ Plot histograms of prior and posterior distributions. @@ -155,8 +170,12 @@ def plot_prior_posteriors( Directory to save the figures, by default "figures". fontsize : float, optional Font size for the plot, by default 4. + x_order : Iterable[str] or None, optional + Order of the variables for the x-axis, by default None. bins_dict : dict, optional Dictionary specifying the number of bins for each variable, by default {}. + group_columns : list, optional + List of columns to group by, by default ["basin", "filtered_by"]. """ plot_dir = fig_dir / Path("basin_histograms") @@ -166,8 +185,6 @@ def plot_prior_posteriors( png_dir = plot_dir / Path("pngs") png_dir.mkdir(parents=True, exist_ok=True) - group_columns = ["basin", "filtered_by"] - rc_params = { "font.size": fontsize, # Add other rcParams settings if needed @@ -176,10 +193,9 @@ def plot_prior_posteriors( with mpl.rc_context(rc=rc_params): with tqdm( desc="Plotting prior and posterior histograms", - total=len(df["basin"].unique()), + total=len(df.groupby(by=group_columns)), ) as progress_bar: for (basin, filter_var), m_df in df.groupby(by=group_columns): - m_df = m_df.drop(columns=group_columns + ["exp_id"]) fig, axs = plt.subplots( 3, 6, @@ -187,7 +203,11 @@ def plot_prior_posteriors( figsize=figsize, ) fig.subplots_adjust(hspace=0.5, wspace=0.1) - for k, v in enumerate(m_df.drop(columns=["ensemble"]).columns): + for k, v in enumerate(x_order): + if bins_dict is not None: + bins = bins_dict.get(v, "auto") + else: + bins = None legend = bool(k == 1) try: ax = axs.ravel()[k] @@ -197,7 +217,7 @@ def plot_prior_posteriors( hue="ensemble", hue_order=["Prior", "Posterior"], palette=sim_cmap, - bins=bins_dict[v], + bins=bins, common_norm=False, stat="probability", multiple="dodge", diff --git a/pism_ragis/processing.py b/pism_ragis/processing.py index a2f6c6f..efd5655 100644 --- a/pism_ragis/processing.py +++ b/pism_ragis/processing.py @@ -329,6 +329,73 @@ def preprocess_scalar_nc( ) +def sort_columns(df: pd.DataFrame, sorted_columns: List[str]) -> pd.DataFrame: + """ + Sort columns of a DataFrame. + + This function sorts the columns of a DataFrame such that the columns specified in + `sorted_columns` appear in the specified order, while all other columns appear before + the sorted columns in their original order. + + Parameters + ---------- + df : pd.DataFrame + The input DataFrame to be sorted. + sorted_columns : List[str] + A list of column names to be sorted. + + Returns + ------- + pd.DataFrame + The DataFrame with columns sorted as specified. + """ + # Identify columns that are not in the list + other_columns = [col for col in df.columns if col not in sorted_columns] + + # Concatenate other columns with the sorted columns + new_column_order = other_columns + sorted_columns + + # Reindex the DataFrame + return df.reindex(columns=new_column_order) + + +def add_prefix_coord( + sensitivity_indices: xr.Dataset, parameter_groups: Dict[str, str] +) -> xr.Dataset: + """ + Add prefix coordinates to an xarray Dataset. + + This function extracts the prefix from each coordinate value in the 'pism_config_axis' + and adds it as a new coordinate. It also maps the prefixes to their corresponding + sensitivity indices groups. + + Parameters + ---------- + sensitivity_indices : xr.Dataset + The input dataset containing sensitivity indices. + parameter_groups : Dict[str, str] + A dictionary mapping parameter names to their corresponding groups. + + Returns + ------- + xr.Dataset + The dataset with added prefix coordinates and sensitivity indices groups. + """ + prefixes = [ + name.split(".")[0] for name in sensitivity_indices.pism_config_axis.values + ] + + sensitivity_indices = sensitivity_indices.assign_coords( + prefix=("pism_config_axis", prefixes) + ) + si_prefixes = [parameter_groups[name] for name in sensitivity_indices.prefix.values] + + sensitivity_indices = sensitivity_indices.assign_coords( + sensitivity_indices_group=("pism_config_axis", si_prefixes) + ) + return sensitivity_indices + + def compute_basin( ds: xr.Dataset, name: str = "basin", @@ -644,16 +711,12 @@ def load_ensemble( The loaded xarray Dataset containing the ensemble data. """ print("Loading ensemble files... ", end="", flush=True) - ds = ( - xr.open_mfdataset( - filenames, - parallel=parallel, - preprocess=preprocess, - engine=engine, - ) - .drop_vars(["spatial_ref", "mapping"], errors="ignore") - .dropna(dim="exp_id") - ) + ds = xr.open_mfdataset( + filenames, + parallel=parallel, + preprocess=preprocess, + engine=engine, + ).drop_vars(["spatial_ref", "mapping"], errors="ignore") print("Done.") return ds @@ -1053,7 +1116,7 @@ def config_to_dataframe( @timeit def filter_retreat_experiments( - ds: xr.Dataset, retreat_method: Literal["free", "prescribed", "all"] + ds: xr.Dataset, retreat_method: Literal["Free", "Prescribed", "All"] ) -> xr.Dataset: """ Filter retreat experiments based on the retreat method. @@ -1065,7 +1128,7 @@ def filter_retreat_experiments( ---------- ds : xr.Dataset The input dataset containing the retreat experiments. - retreat_method : {"free", "prescribed", "all"} + retreat_method : {"Free", "Prescribed", "All"} The retreat method to filter by. "free" selects experiments with no prescribed retreat, "prescribed" selects experiments with a prescribed retreat, and "all" selects all experiments. @@ -1092,11 +1155,11 @@ def filter_retreat_experiments( pism_config_axis="geometry.front_retreat.prescribed.file" ).compute() - if retreat_method == "free": + if retreat_method == "Free": retreat_exp_ids = retreat.where( retreat["pism_config"] == "false", drop=True ).exp_id.values - elif retreat_method == "prescribed": + elif retreat_method == "Prescribed": retreat_exp_ids = retreat.where( retreat["pism_config"] != "false", drop=True ).exp_id.values @@ -1107,3 +1170,219 @@ def filter_retreat_experiments( ds = ds.sel(exp_id=retreat_exp_ids) return ds + + +def prepare_input( + df: pd.DataFrame, + params: List[str] = [ + "surface.given.file", + "ocean.th.file", + "calving.rate_scaling.file", + "geometry.front_retreat.prescribed.file", + ], +) -> pd.DataFrame: + """ + Prepare the input DataFrame by converting columns to numeric and mapping unique values to integers. + + This function processes the input DataFrame by converting specified columns to numeric values, + dropping specified columns, and mapping unique values in the specified parameters to integers. + + Parameters + ---------- + df : pd.DataFrame + The input DataFrame to be processed. + params : List[str], optional + A list of column names to be processed. Unique values in these columns will be mapped to integers. + By default, the list includes: + ["surface.given.file", "ocean.th.file", "calving.rate_scaling.file", "geometry.front_retreat.prescribed.file"]. + + Returns + ------- + pd.DataFrame + The processed DataFrame with specified columns converted to numeric and unique values mapped to integers. + + Examples + -------- + >>> df = pd.DataFrame({ + ... "surface.given.file": ["file1", "file2", "file1"], + ... "ocean.th.file": ["fileA", "fileB", "fileA"], + ... "calving.rate_scaling.file": ["fileX", "fileY", "fileX"], + ... "geometry.front_retreat.prescribed.file": ["fileM", "fileN", "fileM"], + ... "ensemble": [1, 2, 3], + ... "exp_id": [101, 102, 103] + ... }) + >>> prepare_input(df) + surface.given.file ocean.th.file calving.rate_scaling.file geometry.front_retreat.prescribed.file + 0 0 0 0 0 + 1 1 1 1 1 + 2 0 0 0 0 + """ + df = df.apply(convert_column_to_numeric).drop( + columns=["ensemble", "exp_id"], errors="ignore" + ) + + for param in params: + m_dict: Dict[str, int] = {v: k for k, v in enumerate(df[param].unique())} + df[param] = df[param].map(m_dict) + + return df + + +@timeit +def prepare_simulations( + filenames: List[Union[Path, str]], + config: Dict[str, Any], + reference_date: str, + parallel: bool = True, + engine: str = "h5netcdf", +) -> xr.Dataset: + """ + Prepare simulations by loading and processing ensemble datasets. + + This function loads ensemble datasets from the specified filenames, processes them + according to the provided configuration, and returns the processed dataset. The + processing steps include sorting, converting byte strings to strings, dropping NaNs, + standardizing variable names, calculating cumulative variables, and normalizing + cumulative variables. + + Parameters + ---------- + filenames : List[Union[Path, str]] + A list of file paths to the ensemble datasets. + config : Dict[str, Any] + A dictionary containing configuration settings for processing the datasets. + reference_date : str + The reference date for normalizing cumulative variables. + parallel : bool, optional + Whether to load the datasets in parallel, by default True. + engine : str, optional + The engine to use for loading the datasets, by default "h5netcdf". + + Returns + ------- + xr.Dataset + The processed xarray dataset. + + Examples + -------- + >>> filenames = ["file1.nc", "file2.nc"] + >>> config = { + ... "PISM Spatial": {...}, + ... "Cumulative Variables": { + ... "cumulative_grounding_line_flux": "cumulative_gl_flux", + ... "cumulative_smb": "cumulative_smb_flux" + ... }, + ... "Flux Variables": { + ... "grounding_line_flux": "gl_flux", + ... "smb_flux": "smb_flux" + ... } + ... } + >>> reference_date = "2000-01-01" + >>> ds = prepare_simulations(filenames, config, reference_date) + """ + ds = load_ensemble(filenames, parallel=parallel, engine=engine).sortby("basin") + ds = xr.apply_ufunc(np.vectorize(convert_bstrings_to_str), ds, dask="parallelized") + + ds = standardize_variable_names(ds, config["PISM Spatial"]) + ds[config["Cumulative Variables"]["cumulative_grounding_line_flux"]] = ds[ + config["Flux Variables"]["grounding_line_flux"] + ].cumsum() / len(ds.time) + ds[config["Cumulative Variables"]["cumulative_smb"]] = ds[ + config["Flux Variables"]["smb_flux"] + ].cumsum() / len(ds.time) + ds = normalize_cumulative_variables( + ds, + list(config["Cumulative Variables"].values()), + reference_date=reference_date, + ) + return ds + + +@timeit +def prepare_observations( + basin_url: Union[Path, str], + grace_url: Union[Path, str], + config: Dict[str, Any], + reference_date: str, + engine: str = "h5netcdf", +) -> tuple[xr.Dataset, xr.Dataset]: + """ + Prepare observation datasets by normalizing cumulative variables. + + This function loads observation datasets from the specified URLs, sorts them by basin, + normalizes the cumulative variables, and returns the processed datasets. + + Parameters + ---------- + basin_url : Union[Path, str] + The URL or path to the basin observation dataset. + grace_url : Union[Path, str] + The URL or path to the GRACE observation dataset. + config : Dict[str, Any] + A dictionary containing configuration settings for processing the datasets. + reference_date : str + The reference date for normalizing cumulative variables. + engine : str, optional + The engine to use for loading the datasets, by default "h5netcdf". + + Returns + ------- + tuple[xr.Dataset, xr.Dataset] + A tuple containing the processed basin and GRACE observation datasets. + + Examples + -------- + >>> config = { + ... "Cumulative Variables": {"cumulative_mass_balance": "mass_balance"}, + ... "Cumulative Uncertainty Variables": {"cumulative_mass_balance_uncertainty": "mass_balance_uncertainty"} + ... } + >>> prepare_observations("basin.nc", "grace.nc", config, "2000-01-1") + (, ) + """ + obs_basin = xr.open_dataset(basin_url, engine=engine, chunks=-1) + obs_basin = obs_basin.sortby("basin") + + cumulative_vars = config["Cumulative Variables"] + cumulative_uncertainty_vars = config["Cumulative Uncertainty Variables"] + + obs_basin = normalize_cumulative_variables( + obs_basin, + list(cumulative_vars.values()) + list(cumulative_uncertainty_vars.values()), + reference_date, + ) + + obs_grace = xr.open_dataset(grace_url, engine=engine, chunks=-1) + obs_grace = obs_grace.sortby("basin") + + cumulative_vars = config["Cumulative Variables"]["cumulative_mass_balance"] + cumulative_uncertainty_vars = config["Cumulative Uncertainty Variables"][ + "cumulative_mass_balance_uncertainty" + ] + + obs_grace = normalize_cumulative_variables( + obs_grace, + [cumulative_vars] + [cumulative_uncertainty_vars], + reference_date, + ) + + return obs_basin, obs_grace + + +def convert_bstrings_to_str(element: Any) -> Any: + """ + Convert byte strings to regular strings. + + Parameters + ---------- + element : Any + The element to be checked and potentially converted. If the element is a byte string, + it will be converted to a regular string. Otherwise, the element will be returned as is. + + Returns + ------- + Any + The converted element if it was a byte string, otherwise the original element. + """ + if isinstance(element, bytes): + return element.decode("utf-8") + return element