diff --git a/validity/compliance/eval/default_nameset.py b/validity/compliance/eval/default_nameset.py index 4c4e010..2a52c72 100644 --- a/validity/compliance/eval/default_nameset.py +++ b/validity/compliance/eval/default_nameset.py @@ -2,7 +2,7 @@ import jq as pyjq -from validity.utils.config import config # noqa +from validity.models import VDevice builtins = [ @@ -14,6 +14,7 @@ "bool", "bytes", "callable", + "classmethod", "chr", "complex", "dict", @@ -36,12 +37,14 @@ "oct", "ord", "pow", + "property", "range", "reversed", "round", "set", "slice", "sorted", + "staticmethod", "str", "sum", "tuple", @@ -49,7 +52,7 @@ ] -__all__ = ["jq", "config"] + builtins +__all__ = ["jq", "config", "state"] + builtins class jq: @@ -58,3 +61,17 @@ class jq: def __init__(self, *args, **kwargs) -> None: raise TypeError("jq is not callable") + + +def state(device): + # state() implies presence of "_data_source" and "_poller" global variables + # which are gonna be set by RunTests script + vdevice = VDevice() + vdevice.__dict__ = device.__dict__.copy() + vdevice.data_source = _data_source # noqa + vdevice._poller = _poller # noqa + return vdevice.state + + +def config(device): + return state(device).config diff --git a/validity/compliance/exceptions.py b/validity/compliance/exceptions.py index 2f93a0d..3413de1 100644 --- a/validity/compliance/exceptions.py +++ b/validity/compliance/exceptions.py @@ -22,12 +22,15 @@ class NoComponentError(SerializationError): Indicates lack of the required component (e.g. serializer) to do serialization """ - def __init__(self, missing_component: str, orig_error: Exception | None = None) -> None: + def __init__(self, missing_component: str, parent: str | None = None) -> None: self.missing_component = missing_component - super().__init__(orig_error) + self.parent = parent def __str__(self) -> str: - return f"There is no bound {self.missing_component}" + result = f"There is no bound {self.missing_component}" + if self.parent: + result += f' for "{self.parent}"' + return result class BadDataFileContentsError(SerializationError): diff --git a/validity/compliance/state.py b/validity/compliance/state.py index 650597b..fdc6091 100644 --- a/validity/compliance/state.py +++ b/validity/compliance/state.py @@ -6,7 +6,7 @@ from validity.compliance.serialization import Serializable from ..utils.misc import reraise -from .exceptions import SerializationError, StateKeyError +from .exceptions import NoComponentError, SerializationError, StateKeyError if TYPE_CHECKING: @@ -41,6 +41,14 @@ def error(self) -> SerializationError | None: except SerializationError as exc: return exc + @property + def serialized(self): + try: + return super().serialized + except NoComponentError as exc: + exc.parent = self.name + raise + class State(dict): def __init__(self, items, config_command_label: str | None = None): diff --git a/validity/forms/helpers.py b/validity/forms/helpers.py index 706d2f5..c82e9f8 100644 --- a/validity/forms/helpers.py +++ b/validity/forms/helpers.py @@ -1,10 +1,17 @@ import json -from typing import Sequence +from typing import Any, Sequence from django.forms import ChoiceField, Select from utilities.forms import get_field_value +class IntegerChoiceField(ChoiceField): + def to_python(self, value: Any | None) -> Any | None: + if value is not None: + value = int(value) + return value + + class SelectWithPlaceholder(Select): def __init__(self, attrs=None, choices=()) -> None: super().__init__(attrs, choices) diff --git a/validity/managers.py b/validity/managers.py index d525745..25d6491 100644 --- a/validity/managers.py +++ b/validity/managers.py @@ -134,6 +134,9 @@ class VDeviceQS(CustomPrefetchMixin, SetAttributesMixin, RestrictedQuerySet): def set_selector(self, selector): return self.set_attribute("selector", selector) + def set_datasource(self, data_source): + return self.set_attribute("data_source", data_source) + def annotate_datasource_id(self): from validity.models import VDataSource diff --git a/validity/models/data.py b/validity/models/data.py index ba2fdbf..9396a94 100644 --- a/validity/models/data.py +++ b/validity/models/data.py @@ -99,3 +99,8 @@ def new_data_file(path): new_datafiles = (new_data_file(path) for path in paths) created = len(DataFile.objects.bulk_create(new_datafiles, batch_size=batch_size)) logger.debug("%s new files were created and %s existing files were updated during sync", created, updated) + + def sync(self, device_filter: Q | None = None): + if device_filter is not None and self.type == "device_polling": + return self.partial_sync(device_filter) + return super().sync() diff --git a/validity/models/device.py b/validity/models/device.py index b2b16bf..d132d26 100644 --- a/validity/models/device.py +++ b/validity/models/device.py @@ -63,4 +63,8 @@ def dynamic_pair(self) -> Optional["VDevice"]: filter_ = self.selector.dynamic_pair_filter(self) if filter_ is None: return - return type(self).objects.filter(filter_).first() + pair = type(self).objects.filter(filter_).first() + if pair: + pair.data_source = self.data_source + pair.poller = self.poller + return pair diff --git a/validity/models/nameset.py b/validity/models/nameset.py index e227932..7b88427 100644 --- a/validity/models/nameset.py +++ b/validity/models/nameset.py @@ -1,6 +1,5 @@ import ast import builtins -from functools import cached_property from inspect import getmembers from typing import Any, Callable @@ -56,7 +55,7 @@ def clean(self): def effective_definitions(self): return self.effective_text_field() - @cached_property + @property def _globals(self): return dict(getmembers(builtins)) | {name: getattr(default_nameset, name) for name in default_nameset.__all__} diff --git a/validity/models/test.py b/validity/models/test.py index eb5db2d..e124842 100644 --- a/validity/models/test.py +++ b/validity/models/test.py @@ -1,10 +1,13 @@ import ast +from functools import partial +from typing import Any, Callable from django.core.exceptions import ValidationError from django.db import models from django.utils.translation import gettext_lazy as _ from validity.choices import SeverityChoices +from validity.compliance.eval import ExplanationalEval from validity.managers import ComplianceTestQS from .base import BaseModel, DataSourceMixin @@ -20,6 +23,7 @@ class ComplianceTest(DataSourceMixin, BaseModel): clone_fields = ("expression", "selectors", "severity", "data_source", "data_file") text_db_field_name = "expression" + evaluator_cls = partial(ExplanationalEval, load_defaults=True) objects = ComplianceTestQS.as_manager() @@ -46,3 +50,13 @@ def get_severity_color(self): @property def effective_expression(self): return self.effective_text_field() + + def run( + self, device, functions: dict[str, Callable], extra_names: dict[str, Any] | None = None, verbosity: int = 2 + ) -> tuple[bool, list]: + names = {"device": device, "_poller": device.poller, "_data_source": device.data_source} + if extra_names: + names |= extra_names + evaluator = self.evaluator_cls(names=names, functions=functions, verbosity=verbosity) + passed = bool(evaluator.eval(self.effective_expression)) + return passed, evaluator.explanation diff --git a/validity/scripts/run_tests.py b/validity/scripts/run_tests.py index 0cbeb7f..2f5850c 100644 --- a/validity/scripts/run_tests.py +++ b/validity/scripts/run_tests.py @@ -1,20 +1,22 @@ +import operator import time +from functools import reduce from itertools import chain from typing import Any, Callable, Generator, Iterable import yaml +from core.models import DataSource from dcim.models import Device -from django.db.models import Prefetch, QuerySet +from django.db.models import Prefetch, Q, QuerySet from django.utils.translation import gettext as __ from extras.choices import ObjectChangeActionChoices from extras.models import Tag -from extras.scripts import BooleanVar, ChoiceVar, MultiObjectVar +from extras.scripts import BooleanVar, MultiObjectVar, ObjectVar from extras.webhooks import enqueue_object from netbox.context import webhooks_queue import validity from validity.choices import ExplanationVerbosityChoices -from validity.compliance.eval import ExplanationalEval from validity.compliance.exceptions import EvalError, SerializationError from validity.models import ( ComplianceReport, @@ -26,15 +28,11 @@ VDevice, ) from validity.utils.misc import datasource_sync, null_request +from .script_data import RunTestsScriptData, ScriptDataMixin +from .variables import VerbosityVar -class RequiredChoiceVar(ChoiceVar): - def __init__(self, choices, *args, **kwargs): - super().__init__(choices, *args, **kwargs) - self.field_attrs["choices"] = choices - - -class RunTestsScript: +class RunTestsScript(ScriptDataMixin[RunTestsScriptData]): _sleep_between_tests = validity.settings.sleep_between_tests _result_batch_size = validity.settings.result_batch_size @@ -42,7 +40,7 @@ class RunTestsScript: required=False, default=False, label=__("Sync Data Sources"), - description=__('Sync all Data Source instances which have "device_config_path" defined'), + description=__("Sync all referenced Data Sources"), ) make_report = BooleanVar(default=True, label=__("Make Compliance Report")) selectors = MultiObjectVar( @@ -63,12 +61,18 @@ class RunTestsScript: label=__("Specific Test Tags"), description=__("Run the tests which contain specific tags only"), ) - explanation_verbosity = RequiredChoiceVar( + explanation_verbosity = VerbosityVar( choices=ExplanationVerbosityChoices.choices, default=ExplanationVerbosityChoices.maximum, label=__("Explanation Verbosity Level"), required=False, ) + override_datasource = ObjectVar( + model=DataSource, + required=False, + label=__("Override DataSource"), + description=__("Find all devices state/config data in this Data Source instead of bound ones"), + ) class Meta: name = __("Run Compliance Tests") @@ -78,7 +82,6 @@ def __init__(self): super().__init__() self._nameset_functions = {} self.global_namesets = NameSet.objects.filter(_global=True) - self.verbosity = 2 self.results_count = 0 self.results_passed = 0 @@ -97,11 +100,7 @@ def nameset_functions(self, namesets: Iterable[NameSet]) -> dict[str, Callable]: def run_test(self, device: VDevice, test: ComplianceTest) -> tuple[bool, list[tuple[Any, Any]]]: functions = self.nameset_functions(test.namesets.all()) - evaluator = ExplanationalEval( - functions=functions, names={"device": device}, load_defaults=True, verbosity=self.verbosity - ) - passed = bool(evaluator.eval(test.effective_expression)) - return passed, evaluator.explanation + return test.run(device, functions, verbosity=self.script_data.explanation_verbosity) def run_tests_for_device( self, @@ -112,6 +111,7 @@ def run_tests_for_device( for test in tests_qs: explanation = [] try: + device.state passed, explanation = self.run_test(device, test) except EvalError as exc: self.log_failure(f"Failed to execute test **{test}** for device **{device}**, `{exc}`") @@ -129,16 +129,20 @@ def run_tests_for_device( ) time.sleep(self._sleep_between_tests) + def get_device_qs(self, selector: ComplianceSelector) -> QuerySet[VDevice]: + device_qs = selector.devices.select_related().prefetch_serializer().prefetch_poller() + if self.script_data.override_datasource: + device_qs = device_qs.set_datasource(self.script_data.override_datasource.obj) + else: + device_qs = device_qs.prefetch_datasource() + if self.script_data.devices: + device_qs = device_qs.filter(pk__in=self.script_data.devices) + return device_qs + def run_tests_for_selector( - self, - selector: ComplianceSelector, - report: ComplianceReport | None, - device_ids: list[int], + self, selector: ComplianceSelector, report: ComplianceReport | None ) -> Generator[ComplianceTestResult, None, None]: - qs = selector.devices.select_related().prefetch_datasource().prefetch_serializer().prefetch_poller() - if device_ids: - qs = qs.filter(pk__in=device_ids) - for device in qs: + for device in self.get_device_qs(selector): try: yield from self.run_tests_for_device(selector.tests.all(), device, report) except SerializationError as e: @@ -156,27 +160,37 @@ def save_to_db(self, results: Iterable[ComplianceTestResult], report: Compliance if report: ComplianceReport.objects.delete_old() - def get_selectors(self, data: dict) -> QuerySet[ComplianceSelector]: - selectors = ComplianceSelector.objects.all() - if specific_selectors := data.get("selectors"): - selectors = selectors.filter(pk__in=specific_selectors) + def get_selectors(self) -> QuerySet[ComplianceSelector]: + selectors = self.script_data.selectors.queryset test_qs = ComplianceTest.objects.all() - if test_tags := data.get("test_tags"): - test_qs = test_qs.filter(tags__pk__in=test_tags).distinct() - selectors = selectors.filter(tests__tags__pk__in=test_tags).distinct() + if self.script_data.test_tags: + test_qs = test_qs.filter(tags__pk__in=self.script_data.test_tags).distinct() + selectors = selectors.filter(tests__tags__pk__in=self.script_data.test_tags).distinct() return selectors.prefetch_related(Prefetch("tests", test_qs.prefetch_related("namesets"))) + def perform_datasource_sync(self) -> None: + device_filter = reduce(operator.or_, (selector.filter for selector in self.script_data.selectors.queryset)) + if self.script_data.devices: + device_filter |= Q(pk__in=self.script_data.devices) + if self.script_data.override_datasource: + self.script_data.override_datasource.obj.sync(device_filter) + return + datasource_ids = ( + VDevice.objects.filter(device_filter) + .annotate_datasource_id() + .values_list("data_source_id", flat=True) + .distinct() + ) + datasource_sync(VDataSource.objects.filter(pk__in=datasource_ids)) + def run(self, data, commit): - self.verbosity = int(data.get("explanation_verbosity", self.verbosity)) - if data.get("sync_datasources"): - datasource_sync(VDataSource.objects.exclude(custom_field_data__device_config_path=None)) + self.script_data = self.script_data_cls(data) + selectors = self.get_selectors() + if self.script_data.sync_datasources: + self.perform_datasource_sync() with null_request(): - report = ComplianceReport.objects.create() if data.get("make_report") else None - selectors = self.get_selectors(data) - device_ids = data.get("devices", []) - results = chain.from_iterable( - self.run_tests_for_selector(selector, report, device_ids) for selector in selectors - ) + report = ComplianceReport.objects.create() if self.script_data.make_report else None + results = chain.from_iterable(self.run_tests_for_selector(selector, report) for selector in selectors) self.save_to_db(results, report) output = {"results": {"all": self.results_count, "passed": self.results_passed}} if report: diff --git a/validity/scripts/script_data.py b/validity/scripts/script_data.py new file mode 100644 index 0000000..e592ef1 --- /dev/null +++ b/validity/scripts/script_data.py @@ -0,0 +1,118 @@ +from functools import cached_property +from typing import Generic, TypeVar, get_args + +from django.db.models import Model, QuerySet +from django.utils.functional import classproperty +from extras.models import Tag + +from validity import models + + +class DBObject(int): + def __new__(cls, value, model): + return super().__new__(cls, value) + + def __init__(self, value, model): + self.model = model + super().__init__() + + @cached_property + def obj(self): + return self.model.objects.filter(pk=self).first() + + +class QuerySetObject(list): + def __init__(self, iterable, model=None): + self.model = model + super().__init__(iterable) + + +class AllQuerySetObject(QuerySetObject): + """ + Defaults to "all" if empty + """ + + @property + def queryset(self): + if not self: + return self.model.objects.all() + return self.model.objects.filter(pk__in=iter(self)) + + +class EmptyQuerySetObject(QuerySetObject): + """ + Defaults to "none" if empty + """ + + @property + def queryset(self): + if not self: + return self.model.objects.none() + return self.model.objects.filter(pk__in=iter(self)) + + +class DBField: + def __init__(self, model, object_cls, default=None) -> None: + self.model = model + self.object_cls = object_cls + self.attr_name = None + if default is not None and not isinstance(default, object_cls): + default = object_cls(default, model) + self.default = default + + def __set_name__(self, parent_cls, attr_name): + self.attr_name = attr_name + + def __get__(self, instance, type_): + return instance.__dict__.get(self.attr_name, self.default) + + def __set__(self, instance, value): + if value is not None: + value = self.object_cls(value, self.model) + instance.__dict__[self.attr_name] = value + + +class ScriptData: + def from_queryset(self, queryset: QuerySet) -> list[int]: + """ + Extract primary keys from queryset + """ + return list(queryset.values_list("pk", flat=True)) + + def __init__(self, data) -> None: + for k, v in data.items(): + if isinstance(v, QuerySet): + v = self.from_queryset(v) + elif isinstance(v, Model): + v = v.pk + setattr(self, k, v) + + +_ScriptData = TypeVar("_ScriptData", bound=ScriptData) + + +class ScriptDataMixin(Generic[_ScriptData]): + """ + Mixin for Script. Allows to define script data cls in class definition and later use it. + Example: + self.script_data = self.script_data_cls(data) + """ + + script_data: _ScriptData + + @classproperty + def script_data_cls(cls) -> type[_ScriptData]: + for base_classes in cls.__orig_bases__: + if (args := get_args(base_classes)) and issubclass(args[0], ScriptData): + return args[0] + raise AttributeError(f"No ScriptData definition found for {cls.__name__}") + + +class RunTestsScriptData(ScriptData): + sync_datasources = False + make_report = True + selectors = DBField(models.ComplianceSelector, AllQuerySetObject, default=[]) + devices = DBField(models.VDevice, AllQuerySetObject, default=[]) + test_tags = DBField(Tag, EmptyQuerySetObject, default=[]) + explanation_verbosity = 2 + override_datasource = DBField(models.VDataSource, DBObject, default=None) diff --git a/validity/scripts/variables.py b/validity/scripts/variables.py new file mode 100644 index 0000000..b6bea88 --- /dev/null +++ b/validity/scripts/variables.py @@ -0,0 +1,13 @@ +from extras.scripts import ChoiceVar + +from validity.forms.helpers import IntegerChoiceField + + +class NoNullChoiceVar(ChoiceVar): + def __init__(self, choices, *args, **kwargs): + super().__init__(choices, *args, **kwargs) + self.field_attrs["choices"] = choices + + +class VerbosityVar(NoNullChoiceVar): + form_field = IntegerChoiceField diff --git a/validity/utils/config.py b/validity/utils/config.py deleted file mode 100644 index 7025bb3..0000000 --- a/validity/utils/config.py +++ /dev/null @@ -1,9 +0,0 @@ -from dcim.models import Device - -from validity.models import VDevice - - -def config(device: Device) -> dict | list | None: - vdevice = VDevice() - vdevice.__dict__ = device.__dict__.copy() - return vdevice.config diff --git a/validity/utils/misc.py b/validity/utils/misc.py index 09baece..d2aba8f 100644 --- a/validity/utils/misc.py +++ b/validity/utils/misc.py @@ -74,6 +74,8 @@ def sync_func(datasource): except SyncError as e: if fail_handler: fail_handler(datasource, e) + else: + raise with ThreadPoolExecutor(max_workers=threads) as tp: any(tp.map(sync_func, datasources))