Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
amyasnikov committed Dec 23, 2023
1 parent 309bc52 commit 8323582
Show file tree
Hide file tree
Showing 15 changed files with 296 additions and 16 deletions.
2 changes: 1 addition & 1 deletion validity/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ class Meta:
)

def validate(self, data):
models.Poller.validate_commands(data["commands"])
models.Poller.validate_commands(data["connection_type"], data["commands"])
return super().validate(data)


Expand Down
4 changes: 3 additions & 1 deletion validity/forms/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from extras.models import Tag
from netbox.forms import NetBoxModelForm
from tenancy.models import Tenant
from utilities.forms import get_field_value
from utilities.forms.fields import DynamicModelMultipleChoiceField
from utilities.forms.widgets import HTMXSelect

Expand Down Expand Up @@ -125,7 +126,8 @@ class Meta:
}

def clean(self):
models.Poller.validate_commands(self.cleaned_data["commands"])
connection_type = self.cleaned_data.get("connection_type") or get_field_value(self, "connection_type")
models.Poller.validate_commands(connection_type, self.cleaned_data["commands"])
return super().clean()


Expand Down
4 changes: 2 additions & 2 deletions validity/models/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def _sync_status(self):
def partial_sync(self, device_filter: Q, batch_size: int = 1000) -> None:
def update_batch(batch):
for datafile in self.datafiles.filter(path__in=batch).iterator():
datafile.refresh_from_disk(local_path)
yield datafile
if datafile.refresh_from_disk(local_path):
yield datafile
paths.discard(datafile.path)

def new_data_file(path):
Expand Down
17 changes: 15 additions & 2 deletions validity/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tenancy.models import Tenant

import validity
from validity.models import ConfigSerializer
from validity.models import ConfigSerializer, Poller


pytest.register_assert_rewrite("base")
Expand Down Expand Up @@ -52,6 +52,12 @@ def create_custom_fields(db):
type="string",
required=False,
),
CustomField(
name="poller",
type="object",
object_type=ContentType.objects.get_for_model(Poller),
required=False,
),
]
)
cfs[0].content_types.set(
Expand All @@ -62,8 +68,15 @@ def create_custom_fields(db):
]
)
cfs[1].content_types.set([ContentType.objects.get_for_model(Tenant)])
for cf in cfs[2:]:
for cf in cfs[2:5]:
cf.content_types.set([ContentType.objects.get_for_model(DataSource)])
cfs[5].content_types.set(
[
ContentType.objects.get_for_model(Device),
ContentType.objects.get_for_model(DeviceType),
ContentType.objects.get_for_model(Manufacturer),
]
)


@pytest.fixture
Expand Down
18 changes: 18 additions & 0 deletions validity/tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,21 @@ class CompTestResultFactory(DjangoModelFactory):

class Meta:
model = models.ComplianceTestResult


class CommandFactory(DjangoModelFactory):
name = factory.Sequence(lambda n: f"command-{n}")
label = factory.Sequence(lambda n: f"command_{n}")
type = "CLI"
parameters = {"cli_command": "show run"}

class Meta:
model = models.Command


class PollerFactory(DjangoModelFactory):
name = factory.Sequence(lambda n: f"poller-{n}")
connection_type = "netmiko"

class Meta:
model = models.Poller
22 changes: 22 additions & 0 deletions validity/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from base import ApiGetTest, ApiPostGetTest
from django.utils import timezone
from factories import (
CommandFactory,
CompTestDBFactory,
CompTestResultFactory,
ConfigFileFactory,
Expand Down Expand Up @@ -126,6 +127,27 @@ class TestReport(ApiGetTest):
entity = "reports"


class TestCommand(ApiPostGetTest):
entity = "commands"
post_body = {
"name": "command-1",
"label": "command_1",
"type": "CLI",
"parameters": {"cli_command": "show version"},
}


class TestPoller(ApiPostGetTest):
entity = "pollers"
post_body = {
"name": "poller-1",
"connection_type": "netmiko",
"public_credentials": {"username": "admin"},
"private_credentials": {"password": "1234"},
"commands": [CommandFactory, CommandFactory],
}


@pytest.mark.django_db
def test_get_serialized_config(monkeypatch, admin_client):
device = DeviceFactory()
Expand Down
35 changes: 34 additions & 1 deletion validity/tests/test_models/test_clean.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import textwrap
from contextlib import nullcontext

import pytest
from django.core.exceptions import ValidationError
from factories import CompTestDSFactory, NameSetDSFactory, SelectorFactory, SerializerDSFactory
from factories import CommandFactory, CompTestDSFactory, NameSetDSFactory, SelectorFactory, SerializerDSFactory

from validity.models import Poller


class BaseTestClean:
Expand Down Expand Up @@ -95,3 +98,33 @@ class TestCompTest(BaseTestClean):
{"expression": "a = 10 + 15", "data_source": None, "data_file": None},
{"expression": "import itertools; a==b", "data_source": None, "data_file": None},
]


class TestPoller:
@pytest.mark.parametrize(
"connection_type, command_type, is_valid", [("netmiko", "CLI", True), ("netmiko", "netconf", False)]
)
@pytest.mark.django_db
def test_match_command_type(self, connection_type, command_type, is_valid):
command = CommandFactory(type=command_type)
ctx = nullcontext() if is_valid else pytest.raises(ValidationError)
with ctx:
Poller.validate_commands(connection_type=connection_type, commands=[command])

@pytest.mark.parametrize(
"retrive_config, is_valid",
[
([True], True),
([False], True),
([False, True], True),
([False, False, False], True),
([True, True], False),
([True, False, True], False),
],
)
@pytest.mark.django_db
def only_one_config_command(self, retrive_config, is_valid):
commands = [CommandFactory(type=t) for t in retrive_config]
ctx = nullcontext() if is_valid else pytest.raises(ValidationError)
with ctx:
Poller.validate_commands(connection_type="CLI", commands=commands)
48 changes: 48 additions & 0 deletions validity/tests/test_models/test_vdatasource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from contextlib import suppress
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest.mock import MagicMock

import pytest
from core.choices import DataSourceStatusChoices
from factories import DataFileFactory, DataSourceFactory

from validity.models import VDataFile, VDataSource


@pytest.mark.django_db
def test_sync_status():
data_source = DataSourceFactory()
assert data_source.status != DataSourceStatusChoices.SYNCING
with data_source._sync_status():
assert data_source.status == DataSourceStatusChoices.SYNCING
assert VDataSource.objects.get(pk=data_source.pk).status == DataSourceStatusChoices.SYNCING
assert data_source.status == DataSourceStatusChoices.COMPLETED
assert VDataSource.objects.get(pk=data_source.pk).status == DataSourceStatusChoices.COMPLETED

with suppress(Exception):
with data_source._sync_status():
raise ValueError
assert data_source.status == DataSourceStatusChoices.FAILED
assert VDataSource.objects.get(pk=data_source.pk).status == DataSourceStatusChoices.FAILED


@pytest.mark.django_db
def test_partial_sync(monkeypatch):
ds = DataSourceFactory(type="device_polling")
DataFileFactory(source=ds, data="some_contents".encode(), path="file-0.txt")
DataFileFactory(source=ds, path="file-1.txt")
with TemporaryDirectory() as temp_dir:
existing = Path(temp_dir) / "file-1.txt"
new = Path(temp_dir) / "file_new.txt"
existing.write_text("qwe")
new.write_text("rty")
fetch_mock = MagicMock(**{"return_value.fetch.return_value.__enter__.return_value": temp_dir})
monkeypatch.setattr(ds, "get_backend", fetch_mock)
ds.partial_sync("device_filter")
fetch_mock().fetch.assert_called_once_with("device_filter")
assert {*ds.datafiles.values_list("path", flat=True)} == {"file-0.txt", "file-1.txt", "file_new.txt"}
assert VDataFile.objects.get(path="file-0.txt").data_as_string == "some_contents"
assert VDataFile.objects.get(path="file-1.txt").data_as_string == "qwe"
assert VDataFile.objects.get(path="file_new.txt").data_as_string == "rty"
assert VDataSource.objects.get(pk=ds.pk).status == DataSourceStatusChoices.COMPLETED
60 changes: 60 additions & 0 deletions validity/tests/test_pollers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import time
from unittest.mock import Mock

import pytest

from validity.pollers import NetmikoPoller


class TestNetmikoPoller:
@pytest.fixture
def get_mocked_poller(self, monkeypatch):
def _get_poller(credentials, commands, mock):
monkeypatch.setattr(NetmikoPoller, "driver_cls", mock)
return NetmikoPoller(credentials, commands)

return _get_poller

@pytest.fixture
def get_mocked_device(self):
def _get_device(primary_ip):
db_ip = Mock(address=Mock(ip=primary_ip))
return Mock(primary_ip=db_ip)

return _get_device

@pytest.mark.django_db
def test_get_driver(self, get_mocked_poller, get_mocked_device):
credentials = {"user": "admin", "password": "1234"}
poller = get_mocked_poller(credentials, [], Mock())
device = get_mocked_device("1.1.1.1")
assert poller.get_credentials(device) == credentials | {poller.host_param_name: "1.1.1.1"}
assert poller.get_driver(device) == poller.driver_cls.return_value
poller.driver_cls.assert_called_once_with(**credentials, **{poller.host_param_name: "1.1.1.1"})

def test_poll_one_command(self, get_mocked_poller):
poller = get_mocked_poller({}, [], Mock())
driver = Mock(**{"send_command.return_value": 1234})
command = Mock(parameters={"cli_command": "show ver"})
assert poller.poll_one_command(driver, command) == 1234
driver.send_command.assert_called_once_with("show ver")

@pytest.mark.parametrize("raise_exc", [True, False])
def test_poll(self, get_mocked_poller, raise_exc, get_mocked_device):
def poll(arg):
time.sleep(0.1)
if raise_exc:
raise OSError
return arg

commands = [Mock(parameters={"cli_command": "a"}), Mock(parameters={"cli_command": "b"})]
poller = get_mocked_poller({}, commands, Mock(**{"return_value.send_command": poll}))
devices = [get_mocked_device(f"1.1.1.{i}") for i in range(10)]
start = time.time()
results = list(poller.poll(devices))
assert time.time() - start < 0.3
assert len(results) == len(commands) * len(devices)
if raise_exc:
assert all(res.error.message.startswith("OSError") for res in results)
else:
assert all(res.result in {"a", "b"} for res in results)
4 changes: 3 additions & 1 deletion validity/tests/test_scripts/test_run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ def test_run_tests_for_selector(mock_script_logging, monkeypatch):
name="selector",
**{
"devices.select_related.return_value"
".prefetch_datasource.return_value.prefetch_serializer.return_value": devices
".prefetch_datasource.return_value"
".prefetch_serializer.return_value"
".prefetch_poller.return_value": devices
}
)
report = Mock()
Expand Down
24 changes: 24 additions & 0 deletions validity/tests/test_utils/test_dbfields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest

from validity.utils.dbfields import EncryptedDict, EncryptedString


@pytest.fixture
def setup_private_key(monkeypatch):
monkeypatch.setattr(EncryptedString, "secret_key", b"1234567890")


@pytest.mark.parametrize(
"plain_value",
[
{"param1": "val1", "param2": "val2"},
{},
{"param": ["some", "complex", {"val": "ue"}]},
],
)
def test_encrypted_dict(plain_value, setup_private_key):
enc_dict = EncryptedDict(plain_value)
assert enc_dict.decrypted == plain_value
assert enc_dict.keys() == enc_dict.encrypted.keys() == plain_value.keys()
assert all(val.startswith("$") and val.endswith("$") and val.count("$") == 3 for val in enc_dict.encrypted.values())
assert EncryptedDict(enc_dict.encrypted).decrypted == plain_value
18 changes: 17 additions & 1 deletion validity/tests/test_utils/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ class Error2(Exception):
pass


class Error3(Exception):
def __init__(self, *args: object, orig_error) -> None:
self.orig_error = orig_error
super().__init__(*args)


@pytest.mark.parametrize(
"internal_exc, external_exc, msg",
[
Expand All @@ -30,11 +36,21 @@ def test_reraise(internal_exc, external_exc, msg):
else nullcontext()
)
with ctx:
with reraise(type(internal_exc), type(external_exc), msg):
args = () if msg is None else (msg,)
with reraise(type(internal_exc), type(external_exc), *args):
if internal_exc is not None:
raise internal_exc


def test_reraise_orig_error():
try:
with reraise(TypeError, Error3):
raise TypeError("message")
except Error3 as e:
assert isinstance(e.orig_error, TypeError)
assert e.orig_error.args == ("message",)


@pytest.mark.parametrize(
"obj1, obj2, compare_results",
[
Expand Down
Loading

0 comments on commit 8323582

Please sign in to comment.