From 56e76141f12dd4ad4482d26a72322ce84cd6dd88 Mon Sep 17 00:00:00 2001 From: Anton M Date: Mon, 18 Dec 2023 00:13:12 +0400 Subject: [PATCH] workging pollers --- requirements/base.txt | 3 +- validity/__init__.py | 9 +- validity/api/serializers.py | 11 ++- validity/api/views.py | 2 +- .../config_compliance/device_config/base.py | 2 +- .../config_compliance/device_config/ttp.py | 4 +- .../config_compliance/device_config/yaml.py | 2 +- validity/data_backends.py | 84 ++++++++++++++++++ validity/filtersets.py | 4 +- validity/forms/filterset.py | 1 + validity/forms/general.py | 12 +-- validity/j2_env.py | 14 +++ validity/managers.py | 8 +- validity/migrations/0007_polling.py | 41 ++++++++- validity/models/data.py | 6 +- validity/models/device.py | 4 +- validity/models/polling.py | 61 +++++++++++-- validity/pollers/__init__.py | 2 + validity/pollers/base.py | 86 +++++++++++++++++++ validity/pollers/cli.py | 17 ++++ validity/pollers/exceptions.py | 8 ++ validity/pollers/factory.py | 22 +++++ validity/pollers/result.py | 60 +++++++++++++ validity/scripts/run_tests.py | 2 +- validity/search.py | 4 +- validity/tables.py | 2 +- validity/templates/validity/command.html | 4 +- .../templates/validity/report_devices.html | 11 ++- validity/urls.py | 2 +- validity/utils/misc.py | 28 ++++-- validity/utils/orm.py | 39 ++++++--- validity/views/device.py | 2 +- 32 files changed, 492 insertions(+), 65 deletions(-) create mode 100644 validity/data_backends.py create mode 100644 validity/j2_env.py create mode 100644 validity/pollers/__init__.py create mode 100644 validity/pollers/base.py create mode 100644 validity/pollers/cli.py create mode 100644 validity/pollers/exceptions.py create mode 100644 validity/pollers/factory.py create mode 100644 validity/pollers/result.py diff --git a/requirements/base.txt b/requirements/base.txt index 9dc22b1..8177644 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,8 +1,9 @@ django-bootstrap-v5==1.0.* -pydantic==1.10.* +pydantic >=2.0.0,<3 ttp==0.9.* jq==1.4.* deepdiff==6.2.* simpleeval==0.9.* +netmiko >=4.0.0,<5 dulwich # Core NetBox "optional" requirement diff --git a/validity/__init__.py b/validity/__init__.py index 6b83de7..2c7219e 100644 --- a/validity/__init__.py +++ b/validity/__init__.py @@ -1,10 +1,9 @@ import logging -from pathlib import Path from django.conf import settings as django_settings from extras.plugins import PluginConfig from netbox.settings import VERSION -from pydantic import BaseModel, DirectoryPath, Field +from pydantic import BaseModel, Field from validity.utils.version import NetboxVersion @@ -26,6 +25,11 @@ class NetBoxValidityConfig(PluginConfig): # custom field netbox_version = NetboxVersion(VERSION) + def ready(self): + import validity.data_backends + + return super().ready() + config = NetBoxValidityConfig @@ -33,7 +37,6 @@ class NetBoxValidityConfig(PluginConfig): class ValiditySettings(BaseModel): store_last_results: int = Field(default=5, gt=0, lt=1001) store_reports: int = Field(default=5, gt=0, lt=1001) - git_folder: DirectoryPath = Path("/opt/git_repos") sleep_between_tests: float = 0 result_batch_size: int = 500 diff --git a/validity/api/serializers.py b/validity/api/serializers.py index b0b761f..6b34e35 100644 --- a/validity/api/serializers.py +++ b/validity/api/serializers.py @@ -299,7 +299,7 @@ class Meta: "url", "display", "name", - "slug", + "label", "retrieves_config", "type", "parameters", @@ -317,7 +317,10 @@ class PollerSerializer(NetBoxModelSerializer): url = serializers.HyperlinkedIdentityField(view_name="plugins-api:validity-api:poller-detail") private_credentials = EncryptedDictField() commands = SerializedPKRelatedField( - serializer=NestedCommandSerializer, many=True, required=False, queryset=models.Command.objects.all() + serializer=NestedCommandSerializer, + many=True, + queryset=models.Command.objects.all(), + allow_empty=False, ) class Meta: @@ -337,5 +340,9 @@ class Meta: "last_updated", ) + def validate(self, data): + models.Poller.validate_commands(data["commands"]) + return super().validate(data) + NestedPollerSerializer = nested_factory(PollerSerializer, ("id", "url", "display", "name")) diff --git a/validity/api/views.py b/validity/api/views.py index 8dfe485..b966a52 100644 --- a/validity/api/views.py +++ b/validity/api/views.py @@ -98,7 +98,7 @@ def get_queryset(self): class SerializedConfigView(APIView): - queryset = models.VDevice.objects.prefetch_datasource().prefetch_serializer() + queryset = models.VDevice.objects.prefetch_datasource().prefetch_serializer().prefetch_poller() def get_object(self, pk): try: diff --git a/validity/config_compliance/device_config/base.py b/validity/config_compliance/device_config/base.py index 14858e1..3fa9663 100644 --- a/validity/config_compliance/device_config/base.py +++ b/validity/config_compliance/device_config/base.py @@ -30,7 +30,7 @@ def from_device(cls, device: "VDevice") -> "BaseDeviceConfig": with reraise((AssertionError, FileNotFoundError, AttributeError), DeviceConfigError): assert getattr( device, "data_file", None - ), f"{device} has no bound data file. Either no data source bound or the file does not exist" + ), f"{device} has no bound data file. Either there is no data source attached or the file does not exist" assert getattr(device, "serializer", None), f"{device} has no bound serializer" return cls._config_classes[device.serializer.extraction_method]._from_device(device) diff --git a/validity/config_compliance/device_config/ttp.py b/validity/config_compliance/device_config/ttp.py index 6a6e2fb..23f4e7f 100644 --- a/validity/config_compliance/device_config/ttp.py +++ b/validity/config_compliance/device_config/ttp.py @@ -28,7 +28,5 @@ def serialize(self, override: bool = False) -> None: if not self.serialized or override: parser = ttp(data=self.plain_config, template=self._template.template) parser.parse() - with reraise( - IndexError, DeviceConfigError, msg=f"Invalid parsed config for {self.device}: {parser.result()}" - ): + with reraise(IndexError, DeviceConfigError, f"Invalid parsed config for {self.device}: {parser.result()}"): self.serialized = parser.result()[0][0] diff --git a/validity/config_compliance/device_config/yaml.py b/validity/config_compliance/device_config/yaml.py index df63b33..5094e93 100644 --- a/validity/config_compliance/device_config/yaml.py +++ b/validity/config_compliance/device_config/yaml.py @@ -15,6 +15,6 @@ def serialize(self, override: bool = False) -> None: with reraise( yaml.YAMLError, DeviceConfigError, - msg=f"Trying to parse invalid YAML as device config for {self.device}", + f"Trying to parse invalid YAML as device config for {self.device}", ): self.serialized = yaml.safe_load(self.plain_config) diff --git a/validity/data_backends.py b/validity/data_backends.py new file mode 100644 index 0000000..e06a90d --- /dev/null +++ b/validity/data_backends.py @@ -0,0 +1,84 @@ +from contextlib import contextmanager +from itertools import chain, groupby +from pathlib import Path +from tempfile import TemporaryDirectory + +import yaml +from django import forms +from django.utils import timezone +from django.utils.translation import gettext_lazy as _ +from netbox.registry import registry + +from validity import config +from validity.models import VDevice +from .pollers.result import DescriptiveError + + +if config.netbox_version >= 3.7: + from netbox.data_backends import DataBackend +else: + from core.data_backends import DataBackend + + +class PollingBackend(DataBackend): + """ + Custom Data Source Backend to poll devices + """ + + name = "device_polling" + label = _("Device Polling") + + parameters = { + "datasource_id": forms.CharField( + label=_("Data Source ID"), + widget=forms.TextInput(attrs={"class": "form-control"}), + ) + } + + devices_qs = VDevice.objects.prefetch_poller().annotate_datasource_id().order_by("poller_id") + metainfo_file = Path("polling_info.yaml") + + def bound_devices_qs(self): + datasource_id = self.params.get("datasource_id") + assert datasource_id, 'Data Source parameters must contain "datasource_id"' + return self.devices_qs.filter(data_source_id=datasource_id) + + def write_metainfo(self, dir_name: str, errors: set[DescriptiveError]) -> None: + # NetBox does not provide an opportunity for a backend to return any info/errors to the user + # Hence, it will be written into "polling_info.yaml" file + info = { + "polled_at": timezone.now().isoformat(timespec="seconds"), + "devices_polled": self.bound_devices_qs().count(), + "errors": [err.serialized for err in sorted(errors, key=lambda e: e.device)], + } + path = dir_name / self.metainfo_file + path.write_text(yaml.safe_dump(info, sort_keys=False)) + + @contextmanager + def fetch(self): + with TemporaryDirectory() as dir_name: + devices = self.bound_devices_qs() + result_generators = [ + poller.get_backend().poll(device_group) + for poller, device_group in groupby(devices, key=lambda device: device.poller) + ] + errors = set() + for cmd_result in chain.from_iterable(result_generators): + if cmd_result.errored: + errors.add(cmd_result.descriptive_error) + cmd_result.write_on_disk(dir_name) + self.write_metainfo(dir_name, errors) + yield dir_name + + +backends = [PollingBackend] + +if config.netbox_version < 3.7: + # "register" DS backend manually via monkeypatch + from core.choices import DataSourceTypeChoices + from core.forms import DataSourceForm + from core.models import DataSource + + registry["data_backends"][PollingBackend.name] = PollingBackend + DataSourceTypeChoices._choices += [(PollingBackend.name, PollingBackend.label)] + DataSourceForm.base_fields["type"] = DataSource._meta.get_field("type").formfield() diff --git a/validity/filtersets.py b/validity/filtersets.py index fb34dca..9c6e6e2 100644 --- a/validity/filtersets.py +++ b/validity/filtersets.py @@ -123,5 +123,5 @@ class Meta: class CommandFilterSet(SearchMixin, NetBoxModelFilterSet): class Meta: model = models.Command - fields = ("id", "name", "slug", "type", "retrieves_config") - search_fields = ("name", "slug") + fields = ("id", "name", "label", "type", "retrieves_config") + search_fields = ("name", "label") diff --git a/validity/forms/filterset.py b/validity/forms/filterset.py index 96e184d..0c858cb 100644 --- a/validity/forms/filterset.py +++ b/validity/forms/filterset.py @@ -153,6 +153,7 @@ class PollerFilterForm(NetBoxModelFilterSetForm): class CommandFilterForm(NetBoxModelFilterSetForm): model = models.Command name = CharField(required=False) + label = CharField(required=False) type = PlaceholderChoiceField(required=False, placeholder=_("Type"), choices=CommandTypeChoices.choices) retrieves_config = NullBooleanField( label=_("Global"), required=False, widget=Select(choices=BOOLEAN_WITH_BLANK_CHOICES) diff --git a/validity/forms/general.py b/validity/forms/general.py index 2ccfb8e..6633c29 100644 --- a/validity/forms/general.py +++ b/validity/forms/general.py @@ -5,7 +5,7 @@ from extras.models import Tag from netbox.forms import NetBoxModelForm from tenancy.models import Tenant -from utilities.forms.fields import DynamicModelMultipleChoiceField, SlugField +from utilities.forms.fields import DynamicModelMultipleChoiceField from utilities.forms.widgets import HTMXSelect from validity import models @@ -124,15 +124,17 @@ class Meta: "private_credentials": Textarea(attrs={"style": "font-family:monospace"}), } + def clean(self): + models.Poller.validate_commands(self.cleaned_data["commands"]) + return super().clean() -class CommandForm(SubformMixin, NetBoxModelForm): - slug = SlugField() +class CommandForm(SubformMixin, NetBoxModelForm): main_fieldsets = [ - (_("Command"), ("name", "slug", "type", "retrieves_config", "tags")), + (_("Command"), ("name", "label", "type", "retrieves_config", "tags")), ] class Meta: model = models.Command - fields = ("name", "slug", "type", "retrieves_config", "tags") + fields = ("name", "label", "type", "retrieves_config", "tags") widgets = {"type": HTMXSelect()} diff --git a/validity/j2_env.py b/validity/j2_env.py new file mode 100644 index 0000000..78847c5 --- /dev/null +++ b/validity/j2_env.py @@ -0,0 +1,14 @@ +from django.utils.text import slugify +from jinja2 import BaseLoader +from jinja2 import Environment as Jinja2Environment + + +def slug(obj, allow_unicode=False): + return slugify(str(obj), allow_unicode) + + +class Environment(Jinja2Environment): + def __init__(self, *args, **kwargs): + kwargs.setdefault("loader", BaseLoader()) + super().__init__(*args, **kwargs) + self.filters["slugify"] = slug diff --git a/validity/managers.py b/validity/managers.py index 491a22c..a82d9d3 100644 --- a/validity/managers.py +++ b/validity/managers.py @@ -219,7 +219,7 @@ def prefetch_serializer(self): def prefetch_poller(self): from validity.models import Poller - return self.annotate_poller_id().custom_prefetch("poller", Poller.objects.prefetch_related("commands")) + return self.annotate_poller_id().custom_prefetch("poller", Poller.objects.prefetch_commands()) def _count_per_something(self, field: str, annotate_method: str) -> dict[int | None, int]: qs = getattr(self, annotate_method)().values(field).annotate(cnt=Count("id", distinct=True)) @@ -258,3 +258,9 @@ def prefetch_results(self, report_id: int, severity_ge: SeverityChoices = Severi .order_by("test__name"), ) ) + + +class PollerQS(RestrictedQuerySet): + def prefetch_commands(self): + Command = self.model._meta.get_field("commands").remote_field.model + return self.prefetch_related(Prefetch("commands", Command.objects.order_by("-retrieves_config"))) diff --git a/validity/migrations/0007_polling.py b/validity/migrations/0007_polling.py index d10d766..6f35c6c 100644 --- a/validity/migrations/0007_polling.py +++ b/validity/migrations/0007_polling.py @@ -6,6 +6,7 @@ import validity.models.base import validity.utils.dbfields from django.utils.translation import gettext_lazy as _ +from django.core.validators import RegexValidator def create_cf(apps, schema_editor): @@ -40,8 +41,31 @@ def delete_cf(apps, schema_editor): CustomField.objects.using(db_alias).filter(name="poller").delete() -class Migration(migrations.Migration): +def create_polling_datasource(apps, schema_editor): + DataSource = apps.get_model("core", "DataSource") + db = schema_editor.connection.alias + ds = DataSource.objects.using(db).create( + name="Validity Polling", + type="device_polling", + source_url="/", + description=_("Required by Validity. Polls bound devices and stores the results"), + custom_field_data={ + "device_config_path": "{{device | slugify}}/{{ device.poller.config_command.label }}.txt", + "device_config_default": False, + "web_url": "", + }, + ) + ds.parameters = {"datasource_id": ds.pk} + ds.save() + + +def delete_polling_datasource(apps, schema_editor): + DataSource = apps.get_model("core", "DataSource") + db = schema_editor.connection.alias + DataSource.objects.using(db).filter(type="Validity Polling").delete() + +class Migration(migrations.Migration): dependencies = [ ("extras", "0098_webhook_custom_field_data_webhook_tags"), ("validity", "0006_script_change"), @@ -59,7 +83,19 @@ class Migration(migrations.Migration): models.JSONField(blank=True, default=dict, encoder=utilities.json.CustomFieldJSONEncoder), ), ("name", models.CharField(max_length=255, unique=True)), - ("slug", models.SlugField(max_length=100, unique=True)), + ( + "label", + models.CharField( + max_length=100, + unique=True, + validators=[ + RegexValidator( + regex="^[a-z][a-z0-9_]*$", + message=_("Only lowercase ASCII letters, numbers and underscores are allowed"), + ) + ], + ), + ), ("retrieves_config", models.BooleanField(default=False)), ("type", models.CharField(max_length=50)), ("parameters", models.JSONField()), @@ -93,4 +129,5 @@ class Migration(migrations.Migration): bases=(validity.models.base.URLMixin, models.Model), ), migrations.RunPython(create_cf, delete_cf), + migrations.RunPython(create_polling_datasource, delete_polling_datasource), ] diff --git a/validity/models/data.py b/validity/models/data.py index 5fafa25..6fea36c 100644 --- a/validity/models/data.py +++ b/validity/models/data.py @@ -1,8 +1,8 @@ from functools import cached_property from core.models import DataFile, DataSource -from jinja2 import BaseLoader, Environment +from validity.j2_env import Environment from validity.managers import VDataFileQS, VDataSourceQS from validity.utils.orm import QuerySetMap @@ -34,8 +34,8 @@ def is_default(self): @property def web_url(self) -> str: - template_text = self.cf.get("web_url", "") - template = Environment(loader=BaseLoader()).from_string(template_text) + template_text = self.cf.get("web_url") or "" + template = Environment().from_string(template_text) return template.render(**self.parameters or {}) @property diff --git a/validity/models/device.py b/validity/models/device.py index 6a6f17f..5a0aae6 100644 --- a/validity/models/device.py +++ b/validity/models/device.py @@ -2,9 +2,9 @@ from typing import Any, Optional from dcim.models import Device -from jinja2 import BaseLoader, Environment from validity.config_compliance.device_config import DeviceConfig +from validity.j2_env import Environment from validity.managers import VDeviceQS from .data import VDataFile, VDataSource @@ -23,7 +23,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: @property def config_path(self) -> str: assert hasattr(self, "data_source"), "You must prefetch data_source first" - template = Environment(loader=BaseLoader()).from_string(self.data_source.config_path_template) + template = Environment().from_string(self.data_source.config_path_template) return template.render(device=self) @cached_property diff --git a/validity/models/polling.py b/validity/models/polling.py index 75e7046..ead693b 100644 --- a/validity/models/polling.py +++ b/validity/models/polling.py @@ -1,8 +1,16 @@ +from contextlib import contextmanager +from functools import cached_property +from typing import Iterable + from dcim.models import Device +from django.core.exceptions import ValidationError +from django.core.validators import RegexValidator from django.db import models from django.utils.translation import gettext_lazy as _ from validity.choices import CommandTypeChoices, ConnectionTypeChoices +from validity.managers import PollerQS +from validity.pollers import get_poller from validity.subforms import CLICommandForm from validity.utils.dbfields import EncryptedDictField from .base import BaseModel, SubformMixin @@ -10,7 +18,18 @@ class Command(SubformMixin, BaseModel): name = models.CharField(_("Name"), max_length=255, unique=True) - slug = models.SlugField(_("Slug"), max_length=100, unique=True) + label = models.CharField( + _("Label"), + max_length=100, + unique=True, + help_text=_("String key to access command output inside Tests"), + validators=[ + RegexValidator( + regex="^[a-z][a-z0-9_]*$", + message=_("Only lowercase ASCII letters, numbers and underscores are allowed"), + ) + ], + ) retrieves_config = models.BooleanField( _("Retrieves Configuration"), default=False, @@ -19,7 +38,6 @@ class Command(SubformMixin, BaseModel): type = models.CharField(_("Type"), max_length=50, choices=CommandTypeChoices.choices) parameters = models.JSONField(_("Parameters")) - clone_fields = ("retrieves_config", "type", "parameters") subform_type_field = "type" subform_json_field = "parameters" subforms = {"CLI": CLICommandForm} @@ -41,7 +59,7 @@ class Poller(BaseModel): private_credentials = EncryptedDictField(_("Private Credentials"), blank=True) commands = models.ManyToManyField(Command, verbose_name=_("Commands"), related_name="pollers") - clone_fields = ("connection_type", "public_credentials", "private_credentials") + objects = PollerQS.as_manager() class Meta: ordering = ("name",) @@ -62,9 +80,38 @@ def bound_devices(self) -> models.QuerySet[Device]: return VDevice.objects.annotate_poller_id().filter(poller_id=self.pk) + @cached_property + def config_command(self) -> Command | None: + """ + Bound command which is responsible for retrieving configuration + """ + return next((cmd for cmd in self.commands.all() if cmd.retrieves_config), None) + + def get_backend(self): + return get_poller(self.connection_type, self.credentials, self.commands.all()) + def serialize_object(self): + with self.serializable_credentials(): + return super().serialize_object() + + @contextmanager + def serializable_credentials(self): private_creds = self.private_credentials - self.private_credentials = self.private_credentials.encrypted - result = super().serialize_object() - self.private_credentials = private_creds - return result + try: + self.private_credentials = self.private_credentials.encrypted + yield + finally: + self.private_credentials = private_creds + + @staticmethod + def validate_commands(commands: Iterable[Command]): + config_commands_count = sum(1 for cmd in commands if cmd.retrieves_config) + if config_commands_count > 1: + raise ValidationError( + { + "commands": _( + "No more than 1 command to retrieve config is allowed, " + f"but {config_commands_count} were specified" + ) + } + ) diff --git a/validity/pollers/__init__.py b/validity/pollers/__init__.py new file mode 100644 index 0000000..5eb0350 --- /dev/null +++ b/validity/pollers/__init__.py @@ -0,0 +1,2 @@ +from .cli import NetmikoPoller +from .factory import get_poller diff --git a/validity/pollers/base.py b/validity/pollers/base.py new file mode 100644 index 0000000..16ded61 --- /dev/null +++ b/validity/pollers/base.py @@ -0,0 +1,86 @@ +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import TYPE_CHECKING, Any, Collection, Iterable, Iterator + +from validity.utils.misc import reraise +from .exceptions import PollingError +from .result import CommandResult + + +if TYPE_CHECKING: + from validity.models import Command, VDevice + + +class DevicePoller(ABC): + host_param_name: str + + def __init__(self, credentials: dict, commands: Collection["Command"]) -> None: + self.credentials = credentials + self.commands = commands + + @abstractmethod + def poll(self, devices: Iterable["VDevice"]) -> Iterator[CommandResult]: + pass + + def get_credentials(self, device: "VDevice"): + if (ip := device.primary_ip) is None: + raise PollingError(message="Device has no primary IP") + return self.credentials | {self.host_param_name: str(ip.address.ip)} + + +class ThreadPoller(DevicePoller): + """ + Polls devices one by one using threads + """ + + thread_workers: int = 500 + + def _poll_one_device(self, device: "VDevice") -> Iterator[CommandResult]: + """ + Handles device-wide errors + """ + try: + with reraise(Exception, PollingError): + yield from self.poll_one_device(device) + except PollingError as err: + yield from (CommandResult(device, c, error=err) for c in self.commands) + + @abstractmethod + def poll_one_device(self, device: "VDevice") -> Iterator[CommandResult]: + pass + + def _poll(self, devices: Iterable["VDevice"]) -> Iterator[CommandResult]: + with ThreadPoolExecutor(max_workers=self.thread_workers) as executor: + results = [executor.submit(self._poll_one_device, d) for d in devices] + yield # start threadpool and release the generator + for result in as_completed(results): + yield from result.result() + + def poll(self, devices: Iterable["VDevice"]) -> Iterator[CommandResult | PollingError]: + poll_gen = self._poll(devices) + next(poll_gen) + return poll_gen + + +class DriverMixin: + driver_cls: type # Network driver class, e.g. netmiko.ConnectHandler + + def get_driver(self, device: "VDevice"): + creds = self.get_credentials(device) + return self.driver_cls(**creds) + + +class ConsecutivePoller(DriverMixin, ThreadPoller): + @abstractmethod + def poll_one_command(self, driver: Any, command: "Command") -> str: + pass + + def poll_one_device(self, device: "VDevice") -> Iterator[CommandResult]: + driver = self.get_driver(device) + for command in self.commands: + try: + with reraise(Exception, PollingError, device_wide=False): + output = self.poll_one_command(driver, command) + yield CommandResult(device=device, command=command, result=output) + except PollingError as err: + yield CommandResult(device=device, command=command, error=err) diff --git a/validity/pollers/cli.py b/validity/pollers/cli.py new file mode 100644 index 0000000..1c5bef6 --- /dev/null +++ b/validity/pollers/cli.py @@ -0,0 +1,17 @@ +from typing import TYPE_CHECKING + +from netmiko import BaseConnection, ConnectHandler + +from .base import ConsecutivePoller + + +if TYPE_CHECKING: + from validity.models import Command + + +class NetmikoPoller(ConsecutivePoller): + host_param_name = "host" + driver_cls = staticmethod(ConnectHandler) # ConnectHandler is a function + + def poll_one_command(self, driver: BaseConnection, command: "Command") -> str: + return driver.send_command(command.parameters["cli_command"]) diff --git a/validity/pollers/exceptions.py b/validity/pollers/exceptions.py new file mode 100644 index 0000000..2ddf9f2 --- /dev/null +++ b/validity/pollers/exceptions.py @@ -0,0 +1,8 @@ +class PollingError(Exception): + def __init__(self, message, *, device_wide=True, orig_error=None) -> None: + self.device_wide = device_wide + if orig_error: + message = f"{type(orig_error).__name__}: {orig_error}" + super().__init__(message) + + message = property(lambda self: self.args[0]) diff --git a/validity/pollers/factory.py b/validity/pollers/factory.py new file mode 100644 index 0000000..96bc9c8 --- /dev/null +++ b/validity/pollers/factory.py @@ -0,0 +1,22 @@ +from typing import TYPE_CHECKING, Sequence + +from validity.choices import ConnectionTypeChoices +from .base import DevicePoller +from .cli import NetmikoPoller + + +if TYPE_CHECKING: + from validity.models import Command + + +class PollerFactory: + def __init__(self, poller_map: dict) -> None: + self.poller_map = poller_map + + def __call__(self, connection_type: str, credentials: dict, commands: Sequence["Command"]) -> DevicePoller: + if poller_cls := self.poller_map.get(connection_type): + return poller_cls(credentials=credentials, commands=commands) + raise KeyError("No poller exist for this connection type", connection_type) + + +get_poller = PollerFactory(poller_map={ConnectionTypeChoices.netmiko: NetmikoPoller}) diff --git a/validity/pollers/result.py b/validity/pollers/result.py new file mode 100644 index 0000000..cbb6bbb --- /dev/null +++ b/validity/pollers/result.py @@ -0,0 +1,60 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, ClassVar + +from django.utils.text import slugify + +from .exceptions import PollingError + + +if TYPE_CHECKING: + from validity.models import Command, VDevice + + +@dataclass(frozen=True) +class DescriptiveError: + """ + This info will be added to polling_info.yaml + """ + + device: str + command: str | None + error: str + + @property + def serialized(self): + result = {"device": self.device, "error": self.error} + if self.command: + result["command"] = self.command + return result + + +@dataclass +class CommandResult: + device: "VDevice" + command: "Command" + result: str = "" + error: PollingError | None = None + + error_header: ClassVar[str] = "POLLING ERROR\n" + + def __post_init__(self): + assert self.result or self.error is not None + + foldername = property(lambda self: slugify(str(self.device))) + filename = property(lambda self: self.command.label + ".txt") + errored = property(lambda self: self.error is not None) + contents = property(lambda self: self.error_header + str(self.error) if self.errored else self.result) + + @property + def descriptive_error(self): + assert self.errored + command = "" if self.error.device_wide else self.command.label + return DescriptiveError(device=str(self.device), command=command, error=self.error.message) + + def write_on_disk(self, base_dir: str) -> None: + device_folder = Path(base_dir) / self.foldername + if not device_folder.is_dir(): + device_folder.mkdir() + full_path = device_folder / self.filename + full_path.write_text(self.contents, encoding="utf-8") diff --git a/validity/scripts/run_tests.py b/validity/scripts/run_tests.py index cb43b9d..ceae31a 100644 --- a/validity/scripts/run_tests.py +++ b/validity/scripts/run_tests.py @@ -126,7 +126,7 @@ def run_tests_for_selector( report: ComplianceReport | None, device_ids: list[int], ) -> Generator[ComplianceTestResult, None, None]: - qs = selector.devices.select_related().prefetch_datasource().prefetch_serializer() + qs = selector.devices.select_related().prefetch_datasource().prefetch_serializer().prefetch_poller() if device_ids: qs = qs.filter(pk__in=device_ids) for device in qs: diff --git a/validity/search.py b/validity/search.py index 653b026..cda804a 100644 --- a/validity/search.py +++ b/validity/search.py @@ -41,5 +41,5 @@ class PollerIndex(SearchIndex): @register_search class CommandIndex(SearchIndex): - model = models.Poller - fields = (("name", 100), ("slug", 110)) + model = models.Command + fields = (("name", 100), ("label", 110)) diff --git a/validity/tables.py b/validity/tables.py index 83bab2a..44cea63 100644 --- a/validity/tables.py +++ b/validity/tables.py @@ -118,7 +118,7 @@ class CommandTable(NetBoxTable): class Meta(NetBoxTable.Meta): model = models.Command - fields = ("name", "type", "retrieves_config", "bound_pollers", "slug") + fields = ("name", "type", "retrieves_config", "bound_pollers", "label") class ExplanationColumn(Column): diff --git a/validity/templates/validity/command.html b/validity/templates/validity/command.html index f3e081e..7ad33d7 100644 --- a/validity/templates/validity/command.html +++ b/validity/templates/validity/command.html @@ -13,8 +13,8 @@
Command
{{ object.name }} - Slug - {{ object.slug }} + Label + {{ object.label }} Retrieves Configuration diff --git a/validity/templates/validity/report_devices.html b/validity/templates/validity/report_devices.html index 3a85de7..24101ec 100644 --- a/validity/templates/validity/report_devices.html +++ b/validity/templates/validity/report_devices.html @@ -1,8 +1,13 @@ {% extends 'validity/aux_tab_table.html' %} - +{% load buttons %} +{% load perms %} {% block title %}{{ object }}: Devices{% endblock %} - -{% block extra_button %} +{% block controls %} +
+{% if request.user|can_delete:object %} + {% delete_button object %} +{% endif %} +
{% endblock %} {% block table_title %}Devices tested within {{ object }}{% endblock %} diff --git a/validity/urls.py b/validity/urls.py index 413e0cd..118f907 100644 --- a/validity/urls.py +++ b/validity/urls.py @@ -28,7 +28,7 @@ path("pollers/", views.PollerListView.as_view(), name="poller_list"), path("pollers/add/", views.PollerEditView.as_view(), name="poller_add"), path("pollers/delete/", views.PollerBulkDeleteView.as_view(), name="poller_bulk_delete"), - path("pollers//", include(get_model_urls("validity", "Poller"))), + path("pollers//", include(get_model_urls("validity", "poller"))), path("commands/", views.CommandListView.as_view(), name="command_list"), path("commands/add/", views.CommandEditView.as_view(), name="command_add"), path("commands/delete/", views.CommandBulkDeleteView.as_view(), name="command_bulk_delete"), diff --git a/validity/utils/misc.py b/validity/utils/misc.py index 43e4a57..de3a297 100644 --- a/validity/utils/misc.py +++ b/validity/utils/misc.py @@ -1,5 +1,6 @@ +import inspect from concurrent.futures import ThreadPoolExecutor -from contextlib import contextmanager +from contextlib import contextmanager, suppress from typing import TYPE_CHECKING, Any, Callable, Iterable from core.exceptions import SyncError @@ -33,17 +34,28 @@ def null_request(): @contextmanager -def reraise(catch: type[Exception] | tuple[type[Exception], ...], raise_: type[Exception], msg: Any = None): +def reraise( + catch: type[Exception] | tuple[type[Exception], ...], + raise_: type[Exception], + *args, + orig_error_param="orig_error", + **kwargs, +): + """ + Catch one exception and raise another exception of different type, + args and kwargs will be passed to the newly generated exception + """ try: yield except raise_: raise - except catch as e: - if msg and isinstance(msg, str): - msg = msg.format(str(e)) - if not msg: - msg = str(e) - raise raise_(msg) from e + except catch as catched_err: + if not args: + args += (str(catched_err),) + with suppress(): + if orig_error_param in inspect.signature(raise_).parameters: + kwargs[orig_error_param] = catched_err + raise raise_(*args, **kwargs) from catched_err def datasource_sync( diff --git a/validity/utils/orm.py b/validity/utils/orm.py index 1e19474..26bfb42 100644 --- a/validity/utils/orm.py +++ b/validity/utils/orm.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from itertools import chain from typing import Any, Generic, Iterable, Iterator, TypeVar @@ -67,6 +68,21 @@ def all(self) -> Iterator[M]: yield from self.cache +@dataclass +class CustomPrefetch: + field: str + qs: QuerySet + many: bool + + pk_field = property(lambda self: self.field + "_id") + + def get_qs_map(self, main_queryset: QuerySet) -> QuerySetMap: + pk_values = main_queryset.values_list(self.pk_field, flat=True) + if self.many: + pk_values = chain.from_iterable(pk_values) + return QuerySetMap(self.qs.filter(pk__in=pk_values)) + + class CustomPrefetchMixin(QuerySet): """ Allows to prefetch objects without direct relations @@ -75,31 +91,30 @@ class CustomPrefetchMixin(QuerySet): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.custom_prefetches = {} + self.custom_prefetches = [] def custom_prefetch(self, field: str, prefetch_qs: QuerySet, many: bool = False): - pk_field = field + "_id" - pk_values = self.values_list(pk_field, flat=True) - if many: - pk_values = chain.from_iterable(pk_values) - prefetched_objects = prefetch_qs.filter(pk__in=pk_values) - self.custom_prefetches[field] = (many, QuerySetMap(prefetched_objects)) + self.custom_prefetches.append(CustomPrefetch(field, prefetch_qs, many)) return self + custom_prefetch.queryset_only = True + def _clone(self, *args, **kwargs): c = super()._clone(*args, **kwargs) - c.custom_prefetches = self.custom_prefetches + c.custom_prefetches = self.custom_prefetches.copy() return c def _fetch_all(self): super()._fetch_all() + qs_dicts = {custom_pf.field: custom_pf.get_qs_map(self) for custom_pf in self.custom_prefetches} for item in self._result_cache: if not isinstance(item, self.model): continue - for prefetched_field, (many, qs_dict) in self.custom_prefetches.items(): - prefetch_pk_values = getattr(item, prefetched_field + "_id") - if many: + for custom_prefetch in self.custom_prefetches: + prefetch_pk_values = getattr(item, custom_prefetch.pk_field) + qs_dict = qs_dicts[custom_prefetch.field] + if custom_prefetch.many: prefetch_values = M2MIterator(qs_dict[pk] for pk in prefetch_pk_values) else: prefetch_values = qs_dict.get(prefetch_pk_values) - setattr(item, prefetched_field, prefetch_values) + setattr(item, custom_prefetch.field, prefetch_values) diff --git a/validity/views/device.py b/validity/views/device.py index e730acf..e49d7b5 100644 --- a/validity/views/device.py +++ b/validity/views/device.py @@ -23,7 +23,7 @@ class TestResultView(TestResultBaseView): class DeviceSerializedConfigView(generic.ObjectView): template_name = "validity/device_config.html" tab = ViewTab("Serialized Config", permission="dcim.view_device") - queryset = VDevice.objects.prefetch_datasource().prefetch_serializer() + queryset = VDevice.objects.prefetch_datasource().prefetch_serializer().prefetch_poller() def get_extra_context(self, request, instance): try: