Skip to content

Commit

Permalink
Remove _CommitFuture from type_handlers.py and use `futures.CommitF…
Browse files Browse the repository at this point in the history
…utureAwaitingContractedSignals` that also support asynchronous directory creation.

PiperOrigin-RevId: 717917250
  • Loading branch information
mridul-sahu authored and Orbax Authors committed Jan 22, 2025
1 parent 6ffc5dd commit 97473a9
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 94 deletions.
5 changes: 3 additions & 2 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ properties not included in any tree mapping operations.

### Added
- The ability to specify a custom `snapshot_dir` in `checkpoints_iterator`.
- `CommitFuture` and `HandlerAwaitableSignal` for signalling between
Checkpointing layers to enable async directory creation.
- `CommitFutureAwaitDirectorySignals`, `CommitFuture` and
`HandlerAwaitableSignal` for signalling between Checkpointing layers to enable
async directory creation.
- User-provided custom PyTree metadata.

### Fixed
Expand Down
141 changes: 129 additions & 12 deletions checkpoint/orbax/checkpoint/_src/futures/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,73 @@
from typing import Any, Callable, Coroutine, Optional, Sequence

from absl import logging
import jax
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.futures import synchronization
from orbax.checkpoint._src.multihost import multihost
from typing_extensions import Protocol


get_unique_barrier_key = (
synchronization.HandlerAwaitableSignalBarrierKeyGenerator.get_unique_barrier_key
HandlerAwaitableSignal = synchronization.HandlerAwaitableSignal
HandlerAwaitableSignalOperationIdGenerator = (
synchronization.HandlerAwaitableSignalOperationIdGenerator
)
is_intialized = HandlerAwaitableSignalOperationIdGenerator.is_intialized
_SIGNAL_ACTION_SUCCESS = 'signal_action_success'


def _get_unique_barrier_key(
signal: HandlerAwaitableSignal, operation_id: str
) -> str:
"""Returns a unique barrier key for the signal.
Args:
signal: The signal to generate a barrier key for.
operation_id: The operation id to use as a suffix for the barrier key.
"""
return multihost.unique_barrier_key(signal.value, suffix=operation_id)


def get_awaitable_signals_from_contract() -> Sequence[HandlerAwaitableSignal]:
"""Gets the awaitable signals that may be sent for the current operation id."""
client = multihost.get_jax_distributed_client()
barrier_key = _get_unique_barrier_key(
HandlerAwaitableSignal.AWAITABLE_SIGNALS_CONTRACT,
HandlerAwaitableSignalOperationIdGenerator.get_current_operation_id(),
)
try:
values_str = str(client.key_value_try_get(barrier_key))
return [HandlerAwaitableSignal(value) for value in values_str.split(',')]
except jax.errors.JaxRuntimeError:
# If the key is not found, then there are no awaitable signals yet.
return []


def add_to_awaitable_signals_contract(
signals: Sequence[HandlerAwaitableSignal],
):
"""Adds awaitable signals to `AWAITABLE_SIGNALS_CONTRACT` for lower checkpointing layers to wait on.
These signals are added to the list of awaitable signals for the current
opertation id in `HandlerAwaitableSignalOperationIdGenerator`.
Args:
signals: The signals to add to the list of awaitable signals.
"""
if not signals:
return

current_signals = list(get_awaitable_signals_from_contract())
current_signals.extend(signals)
keys = ','.join([current_signal.value for current_signal in current_signals])
client = multihost.get_jax_distributed_client()
barrier_key = _get_unique_barrier_key(
HandlerAwaitableSignal.AWAITABLE_SIGNALS_CONTRACT,
HandlerAwaitableSignalOperationIdGenerator.get_current_operation_id(),
)
client.key_value_set(barrier_key, keys, allow_overwrite=True)


class Future(Protocol):
"""Abstracted Orbax Future class.
Expand Down Expand Up @@ -108,18 +163,21 @@ def join(self, timeout=None):
class _SignalingThread(threading.Thread):
"""Thread that raises an exception if it encounters an error.
Waits for signals to be received before proceeding with the target function.
Then sends signals to indicate that the target function has completed.
Waits for signals to be received for the current operation id before
proceeding with the target function. Then sends signals to indicate that the
target function has completed with
same operation id.
"""

_exception: Optional[Exception] = None

def __init__(
self,
*,
send_signals: list[synchronization.HandlerAwaitableSignal],
receive_signals: list[synchronization.HandlerAwaitableSignal],
send_signals: Sequence[HandlerAwaitableSignal],
receive_signals: Sequence[HandlerAwaitableSignal],
timeout_secs: int = 600,
operation_id: str | None = None,
**kwargs,
):
"""Constructor.
Expand All @@ -130,12 +188,19 @@ def __init__(
receive_signals: Signals to wait for before proceeding with the target
function.
timeout_secs: Timeout in seconds for waiting for signals.
operation_id: The operation id to use for the barrier keys. If None, the
current operation id is used.
**kwargs: Keyword arguments passed to the base class.
"""
super().__init__(**kwargs)
self._send_signals = send_signals
self._receive_signals = receive_signals
self._timeout_secs = timeout_secs
# Capture the current operation id syncronously.
self._operation_id = (
operation_id
or HandlerAwaitableSignalOperationIdGenerator.get_current_operation_id()
)

def _wait_for_signals(self):
"""Waits for signals to be set."""
Expand All @@ -148,7 +213,7 @@ def _wait_for_signals(self):
signal.value,
self._timeout_secs,
)
barrier_key = get_unique_barrier_key(signal)
barrier_key = _get_unique_barrier_key(signal, self._operation_id)
client = multihost.get_jax_distributed_client()
client.blocking_key_value_get(barrier_key, self._timeout_secs * 1000)

Expand All @@ -162,7 +227,7 @@ def _set_signals(self):
threading.current_thread().name,
signal.value,
)
barrier_key = get_unique_barrier_key(signal)
barrier_key = _get_unique_barrier_key(signal, self._operation_id)
client = multihost.get_jax_distributed_client()
client.key_value_set(barrier_key, _SIGNAL_ACTION_SUCCESS)

Expand Down Expand Up @@ -194,11 +259,10 @@ def __init__(
coro: Coroutine[Any, Any, None],
*,
name: str | None = None,
send_signals: list[synchronization.HandlerAwaitableSignal] | None = None,
receive_signals: (
list[synchronization.HandlerAwaitableSignal] | None
) = None,
send_signals: Sequence[HandlerAwaitableSignal] | None = None,
receive_signals: Sequence[HandlerAwaitableSignal] | None = None,
timeout_secs: int = 600,
operation_id: str | None = None,
):
"""Constructor.
Expand All @@ -208,6 +272,8 @@ def __init__(
send_signals: Signals to send to indicate that the commit has completed.
receive_signals: Signals to wait for before proceeding with the commit.
timeout_secs: Timeout in seconds for waiting for signals.
operation_id: The operation id to use for the barrier keys. If None, the
current operation id is used.
"""
super().__init__()
send_signals = send_signals or []
Expand All @@ -216,6 +282,7 @@ def __init__(
send_signals=send_signals,
receive_signals=receive_signals,
timeout_secs=timeout_secs,
operation_id=operation_id,
target=lambda: asyncio_utils.run_sync(coro),
name=name,
)
Expand All @@ -224,3 +291,53 @@ def __init__(
def result(self, timeout: Optional[float] = None) -> Any:
"""Waits for the commit to complete."""
return self._t.join(timeout=timeout)


class CommitFutureAwaitingContractedSignals(Future):
"""Represents the result of a background commit.
May send signals to indicate that the commit has completed. Waits for all
awaitable signals in the `AWAITABLE_SIGNALS_CONTRACT` to be set before
proceeding with the commit.
"""

def __init__(
self,
coro: Coroutine[Any, Any, None],
*,
name: str | None = None,
send_signals: Sequence[HandlerAwaitableSignal] | None = None,
skip_if_not_initialized: bool = True,
timeout_secs: int = 600,
operation_id: str | None = None,
):
"""Constructor.
Synchronously gets all awaitable signals in the contract and waits to
receive them in background before proceeding with the commit.
Args:
coro: The coroutine to run.
name: The name of the thread.
send_signals: Signals to send to indicate that the commit has completed.
skip_if_not_initialized: If True, skip fetching signals if the
`HandlerAwaitableSignalOperationIdGenerator` is not initialized.
timeout_secs: Timeout in seconds for waiting for signals.
operation_id: The operation id to use for the barrier keys. If None, the
current operation id is used.
"""
super().__init__()
receive_signals = []
if is_intialized() or not skip_if_not_initialized:
receive_signals = get_awaitable_signals_from_contract()
self._f = CommitFuture(
coro,
name=name,
send_signals=send_signals,
receive_signals=receive_signals,
timeout_secs=timeout_secs,
operation_id=operation_id,
)

def result(self, timeout: Optional[float] = None) -> Any:
return self._f.result(timeout=timeout)
33 changes: 13 additions & 20 deletions checkpoint/orbax/checkpoint/_src/futures/synchronization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import enum
import itertools
from orbax.checkpoint._src.multihost import multihost


class HandlerAwaitableSignal(enum.Enum):
Expand All @@ -26,6 +25,8 @@ class HandlerAwaitableSignal(enum.Enum):
`CheckpointHandler or below.`
Attributes:
AWAITABLE_SIGNALS_CONTRACT: Contract that contains a list of signals that
may be sent and can be awaited by the handlers.
STEP_DIRECTORY_CREATION: When recieved, indicates that the step directory
has been created. The handler should not attempt to write files before the
directory is created.
Expand All @@ -34,15 +35,16 @@ class HandlerAwaitableSignal(enum.Enum):
directory is created.
"""

AWAITABLE_SIGNALS_CONTRACT = "awaitable_signals_contract"
STEP_DIRECTORY_CREATION = "step_directory_creation"
ITEM_DIRECTORY_CREATION = "item_directory_creation"


class HandlerAwaitableSignalBarrierKeyGenerator:
"""A unique barrier key generator for a `HandlerAwaitableSignal`."""
class HandlerAwaitableSignalOperationIdGenerator:
"""A unique operation id generator for a `HandlerAwaitableSignal`."""

_operation_id_counter = itertools.count()
_operation_id = None
_operation_id = next(_operation_id_counter)

@classmethod
def next_operation_id(cls) -> int:
Expand All @@ -51,20 +53,11 @@ def next_operation_id(cls) -> int:
return cls._operation_id

@classmethod
def get_unique_barrier_key(cls, signal: HandlerAwaitableSignal) -> str:
"""Returns a unique barrier key for the signal.
def get_current_operation_id(cls) -> str:
"""Returns the current operation id."""
return str(cls._operation_id)

Args:
signal: The signal to generate a barrier key for.
Raises:
ValueError: If `_operation_id` is not initialized.
"""
if cls._operation_id is None:
raise ValueError(
"_operation_id is not initialized. Please call `next_operation_id()`"
" first."
)
return multihost.unique_barrier_key(
signal.value, suffix=str(cls._operation_id)
)
@classmethod
def is_intialized(cls) -> bool:
"""Returns whether the operation id counter is initialized by calling `next_operation_id`."""
return cls._operation_id > 0
55 changes: 16 additions & 39 deletions checkpoint/orbax/checkpoint/_src/futures/synchronization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,57 +14,34 @@

from absl.testing import absltest
from orbax.checkpoint._src.futures import synchronization
from orbax.checkpoint._src.multihost import multihost


HandlerAwaitableSignalBarrierKeyGenerator = (
synchronization.HandlerAwaitableSignalBarrierKeyGenerator
HandlerAwaitableSignalOperationIdGenerator = (
synchronization.HandlerAwaitableSignalOperationIdGenerator
)


class HandlerAwaitableSignalBarrierKeyGeneratorTest(absltest.TestCase):
class HandlerAwaitableSignalOperationIdGeneratorTest(absltest.TestCase):

def test_get_unique_barrier_key_without_operation_id_raises_error(self):
step_directory_creation_signal = (
synchronization.HandlerAwaitableSignal.STEP_DIRECTORY_CREATION
)
HandlerAwaitableSignalBarrierKeyGenerator._operation_id = None
def test_is_operation_id_initialized(self):
HandlerAwaitableSignalOperationIdGenerator._operation_id = 0

with self.assertRaisesWithLiteralMatch(
ValueError,
"_operation_id is not initialized. Please call `next_operation_id()`"
" first.",
):
HandlerAwaitableSignalBarrierKeyGenerator.get_unique_barrier_key(
step_directory_creation_signal
)
self.assertFalse(HandlerAwaitableSignalOperationIdGenerator.is_intialized())

def test_get_unique_barrier_key(self):
step_directory_creation_signal = (
synchronization.HandlerAwaitableSignal.STEP_DIRECTORY_CREATION
)
expected_barrier_key_0 = multihost.unique_barrier_key(
step_directory_creation_signal.value, suffix="0"
)
expected_barrier_key_1 = multihost.unique_barrier_key(
step_directory_creation_signal.value, suffix="1"
def test_get_operation_id(self):
HandlerAwaitableSignalOperationIdGenerator.next_operation_id()
operation_id_1 = (
HandlerAwaitableSignalOperationIdGenerator.get_current_operation_id()
)

HandlerAwaitableSignalBarrierKeyGenerator.next_operation_id()
barrier_key_0 = (
HandlerAwaitableSignalBarrierKeyGenerator.get_unique_barrier_key(
step_directory_creation_signal
)
)
HandlerAwaitableSignalBarrierKeyGenerator.next_operation_id()
barrier_key_1 = (
HandlerAwaitableSignalBarrierKeyGenerator.get_unique_barrier_key(
step_directory_creation_signal
)
HandlerAwaitableSignalOperationIdGenerator.next_operation_id()
operation_id_2 = (
HandlerAwaitableSignalOperationIdGenerator.get_current_operation_id()
)

self.assertEqual(barrier_key_0, expected_barrier_key_0)
self.assertEqual(barrier_key_1, expected_barrier_key_1)
self.assertTrue(HandlerAwaitableSignalOperationIdGenerator.is_intialized())
self.assertEqual(operation_id_1, "1")
self.assertEqual(operation_id_2, "2")


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions checkpoint/orbax/checkpoint/_src/serialization/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ py_library(
":serialization",
":tensorstore_utils",
":types",
"//checkpoint/orbax/checkpoint:future",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/arrays:subchunking",
"//checkpoint/orbax/checkpoint/_src/arrays:types",
"//checkpoint/orbax/checkpoint/_src/metadata:empty_values",
Expand All @@ -46,6 +44,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src/multihost:multislice",
"//checkpoint/orbax/checkpoint/_src/path:async_utils",
"//checkpoint/orbax/checkpoint/_src/path:format_utils",
"//orbax/checkpoint/_src/futures:future",
"//orbax/checkpoint/_src/metadata:array_metadata_store",
],
)
Expand Down
Loading

0 comments on commit 97473a9

Please sign in to comment.