Skip to content

Commit

Permalink
Merge pull request #22 from AllenNeuralDynamics/feat-add-json-cli-ent…
Browse files Browse the repository at this point in the history
…ry-point

Allow json instances to be passed to constructor and cli
  • Loading branch information
bruno-f-cruz authored Dec 14, 2024
2 parents b6d543c + d68c718 commit 799d195
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
extra_args: Optional[str] = None,
delete_src: bool = False,
overwrite: bool = False,
force_dir: bool = True
force_dir: bool = True,
):
self.source = source
self.destination = destination
Expand Down
85 changes: 79 additions & 6 deletions src/aind_behavior_experiment_launcher/launcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
AindBehaviorSessionModel,
AindBehaviorTaskLogicModel,
)
from aind_behavior_services.utils import model_from_json_file

from aind_behavior_experiment_launcher import logging_helper, ui_helper
from aind_behavior_experiment_launcher.services import ServicesFactoryManager
Expand All @@ -37,6 +38,7 @@ class BaseLauncher(Generic[TRig, TSession, TTaskLogic]):

def __init__(
self,
*,
rig_schema_model: Type[TRig],
session_schema_model: Type[TSession],
task_logic_schema_model: Type[TTaskLogic],
Expand All @@ -51,6 +53,9 @@ def __init__(
services: Optional[ServicesFactoryManager] = None,
validate_init: bool = True,
attached_logger: Optional[logging.Logger] = None,
rig_schema_path: Optional[os.PathLike] = None,
task_logic_schema: Optional[os.PathLike] = None,
subject: Optional[str] = None,
) -> None:
self.temp_dir = self.abspath(temp_dir) / secrets.token_hex(nbytes=16)
self.temp_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -64,7 +69,7 @@ def __init__(
_logger.setLevel(logging.DEBUG)

self._ui_helper = ui_helper.UIHelper()
self._cli_args = self._cli_wrapper()
self._cli_args: _CliArgs = self._cli_wrapper()
self._bind_launcher_services(services)

repository_dir = (
Expand All @@ -88,6 +93,8 @@ def __init__(
self._rig_schema: Optional[TRig] = None
self._session_schema: Optional[TSession] = None
self._task_logic_schema: Optional[TTaskLogic] = None
self._solve_schema_instances(rig_path_path=rig_schema_path, task_logic_path=task_logic_schema)
self._subject: Optional[str] = self._cli_args.subject if self._cli_args.subject else subject

# Directories
self.data_dir = Path(self._cli_args.data_dir) if self._cli_args.data_dir is not None else self.abspath(data_dir)
Expand Down Expand Up @@ -182,8 +189,10 @@ def _ui_prompt(self) -> Self:
self._print_diagnosis()

self._session_schema = self._prompt_session_input()
self._task_logic_schema = self._prompt_task_logic_input()
self._rig_schema = self._prompt_rig_input()
if self._task_logic_schema is None:
self._task_logic_schema = self._prompt_task_logic_input()
if self._rig_schema is None:
self._rig_schema = self._prompt_rig_input()
return self

def _prompt_session_input(self) -> TSession:
Expand Down Expand Up @@ -332,13 +341,25 @@ def _get_default_arg_parser() -> argparse.ArgumentParser:
default=False,
)

# These should default to None
parser.add_argument("--subject", help="Specifies the name of the subject")
parser.add_argument("--task-logic-path", help="Specifies the path to a json file containing task logic")
parser.add_argument("--rig-path", help="Specifies the path to a json file containing rig configuration")

# Catch all additional arguments
# Syntax is a bit clunky, but it works
# e.g. "python script.py -- --arg1 --arg"
# This will capture "--arg1 --arg2" in the "extras" list
parser.add_argument(
"extras", nargs=argparse.REMAINDER, help="Capture all remaining arguments after -- separator"
)
return parser

@classmethod
def _cli_wrapper(cls) -> argparse.Namespace:
def _cli_wrapper(cls) -> _CliArgs:
parser = cls._get_default_arg_parser()
args, _ = parser.parse_known_args()
return args
args = vars(parser.parse_args())
return _CliArgs(**args)

def _copy_tmp_directory(self, dst: os.PathLike) -> None:
dst = Path(dst) / ".launcher"
Expand All @@ -351,3 +372,55 @@ def _bind_launcher_services(
if self._services_factory_manager is not None:
self._services_factory_manager.register_launcher(self)
return self._services_factory_manager

def _solve_schema_instances(
self, rig_path_path: Optional[os.PathLike] = None, task_logic_path: Optional[os.PathLike] = None
) -> None:
rig_path_path = self._cli_args.rig_path if self._cli_args.rig_path is not None else rig_path_path
task_logic_path = (
self._cli_args.task_logic_path if self._cli_args.task_logic_path is not None else task_logic_path
)
if rig_path_path is not None:
logging.info("Loading rig schema from %s", self._cli_args.rig_path)
self._rig_schema = model_from_json_file(rig_path_path, self.rig_schema_model)
if task_logic_path is not None:
logging.info("Loading task logic schema from %s", self._cli_args.task_logic_path)
self._task_logic_schema = model_from_json_file(task_logic_path, self.task_logic_schema_model)


@pydantic.dataclasses.dataclass
class _CliArgs:
data_dir: Optional[os.PathLike] = None
repository_dir: Optional[os.PathLike] = None
config_library_dir: Optional[os.PathLike] = None
create_directories: bool = False
debug: bool = False
allow_dirty: bool = False
skip_hardware_validation: bool = False
subject: Optional[str] = None
task_logic_path: Optional[os.PathLike] = None
rig_path: Optional[os.PathLike] = None
extras: dict[str, str] = pydantic.Field(default_factory=dict)

@pydantic.field_validator("extras", mode="before")
@classmethod
def _validate_extras(cls, v):
if isinstance(v, list):
v = cls._parse_extra_args(v)
return v

@staticmethod
def _parse_extra_args(args: list[str]) -> dict[str, str]:
extra_kwargs: dict[str, str] = {}
if len(args) == 0:
return extra_kwargs
_ = args.pop(0) # remove the "--" separator
for arg in args:
if arg.startswith("--"):
key_value = arg.lstrip("--").split("=", 1)
if len(key_value) == 2:
key, value = key_value
extra_kwargs[key] = value
else:
logger.error("Skipping invalid argument format: %s", arg)
return extra_kwargs
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,20 @@ def _post_init(self, validate: bool = True) -> None:

@override
def _prompt_session_input(self, directory: Optional[str] = None) -> TSession:
_local_config_directory = (
Path(os.path.join(self.config_library_dir, directory)) if directory is not None else self._subject_dir
)
available_batches = self._get_available_batches(_local_config_directory)
experimenter = self._ui_helper.prompt_experimenter(strict=True)
subject_list = self._get_subject_list(available_batches)
subject = self._ui_helper.choose_subject(subject_list)
self._subject_db_data = subject_list.get_subject(subject)
if self._subject is not None:
logging.info("Subject provided via CLABE: %s", self._cli_args.subject)
subject = self._subject
else:
_local_config_directory = (
Path(os.path.join(self.config_library_dir, directory)) if directory is not None else self._subject_dir
)
available_batches = self._get_available_batches(_local_config_directory)
subject_list = self._get_subject_list(available_batches)
subject = self._ui_helper.choose_subject(subject_list)
self._subject = subject
self._subject_db_data = subject_list.get_subject(subject)

notes = self._ui_helper.prompt_get_notes()

return self.session_schema_model(
Expand Down Expand Up @@ -142,7 +148,8 @@ def _prompt_task_logic_input(
Path(os.path.join(self.config_library_dir, directory)) if directory is not None else self._task_logic_dir
)
hint_input: Optional[SubjectEntry] = self._subject_db_data
task_logic: Optional[TTaskLogic] = None
task_logic: Optional[TTaskLogic] = self._task_logic_schema
# If the task logic is already set (e.g. from CLI), skip the prompt
while task_logic is None:
try:
if hint_input is None:
Expand Down
51 changes: 37 additions & 14 deletions tests/test_base_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from aind_behavior_services import AindBehaviorRigModel, AindBehaviorSessionModel, AindBehaviorTaskLogicModel

from aind_behavior_experiment_launcher.launcher import BaseLauncher
from aind_behavior_experiment_launcher.launcher import BaseLauncher, _CliArgs
from aind_behavior_experiment_launcher.services import ServicesFactoryManager


Expand Down Expand Up @@ -88,25 +88,22 @@ def test_cli_wrapper(self, mock_parse_known_args):
[],
)
args = BaseLauncher._cli_wrapper()
self.assertEqual(args.data_dir, "/tmp/fake/data/dir")
self.assertEqual(args.data_dir, Path("/tmp/fake/data/dir"))
self.assertFalse(args.create_directories)
self.assertFalse(args.debug)
self.assertFalse(args.allow_dirty)
self.assertFalse(args.skip_hardware_validation)

@patch("argparse.ArgumentParser.parse_known_args")
@patch("argparse.ArgumentParser.parse_args")
def test_cli_args_integration(self, mock_parse_known_args):
mock_parse_known_args.return_value = (
argparse.Namespace(
data_dir="/tmp/fake/data/dir",
repository_dir=None,
config_library_dir=None,
create_directories=True,
debug=True,
allow_dirty=True,
skip_hardware_validation=True,
),
[],
mock_parse_known_args.return_value = argparse.Namespace(
data_dir="/tmp/fake/data/dir",
repository_dir=None,
config_library_dir=None,
create_directories=True,
debug=True,
allow_dirty=True,
skip_hardware_validation=True,
)
launcher = BaseLauncher(
rig_schema_model=self.rig_schema_model,
Expand All @@ -122,5 +119,31 @@ def test_cli_args_integration(self, mock_parse_known_args):
self.assertTrue(launcher._cli_args.skip_hardware_validation)


class TestCliArgs(unittest.TestCase):
def test_parse_extra_args_valid(self):
args = ["--", "--key1=value1", "--key2=value2"]
expected_output = {"key1": "value1", "key2": "value2"}
self.assertEqual(_CliArgs._parse_extra_args(args), expected_output)

def test_parse_extra_args_invalid_format_no_equals(self):
args = ["--", "--key1=value1", "--invalid_arg a"]
result = _CliArgs._parse_extra_args(args)
self.assertEqual(result, {"key1": "value1"})

def test_parse_extra_args_invalid_format_is_a_flag(self):
args = ["--", "--key1=value1", "--invalid_flag"]
result = _CliArgs._parse_extra_args(args)
self.assertEqual(result, {"key1": "value1"})

def test_validate_extras_with_list(self):
extras = ["--", "--key1=value1", "--key2=value2"]
expected_output = {"key1": "value1", "key2": "value2"}
self.assertEqual(_CliArgs._validate_extras(extras), expected_output)

def test_validate_extras_with_dict(self):
extras = {"key1": "value1", "key2": "value2"}
self.assertEqual(_CliArgs._validate_extras(extras), extras)


if __name__ == "__main__":
unittest.main()
48 changes: 48 additions & 0 deletions tests/test_behavior_launcher.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import unittest
from pathlib import Path
from unittest.mock import MagicMock, create_autospec, patch
Expand Down Expand Up @@ -149,5 +150,52 @@ def test_validate_service_type_invalid(self):
self.factory_manager._validate_service_type(service, str)


class TestBehaviorLauncherSaveTempModel(unittest.TestCase):
def setUp(self):
self.services_factory_manager = create_autospec(BehaviorServicesFactoryManager)
self.launcher = BehaviorLauncher(
rig_schema_model=MagicMock(),
task_logic_schema_model=MagicMock(),
session_schema_model=MagicMock(),
data_dir="/path/to/data",
config_library_dir="/path/to/config",
temp_dir="/path/to/temp",
repository_dir=None,
allow_dirty=False,
skip_hardware_validation=False,
debug_mode=False,
group_by_subject_log=False,
services=self.services_factory_manager,
validate_init=False,
attached_logger=None,
)

@patch("aind_behavior_experiment_launcher.launcher.behavior_launcher.os.makedirs")
def test_save_temp_model_creates_directory(self, mock_makedirs):
model = MagicMock()
model.__class__.__name__ = "TestModel"
model.model_dump_json.return_value = '{"key": "value"}'
self.launcher._save_temp_model(model, "/path/to/temp")
mock_makedirs.assert_called_once_with(Path("/path/to/temp"), exist_ok=True)

@patch("aind_behavior_experiment_launcher.launcher.behavior_launcher.os.makedirs")
def test_save_temp_model_default_directory(self, mock_makedirs):
model = MagicMock()
model.__class__.__name__ = "TestModel"
model.model_dump_json.return_value = '{"key": "value"}'
path = self.launcher._save_temp_model(model, None)
self.assertTrue(path.endswith("TestModel.json"))

@patch("aind_behavior_experiment_launcher.launcher.behavior_launcher.os.makedirs")
@patch("builtins.open", new_callable=unittest.mock.mock_open)
def test_save_temp_model_returns_correct_path(self, mock_open, mock_makedirs):
model = MagicMock()
model.__class__.__name__ = "TestModel"
model.model_dump_json.return_value = '{"key": "value"}'
path = self.launcher._save_temp_model(model, Path("/path/to/temp"))
expected_path = os.path.join(Path("/path/to/temp"), "TestModel.json")
self.assertEqual(path, expected_path)


if __name__ == "__main__":
unittest.main()

0 comments on commit 799d195

Please sign in to comment.