Skip to content

Commit

Permalink
more clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
aaschwanden committed Nov 27, 2024
1 parent 4b97ac5 commit c24c6dd
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
with:
micromamba-version: "latest"
generate-run-shell: true
environment-file: environment.yml
environment-file: environment-dev.yml
cache-environment: true
post-cleanup: 'all'
- name: Run tests
Expand Down
89 changes: 76 additions & 13 deletions analysis/analyze_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA

# pylint: disable=unused-import,too-many-positional-arguments

"""
Analyze RAGIS ensemble.
"""

import time
Expand Down Expand Up @@ -104,30 +103,33 @@ def filter_config(ds: xr.Dataset, params: List[str]) -> xr.DataArray:

@timeit
def prepare_simulations(
filenames: List[Path | str],
filenames: List[Union[Path, str]],
config: Dict,
reference_date: str,
parallel: bool = True,
engine: str = "netcdf4",
engine: str = "h5netcdf",
) -> xr.Dataset:
"""
Prepare simulations by loading and processing ensemble datasets.
This function loads ensemble datasets from the specified filenames, processes them
according to the provided configuration, and returns the processed dataset. The
processing steps include sorting, dropping NaNs, standardizing variable names,
calculating cumulative variables, and normalizing cumulative variables.
processing steps include sorting, converting byte strings to strings, dropping NaNs,
standardizing variable names, calculating cumulative variables, and normalizing
cumulative variables.
Parameters
----------
filenames : List[Union[Path, str]]
A list of file paths to the ensemble datasets.
config : Dict
A dictionary containing configuration settings for processing the datasets.
reference_date : str
The reference date for normalizing cumulative variables.
parallel : bool, optional
Whether to load the datasets in parallel, by default True.
engine : str, optional
The engine to use for loading the datasets, by default "netcdf4".
The engine to use for loading the datasets, by default "h5netcdf".
Returns
-------
Expand All @@ -148,7 +150,8 @@ def prepare_simulations(
... "smb_flux": "smb_flux"
... }
... }
>>> ds = prepare_simulations(filenames, config)
>>> reference_date = "2000-01-01"
>>> ds = prepare_simulations(filenames, config, reference_date)
"""
ds = prp.load_ensemble(filenames, parallel=parallel, engine=engine).sortby("basin")
ds = xr.apply_ufunc(np.vectorize(convert_bstrings_to_str), ds, dask="parallelized")
Expand Down Expand Up @@ -192,6 +195,8 @@ def prepare_observations(
A dictionary containing configuration settings for processing the datasets.
reference_date : str
The reference date for normalizing cumulative variables.
engine : str, optional
The engine to use for loading the datasets, by default "h5netcdf".
Returns
-------
Expand Down Expand Up @@ -237,22 +242,49 @@ def prepare_observations(


@timeit
def config_to_dataframe(config: xr.DataArray, ensemble: str | None = None):
def config_to_dataframe(
config: xr.DataArray, ensemble: Union[str, None] = None
) -> pd.DataFrame:
"""
Convert an xarray DataArray configuration to a pandas DataFrame.
This function converts the input DataArray containing configuration data into a
pandas DataFrame. The dimensions of the DataArray (excluding 'pism_config_axis')
are used as the index, and the 'pism_config_axis' values are used as columns.
Parameters
----------
config : xr.DataArray
The input DataArray containing the configuration data.
ensemble : Union[str, None], optional
An optional string to add as a column named 'Ensemble' in the DataFrame, by default None.
Returns
-------
pd.DataFrame
A DataFrame where the dimensions of the DataArray (excluding 'pism_config_axis')
are used as the index, and the 'pism_config_axis' values are used as columns.
Examples
--------
>>> config = xr.DataArray(
... data=[[1, 2, 3], [4, 5, 6]],
... dims=["time", "pism_config_axis"],
... coords={"time": [0, 1], "pism_config_axis": ["param1", "param2", "param3"]}
... )
>>> df = config_to_dataframe(config)
>>> print(df)
pism_config_axis time param1 param2 param3
0 0 1 2 3
1 1 4 5 6
>>> df = config_to_dataframe(config, ensemble="Ensemble1")
>>> print(df)
pism_config_axis time param1 param2 param3 Ensemble
0 0 1 2 3 Ensemble1
1 1 4 5 6 Ensemble1
"""
dims = [dim for dim in config.dims if not dim in ["pism_config_axis"]]
dims = [dim for dim in config.dims if dim != "pism_config_axis"]
df = config.to_dataframe().reset_index()
df = df.pivot(index=dims, columns="pism_config_axis", values="pism_config")
df.reset_index(inplace=True)
Expand Down Expand Up @@ -282,10 +314,41 @@ def convert_bstrings_to_str(element: Any) -> Any:


def plot_outliers(
filtered_da: xr.DataArray, outliers_da: xr.DataArray, filename: Path | str
filtered_da: xr.DataArray, outliers_da: xr.DataArray, filename: Union[Path, str]
):
"""
Plot outliers.
Plot outliers in the given DataArrays and save the plot to a file.
This function creates a plot with the filtered data and outliers, and saves the plot
to the specified filename. The filtered data is plotted in black, and the outliers
are plotted in red.
Parameters
----------
filtered_da : xr.DataArray
The DataArray containing the filtered data.
outliers_da : xr.DataArray
The DataArray containing the outliers.
filename : Union[Path, str]
The path or filename where the plot will be saved.
Returns
-------
None
Examples
--------
>>> filtered_da = xr.DataArray(
... data=[[1, 2, 3], [4, 5, 6]],
... dims=["time", "exp_id"],
... coords={"time": [0, 1], "exp_id": [0, 1, 2]}
... )
>>> outliers_da = xr.DataArray(
... data=[[7, 8, 9], [10, 11, 12]],
... dims=["time", "exp_id"],
... coords={"time": [0, 1], "exp_id": [0, 1, 2]}
... )
>>> plot_outliers(filtered_da, outliers_da, "outliers_plot.png")
"""
fig, ax = plt.subplots(1, 1)
if filtered_da.size > 0:
Expand Down Expand Up @@ -1295,7 +1358,7 @@ def plot_obs_sims_3(
hue="sensitivity_indices_group", ax=ax, lw=0.75, label=g.values
)
ax.legend()
ax.set_title(f"S1 for {basin} filtered by {filter_var}")
ax.set_title(f"{indices_var} for {basin} filtered by {filter_var}")
fn = (
result_dir
/ Path("figures")
Expand Down

0 comments on commit c24c6dd

Please sign in to comment.