Skip to content

Commit

Permalink
Fix class property and default sets
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Nov 30, 2023
1 parent a8f06a2 commit 69aa3a5
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 33 deletions.
53 changes: 27 additions & 26 deletions e3sm_diags/parameter/core_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,34 @@

logger = custom_logger(__name__)

# FIXME: There is probably a better way of defining default sets because most of
# this is repeated in SETS_TO_PARAMETERS and SETS_TO_PARSERS.
# Also integration tests will break if "mp_partition" is included because
# we did not take it into account yet.
DEFAULT_SETS = [
"zonal_mean_xy",
"zonal_mean_2d",
"zonal_mean_2d_stratosphere",
"meridional_mean_2d",
"lat_lon",
"polar",
"area_mean_time_series",
"cosp_histogram",
"enso_diags",
"qbo",
"streamflow",
"diurnal_cycle",
"arm_diags",
"tc_analysis",
"annual_cycle_zonal_mean",
"lat_lon_land",
"lat_lon_river",
"aerosol_aeronet",
"aerosol_budget",
]

class CoreParameter:
# FIXME: mp_partition was not originally included as a default set
# in `CoreParameter.sets`. Including it will break an integration
# test, so it needs to be removed.
DEFAULT_SETS = [
"zonal_mean_xy",
"zonal_mean_2d",
"zonal_mean_2d_stratosphere",
"meridional_mean_2d",
"lat_lon",
"polar",
"area_mean_time_series",
"cosp_histogram",
"enso_diags",
"qbo",
"streamflow",
"diurnal_cycle",
"arm_diags",
"tc_analysis",
"annual_cycle_zonal_mean",
"lat_lon_land",
"lat_lon_river",
"aerosol_aeronet",
"aerosol_budget",
]

class CoreParameter:
def __init__(self):
# File I/O
# ------------------------
Expand Down Expand Up @@ -84,7 +85,7 @@ def __init__(self):
self.run_type: str = "model_vs_obs"

# A list of the sets to be run, by default all sets.
self.sets: List[str] = CoreParameter.DEFAULT_SETS
self.sets: List[str] = DEFAULT_SETS

# The current set that is being ran when looping over sets in
# `e3sm_diags_driver.run_diag()`.
Expand Down
16 changes: 10 additions & 6 deletions e3sm_diags/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from e3sm_diags.e3sm_diags_driver import get_default_diags_path, main
from e3sm_diags.logger import custom_logger, move_log_to_prov_dir
from e3sm_diags.parameter import SET_TO_PARAMETERS
from e3sm_diags.parameter.core_parameter import CoreParameter
from e3sm_diags.parameter.core_parameter import DEFAULT_SETS, CoreParameter
from e3sm_diags.parser.core_parser import CoreParser

logger = custom_logger(__name__)
Expand All @@ -21,9 +21,6 @@ class Run:

def __init__(self):
self.parser = CoreParser()
args = self.parser.view_args()

self.has_cfg_params = len(args.other_parameters) > 0

# The list of sets to run using parameter objects.
self.sets_to_run = []
Expand Down Expand Up @@ -134,15 +131,16 @@ def _get_cfg_parameters(

# Get parameters from user-defined .cfg file or default diags .cfg
# file.
if self.has_cfg_params:

if self.has_cfg_file_arg:
cfg_params = self._get_diags_from_cfg_file()
else:
run_type = parameters[0].run_type
cfg_params = self._get_default_diags_from_cfg_file(run_type)

# Loop over the sets to run and get the related parameters.
if len(self.sets_to_run) == 0:
self.sets_to_run = CoreParameter.DEFAULT_SETS
self.sets_to_run = DEFAULT_SETS

for set_name in self.sets_to_run:
# For each of the set_names, get the corresponding parameter.
Expand Down Expand Up @@ -177,6 +175,12 @@ def _get_cfg_parameters(

return run_params

@property
def has_cfg_file_arg(self):
args = self.parser.view_args()

self.has_cfg_params = len(args.other_parameters) > 0

def _get_diags_from_cfg_file(self) -> Union[List, List[CoreParameter]]:
"""
Get parameters defined by the cfg file passed to -d/--diags (if set).
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/test_all_sets_image_diffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from tests.integration.config import TEST_IMAGES_PATH, TEST_ROOT_PATH
from tests.integration.utils import _get_test_params

CFG_PATH = f"{TEST_ROOT_PATH}/all_sets.cfg"
CFG_PATH = os.path.join(TEST_ROOT_PATH, "all_sets.cfg")


logger = custom_logger(__name__)

Expand Down

0 comments on commit 69aa3a5

Please sign in to comment.