diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index a494e0c5..0fdfb06e 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -19,8 +19,9 @@ properties not included in any tree mapping operations. ### Added - The ability to specify a custom `snapshot_dir` in `checkpoints_iterator`. -- `CommitFuture` and `HandlerAwaitableSignal` for signalling between -Checkpointing layers to enable async directory creation. +- `CommitFutureAwaitDirectorySignals`, `CommitFuture` and +`HandlerAwaitableSignal` for signalling between Checkpointing layers to enable +async directory creation. - User-provided custom PyTree metadata. ### Fixed diff --git a/checkpoint/orbax/checkpoint/_src/futures/future.py b/checkpoint/orbax/checkpoint/_src/futures/future.py index c8fa051f..9b3a6d71 100644 --- a/checkpoint/orbax/checkpoint/_src/futures/future.py +++ b/checkpoint/orbax/checkpoint/_src/futures/future.py @@ -19,18 +19,73 @@ from typing import Any, Callable, Coroutine, Optional, Sequence from absl import logging +import jax from orbax.checkpoint._src import asyncio_utils from orbax.checkpoint._src.futures import synchronization from orbax.checkpoint._src.multihost import multihost from typing_extensions import Protocol -get_unique_barrier_key = ( - synchronization.HandlerAwaitableSignalBarrierKeyGenerator.get_unique_barrier_key +HandlerAwaitableSignal = synchronization.HandlerAwaitableSignal +HandlerAwaitableSignalOperationIdGenerator = ( + synchronization.HandlerAwaitableSignalOperationIdGenerator ) +is_intialized = HandlerAwaitableSignalOperationIdGenerator.is_intialized _SIGNAL_ACTION_SUCCESS = 'signal_action_success' +def _get_unique_barrier_key( + signal: HandlerAwaitableSignal, operation_id: str +) -> str: + """Returns a unique barrier key for the signal. + + Args: + signal: The signal to generate a barrier key for. + operation_id: The operation id to use as a suffix for the barrier key. + """ + return multihost.unique_barrier_key(signal.value, suffix=operation_id) + + +def get_awaitable_signals_from_contract() -> Sequence[HandlerAwaitableSignal]: + """Gets the awaitable signals that may be sent for the current operation id.""" + client = multihost.get_jax_distributed_client() + barrier_key = _get_unique_barrier_key( + HandlerAwaitableSignal.AWAITABLE_SIGNALS_CONTRACT, + HandlerAwaitableSignalOperationIdGenerator.get_current_operation_id(), + ) + try: + values_str = str(client.key_value_try_get(barrier_key)) + return [HandlerAwaitableSignal(value) for value in values_str.split(',')] + except jax.errors.JaxRuntimeError: + # If the key is not found, then there are no awaitable signals yet. + return [] + + +def add_to_awaitable_signals_contract( + signals: Sequence[HandlerAwaitableSignal], +): + """Adds awaitable signals to `AWAITABLE_SIGNALS_CONTRACT` for lower checkpointing layers to wait on. + + These signals are added to the list of awaitable signals for the current + opertation id in `HandlerAwaitableSignalOperationIdGenerator`. + + Args: + signals: The signals to add to the list of awaitable signals. + """ + if not signals: + return + + current_signals = list(get_awaitable_signals_from_contract()) + current_signals.extend(signals) + keys = ','.join([current_signal.value for current_signal in current_signals]) + client = multihost.get_jax_distributed_client() + barrier_key = _get_unique_barrier_key( + HandlerAwaitableSignal.AWAITABLE_SIGNALS_CONTRACT, + HandlerAwaitableSignalOperationIdGenerator.get_current_operation_id(), + ) + client.key_value_set(barrier_key, keys, allow_overwrite=True) + + class Future(Protocol): """Abstracted Orbax Future class. @@ -108,8 +163,10 @@ def join(self, timeout=None): class _SignalingThread(threading.Thread): """Thread that raises an exception if it encounters an error. - Waits for signals to be received before proceeding with the target function. - Then sends signals to indicate that the target function has completed. + Waits for signals to be received for the current operation id before + proceeding with the target function. Then sends signals to indicate that the + target function has completed with + same operation id. """ _exception: Optional[Exception] = None @@ -117,9 +174,10 @@ class _SignalingThread(threading.Thread): def __init__( self, *, - send_signals: list[synchronization.HandlerAwaitableSignal], - receive_signals: list[synchronization.HandlerAwaitableSignal], + send_signals: Sequence[HandlerAwaitableSignal], + receive_signals: Sequence[HandlerAwaitableSignal], timeout_secs: int = 600, + operation_id: str | None = None, **kwargs, ): """Constructor. @@ -130,12 +188,19 @@ def __init__( receive_signals: Signals to wait for before proceeding with the target function. timeout_secs: Timeout in seconds for waiting for signals. + operation_id: The operation id to use for the barrier keys. If None, the + current operation id is used. **kwargs: Keyword arguments passed to the base class. """ super().__init__(**kwargs) self._send_signals = send_signals self._receive_signals = receive_signals self._timeout_secs = timeout_secs + # Capture the current operation id syncronously. + self._operation_id = ( + operation_id + or HandlerAwaitableSignalOperationIdGenerator.get_current_operation_id() + ) def _wait_for_signals(self): """Waits for signals to be set.""" @@ -148,7 +213,7 @@ def _wait_for_signals(self): signal.value, self._timeout_secs, ) - barrier_key = get_unique_barrier_key(signal) + barrier_key = _get_unique_barrier_key(signal, self._operation_id) client = multihost.get_jax_distributed_client() client.blocking_key_value_get(barrier_key, self._timeout_secs * 1000) @@ -162,7 +227,7 @@ def _set_signals(self): threading.current_thread().name, signal.value, ) - barrier_key = get_unique_barrier_key(signal) + barrier_key = _get_unique_barrier_key(signal, self._operation_id) client = multihost.get_jax_distributed_client() client.key_value_set(barrier_key, _SIGNAL_ACTION_SUCCESS) @@ -194,11 +259,10 @@ def __init__( coro: Coroutine[Any, Any, None], *, name: str | None = None, - send_signals: list[synchronization.HandlerAwaitableSignal] | None = None, - receive_signals: ( - list[synchronization.HandlerAwaitableSignal] | None - ) = None, + send_signals: Sequence[HandlerAwaitableSignal] | None = None, + receive_signals: Sequence[HandlerAwaitableSignal] | None = None, timeout_secs: int = 600, + operation_id: str | None = None, ): """Constructor. @@ -208,6 +272,8 @@ def __init__( send_signals: Signals to send to indicate that the commit has completed. receive_signals: Signals to wait for before proceeding with the commit. timeout_secs: Timeout in seconds for waiting for signals. + operation_id: The operation id to use for the barrier keys. If None, the + current operation id is used. """ super().__init__() send_signals = send_signals or [] @@ -216,6 +282,7 @@ def __init__( send_signals=send_signals, receive_signals=receive_signals, timeout_secs=timeout_secs, + operation_id=operation_id, target=lambda: asyncio_utils.run_sync(coro), name=name, ) @@ -253,3 +320,53 @@ def __init__(self, coro, name: Optional[str] = None): def result(self, timeout: Optional[int] = None) -> Any: return self._t.join(timeout=timeout) + + +class CommitFutureAwaitingContractedSignals(Future): + """Represents the result of a background commit. + + May send signals to indicate that the commit has completed. Waits for all + awaitable signals in the `AWAITABLE_SIGNALS_CONTRACT` to be set before + proceeding with the commit. + """ + + def __init__( + self, + coro: Coroutine[Any, Any, None], + *, + name: str | None = None, + send_signals: Sequence[HandlerAwaitableSignal] | None = None, + skip_if_not_initialized: bool = True, + timeout_secs: int = 600, + operation_id: str | None = None, + ): + """Constructor. + + Synchronously gets all awaitable signals in the contract and waits to + receive them in background before proceeding with the commit. + + Args: + coro: The coroutine to run. + name: The name of the thread. + send_signals: Signals to send to indicate that the commit has completed. + skip_if_not_initialized: If True, skip fetching signals if the + `HandlerAwaitableSignalOperationIdGenerator` is not initialized. + timeout_secs: Timeout in seconds for waiting for signals. + operation_id: The operation id to use for the barrier keys. If None, the + current operation id is used. + """ + super().__init__() + receive_signals = [] + if is_intialized() or not skip_if_not_initialized: + receive_signals = get_awaitable_signals_from_contract() + self._f = CommitFuture( + coro, + name=name, + send_signals=send_signals, + receive_signals=receive_signals, + timeout_secs=timeout_secs, + operation_id=operation_id, + ) + + def result(self, timeout: Optional[float] = None) -> Any: + return self._f.result(timeout=timeout) diff --git a/checkpoint/orbax/checkpoint/_src/futures/synchronization.py b/checkpoint/orbax/checkpoint/_src/futures/synchronization.py index e160cda8..0b91d458 100644 --- a/checkpoint/orbax/checkpoint/_src/futures/synchronization.py +++ b/checkpoint/orbax/checkpoint/_src/futures/synchronization.py @@ -16,7 +16,6 @@ import enum import itertools -from orbax.checkpoint._src.multihost import multihost class HandlerAwaitableSignal(enum.Enum): @@ -26,6 +25,8 @@ class HandlerAwaitableSignal(enum.Enum): `CheckpointHandler or below.` Attributes: + AWAITABLE_SIGNALS_CONTRACT: Contract that contains a list of signals that + may be sent and can be awaited by the handlers. STEP_DIRECTORY_CREATION: When recieved, indicates that the step directory has been created. The handler should not attempt to write files before the directory is created. @@ -34,15 +35,16 @@ class HandlerAwaitableSignal(enum.Enum): directory is created. """ + AWAITABLE_SIGNALS_CONTRACT = "awaitable_signals_contract" STEP_DIRECTORY_CREATION = "step_directory_creation" ITEM_DIRECTORY_CREATION = "item_directory_creation" -class HandlerAwaitableSignalBarrierKeyGenerator: - """A unique barrier key generator for a `HandlerAwaitableSignal`.""" +class HandlerAwaitableSignalOperationIdGenerator: + """A unique operation id generator for `HandlerAwaitableSignal`.""" _operation_id_counter = itertools.count() - _operation_id = None + _operation_id = next(_operation_id_counter) @classmethod def next_operation_id(cls) -> int: @@ -51,20 +53,11 @@ def next_operation_id(cls) -> int: return cls._operation_id @classmethod - def get_unique_barrier_key(cls, signal: HandlerAwaitableSignal) -> str: - """Returns a unique barrier key for the signal. + def get_current_operation_id(cls) -> str: + """Returns the current operation id.""" + return str(cls._operation_id) - Args: - signal: The signal to generate a barrier key for. - - Raises: - ValueError: If `_operation_id` is not initialized. - """ - if cls._operation_id is None: - raise ValueError( - "_operation_id is not initialized. Please call `next_operation_id()`" - " first." - ) - return multihost.unique_barrier_key( - signal.value, suffix=str(cls._operation_id) - ) + @classmethod + def is_intialized(cls) -> bool: + """Returns whether the operation id counter is initialized by calling `next_operation_id`.""" + return cls._operation_id > 0 diff --git a/checkpoint/orbax/checkpoint/_src/futures/synchronization_test.py b/checkpoint/orbax/checkpoint/_src/futures/synchronization_test.py index 5c2912a5..b8363b56 100644 --- a/checkpoint/orbax/checkpoint/_src/futures/synchronization_test.py +++ b/checkpoint/orbax/checkpoint/_src/futures/synchronization_test.py @@ -14,57 +14,34 @@ from absl.testing import absltest from orbax.checkpoint._src.futures import synchronization -from orbax.checkpoint._src.multihost import multihost -HandlerAwaitableSignalBarrierKeyGenerator = ( - synchronization.HandlerAwaitableSignalBarrierKeyGenerator +HandlerAwaitableSignalOperationIdGenerator = ( + synchronization.HandlerAwaitableSignalOperationIdGenerator ) -class HandlerAwaitableSignalBarrierKeyGeneratorTest(absltest.TestCase): +class HandlerAwaitableSignalOperationIdGeneratorTest(absltest.TestCase): - def test_get_unique_barrier_key_without_operation_id_raises_error(self): - step_directory_creation_signal = ( - synchronization.HandlerAwaitableSignal.STEP_DIRECTORY_CREATION - ) - HandlerAwaitableSignalBarrierKeyGenerator._operation_id = None + def test_is_operation_id_initialized(self): + HandlerAwaitableSignalOperationIdGenerator._operation_id = 0 - with self.assertRaisesWithLiteralMatch( - ValueError, - "_operation_id is not initialized. Please call `next_operation_id()`" - " first.", - ): - HandlerAwaitableSignalBarrierKeyGenerator.get_unique_barrier_key( - step_directory_creation_signal - ) + self.assertFalse(HandlerAwaitableSignalOperationIdGenerator.is_intialized()) - def test_get_unique_barrier_key(self): - step_directory_creation_signal = ( - synchronization.HandlerAwaitableSignal.STEP_DIRECTORY_CREATION - ) - expected_barrier_key_0 = multihost.unique_barrier_key( - step_directory_creation_signal.value, suffix="0" - ) - expected_barrier_key_1 = multihost.unique_barrier_key( - step_directory_creation_signal.value, suffix="1" + def test_get_operation_id(self): + HandlerAwaitableSignalOperationIdGenerator.next_operation_id() + operation_id_1 = ( + HandlerAwaitableSignalOperationIdGenerator.get_current_operation_id() ) - HandlerAwaitableSignalBarrierKeyGenerator.next_operation_id() - barrier_key_0 = ( - HandlerAwaitableSignalBarrierKeyGenerator.get_unique_barrier_key( - step_directory_creation_signal - ) - ) - HandlerAwaitableSignalBarrierKeyGenerator.next_operation_id() - barrier_key_1 = ( - HandlerAwaitableSignalBarrierKeyGenerator.get_unique_barrier_key( - step_directory_creation_signal - ) + HandlerAwaitableSignalOperationIdGenerator.next_operation_id() + operation_id_2 = ( + HandlerAwaitableSignalOperationIdGenerator.get_current_operation_id() ) - self.assertEqual(barrier_key_0, expected_barrier_key_0) - self.assertEqual(barrier_key_1, expected_barrier_key_1) + self.assertTrue(HandlerAwaitableSignalOperationIdGenerator.is_intialized()) + self.assertEqual(operation_id_1, "1") + self.assertEqual(operation_id_2, "2") if __name__ == "__main__":