diff --git a/docs/entities/commands.md b/docs/entities/commands.md index 1a47502..b36e337 100644 --- a/docs/entities/commands.md +++ b/docs/entities/commands.md @@ -39,11 +39,11 @@ This field defines [Serializer](serializers.md) for Command output. ## Parameters This block contains type-specific parameters. -### Type:CLI +### Type: CLI #### CLI Command This field must contain text string which is going to be sent to device when polling occurs. -### Type:NETCONF +### Type: NETCONF #### RPC This field must contain an XML RPC which is going to be sent to device via Netconf. @@ -80,3 +80,7 @@ Example: } } ``` + +### Type: Custom + +This type has been introduced especially for [Custom Pollers](../features/custom_pollers.md) support. For this type you can define arbitrary parameters (in a form of JSON object) and then use them inside your custom poller. diff --git a/docs/features/custom_pollers.md b/docs/features/custom_pollers.md new file mode 100644 index 0000000..22e04b3 --- /dev/null +++ b/docs/features/custom_pollers.md @@ -0,0 +1,73 @@ +# User-defined Pollers + +Validity is able to perform device polling via custom user-defined pollers. This feature may be useful when: + +* existing polling methods must be adjusted to work with specific network equipment (e.g. slightly modify `netmiko` to interact with some ancient switch); +* some completely new polling method must be introduced (e.g. gNMI-based). + +## Defining custom Poller + +To define your own Poller, two steps must be performed: + +* Inherit from `CustomPoller` class to implement your custom polling logic +* Fill out `PollerInfo` structure with Poller meta info + +### Implementing Poller class + +Here is the minimal viable example of a custom poller class. It uses `scrapli` library to connect to devices via SSH. + +```python +from scrapli import Scrapli +from validity.pollers import CustomPoller +from validity.models import Command + + +class ScrapliPoller(CustomPoller): + driver_factory = Scrapli + host_param_name = 'host' # Scrapli expects "host" param containing ip address of the device + driver_connect_method = 'open' # This driver method (if defined) will be called to open the connection. + driver_disconnect_method = 'close' # This driver method (if defined) will be called to gracefully close the connection. + + def poll_one_command(self, driver, command) -> str: + """ + Arguments: + driver - object returned by calling driver_factory, usually represents connection to a particular device + command - Django model instance of the Command + Returns: + A string containing particular command execution result + """ + resp = driver.send_command(command.parameters["cli_command"]) + return resp.result +``` + +!!! note + Be aware that every poller class instance is usually responsible for interaction with multiple devices. Hence, do not use poller fields for storing device-specific parameters. + + +### Filling PollerInfo + +Poller Info is required to tell Validity about your custom poller. +Here is the example of the plugin settings: + +```python +# configuration.py + +from validity.settings import PollerInfo +from my_awesome_poller import ScrapliPoller + +PLUGIN_SETTINGS = { + 'validity': { + 'custom_pollers' : [ + PollerInfo(klass=ScrapliPoller, name='scrapli', color='pink', command_types=['CLI']) + ] + } +} +``` + +PollerInfo parameters: + +* **klass** - class inherited from `CustomPoller` +* **name** - system name of the poller, must contain lowercase letters only +* **verbose_name** - optional verbose name of the poller. Will be used in NetBox GUI +* **color** - badge color used for "Connection Type" field in the GUI +* **command_types** - list of acceptable [Command](../entities/commands.md) types for this kind of Poller. Available choices are `CLI`, `netconf`, `json_api` and `custom` diff --git a/requirements/base.txt b/requirements/base.txt index 03a42f6..007573e 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,5 +1,5 @@ deepdiff>=6.2.0,<7 -dimi >=1.2.0,< 2 +dimi >=1.3.0,< 2 django-bootstrap5 >=24.2,<25 dulwich # Core NetBox "optional" requirement jq>=1.4.0,<2 diff --git a/requirements/docs.txt b/requirements/docs.txt index 8cdd4ce..fcfa074 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,3 +1,3 @@ mkdocs==1.6.1 -mkdocs-include-markdown-plugin==4.0.4 +mkdocs-include-markdown-plugin==7.0.0 mkdocs-material==9.5.34 diff --git a/validity/api/serializers.py b/validity/api/serializers.py index 5e83768..4216072 100644 --- a/validity/api/serializers.py +++ b/validity/api/serializers.py @@ -1,3 +1,5 @@ +from typing import Annotated + from core.api.nested_serializers import ( NestedDataFileSerializer as _NestedDataFileSerializer, ) @@ -27,7 +29,7 @@ from rest_framework.reverse import reverse from tenancy.models import Tenant -from validity import config, models +from validity import config, di, models from validity.choices import ExplanationVerbosityChoices from validity.netbox_changes import NestedTenantSerializer from .helpers import ( @@ -366,8 +368,9 @@ class Meta: ) brief_fields = ("id", "url", "display", "name") - def validate(self, data): - models.Poller.validate_commands(data["connection_type"], data["commands"]) + @di.inject + def validate(self, data, command_types: Annotated[dict[str, list[str]], "PollerChoices.command_types"]): + models.Poller.validate_commands(data["commands"], command_types, data["connection_type"]) return super().validate(data) diff --git a/validity/choices.py b/validity/choices.py index 08877e8..f530610 100644 --- a/validity/choices.py +++ b/validity/choices.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, TypeVar +from typing import Optional, TypeVar from django.db.models import IntegerChoices, TextChoices from django.db.models.enums import ChoicesMeta @@ -39,9 +39,13 @@ def colors(self): class MemberMixin: @classmethod - def member(cls: type[_Type], value: Any) -> Optional[_Type]: + def member(cls: type[_Type], value: str) -> Optional[_Type]: return cls._value2member_map_.get(value) # type: ignore + @classmethod + def contains(cls, value: str) -> bool: + return value in cls._value2member_map_ + class BoolOperationChoices(TextChoices, metaclass=ColoredChoiceMeta): OR = "OR", _("OR"), "purple" @@ -94,10 +98,6 @@ class DeviceGroupByChoices(MemberMixin, TextChoices): SITE = "device__site__slug", _("Site") TEST = "test__name", _("Test") - @classmethod - def contains(cls, value: str) -> bool: - return value in cls._value2member_map_ - def viewname(self) -> str: view_prefixes = {self.TENANT: "tenancy:", self.TEST: "plugins:validity:compliance"} default_prefix = "dcim:" @@ -109,22 +109,11 @@ def pk_field(self): return "__".join(pk_path) -class ConnectionTypeChoices(TextChoices, metaclass=ColoredChoiceMeta): - netmiko = "netmiko", "netmiko", "blue" - requests = "requests", "requests", "info" - scrapli_netconf = "scrapli_netconf", "scrapli_netconf", "orange" - - __command_types__ = {"netmiko": "CLI", "scrapli_netconf": "netconf", "requests": "json_api"} - - @property - def acceptable_command_type(self) -> "CommandTypeChoices": - return CommandTypeChoices[self.__command_types__[self.name]] - - class CommandTypeChoices(TextChoices, metaclass=ColoredChoiceMeta): CLI = "CLI", "CLI", "blue" netconf = "netconf", "orange" json_api = "json_api", "JSON API", "info" + custom = "custom", _("Custom"), "gray" class ExplanationVerbosityChoices(IntegerChoices): diff --git a/validity/dependencies.py b/validity/dependencies.py index c46593e..58144cb 100644 --- a/validity/dependencies.py +++ b/validity/dependencies.py @@ -8,9 +8,8 @@ from rq.job import Job from validity import di -from validity.choices import ConnectionTypeChoices from validity.pollers import NetmikoPoller, RequestsPoller, ScrapliNetconfPoller -from validity.settings import ValiditySettings +from validity.settings import PollerInfo, ValiditySettings from validity.utils.misc import null_request @@ -25,12 +24,20 @@ def validity_settings(django_settings: Annotated[LazySettings, django_settings]) @di.dependency(scope=Singleton) -def poller_map(): - return { - ConnectionTypeChoices.netmiko: NetmikoPoller, - ConnectionTypeChoices.requests: RequestsPoller, - ConnectionTypeChoices.scrapli_netconf: ScrapliNetconfPoller, - } +def pollers_info(custom_pollers: Annotated[list[PollerInfo], "validity_settings.custom_pollers"]) -> list[PollerInfo]: + return [ + PollerInfo(klass=NetmikoPoller, name="netmiko", verbose_name="netmiko", color="blue", command_types=["CLI"]), + PollerInfo( + klass=RequestsPoller, name="requests", verbose_name="requests", color="info", command_types=["json_api"] + ), + PollerInfo( + klass=ScrapliNetconfPoller, + name="scrapli_netconf", + verbose_name="scrapli_netconf", + color="orange", + command_types=["netconf"], + ), + ] + custom_pollers import validity.pollers.factory # noqa diff --git a/validity/forms/bulk_import.py b/validity/forms/bulk_import.py index a238ea1..50de098 100644 --- a/validity/forms/bulk_import.py +++ b/validity/forms/bulk_import.py @@ -8,8 +8,9 @@ from tenancy.models import Tenant from utilities.forms.fields import CSVChoiceField, CSVModelChoiceField, CSVModelMultipleChoiceField, JSONField -from validity import choices, models +from validity import choices, di, models from validity.api.helpers import SubformValidationMixin +from ..utils.misc import LazyIterator from .mixins import PollerCleanMixin @@ -77,7 +78,7 @@ def __init__(self, *args, headers=None, **kwargs): self.base_fields = base_fields super().__init__(*args, headers=headers, **kwargs) - def save(self, commit=True) -> choices.Any: + def save(self, commit=True): if (_global := self.cleaned_data.get("global")) is not None: self.instance._global = _global return super().save(commit) @@ -186,7 +187,9 @@ class Meta: class PollerImportForm(PollerCleanMixin, NetBoxModelImportForm): - connection_type = CSVChoiceField(choices=choices.ConnectionTypeChoices.choices, help_text=_("Connection Type")) + connection_type = CSVChoiceField( + choices=LazyIterator(lambda: di["PollerChoices"].choices), help_text=_("Connection Type") + ) commands = CSVModelMultipleChoiceField( queryset=models.Command.objects.all(), to_field_name="label", @@ -201,9 +204,6 @@ class PollerImportForm(PollerCleanMixin, NetBoxModelImportForm): required=False, ) - def full_clean(self) -> None: - return super().full_clean() - class Meta: model = models.Poller fields = ("name", "connection_type", "commands", "public_credentials", "private_credentials") diff --git a/validity/forms/filterset.py b/validity/forms/filterset.py index 3370a04..6099719 100644 --- a/validity/forms/filterset.py +++ b/validity/forms/filterset.py @@ -11,17 +11,17 @@ from utilities.forms.fields import DynamicModelMultipleChoiceField from utilities.forms.widgets import DateTimePicker -from validity import models +from validity import di, models from validity.choices import ( BoolOperationChoices, CommandTypeChoices, - ConnectionTypeChoices, DeviceGroupByChoices, DynamicPairsChoices, ExtractionMethodChoices, SeverityChoices, ) from validity.netbox_changes import FieldSet +from validity.utils.misc import LazyIterator from .fields import PlaceholderChoiceField from .mixins import AddM2MPlaceholderFormMixin, ExcludeMixin @@ -175,7 +175,7 @@ class PollerFilterForm(NetBoxModelFilterSetForm): model = models.Poller name = CharField(required=False) connection_type = PlaceholderChoiceField( - required=False, label=_("Connection Type"), choices=ConnectionTypeChoices.choices + required=False, label=_("Connection Type"), choices=LazyIterator(lambda: di["PollerChoices"].choices) ) diff --git a/validity/forms/general.py b/validity/forms/general.py index 3d85025..b2628dd 100644 --- a/validity/forms/general.py +++ b/validity/forms/general.py @@ -7,13 +7,13 @@ from extras.models import Tag from netbox.forms import NetBoxModelForm from tenancy.models import Tenant -from utilities.forms import add_blank_choice from utilities.forms.fields import DynamicModelChoiceField, DynamicModelMultipleChoiceField from utilities.forms.widgets import HTMXSelect -from validity import models -from validity.choices import ConnectionTypeChoices, ExplanationVerbosityChoices +from validity import di, models +from validity.choices import ExplanationVerbosityChoices from validity.netbox_changes import FieldSet +from validity.utils.misc import LazyIterator from .fields import DynamicModelChoicePropertyField, DynamicModelMultipleChoicePropertyField from .mixins import PollerCleanMixin, SubformMixin from .widgets import PrettyJSONWidget @@ -137,7 +137,8 @@ class Meta: class PollerForm(PollerCleanMixin, NetBoxModelForm): connection_type = ChoiceField( - choices=add_blank_choice(ConnectionTypeChoices.choices), widget=Select(attrs={"id": "connection_type_select"}) + choices=LazyIterator([(None, "---------")], lambda: di["PollerChoices"].choices), + widget=Select(attrs={"id": "connection_type_select"}), ) commands = DynamicModelMultipleChoiceField(queryset=models.Command.objects.all()) diff --git a/validity/forms/mixins.py b/validity/forms/mixins.py index 77a97b9..5a5d819 100644 --- a/validity/forms/mixins.py +++ b/validity/forms/mixins.py @@ -1,10 +1,10 @@ import json -from typing import Literal, Sequence +from typing import Annotated, Literal, Sequence from utilities.forms import get_field_value from utilities.forms.fields import DynamicModelMultipleChoiceField -from validity.models import Poller +from validity import di, models from validity.netbox_changes import FieldSet @@ -26,9 +26,10 @@ def __init__(self, *args, exclude: Sequence[str] = (), **kwargs) -> None: class PollerCleanMixin: - def clean(self): + @di.inject + def clean(self, command_types: Annotated[dict[str, list[str]], "PollerChoices.command_types"]): connection_type = self.cleaned_data.get("connection_type") or get_field_value(self, "connection_type") - Poller.validate_commands(connection_type, self.cleaned_data.get("commands", [])) + models.Poller.validate_commands(self.cleaned_data.get("commands", []), command_types, connection_type) return super().clean() diff --git a/validity/model_validators.py b/validity/model_validators.py new file mode 100644 index 0000000..dd6cc4a --- /dev/null +++ b/validity/model_validators.py @@ -0,0 +1,34 @@ +from typing import TYPE_CHECKING, Collection + +from django.core.exceptions import ValidationError +from django.utils.translation import gettext_lazy as _ + + +if TYPE_CHECKING: + from validity.models import Command + + +def only_one_config_command(commands: Collection["Command"]) -> None: + config_commands_count = sum(1 for cmd in commands if cmd.retrieves_config) + if config_commands_count > 1: + raise ValidationError( + { + "commands": _("No more than 1 command to retrieve config is allowed, but %(cnt)s were specified") + % {"cnt": config_commands_count} + } + ) + + +def commands_with_appropriate_type( + commands: Collection["Command"], command_types: dict[str, list[str]], connection_type: str +): + acceptable_command_types = command_types.get(connection_type, []) + if invalid_cmds := [cmd.label for cmd in commands if cmd.type not in acceptable_command_types]: + raise ValidationError( + { + "commands": _( + "The following commands have inappropriate type and cannot be bound to this Poller: %(cmds)s" + ) + % {"cmds": ", ".join(label for label in invalid_cmds)} + } + ) diff --git a/validity/models/polling.py b/validity/models/polling.py index d9a21e6..65c7acd 100644 --- a/validity/models/polling.py +++ b/validity/models/polling.py @@ -2,16 +2,17 @@ from typing import TYPE_CHECKING, Annotated, Collection from dcim.models import Device -from django.core.exceptions import ValidationError from django.core.validators import RegexValidator from django.db import models from django.utils.translation import gettext_lazy as _ from validity import di -from validity.choices import CommandTypeChoices, ConnectionTypeChoices +from validity.choices import CommandTypeChoices from validity.fields import EncryptedDictField from validity.managers import CommandQS, PollerQS -from validity.subforms import CLICommandForm, JSONAPICommandForm, NetconfCommandForm +from validity.model_validators import commands_with_appropriate_type, only_one_config_command +from validity.subforms import CLICommandForm, CustomCommandForm, JSONAPICommandForm, NetconfCommandForm +from validity.utils.misc import LazyIterator from .base import BaseModel, SubformMixin from .serializer import Serializer @@ -55,7 +56,12 @@ class Command(SubformMixin, BaseModel): subform_type_field = "type" subform_json_field = "parameters" - subforms = {"CLI": CLICommandForm, "json_api": JSONAPICommandForm, "netconf": NetconfCommandForm} + subforms = { + "CLI": CLICommandForm, + "json_api": JSONAPICommandForm, + "netconf": NetconfCommandForm, + "custom": CustomCommandForm, + } class Meta: ordering = ("name",) @@ -69,7 +75,9 @@ def get_type_color(self): class Poller(BaseModel): name = models.CharField(_("Name"), max_length=255, unique=True) - connection_type = models.CharField(_("Connection Type"), max_length=50, choices=ConnectionTypeChoices.choices) + connection_type = models.CharField( + _("Connection Type"), max_length=50, choices=LazyIterator(lambda: di["PollerChoices"].choices) + ) public_credentials = models.JSONField( _("Public Credentials"), default=dict, @@ -99,7 +107,7 @@ def credentials(self) -> dict: return self.public_credentials | self.private_credentials.decrypted def get_connection_type_color(self): - return ConnectionTypeChoices.colors.get(self.connection_type) + return di["PollerChoices"].colors.get(self.connection_type) @property def bound_devices(self) -> models.QuerySet[Device]: @@ -119,23 +127,6 @@ def get_backend(self, poller_factory: Annotated["PollerFactory", ...]): return poller_factory(self.connection_type, self.credentials, self.commands.all()) @staticmethod - def validate_commands(connection_type: str, commands: Collection[Command]): - # All the commands must be of the matching type - conn_type = ConnectionTypeChoices[connection_type] - if any(cmd.type != conn_type.acceptable_command_type for cmd in commands): - raise ValidationError( - { - "commands": _("%(conntype)s accepts only %(cmdtype)s commands") - % {"conntype": conn_type.label, "cmdtype": conn_type.acceptable_command_type.label} - } - ) - - # Only one bound "retrives config" command may exist - config_commands_count = sum(1 for cmd in commands if cmd.retrieves_config) - if config_commands_count > 1: - raise ValidationError( - { - "commands": _("No more than 1 command to retrieve config is allowed, but %(cnt)s were specified") - % {"cnt": config_commands_count} - } - ) + def validate_commands(commands: Collection[Command], command_types: dict[str, list[str]], connection_type: str): + commands_with_appropriate_type(commands, command_types, connection_type) + only_one_config_command(commands) diff --git a/validity/pollers/__init__.py b/validity/pollers/__init__.py index 890f508..e1dfe73 100644 --- a/validity/pollers/__init__.py +++ b/validity/pollers/__init__.py @@ -1,3 +1,4 @@ +from .base import BasePoller, CustomPoller from .cli import NetmikoPoller from .http import RequestsPoller from .netconf import ScrapliNetconfPoller diff --git a/validity/pollers/base.py b/validity/pollers/base.py index fc17fa9..2f8acf3 100644 --- a/validity/pollers/base.py +++ b/validity/pollers/base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import TYPE_CHECKING, Any, Collection, Iterable, Iterator +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable, Iterator from validity.utils.misc import reraise from .exceptions import PollingError @@ -11,7 +12,7 @@ from validity.models import Command, VDevice -class DevicePoller(ABC): +class BasePoller(ABC): host_param_name: str def __init__(self, credentials: dict, commands: Collection["Command"]) -> None: @@ -28,7 +29,7 @@ def get_credentials(self, device: "VDevice"): return self.credentials | {self.host_param_name: str(ip.address.ip)} -class ThreadPoller(DevicePoller): +class ThreadPoller(BasePoller): """ Polls devices one by one using threads """ @@ -65,11 +66,28 @@ def poll(self, devices: Iterable["VDevice"]) -> Iterator[CommandResult | Polling class DriverMixin: - driver_cls: type # Network driver class, e.g. netmiko.ConnectHandler - - def get_driver(self, device: "VDevice"): + driver_factory: Callable # Network driver class, e.g. netmiko.ConnectHandler + driver_connect_method: str = "" + driver_disconnect_method: str = "" + + def connect(self, credentials: dict[str, Any]): + driver = type(self).driver_factory(**credentials) + if self.driver_connect_method: + getattr(driver, self.driver_connect_method)() + return driver + + def disconnect(self, driver): + if self.driver_disconnect_method: + getattr(driver, self.driver_disconnect_method)() + + @contextmanager + def connection(self, device: "VDevice"): creds = self.get_credentials(device) - return self.driver_cls(**creds) + driver = self.connect(creds) + try: + yield driver + finally: + self.disconnect(driver) class ConsecutivePoller(DriverMixin, ThreadPoller): @@ -78,11 +96,23 @@ def poll_one_command(self, driver: Any, command: "Command") -> str: pass def poll_one_device(self, device: "VDevice") -> Iterator[CommandResult]: - driver = self.get_driver(device) - for command in self.commands: - try: - with reraise(Exception, PollingError, device_wide=False): - output = self.poll_one_command(driver, command) - yield CommandResult(device=device, command=command, result=output) - except PollingError as err: - yield CommandResult(device=device, command=command, error=err) + with self.connection(device) as driver: + for command in self.commands: + try: + with reraise(Exception, PollingError, device_wide=False): + output = self.poll_one_command(driver, command) + yield CommandResult(device=device, command=command, result=output) + except PollingError as err: + yield CommandResult(device=device, command=command, error=err) + + +class CustomPoller(ConsecutivePoller): + """ + Base class for creating user-defined pollers + To define your own poller override the following attributes: + - driver_factory - class/function for creating connection to particular device + - host_param_name - name of the driver parameter which holds device IP address + - poll_one_command() - method for sending one particular command to device and retrieving the result + - driver_connect_method - optional driver method name to initiate the connection + - driver_disconnect_method - optional driver method name to gracefully terminate the connection + """ diff --git a/validity/pollers/cli.py b/validity/pollers/cli.py index 1c5bef6..0bb2c01 100644 --- a/validity/pollers/cli.py +++ b/validity/pollers/cli.py @@ -11,7 +11,8 @@ class NetmikoPoller(ConsecutivePoller): host_param_name = "host" - driver_cls = staticmethod(ConnectHandler) # ConnectHandler is a function + driver_disconnect_method = "disconnect" + driver_factory = ConnectHandler def poll_one_command(self, driver: BaseConnection, command: "Command") -> str: return driver.send_command(command.parameters["cli_command"]) diff --git a/validity/pollers/factory.py b/validity/pollers/factory.py index 005d98f..09af03c 100644 --- a/validity/pollers/factory.py +++ b/validity/pollers/factory.py @@ -3,25 +3,41 @@ from dimi import Singleton from validity import di +from validity.settings import PollerInfo from validity.utils.misc import partialcls -from .base import DevicePoller, ThreadPoller +from .base import BasePoller, ThreadPoller if TYPE_CHECKING: from validity.models import Command +@di.dependency(scope=Singleton) +class PollerChoices: + def __init__(self, pollers_info: Annotated[list[PollerInfo], "pollers_info"]): + self.choices: list[tuple[str, str]] = [] + self.classes: dict[str, type] = {} + self.colors: dict[str, str] = {} + self.command_types: dict[str, Sequence[str]] = {} + + for info in pollers_info: + self.choices.append((info.name, info.verbose_name)) + self.classes[info.name] = info.klass + self.colors[info.name] = info.color + self.command_types[info.name] = info.command_types + + @di.dependency(scope=Singleton) class PollerFactory: def __init__( self, - poller_map: Annotated[dict, "poller_map"], + poller_map: Annotated[dict[str, type[BasePoller]], "PollerChoices.classes"], max_threads: Annotated[int, "validity_settings.polling_threads"], ) -> None: self.poller_map = poller_map self.max_threads = max_threads - def __call__(self, connection_type: str, credentials: dict, commands: Sequence["Command"]) -> DevicePoller: + def __call__(self, connection_type: str, credentials: dict, commands: Sequence["Command"]) -> BasePoller: if poller_cls := self.poller_map.get(connection_type): if issubclass(poller_cls, ThreadPoller): poller_cls = partialcls(poller_cls, thread_workers=self.max_threads) diff --git a/validity/pollers/http.py b/validity/pollers/http.py index 0dfc6ba..db9f1aa 100644 --- a/validity/pollers/http.py +++ b/validity/pollers/http.py @@ -49,7 +49,7 @@ def request(self, command: "Command", *, requests=requests) -> str: class RequestsPoller(ConsecutivePoller): - driver_cls = HttpDriver + driver_factory = HttpDriver def get_credentials(self, device: "VDevice"): return self.credentials | {"device": device} diff --git a/validity/pollers/netconf.py b/validity/pollers/netconf.py index b264781..7269f5b 100644 --- a/validity/pollers/netconf.py +++ b/validity/pollers/netconf.py @@ -10,10 +10,11 @@ class ScrapliNetconfPoller(ConsecutivePoller): - driver_cls = NetconfDriver + driver_factory = NetconfDriver + driver_connect_method = "open" + driver_disconnect_method = "close" host_param_name = "host" def poll_one_command(self, driver: NetconfDriver, command: "Command") -> str: - with driver: - response = driver.rpc(command.parameters["rpc"]) - return response.result + response = driver.rpc(command.parameters["rpc"]) + return response.result diff --git a/validity/settings.py b/validity/settings.py index d4339fc..c2602d9 100644 --- a/validity/settings.py +++ b/validity/settings.py @@ -1,8 +1,9 @@ -from typing import Annotated +from typing import Annotated, Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator from validity import di +from validity.pollers import BasePoller class ScriptTimeouts(BaseModel): @@ -15,12 +16,30 @@ class ScriptTimeouts(BaseModel): runtests_combine: int | str = "10m" +class PollerInfo(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + klass: type[BasePoller] + name: str = Field(pattern="[a-z_]+") + verbose_name: str = Field(default="", validate_default=True) + color: str = Field(pattern="[a-z-]+") + command_types: list[Literal["CLI", "netconf", "json_api", "custom"]] + + @field_validator("verbose_name") + @classmethod + def validate_verbose_name(cls, value, info): + if value: + return value + return " ".join(part.title() for part in info.data["name"].split("_")) + + class ValiditySettings(BaseModel): store_reports: int = Field(default=5, gt=0, lt=1001) result_batch_size: int = Field(default=500, ge=1) polling_threads: int = Field(default=500, ge=1) runtests_queue: str = "default" script_timeouts: ScriptTimeouts = ScriptTimeouts() + custom_pollers: list[PollerInfo] = [] class ValiditySettingsMixin: diff --git a/validity/static/validity/connection-type-select.js b/validity/static/validity/connection-type-select.js index 69ced43..aa942f0 100644 --- a/validity/static/validity/connection-type-select.js +++ b/validity/static/validity/connection-type-select.js @@ -6,14 +6,13 @@ function fillTextArea(public_creds, private_creds) { function fillCredentials(valueExtracter, connectionTypeInfo) { try { const connectionType = valueExtracter(connectionTypeInfo) - if (connectionType == "") - return; const defaultCredentials = JSON.parse(document.getElementById('default_credentials').textContent)[connectionType]; - fillTextArea(defaultCredentials.public, defaultCredentials.private); + if (defaultCredentials !== undefined) { + fillTextArea(defaultCredentials.public, defaultCredentials.private); + } } catch(e) { console.log(e.name, e.message) } - } window.onload = () => { diff --git a/validity/subforms.py b/validity/subforms.py index 146555a..afa1d61 100644 --- a/validity/subforms.py +++ b/validity/subforms.py @@ -60,6 +60,13 @@ def clean_rpc(self): return rpc +class CustomCommandForm(BaseSubform): + params = forms.JSONField(label=_("Command Parameters")) + + def clean(self): + return self.cleaned_data["params"] + + # Serializer Subforms diff --git a/validity/tests/test_custom_pollers.py b/validity/tests/test_custom_pollers.py new file mode 100644 index 0000000..be026ad --- /dev/null +++ b/validity/tests/test_custom_pollers.py @@ -0,0 +1,52 @@ +from http import HTTPStatus +from typing import Any +from unittest.mock import Mock + +import pytest +from factories import CommandFactory, PollerFactory + +from validity import config +from validity.dependencies import validity_settings +from validity.forms import PollerForm +from validity.models.polling import Command +from validity.pollers import CustomPoller +from validity.settings import PollerInfo, ValiditySettings + + +class MyCustomPoller(CustomPoller): + host_param_name = "ip_address" + driver_factory = Mock + + def poll_one_command(self, driver: Any, command: Command) -> str: + return "output" + + +@pytest.fixture +def custom_poller(db, di): + settings = ValiditySettings( + custom_pollers=[PollerInfo(klass=MyCustomPoller, name="cupo", color="red", command_types=["custom"])] + ) + with di.override({validity_settings: lambda: settings}): + yield PollerFactory(connection_type="cupo") + + +def test_custom_poller_model(custom_poller, di): + poller = PollerFactory(connection_type="cupo") + poller.commands.set([CommandFactory(type="custom")]) + backend = poller.get_backend() + assert isinstance(backend, MyCustomPoller) + assert poller.get_connection_type_color() == "red" + poller.validate_commands(poller.commands.all(), di["PollerChoices"].command_types, poller.connection_type) + + +def test_custom_poller_api(custom_poller, admin_client): + resp = admin_client.get(f"/api/plugins/validity/pollers/{custom_poller.pk}/") + assert resp.status_code == HTTPStatus.OK + assert resp.json()["connection_type"] == "cupo" + + +@pytest.mark.skipif(condition=config.version < "4", reason="netbox < 4.0") +def test_custom_poller_form(custom_poller): + form = PollerForm() + form_choices = {choice[0] for choice in form["connection_type"].field.choices} + assert "cupo" in form_choices diff --git a/validity/tests/test_models/test_clean.py b/validity/tests/test_models/test_clean.py index 38f83bb..36a2890 100644 --- a/validity/tests/test_models/test_clean.py +++ b/validity/tests/test_models/test_clean.py @@ -105,11 +105,13 @@ class TestPoller: "connection_type, command_type, is_valid", [("netmiko", "CLI", True), ("netmiko", "netconf", False)] ) @pytest.mark.django_db - def test_match_command_type(self, connection_type, command_type, is_valid): + def test_match_command_type(self, connection_type, command_type, is_valid, di): command = CommandFactory(type=command_type) ctx = nullcontext() if is_valid else pytest.raises(ValidationError) with ctx: - Poller.validate_commands(connection_type=connection_type, commands=[command]) + Poller.validate_commands( + connection_type=connection_type, commands=[command], command_types=di["PollerChoices"].command_types + ) @pytest.mark.parametrize( "retrive_config, is_valid", @@ -127,4 +129,4 @@ def only_one_config_command(self, retrive_config, is_valid): commands = [CommandFactory(type=t) for t in retrive_config] ctx = nullcontext() if is_valid else pytest.raises(ValidationError) with ctx: - Poller.validate_commands(connection_type="CLI", commands=commands) + Poller.validate_commands(connection_type="CLI", commands=commands, command_types={}) diff --git a/validity/tests/test_models/test_poller.py b/validity/tests/test_models/test_poller.py new file mode 100644 index 0000000..a9d5b06 --- /dev/null +++ b/validity/tests/test_models/test_poller.py @@ -0,0 +1,15 @@ +import pytest +from factories import PollerFactory + +from validity.pollers import NetmikoPoller, RequestsPoller, ScrapliNetconfPoller + + +@pytest.mark.parametrize( + "connection_type, poller_class", + [("netmiko", NetmikoPoller), ("requests", RequestsPoller), ("scrapli_netconf", ScrapliNetconfPoller)], +) +@pytest.mark.django_db +def test_get_backend(connection_type, poller_class): + poller = PollerFactory(connection_type=connection_type) + backend = poller.get_backend() + assert isinstance(backend, poller_class) diff --git a/validity/tests/test_pollers.py b/validity/tests/test_pollers.py index 6da47a4..15cc74d 100644 --- a/validity/tests/test_pollers.py +++ b/validity/tests/test_pollers.py @@ -3,15 +3,31 @@ import pytest -from validity.pollers import NetmikoPoller +from validity.models.polling import Command +from validity.pollers import CustomPoller, NetmikoPoller, RequestsPoller +from validity.pollers.factory import PollerChoices from validity.pollers.http import HttpDriver +from validity.settings import PollerInfo + + +@pytest.fixture +def custom_poller(): + class MyCustomPoller(CustomPoller): + driver_factory = Mock(name="driver_factory") + driver_connect_method = "con" + driver_disconnect_method = "dis" + + def poll_one_command(self, driver: time.Any, command: Command) -> str: + return super().poll_one_command(driver, command) + + return MyCustomPoller class TestNetmikoPoller: @pytest.fixture def get_mocked_poller(self, monkeypatch): def _get_poller(credentials, commands, mock): - monkeypatch.setattr(NetmikoPoller, "driver_cls", mock) + monkeypatch.setattr(NetmikoPoller, "driver_factory", mock) return NetmikoPoller(credentials, commands) return _get_poller @@ -25,13 +41,11 @@ def _get_device(primary_ip): return _get_device @pytest.mark.django_db - def test_get_driver(self, get_mocked_poller, get_mocked_device): + def test_get_credentials(self, get_mocked_poller, get_mocked_device): credentials = {"user": "admin", "password": "1234"} poller = get_mocked_poller(credentials, [], Mock()) device = get_mocked_device("1.1.1.1") assert poller.get_credentials(device) == credentials | {poller.host_param_name: "1.1.1.1"} - assert poller.get_driver(device) == poller.driver_cls.return_value - poller.driver_cls.assert_called_once_with(**credentials, **{poller.host_param_name: "1.1.1.1"}) def test_poll_one_command(self, get_mocked_poller): poller = get_mocked_poller({}, [], Mock()) @@ -84,3 +98,18 @@ def test_http_driver(): auth=None, ) assert result == requests.request.return_value.content.decode.return_value + + +def test_poller_choices(): + poller_choices = PollerChoices( + pollers_info=[ + PollerInfo(klass=NetmikoPoller, name="some_poller", color="red", command_types=["CLI"]), + PollerInfo( + klass=RequestsPoller, name="p2", verbose_name="P2", color="green", command_types=["json_api", "custom"] + ), + ] + ) + assert poller_choices.choices == [("some_poller", "Some Poller"), ("p2", "P2")] + assert poller_choices.colors == {"some_poller": "red", "p2": "green"} + assert poller_choices.classes == {"some_poller": NetmikoPoller, "p2": RequestsPoller} + assert poller_choices.command_types == {"some_poller": ["CLI"], "p2": ["json_api", "custom"]} diff --git a/validity/tests/test_utils/test_misc.py b/validity/tests/test_utils/test_misc.py index 3a6853f..5ae0a2d 100644 --- a/validity/tests/test_utils/test_misc.py +++ b/validity/tests/test_utils/test_misc.py @@ -5,7 +5,7 @@ import pytest -from validity.utils.misc import log_exceptions, partialcls, reraise +from validity.utils.misc import LazyIterator, log_exceptions, partialcls, reraise from validity.utils.version import NetboxVersion @@ -89,3 +89,14 @@ def test_log_exceptions(): with log_exceptions(logger, "info", log_traceback=True): raise ValueError("qwerty") logger.info.assert_called_once_with(msg="qwerty", exc_info=True) + + +def test_lazy_iterator(): + part1 = [10, 20, 30] + part2 = lambda: [40, 50] # noqa + part3 = Mock(return_value=[60]) + part4 = (70,) + iterator = LazyIterator(part1, part2, part3, part4) + part3.assert_not_called() + assert list(iterator) == [10, 20, 30, 40, 50, 60, 70] + assert list(iterator) == [10, 20, 30, 40, 50, 60, 70] # checking iterator is not exhausted diff --git a/validity/utils/misc.py b/validity/utils/misc.py index 880faa2..adf4831 100644 --- a/validity/utils/misc.py +++ b/validity/utils/misc.py @@ -1,12 +1,13 @@ import inspect from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager, suppress -from itertools import islice +from itertools import chain, islice from logging import Logger -from typing import TYPE_CHECKING, Any, Callable, Iterable +from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable from core.exceptions import SyncError from django.db.models import Q +from django.utils.functional import Promise from netbox.context import current_request @@ -112,3 +113,11 @@ def log_exceptions(logger: Logger, level: str, log_traceback=True): log_method = getattr(logger, level) log_method(msg=str(exc), exc_info=log_traceback) raise + + +class LazyIterator(Promise): + def __init__(self, *parts: Callable[[], Collection] | Collection): + self._parts = parts + + def __iter__(self): + yield from chain.from_iterable(part() if callable(part) else part for part in self._parts)