diff --git a/examples/entrypoint/task_with_defaults.py b/examples/entrypoint/task_with_defaults.py index 5773bd2..0826ea7 100644 --- a/examples/entrypoint/task_with_defaults.py +++ b/examples/entrypoint/task_with_defaults.py @@ -41,16 +41,6 @@ def my_optimizer( ) -def defaults() -> run.Partial["train_model"]: - return run.Partial( - train_model, - model=my_model(), - optimizer=my_optimizer(), - epochs=40, - batch_size=1024, - ) - - def train_model( model: Model, optimizer: Optimizer, @@ -89,5 +79,14 @@ def custom_defaults() -> run.Partial[train_model]: ) +@run.autoconvert +def local_executor() -> run.Executor: + return run.LocalExecutor() + + if __name__ == "__main__": - run.cli.main(train_model, default_factory=custom_defaults) + run.cli.main( + train_model, + default_factory=custom_defaults, + default_executor=local_executor(), + ) diff --git a/src/nemo_run/cli/api.py b/src/nemo_run/cli/api.py index cc6d077..cec5875 100644 --- a/src/nemo_run/cli/api.py +++ b/src/nemo_run/cli/api.py @@ -192,13 +192,19 @@ def wrapper(f: F) -> F: return wrapper(fn) -def main(fn: F, default_factory: Optional[Callable] = None, **kwargs): +def main( + fn: F, + default_factory: Optional[Callable] = None, + default_executor: Optional[Config[Executor]] = None, + **kwargs, +): """ Execute the main CLI entrypoint for the given function. Args: fn (F): The function to be executed as a CLI entrypoint. - default_factory (Optional[Callable]): A custom default factory to use for this execution. + default_factory (Optional[Callable]): A custom default factory to use in the context of this main function. + default_executor (Optional[Config[Executor]]): A custom default executor to use in the context of this main function. **kwargs: Additional keyword arguments to pass to the entrypoint decorator. Example: @@ -218,8 +224,14 @@ def my_cli_function(): if default_factory: fn.cli_entrypoint.default_factory = default_factory + _original_default_executor = fn.cli_entrypoint.default_executor + if default_executor: + fn.cli_entrypoint.default_executor = default_executor + fn.cli_entrypoint.main() + fn.cli_entrypoint.default_factory = _original_default_factory + fn.cli_entrypoint.default_executor = _original_default_executor @overload @@ -681,6 +693,7 @@ def cli_command( name: str, fn: Callable, default_factory: Optional[Callable] = None, + default_executor: Optional[Executor] = None, type: Literal["task", "experiment"] = "task", command_kwargs: Dict[str, Any] = {}, ): @@ -738,6 +751,10 @@ def command( require_confirmation=require_confirmation, tail_logs=tail_logs, ) + + if default_executor: + self.executor = default_executor + try: _load_entrypoints() _load_workspace() @@ -932,7 +949,12 @@ def parse_args(self, args: List[str]): if executor_name: self.executor = self.parse_executor(executor_name, *executor_args) else: - self.executor = None + if hasattr(self, "executor"): + if isinstance(self.executor, Config): + self.executor = fdl.build(self.executor) + parse_cli_args(self.executor, args, self.executor) + else: + self.executor = None if plugin_name: plugins = self.parse_plugin(plugin_name, *plugin_args) if not isinstance(plugins, list): @@ -1040,6 +1062,7 @@ def __init__( fn: Callable[Params, ReturnType], namespace: str, default_factory: Optional[Callable] = None, + default_executor: Optional[Config[Executor]] = None, env=None, name=None, help_str=None, @@ -1075,6 +1098,7 @@ def __init__( self.require_confirmation = require_confirmation self.type = type self.default_factory = default_factory + self.default_executor = default_executor def __call__(self, *args: Params.args, **kwargs: Params.kwargs) -> ReturnType: return self.fn(*args, **kwargs) @@ -1139,6 +1163,7 @@ class CLITaskCommand(EntrypointCommand): self.fn, type=self.type, default_factory=self.default_factory, + default_executor=self.default_executor, command_kwargs=dict( help=colored_help, cls=CLITaskCommand,