Skip to content

Commit

Permalink
add serialization and deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
biagiodistefano committed Nov 28, 2024
1 parent d90f713 commit 56b1e6e
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ dmypy.json
*.sqlite3
/src/static
/src/deployments.info
/src/playground.py
/src/media/
server_setup/

Expand Down
46 changes: 41 additions & 5 deletions src/rule_engine/rule.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import re
import typing as t
from enum import Enum
Expand All @@ -24,7 +25,7 @@ class Operator(str, Enum):
NE = "ne"
EQ = "eq"
REGEX = "regex"
FUNC = "func"
# FUNC = "func"


AND, OR = "AND", "OR"
Expand Down Expand Up @@ -59,7 +60,7 @@ def _regex(field_value: t.Any, pattern: t.Any) -> bool:
raise ValueError("The value for the `REGEX` operator must be a string or a compiled regex pattern.")


def _func(field_value: t.Any, func: t.Callable[[t.Any], bool]) -> bool:
def _func(field_value: t.Any, func: t.Callable[[t.Any], bool]) -> bool: # pragma: no cover
if callable(func):
return func(field_value)
raise ValueError("The value for the `FUNC` operator must be a callable.")
Expand All @@ -84,7 +85,7 @@ def _func(field_value: t.Any, func: t.Callable[[t.Any], bool]) -> bool:
Operator.NE: lambda fv, cv: fv != cv,
Operator.EQ: lambda fv, cv: fv == cv,
Operator.REGEX: _regex,
Operator.FUNC: _func,
# Operator.FUNC: _func,
}


Expand Down Expand Up @@ -151,13 +152,12 @@ def _eval() -> bool:
return condition.evaluate(example)
else:
for key, value in condition.items():
print(key, value)
if "__" in key:
field, op = key.split("__", 1)
if not self._evaluate_operator(op, example.get(field, None), value):
return False
else:
if key not in example or example[key] != value:
if not self._evaluate_operator("eq", example.get(key, None), value):
return False
return True

Expand Down Expand Up @@ -189,6 +189,42 @@ def evaluate(self, example: t.Dict[str, t.Any]) -> bool:

return result if result is not None else False

def to_dict(self) -> dict[str, t.Any]:
return {
"$rule": True,
"id": self.id,
"negated": self.negated,
"conditions": [
{"operator": op, "condition": cond.to_dict() if isinstance(cond, Rule) else cond}
for op, cond in self.conditions
],
}

@classmethod
def from_dict(cls, data: dict[str, t.Any]) -> "Rule":
rule = cls()
if not data.get("$rule"):
raise ValueError("Invalid rule data")
rule._id = data["id"]
rule._negated = data["negated"]
for cond in data["conditions"]:
operator = cond["operator"]
condition = cond["condition"]
if isinstance(condition, dict) and condition.get("$rule"):
condition = cls.from_dict(condition)
rule.conditions.append((operator, condition))
return rule

def to_json(self, *args: t.Any, **kwargs: t.Any) -> str:
"""Serialize the Rule to a JSON string."""
return json.dumps(self.to_dict(), *args, **kwargs)

@classmethod
def from_json(cls, json_str: str) -> "Rule":
"""Deserialize a Rule from a JSON string."""
data = json.loads(json_str)
return cls.from_dict(data)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(conditions={self.conditions}, negated={self.negated})"

Expand Down
27 changes: 24 additions & 3 deletions src/rule_engine/tests/test_rule_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
(Operator.EQ, 3, 4, False),
(Operator.REGEX, "hello123", r"\w+\d+", True),
(Operator.REGEX, "hello", r"\d+", False),
(Operator.FUNC, "hello", lambda x: x.startswith("he"), True),
(Operator.FUNC, "hello", lambda x: x.endswith("lo"), True),
# (Operator.FUNC, "hello", lambda x: x.startswith("he"), True),
# (Operator.FUNC, "hello", lambda x: x.endswith("lo"), True),
],
)
def test_operator_evaluation(operator: str, field_value: t.Any, condition_value: t.Any, expected: bool) -> None:
Expand All @@ -56,7 +56,7 @@ def test_operator_evaluation(operator: str, field_value: t.Any, condition_value:
(Operator.STARTSWITH, 5, "hello"),
(Operator.ENDSWITH, "hello", 5),
(Operator.REGEX, 5, "hello"),
(Operator.FUNC, 5, "hello"),
# (Operator.FUNC, 5, "hello"),
),
)
def test_operator_evaluation_value_error(operator: str, field_value: t.Any, condition_value: t.Any) -> None:
Expand Down Expand Up @@ -168,3 +168,24 @@ def test_and_value_error() -> None:
def test_or_value_error() -> None:
with pytest.raises(ValueError):
Rule() | "invalid_rule" # type: ignore[operator]


def test_to_json_and_from_json() -> None:
rule = Rule(Rule(foo="bar") | Rule(foo="baz"), name="John", age__gte=21)
rule_json = rule.to_json()
loaded_rule = Rule.from_json(rule_json)
assert rule.to_dict() == loaded_rule.to_dict()
example_true = {"foo": "bar", "name": "John", "age": 22}
example_false = {"foo": "qux", "name": "Jane", "age": 19}
assert evaluate(rule, example_true)
assert not evaluate(rule, example_false)
assert evaluate(loaded_rule, example_true)
assert not evaluate(loaded_rule, example_false)


def test_to_load_rule_invalid() -> None:
rule = Rule(Rule(foo="bar"))
rule_json = rule.to_dict()
rule_json.pop("$rule")
with pytest.raises(ValueError):
Rule.from_dict(rule_json)

0 comments on commit 56b1e6e

Please sign in to comment.