Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CompositeCheckpointHandler._get_metadata_from_temporary_paths() for metadata access during saving (before finalize). #1528

Merged
merged 1 commit into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,9 @@ def _save_step_metadata(
):
try:
# get item_handlers from handler
partial_metadata: StepMetadata = self._handler.metadata(directory)
partial_metadata: StepMetadata = (
self._handler.metadata_from_temporary_paths(directory)
)
except (FileNotFoundError, NotImplementedError, ValueError, TypeError):
logging.warning(
'Failed to get per-item metadata from directory %s. Handler types '
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,125 @@ def restore(
)
return CompositeResults(**restored)

def _get_item_handlers(
self,
saved_metadata: checkpoint.StepMetadata,
item_names: list[str],
items_to_handlers: dict[str, CheckpointHandler | None],
) -> dict[str, checkpoint.CheckpointHandlerTypeStr | None]:
if saved_metadata.item_handlers is not None:
assert isinstance(saved_metadata.item_handlers, dict)
# Keep relevant non-None handler typestrs from on-disk metadata.
item_handlers: dict[str, checkpoint.CheckpointHandlerTypeStr] = (
saved_metadata.item_handlers
)
else:
item_handlers: dict[str, checkpoint.CheckpointHandlerTypeStr] = {}
for item_name in item_names:
if items_to_handlers.get(item_name) is None:
logging.warning(
'Item "%s" was found in the checkpoint, but could not'
' be restored. Please provide a `CheckpointHandlerRegistry`, or'
' call `restore` with an appropriate `CheckpointArgs` subclass.',
item_name,
)
# Don't overwrite if it was already set on disk.
if item_name not in item_handlers:
item_handlers[item_name] = None
continue

handler = items_to_handlers[item_name]
assert handler is not None
# If already set on disk (to non-None), don't overwrite.
if item_handlers.get(item_name) is None:
item_handlers[item_name] = handler.typestr()

return item_handlers

def _get_item_metadata(
self,
directory: epath.Path,
item_names: list[str],
items_to_handlers: dict[str, CheckpointHandler | None],
) -> dict[str, Any]:
# item_metadata is not saved in StepMetadata, so we don't need to worry
# about reading it from disk.
item_metadata = {}
for item_name in item_names:
if items_to_handlers.get(item_name) is None:
logging.warning(
'Item "%s" was found in the checkpoint, but could not'
' be restored. Please provide a `CheckpointHandlerRegistry`, or'
' call `restore` with an appropriate `CheckpointArgs` subclass.',
item_name,
)
item_metadata[item_name] = None
continue

handler = items_to_handlers[item_name]
assert handler is not None
item_metadata[item_name] = handler.metadata(
self._get_item_directory(directory, item_name)
)
return item_metadata

def _get_metadata_base(
self,
directory: epath.Path,
item_names: list[str],
get_item_metadata: bool = True,
) -> StepMetadata:
"""Base implementation for metadata handling.

Args:
directory: Path to the checkpoint.
item_names: List of item names to process.
get_item_metadata: whether to get item metadata,

Returns:
A tuple of item handlers and item metadata.
"""
items_to_handlers = dict(
self._get_all_registered_and_unregistered_items_and_handlers()
)

serialized_metadata = self._metadata_store.read(
checkpoint.step_metadata_file_path(directory)
)
saved_metadata = step_metadata_serialization.deserialize(
serialized_metadata or {}
)

if get_item_metadata:
item_metadata = CompositeItemMetadata(
**self._get_item_metadata(
directory, item_names, items_to_handlers
)
)
else:
item_metadata = None

return dataclasses.replace(
saved_metadata,
item_handlers=self._get_item_handlers(
saved_metadata, item_names, items_to_handlers
),
item_metadata=item_metadata,
)

def metadata_from_temporary_paths(
self, directory: epath.Path
) -> StepMetadata:
"""Metadata for each item in the temporary checkpoint."""
if not directory.exists():
raise FileNotFoundError(f'Directory does not exist: {directory}')

return self._get_metadata_base(
directory=directory,
item_names=list(self._current_temporary_paths.keys()),
get_item_metadata=False,
)

def metadata(self, directory: epath.Path) -> StepMetadata:
"""Metadata for each item in the checkpoint.

Expand All @@ -834,63 +953,26 @@ def metadata(self, directory: epath.Path) -> StepMetadata:
if not directory.exists():
raise FileNotFoundError(f'Directory does not exist: {directory}')

items_to_handlers = dict(
self._get_all_registered_and_unregistered_items_and_handlers()
)
try:
existing_items = self._existing_items(directory)
except OSError:
existing_items = []
logging.warning(
'Failed to get existing items from directory %s. Will use items '
'provided during initialization: %s.',
directory, list(items_to_handlers.keys()),
)

serialized_metadata = self._metadata_store.read(
checkpoint.step_metadata_file_path(directory)
)
saved_metadata = step_metadata_serialization.deserialize(
serialized_metadata or {}
)
if saved_metadata.item_handlers is not None:
assert isinstance(saved_metadata.item_handlers, dict)
item_handlers: dict[str, checkpoint.CheckpointHandlerTypeStr] = (
saved_metadata.item_handlers
'provided during initialization.', directory,
)
else:
item_handlers: dict[str, checkpoint.CheckpointHandlerTypeStr] = {}
item_metadata = {}

for item_name in existing_items:
if (
item_name not in items_to_handlers
or items_to_handlers[item_name] is None
):
logging.warning(
'Item "%s" was found in the checkpoint, but could not'
' be restored. Please provide a `CheckpointHandlerRegistry`, or'
' call `restore` with an appropriate `CheckpointArgs` subclass.',
item_name,
)
if item_name not in item_handlers:
item_handlers[item_name] = None
if item_name not in item_metadata:
item_metadata[item_name] = None
continue
handler = items_to_handlers[item_name]
assert handler is not None
if item_handlers.get(item_name) is None:
item_handlers[item_name] = handler.typestr()
if item_metadata.get(item_name) is None:
item_metadata[item_name] = handler.metadata(
self._get_item_directory(directory, item_name)
for item_tmp_dir in self._current_temporary_paths.values():
tmp_dir_name = item_tmp_dir.get().name
if tmp_dir_name in existing_items:
raise ValueError(
f'Item "{tmp_dir_name}" was found in the checkpoint, but it is a '
'temporary item. Please call finalize() before metadata().'
)

return dataclasses.replace(
saved_metadata,
item_handlers=item_handlers,
item_metadata=CompositeItemMetadata(**item_metadata),
return self._get_metadata_base(
directory=directory,
item_names=existing_items,
)

def finalize(self, directory: epath.Path):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,68 @@ def test_metadata_existing_items_updates_step_metadata(self):
)
self.assertIsNotNone(step_metadata.item_metadata['state'])

def test_metadata_with_temporary_directories(self):
handler = CompositeCheckpointHandler()
handler.save(
self.directory,
CompositeArgs(
state=args_lib.StandardSave({'a': 1, 'b': 2}),
),
)

# Make sure 'state' tmp dir is created.
self.assertFalse((self.directory / 'state').exists())
existing_items = handler._existing_items(self.directory)
item_names = [
item_dir.split(step.TMP_DIR_SUFFIX, 1)[0]
for item_dir in existing_items
]
self.assertIn('state', item_names)

with self.assertRaises(ValueError):
handler.metadata(self.directory)

handler.finalize(self.directory)
step_metadata = handler.metadata(self.directory)
self.assertDictEqual(
step_metadata.item_handlers,
{
'state': StandardCheckpointHandler().typestr(),
}
)
self.assertIsNotNone(step_metadata.item_metadata['state'])

def test_metadata_from_temporary_paths(self):
handler = CompositeCheckpointHandler()
handler.save(
self.directory,
CompositeArgs(
state=args_lib.StandardSave({'a': 1, 'b': 2}),
),
)

# Make sure 'state' temp dir is created.
self.assertFalse((self.directory / 'state').exists())
existing_items = handler._existing_items(self.directory)
item_names = [
item_dir.split(step.TMP_DIR_SUFFIX, 1)[0]
for item_dir in existing_items
]
self.assertIn('state', item_names)

step_metadata = handler.metadata_from_temporary_paths(self.directory)
self.assertDictEqual(
step_metadata.item_handlers,
{
'state': StandardCheckpointHandler().typestr(),
},
)

handler.finalize(self.directory)
# Temporary files are absent after finalize.
step_metadata = handler.metadata_from_temporary_paths(self.directory)
self.assertEmpty(step_metadata.item_handlers)

def test_finalize(self):
state_handler = mock.create_autospec(StandardCheckpointHandler)
metadata_handler = mock.create_autospec(JsonCheckpointHandler)
Expand Down
Loading