diff --git a/src/ffmpeg/common/serialize.py b/src/ffmpeg/common/serialize.py index 28a04a38..5daf7599 100644 --- a/src/ffmpeg/common/serialize.py +++ b/src/ffmpeg/common/serialize.py @@ -1,99 +1,76 @@ from __future__ import absolute_import, annotations -import datetime import importlib import json from dataclasses import fields, is_dataclass from enum import Enum +from functools import partial from pathlib import Path from typing import Any -def load_class(path: str) -> Any: +def load_class(path: str, strict: bool = True) -> Any: """ Load a class from a string path Args: path: The path to the class. + strict: If True, raise an error if the class is not in ffmpeg package. Returns: The class. """ + if strict: + assert path.startswith("ffmpeg."), f"Only support loading class from ffmpeg package: {path}" + module_path, class_name = path.rsplit(".", 1) module = importlib.import_module(module_path) return getattr(module, class_name) -class Encoder(json.JSONEncoder): +def frozen(v: Any) -> Any: """ - Extend JSON encoder to support more type + Convert the instance to a frozen instance - Note: - This encoder supports: - - Enum - - datetime.datetime - - dataclass - """ + Args: + v: The instance to convert. - def default(self, obj: Any) -> Any: - if isinstance(obj, Enum): - return { - "__class__": f"{obj.__class__.__module__}.{obj.__class__.__name__}", - "value": obj.value, - } - elif isinstance(obj, datetime.datetime): - return obj.strftime("%Y-%m-%d %H:%M:%S.%f%z") - elif is_dataclass(obj): - output = {} - for field in fields(obj): - v = getattr(obj, field.name) - output[field.name] = self.default(v) - - return { - "__class__": f"{obj.__class__.__module__}.{obj.__class__.__name__}", - **output, - } - return super().default(obj) - - -class Decoder(json.JSONDecoder): + Returns: + The frozen instance. """ - Extend JSON decoder to support more type + if isinstance(v, list): + return tuple(frozen(i) for i in v) - Note: - This decoder supports: - - Enum - - datetime.datetime - - dataclass - """ + if isinstance(v, dict): + return tuple((key, frozen(value)) for key, value in v.items()) - def __init__(self, *args: Any, **kwargs: Any): - json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) + return v - def frozen(self, v: Any) -> Any: - if isinstance(v, list): - return tuple(self.frozen(i) for i in v) - if isinstance(v, dict): - return tuple((key, self.frozen(value)) for key, value in v.items()) +def object_hook(obj: Any, strict: bool = True) -> Any: + """ + Convert the dictionary to an instance - return v + Args: + obj: The dictionary to convert. - def object_hook(self, obj: Any) -> Any: # pylint: disable=method-hidden - if isinstance(obj, dict): - if obj.get("__class__"): - cls = load_class(obj.pop("__class__")) + Returns: + The instance. + """ + if isinstance(obj, dict): + if obj.get("__class__"): + cls = load_class(obj.pop("__class__"), strict=strict) - if is_dataclass(cls): - # NOTE: in our application, the dataclass is always frozen - return cls(**{k: self.frozen(v) for k, v in obj.items()}) + if is_dataclass(cls): + # NOTE: in our application, the dataclass is always frozen + return cls(**{k: frozen(v) for k, v in obj.items()}) - return cls(**{k: v for k, v in obj.items()}) + return cls(**{k: v for k, v in obj.items()}) - return obj + return obj -def loads(raw: str) -> Any: +def loads(raw: str, strict: bool = True) -> Any: """ Deserialize the JSON string to an instance @@ -103,8 +80,9 @@ def loads(raw: str) -> Any: Returns: The deserialized instance. """ + object_hook_strict = partial(object_hook, strict=strict) - return json.loads(raw, cls=Decoder) + return json.loads(raw, object_hook=object_hook_strict) def to_dict_with_class_info(instance: Any) -> Any: diff --git a/src/scripts/cache.py b/src/scripts/cache.py index 63f11f4a..178f8857 100644 --- a/src/scripts/cache.py +++ b/src/scripts/cache.py @@ -13,7 +13,7 @@ def load(cls: type[T], id: str) -> T: path = cache_path / f"{cls.__name__}/{id}.json" with path.open() as ifile: - obj = loads(ifile.read()) + obj = loads(ifile.read(), strict=False) return obj