Skip to content

Commit

Permalink
only allow load class in ffmpeg module (#376)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucemia authored Mar 18, 2024
1 parent c7d2f27 commit f5eead2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 59 deletions.
94 changes: 36 additions & 58 deletions src/ffmpeg/common/serialize.py
Original file line number Diff line number Diff line change
@@ -1,99 +1,76 @@
from __future__ import absolute_import, annotations

import datetime
import importlib
import json
from dataclasses import fields, is_dataclass
from enum import Enum
from functools import partial
from pathlib import Path
from typing import Any


def load_class(path: str) -> Any:
def load_class(path: str, strict: bool = True) -> Any:
"""
Load a class from a string path
Args:
path: The path to the class.
strict: If True, raise an error if the class is not in ffmpeg package.
Returns:
The class.
"""
if strict:
assert path.startswith("ffmpeg."), f"Only support loading class from ffmpeg package: {path}"

module_path, class_name = path.rsplit(".", 1)
module = importlib.import_module(module_path)
return getattr(module, class_name)


class Encoder(json.JSONEncoder):
def frozen(v: Any) -> Any:
"""
Extend JSON encoder to support more type
Convert the instance to a frozen instance
Note:
This encoder supports:
- Enum
- datetime.datetime
- dataclass
"""
Args:
v: The instance to convert.
def default(self, obj: Any) -> Any:
if isinstance(obj, Enum):
return {
"__class__": f"{obj.__class__.__module__}.{obj.__class__.__name__}",
"value": obj.value,
}
elif isinstance(obj, datetime.datetime):
return obj.strftime("%Y-%m-%d %H:%M:%S.%f%z")
elif is_dataclass(obj):
output = {}
for field in fields(obj):
v = getattr(obj, field.name)
output[field.name] = self.default(v)

return {
"__class__": f"{obj.__class__.__module__}.{obj.__class__.__name__}",
**output,
}
return super().default(obj)


class Decoder(json.JSONDecoder):
Returns:
The frozen instance.
"""
Extend JSON decoder to support more type
if isinstance(v, list):
return tuple(frozen(i) for i in v)

Note:
This decoder supports:
- Enum
- datetime.datetime
- dataclass
"""
if isinstance(v, dict):
return tuple((key, frozen(value)) for key, value in v.items())

def __init__(self, *args: Any, **kwargs: Any):
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
return v

def frozen(self, v: Any) -> Any:
if isinstance(v, list):
return tuple(self.frozen(i) for i in v)

if isinstance(v, dict):
return tuple((key, self.frozen(value)) for key, value in v.items())
def object_hook(obj: Any, strict: bool = True) -> Any:
"""
Convert the dictionary to an instance
return v
Args:
obj: The dictionary to convert.
def object_hook(self, obj: Any) -> Any: # pylint: disable=method-hidden
if isinstance(obj, dict):
if obj.get("__class__"):
cls = load_class(obj.pop("__class__"))
Returns:
The instance.
"""
if isinstance(obj, dict):
if obj.get("__class__"):
cls = load_class(obj.pop("__class__"), strict=strict)

if is_dataclass(cls):
# NOTE: in our application, the dataclass is always frozen
return cls(**{k: self.frozen(v) for k, v in obj.items()})
if is_dataclass(cls):
# NOTE: in our application, the dataclass is always frozen
return cls(**{k: frozen(v) for k, v in obj.items()})

return cls(**{k: v for k, v in obj.items()})
return cls(**{k: v for k, v in obj.items()})

return obj
return obj


def loads(raw: str) -> Any:
def loads(raw: str, strict: bool = True) -> Any:
"""
Deserialize the JSON string to an instance
Expand All @@ -103,8 +80,9 @@ def loads(raw: str) -> Any:
Returns:
The deserialized instance.
"""
object_hook_strict = partial(object_hook, strict=strict)

return json.loads(raw, cls=Decoder)
return json.loads(raw, object_hook=object_hook_strict)


def to_dict_with_class_info(instance: Any) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion src/scripts/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def load(cls: type[T], id: str) -> T:
path = cache_path / f"{cls.__name__}/{id}.json"

with path.open() as ifile:
obj = loads(ifile.read())
obj = loads(ifile.read(), strict=False)
return obj


Expand Down

0 comments on commit f5eead2

Please sign in to comment.