Skip to content

Commit

Permalink
Allow concurrent CheckpointManager.save requests and wait_until_finis…
Browse files Browse the repository at this point in the history
…hed.

PiperOrigin-RevId: 726555960
  • Loading branch information
niketkumar authored and Orbax Authors committed Feb 13, 2025
1 parent 7fc68b8 commit 478b102
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 84 deletions.
4 changes: 4 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- Allow concurrent CheckpointManager.save requests and wait_until_finished.

## [0.11.5] - 2025-02-10

### Fixed
Expand Down
18 changes: 18 additions & 0 deletions checkpoint/orbax/checkpoint/_src/threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

from collections.abc import Callable
import threading
import time
from typing import Generic, TypeVar
Expand Down Expand Up @@ -80,6 +81,12 @@ def set(self, value: _T | None) -> OptionalRef[_T]:
self._value = value
return self

def set_from(self, get_value: Callable[[], _T | None]) -> OptionalRef[_T]:
"""Sets `value` from `get_value` and returns self."""
with self._lock:
self._value = get_value()
return self

def set_if_none(self, value: _T) -> OptionalRef[_T]:
"""Sets `value` if current value is None and returns self."""
with self._lock:
Expand All @@ -92,6 +99,17 @@ def get(self) -> _T | None:
with self._lock:
return self._value

def get_not_none(self) -> _T:
"""Returns the value, or raises assertion exception if the value is None."""
with self._lock:
assert self._value is not None
return self._value

def is_none(self) -> bool:
"""Returns True if the value is None."""
with self._lock:
return self._value is None


class Ref(Generic[_T]):
"""A thread-safe reference to a value that should never be None."""
Expand Down
169 changes: 85 additions & 84 deletions checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class _FinalizeThread(threading.Thread):
def __init__(
self,
step: int,
save_thread: threading.Thread,
saver_thread: threading.Thread,
target: Callable[..., object],
name: str,
args=(),
Expand All @@ -146,13 +146,13 @@ def __init__(
daemon=daemon,
)
self._step = step
self._save_thread = save_thread
self._saver_thread = saver_thread

def step(self) -> int:
return self._step

def save_thread(self) -> threading.Thread:
return self._save_thread
def saver_thread(self) -> threading.Thread:
return self._saver_thread

def run(self):
try:
Expand Down Expand Up @@ -766,13 +766,7 @@ def __init__(
self._maybe_save_root_metadata(metadata)

# TODO: b/359854428 - Move Finalize biz logic to a separate class/module.
self._finalize_thread_lock = threading.Lock()
with self._finalize_thread_lock:
self._finalize_thread = None

self._is_saving_in_progress_lock = threading.Lock()
with self._is_saving_in_progress_lock:
self._is_saving_in_progress = False
self._finalize_thread = threading_lib.OptionalRef[_FinalizeThread]()

self._checkpoint_deleter: deleter.CheckpointDeleter = (
deleter.create_checkpoint_deleter(
Expand Down Expand Up @@ -1311,34 +1305,41 @@ def save(
processes=self._multiprocessing_options.active_processes,
)

assert self._finalize_thread is None
current_thread = threading.current_thread()
if is_async_checkpointer(self._checkpointer):
with self._is_saving_in_progress_lock:
self._is_saving_in_progress = True
with self._finalize_thread_lock:

def launch_finalize_thread() -> _FinalizeThread:
assert self._finalize_thread.is_none(), (
'Save finalization already in progress for'
f' step={self._finalize_thread.get_not_none().step()}'
)
finalize_thread_name = 'save_finalize'
logging.info(
'[process=%s][thread=%s][step=%s] Starting CheckpointManager Save'
' Finalize thread=%s',
process_index,
threading.current_thread().name,
current_thread.name,
step,
finalize_thread_name,
)
self._finalize_thread = _FinalizeThread(
finalize_thread = _FinalizeThread(
step=step,
save_thread=threading.current_thread(),
saver_thread=current_thread,
name=finalize_thread_name,
target=self._finalize,
args=(step, steps_to_remove),
)
self._finalize_thread.start()
finalize_thread.start()
return finalize_thread

self._finalize_thread.set_from(launch_finalize_thread)

else:
self._finalize(step, steps_to_remove)
logging.info(
'[process=%s][thread=%s][step=%s] Finished synchronous save.',
process_index,
threading.current_thread().name,
current_thread.name,
step,
)

Expand Down Expand Up @@ -1796,80 +1797,81 @@ def _wait_for_checkpointers(self):
def wait_until_finished(self):
"""See superclass documentation."""
process_index = multihost.process_index()
current_thread = threading.current_thread()
logging.info(
'[process=%s][thread=%s][wait_until_finished] Initiating wait for Save'
' Finalize thread.',
process_index,
threading.current_thread().name,
current_thread.name,
)
with self._finalize_thread_lock:
if self._finalize_thread is None:
logging.info(
'[process=%s][thread=%s][wait_until_finished] No Save Finalize'
' thread to wait for. Returning.',
process_index,
threading.current_thread().name,
)
return
if self._finalize_thread.is_none():
logging.info(
'[process=%s][thread=%s][wait_until_finished] No Save Finalize'
' thread to wait for. Returning.',
process_index,
current_thread.name,
)
return

step = self._finalize_thread.step()
try:
logging.info(
'[process=%s][thread=%s][step=%s][wait_until_finished] Waiting for'
' Save Finalize thread (%s) to complete.',
process_index,
threading.current_thread().name,
step,
self._finalize_thread.name,
)
self._finalize_thread.join()
step = self._finalize_thread.get_not_none().step()
finalize_thread_name = self._finalize_thread.get_not_none().name
saver_thread = self._finalize_thread.get_not_none().saver_thread()
try:
logging.info(
'[process=%s][thread=%s][step=%s][wait_until_finished] Waiting for'
' Save Finalize thread (%s) to complete.',
process_index,
current_thread.name,
step,
finalize_thread_name,
)
self._finalize_thread.get_not_none().join() # don't call with a lock.
logging.info(
'[process=%s][thread=%s][step=%s][wait_until_finished] Done'
' waiting for Save Finalize thread (%s) running at step=%d.',
process_index,
current_thread.name,
step,
finalize_thread_name,
step,
)
except BaseException: # pylint:disable=broad-exception-caught
logging.exception(
'[process=%s][thread=%s][step=%s][wait_until_finished] Save'
' Finalize thread (%s) failed.',
process_index,
current_thread.name,
step,
finalize_thread_name,
)
# Only thread which requested save is allowed to clean up.
if current_thread is saver_thread:
# If an exception occurred in the finalization of the previous
# save, we clean up since that checkpoint was never actually saved.
latest = self._checkpoints.latest()
assert latest is not None
assert latest.step == step
self._checkpoints.delete_if(lambda info: info.step == step)
raise
finally:
# Only thread which requested save is allowed to reset Save Finalize
# thread.
if current_thread is saver_thread:
logging.info(
'[process=%s][thread=%s][step=%s][wait_until_finished] Done'
' waiting for Save Finalize thread (%s) running at step=%d.',
'[process=%s][thread=%s][step=%s][wait_until_finished] Resetting'
' Save Finalize thread (%s) running at step=%d, also errors if'
' any.',
process_index,
threading.current_thread().name,
step,
self._finalize_thread.name,
current_thread.name,
step,
)
except BaseException: # pylint:disable=broad-exception-caught
logging.exception(
'[process=%s][thread=%s][step=%s][wait_until_finished] Save'
' Finalize thread (%s) failed.',
process_index,
threading.current_thread().name,
finalize_thread_name,
step,
self._finalize_thread.name,
)
# Only thread which requested save is allowed to clean up.
if threading.current_thread() is self._finalize_thread.save_thread():
# If an exception occurred in the finalization of the previous
# save, we clean up since that checkpoint was never actually saved.
latest = self._checkpoints.latest()
assert latest is not None
assert latest.step == step
self._checkpoints.delete_if(lambda info: info.step == step)
raise
finally:
# Only thread which requested save is allowed to reset Save Finalize
# thread.
if threading.current_thread() is self._finalize_thread.save_thread():
logging.info(
'[process=%s][thread=%s][step=%s][wait_until_finished] Resetting'
' Save Finalize thread (%s) running at step=%d, also errors if'
' any.',
process_index,
threading.current_thread().name,
step,
self._finalize_thread.name,
step,
)
self._finalize_thread = None
self._finalize_thread.set(None)

def is_saving_in_progress(self) -> bool:
"""Returns whether a checkpoint save is in progress."""
with self._is_saving_in_progress_lock:
return self._is_saving_in_progress
return not self._finalize_thread.is_none()

def check_for_errors(self):
"""See superclass documentation."""
Expand Down Expand Up @@ -1916,6 +1918,7 @@ def _finalize_checkpoint(self, step: int):
def _finalize(self, step: int, steps_to_remove: List[int]):
"""Finalizes individual items and starts garbage collection."""
process_index = multihost.process_index()
current_thread = threading.current_thread()
self._non_blocking_metadata_store.wait_until_finished()
self._wait_for_checkpointers()
# If an error is encountered while waiting for commit futures to complete,
Expand All @@ -1931,7 +1934,7 @@ def _finalize(self, step: int, steps_to_remove: List[int]):
'[process=%s][thread=%s][step=%s] CheckpointManager Save Finalize is'
' syncing with other hosts...',
process_index,
threading.current_thread().name,
current_thread.name,
step,
)
barrier_sync_fn = self._create_thread_safe_barrier_sync_fn()
Expand All @@ -1946,11 +1949,9 @@ def _finalize(self, step: int, steps_to_remove: List[int]):
'[process=%s][thread=%s][step=%s] CheckpointManager Save Finalize is'
' done on all hosts.',
process_index,
threading.current_thread().name,
current_thread.name,
step,
)
with self._is_saving_in_progress_lock:
self._is_saving_in_progress = False

def close(self):
"""See superclass documentation."""
Expand Down

0 comments on commit 478b102

Please sign in to comment.