Skip to content

Commit

Permalink
Switching to concurrent.Futures
Browse files Browse the repository at this point in the history
  • Loading branch information
aaschwanden committed Nov 26, 2024
1 parent c168bb6 commit 026c4cf
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 32 deletions.
68 changes: 37 additions & 31 deletions analysis/analyze_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from functools import wraps
from importlib.resources import files
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Callable, Dict, Hashable, List, Mapping, Union

Expand All @@ -40,7 +41,6 @@
import xarray as xr
from dask.diagnostics import ProgressBar
from dask.distributed import Client, progress
from joblib import Parallel, delayed
from tqdm.auto import tqdm

import pism_ragis.processing as prp
Expand All @@ -52,6 +52,7 @@

logger = get_logger("pism_ragis")


xr.set_options(keep_attrs=True)
plt.style.use("tableau-colorblind10")

Expand Down Expand Up @@ -426,7 +427,9 @@ def plot_obs_sims(
"""

import pism_ragis.processing # pylint: disable=import-outside-toplevel,reimported

import matplotlib # pylint: disable=import-outside-toplevel,reimported
matplotlib.use('Agg')

Path(fig_dir).mkdir(exist_ok=True)

percentile_range = (percentiles[1] - percentiles[0]) * 100
Expand Down Expand Up @@ -544,7 +547,7 @@ def plot_obs_sims(

axs[0].xaxis.set_tick_params(labelbottom=False)

axs[0].set_ylabel(f"Cumulative mass\nloss since {reference_date} (Gt)")
axs[0].set_ylabel(f"Cumulative mass\nchange since {reference_date} (Gt)")
axs[0].set_xlabel("")
axs[0].set_title(f"{basin} filtered by {filtering_var}")
axs[1].set_title("")
Expand Down Expand Up @@ -862,12 +865,6 @@ def plot_obs_sims_3(
type=str,
default="grounding_line_flux",
)
parser.add_argument(
"--ensemble",
help="""Name of the ensemble. Default=RAGIS.""",
type=str,
default="RAGIS",
)
parser.add_argument(
"--fudge_factor",
help="""Observational uncertainty multiplier. Default=3""",
Expand All @@ -890,7 +887,7 @@ def plot_obs_sims_3(
"--resampling_frequency",
help="""Resampling data to resampling_frequency for importance sampling. Default is "MS".""",
type=str,
default="MS",
default="YS",
)
parser.add_argument(
"--reference_date",
Expand Down Expand Up @@ -925,7 +922,6 @@ def plot_obs_sims_3(

options, unknown = parser.parse_known_args()
basin_files = options.FILES
ensemble = options.ensemble
engine = options.engine
filter_start_year, filter_end_year = options.filter_range
fudge_factor = options.fudge_factor
Expand Down Expand Up @@ -1004,15 +1000,10 @@ def plot_obs_sims_3(
# Path(fig_dir) / Path(f"{outlier_variable}_filtering.pdf"),
# )

start = time.time()

prior_config = simulated_ds.sel(pism_config_axis=params).pism_config
prior_df = config_to_dataframe(prior_config, ensemble="Prior")

end = time.time()
time_elapsed = end - start
print(f"Time elapsed {time_elapsed:.0f}s")

outliers_config = filter_config(outliers_ds, params)
outliers_df = config_to_dataframe(outliers_config, ensemble="Outliers")

Expand Down Expand Up @@ -1110,20 +1101,28 @@ def plot_obs_sims_3(
total=len(observed_mankoff_basins_resampled_ds.basin),
)
) as progress_bar:
result = Parallel(n_jobs=options.n_jobs)(
delayed(plot_obs_sims)(
observed_mankoff_basins_resampled_ds.sel(basin=basin),
sim_prior.sel(basin=basin),
sim_posterior.sel(basin=basin),
config=ragis_config,
filtering_var=obs_mean_var,
filter_range=[filter_start_year, filter_end_year],
fig_dir=fig_dir,
obs_alpha=obs_alpha,
sim_alpha=sim_alpha,
)
for basin in observed_mankoff_basins_resampled_ds.basin
)

with ThreadPoolExecutor(max_workers=options.n_jobs) as executor:
futures = []
for basin in observed_mankoff_basins_resampled_ds.basin:
futures.append(executor.submit(plot_obs_sims,
observed_mankoff_basins_resampled_ds.sel(basin=basin),
sim_prior.sel(basin=basin),
sim_posterior.sel(basin=basin),
config=ragis_config,
filtering_var=obs_mean_var,
filter_range=[filter_start_year, filter_end_year],
fig_dir=fig_dir,
obs_alpha=obs_alpha,
sim_alpha=sim_alpha,
))
for future in as_completed(futures):
try:
future.result()
except Exception as e:
print(f"An error occurred: {e}")



prior_posterior = pd.concat(prior_posterior_list).reset_index()
prior_posterior = prior_posterior.apply(prp.convert_column_to_numeric)
Expand Down Expand Up @@ -1203,9 +1202,16 @@ def plot_obs_sims_3(
"calving.rate_scaling.file"
].map(calving_dict)

retreat_dict = {
v: k for k, v in enumerate(ensemble_df["geometry.front_retreat.prescribed.file"].unique())
}

ensemble_df["geometry.front_retreat.prescribed.file"] = ensemble_df[
"geometry.front_retreat.prescribed.file"
].map(retreat_dict)
to_analyze = simulated_ds.sel(time=slice("1980-01-01", "2020-01-01"))
all_delta_indices = run_delta_analysis(
to_analyze, ensemble_df, list(flux_vars.values())[:2], notebook=notebook
to_analyze, ensemble_df, list(flux_vars.values())[1:2], notebook=notebook
)

# Extract the prefix from each coordinate value
Expand Down
1 change: 1 addition & 0 deletions pism_ragis/data/ragis_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"atmosphere" = "Climate"
"ocean" = "Ocean"
"calving" = "Calving"
"geometry" = "Calving"
"frontal_melt" = "Frontal Melt"
"basal_resistance" = "Flow"
"basal_yield_stress" = "Flow"
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ joblib
jupyter
matplotlib
netCDF4<1.7
numpy
numpy>2.0.0
openpyxl
pandas
pint_xarray
Expand Down

0 comments on commit 026c4cf

Please sign in to comment.