From 23d218a3f8c5a8189cb4b16def9d70172c0363ab Mon Sep 17 00:00:00 2001 From: Anton M Date: Sun, 20 Oct 2024 22:29:01 +0200 Subject: [PATCH] tests --- validity/dependencies.py | 1 - validity/forms/general.py | 2 +- validity/pollers/__init__.py | 2 +- validity/pollers/base.py | 4 ++-- validity/pollers/factory.py | 6 +++--- validity/settings.py | 4 ++-- validity/tests/test_models/test_poller.py | 14 ++++++++++++++ validity/tests/test_pollers.py | 19 ++++++++++++++++++- validity/tests/test_utils/test_misc.py | 13 ++++++++++++- 9 files changed, 53 insertions(+), 12 deletions(-) create mode 100644 validity/tests/test_models/test_poller.py diff --git a/validity/dependencies.py b/validity/dependencies.py index cf692da..b1f6a4e 100644 --- a/validity/dependencies.py +++ b/validity/dependencies.py @@ -42,7 +42,6 @@ def pollers_info(custom_pollers: Annotated[list[PollerInfo], "validity_settings. ] + custom_pollers -import validity.choices # noqa import validity.pollers.factory # noqa from validity.scripts import ApplyWorker, CombineWorker, Launcher, SplitWorker, Task # noqa diff --git a/validity/forms/general.py b/validity/forms/general.py index 9d37d65..b2628dd 100644 --- a/validity/forms/general.py +++ b/validity/forms/general.py @@ -13,7 +13,7 @@ from validity import di, models from validity.choices import ExplanationVerbosityChoices from validity.netbox_changes import FieldSet -from ..utils.misc import LazyIterator +from validity.utils.misc import LazyIterator from .fields import DynamicModelChoicePropertyField, DynamicModelMultipleChoicePropertyField from .mixins import PollerCleanMixin, SubformMixin from .widgets import PrettyJSONWidget diff --git a/validity/pollers/__init__.py b/validity/pollers/__init__.py index 23e53d6..e1dfe73 100644 --- a/validity/pollers/__init__.py +++ b/validity/pollers/__init__.py @@ -1,4 +1,4 @@ -from .base import CustomPoller, Poller +from .base import BasePoller, CustomPoller from .cli import NetmikoPoller from .http import RequestsPoller from .netconf import ScrapliNetconfPoller diff --git a/validity/pollers/base.py b/validity/pollers/base.py index 0b6122f..762d1de 100644 --- a/validity/pollers/base.py +++ b/validity/pollers/base.py @@ -11,7 +11,7 @@ from validity.models import Command, VDevice -class Poller(ABC): +class BasePoller(ABC): host_param_name: str def __init__(self, credentials: dict, commands: Collection["Command"]) -> None: @@ -28,7 +28,7 @@ def get_credentials(self, device: "VDevice"): return self.credentials | {self.host_param_name: str(ip.address.ip)} -class ThreadPoller(Poller): +class ThreadPoller(BasePoller): """ Polls devices one by one using threads """ diff --git a/validity/pollers/factory.py b/validity/pollers/factory.py index 6708868..09af03c 100644 --- a/validity/pollers/factory.py +++ b/validity/pollers/factory.py @@ -5,7 +5,7 @@ from validity import di from validity.settings import PollerInfo from validity.utils.misc import partialcls -from .base import Poller, ThreadPoller +from .base import BasePoller, ThreadPoller if TYPE_CHECKING: @@ -31,13 +31,13 @@ def __init__(self, pollers_info: Annotated[list[PollerInfo], "pollers_info"]): class PollerFactory: def __init__( self, - poller_map: Annotated[dict[str, type[Poller]], "PollerChoices.classes"], + poller_map: Annotated[dict[str, type[BasePoller]], "PollerChoices.classes"], max_threads: Annotated[int, "validity_settings.polling_threads"], ) -> None: self.poller_map = poller_map self.max_threads = max_threads - def __call__(self, connection_type: str, credentials: dict, commands: Sequence["Command"]) -> Poller: + def __call__(self, connection_type: str, credentials: dict, commands: Sequence["Command"]) -> BasePoller: if poller_cls := self.poller_map.get(connection_type): if issubclass(poller_cls, ThreadPoller): poller_cls = partialcls(poller_cls, thread_workers=self.max_threads) diff --git a/validity/settings.py b/validity/settings.py index 13a290d..de09316 100644 --- a/validity/settings.py +++ b/validity/settings.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator from validity import di -from validity.pollers import Poller +from validity.pollers import BasePoller class ScriptTimeouts(BaseModel): @@ -19,7 +19,7 @@ class ScriptTimeouts(BaseModel): class PollerInfo(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - klass: type[Poller] + klass: type[BasePoller] name: str = Field(pattern="[a-z_]+") verbose_name: str = Field(default="", validate_default=True) color: str = Field(pattern="[a-z-]+") diff --git a/validity/tests/test_models/test_poller.py b/validity/tests/test_models/test_poller.py new file mode 100644 index 0000000..236c5dc --- /dev/null +++ b/validity/tests/test_models/test_poller.py @@ -0,0 +1,14 @@ +import pytest +from factories import PollerFactory + +from validity.pollers import NetmikoPoller, RequestsPoller, ScrapliNetconfPoller + + +@pytest.mark.parametrize( + "connection_type, poller_class", + [("netmiko", NetmikoPoller), ("requests", RequestsPoller), ("scrapli_netconf", ScrapliNetconfPoller)], +) +def test_get_backend(connection_type, poller_class): + poller = PollerFactory(connection_type=connection_type) + backend = poller.get_backend() + assert isinstance(backend, poller_class) diff --git a/validity/tests/test_pollers.py b/validity/tests/test_pollers.py index 9e2a85d..898642b 100644 --- a/validity/tests/test_pollers.py +++ b/validity/tests/test_pollers.py @@ -3,8 +3,10 @@ import pytest -from validity.pollers import NetmikoPoller +from validity.pollers import NetmikoPoller, RequestsPoller +from validity.pollers.factory import PollerChoices from validity.pollers.http import HttpDriver +from validity.settings import PollerInfo class TestNetmikoPoller: @@ -84,3 +86,18 @@ def test_http_driver(): auth=None, ) assert result == requests.request.return_value.content.decode.return_value + + +def test_poller_choices(): + poller_choices = PollerChoices( + pollers_info=[ + PollerInfo(klass=NetmikoPoller, name="some_poller", color="red", command_types=["CLI"]), + PollerInfo( + klass=RequestsPoller, name="p2", verbose_name="P2", color="green", command_types=["json_api", "custom"] + ), + ] + ) + assert poller_choices.choices == [("some_poller", "Some Poller"), ("p2", "P2")] + assert poller_choices.colors == {"some_poller": "red", "p2": "green"} + assert poller_choices.classes == {"some_poller": NetmikoPoller, "p2": RequestsPoller} + assert poller_choices.command_types == {"some_poller": ["CLI"], "p2": ["json_api", "custom"]} diff --git a/validity/tests/test_utils/test_misc.py b/validity/tests/test_utils/test_misc.py index 3a6853f..5ae0a2d 100644 --- a/validity/tests/test_utils/test_misc.py +++ b/validity/tests/test_utils/test_misc.py @@ -5,7 +5,7 @@ import pytest -from validity.utils.misc import log_exceptions, partialcls, reraise +from validity.utils.misc import LazyIterator, log_exceptions, partialcls, reraise from validity.utils.version import NetboxVersion @@ -89,3 +89,14 @@ def test_log_exceptions(): with log_exceptions(logger, "info", log_traceback=True): raise ValueError("qwerty") logger.info.assert_called_once_with(msg="qwerty", exc_info=True) + + +def test_lazy_iterator(): + part1 = [10, 20, 30] + part2 = lambda: [40, 50] # noqa + part3 = Mock(return_value=[60]) + part4 = (70,) + iterator = LazyIterator(part1, part2, part3, part4) + part3.assert_not_called() + assert list(iterator) == [10, 20, 30, 40, 50, 60, 70] + assert list(iterator) == [10, 20, 30, 40, 50, 60, 70] # checking iterator is not exhausted