Skip to content

Commit

Permalink
Support custom PyTree metadata. Standardize naming of the "custom met…
Browse files Browse the repository at this point in the history
…adata" field (user-supplied metadata) as `custom_metadata`.

PiperOrigin-RevId: 718050751
  • Loading branch information
cpgaffney1 authored and pax authors committed Jan 21, 2025
1 parent 5827e24 commit 3bbff47
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions paxml/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,12 +598,19 @@ class PaxCheckpointHandlerImpl(ocp.BasePyTreeCheckpointHandler):
def _write_metadata_file(
self,
directory: epath.Path,
item: PyTree,
*,
param_infos: PyTree,
save_args: PyTree,
custom_metadata: Any | None = None,
use_zarr3: bool = False,
):
if self._use_ocdbt:
return super()._write_metadata_file(directory, item, save_args, use_zarr3)
return super()._write_metadata_file(
directory,
param_infos=param_infos,
save_args=save_args,
use_zarr3=use_zarr3,
)
return ocp.future.NoopFuture()

def _read_metadata_file(self, directory: epath.Path) -> Any:
Expand Down Expand Up @@ -782,8 +789,10 @@ class FlaxCheckpointHandlerImpl(ocp.BasePyTreeCheckpointHandler):
def _write_metadata_file(
self,
directory: epath.Path,
item: PyTree,
*,
param_infos: PyTree,
save_args: PyTree,
custom_metadata: Any | None = None,
use_zarr3: bool = False,
):
pass
Expand Down

0 comments on commit 3bbff47

Please sign in to comment.