Skip to content

Commit

Permalink
Update complete_run_script.py to save netCDF
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Dec 10, 2024
1 parent 9136ea7 commit 5fa475c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 10 deletions.
17 changes: 16 additions & 1 deletion tests/complete_run_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@

case = "extendedOutput.v3.LR.historical_0101"
short_name = "v3.LR.historical_0101"
results_dir = "/global/cfs/cdirs/e3sm/www/chengzhu/tutorial2024/e3sm_diags_extended_int"

# TODO: Update `MAIN_DIR` as needed.
MAIN_DIR = "24-12-09-main"
results_dir = f"/global/cfs/cdirs/e3sm/www/e3sm_diags/{MAIN_DIR}/"

test_climo = "/global/cfs/cdirs/e3sm/chengzhu/tutorial2024/v3.LR.historical_0101/post/atm/180x360_aave/clim/15yr"
test_ts = "/global/cfs/cdirs/e3sm/chengzhu/tutorial2024/v3.LR.historical_0101/post/atm/180x360_aave/ts/monthly/15yr"
Expand Down Expand Up @@ -78,6 +81,7 @@
param.output_format_subplot = []
param.multiprocessing = True
param.num_workers = 24
param.save_netcdf = True
param.seasons = ["ANN"]
params = [param]

Expand All @@ -93,6 +97,7 @@
enso_param.ref_start_yr = start_yr
enso_param.ref_end_yr = end_yr

enso_param.save_netcdf = True
params.append(enso_param)

trop_param = TropicalSubseasonalParameter()
Expand All @@ -106,7 +111,9 @@
trop_param.ref_start_yr = "2001"
trop_param.ref_end_yr = "2010"

trop_param.save_netcdf = True
params.append(trop_param)

qbo_param = QboParameter()
qbo_param.test_data_path = test_ts
# qbo_param.test_name = short_name
Expand All @@ -118,7 +125,9 @@
# Obs
qbo_param.reference_data_path = ref_ts

qbo_param.save_netcdf = True
params.append(qbo_param)

dc_param = DiurnalCycleParameter()
dc_param.test_data_path = "/global/cfs/cdirs/e3sm/chengzhu/tutorial2024/v3.LR.historical_0101/post/atm/180x360_aave/clim_diurnal_8xdaily/"
# dc_param.short_test_name = short_name
Expand All @@ -128,7 +137,9 @@
# Obs
dc_param.reference_data_path = ref_climo

dc_param.save_netcdf = True
params.append(dc_param)

streamflow_param = StreamflowParameter()
streamflow_param.reference_data_path = ref_ts
streamflow_param.test_data_path = "/global/cfs/cdirs/e3sm/chengzhu/tutorial2024/v3.LR.historical_0101/post/rof/native/ts/monthly/15yr/"
Expand All @@ -143,7 +154,9 @@
)
streamflow_param.ref_end_yr = "1995"

streamflow_param.save_netcdf = True
params.append(streamflow_param)

tc_param = TCAnalysisParameter()
tc_param.test_data_path = "/global/cfs/cdirs/e3sm/chengzhu/tutorial2024/v3.LR.historical_0101/post/atm/tc-analysis_2000_2014"
# tc_param.short_test_name = short_name
Expand All @@ -159,6 +172,7 @@
tc_param.ref_start_yr = "1979"
tc_param.ref_end_yr = "2018"

tc_param.save_netcdf = True
params.append(tc_param)

arm_param = ARMDiagsParameter()
Expand All @@ -177,6 +191,7 @@
arm_param.ref_start_yr = "0001"
arm_param.ref_end_yr = "0001"

arm_param.save_netcdf = True
params.append(arm_param)

# Run
Expand Down
50 changes: 41 additions & 9 deletions tests/test_regression.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import glob
import subprocess
from datetime import datetime
from typing import List, TypedDict

import numpy as np
import pytest
Expand All @@ -8,18 +11,37 @@
from e3sm_diags.logger import custom_logger
from tests.complete_run_script import params, runner


logger = custom_logger(__name__)


DEV_DIR = "843-migration-phase3-model-vs-obs"
DEV_PATH = f"/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/{DEV_DIR}/"
def _get_git_branch_name() -> str:
"""Get the current git branch name."""
try:
branch_name = (
subprocess.check_output(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
stderr=subprocess.DEVNULL,
)
.strip()
.decode("utf-8")
)
except subprocess.CalledProcessError:
branch_name = "unknown"

return branch_name


BRANCH_NAME = _get_git_branch_name()
DEV_TIMESTAMP = datetime.now().strftime("%y-%m-%d")
DEV_DIR = f"{DEV_TIMESTAMP}-{BRANCH_NAME}"
DEV_PATH = f"/global/cfs/cdirs/e3sm/www/e3sm_diags/complete_run/{DEV_DIR}"

DEV_GLOB = sorted(glob.glob(DEV_PATH + "**/**/*.nc"))
DEV_NUM_FILES = len(DEV_GLOB)

MAIN_DIR = "main"
MAIN_PATH = f"/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/{MAIN_DIR}/"
# TODO: Update `MAIN_DIR` as needed.
MAIN_DIR = "24-12-09-main"
MAIN_PATH = f"/global/cfs/cdirs/e3sm/www/e3sm_diags/{MAIN_DIR}/"
MAIN_GLOB = sorted(glob.glob(MAIN_PATH + "**/**/*.nc"))
MAIN_NUM_FILES = len(MAIN_GLOB)

Expand Down Expand Up @@ -54,7 +76,6 @@ def run_diags_and_get_results_dir() -> str:
class TestRegression:
@pytest.fixture(autouse=True)
def setup(self, run_diags_and_get_results_dir):
# TODO: We need to store `main` results on a data container
self.results_dir = run_diags_and_get_results_dir

def test_check_if_files_found(self):
Expand Down Expand Up @@ -90,8 +111,19 @@ def test_get_relative_diffs(self):
assert len(results["key_errors"]) == 0


def _get_relative_diffs():
results = {
class DiffResults(TypedDict):
"""Type annotation for the results of the relative differences comparison."""

missing_files: List[str]
missing_vars: List[str]
matching_files: List[str]
mismatch_errors: List[str]
not_equal_errors: List[str]
key_errors: List[str]


def _get_relative_diffs() -> DiffResults:
results: DiffResults = {
"missing_files": [],
"missing_vars": [],
"matching_files": [],
Expand Down Expand Up @@ -192,7 +224,7 @@ def _get_var_data(ds: xr.Dataset, var_key: str) -> np.ndarray | None:
except KeyError:
var_keys = DERIVED_VARIABLES[var_key.upper()].keys()

var_keys = [var_key] + list(sum(var_keys, ()))
var_keys = [var_key] + list(sum(var_keys, ())) # type: ignore

for key in var_keys:
if key in ds.data_vars.keys():
Expand Down

0 comments on commit 5fa475c

Please sign in to comment.