diff --git a/validity/api/serializers.py b/validity/api/serializers.py index 6b34e35..c71a97e 100644 --- a/validity/api/serializers.py +++ b/validity/api/serializers.py @@ -341,7 +341,7 @@ class Meta: ) def validate(self, data): - models.Poller.validate_commands(data["commands"]) + models.Poller.validate_commands(data["connection_type"], data["commands"]) return super().validate(data) diff --git a/validity/forms/general.py b/validity/forms/general.py index 6633c29..0984995 100644 --- a/validity/forms/general.py +++ b/validity/forms/general.py @@ -5,6 +5,7 @@ from extras.models import Tag from netbox.forms import NetBoxModelForm from tenancy.models import Tenant +from utilities.forms import get_field_value from utilities.forms.fields import DynamicModelMultipleChoiceField from utilities.forms.widgets import HTMXSelect @@ -125,7 +126,8 @@ class Meta: } def clean(self): - models.Poller.validate_commands(self.cleaned_data["commands"]) + connection_type = self.cleaned_data.get("connection_type") or get_field_value(self, "connection_type") + models.Poller.validate_commands(connection_type, self.cleaned_data["commands"]) return super().clean() diff --git a/validity/models/data.py b/validity/models/data.py index 116ca60..b356b55 100644 --- a/validity/models/data.py +++ b/validity/models/data.py @@ -75,8 +75,8 @@ def _sync_status(self): def partial_sync(self, device_filter: Q, batch_size: int = 1000) -> None: def update_batch(batch): for datafile in self.datafiles.filter(path__in=batch).iterator(): - datafile.refresh_from_disk(local_path) - yield datafile + if datafile.refresh_from_disk(local_path): + yield datafile paths.discard(datafile.path) def new_data_file(path): diff --git a/validity/tests/conftest.py b/validity/tests/conftest.py index 4d2158d..84dd9ff 100644 --- a/validity/tests/conftest.py +++ b/validity/tests/conftest.py @@ -9,7 +9,7 @@ from tenancy.models import Tenant import validity -from validity.models import ConfigSerializer +from validity.models import ConfigSerializer, Poller pytest.register_assert_rewrite("base") @@ -52,6 +52,12 @@ def create_custom_fields(db): type="string", required=False, ), + CustomField( + name="poller", + type="object", + object_type=ContentType.objects.get_for_model(Poller), + required=False, + ), ] ) cfs[0].content_types.set( @@ -62,8 +68,15 @@ def create_custom_fields(db): ] ) cfs[1].content_types.set([ContentType.objects.get_for_model(Tenant)]) - for cf in cfs[2:]: + for cf in cfs[2:5]: cf.content_types.set([ContentType.objects.get_for_model(DataSource)]) + cfs[5].content_types.set( + [ + ContentType.objects.get_for_model(Device), + ContentType.objects.get_for_model(DeviceType), + ContentType.objects.get_for_model(Manufacturer), + ] + ) @pytest.fixture diff --git a/validity/tests/factories.py b/validity/tests/factories.py index 1b43f26..1a652ac 100644 --- a/validity/tests/factories.py +++ b/validity/tests/factories.py @@ -213,3 +213,21 @@ class CompTestResultFactory(DjangoModelFactory): class Meta: model = models.ComplianceTestResult + + +class CommandFactory(DjangoModelFactory): + name = factory.Sequence(lambda n: f"command-{n}") + label = factory.Sequence(lambda n: f"command_{n}") + type = "CLI" + parameters = {"cli_command": "show run"} + + class Meta: + model = models.Command + + +class PollerFactory(DjangoModelFactory): + name = factory.Sequence(lambda n: f"poller-{n}") + connection_type = "netmiko" + + class Meta: + model = models.Poller diff --git a/validity/tests/test_api.py b/validity/tests/test_api.py index bb8db6b..49dd55d 100644 --- a/validity/tests/test_api.py +++ b/validity/tests/test_api.py @@ -5,6 +5,7 @@ from base import ApiGetTest, ApiPostGetTest from django.utils import timezone from factories import ( + CommandFactory, CompTestDBFactory, CompTestResultFactory, ConfigFileFactory, @@ -126,6 +127,27 @@ class TestReport(ApiGetTest): entity = "reports" +class TestCommand(ApiPostGetTest): + entity = "commands" + post_body = { + "name": "command-1", + "label": "command_1", + "type": "CLI", + "parameters": {"cli_command": "show version"}, + } + + +class TestPoller(ApiPostGetTest): + entity = "pollers" + post_body = { + "name": "poller-1", + "connection_type": "netmiko", + "public_credentials": {"username": "admin"}, + "private_credentials": {"password": "1234"}, + "commands": [CommandFactory, CommandFactory], + } + + @pytest.mark.django_db def test_get_serialized_config(monkeypatch, admin_client): device = DeviceFactory() diff --git a/validity/tests/test_models/test_clean.py b/validity/tests/test_models/test_clean.py index 8c30fbd..c772558 100644 --- a/validity/tests/test_models/test_clean.py +++ b/validity/tests/test_models/test_clean.py @@ -1,8 +1,11 @@ import textwrap +from contextlib import nullcontext import pytest from django.core.exceptions import ValidationError -from factories import CompTestDSFactory, NameSetDSFactory, SelectorFactory, SerializerDSFactory +from factories import CommandFactory, CompTestDSFactory, NameSetDSFactory, SelectorFactory, SerializerDSFactory + +from validity.models import Poller class BaseTestClean: @@ -95,3 +98,33 @@ class TestCompTest(BaseTestClean): {"expression": "a = 10 + 15", "data_source": None, "data_file": None}, {"expression": "import itertools; a==b", "data_source": None, "data_file": None}, ] + + +class TestPoller: + @pytest.mark.parametrize( + "connection_type, command_type, is_valid", [("netmiko", "CLI", True), ("netmiko", "netconf", False)] + ) + @pytest.mark.django_db + def test_match_command_type(self, connection_type, command_type, is_valid): + command = CommandFactory(type=command_type) + ctx = nullcontext() if is_valid else pytest.raises(ValidationError) + with ctx: + Poller.validate_commands(connection_type=connection_type, commands=[command]) + + @pytest.mark.parametrize( + "retrive_config, is_valid", + [ + ([True], True), + ([False], True), + ([False, True], True), + ([False, False, False], True), + ([True, True], False), + ([True, False, True], False), + ], + ) + @pytest.mark.django_db + def only_one_config_command(self, retrive_config, is_valid): + commands = [CommandFactory(type=t) for t in retrive_config] + ctx = nullcontext() if is_valid else pytest.raises(ValidationError) + with ctx: + Poller.validate_commands(connection_type="CLI", commands=commands) diff --git a/validity/tests/test_models/test_vdatasource.py b/validity/tests/test_models/test_vdatasource.py new file mode 100644 index 0000000..6c7b667 --- /dev/null +++ b/validity/tests/test_models/test_vdatasource.py @@ -0,0 +1,48 @@ +from contextlib import suppress +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest.mock import MagicMock + +import pytest +from core.choices import DataSourceStatusChoices +from factories import DataFileFactory, DataSourceFactory + +from validity.models import VDataFile, VDataSource + + +@pytest.mark.django_db +def test_sync_status(): + data_source = DataSourceFactory() + assert data_source.status != DataSourceStatusChoices.SYNCING + with data_source._sync_status(): + assert data_source.status == DataSourceStatusChoices.SYNCING + assert VDataSource.objects.get(pk=data_source.pk).status == DataSourceStatusChoices.SYNCING + assert data_source.status == DataSourceStatusChoices.COMPLETED + assert VDataSource.objects.get(pk=data_source.pk).status == DataSourceStatusChoices.COMPLETED + + with suppress(Exception): + with data_source._sync_status(): + raise ValueError + assert data_source.status == DataSourceStatusChoices.FAILED + assert VDataSource.objects.get(pk=data_source.pk).status == DataSourceStatusChoices.FAILED + + +@pytest.mark.django_db +def test_partial_sync(monkeypatch): + ds = DataSourceFactory(type="device_polling") + DataFileFactory(source=ds, data="some_contents".encode(), path="file-0.txt") + DataFileFactory(source=ds, path="file-1.txt") + with TemporaryDirectory() as temp_dir: + existing = Path(temp_dir) / "file-1.txt" + new = Path(temp_dir) / "file_new.txt" + existing.write_text("qwe") + new.write_text("rty") + fetch_mock = MagicMock(**{"return_value.fetch.return_value.__enter__.return_value": temp_dir}) + monkeypatch.setattr(ds, "get_backend", fetch_mock) + ds.partial_sync("device_filter") + fetch_mock().fetch.assert_called_once_with("device_filter") + assert {*ds.datafiles.values_list("path", flat=True)} == {"file-0.txt", "file-1.txt", "file_new.txt"} + assert VDataFile.objects.get(path="file-0.txt").data_as_string == "some_contents" + assert VDataFile.objects.get(path="file-1.txt").data_as_string == "qwe" + assert VDataFile.objects.get(path="file_new.txt").data_as_string == "rty" + assert VDataSource.objects.get(pk=ds.pk).status == DataSourceStatusChoices.COMPLETED diff --git a/validity/tests/test_pollers.py b/validity/tests/test_pollers.py new file mode 100644 index 0000000..cdf72bb --- /dev/null +++ b/validity/tests/test_pollers.py @@ -0,0 +1,60 @@ +import time +from unittest.mock import Mock + +import pytest + +from validity.pollers import NetmikoPoller + + +class TestNetmikoPoller: + @pytest.fixture + def get_mocked_poller(self, monkeypatch): + def _get_poller(credentials, commands, mock): + monkeypatch.setattr(NetmikoPoller, "driver_cls", mock) + return NetmikoPoller(credentials, commands) + + return _get_poller + + @pytest.fixture + def get_mocked_device(self): + def _get_device(primary_ip): + db_ip = Mock(address=Mock(ip=primary_ip)) + return Mock(primary_ip=db_ip) + + return _get_device + + @pytest.mark.django_db + def test_get_driver(self, get_mocked_poller, get_mocked_device): + credentials = {"user": "admin", "password": "1234"} + poller = get_mocked_poller(credentials, [], Mock()) + device = get_mocked_device("1.1.1.1") + assert poller.get_credentials(device) == credentials | {poller.host_param_name: "1.1.1.1"} + assert poller.get_driver(device) == poller.driver_cls.return_value + poller.driver_cls.assert_called_once_with(**credentials, **{poller.host_param_name: "1.1.1.1"}) + + def test_poll_one_command(self, get_mocked_poller): + poller = get_mocked_poller({}, [], Mock()) + driver = Mock(**{"send_command.return_value": 1234}) + command = Mock(parameters={"cli_command": "show ver"}) + assert poller.poll_one_command(driver, command) == 1234 + driver.send_command.assert_called_once_with("show ver") + + @pytest.mark.parametrize("raise_exc", [True, False]) + def test_poll(self, get_mocked_poller, raise_exc, get_mocked_device): + def poll(arg): + time.sleep(0.1) + if raise_exc: + raise OSError + return arg + + commands = [Mock(parameters={"cli_command": "a"}), Mock(parameters={"cli_command": "b"})] + poller = get_mocked_poller({}, commands, Mock(**{"return_value.send_command": poll})) + devices = [get_mocked_device(f"1.1.1.{i}") for i in range(10)] + start = time.time() + results = list(poller.poll(devices)) + assert time.time() - start < 0.3 + assert len(results) == len(commands) * len(devices) + if raise_exc: + assert all(res.error.message.startswith("OSError") for res in results) + else: + assert all(res.result in {"a", "b"} for res in results) diff --git a/validity/tests/test_scripts/test_run_tests.py b/validity/tests/test_scripts/test_run_tests.py index 2def742..67370cc 100644 --- a/validity/tests/test_scripts/test_run_tests.py +++ b/validity/tests/test_scripts/test_run_tests.py @@ -144,7 +144,9 @@ def test_run_tests_for_selector(mock_script_logging, monkeypatch): name="selector", **{ "devices.select_related.return_value" - ".prefetch_datasource.return_value.prefetch_serializer.return_value": devices + ".prefetch_datasource.return_value" + ".prefetch_serializer.return_value" + ".prefetch_poller.return_value": devices } ) report = Mock() diff --git a/validity/tests/test_utils/test_dbfields.py b/validity/tests/test_utils/test_dbfields.py new file mode 100644 index 0000000..128365d --- /dev/null +++ b/validity/tests/test_utils/test_dbfields.py @@ -0,0 +1,24 @@ +import pytest + +from validity.utils.dbfields import EncryptedDict, EncryptedString + + +@pytest.fixture +def setup_private_key(monkeypatch): + monkeypatch.setattr(EncryptedString, "secret_key", b"1234567890") + + +@pytest.mark.parametrize( + "plain_value", + [ + {"param1": "val1", "param2": "val2"}, + {}, + {"param": ["some", "complex", {"val": "ue"}]}, + ], +) +def test_encrypted_dict(plain_value, setup_private_key): + enc_dict = EncryptedDict(plain_value) + assert enc_dict.decrypted == plain_value + assert enc_dict.keys() == enc_dict.encrypted.keys() == plain_value.keys() + assert all(val.startswith("$") and val.endswith("$") and val.count("$") == 3 for val in enc_dict.encrypted.values()) + assert EncryptedDict(enc_dict.encrypted).decrypted == plain_value diff --git a/validity/tests/test_utils/test_misc.py b/validity/tests/test_utils/test_misc.py index 1d9208f..6ee7c8a 100644 --- a/validity/tests/test_utils/test_misc.py +++ b/validity/tests/test_utils/test_misc.py @@ -15,6 +15,12 @@ class Error2(Exception): pass +class Error3(Exception): + def __init__(self, *args: object, orig_error) -> None: + self.orig_error = orig_error + super().__init__(*args) + + @pytest.mark.parametrize( "internal_exc, external_exc, msg", [ @@ -30,11 +36,21 @@ def test_reraise(internal_exc, external_exc, msg): else nullcontext() ) with ctx: - with reraise(type(internal_exc), type(external_exc), msg): + args = () if msg is None else (msg,) + with reraise(type(internal_exc), type(external_exc), *args): if internal_exc is not None: raise internal_exc +def test_reraise_orig_error(): + try: + with reraise(TypeError, Error3): + raise TypeError("message") + except Error3 as e: + assert isinstance(e.orig_error, TypeError) + assert e.orig_error.args == ("message",) + + @pytest.mark.parametrize( "obj1, obj2, compare_results", [ diff --git a/validity/tests/test_views.py b/validity/tests/test_views.py index 3f174dc..fb5046e 100644 --- a/validity/tests/test_views.py +++ b/validity/tests/test_views.py @@ -4,6 +4,7 @@ import pytest from base import ViewTest from factories import ( + CommandFactory, CompTestDBFactory, CompTestResultFactory, ConfigFileFactory, @@ -16,6 +17,7 @@ NameSetDBFactory, NameSetDSFactory, PlatformFactory, + PollerFactory, ReportFactory, SelectorFactory, SerializerDBFactory, @@ -153,3 +155,26 @@ def test_report_devices(admin_client): report = ReportFactory(passed_results=4, failed_results=2) resp = admin_client.get(f"/plugins/validity/reports/{report.pk}/devices/") assert resp.status_code == HTTPStatus.OK + + +class TestPoller(ViewTest): + factory_class = PollerFactory + model_class = models.Poller + post_body = { + "name": "poller-1", + "connection_type": "netmiko", + "public_credentials": '{"username": "admin"}', + "private_credentials": '{"password": "ADMIN"}', + "commands": [CommandFactory, CommandFactory], + } + + +class TestCommand(ViewTest): + factory_class = CommandFactory + model_class = models.Command + post_body = { + "name": "command-1", + "label": "command_1", + "type": "CLI", + "cli_command": "show run", + } diff --git a/validity/utils/dbfields.py b/validity/utils/dbfields.py index d041558..b73f3de 100644 --- a/validity/utils/dbfields.py +++ b/validity/utils/dbfields.py @@ -8,9 +8,10 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from django import forms from django.conf import settings from django.core.serializers.json import DjangoJSONEncoder -from django.db.models import JSONField, Model +from django.db.models import Field, JSONField @dataclass @@ -93,13 +94,13 @@ def default(self, o: Any) -> Any: class EncryptedDictField(JSONField): def __init__(self, *args: Any, **kwargs: Any) -> None: - kwargs.setdefault("default", dict) + kwargs.setdefault("default", EncryptedDict) kwargs["encoder"] = EncryptedFieldEncoder super().__init__(*args, **kwargs) def deconstruct(self) -> Any: name, path, args, kwargs = super().deconstruct() - if kwargs.get("default") == dict: + if kwargs.get("default") == EncryptedDict: del kwargs["default"] del kwargs["encoder"] return name, path, args, kwargs @@ -121,5 +122,21 @@ def to_python(self, value): return value return EncryptedDict(value) - def validate(self, value: Any, model_instance: Model) -> None: - pass # TODO: add validation + def formfield(self, **kwargs): + return Field.formfield( + self, + **{ + "form_class": EncryptedDictFormField, + "encoder": self.encoder, + "decoder": self.decoder, + **kwargs, + }, + ) + + +class EncryptedDictFormField(forms.JSONField): + def to_python(self, value: Any) -> Any: + value = super().to_python(value) + if isinstance(value, dict): + value = EncryptedDict(value) + return value diff --git a/validity/utils/misc.py b/validity/utils/misc.py index ab17194..09baece 100644 --- a/validity/utils/misc.py +++ b/validity/utils/misc.py @@ -53,7 +53,7 @@ def reraise( except catch as catched_err: if not args: args += (str(catched_err),) - with suppress(): + with suppress(Exception): if orig_error_param in inspect.signature(raise_).parameters: kwargs[orig_error_param] = catched_err raise raise_(*args, **kwargs) from catched_err @@ -81,7 +81,7 @@ def sync_func(datasource): def batched(iterable: Iterable, n: int, container: type = list): """ - Batch data into containers of length n + Batch data into containers of length n. Equal to python3.12 itertools.batched """ it = iter(iterable) while True: