diff --git a/validity/compliance/device_config/__init__.py b/validity/compliance/device_config/__init__.py deleted file mode 100644 index 8230d54..0000000 --- a/validity/compliance/device_config/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .base import DeviceConfig -from .routeros import RouterOSDeviceConfig -from .ttp import TTPDeviceConfig -from .yaml import YAMLDeviceConfig diff --git a/validity/compliance/device_config/base.py b/validity/compliance/device_config/base.py deleted file mode 100644 index 3c4d8fe..0000000 --- a/validity/compliance/device_config/base.py +++ /dev/null @@ -1,56 +0,0 @@ -from abc import abstractmethod -from dataclasses import dataclass -from datetime import datetime -from typing import TYPE_CHECKING, ClassVar - -from validity.utils.misc import reraise -from ..exceptions import SerializationError - - -if TYPE_CHECKING: - from validity.models import VDevice - - -@dataclass -class BaseDeviceConfig: - device: "VDevice" - plain_config: str - last_modified: datetime | None = None - serialized: dict | list | None = None - - _config_classes: ClassVar[dict[str, type]] = {} - - @classmethod - def from_device(cls, device: "VDevice") -> "BaseDeviceConfig": - """ - Get DeviceConfig from dcim.models.Device - Device MUST be annotated with ".data_file" - Device MUST be annotated with ".serializer" pointing to appropriate config serializer instance - """ - with reraise((AssertionError, FileNotFoundError, AttributeError), SerializationError): - assert getattr( - device, "data_file", None - ), f"{device} has no bound data file. Either there is no data source attached or the file does not exist" - assert getattr(device, "serializer", None), f"{device} has no bound serializer" - return cls._config_classes[device.serializer.extraction_method]._from_device(device) - - @classmethod - def _from_device(cls, device: "VDevice") -> "BaseDeviceConfig": - instance = cls(device, device.data_file.data_as_string, device.data_file.last_updated) - instance.serialize() - return instance - - @abstractmethod - def serialize(self, override: bool = False) -> None: - pass - - -class DeviceConfigMeta(type): - def __init__(cls, name, bases, dct): - if name != "DeviceConfig": - BaseDeviceConfig._config_classes[dct["extract_method"]] = cls - super().__init__(name, bases, dct) - - -class DeviceConfig(BaseDeviceConfig, metaclass=DeviceConfigMeta): - pass diff --git a/validity/compliance/device_config/routeros.py b/validity/compliance/device_config/routeros.py deleted file mode 100644 index 8d88804..0000000 --- a/validity/compliance/device_config/routeros.py +++ /dev/null @@ -1,156 +0,0 @@ -import io -import logging -import re -from dataclasses import dataclass, field -from typing import ClassVar, Generator, Literal - -from validity.utils.misc import reraise -from ..exceptions import SerializationError -from .base import DeviceConfig - - -logger = logging.getLogger(__name__) - - -class LineParsingError(SerializationError): - pass - - -def non_quoted_characters(line: str) -> Generator[tuple[int, str], None, None]: - """ - Generator returns pairs (char_position, char) for each char in line not placed inside the quotes - Quoted substring will be returned as 1 single character with char_position equal to first character - """ - - quote_open = False - quote_start = -1 - for i, char in enumerate(line): - if char == '"' and (not i or line[i - 1] != "\\"): - quote_open = not quote_open - if quote_open: - quote_start = i - else: - yield i, line[quote_start : i + 1] - continue - if quote_open: - continue - yield i, char - - -@dataclass -class ParsedLine: - method: Literal["add", "set"] - find_by: tuple[str, str] | tuple[()] = () - properties: dict[str, str] = field(default_factory=dict) - implicit_name: bool = False - - @classmethod - def _extract_find(cls, line: str) -> tuple[tuple[str, str | int | bool], str]: - find, line = line.split("]", maxsplit=1) - find = find[1:].replace("find", "", 1).strip() - find_key, find_value = find.split("=", maxsplit=1) - find_value = cls._transform_value(find_key, find_value) - return (find_key, find_value), line - - @staticmethod - def _replace_line_breaks(line: str) -> str: - drop_match = re.compile(r"\\\n +") - new_line = [] - backslash_seq = [] - for _, char in non_quoted_characters(line): - if char == "\\" or char in {"\n", " "} and backslash_seq: - backslash_seq.append(char) - continue - if not drop_match.fullmatch("".join(backslash_seq)): - new_line.extend(backslash_seq) - backslash_seq = [] - new_line.append(char) - return "".join(new_line) - - @staticmethod - def _transform_value(key: str, value: str) -> str | int | bool: - if value and len(value) > 2 and value[0] == '"' and value[-1] == '"': - value = value[1:-1] - if key in {"name", "comment"}: - return value - if value.isdigit(): - return int(value) - booleans = {"yes": True, "no": False} - if value in booleans: - return booleans[value] - return value - - @classmethod - def from_plain_text(cls, line: str) -> "ParsedLine": - method, line = line.split(maxsplit=1) - if method not in {"add", "set"}: - raise LineParsingError("Unknown line") - find = () - if line.startswith("["): - find, line = cls._extract_find(line) - properties = {} - sub_start = 0 - implicit_name = False - line = cls._replace_line_breaks(line).strip() - for char_num, char in non_quoted_characters(line + " "): - if char == " ": - kvline = line[sub_start:char_num].strip(" \n") - if kvline and "=" not in kvline: - kvline = "name=" + kvline - implicit_name = True - with reraise(ValueError, LineParsingError, f'"{kvline}" cannot be split into key/value'): - key, value = kvline.split("=", maxsplit=1) - properties[key] = cls._transform_value(key, value) - sub_start = char_num + 1 - return cls(method=method, find_by=find, properties=properties, implicit_name=implicit_name) - - -def parse_config(plain_config: str) -> dict: - result = {} - context_path = [] - prevlines = [] - cfgfile = io.StringIO(plain_config) - for line_num, line in enumerate(cfgfile, start=1): - if line.startswith(("#", ":")) or line == "\n": - continue - if line.startswith("/"): - context_path = line[1:-1].split() - continue - if line.endswith("\\\n"): - prevlines.append(line) - continue - if prevlines: - line = "".join(prevlines) + line - prevlines = [] - try: - parsed_line = ParsedLine.from_plain_text(line) - except LineParsingError as e: - e.args = (e.args[0] + f", config line {line_num}",) + e.args[1:] - raise - current_context = result - for key in context_path: - try: - current_context = current_context[key] - except KeyError: - current_context[key] = {} - current_context = current_context[key] - if parsed_line.find_by or parsed_line.method == "add" or parsed_line.implicit_name: - if "values" not in current_context: - current_context["values"] = [] - current_context["values"].append(parsed_line.properties) - if parsed_line.find_by: - current_context["values"][-1]["find_by"] = [ - {"key": parsed_line.find_by[0], "value": parsed_line.find_by[1]} - ] - else: - current_context["properties"] = parsed_line.properties - return result - - -class RouterOSDeviceConfig(DeviceConfig): - extract_method: ClassVar[str] = "ROUTEROS" - - def serialize(self, override: bool = False) -> None: - if not self.serialized or override: - with reraise(Exception, SerializationError): - self.serialized = parse_config(self.plain_config) diff --git a/validity/compliance/device_config/ttp.py b/validity/compliance/device_config/ttp.py deleted file mode 100644 index 382159a..0000000 --- a/validity/compliance/device_config/ttp.py +++ /dev/null @@ -1,32 +0,0 @@ -from dataclasses import dataclass, field -from typing import ClassVar - -from ttp import ttp - -from validity.utils.misc import reraise -from ..exceptions import SerializationError -from .base import DeviceConfig - - -@dataclass -class TTPTemplate: - name: str - template: str = "" # according to TTP API this may contain a template itself or a filepath - - -@dataclass -class TTPDeviceConfig(DeviceConfig): - extract_method: ClassVar[str] = "TTP" - _template: TTPTemplate = field(init=False) - - def __post_init__(self): - self._template = TTPTemplate( - name=self.device.serializer.name, template=self.device.serializer.effective_template - ) - - def serialize(self, override: bool = False) -> None: - if not self.serialized or override: - parser = ttp(data=self.plain_config, template=self._template.template) - parser.parse() - with reraise(IndexError, SerializationError, f"Invalid parsed config for {self.device}: {parser.result()}"): - self.serialized = parser.result()[0][0] diff --git a/validity/compliance/device_config/yaml.py b/validity/compliance/device_config/yaml.py deleted file mode 100644 index f2867de..0000000 --- a/validity/compliance/device_config/yaml.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import ClassVar - -import yaml - -from validity.utils.misc import reraise -from ..exceptions import SerializationError -from .base import DeviceConfig - - -class YAMLDeviceConfig(DeviceConfig): - extract_method: ClassVar[str] = "YAML" - - def serialize(self, override: bool = False) -> None: - if not self.serialized or override: - with reraise( - yaml.YAMLError, - SerializationError, - f"Trying to parse invalid YAML as device config for {self.device}", - ): - self.serialized = yaml.safe_load(self.plain_config) diff --git a/validity/compliance/state.py b/validity/compliance/state.py index fdc6091..1348dd7 100644 --- a/validity/compliance/state.py +++ b/validity/compliance/state.py @@ -76,7 +76,7 @@ def with_config(self, serializable: Serializable): def _blocked_op(self, *_): raise AttributeError("State is read only") - __setitem__ = __delitem__ = __ior__ = pop = popitem = update = setdefault = clear = _blocked_op + __setitem__ = __delitem__ = pop = popitem = update = setdefault = clear = _blocked_op def __getattr__(self, key): return self[key] diff --git a/validity/managers.py b/validity/managers.py index 25d6491..7994f36 100644 --- a/validity/managers.py +++ b/validity/managers.py @@ -141,7 +141,7 @@ def annotate_datasource_id(self): from validity.models import VDataSource return self.annotate( - bound_source=Cast(KeyTextTransform("config_data_source", "tenant__custom_field_data"), BigIntegerField()) + bound_source=Cast(KeyTextTransform("data_source", "tenant__custom_field_data"), BigIntegerField()) ).annotate( data_source_id=Case( When(bound_source__isnull=False, then=F("bound_source")), diff --git a/validity/migrations/0005_datasources.py b/validity/migrations/0005_datasources.py index 0f112b1..a734264 100644 --- a/validity/migrations/0005_datasources.py +++ b/validity/migrations/0005_datasources.py @@ -51,8 +51,8 @@ def setup_datasource_cf(apps, schema_editor): for cf in datasource_cfs: cf.content_types.set([ContentType.objects.get_for_model(DataSource)]) tenant_cf = CustomField.objects.using(db).create( - name="config_data_source", - label=_("Config Data Source"), + name="data_source", + label=_("Data Source"), description=_("Required by Validity"), type="object", object_type=ContentType.objects.get_for_model(DataSource), diff --git a/validity/models/polling.py b/validity/models/polling.py index 5b5024c..380fad0 100644 --- a/validity/models/polling.py +++ b/validity/models/polling.py @@ -1,4 +1,3 @@ -from contextlib import contextmanager from functools import cached_property from typing import Collection @@ -102,19 +101,6 @@ def config_command(self) -> Command | None: def get_backend(self): return get_poller(self.connection_type, self.credentials, self.commands.all()) - def serialize_object(self): - with self.serializable_credentials(): - return super().serialize_object() - - @contextmanager - def serializable_credentials(self): - private_creds = self.private_credentials - try: - self.private_credentials = self.private_credentials.encrypted - yield - finally: - self.private_credentials = private_creds - @staticmethod def validate_commands(connection_type: str, commands: Collection[Command]): # All the commands must be of the matching type diff --git a/validity/tests/conftest.py b/validity/tests/conftest.py index 8a5bb0f..9d730e0 100644 --- a/validity/tests/conftest.py +++ b/validity/tests/conftest.py @@ -31,7 +31,7 @@ def create_custom_fields(db): required=False, ), CustomField( - name="config_data_source", + name="data_source", type="object", object_type=ContentType.objects.get_for_model(DataSource), required=False, @@ -52,6 +52,11 @@ def create_custom_fields(db): type="string", required=False, ), + CustomField( + name="device_command_path", + type="string", + required=False, + ), CustomField( name="poller", type="object", @@ -68,9 +73,9 @@ def create_custom_fields(db): ] ) cfs[1].content_types.set([ContentType.objects.get_for_model(Tenant)]) - for cf in cfs[2:5]: + for cf in cfs[2:6]: cf.content_types.set([ContentType.objects.get_for_model(DataSource)]) - cfs[5].content_types.set( + cfs[6].content_types.set( [ ContentType.objects.get_for_model(Device), ContentType.objects.get_for_model(DeviceType), diff --git a/validity/tests/factories.py b/validity/tests/factories.py index 6b5e7b0..88142ac 100644 --- a/validity/tests/factories.py +++ b/validity/tests/factories.py @@ -7,6 +7,7 @@ from tenancy.models import Tenant from validity import models +from validity.compliance.state import StateItem class DataSourceFactory(DjangoModelFactory): @@ -30,18 +31,6 @@ class Meta: model = models.VDataFile -class ConfigFileFactory(DataFileFactory): - path = "file-1.txt" - source = factory.SubFactory( - DataSourceFactory, - custom_field_data={ - "default": True, - "device_config_path": path, - "web_url": "http://some_url.com/", - }, - ) - - class DataSourceLinkFactory(DjangoModelFactory): data_source = factory.SubFactory(DataSourceFactory) data_file = factory.SubFactory(DataFileFactory, source=data_source, data=factory.SelfAttribute("..contents_bin")) @@ -231,3 +220,18 @@ class PollerFactory(DjangoModelFactory): class Meta: model = models.Poller + + +_NOT_DEFINED = object() + + +def state_item(name, serialized, data_file=_NOT_DEFINED, command=_NOT_DEFINED): + if data_file == _NOT_DEFINED: + data_file = DataFileFactory() + if command == _NOT_DEFINED: + command = CommandFactory() + command.label = name + serializer = SerializerDBFactory() + item = StateItem(serializer, data_file, command) + item.__dict__["serialized"] = serialized + return item diff --git a/validity/tests/test_api.py b/validity/tests/test_api.py index c656118..2b20def 100644 --- a/validity/tests/test_api.py +++ b/validity/tests/test_api.py @@ -1,14 +1,11 @@ from http import HTTPStatus -from unittest.mock import Mock import pytest from base import ApiGetTest, ApiPostGetTest -from django.utils import timezone from factories import ( CommandFactory, CompTestDBFactory, CompTestResultFactory, - ConfigFileFactory, DataFileFactory, DataSourceFactory, DeviceFactory, @@ -18,13 +15,13 @@ PlatformFactory, ReportFactory, SelectorFactory, - SerializerDBFactory, SiteFactory, TagFactory, TenantFactory, + state_item, ) -from validity.compliance.device_config import DeviceConfig +from validity.models import VDevice class TestDBNameSet(ApiPostGetTest): @@ -148,27 +145,26 @@ class TestPoller(ApiPostGetTest): } +@pytest.mark.parametrize("params", [{}, {"fields": ["name", "value"]}, {"name": ["config", "bad_cmd"]}]) @pytest.mark.django_db -def test_get_serialized_config(monkeypatch, admin_client): +def test_get_serialized_state(admin_client, params, monkeypatch): device = DeviceFactory() - config_file = ConfigFileFactory() - device.custom_field_data["serializer"] = SerializerDBFactory().pk - device.save() - device.data_source = config_file.source - lm = timezone.now() - config = DeviceConfig(device=device, plain_config="", last_modified=lm, serialized={"key1": "value1"}) - monkeypatch.setattr(DeviceConfig, "from_device", Mock(return_value=config)) - resp = admin_client.get(f"/api/dcim/devices/{device.pk}/serialized_config/") - assert resp.status_code == HTTPStatus.OK - assert resp.json().keys() == { - "data_source", - "data_file", - "local_copy_last_updated", - "config_web_link", - "serialized_config", + state = { + "config": state_item("config", {"vlans": [1, 2, 3]}), + "show_ver": state_item("show_ver", {"version": "v1.2.3"}), + "bad_cmd": state_item("bad_cmd", {}, data_file=None), } - assert resp.json()["serialized_config"] == {"key1": "value1"} - assert resp.json()["local_copy_last_updated"] == lm.isoformat().replace("+00:00", "Z") + monkeypatch.setattr(VDevice, "state", state) + resp = admin_client.get(f"/api/dcim/devices/{device.pk}/serialized_state/", params) + assert resp.status_code == HTTPStatus.OK + answer = resp.json() + expected_result_count = len(params.get("name", [])) or 3 + assert len(answer["results"]) == expected_result_count + assert answer["count"] == expected_result_count + for api_item in answer["results"]: + if "fields" in params: + assert api_item.keys() == set(params["fields"]) + assert state[api_item["name"]].serialized == api_item["value"] @pytest.mark.django_db diff --git a/validity/tests/test_compliance/test_eval.py b/validity/tests/test_compliance/test_eval.py index 168bfdf..aad52e8 100644 --- a/validity/tests/test_compliance/test_eval.py +++ b/validity/tests/test_compliance/test_eval.py @@ -5,7 +5,6 @@ import pytest from deepdiff.serialization import json_dumps -from simpleeval import InvalidExpression from validity.compliance.eval import ExplanationalEval, default_nameset, eval_defaults from validity.compliance.exceptions import EvalError @@ -31,7 +30,7 @@ pytest.param(EXPR_1, EXPLANATION_1, None, id="EXPR_1"), pytest.param(EXPR_2, EXPLANATION_2, None, id="EXPR_2"), pytest.param("some invalid syntax", [], EvalError, id="invalid syntax"), - pytest.param("def f(): pass", [], InvalidExpression, id="invalif expression"), + pytest.param("def f(): pass", [], EvalError, id="invalid expression"), ], ) def test_explanation(expression, explanation, error): diff --git a/validity/tests/test_compliance/test_device_config.py b/validity/tests/test_compliance/test_serialization.py similarity index 69% rename from validity/tests/test_compliance/test_device_config.py rename to validity/tests/test_compliance/test_serialization.py index 2cf1051..144a630 100644 --- a/validity/tests/test_compliance/test_device_config.py +++ b/validity/tests/test_compliance/test_serialization.py @@ -1,12 +1,9 @@ import json -from unittest.mock import Mock import pytest import yaml -from factories import DataFileFactory, DeviceFactory -from validity.compliance.device_config import DeviceConfig -from validity.models.data import VDataFile +from validity.compliance.serialization import serialize JSON_CONFIG = """ @@ -101,22 +98,15 @@ @pytest.mark.parametrize( - "extraction_method, contents, serialized", + "extraction_method, contents, template, serialized", [ - pytest.param("YAML", JSON_CONFIG, json.loads(JSON_CONFIG), id="YAML-JSON"), - pytest.param("YAML", YAML_CONFIG, yaml.safe_load(YAML_CONFIG), id="YAML"), - pytest.param("TTP", TTP_CONFIG, TTP_SERIALIZED, id="TTP"), - pytest.param("ROUTEROS", ROUTEROS_CONFIG, ROUTEROS_SERIALIZED, id="ROUTEROS"), + pytest.param("YAML", JSON_CONFIG, "", json.loads(JSON_CONFIG), id="YAML-JSON"), + pytest.param("YAML", YAML_CONFIG, "", yaml.safe_load(YAML_CONFIG), id="YAML"), + pytest.param("TTP", TTP_CONFIG, TTP_TEMPLATE, TTP_SERIALIZED, id="TTP"), + pytest.param("ROUTEROS", ROUTEROS_CONFIG, "", ROUTEROS_SERIALIZED, id="ROUTEROS"), ], ) @pytest.mark.django_db -def test_device_config(extraction_method, contents, serialized): - device = DeviceFactory() - device.serializer = Mock(name="some_serializer", extraction_method=extraction_method) - if extraction_method == "TTP": - device.serializer.effective_template = TTP_TEMPLATE - DataFileFactory(data=contents.encode()) - device.data_file = VDataFile.objects.first() - device_config = DeviceConfig.from_device(device) - assert extraction_method.lower() in type(device_config).__name__.lower() - assert device_config.serialized == serialized +def test_serialization(extraction_method, contents, template, serialized): + serialize_result = serialize(extraction_method, contents, template) + assert serialize_result == serialized diff --git a/validity/tests/test_compliance/test_state.py b/validity/tests/test_compliance/test_state.py new file mode 100644 index 0000000..887d647 --- /dev/null +++ b/validity/tests/test_compliance/test_state.py @@ -0,0 +1,95 @@ +from contextlib import nullcontext +from unittest.mock import Mock + +import pytest +from factories import CommandFactory, DataFileFactory, SerializerDBFactory, state_item + +from validity.compliance.exceptions import NoComponentError, SerializationError +from validity.compliance.serialization import Serializable +from validity.compliance.state import State, StateItem + + +class TestStateItem: + @pytest.mark.parametrize( + "command_kwargs, contains_config, name, verbose_name", + [ + (None, True, "config", "Config"), + ({"retrieves_config": True}, True, "config", "Config"), + ({"retrieves_config": False, "name": "Cmd1", "label": "cmd1"}, False, "cmd1", "Cmd1"), + ], + ) + @pytest.mark.django_db + def test_contains_config(self, command_kwargs, contains_config, name, verbose_name): + command = CommandFactory(**command_kwargs) if command_kwargs is not None else None + serializer = SerializerDBFactory() + data_file = DataFileFactory() + item = StateItem(serializer, data_file, command) + assert item.contains_config == contains_config + assert item.name == name + assert item.verbose_name == verbose_name + + @pytest.mark.parametrize( + "has_datafile, has_serializer, expected_error, serialized", + [ + (True, True, None, {"some": ["serialized", "data"]}), + (False, True, NoComponentError, None), + (True, False, NoComponentError, None), + (False, False, NoComponentError, None), + ], + ) + def test_serialized(self, has_datafile, has_serializer, expected_error, serialized): + serializer = Mock(**{"serialize.return_value": serialized}) if has_serializer else None + data_file = Mock() if has_datafile else None + item = StateItem(serializer, data_file, None) + ctx = pytest.raises(expected_error) if expected_error is not None else nullcontext() + if expected_error is not None: + assert isinstance(item.error, expected_error) + with ctx: + assert item.serialized == serialized + if has_serializer and has_datafile: + serializer.serialize.assert_called_once_with(data_file.data_as_string) + + +class TestState: + @pytest.mark.django_db + def test_get_item(self): + item1 = state_item("item1", {"k1": "v1"}) + item2 = state_item("item2", {"k2": "v2"}) + item_err = state_item("item_err", {}, data_file=None) + del item_err.__dict__["serialized"] + state = State({item1.name: item1, item2.name: item2, item_err.name: item_err}) + assert state["item1"] == state.item1 == {"k1": "v1"} + assert state["item2"] == state.item2 == {"k2": "v2"} + assert state.get_full_item("item1") == item1 + assert state.get("item3") is None + assert state.get("item_err", ignore_errors=True) is None + with pytest.raises(SerializationError): + state.get("item_err") + + @pytest.mark.django_db + def test_from_commands(self): + s1 = SerializerDBFactory() + s2 = SerializerDBFactory() + f1 = DataFileFactory() + f2 = DataFileFactory() + cmd1 = CommandFactory(serializer=s1) + cmd2 = CommandFactory(serializer=s2) + cmd1.data_file = f1 + cmd2.data_file = f2 + cmd2.retrieves_config = True + state = State.from_commands([cmd1, cmd2]) + assert state.keys() == {cmd1.label, "config"} + assert state.config_command_label == cmd2.label + assert state.get_full_item(cmd1.label) == StateItem(s1, f1, cmd1) + assert state.get_full_item("config") == StateItem(s2, f2, cmd2) + + @pytest.mark.django_db + def test_with_config(self): + items = [state_item("item1", {}), state_item("item2", {})] + cfg_item = Serializable(SerializerDBFactory(), DataFileFactory()) + state = State({i.name: i for i in items}).with_config(cfg_item) + assert state.get_full_item("config").serializer == cfg_item.serializer + assert state.get_full_item("config").data_file == cfg_item.data_file + assert state.get_full_item("config").command is None + cfg_item2 = StateItem(SerializerDBFactory(), DataFileFactory(), None) + assert state.with_config(cfg_item2).get_full_item("config") == cfg_item2 diff --git a/validity/tests/test_managers.py b/validity/tests/test_managers.py index 739357e..731d9b9 100644 --- a/validity/tests/test_managers.py +++ b/validity/tests/test_managers.py @@ -2,9 +2,9 @@ from unittest.mock import Mock import pytest -from factories import CompTestDBFactory, DeviceFactory +from factories import CommandFactory, CompTestDBFactory, DataSourceFactory, DeviceFactory -from validity.models import ComplianceReport, ComplianceTestResult +from validity.models import Command, ComplianceReport, ComplianceTestResult @pytest.mark.parametrize("store_results", [3, 2, 1]) @@ -42,3 +42,16 @@ def test_delete_old_reports(store_reports): reports = [ComplianceReport.objects.create() for _ in range(10)] ComplianceReport.objects.delete_old(_settings=Mock(store_reports=store_reports)) assert list(ComplianceReport.objects.order_by("created")) == reports[-store_reports:] + + +@pytest.mark.django_db +def test_set_file_paths(create_custom_fields): + CommandFactory(label="cmd1") + CommandFactory(label="cmd2") + device = DeviceFactory(name="d1") + ds = DataSourceFactory( + name="ds1", custom_field_data={"device_command_path": "path/{{device.name}}/{{command.label}}"} + ) + commands = Command.objects.set_file_paths(device=device, data_source=ds) + for cmd in commands: + assert cmd.path == f"path/d1/{cmd.label}" diff --git a/validity/tests/test_models/test_compliancetest.py b/validity/tests/test_models/test_compliancetest.py new file mode 100644 index 0000000..b5c0744 --- /dev/null +++ b/validity/tests/test_models/test_compliancetest.py @@ -0,0 +1,29 @@ +from unittest.mock import MagicMock + +import pytest +from factories import CompTestDBFactory, DataSourceFactory, DeviceFactory, PollerFactory + + +@pytest.mark.django_db +def test_run_test(monkeypatch): + test = CompTestDBFactory(expression="1==1") + device = DeviceFactory() + device.data_source = DataSourceFactory() + device.poller = PollerFactory() + evaluator_cls = MagicMock() + monkeypatch.setattr(test, "evaluator_cls", evaluator_cls) + functions = {"f1": lambda x: x * 10} + names = {"name1": 10} + verbosity = object() + passed, explanation = test.run(device, functions, names, verbosity) + assert passed == evaluator_cls.return_value.eval.return_value.__bool__.return_value + evaluator_cls.return_value.eval.assert_called_once_with("1==1") + assert explanation == evaluator_cls.return_value.explanation + evaluator_cls.assert_called_once() + assert evaluator_cls.call_args.kwargs["functions"] == functions + assert evaluator_cls.call_args.kwargs["names"] == names | { + "device": device, + "_poller": device.poller, + "_data_source": device.data_source, + } + assert evaluator_cls.call_args.kwargs["verbosity"] == verbosity diff --git a/validity/tests/test_models/test_vdatasource.py b/validity/tests/test_models/test_vdatasource.py index 6c7b667..ca8f340 100644 --- a/validity/tests/test_models/test_vdatasource.py +++ b/validity/tests/test_models/test_vdatasource.py @@ -1,11 +1,12 @@ from contextlib import suppress from pathlib import Path from tempfile import TemporaryDirectory -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock import pytest from core.choices import DataSourceStatusChoices -from factories import DataFileFactory, DataSourceFactory +from core.models import DataSource +from factories import CommandFactory, DataFileFactory, DataSourceFactory, DeviceFactory from validity.models import VDataFile, VDataSource @@ -46,3 +47,33 @@ def test_partial_sync(monkeypatch): assert VDataFile.objects.get(path="file-1.txt").data_as_string == "qwe" assert VDataFile.objects.get(path="file_new.txt").data_as_string == "rty" assert VDataSource.objects.get(pk=ds.pk).status == DataSourceStatusChoices.COMPLETED + + +@pytest.mark.django_db +def test_sync_with_param(monkeypatch): + ds = DataSourceFactory(type="device_polling") + monkeypatch.setattr(DataSource, "sync", Mock()) + monkeypatch.setattr(VDataSource, "partial_sync", Mock()) + ds.sync() + DataSource.sync.assert_called_once_with() + assert VDataSource.partial_sync.call_count == 0 + filtr = object() + ds.sync(filtr) + VDataSource.partial_sync.assert_called_once_with(filtr) + + +@pytest.mark.django_db +def test_get_path(create_custom_fields): + device = DeviceFactory(name="device 1") + command = CommandFactory(label="cmd1") + ds = DataSourceFactory( + custom_field_data={ + "web_url": "http://ex.com/{{branch}}", + "device_config_path": "cfg/{{device.name}}.cfg", + "device_command_path": "state/{{device | slugify}}/{{command.label}}.txt", + }, + parameters={"branch": "b1"}, + ) + assert ds.web_url == "http://ex.com/b1" + assert ds.get_config_path(device) == "cfg/device 1.cfg" + assert ds.get_command_path(device, command) == "state/device-1/cmd1.txt" diff --git a/validity/tests/test_models/test_vdevice.py b/validity/tests/test_models/test_vdevice.py index ba32db3..302d93f 100644 --- a/validity/tests/test_models/test_vdevice.py +++ b/validity/tests/test_models/test_vdevice.py @@ -2,7 +2,7 @@ import pytest from factories import ( - ConfigFileFactory, + DataFileFactory, DataSourceFactory, DeviceFactory, SelectorFactory, @@ -10,6 +10,7 @@ TenantFactory, ) +from validity.compliance.serialization import Serializable from validity.models import VDevice @@ -30,7 +31,7 @@ def setup_serializers(create_custom_fields): @pytest.mark.django_db def test_datasource_tenant(create_custom_fields): datasource = DataSourceFactory() - tenant = TenantFactory(custom_field_data={"config_data_source": datasource.pk}) + tenant = TenantFactory(custom_field_data={"data_source": datasource.pk}) DeviceFactory(tenant=tenant) device = VDevice.objects.prefetch_datasource().first() assert device.data_source == datasource @@ -44,14 +45,6 @@ def test_datasource_default(create_custom_fields): assert device.data_source == datasource -@pytest.mark.django_db -def test_data_file(create_custom_fields): - DeviceFactory() - data_file = ConfigFileFactory() - device = VDevice.objects.prefetch_datasource().first() - assert device.data_file == data_file - - @pytest.mark.django_db def test_serializer(setup_serializers, subtests): device_serializer_map = setup_serializers @@ -61,24 +54,25 @@ def test_serializer(setup_serializers, subtests): assert d.serializer.pk == device_serializer_map[d.pk] -@pytest.mark.django_db -def test_config_path(create_custom_fields): - DeviceFactory(name="device1") - DataSourceFactory(custom_field_data={"device_config_path": "path/{{device.name}}.cfg", "default": True}) - device = VDevice.objects.prefetch_datasource().first() - assert device.config_path == "path/device1.cfg" - - @pytest.mark.parametrize("qs", [VDevice.objects.all(), VDevice.objects.filter(name__in=["d1", "d2"])]) @pytest.mark.django_db def test_set_selector(qs, subtests): for name in ["d1", "d2", "d3"]: DeviceFactory(name=name) selector = SelectorFactory() - some_model = qs.first() - assert some_model.selector is None qs = qs.set_selector(selector) for i, queryset in enumerate([qs, qs.select_related(), qs.filter(name="d1")]): with subtests.test(id=f"qs-{i}"): for model in queryset: assert model.selector == selector + + +def test_config_item(create_custom_fields): + ds = DataSourceFactory(name="ds1", custom_field_data={"device_config_path": "path/{{device.name}}.txt"}) + data_file = DataFileFactory(source=ds, path="path/d1.txt") + device = DeviceFactory(name="d1") + device.serializer = SerializerDBFactory() + device.data_source = ds + assert device._config_item() == Serializable(device.serializer, data_file) + device.data_source = None + assert device._config_item() == Serializable(device.serializer, None) diff --git a/validity/tests/test_scripts/test_run_tests.py b/validity/tests/test_scripts/test_run_tests.py index 57e2818..ce94c1c 100644 --- a/validity/tests/test_scripts/test_run_tests.py +++ b/validity/tests/test_scripts/test_run_tests.py @@ -84,31 +84,12 @@ def test_builtins_are_available_in_nameset(definitions): functions["func"]() -def test_run_test(monkeypatch): - script = RunTestsScript() - nm_functions = Mock() - evaluator_cls = Mock(return_value=Mock(explanation=[("var1", "val1")])) - monkeypatch.setattr(script, "nameset_functions", nm_functions) - monkeypatch.setattr(run_tests, "ExplanationalEval", evaluator_cls) - device = Mock() - test = Mock() - passed, explanation = script.run_test(device, test) - assert passed # bool(Mock()) is True - assert explanation - nm_functions.assert_called_once_with(test.namesets.all()) - evaluator_cls.assert_called_once_with( - functions=nm_functions.return_value, names={"device": device}, load_defaults=True - ) - evaluator_cls.return_value.eval.assert_called_once_with(test.effective_expression) - - @pytest.mark.parametrize( "run_test_mock", [ Mock(return_value=(True, [("expla", "nation")])), Mock(return_value=(False, [("1", "2"), ("3", "4")])), - Mock(side_effect=InvalidExpression()), - Mock(side_effect=EvalError(InvalidExpression())), + Mock(side_effect=EvalError(orig_error=InvalidExpression())), ], ) def test_run_tests_for_device(mock_script_logging, run_test_mock, monkeypatch): @@ -140,20 +121,14 @@ def test_run_tests_for_selector(mock_script_logging, monkeypatch): script = RunTestsScript() devices = [Mock(name="device1"), Mock(name="device2")] monkeypatch.setattr(script, "run_tests_for_device", Mock(return_value=range(3))) - selector = Mock( - name="selector", - **{ - "devices.select_related.return_value" - ".prefetch_datasource.return_value" - ".prefetch_serializer.return_value" - ".prefetch_poller.return_value": devices - } - ) + monkeypatch.setattr(script, "get_device_qs", Mock(return_value=devices)) + selector = Mock() report = Mock() - list(script.run_tests_for_selector(selector, report, [])) + list(script.run_tests_for_selector(selector, report)) assert script.run_tests_for_device.call_count == len(devices) script.run_tests_for_device.assert_any_call(selector.tests.all(), devices[0], report) script.run_tests_for_device.assert_any_call(selector.tests.all(), devices[1], report) + script.get_device_qs.assert_called_once_with(selector) @pytest.mark.django_db diff --git a/validity/tests/test_views.py b/validity/tests/test_views.py index 9e2cdc8..233a0c0 100644 --- a/validity/tests/test_views.py +++ b/validity/tests/test_views.py @@ -7,7 +7,6 @@ CommandFactory, CompTestDBFactory, CompTestResultFactory, - ConfigFileFactory, DataFileFactory, DataSourceFactory, DeviceFactory, @@ -25,9 +24,11 @@ SiteFactory, TagFactory, TenantFactory, + state_item, ) from validity import models +from validity.compliance.state import State class TestDBNameSet(ViewTest): @@ -139,14 +140,20 @@ class TestDSTest(ViewTest): } +@pytest.mark.parametrize("item", [None, "config", "show_ver", "bad_cmd", "non-existent"]) @pytest.mark.django_db -def test_device_results(admin_client): +def test_get_serialized_state(admin_client, item, monkeypatch): device = DeviceFactory() - ConfigFileFactory() - serializer = SerializerDBFactory() - device.custom_field_data["serializer"] = serializer.pk - device.save() - resp = admin_client.get(f"/dcim/devices/{device.pk}/serialized_config/") + state = State( + { + "config": state_item("config", {"vlans": [1, 2, 3]}), + "show_ver": state_item("show_ver", {"version": "v1.2.3"}), + "bad_cmd": state_item("bad_cmd", {}, data_file=None), + } + ) + monkeypatch.setattr(models.VDevice, "state", state) + params = {"state_item": item} if item is not None else {} + resp = admin_client.get(f"/dcim/devices/{device.pk}/serialized_state/", params) assert resp.status_code == HTTPStatus.OK diff --git a/validity/utils/dbfields.py b/validity/utils/dbfields.py index b73f3de..2ccc031 100644 --- a/validity/utils/dbfields.py +++ b/validity/utils/dbfields.py @@ -133,6 +133,9 @@ def formfield(self, **kwargs): }, ) + def value_to_string(self, obj: Any) -> Any: + return super().value_to_string(obj).encrypted + class EncryptedDictFormField(forms.JSONField): def to_python(self, value: Any) -> Any: