Skip to content

Commit

Permalink
Add serialize_for_update() method for StepMetadata.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713417911
  • Loading branch information
BlaziusMaximus authored and Orbax Authors committed Jan 9, 2025
1 parent 95459f7 commit 429326d
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 0 deletions.
52 changes: 52 additions & 0 deletions checkpoint/orbax/checkpoint/_src/metadata/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,58 @@ def test_validate_dict_entry_wrong_value_type(
):
self.deserialize_metadata(StepMetadata, step_metadata)

@parameterized.parameters(
({'item_handlers': {'a': 'b'}},),
({'performance_metrics': {'a': 1.0}},),
({'user_metadata': {'a': 1}, 'init_timestamp_nsecs': 1},),
)
def test_serialize_for_update_valid_kwargs(
self, kwargs: dict[str, Any]
):
self.assertEqual(
step_metadata_serialization.serialize_for_update(**kwargs),
kwargs,
)

@parameterized.parameters(
({'item_handlers': list()},),
({'item_handlers': {int(): None}},),
({'metrics': list()},),
({'metrics': {int(): None}},),
({'performance_metrics': list()},),
({'init_timestamp_nsecs': float()},),
({'commit_timestamp_nsecs': float()},),
({'user_metadata': list()},),
({'user_metadata': {int(): None}},),
)
def test_serialize_for_update_wrong_types(
self, kwargs: dict[str, Any]
):
with self.assertRaises(ValueError):
step_metadata_serialization.serialize_for_update(**kwargs)

def test_serialize_for_update_with_unknown_kwargs(self):
with self.assertRaisesRegex(
ValueError, 'Provided metadata contains unknown key blah'
):
step_metadata_serialization.serialize_for_update(
user_metadata={'a': 1},
blah=123,
)

def test_serialize_for_update_performance_metrics_only_float(self):
self.assertEqual(
step_metadata_serialization.serialize_for_update(
performance_metrics=StepStatistics(
step=1,
event_type='save',
reached_preemption=False,
preemption_received_at=1.0,
)
),
{'performance_metrics': {'preemption_received_at': 1.0}},
)


if __name__ == '__main__':
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,73 @@ def serialize(metadata: StepMetadata) -> SerializedMetadata:
}


# TODO(adamcogdell): Reduce code duplication with deserialize().
def serialize_for_update(**kwargs) -> SerializedMetadata:
"""Validates and serializes `kwargs` to a dictionary.
To be used with MetadataStore.update().
Args:
**kwargs: The kwargs to be serialized.
Returns:
A dictionary of the serialized kwargs.
"""
validated_kwargs = {}

if 'item_handlers' in kwargs:
utils.validate_field(kwargs, 'item_handlers', [dict, str])
item_handlers = kwargs.get('item_handlers')
if isinstance(item_handlers, CompositeCheckpointHandlerTypeStrs):
for k in kwargs.get('item_handlers'):
utils.validate_dict_entry(kwargs, 'item_handlers', k, str)
validated_kwargs['item_handlers'] = item_handlers
elif isinstance(item_handlers, CheckpointHandlerTypeStr):
validated_kwargs['item_handlers'] = item_handlers

if 'metrics' in kwargs:
utils.validate_field(kwargs, 'metrics', dict)
for k in kwargs.get('metrics', {}) or {}:
utils.validate_dict_entry(kwargs, 'metrics', k, str)
validated_kwargs['metrics'] = kwargs.get('metrics', {})

if 'performance_metrics' in kwargs:
utils.validate_field(kwargs, 'performance_metrics', [dict, StepStatistics])
performance_metrics = kwargs.get('performance_metrics', {})
if isinstance(performance_metrics, StepStatistics):
performance_metrics = dataclasses.asdict(performance_metrics)
float_metrics = {
metric: val
for metric, val in performance_metrics.items()
if isinstance(val, float)
}
validated_kwargs['performance_metrics'] = float_metrics

if 'init_timestamp_nsecs' in kwargs:
utils.validate_field(kwargs, 'init_timestamp_nsecs', int)
validated_kwargs['init_timestamp_nsecs'] = (
kwargs.get('init_timestamp_nsecs', None)
)

if 'commit_timestamp_nsecs' in kwargs:
utils.validate_field(kwargs, 'commit_timestamp_nsecs', int)
validated_kwargs['commit_timestamp_nsecs'] = (
kwargs.get('commit_timestamp_nsecs', None)
)

if 'user_metadata' in kwargs:
utils.validate_field(kwargs, 'user_metadata', dict)
for k in kwargs.get('user_metadata', {}) or {}:
utils.validate_dict_entry(kwargs, 'user_metadata', k, str)
validated_kwargs['user_metadata'] = kwargs.get('user_metadata', {})

for k in kwargs:
if k not in validated_kwargs:
raise ValueError('Provided metadata contains unknown key %s.' % k)

return validated_kwargs


def deserialize(
metadata_dict: SerializedMetadata,
item_metadata: CompositeItemMetadata | SingleItemMetadata | None = None,
Expand Down

0 comments on commit 429326d

Please sign in to comment.