Skip to content

Commit

Permalink
tests fix
Browse files Browse the repository at this point in the history
  • Loading branch information
amyasnikov committed Nov 5, 2024
1 parent 0f74224 commit 8a93eb7
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
4 changes: 4 additions & 0 deletions docs/features/custom_pollers.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class ScrapliPoller(CustomPoller):
return resp.result
```

!!! note
Be aware that every poller class instance is usually responsible for interaction with multiple devices. Hence, do not use poller fields for storing device-specific parameters.


### Filling PollerInfo

Poller Info is required to tell Validity about your custom poller.
Expand Down
6 changes: 3 additions & 3 deletions validity/pollers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def poll(self, devices: Iterable["VDevice"]) -> Iterator[CommandResult | Polling

class DriverMixin:
driver_factory: Callable # Network driver class, e.g. netmiko.ConnectHandler
driver_connect_method: str
driver_disconnect_method: str
driver_connect_method: str = ""
driver_disconnect_method: str = ""

def connect(self, credentials: dict[str, Any]):
driver = type(self).driver_factory(**credentials)
Expand Down Expand Up @@ -111,7 +111,7 @@ class CustomPoller(ConsecutivePoller):
Base class for creating user-defined pollers
To define your own poller override the following attributes:
- driver_factory - class/function for creating connection to particular device
- host_param_name - name of the driver_factory parameter, which holds device IP address
- host_param_name - name of the driver parameter which holds device IP address
- poll_one_command() - method for sending one particular command to device and retrieving the result
- driver_connect_method - optional driver method name to initiate the connection
- driver_disconnect_method - optional driver method name to gracefully terminate the connection
Expand Down
20 changes: 16 additions & 4 deletions validity/tests/test_pollers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,26 @@

import pytest

from validity.pollers import NetmikoPoller, RequestsPoller
from validity.models.polling import Command
from validity.pollers import CustomPoller, NetmikoPoller, RequestsPoller
from validity.pollers.factory import PollerChoices
from validity.pollers.http import HttpDriver
from validity.settings import PollerInfo


@pytest.fixture
def custom_poller():
class MyCustomPoller(CustomPoller):
driver_factory = Mock(name="driver_factory")
driver_connect_method = "con"
driver_disconnect_method = "dis"

def poll_one_command(self, driver: time.Any, command: Command) -> str:
return super().poll_one_command(driver, command)

return MyCustomPoller


class TestNetmikoPoller:
@pytest.fixture
def get_mocked_poller(self, monkeypatch):
Expand All @@ -27,13 +41,11 @@ def _get_device(primary_ip):
return _get_device

@pytest.mark.django_db
def test_get_driver(self, get_mocked_poller, get_mocked_device):
def test_get_credentials(self, get_mocked_poller, get_mocked_device):
credentials = {"user": "admin", "password": "1234"}
poller = get_mocked_poller(credentials, [], Mock())
device = get_mocked_device("1.1.1.1")
assert poller.get_credentials(device) == credentials | {poller.host_param_name: "1.1.1.1"}
assert poller.get_driver(device) == poller.driver_factory.return_value
poller.driver_factory.assert_called_once_with(**credentials, **{poller.host_param_name: "1.1.1.1"})

def test_poll_one_command(self, get_mocked_poller):
poller = get_mocked_poller({}, [], Mock())
Expand Down

0 comments on commit 8a93eb7

Please sign in to comment.