Skip to content

Commit

Permalink
Add step utility to get item name from item directory.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718511442
  • Loading branch information
BlaziusMaximus authored and Orbax Authors committed Jan 22, 2025
1 parent 0500787 commit f44afe0
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
10 changes: 10 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_GCS_PATH_PREFIX = ('gs://',)
_COMMIT_SUCCESS_FILE = 'commit_success.txt'
TMP_DIR_SUFFIX = '.orbax-checkpoint-tmp-'
TMP_DIR_NAME_PATTERN = r'^(.+?)\.orbax-checkpoint-tmp-\d+$'
# prefix_1000.orbax-checkpoint-tmp-1010101
# OR
# 1000.orbax-checkpoint-tmp-1010101
Expand Down Expand Up @@ -517,6 +518,15 @@ def is_tmp_checkpoint(path: epath.PathLike) -> bool:
return False


def item_name_from_item_dir(item_dir: epath.PathLike) -> str:
"""Returns the item name from a item's directory (which may be temporary)."""
name = epath.Path(item_dir).name
if tmp_match := re.match(TMP_DIR_NAME_PATTERN, name):
return tmp_match.group(1)
else:
return name


def is_checkpoint_finalized(path: epath.PathLike) -> bool:
"""Determines if the given path represents a finalized checkpoint.
Expand Down
21 changes: 21 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/step_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,27 @@ def test_step_metadata_of_checkpoint_path(self):
)


@parameterized.parameters(
('0.orbax-checkpoint-tmp-1010101', '0'),
('foobar.orbax-checkpoint-tmp-124244', 'foobar'),
('foobar_000505.orbax-checkpoint-tmp-13124', 'foobar_000505'),
(epath.Path('foobar_000505.orbax-checkpoint-tmp-13124'), 'foobar_000505'),
)
def test_item_name_from_item_dir(self, item_dir, expected_item_name):
self.assertEqual(
step_lib.item_name_from_item_dir(item_dir), expected_item_name
)

@parameterized.parameters(
('abc',),
('.orbax-checkpoint-tmp-191913',),
('0.orbax-checkpoint-tmp-',),
(epath.Path('0.orbax-checkpoint-tmp-'),),
)
def test_item_name_from_item_dir_invalid(self, item_dir):
with self.assertRaises(ValueError):
step_lib.item_name_from_item_dir(item_dir)


if __name__ == '__main__':
absltest.main()

0 comments on commit f44afe0

Please sign in to comment.