diff --git a/docs/features/custom_pollers.md b/docs/features/custom_pollers.md index 06602d5..3fea24a 100644 --- a/docs/features/custom_pollers.md +++ b/docs/features/custom_pollers.md @@ -39,6 +39,10 @@ class ScrapliPoller(CustomPoller): return resp.result ``` +!!! note + Be aware that every poller class instance is usually responsible for interaction with multiple devices. Hence, do not use poller fields for storing device-specific parameters. + + ### Filling PollerInfo Poller Info is required to tell Validity about your custom poller. diff --git a/validity/pollers/base.py b/validity/pollers/base.py index 9e56ebc..2f8acf3 100644 --- a/validity/pollers/base.py +++ b/validity/pollers/base.py @@ -67,8 +67,8 @@ def poll(self, devices: Iterable["VDevice"]) -> Iterator[CommandResult | Polling class DriverMixin: driver_factory: Callable # Network driver class, e.g. netmiko.ConnectHandler - driver_connect_method: str - driver_disconnect_method: str + driver_connect_method: str = "" + driver_disconnect_method: str = "" def connect(self, credentials: dict[str, Any]): driver = type(self).driver_factory(**credentials) @@ -111,7 +111,7 @@ class CustomPoller(ConsecutivePoller): Base class for creating user-defined pollers To define your own poller override the following attributes: - driver_factory - class/function for creating connection to particular device - - host_param_name - name of the driver_factory parameter, which holds device IP address + - host_param_name - name of the driver parameter which holds device IP address - poll_one_command() - method for sending one particular command to device and retrieving the result - driver_connect_method - optional driver method name to initiate the connection - driver_disconnect_method - optional driver method name to gracefully terminate the connection diff --git a/validity/tests/test_pollers.py b/validity/tests/test_pollers.py index 898642b..15cc74d 100644 --- a/validity/tests/test_pollers.py +++ b/validity/tests/test_pollers.py @@ -3,12 +3,26 @@ import pytest -from validity.pollers import NetmikoPoller, RequestsPoller +from validity.models.polling import Command +from validity.pollers import CustomPoller, NetmikoPoller, RequestsPoller from validity.pollers.factory import PollerChoices from validity.pollers.http import HttpDriver from validity.settings import PollerInfo +@pytest.fixture +def custom_poller(): + class MyCustomPoller(CustomPoller): + driver_factory = Mock(name="driver_factory") + driver_connect_method = "con" + driver_disconnect_method = "dis" + + def poll_one_command(self, driver: time.Any, command: Command) -> str: + return super().poll_one_command(driver, command) + + return MyCustomPoller + + class TestNetmikoPoller: @pytest.fixture def get_mocked_poller(self, monkeypatch): @@ -27,13 +41,11 @@ def _get_device(primary_ip): return _get_device @pytest.mark.django_db - def test_get_driver(self, get_mocked_poller, get_mocked_device): + def test_get_credentials(self, get_mocked_poller, get_mocked_device): credentials = {"user": "admin", "password": "1234"} poller = get_mocked_poller(credentials, [], Mock()) device = get_mocked_device("1.1.1.1") assert poller.get_credentials(device) == credentials | {poller.host_param_name: "1.1.1.1"} - assert poller.get_driver(device) == poller.driver_factory.return_value - poller.driver_factory.assert_called_once_with(**credentials, **{poller.host_param_name: "1.1.1.1"}) def test_poll_one_command(self, get_mocked_poller): poller = get_mocked_poller({}, [], Mock())