From 4f955f5655e4ebd31a74cc2177c9924026175e06 Mon Sep 17 00:00:00 2001 From: Anton M Date: Thu, 4 Jan 2024 02:37:15 +0400 Subject: [PATCH] working run_tests --- validity/api/helpers.py | 4 +- validity/api/serializers.py | 9 ++-- validity/choices.py | 8 +++- validity/compliance/eval/eval.py | 31 +++++++++----- validity/compliance/exceptions.py | 6 +++ validity/compliance/state.py | 6 ++- validity/filtersets.py | 2 + validity/forms/filterset.py | 4 +- validity/managers.py | 2 +- validity/models/nameset.py | 18 ++++++++ validity/scripts/run_tests.py | 70 +++++++++++++++++++------------ validity/utils/orm.py | 3 +- 12 files changed, 114 insertions(+), 49 deletions(-) diff --git a/validity/api/helpers.py b/validity/api/helpers.py index bc7da36..897bdd2 100644 --- a/validity/api/helpers.py +++ b/validity/api/helpers.py @@ -42,11 +42,11 @@ class ListQPMixin: """ def get_list_param(self, param: str) -> list[str] | None: - if "request" not in self.context or param not in self.context['request'].query_params: + if "request" not in self.context or param not in self.context["request"].query_params: return None param_value = self.context["request"].query_params.getlist(param) if len(param_value) == 1: - return param_value[0].split(',') + return param_value[0].split(",") return param_value diff --git a/validity/api/serializers.py b/validity/api/serializers.py index 0ccc3fc..2bdc371 100644 --- a/validity/api/serializers.py +++ b/validity/api/serializers.py @@ -1,4 +1,3 @@ -from rest_framework.fields import empty from core.api.nested_serializers import NestedDataFileSerializer, NestedDataSourceSerializer from dcim.api.nested_serializers import ( NestedDeviceSerializer, @@ -19,8 +18,8 @@ from tenancy.models import Tenant from validity import models -from .helpers import EncryptedDictField, FieldsMixin, nested_factory, ListQPMixin -from rest_framework.fields import empty +from .helpers import EncryptedDictField, FieldsMixin, ListQPMixin, nested_factory + class ComplianceSelectorSerializer(NetBoxModelSerializer): url = serializers.HyperlinkedIdentityField(view_name="plugins-api:validity-api:complianceselector-detail") @@ -355,12 +354,12 @@ def get_serialized(self, state_item): class SerializedStateSerializer(ListQPMixin, serializers.Serializer): count = serializers.SerializerMethodField() - results = SerializedStateItemSerializer(many=True, read_only=True, source='*') + results = SerializedStateItemSerializer(many=True, read_only=True, source="*") def get_count(self, state): return len(state) def to_representation(self, instance): - if name_filter := self.get_list_param('name'): + if name_filter := self.get_list_param("name"): instance = [item for item in instance if item.name in set(name_filter)] return super().to_representation(instance) diff --git a/validity/choices.py b/validity/choices.py index 0a567a6..505d3f6 100644 --- a/validity/choices.py +++ b/validity/choices.py @@ -1,6 +1,6 @@ from typing import Any, Optional, TypeVar -from django.db.models import TextChoices +from django.db.models import IntegerChoices, TextChoices from django.db.models.enums import ChoicesMeta from django.utils.translation import gettext_lazy as _ @@ -121,3 +121,9 @@ def acceptable_command_type(self) -> "CommandTypeChoices": class CommandTypeChoices(TextChoices, metaclass=ColoredChoiceMeta): CLI = "CLI", "CLI", "blue" + + +class ExplanationVerbosityChoices(IntegerChoices): + disabled = 0, _("0 - Disabled") + medium = 1, _("1 - Medium") + maximum = 2, _("2 - Maximum") diff --git a/validity/compliance/eval/eval.py b/validity/compliance/eval/eval.py index 3fb5029..f0d472e 100644 --- a/validity/compliance/eval/eval.py +++ b/validity/compliance/eval/eval.py @@ -1,20 +1,33 @@ import ast import re +from typing import Literal import deepdiff -from simpleeval import EvalWithCompoundTypes, InvalidExpression +from simpleeval import EvalWithCompoundTypes +from validity.utils.misc import reraise from ..exceptions import EvalError from . import default_nameset, eval_defaults class ExplanationalEval(EvalWithCompoundTypes): - do_not_explain = (ast.Constant, ast.Name, ast.Attribute, ast.Expr) - def __init__(self, operators=None, functions=None, names=None, deepdiff_types=None, *, load_defaults=False): - if deepdiff_types is None: - deepdiff_types = (list, dict, set, frozenset, tuple) + def __init__( + self, + operators=None, + functions=None, + names=None, + deepdiff_types=None, + *, + load_defaults=False, + verbosity: Literal[0, 1, 2] = 2, + ): + self.verbosity = verbosity + deepdiff_types = deepdiff_types or (list, dict, set, frozenset, tuple) + if verbosity < 2: + # disable deepdiff explanation + deepdiff_types = () self.deepdiff_types = deepdiff_types self.explanation = [] self._deepdiff = [] @@ -33,6 +46,8 @@ def _load_defaults(self, /, **kwargs): def _eval(self, node): result = super()._eval(node) + if self.verbosity < 1: + return result unparsed = ast.unparse(node) if not isinstance(node, self.do_not_explain) and str(result) != unparsed and unparsed: self.explanation.append((self._format_unparsed(unparsed), result)) @@ -61,9 +76,5 @@ def _eval_compare(self, node): def eval(self, expr): self.explanation = [] - try: + with reraise(Exception, EvalError): return super().eval(expr) - except InvalidExpression: - raise - except Exception as e: - raise EvalError(e) from e diff --git a/validity/compliance/exceptions.py b/validity/compliance/exceptions.py index b6b12a7..2f93a0d 100644 --- a/validity/compliance/exceptions.py +++ b/validity/compliance/exceptions.py @@ -32,3 +32,9 @@ def __str__(self) -> str: class BadDataFileContentsError(SerializationError): pass + + +class StateKeyError(KeyError): + def __str__(self) -> str: + key = str(self.args[0]).strip("\"'") + return f"State has no '{key}' item" diff --git a/validity/compliance/state.py b/validity/compliance/state.py index f2d861e..650597b 100644 --- a/validity/compliance/state.py +++ b/validity/compliance/state.py @@ -5,7 +5,8 @@ from django.utils.translation import gettext_lazy as _ from validity.compliance.serialization import Serializable -from .exceptions import SerializationError +from ..utils.misc import reraise +from .exceptions import SerializationError, StateKeyError if TYPE_CHECKING: @@ -73,7 +74,8 @@ def __getattr__(self, key): return self[key] def __getitem__(self, key): - state_item = super().__getitem__(key) + with reraise(KeyError, StateKeyError): + state_item = super().__getitem__(key) return state_item.serialized def get(self, key, default=None, ignore_errors=False): diff --git a/validity/filtersets.py b/validity/filtersets.py index 11653ac..b0ca6ab 100644 --- a/validity/filtersets.py +++ b/validity/filtersets.py @@ -6,6 +6,7 @@ from dcim.models import Device, DeviceRole, DeviceType, Location, Manufacturer, Platform, Site from django.db.models import Q from django_filters import BooleanFilter, ChoiceFilter, ModelMultipleChoiceFilter +from extras.models import Tag from netbox.filtersets import NetBoxModelFilterSet from tenancy.models import Tenant @@ -60,6 +61,7 @@ class ComplianceTestResultFilterSet(SearchMixin, NetBoxModelFilterSet): platform_id = ModelMultipleChoiceFilter(field_name="device__platform", queryset=Platform.objects.all()) location_id = ModelMultipleChoiceFilter(field_name="device__location", queryset=Location.objects.all()) site_id = ModelMultipleChoiceFilter(field_name="device__site", queryset=Site.objects.all()) + test_tag_id = ModelMultipleChoiceFilter(field_name="test__tags", queryset=Tag.objects.all()) tag = None class Meta: diff --git a/validity/forms/filterset.py b/validity/forms/filterset.py index ea1cc5c..f733f33 100644 --- a/validity/forms/filterset.py +++ b/validity/forms/filterset.py @@ -2,6 +2,7 @@ from dcim.models import Device, DeviceRole, DeviceType, Location, Manufacturer, Platform, Site from django.forms import CharField, Form, NullBooleanField, Select from django.utils.translation import gettext_lazy as _ +from extras.models import Tag from netbox.forms import NetBoxModelFilterSetForm from tenancy.models import Tenant from utilities.forms import BOOLEAN_WITH_BLANK_CHOICES @@ -67,13 +68,14 @@ class TestResultFilterForm(ExcludeMixin, Form): platform_id = DynamicModelMultipleChoiceField(required=False, label=_("Platform"), queryset=Platform.objects.all()) location_id = DynamicModelMultipleChoiceField(required=False, label=_("Location"), queryset=Location.objects.all()) site_id = DynamicModelMultipleChoiceField(required=False, label=_("Site"), queryset=Site.objects.all()) + test_tag_id = DynamicModelMultipleChoiceField(required=False, label=_("Test Tags"), queryset=Tag.objects.all()) class ComplianceTestResultFilterForm(TestResultFilterForm, NetBoxModelFilterSetForm): model = models.ComplianceTestResult fieldsets = ( [_("Common"), ("latest", "passed", "selector_id")], - [_("Test"), ("severity", "test_id", "report_id")], + [_("Test"), ("severity", "test_id", "report_id", "test_tag_id")], [ _("Device"), ( diff --git a/validity/managers.py b/validity/managers.py index 07a178f..d525745 100644 --- a/validity/managers.py +++ b/validity/managers.py @@ -132,7 +132,7 @@ def delete_old(self, _settings=settings): class VDeviceQS(CustomPrefetchMixin, SetAttributesMixin, RestrictedQuerySet): def set_selector(self, selector): - self.set_attribute("selector", selector) + return self.set_attribute("selector", selector) def annotate_datasource_id(self): from validity.models import VDataSource diff --git a/validity/models/nameset.py b/validity/models/nameset.py index 01ced2e..e227932 100644 --- a/validity/models/nameset.py +++ b/validity/models/nameset.py @@ -1,9 +1,14 @@ import ast +import builtins +from functools import cached_property +from inspect import getmembers +from typing import Any, Callable from django.core.exceptions import ValidationError from django.db import models from django.utils.translation import gettext_lazy as _ +import validity.compliance.eval.default_nameset as default_nameset from .base import BaseModel, DataSourceMixin from .test import ComplianceTest @@ -50,3 +55,16 @@ def clean(self): @property def effective_definitions(self): return self.effective_text_field() + + @cached_property + def _globals(self): + return dict(getmembers(builtins)) | {name: getattr(default_nameset, name) for name in default_nameset.__all__} + + def extract(self, extra_globals: dict[str, Any] | None = None) -> dict[str, Callable]: + all_globals = self._globals + if extra_globals: + all_globals |= extra_globals + locs = {} + exec(self.effective_definitions, all_globals, locs) + __all__ = set(locs.get("__all__", [])) + return {k: v for k, v in locs.items() if k in __all__ and callable(v)} diff --git a/validity/scripts/run_tests.py b/validity/scripts/run_tests.py index 6357513..0cbeb7f 100644 --- a/validity/scripts/run_tests.py +++ b/validity/scripts/run_tests.py @@ -1,21 +1,19 @@ -import builtins import time -from inspect import getmembers from itertools import chain from typing import Any, Callable, Generator, Iterable import yaml from dcim.models import Device -from django.db.models import QuerySet +from django.db.models import Prefetch, QuerySet from django.utils.translation import gettext as __ from extras.choices import ObjectChangeActionChoices -from extras.scripts import BooleanVar, MultiObjectVar +from extras.models import Tag +from extras.scripts import BooleanVar, ChoiceVar, MultiObjectVar from extras.webhooks import enqueue_object from netbox.context import webhooks_queue -from simpleeval import InvalidExpression import validity -import validity.compliance.eval.default_nameset as default_nameset +from validity.choices import ExplanationVerbosityChoices from validity.compliance.eval import ExplanationalEval from validity.compliance.exceptions import EvalError, SerializationError from validity.models import ( @@ -30,6 +28,12 @@ from validity.utils.misc import datasource_sync, null_request +class RequiredChoiceVar(ChoiceVar): + def __init__(self, choices, *args, **kwargs): + super().__init__(choices, *args, **kwargs) + self.field_attrs["choices"] = choices + + class RunTestsScript: _sleep_between_tests = validity.settings.sleep_between_tests _result_batch_size = validity.settings.result_batch_size @@ -44,15 +48,27 @@ class RunTestsScript: selectors = MultiObjectVar( model=ComplianceSelector, required=False, - label=__("Specific selectors"), + label=__("Specific Selectors"), description=__("Run the tests only for specific selectors"), ) devices = MultiObjectVar( model=Device, required=False, - label=__("Specific devices"), + label=__("Specific Devices"), description=__("Run the tests only for specific devices"), ) + test_tags = MultiObjectVar( + model=Tag, + required=False, + label=__("Specific Test Tags"), + description=__("Run the tests which contain specific tags only"), + ) + explanation_verbosity = RequiredChoiceVar( + choices=ExplanationVerbosityChoices.choices, + default=ExplanationVerbosityChoices.maximum, + label=__("Explanation Verbosity Level"), + required=False, + ) class Meta: name = __("Run Compliance Tests") @@ -62,24 +78,16 @@ def __init__(self): super().__init__() self._nameset_functions = {} self.global_namesets = NameSet.objects.filter(_global=True) + self.verbosity = 2 self.results_count = 0 self.results_passed = 0 def nameset_functions(self, namesets: Iterable[NameSet]) -> dict[str, Callable]: - def extract_nameset(nameset, globals_): - locs = {} - exec(nameset.effective_definitions, globals_, locs) - __all__ = set(locs.get("__all__", [])) - return {k: v for k, v in locs.items() if k in __all__ and callable(v)} - result = {} - globals_ = dict(getmembers(builtins)) | { - name: getattr(default_nameset, name) for name in default_nameset.__all__ - } for nameset in chain(namesets, self.global_namesets): if nameset.name not in self._nameset_functions: try: - new_functions = extract_nameset(nameset, globals_) + new_functions = nameset.extract() except Exception as e: self.log_warning(f"Cannot extract code from nameset {nameset}, {type(e).__name__}: {e}") new_functions = {} @@ -89,7 +97,9 @@ def extract_nameset(nameset, globals_): def run_test(self, device: VDevice, test: ComplianceTest) -> tuple[bool, list[tuple[Any, Any]]]: functions = self.nameset_functions(test.namesets.all()) - evaluator = ExplanationalEval(functions=functions, names={"device": device}, load_defaults=True) + evaluator = ExplanationalEval( + functions=functions, names={"device": device}, load_defaults=True, verbosity=self.verbosity + ) passed = bool(evaluator.eval(test.effective_expression)) return passed, evaluator.explanation @@ -102,12 +112,11 @@ def run_tests_for_device( for test in tests_qs: explanation = [] try: - device.config passed, explanation = self.run_test(device, test) - except (InvalidExpression, EvalError) as e: - self.log_failure(f"Failed to execute test *{test}* for device *{device}*, `{type(e).__name__}: {e}`") + except EvalError as exc: + self.log_failure(f"Failed to execute test **{test}** for device **{device}**, `{exc}`") passed = False - explanation.append((f"{type(e).__name__}: {e}", None)) + explanation.append((str(exc), None)) self.results_count += 1 self.results_passed += int(passed) yield ComplianceTestResult( @@ -147,15 +156,24 @@ def save_to_db(self, results: Iterable[ComplianceTestResult], report: Compliance if report: ComplianceReport.objects.delete_old() + def get_selectors(self, data: dict) -> QuerySet[ComplianceSelector]: + selectors = ComplianceSelector.objects.all() + if specific_selectors := data.get("selectors"): + selectors = selectors.filter(pk__in=specific_selectors) + test_qs = ComplianceTest.objects.all() + if test_tags := data.get("test_tags"): + test_qs = test_qs.filter(tags__pk__in=test_tags).distinct() + selectors = selectors.filter(tests__tags__pk__in=test_tags).distinct() + return selectors.prefetch_related(Prefetch("tests", test_qs.prefetch_related("namesets"))) + def run(self, data, commit): + self.verbosity = int(data.get("explanation_verbosity", self.verbosity)) if data.get("sync_datasources"): datasource_sync(VDataSource.objects.exclude(custom_field_data__device_config_path=None)) with null_request(): report = ComplianceReport.objects.create() if data.get("make_report") else None - selectors = ComplianceSelector.objects.prefetch_related("tests", "tests__namesets") + selectors = self.get_selectors(data) device_ids = data.get("devices", []) - if specific_selectors := data.get("selectors"): - selectors = selectors.filter(pk__in=specific_selectors) results = chain.from_iterable( self.run_tests_for_selector(selector, report, device_ids) for selector in selectors ) diff --git a/validity/utils/orm.py b/validity/utils/orm.py index a5c5840..5fab36b 100644 --- a/validity/utils/orm.py +++ b/validity/utils/orm.py @@ -173,7 +173,7 @@ def _clone(self, *args, **kwargs): return c def bind_attributes(self, instance): - for attr, attr_value in self._aux_attributes: + for attr, attr_value in self._aux_attributes.items(): setattr(instance, attr, attr_value) def _fetch_all(self): @@ -184,3 +184,4 @@ def _fetch_all(self): def set_attribute(self, name, value): self._aux_attributes[name] = value + return self