Skip to content

Commit

Permalink
Save StepMetadata in the Checkpointer.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707670612
  • Loading branch information
BlaziusMaximus authored and Orbax Authors committed Jan 10, 2025
1 parent 3ef736f commit 33b0051
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 2 deletions.
2 changes: 2 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ when a custom `snapshot_dir` is specified.
`CompositeCheckpointHandler.metadata()` to retrieve item metadata by
default-constructing `CheckpointHandler`s when they're listed in the saved
`StepMetadata` but aren't found in the checkpoint.
- `FileOptions.format` to specify the underlying checkpointing file format.

### Fixed
- Ignore not-exists and not-dir errors while building step metadata in
Expand All @@ -45,6 +46,7 @@ default-constructing `CheckpointHandler`s when they're listed in the saved

### Changed
- Return `StepMetadata` from `CompositeCheckpointHandler.metadata()`.
- `Checkpointer.save()` also saves `StepMetadata`.

## [0.10.2] - 2024-12-04

Expand Down
58 changes: 56 additions & 2 deletions checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from orbax.checkpoint._src.handlers import checkpoint_handler
from orbax.checkpoint._src.handlers import composite_checkpoint_handler
from orbax.checkpoint._src.metadata import checkpoint
from orbax.checkpoint._src.metadata import step_metadata_serialization
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import atomicity
from orbax.checkpoint._src.path import atomicity_defaults
Expand All @@ -42,6 +43,7 @@
get_legacy_handler_wrapper = (
composite_checkpoint_handler.get_legacy_handler_wrapper
)
StepMetadata = checkpoint.StepMetadata


def construct_checkpoint_args(
Expand Down Expand Up @@ -161,7 +163,12 @@ async def create_temporary_path(
return tmpdir

def save(
self, directory: epath.PathLike, *args, force: bool = False, **kwargs
self,
directory: epath.PathLike,
*args,
force: bool = False,
custom: dict[str, Any] | None = None,
**kwargs,
):
"""Saves the given item to the provided directory.
Expand All @@ -176,6 +183,8 @@ def save(
*args: additional args to provide to the CheckpointHandler's save method.
force: if True, allows overwriting an existing directory. May add overhead
due to the need to delete any existing files.
custom: a dictionary of custom metadata to be written to the
checkpoint directory via StepMetadata.
**kwargs: additional keyword args to provide to the CheckpointHandler's
save method.
Expand All @@ -200,8 +209,12 @@ def save(
else:
raise ValueError(f'Destination {directory} already exists.')
ckpt_args = construct_checkpoint_args(self._handler, True, *args, **kwargs)
# tmpdir creation also does an initial StepMetadata save.
tmpdir = asyncio_utils.run_sync(self.create_temporary_path(directory))
self._handler.save(tmpdir.get(), args=ckpt_args)
if utils.is_primary_host(self._primary_host):
# Update StepMetadata after the handler save is complete.
self._save_step_metadata(tmpdir.get(), custom=custom)
multihost.sync_global_processes(
multihost.unique_barrier_key(
'Checkpointer:save',
Expand All @@ -212,6 +225,7 @@ def save(

# Ensure save operation atomicity and record time saved by checkpoint.
if utils.is_primary_host(self._primary_host):
# finalize does a final StepMetadata update.
self._handler.finalize(tmpdir.get())
atomicity.on_commit_callback(
tmpdir,
Expand Down Expand Up @@ -251,11 +265,51 @@ def _restore(
) -> Any:
return self._handler.restore(directory, args=args)

def metadata(self, directory: epath.PathLike) -> Optional[Any]:
def metadata(self, directory: epath.PathLike) -> StepMetadata | Any | None:
"""See superclass documentation."""
directory = epath.Path(directory)
return self._handler.metadata(directory)

def _save_step_metadata(
self, directory: epath.Path, custom: dict[str, Any] | None
):
"""Saves StepMetadata to the checkpoint directory."""
update_dict = {
'custom': custom,
}
# TODO(adamcogdell): Move this to CheckpointHandler._update_metadata().
if isinstance(
self._handler, composite_checkpoint_handler.CompositeCheckpointHandler
):
try:
# get item_handlers from handler
partial_metadata: StepMetadata = self._handler.metadata(directory)
except (FileNotFoundError, NotImplementedError, ValueError, TypeError):
logging.warning(
'Failed to get per-item metadata from directory %s. Handler types '
'will not be saved.',
directory,
)
else:
update_dict['item_handlers'] = partial_metadata.item_handlers
else:
try:
item_handler = self._handler.typestr()
except (NotImplementedError, AttributeError):
logging.warning(
'Failed to get item handler typestr from directory %s. Backup '
'handler type will be saved.',
directory,
)
item_handler = (
f'{self._handler.__module__}.{self._handler.__class__.__qualname__}'
)
update_dict['item_handlers'] = item_handler
self._metadata_store.update(
file_path=checkpoint.step_metadata_file_path(directory),
**step_metadata_serialization.serialize_for_update(**update_dict),
)

def close(self):
"""Closes the underlying CheckpointHandler."""
self._handler.close()
Expand Down
7 changes: 7 additions & 0 deletions checkpoint/orbax/checkpoint/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@



_ORBAX_STANDARD_FORMAT = 'orbax-standard'


@dataclasses.dataclass
class AsyncOptions:
"""Options used to configure async behavior.
Expand Down Expand Up @@ -65,9 +68,13 @@ class FileOptions:
metadata files. e.g. 0o750. Please check
https://github.com/google/etils/blob/main/etils/epath/backend.py if your
path is supported. default=None.
format: The checkpoint file format. This is useful when differentiating
between Orbax and non-Orbax checkpoints, as well as checkpoints saved by
different apis. Defaults to 'orbax-standard'.
"""

path_permission_mode: Optional[int] = None
format: str = _ORBAX_STANDARD_FORMAT



0 comments on commit 33b0051

Please sign in to comment.