Skip to content

Commit

Permalink
better serialization error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
amyasnikov committed Sep 28, 2024
1 parent 6112967 commit 9900d85
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 9 deletions.
3 changes: 2 additions & 1 deletion validity/compliance/serialization/routeros.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass, field
from typing import Generator, Literal

from validity.utils.misc import reraise
from validity.utils.misc import log_exceptions, reraise
from ..exceptions import SerializationError


Expand Down Expand Up @@ -146,5 +146,6 @@ def parse_config(plain_config: str) -> dict:
return result


@log_exceptions(logger, level="info", log_traceback=True)
def serialize_ros(plain_data: str, template: str, parameters: dict):
return parse_config(plain_data)
6 changes: 6 additions & 0 deletions validity/compliance/serialization/textfsm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import io
import logging

import textfsm

from validity.utils.misc import log_exceptions
from .common import postprocess_jq


logger = logging.getLogger(__name__)


@log_exceptions(logger, "info", log_traceback=True)
@postprocess_jq
def serialize_textfsm(plain_data: str, template: str, parameters: dict) -> list[dict]:
dict_results = []
Expand Down
7 changes: 7 additions & 0 deletions validity/compliance/serialization/ttp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import logging

from ttp import ttp

from validity.utils.misc import log_exceptions
from .common import postprocess_jq


logger = logging.getLogger(__name__)


@log_exceptions(logger, "info", log_traceback=True)
@postprocess_jq
def serialize_ttp(plain_data: str, template: str, parameters: dict):
parser = ttp(data=plain_data, template=template)
Expand Down
9 changes: 7 additions & 2 deletions validity/compliance/serialization/xml.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import logging
from xml.parsers.expat import ExpatError

import xmltodict

from validity.utils.json import transform_json
from validity.utils.misc import reraise
from validity.utils.misc import log_exceptions, reraise
from ..exceptions import SerializationError
from .common import postprocess_jq


logger = logging.getLogger(__name__)


@log_exceptions(logger, "info", log_traceback=True)
@postprocess_jq
def serialize_xml(plain_data: str, template: str, parameters: dict):
with reraise(ExpatError, SerializationError, "Got invalid XML"):
with reraise(ExpatError, SerializationError, "Got invalid XML", orig_error_param=None):
result = xmltodict.parse(plain_data)
if parameters.get("drop_attributes"):
result = transform_json(
Expand Down
2 changes: 1 addition & 1 deletion validity/compliance/serialization/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@

@postprocess_jq
def serialize_yaml(plain_data: str, template: str, parameters: dict) -> dict:
with reraise(yaml.YAMLError, SerializationError, "Got invalid JSON/YAML"):
with reraise(yaml.YAMLError, SerializationError, "Got invalid JSON/YAML", orig_error_param=None):
return yaml.safe_load(plain_data)
3 changes: 1 addition & 2 deletions validity/compliance/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def from_commands(cls, commands: Iterable["Command"]):

def with_config(self, serializable: Serializable):
state_item = StateItem(serializer=serializable.serializer, data_file=serializable.data_file, command=None)
with suppress(SerializationError):
state_item.serialized # noqa: B018
if state_item.error is None or self.config_command_label is None:
super().__setitem__("config", state_item)
self.config_command_label = None
return self
Expand Down
2 changes: 1 addition & 1 deletion validity/pollers/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class CommandResult:
error_header: ClassVar[str] = "POLLING ERROR\n"

def __post_init__(self):
assert self.result or self.error is not None
assert self.result or self.error is not None, "Got empty result from device"

foldername = property(lambda self: slugify(str(self.device)))
filename = property(lambda self: self.command.label + ".txt")
Expand Down
18 changes: 18 additions & 0 deletions validity/tests/test_compliance/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,21 @@ def test_with_config(self):
assert state.get_full_item("config").command is None
cfg_item2 = StateItem(SerializerDBFactory(), DataFileFactory(), None)
assert state.with_config(cfg_item2).get_full_item("config") == cfg_item2

@pytest.mark.django_db
def test_with_config_errored(self):
cfg_serializable = Serializable(SerializerDBFactory(), None)
state = State({}).with_config(cfg_serializable)
assert state.get_full_item("config") == StateItem(
serializer=cfg_serializable.serializer, data_file=None, command=None
)
with pytest.raises(NoComponentError):
state["config"]

@pytest.mark.django_db
def test_with_config_no_override(self):
cfg_serializable = Serializable(SerializerDBFactory(), None)
command_config = state_item("item1", {"some": "config"})
state = State({"config": command_config}, config_command_label="item1").with_config(cfg_serializable)
assert state.get_full_item("config") == command_config
assert state["config"] == {"some": "config"}
14 changes: 13 additions & 1 deletion validity/tests/test_utils/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import operator
from contextlib import nullcontext
from dataclasses import dataclass
from unittest.mock import Mock

import pytest

from validity.utils.misc import partialcls, reraise
from validity.utils.misc import log_exceptions, partialcls, reraise
from validity.utils.version import NetboxVersion


Expand Down Expand Up @@ -77,3 +78,14 @@ class A:
assert A2(5) == A(5, 10)
assert A2(a=3, b=4) == A(3, 4)
assert type(A2(1)) is A


def test_log_exceptions():
logger = Mock()
with log_exceptions(logger, "info", log_traceback=True):
pass
logger.info.assert_not_called()
with pytest.raises(ValueError):
with log_exceptions(logger, "info", log_traceback=True):
raise ValueError("qwerty")
logger.info.assert_called_once_with(msg="qwerty", exc_info=True)
16 changes: 15 additions & 1 deletion validity/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager, suppress
from itertools import islice
from logging import Logger
from typing import TYPE_CHECKING, Any, Callable, Iterable

from core.exceptions import SyncError
Expand Down Expand Up @@ -31,7 +32,7 @@ def reraise(
catch: type[Exception] | tuple[type[Exception], ...],
raise_: type[Exception],
*args,
orig_error_param="orig_error",
orig_error_param: str | None = "orig_error",
**kwargs,
):
"""
Expand Down Expand Up @@ -98,3 +99,16 @@ def __new__(_, *new_args, **new_kwargs):
return cls(*new_args, **new_kwargs)

return type(cls.__name__, (cls,), {"__new__": __new__})


@contextmanager
def log_exceptions(logger: Logger, level: str, log_traceback=True):
"""
Log exceptions of a function/method/codeblock
"""
try:
yield
except Exception as exc:
log_method = getattr(logger, level)
log_method(msg=str(exc), exc_info=log_traceback)
raise

0 comments on commit 9900d85

Please sign in to comment.