From 089af73aa7b0d89dc38e1d0c5f8f05e6437a1e7b Mon Sep 17 00:00:00 2001 From: Laurent El Shafey Date: Mon, 20 Mar 2023 16:31:12 -0700 Subject: [PATCH] Handle FlattenedIndexKey following Jax tree_util fix/update. PiperOrigin-RevId: 518109319 --- orbax/checkpoint/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/orbax/checkpoint/utils.py b/orbax/checkpoint/utils.py index 5be91cbd9..38b920178 100644 --- a/orbax/checkpoint/utils.py +++ b/orbax/checkpoint/utils.py @@ -110,6 +110,8 @@ def get_key_name(key: Any) -> Union[int, str]: return str(key.key) elif isinstance(key, jax.tree_util.GetAttrKey): return key.name + elif isinstance(key, jax.tree_util.FlattenedIndexKey): + return key.key else: raise ValueError(f'Unsupported KeyEntry: {type(key)}: "{key}"') @@ -119,7 +121,9 @@ def _is_dict_key(key) -> bool: def _is_sequence_key(key) -> bool: - return isinstance(key, jax.tree_util.SequenceKey) + return isinstance( + key, (jax.tree_util.FlattenedIndexKey, jax.tree_util.SequenceKey) + ) def _raise_unsupported_key_error(key):