Skip to content

Commit

Permalink
Adding default_executor
Browse files Browse the repository at this point in the history
Signed-off-by: Marc Romeyn <mromeijn@nvidia.com>
  • Loading branch information
marcromeyn committed Aug 23, 2024
1 parent 2fbaab6 commit 3a0adfe
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 14 deletions.
21 changes: 10 additions & 11 deletions examples/entrypoint/task_with_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,6 @@ def my_optimizer(
)


def defaults() -> run.Partial["train_model"]:
return run.Partial(
train_model,
model=my_model(),
optimizer=my_optimizer(),
epochs=40,
batch_size=1024,
)


def train_model(
model: Model,
optimizer: Optimizer,
Expand Down Expand Up @@ -89,5 +79,14 @@ def custom_defaults() -> run.Partial[train_model]:
)


@run.autoconvert
def local_executor() -> run.Executor:
return run.LocalExecutor()


if __name__ == "__main__":
run.cli.main(train_model, default_factory=custom_defaults)
run.cli.main(
train_model,
default_factory=custom_defaults,
default_executor=local_executor(),
)
31 changes: 28 additions & 3 deletions src/nemo_run/cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,19 @@ def wrapper(f: F) -> F:
return wrapper(fn)


def main(fn: F, default_factory: Optional[Callable] = None, **kwargs):
def main(
fn: F,
default_factory: Optional[Callable] = None,
default_executor: Optional[Config[Executor]] = None,
**kwargs,
):
"""
Execute the main CLI entrypoint for the given function.
Args:
fn (F): The function to be executed as a CLI entrypoint.
default_factory (Optional[Callable]): A custom default factory to use for this execution.
default_factory (Optional[Callable]): A custom default factory to use in the context of this main function.
default_executor (Optional[Config[Executor]]): A custom default executor to use in the context of this main function.
**kwargs: Additional keyword arguments to pass to the entrypoint decorator.
Example:
Expand All @@ -218,8 +224,14 @@ def my_cli_function():
if default_factory:
fn.cli_entrypoint.default_factory = default_factory

_original_default_executor = fn.cli_entrypoint.default_executor
if default_executor:
fn.cli_entrypoint.default_executor = default_executor

fn.cli_entrypoint.main()

fn.cli_entrypoint.default_factory = _original_default_factory
fn.cli_entrypoint.default_executor = _original_default_executor


@overload
Expand Down Expand Up @@ -681,6 +693,7 @@ def cli_command(
name: str,
fn: Callable,
default_factory: Optional[Callable] = None,
default_executor: Optional[Executor] = None,
type: Literal["task", "experiment"] = "task",
command_kwargs: Dict[str, Any] = {},
):
Expand Down Expand Up @@ -738,6 +751,10 @@ def command(
require_confirmation=require_confirmation,
tail_logs=tail_logs,
)

if default_executor:
self.executor = default_executor

try:
_load_entrypoints()
_load_workspace()
Expand Down Expand Up @@ -932,7 +949,12 @@ def parse_args(self, args: List[str]):
if executor_name:
self.executor = self.parse_executor(executor_name, *executor_args)
else:
self.executor = None
if hasattr(self, "executor"):
if isinstance(self.executor, Config):
self.executor = fdl.build(self.executor)
parse_cli_args(self.executor, args, self.executor)
else:
self.executor = None
if plugin_name:
plugins = self.parse_plugin(plugin_name, *plugin_args)
if not isinstance(plugins, list):
Expand Down Expand Up @@ -1040,6 +1062,7 @@ def __init__(
fn: Callable[Params, ReturnType],
namespace: str,
default_factory: Optional[Callable] = None,
default_executor: Optional[Config[Executor]] = None,
env=None,
name=None,
help_str=None,
Expand Down Expand Up @@ -1075,6 +1098,7 @@ def __init__(
self.require_confirmation = require_confirmation
self.type = type
self.default_factory = default_factory
self.default_executor = default_executor

def __call__(self, *args: Params.args, **kwargs: Params.kwargs) -> ReturnType:
return self.fn(*args, **kwargs)
Expand Down Expand Up @@ -1139,6 +1163,7 @@ class CLITaskCommand(EntrypointCommand):
self.fn,
type=self.type,
default_factory=self.default_factory,
default_executor=self.default_executor,
command_kwargs=dict(
help=colored_help,
cls=CLITaskCommand,
Expand Down

0 comments on commit 3a0adfe

Please sign in to comment.