diff --git a/src/access_nri_intake/cli.py b/src/access_nri_intake/cli.py index 04176ac7..fbb67427 100644 --- a/src/access_nri_intake/cli.py +++ b/src/access_nri_intake/cli.py @@ -632,11 +632,33 @@ def use_esm_datastore(argv: Sequence[str] | None = None) -> int: ), ) + parser.add_argument( + "--datastore-name", + type=str, + help=( + "Name of the datastore to use. If not provided, this will default to" + " 'experiment_datastore'." + ), + default="experiment_datastore", + ) + + parser.add_argument( + "--description", + type=str, + help=( + "Description of the datastore. If not provided, a default description will be used:" + " 'esm_datastore for the model output in {--expt-dir}'" + ), + default=None, + ) + args = parser.parse_args(argv) builder = args.builder experiment_dir = Path(args.expt_dir) catalog_dir = Path(args.cat_dir) if args.cat_dir else experiment_dir builder_kwargs = args.builder_kwargs or {} + datastore_name = args.datastore_name + description = args.description try: builder = getattr(builders, builder) @@ -664,6 +686,8 @@ def use_esm_datastore(argv: Sequence[str] | None = None) -> int: builder, catalog_dir, builder_kwargs=builder_kwargs, + datastore_name=datastore_name, + description=description, open_ds=False, ) diff --git a/src/access_nri_intake/experiment/utils.py b/src/access_nri_intake/experiment/utils.py index d23103af..4fe92b4c 100644 --- a/src/access_nri_intake/experiment/utils.py +++ b/src/access_nri_intake/experiment/utils.py @@ -1,3 +1,4 @@ +import ast import json import re import warnings @@ -305,10 +306,12 @@ def parse_kwargs(kwargs: str) -> dict: for item in kwargs.split(): kw, arg = item.split("=") if kw == "ensemble": - if arg.lower() not in ["true", "false"]: + try: + ret[kw] = ast.literal_eval(arg.capitalize()) + if not isinstance(ret[kw], bool): + raise ValueError + except (ValueError, SyntaxError): raise TypeError(f"Ensemble kwarg must be a boolean, not {arg}.") - else: - ret[kw] = True if arg.lower() == "true" else False return ret diff --git a/tests/test_cli.py b/tests/test_cli.py index f16e454a..a76a7f98 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -4,13 +4,14 @@ import glob import os import shutil -from pathlib import Path +from pathlib import Path, PosixPath from unittest import mock import intake import pytest import yaml +import access_nri_intake from access_nri_intake.catalog.manager import CatalogManager from access_nri_intake.cli import ( MetadataCheckError, @@ -1152,18 +1153,71 @@ def test_use_esm_datastore_nonexistent_dirs(expt_dir, cat_dir): @mock.patch("access_nri_intake.cli.use_datastore") -def test_use_esm_datastore_valid(use_datastore): +@pytest.mark.parametrize( + "argv, expected_call_args, expected_call_kwargs", + [ + ( + ["--builder", "AccessOm2Builder"], + ( + PosixPath("."), + access_nri_intake.source.builders.AccessOm2Builder, + PosixPath("."), + ), + { + "builder_kwargs": {}, + "datastore_name": "experiment_datastore", + "description": None, + "open_ds": False, + }, + ), + ( + ["--builder", "Mom6Builder", "--datastore-name", "VERY_BAD_NAME"], + ( + PosixPath("."), + access_nri_intake.source.builders.Mom6Builder, + PosixPath("."), + ), + { + "builder_kwargs": {}, + "datastore_name": "VERY_BAD_NAME", + "description": None, + "open_ds": False, + }, + ), + ( + [ + "--builder", + "AccessOm2Builder", + "--description", + "meaningless_description", + ], + ( + PosixPath("."), + access_nri_intake.source.builders.AccessOm2Builder, + PosixPath("."), + ), + { + "builder_kwargs": {}, + "datastore_name": "experiment_datastore", + "description": "meaningless_description", + "open_ds": False, + }, + ), + ], +) +def test_use_esm_datastore_valid( + use_datastore, argv, expected_call_args, expected_call_kwargs +): """I'm not using any args here, so we should get defaults. This should return zero. I'm going to mock the use_datastore function so it doesn't do anything, just returns none""" use_datastore.return_value = None - ret = use_esm_datastore( - [ - "--builder", - "AccessOm2Builder", - ] - ) + ret = use_esm_datastore(argv) + + args, kwargs = use_datastore.call_args + assert args == expected_call_args + assert kwargs == expected_call_kwargs assert ret == 0 diff --git a/tests/test_experiment.py b/tests/test_experiment.py index 56f14587..a0dd9e55 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -458,6 +458,11 @@ def test_validate_args(builder: str, kwargs, fails, err_msg): True, None, ), + ( + "ensemble=1", + True, + None, + ), ( "esnmebel=True", False,