From dbac934e289fd48e3e470e4d57fc4b9a6b117c83 Mon Sep 17 00:00:00 2001 From: Anton M Date: Sun, 20 Oct 2024 20:52:26 +0200 Subject: [PATCH] fix existing tests --- validity/tests/test_models/test_clean.py | 8 +++++--- validity/tests/test_pollers.py | 6 +++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/validity/tests/test_models/test_clean.py b/validity/tests/test_models/test_clean.py index 38f83bb..36a2890 100644 --- a/validity/tests/test_models/test_clean.py +++ b/validity/tests/test_models/test_clean.py @@ -105,11 +105,13 @@ class TestPoller: "connection_type, command_type, is_valid", [("netmiko", "CLI", True), ("netmiko", "netconf", False)] ) @pytest.mark.django_db - def test_match_command_type(self, connection_type, command_type, is_valid): + def test_match_command_type(self, connection_type, command_type, is_valid, di): command = CommandFactory(type=command_type) ctx = nullcontext() if is_valid else pytest.raises(ValidationError) with ctx: - Poller.validate_commands(connection_type=connection_type, commands=[command]) + Poller.validate_commands( + connection_type=connection_type, commands=[command], command_types=di["PollerChoices"].command_types + ) @pytest.mark.parametrize( "retrive_config, is_valid", @@ -127,4 +129,4 @@ def only_one_config_command(self, retrive_config, is_valid): commands = [CommandFactory(type=t) for t in retrive_config] ctx = nullcontext() if is_valid else pytest.raises(ValidationError) with ctx: - Poller.validate_commands(connection_type="CLI", commands=commands) + Poller.validate_commands(connection_type="CLI", commands=commands, command_types={}) diff --git a/validity/tests/test_pollers.py b/validity/tests/test_pollers.py index 6da47a4..9e2a85d 100644 --- a/validity/tests/test_pollers.py +++ b/validity/tests/test_pollers.py @@ -11,7 +11,7 @@ class TestNetmikoPoller: @pytest.fixture def get_mocked_poller(self, monkeypatch): def _get_poller(credentials, commands, mock): - monkeypatch.setattr(NetmikoPoller, "driver_cls", mock) + monkeypatch.setattr(NetmikoPoller, "driver_factory", mock) return NetmikoPoller(credentials, commands) return _get_poller @@ -30,8 +30,8 @@ def test_get_driver(self, get_mocked_poller, get_mocked_device): poller = get_mocked_poller(credentials, [], Mock()) device = get_mocked_device("1.1.1.1") assert poller.get_credentials(device) == credentials | {poller.host_param_name: "1.1.1.1"} - assert poller.get_driver(device) == poller.driver_cls.return_value - poller.driver_cls.assert_called_once_with(**credentials, **{poller.host_param_name: "1.1.1.1"}) + assert poller.get_driver(device) == poller.driver_factory.return_value + poller.driver_factory.assert_called_once_with(**credentials, **{poller.host_param_name: "1.1.1.1"}) def test_poll_one_command(self, get_mocked_poller): poller = get_mocked_poller({}, [], Mock())