Skip to content

Commit

Permalink
Refactor internal methods of AsyncCheckpointer to facilitate reuse.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726209384
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Feb 13, 2025
1 parent e190177 commit 8925b41
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 147 deletions.
325 changes: 178 additions & 147 deletions checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,92 @@ def _on_commit_callback(
)


def _background_wait_for_commit_futures(
directory: epath.Path,
commit_futures: Sequence[future.Future],
on_commit_callback: Callable[[], None],
*,
barrier_sync_key_prefix: str,
sync_fn: Callable[[str], None],
primary_host: int | None,
):
"""A function to be run in a background thread that waits for futures."""
current_process = multihost.process_index()
current_thread_id = threading.current_thread().name
process_count = jax.process_count()
logging.info(
'[process=%s][thread=%s] Background save thread started.',
current_process,
current_thread_id,
)
thread_start_time = time.time()

# Wait for commit operations to complete.
for commit_future in commit_futures:
commit_future.result()
logging.info(
'[process=%s][thread=%s] %d Handler Commit operations completed.',
current_process,
current_thread_id,
len(commit_futures),
)
# Log the number of async writes that are in flight. Abuses a duration
# metric as a counter since jax.monitoring only has events and durations.
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/commit_future_count',
len(commit_futures),
)

# Log the per process storage commit latency excluding the barrier time.
commit_duration_secs = time.time() - thread_start_time
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/commit_duration_sec',
commit_duration_secs,
)
logging.vlog(1, 'Async Commit duration: %s seconds', commit_duration_secs)

if process_count > 1:
# All processes will wait at the barrier. When all processes are at the
# barrier, the barrier will be satisfied. If not, then it will timeout.
try:
sync_fn(
multihost.unique_barrier_key(
'async_write_complete',
prefix=barrier_sync_key_prefix,
suffix=f'{directory.name}',
)
)
except jax.errors.JaxRuntimeError as e:
if sys.version_info >= (3, 11):
if 'DEADLINE_EXCEEDED' in str(e):
_add_deadline_exceeded_notes(e)
raise

if utils.is_primary_host(primary_host):
on_commit_callback()
if process_count > 1:
# Block until process 0 completes on_commit_callback.
sync_fn(
multihost.unique_barrier_key(
'async_commit_complete',
prefix=barrier_sync_key_prefix,
suffix=f'{directory.name}',
)
)

thread_duration_secs = time.time() - thread_start_time
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/thread_duration_sec',
thread_duration_secs,
)
logging.vlog(1, 'Async thread duration: %s seconds', thread_duration_secs)
logging.info(
'[process=%s][thread=%s] Background save thread done.',
current_process,
current_thread_id,
)


def _add_deadline_exceeded_notes(e: jax.errors.JaxRuntimeError):
"""Adds notes to the exception to help debug the deadline exceeded error."""
e.add_note('1. Make sure that the job and storage are colocated.')
Expand Down Expand Up @@ -117,85 +203,18 @@ def _thread_func(
on_commit_callback: Callable[[], None],
):
"""Awaits on commit futures and finalizes the checkpoint."""
current_process = multihost.process_index()
current_thread_id = threading.current_thread().name
try:
process_count = jax.process_count()
logging.info(
'[process=%s][thread=%s] Background save thread started.',
current_process,
current_thread_id,
)
thread_start_time = time.time()

# Wait for commit operations to complete.
for commit_future in commit_futures:
commit_future.result()
logging.info(
'[process=%s][thread=%s] %d Handler Commit operations completed.',
current_process,
current_thread_id,
len(commit_futures),
)
# Log the number of async writes that are in flight. Abuses a duration
# metric as a counter since jax.monitoring only has events and durations.
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/commit_future_count',
len(commit_futures),
)

# Log the per process storage commit latency excluding the barrier time.
commit_duration_secs = time.time() - thread_start_time
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/commit_duration_sec',
commit_duration_secs,
)
logging.vlog(1, 'Async Commit duration: %s seconds', commit_duration_secs)

if process_count > 1:
# All processes will wait at the barrier. When all processes are at the
# barrier, the barrier will be satisfied. If not, then it will timeout.
try:
self._sync_fn(
multihost.unique_barrier_key(
'async_write_complete',
prefix=self._barrier_sync_key_prefix,
suffix=f'{directory.name}',
)
)
except jax.errors.JaxRuntimeError as e:
if sys.version_info >= (3, 11):
if 'DEADLINE_EXCEEDED' in str(e):
_add_deadline_exceeded_notes(e)
raise

if utils.is_primary_host(self._primary_host):
on_commit_callback()
if process_count > 1:
# Block until process 0 completes on_commit_callback.
self._sync_fn(
multihost.unique_barrier_key(
'async_commit_complete',
prefix=self._barrier_sync_key_prefix,
suffix=f'{directory.name}',
)
)

thread_duration_secs = time.time() - thread_start_time
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/thread_duration_sec',
thread_duration_secs,
)
logging.vlog(1, 'Async thread duration: %s seconds', thread_duration_secs)
logging.info(
'[process=%s][thread=%s] Background save thread done.',
current_process,
current_thread_id,
_background_wait_for_commit_futures(
directory,
commit_futures,
on_commit_callback,
barrier_sync_key_prefix=self._barrier_sync_key_prefix,
sync_fn=self._sync_fn,
primary_host=self._primary_host,
)

except Exception as e: # pylint: disable=broad-exception-caught
msg = (
f'[process={current_process}] Failed to run'
f'[process={multihost.process_index()}] Failed to run'
f' {len(commit_futures)} Handler Commit operations or the Commit'
f' callback in background save thread, directory: {directory}'
)
Expand Down Expand Up @@ -338,63 +357,14 @@ def __init__(
)
self._multiprocessing_options = multiprocessing_options

async def _save(
def _make_on_commit_callback(
self,
directory: epath.PathLike,
*args,
force: bool = False,
custom_metadata: dict[str, Any] | None = None,
**kwargs
):
checkpoint_start_time = time.time()
directory = epath.Path(directory)
self.wait_until_finished()
self.synchronize_next_awaitable_signal_operation_id()

jax.monitoring.record_event('/jax/orbax/write/async/start')
logging.info(
'[process=%s] Started async saving checkpoint to %s.',
multihost.process_index(),
directory,
)

if await async_utils.async_exists(directory):
if force:
if utils.is_primary_host(self._primary_host):
logging.info(
'[process=%s] Specified `force`: removing existing directory.',
multihost.process_index(),
)
await async_utils.async_rmtree(
directory
) # Post-sync handled by create_tmp_directory.
else:
raise ValueError(f'Destination {directory} already exists.')

commit_ops = []
tmpdir = self.get_temporary_path(directory)
if self._create_directories_asynchronously:
commit_ops.append(
atomicity.create_all_async(
[tmpdir],
completion_signals=_DIRECTORY_CREATION_SIGNALS,
multiprocessing_options=self._multiprocessing_options,
)
)
else:
await self.create_temporary_path(tmpdir)
# Run copy ops.
# Try to save using new CheckpointArgs API if supported by the handler.
ckpt_args = checkpointer.construct_checkpoint_args(
self._handler, True, *args, **kwargs
)
commit_ops.extend(
await self._handler.async_save(tmpdir.get(), args=ckpt_args) or []
)
commit_ops, _ = jax.tree.flatten(commit_ops)
commit_ops = [op for op in commit_ops if op is not None]

tmpdir: atomicity_types.TemporaryPath,
custom_metadata: dict[str, Any] | None,
checkpoint_start_time: float,
) -> Callable[[], None]:
# Directory is the final directory.

def _callback() -> None:
if utils.is_primary_host(self._primary_host):
# Update StepMetadata after the handler save is complete.
Expand Down Expand Up @@ -438,25 +408,64 @@ def _callback() -> None:
'Finished asynchronous save (blocking + background) in %.2f seconds'
' to %s',
time.time() - checkpoint_start_time,
directory,
tmpdir.get_final(),
)

self._async_manager.start_async_commit(
return _callback

async def _save(
self,
tmpdir: atomicity_types.TemporaryPath,
*args,
force: bool = False,
**kwargs,
):
directory = tmpdir.get_final()
self.synchronize_next_awaitable_signal_operation_id()

jax.monitoring.record_event('/jax/orbax/write/async/start')
logging.info(
'[process=%s] Started async saving checkpoint to %s.',
multihost.process_index(),
directory,
commit_futures=commit_ops,
on_commit_callback=_callback,
)
blocking_duration_secs = time.time() - checkpoint_start_time
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/blocking_duration_secs',
blocking_duration_secs,

if await async_utils.async_exists(directory):
if force:
if utils.is_primary_host(self._primary_host):
logging.info(
'[process=%s] Specified `force`: removing existing directory.',
multihost.process_index(),
)
await async_utils.async_rmtree(
directory
) # Post-sync handled by create_tmp_directory.
else:
raise ValueError(f'Destination {directory} already exists.')

commit_ops = []
if self._create_directories_asynchronously:
commit_ops.append(
atomicity.create_all_async(
[tmpdir],
completion_signals=_DIRECTORY_CREATION_SIGNALS,
multiprocessing_options=self._multiprocessing_options,
)
)
else:
await self.create_temporary_path(tmpdir)
# Run copy ops.
# Try to save using new CheckpointArgs API if supported by the handler.
ckpt_args = checkpointer.construct_checkpoint_args(
self._handler, True, *args, **kwargs
)
logging.info(
'Finished blocking save in %.2f seconds. Continuing to save'
' asynchronously to %s.',
blocking_duration_secs,
directory,
commit_ops.extend(
await self._handler.async_save(tmpdir.get(), args=ckpt_args) or []
)
commit_ops, _ = jax.tree.flatten(commit_ops)
commit_ops = [op for op in commit_ops if op is not None]

return commit_ops

def save(
self,
Expand Down Expand Up @@ -488,15 +497,37 @@ def save(
Raises:
ValueError if the provided directory already exists.
"""
asyncio_utils.run_sync(
checkpoint_start_time = time.time()
directory = epath.Path(directory)
tmpdir = self.get_temporary_path(directory)
on_commit_callback = self._make_on_commit_callback(
tmpdir, custom_metadata, checkpoint_start_time
)
self.wait_until_finished()
commit_ops = asyncio_utils.run_sync(
self._save(
directory,
tmpdir,
*args,
force=force,
custom_metadata=custom_metadata,
**kwargs
**kwargs,
)
)
self._async_manager.start_async_commit(
directory,
commit_futures=commit_ops,
on_commit_callback=on_commit_callback,
)
blocking_duration_secs = time.time() - checkpoint_start_time
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/blocking_duration_secs',
blocking_duration_secs,
)
logging.info(
'Finished blocking save in %.2f seconds. Continuing to save'
' asynchronously to %s.',
blocking_duration_secs,
directory,
)

def restore(self, directory: epath.PathLike, *args, **kwargs) -> Any:
"""See superclass documentation."""
Expand Down
4 changes: 4 additions & 0 deletions checkpoint/orbax/checkpoint/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,10 @@ def _get_unique_barrier_key(key: str) -> str:
def barrier_compatible_test(cls):
"""A decorator to be used with a test class.
This is primarily needed when different processes in a multihost test may be
executing different code. This will cause operation IDs to get out of sync. If
all processes always execute the same code, this decorator is not needed.
E.g.
@barrier_compatible_test
Expand Down
1 change: 1 addition & 0 deletions checkpoint/orbax/checkpoint/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
PyTreeOf,
PyTreeKey,
PyTreePath,
JsonType,
)
from orbax.checkpoint._src.tree.utils import (
get_param_names,
Expand Down

0 comments on commit 8925b41

Please sign in to comment.