Skip to content

Commit

Permalink
Handle FlattenedIndexKey following Jax tree_util fix/update.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 518109319
  • Loading branch information
laurentes authored and copybara-github committed Mar 20, 2023
1 parent 7010502 commit 089af73
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion orbax/checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def get_key_name(key: Any) -> Union[int, str]:
return str(key.key)
elif isinstance(key, jax.tree_util.GetAttrKey):
return key.name
elif isinstance(key, jax.tree_util.FlattenedIndexKey):
return key.key
else:
raise ValueError(f'Unsupported KeyEntry: {type(key)}: "{key}"')

Expand All @@ -119,7 +121,9 @@ def _is_dict_key(key) -> bool:


def _is_sequence_key(key) -> bool:
return isinstance(key, jax.tree_util.SequenceKey)
return isinstance(
key, (jax.tree_util.FlattenedIndexKey, jax.tree_util.SequenceKey)
)


def _raise_unsupported_key_error(key):
Expand Down

0 comments on commit 089af73

Please sign in to comment.