diff --git a/e3sm_diags/parameter/meridional_mean_2d_parameter.py b/e3sm_diags/parameter/meridional_mean_2d_parameter.py index 44da9ad60..20346f384 100644 --- a/e3sm_diags/parameter/meridional_mean_2d_parameter.py +++ b/e3sm_diags/parameter/meridional_mean_2d_parameter.py @@ -10,7 +10,7 @@ def __init__(self): super(MeridionalMean2dParameter, self).__init__() # Override existing attributes # ============================= - self.plevs = numpy.logspace(2.0, 3.0, num=17).tolist() + self.plevs = numpy.logspace(2.0, 3.0, num=17).tolist() # type: ignore self.plot_log_plevs = False self.plot_plevs = False # Granulating plevs causes duplicate plots in this case. diff --git a/e3sm_diags/parameter/zonal_mean_2d_parameter.py b/e3sm_diags/parameter/zonal_mean_2d_parameter.py index 554c35ce6..a5ba31d45 100644 --- a/e3sm_diags/parameter/zonal_mean_2d_parameter.py +++ b/e3sm_diags/parameter/zonal_mean_2d_parameter.py @@ -14,7 +14,7 @@ def __init__(self): super(ZonalMean2dParameter, self).__init__() # Override existing attributes # ============================= - self.plevs = copy.deepcopy(DEFAULT_PLEVS) + self.plevs = copy.deepcopy(DEFAULT_PLEVS) # type: ignore self.plot_log_plevs = False self.plot_plevs = False # Granulating plevs causes duplicate plots in this case. diff --git a/e3sm_diags/parameter/zonal_mean_2d_stratosphere_parameter.py b/e3sm_diags/parameter/zonal_mean_2d_stratosphere_parameter.py index bb8adb350..78694e022 100644 --- a/e3sm_diags/parameter/zonal_mean_2d_stratosphere_parameter.py +++ b/e3sm_diags/parameter/zonal_mean_2d_stratosphere_parameter.py @@ -12,5 +12,5 @@ def __init__(self): super(ZonalMean2dStratosphereParameter, self).__init__() # Override existing attributes # ============================= - self.plevs = copy.deepcopy(DEFAULT_PLEVS) + self.plevs = copy.deepcopy(DEFAULT_PLEVS) # type: ignore self.plot_log_plevs = True diff --git a/tests/complete_run_script.py b/tests/complete_run_script.py index 3e723f87f..3a6192856 100644 --- a/tests/complete_run_script.py +++ b/tests/complete_run_script.py @@ -46,7 +46,10 @@ case = "extendedOutput.v3.LR.historical_0101" short_name = "v3.LR.historical_0101" -results_dir = "/global/cfs/cdirs/e3sm/www/chengzhu/tutorial2024/e3sm_diags_extended_int" + +# TODO: Update `MAIN_DIR` as needed. +MAIN_DIR = "24-12-09-main" +results_dir = f"/global/cfs/cdirs/e3sm/www/e3sm_diags/{MAIN_DIR}/" test_climo = "/global/cfs/cdirs/e3sm/chengzhu/tutorial2024/v3.LR.historical_0101/post/atm/180x360_aave/clim/15yr" test_ts = "/global/cfs/cdirs/e3sm/chengzhu/tutorial2024/v3.LR.historical_0101/post/atm/180x360_aave/ts/monthly/15yr" @@ -78,6 +81,7 @@ param.output_format_subplot = [] param.multiprocessing = True param.num_workers = 24 +param.save_netcdf = True param.seasons = ["ANN"] params = [param] @@ -93,6 +97,7 @@ enso_param.ref_start_yr = start_yr enso_param.ref_end_yr = end_yr +enso_param.save_netcdf = True params.append(enso_param) trop_param = TropicalSubseasonalParameter() @@ -106,7 +111,9 @@ trop_param.ref_start_yr = "2001" trop_param.ref_end_yr = "2010" +trop_param.save_netcdf = True params.append(trop_param) + qbo_param = QboParameter() qbo_param.test_data_path = test_ts # qbo_param.test_name = short_name @@ -118,7 +125,9 @@ # Obs qbo_param.reference_data_path = ref_ts +qbo_param.save_netcdf = True params.append(qbo_param) + dc_param = DiurnalCycleParameter() dc_param.test_data_path = "/global/cfs/cdirs/e3sm/chengzhu/tutorial2024/v3.LR.historical_0101/post/atm/180x360_aave/clim_diurnal_8xdaily/" # dc_param.short_test_name = short_name @@ -128,7 +137,9 @@ # Obs dc_param.reference_data_path = ref_climo +dc_param.save_netcdf = True params.append(dc_param) + streamflow_param = StreamflowParameter() streamflow_param.reference_data_path = ref_ts streamflow_param.test_data_path = "/global/cfs/cdirs/e3sm/chengzhu/tutorial2024/v3.LR.historical_0101/post/rof/native/ts/monthly/15yr/" @@ -143,7 +154,9 @@ ) streamflow_param.ref_end_yr = "1995" +streamflow_param.save_netcdf = True params.append(streamflow_param) + tc_param = TCAnalysisParameter() tc_param.test_data_path = "/global/cfs/cdirs/e3sm/chengzhu/tutorial2024/v3.LR.historical_0101/post/atm/tc-analysis_2000_2014" # tc_param.short_test_name = short_name @@ -159,6 +172,7 @@ tc_param.ref_start_yr = "1979" tc_param.ref_end_yr = "2018" +tc_param.save_netcdf = True params.append(tc_param) arm_param = ARMDiagsParameter() @@ -177,6 +191,7 @@ arm_param.ref_start_yr = "0001" arm_param.ref_end_yr = "0001" +arm_param.save_netcdf = True params.append(arm_param) # Run diff --git a/tests/test_regression.py b/tests/test_regression.py index 0f721da34..af04ded2c 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -1,4 +1,7 @@ import glob +import subprocess +from datetime import datetime +from typing import List, TypedDict import numpy as np import pytest @@ -8,18 +11,37 @@ from e3sm_diags.logger import custom_logger from tests.complete_run_script import params, runner - logger = custom_logger(__name__) -DEV_DIR = "843-migration-phase3-model-vs-obs" -DEV_PATH = f"/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/{DEV_DIR}/" +def _get_git_branch_name() -> str: + """Get the current git branch name.""" + try: + branch_name = ( + subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + stderr=subprocess.DEVNULL, + ) + .strip() + .decode("utf-8") + ) + except subprocess.CalledProcessError: + branch_name = "unknown" + + return branch_name + + +BRANCH_NAME = _get_git_branch_name() +DEV_TIMESTAMP = datetime.now().strftime("%y-%m-%d") +DEV_DIR = f"{DEV_TIMESTAMP}-{BRANCH_NAME}" +DEV_PATH = f"/global/cfs/cdirs/e3sm/www/e3sm_diags/complete_run/{DEV_DIR}" DEV_GLOB = sorted(glob.glob(DEV_PATH + "**/**/*.nc")) DEV_NUM_FILES = len(DEV_GLOB) -MAIN_DIR = "main" -MAIN_PATH = f"/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/{MAIN_DIR}/" +# TODO: Update `MAIN_DIR` as needed. +MAIN_DIR = "24-12-09-main" +MAIN_PATH = f"/global/cfs/cdirs/e3sm/www/e3sm_diags/{MAIN_DIR}/" MAIN_GLOB = sorted(glob.glob(MAIN_PATH + "**/**/*.nc")) MAIN_NUM_FILES = len(MAIN_GLOB) @@ -54,7 +76,6 @@ def run_diags_and_get_results_dir() -> str: class TestRegression: @pytest.fixture(autouse=True) def setup(self, run_diags_and_get_results_dir): - # TODO: We need to store `main` results on a data container self.results_dir = run_diags_and_get_results_dir def test_check_if_files_found(self): @@ -90,8 +111,19 @@ def test_get_relative_diffs(self): assert len(results["key_errors"]) == 0 -def _get_relative_diffs(): - results = { +class DiffResults(TypedDict): + """Type annotation for the results of the relative differences comparison.""" + + missing_files: List[str] + missing_vars: List[str] + matching_files: List[str] + mismatch_errors: List[str] + not_equal_errors: List[str] + key_errors: List[str] + + +def _get_relative_diffs() -> DiffResults: + results: DiffResults = { "missing_files": [], "missing_vars": [], "matching_files": [], @@ -192,7 +224,7 @@ def _get_var_data(ds: xr.Dataset, var_key: str) -> np.ndarray | None: except KeyError: var_keys = DERIVED_VARIABLES[var_key.upper()].keys() - var_keys = [var_key] + list(sum(var_keys, ())) + var_keys = [var_key] + list(sum(var_keys, ())) # type: ignore for key in var_keys: if key in ds.data_vars.keys():