diff --git a/validity/compliance/serialization/routeros.py b/validity/compliance/serialization/routeros.py index 720a124..73c11c7 100644 --- a/validity/compliance/serialization/routeros.py +++ b/validity/compliance/serialization/routeros.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from typing import Generator, Literal -from validity.utils.misc import reraise +from validity.utils.misc import log_exceptions, reraise from ..exceptions import SerializationError @@ -146,5 +146,6 @@ def parse_config(plain_config: str) -> dict: return result +@log_exceptions(logger, level="info", log_traceback=True) def serialize_ros(plain_data: str, template: str, parameters: dict): return parse_config(plain_data) diff --git a/validity/compliance/serialization/textfsm.py b/validity/compliance/serialization/textfsm.py index 814cd3b..35b2302 100644 --- a/validity/compliance/serialization/textfsm.py +++ b/validity/compliance/serialization/textfsm.py @@ -1,10 +1,16 @@ import io +import logging import textfsm +from validity.utils.misc import log_exceptions from .common import postprocess_jq +logger = logging.getLogger(__name__) + + +@log_exceptions(logger, "info", log_traceback=True) @postprocess_jq def serialize_textfsm(plain_data: str, template: str, parameters: dict) -> list[dict]: dict_results = [] diff --git a/validity/compliance/serialization/ttp.py b/validity/compliance/serialization/ttp.py index 0018a5a..da2fbfe 100644 --- a/validity/compliance/serialization/ttp.py +++ b/validity/compliance/serialization/ttp.py @@ -1,8 +1,15 @@ +import logging + from ttp import ttp +from validity.utils.misc import log_exceptions from .common import postprocess_jq +logger = logging.getLogger(__name__) + + +@log_exceptions(logger, "info", log_traceback=True) @postprocess_jq def serialize_ttp(plain_data: str, template: str, parameters: dict): parser = ttp(data=plain_data, template=template) diff --git a/validity/compliance/serialization/xml.py b/validity/compliance/serialization/xml.py index f1eff8c..2e6e3c7 100644 --- a/validity/compliance/serialization/xml.py +++ b/validity/compliance/serialization/xml.py @@ -1,16 +1,21 @@ +import logging from xml.parsers.expat import ExpatError import xmltodict from validity.utils.json import transform_json -from validity.utils.misc import reraise +from validity.utils.misc import log_exceptions, reraise from ..exceptions import SerializationError from .common import postprocess_jq +logger = logging.getLogger(__name__) + + +@log_exceptions(logger, "info", log_traceback=True) @postprocess_jq def serialize_xml(plain_data: str, template: str, parameters: dict): - with reraise(ExpatError, SerializationError, "Got invalid XML"): + with reraise(ExpatError, SerializationError, "Got invalid XML", orig_error_param=None): result = xmltodict.parse(plain_data) if parameters.get("drop_attributes"): result = transform_json( diff --git a/validity/compliance/serialization/yaml.py b/validity/compliance/serialization/yaml.py index 817aacd..ce30e6f 100644 --- a/validity/compliance/serialization/yaml.py +++ b/validity/compliance/serialization/yaml.py @@ -7,5 +7,5 @@ @postprocess_jq def serialize_yaml(plain_data: str, template: str, parameters: dict) -> dict: - with reraise(yaml.YAMLError, SerializationError, "Got invalid JSON/YAML"): + with reraise(yaml.YAMLError, SerializationError, "Got invalid JSON/YAML", orig_error_param=None): return yaml.safe_load(plain_data) diff --git a/validity/compliance/state.py b/validity/compliance/state.py index 27c6da5..43d7b69 100644 --- a/validity/compliance/state.py +++ b/validity/compliance/state.py @@ -67,8 +67,7 @@ def from_commands(cls, commands: Iterable["Command"]): def with_config(self, serializable: Serializable): state_item = StateItem(serializer=serializable.serializer, data_file=serializable.data_file, command=None) - with suppress(SerializationError): - state_item.serialized # noqa: B018 + if state_item.error is None or self.config_command_label is None: super().__setitem__("config", state_item) self.config_command_label = None return self diff --git a/validity/pollers/result.py b/validity/pollers/result.py index f908bbd..f02ca37 100644 --- a/validity/pollers/result.py +++ b/validity/pollers/result.py @@ -24,7 +24,7 @@ class CommandResult: error_header: ClassVar[str] = "POLLING ERROR\n" def __post_init__(self): - assert self.result or self.error is not None + assert self.result or self.error is not None, "Got empty result from device" foldername = property(lambda self: slugify(str(self.device))) filename = property(lambda self: self.command.label + ".txt") diff --git a/validity/tests/test_compliance/test_state.py b/validity/tests/test_compliance/test_state.py index b1fe966..f4e9c7d 100644 --- a/validity/tests/test_compliance/test_state.py +++ b/validity/tests/test_compliance/test_state.py @@ -93,3 +93,21 @@ def test_with_config(self): assert state.get_full_item("config").command is None cfg_item2 = StateItem(SerializerDBFactory(), DataFileFactory(), None) assert state.with_config(cfg_item2).get_full_item("config") == cfg_item2 + + @pytest.mark.django_db + def test_with_config_errored(self): + cfg_serializable = Serializable(SerializerDBFactory(), None) + state = State({}).with_config(cfg_serializable) + assert state.get_full_item("config") == StateItem( + serializer=cfg_serializable.serializer, data_file=None, command=None + ) + with pytest.raises(NoComponentError): + state["config"] + + @pytest.mark.django_db + def test_with_config_no_override(self): + cfg_serializable = Serializable(SerializerDBFactory(), None) + command_config = state_item("item1", {"some": "config"}) + state = State({"config": command_config}, config_command_label="item1").with_config(cfg_serializable) + assert state.get_full_item("config") == command_config + assert state["config"] == {"some": "config"} diff --git a/validity/tests/test_utils/test_misc.py b/validity/tests/test_utils/test_misc.py index d6a9a2d..3a6853f 100644 --- a/validity/tests/test_utils/test_misc.py +++ b/validity/tests/test_utils/test_misc.py @@ -1,10 +1,11 @@ import operator from contextlib import nullcontext from dataclasses import dataclass +from unittest.mock import Mock import pytest -from validity.utils.misc import partialcls, reraise +from validity.utils.misc import log_exceptions, partialcls, reraise from validity.utils.version import NetboxVersion @@ -77,3 +78,14 @@ class A: assert A2(5) == A(5, 10) assert A2(a=3, b=4) == A(3, 4) assert type(A2(1)) is A + + +def test_log_exceptions(): + logger = Mock() + with log_exceptions(logger, "info", log_traceback=True): + pass + logger.info.assert_not_called() + with pytest.raises(ValueError): + with log_exceptions(logger, "info", log_traceback=True): + raise ValueError("qwerty") + logger.info.assert_called_once_with(msg="qwerty", exc_info=True) diff --git a/validity/utils/misc.py b/validity/utils/misc.py index 834d54a..880faa2 100644 --- a/validity/utils/misc.py +++ b/validity/utils/misc.py @@ -2,6 +2,7 @@ from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager, suppress from itertools import islice +from logging import Logger from typing import TYPE_CHECKING, Any, Callable, Iterable from core.exceptions import SyncError @@ -31,7 +32,7 @@ def reraise( catch: type[Exception] | tuple[type[Exception], ...], raise_: type[Exception], *args, - orig_error_param="orig_error", + orig_error_param: str | None = "orig_error", **kwargs, ): """ @@ -98,3 +99,16 @@ def __new__(_, *new_args, **new_kwargs): return cls(*new_args, **new_kwargs) return type(cls.__name__, (cls,), {"__new__": __new__}) + + +@contextmanager +def log_exceptions(logger: Logger, level: str, log_traceback=True): + """ + Log exceptions of a function/method/codeblock + """ + try: + yield + except Exception as exc: + log_method = getattr(logger, level) + log_method(msg=str(exc), exc_info=log_traceback) + raise