Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
amyasnikov committed Oct 20, 2024
1 parent dbac934 commit 22f1757
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 14 deletions.
1 change: 0 additions & 1 deletion validity/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def pollers_info(custom_pollers: Annotated[list[PollerInfo], "validity_settings.
] + custom_pollers


import validity.choices # noqa
import validity.pollers.factory # noqa
from validity.scripts import ApplyWorker, CombineWorker, Launcher, SplitWorker, Task # noqa

Expand Down
2 changes: 1 addition & 1 deletion validity/forms/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from validity import di, models
from validity.choices import ExplanationVerbosityChoices
from validity.netbox_changes import FieldSet
from ..utils.misc import LazyIterator
from validity.utils.misc import LazyIterator
from .fields import DynamicModelChoicePropertyField, DynamicModelMultipleChoicePropertyField
from .mixins import PollerCleanMixin, SubformMixin
from .widgets import PrettyJSONWidget
Expand Down
2 changes: 1 addition & 1 deletion validity/pollers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import CustomPoller, Poller
from .base import BasePoller, CustomPoller
from .cli import NetmikoPoller
from .http import RequestsPoller
from .netconf import ScrapliNetconfPoller
4 changes: 2 additions & 2 deletions validity/pollers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from validity.models import Command, VDevice


class Poller(ABC):
class BasePoller(ABC):
host_param_name: str

def __init__(self, credentials: dict, commands: Collection["Command"]) -> None:
Expand All @@ -28,7 +28,7 @@ def get_credentials(self, device: "VDevice"):
return self.credentials | {self.host_param_name: str(ip.address.ip)}


class ThreadPoller(Poller):
class ThreadPoller(BasePoller):
"""
Polls devices one by one using threads
"""
Expand Down
6 changes: 3 additions & 3 deletions validity/pollers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from validity import di
from validity.settings import PollerInfo
from validity.utils.misc import partialcls
from .base import Poller, ThreadPoller
from .base import BasePoller, ThreadPoller


if TYPE_CHECKING:
Expand All @@ -31,13 +31,13 @@ def __init__(self, pollers_info: Annotated[list[PollerInfo], "pollers_info"]):
class PollerFactory:
def __init__(
self,
poller_map: Annotated[dict[str, type[Poller]], "PollerChoices.classes"],
poller_map: Annotated[dict[str, type[BasePoller]], "PollerChoices.classes"],
max_threads: Annotated[int, "validity_settings.polling_threads"],
) -> None:
self.poller_map = poller_map
self.max_threads = max_threads

def __call__(self, connection_type: str, credentials: dict, commands: Sequence["Command"]) -> Poller:
def __call__(self, connection_type: str, credentials: dict, commands: Sequence["Command"]) -> BasePoller:
if poller_cls := self.poller_map.get(connection_type):
if issubclass(poller_cls, ThreadPoller):
poller_cls = partialcls(poller_cls, thread_workers=self.max_threads)
Expand Down
8 changes: 4 additions & 4 deletions validity/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel, ConfigDict, Field, field_validator

from validity import di
from validity.pollers import Poller
from validity.pollers import BasePoller


class ScriptTimeouts(BaseModel):
Expand All @@ -19,18 +19,18 @@ class ScriptTimeouts(BaseModel):
class PollerInfo(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

klass: type[Poller]
klass: type[BasePoller]
name: str = Field(pattern="[a-z_]+")
verbose_name: str = Field(default="", validate_default=True)
color: str = Field(pattern="[a-z-]+")
command_types: list[Literal["CLI", "netconf", "json_api", "custom"]]

@field_validator("verbose_name")
@classmethod
def validate_verbose_name(cls, value):
def validate_verbose_name(cls, value, info):
if value:
return value
return " ".join(part.title() for part in value.split("_"))
return " ".join(part.title() for part in info.data["name"].split("_"))


class ValiditySettings(BaseModel):
Expand Down
15 changes: 15 additions & 0 deletions validity/tests/test_models/test_poller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest
from factories import PollerFactory

from validity.pollers import NetmikoPoller, RequestsPoller, ScrapliNetconfPoller


@pytest.mark.parametrize(
"connection_type, poller_class",
[("netmiko", NetmikoPoller), ("requests", RequestsPoller), ("scrapli_netconf", ScrapliNetconfPoller)],
)
@pytest.mark.django_db
def test_get_backend(connection_type, poller_class):
poller = PollerFactory(connection_type=connection_type)
backend = poller.get_backend()
assert isinstance(backend, poller_class)
19 changes: 18 additions & 1 deletion validity/tests/test_pollers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

import pytest

from validity.pollers import NetmikoPoller
from validity.pollers import NetmikoPoller, RequestsPoller
from validity.pollers.factory import PollerChoices
from validity.pollers.http import HttpDriver
from validity.settings import PollerInfo


class TestNetmikoPoller:
Expand Down Expand Up @@ -84,3 +86,18 @@ def test_http_driver():
auth=None,
)
assert result == requests.request.return_value.content.decode.return_value


def test_poller_choices():
poller_choices = PollerChoices(
pollers_info=[
PollerInfo(klass=NetmikoPoller, name="some_poller", color="red", command_types=["CLI"]),
PollerInfo(
klass=RequestsPoller, name="p2", verbose_name="P2", color="green", command_types=["json_api", "custom"]
),
]
)
assert poller_choices.choices == [("some_poller", "Some Poller"), ("p2", "P2")]
assert poller_choices.colors == {"some_poller": "red", "p2": "green"}
assert poller_choices.classes == {"some_poller": NetmikoPoller, "p2": RequestsPoller}
assert poller_choices.command_types == {"some_poller": ["CLI"], "p2": ["json_api", "custom"]}
13 changes: 12 additions & 1 deletion validity/tests/test_utils/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from validity.utils.misc import log_exceptions, partialcls, reraise
from validity.utils.misc import LazyIterator, log_exceptions, partialcls, reraise
from validity.utils.version import NetboxVersion


Expand Down Expand Up @@ -89,3 +89,14 @@ def test_log_exceptions():
with log_exceptions(logger, "info", log_traceback=True):
raise ValueError("qwerty")
logger.info.assert_called_once_with(msg="qwerty", exc_info=True)


def test_lazy_iterator():
part1 = [10, 20, 30]
part2 = lambda: [40, 50] # noqa
part3 = Mock(return_value=[60])
part4 = (70,)
iterator = LazyIterator(part1, part2, part3, part4)
part3.assert_not_called()
assert list(iterator) == [10, 20, 30, 40, 50, 60, 70]
assert list(iterator) == [10, 20, 30, 40, 50, 60, 70] # checking iterator is not exhausted

0 comments on commit 22f1757

Please sign in to comment.