diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 47045d4..fe95013 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,7 +28,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - netbox_version: [v3.4.10, v3.5.9, v3.6.2] + netbox_version: [v3.5.9, v3.6.5] steps: - name: Checkout uses: actions/checkout@v3 diff --git a/validity/api/serializers.py b/validity/api/serializers.py index 5863b10..8c3f47e 100644 --- a/validity/api/serializers.py +++ b/validity/api/serializers.py @@ -262,13 +262,14 @@ def run_validation(self, data=...): class SerializedConfigSerializer(serializers.Serializer): serializer = NestedConfigSerializerSerializer(read_only=True, source="device.serializer") - data_source = NestedDataSourceSerializer(read_only=True, source="device.datasource") + data_source = NestedDataSourceSerializer(read_only=True, source="device.data_source") + data_file = NestedDataFileSerializer(read_only=True, source="device.data_file") local_copy_last_updated = serializers.DateTimeField(allow_null=True, source="last_modified") config_web_link = serializers.SerializerMethodField() serialized_config = serializers.JSONField(source="serialized") def get_config_web_link(self, obj): - return urljoin(obj.device.datasource.web_url, obj.config_path.as_posix()) + return urljoin(obj.device.data_source.web_url, obj.device.config_path) class DeviceReportSerializer(NestedDeviceSerializer): diff --git a/validity/tests/base.py b/validity/tests/base.py index d46edcd..75fb10a 100644 --- a/validity/tests/base.py +++ b/validity/tests/base.py @@ -32,6 +32,11 @@ class PostMixin: @classmethod def resolve_post_body(cls): + # making data_file point to the same data_source + if "data_source" in cls.post_body and "data_file" in cls.post_body: + data_source = cls.post_body["data_source"]() + cls.post_body["data_source"] = data_source.pk + cls.post_body["data_file"] = cls.post_body["data_file"](source=data_source).pk for k, v in cls.post_body.items(): if isinstance(v, type): cls.post_body[k] = v().pk diff --git a/validity/tests/conftest.py b/validity/tests/conftest.py index ab2b266..4d2158d 100644 --- a/validity/tests/conftest.py +++ b/validity/tests/conftest.py @@ -1,8 +1,7 @@ -import os -import shutil from pathlib import Path import pytest +from core.models import DataSource from dcim.models import Device, DeviceType, Manufacturer from django.contrib.contenttypes.models import ContentType from extras.models import CustomField @@ -10,7 +9,7 @@ from tenancy.models import Tenant import validity -from validity.models import ConfigSerializer, GitRepo +from validity.models import ConfigSerializer pytest.register_assert_rewrite("base") @@ -21,42 +20,6 @@ def tests_root(): return Path(validity.__file__).parent.absolute() / "tests" -@pytest.fixture -def temp_file(): - file_paths = [] - - def _temp_file(path, content): - file_paths.append(str(path)) - with open(path, "w") as file: - file.write(content) - - yield _temp_file - for path in file_paths: - os.remove(path) - - -@pytest.fixture -def temp_folder(): - folder_paths = [] - - def _temp_folder(path): - folder_paths.append(path) - os.mkdir(path) - - yield _temp_folder - for folder in folder_paths: - shutil.rmtree(folder) - - -@pytest.fixture -def temp_file_and_folder(temp_folder, temp_file): - def _temp_file_and_folder(base_dir, dirname, filename, file_content): - temp_folder(base_dir / dirname) - temp_file(base_dir / dirname / filename, file_content) - - return _temp_file_and_folder - - @pytest.fixture def create_custom_fields(db): cfs = CustomField.objects.bulk_create( @@ -68,9 +31,25 @@ def create_custom_fields(db): required=False, ), CustomField( - name="repo", + name="config_data_source", type="object", - object_type=ContentType.objects.get_for_model(GitRepo), + object_type=ContentType.objects.get_for_model(DataSource), + required=False, + ), + CustomField( + name="device_config_default", + type="boolean", + required=False, + default=False, + ), + CustomField( + name="device_config_path", + type="string", + required=False, + ), + CustomField( + name="web_url", + type="string", required=False, ), ] @@ -83,6 +62,8 @@ def create_custom_fields(db): ] ) cfs[1].content_types.set([ContentType.objects.get_for_model(Tenant)]) + for cf in cfs[2:]: + cf.content_types.set([ContentType.objects.get_for_model(DataSource)]) @pytest.fixture diff --git a/validity/tests/factories.py b/validity/tests/factories.py index def2b11..1b43f26 100644 --- a/validity/tests/factories.py +++ b/validity/tests/factories.py @@ -1,3 +1,5 @@ +import datetime + import factory from dcim.models import DeviceRole, DeviceType, Location, Manufacturer, Platform, Site from extras.models import Tag @@ -7,25 +9,46 @@ from validity import models -class GitRepoFactory(DjangoModelFactory): - name = factory.Sequence(lambda n: f"repo-{n}") - git_url = "http://some.url/repo" - web_url = "http://some.url/repo/{{branch}}" - device_config_path = "some/path/{{device.name}}.txt" - username = "" +class DataSourceFactory(DjangoModelFactory): + name = factory.Sequence(lambda n: f"datasource-{n}") + type = "local" + source_url = "file:///some_path" class Meta: - model = models.GitRepo + model = models.VDataSource + + +class DataFileFactory(DjangoModelFactory): + source = factory.SubFactory(DataSourceFactory) + path = factory.Sequence(lambda n: f"file-{n}.txt") + data = "some contents".encode() + size = len(data) + last_updated = datetime.datetime.utcnow() + hash = "1" * 64 + + class Meta: + model = models.VDataFile + + +class ConfigFileFactory(DataFileFactory): + path = "file-1.txt" + source = factory.SubFactory( + DataSourceFactory, + custom_field_data={ + "device_config_default": True, + "device_config_path": path, + "web_url": "http://some_url.com/", + }, + ) - @factory.post_generation - def password(self, create, extracted, **kwargs): - self.password = extracted - self.save() +class DataSourceLinkFactory(DjangoModelFactory): + data_source = factory.SubFactory(DataSourceFactory) + data_file = factory.SubFactory(DataFileFactory, source=data_source, data=factory.SelfAttribute("..contents_bin")) -class GitRepoLinkFactory(DjangoModelFactory): - repo = factory.SubFactory(GitRepoFactory) - file_path = "some/file.txt" + class Params: + contents = "some_contents" + contents_bin = factory.LazyAttribute(lambda o: o.contents.encode()) class NameSetDBFactory(DjangoModelFactory): @@ -37,7 +60,7 @@ class Meta: model = models.NameSet -class NameSetGitFactory(GitRepoLinkFactory, NameSetDBFactory): +class NameSetDSFactory(DataSourceLinkFactory, NameSetDBFactory): definitions = "" class Meta: @@ -63,7 +86,7 @@ class Meta: model = models.ConfigSerializer -class SerializerGitFactory(GitRepoLinkFactory, SerializerDBFactory): +class SerializerDSFactory(DataSourceLinkFactory, SerializerDBFactory): ttp_template = "" class Meta: @@ -78,7 +101,7 @@ class Meta: model = models.ComplianceTest -class CompTestGitFactory(GitRepoLinkFactory, CompTestDBFactory): +class CompTestDSFactory(DataSourceLinkFactory, CompTestDBFactory): expression = "" class Meta: diff --git a/validity/tests/test_api.py b/validity/tests/test_api.py index 644d18a..bb8db6b 100644 --- a/validity/tests/test_api.py +++ b/validity/tests/test_api.py @@ -1,5 +1,4 @@ from http import HTTPStatus -from pathlib import Path from unittest.mock import Mock import pytest @@ -8,9 +7,11 @@ from factories import ( CompTestDBFactory, CompTestResultFactory, + ConfigFileFactory, + DataFileFactory, + DataSourceFactory, DeviceFactory, DeviceTypeFactory, - GitRepoFactory, LocationFactory, ManufacturerFactory, PlatformFactory, @@ -39,29 +40,15 @@ def get_extra_checks(self, resp_json, pk): assert resp_json["effective_definitions"] -class TestGitNameSet(ApiPostGetTest): +class TestDSNameSet(ApiPostGetTest): entity = "namesets" post_body = { "name": "nameset-1", "description": "nameset description", "global": False, "tests": [CompTestDBFactory, CompTestDBFactory], - "repo": GitRepoFactory, - "file_path": "some/file.txt", - } - - -class TestGitRepo(ApiPostGetTest): - entity = "git-repositories" - post_body = { - "name": "repo-1", - "git_url": "http://some.url/path", - "web_url": "http://some.url/webpath", - "device_config_path": "some/path/{{device.name}}.txt", - "default": True, - "username": "admin", - "password": "1234", - "branch": "main", + "data_source": DataSourceFactory, + "data_file": DataFileFactory, } @@ -92,13 +79,13 @@ def get_extra_checks(self, resp_json, pk): assert resp_json["effective_template"] -class TestGitSerializer(ApiPostGetTest): +class TestDSSerializer(ApiPostGetTest): entity = "serializers" post_body = { "name": "serializer-1", "extraction_method": "TTP", - "repo": GitRepoFactory, - "file_path": "some_file.txt", + "data_source": DataSourceFactory, + "data_file": DataFileFactory, } @@ -117,15 +104,15 @@ def get_extra_checks(self, resp_json, pk): assert resp_json["effective_expression"] -class TestGitTest(ApiPostGetTest): +class TestDSTest(ApiPostGetTest): entity = "tests" post_body = { "name": "test-1", "description": "some description", "severity": "LOW", "selectors": [SelectorFactory], - "repo": GitRepoFactory, - "file_path": "some/file.txt", + "data_source": DataSourceFactory, + "data_file": DataFileFactory, } @@ -142,18 +129,18 @@ class TestReport(ApiGetTest): @pytest.mark.django_db def test_get_serialized_config(monkeypatch, admin_client): device = DeviceFactory() - device.repo = GitRepoFactory(web_url="http://github.com/reponame") - device.serializer = SerializerDBFactory() + config_file = ConfigFileFactory() + device.custom_field_data["serializer"] = SerializerDBFactory().pk + device.save() + device.data_source = config_file.source lm = timezone.now() - config = DeviceConfig( - device=device, config_path=Path("some/file.txt"), last_modified=lm, serialized={"key1": "value1"} - ) + config = DeviceConfig(device=device, plain_config="", last_modified=lm, serialized={"key1": "value1"}) monkeypatch.setattr(DeviceConfig, "from_device", Mock(return_value=config)) resp = admin_client.get(f"/api/dcim/devices/{device.pk}/serialized_config/") assert resp.status_code == HTTPStatus.OK assert resp.json().keys() == { - "serializer", - "repo", + "data_source", + "data_file", "local_copy_last_updated", "config_web_link", "serialized_config", diff --git a/validity/tests/test_config_compliance/test_device_config.py b/validity/tests/test_config_compliance/test_device_config.py index 6aa60d4..c4bd061 100644 --- a/validity/tests/test_config_compliance/test_device_config.py +++ b/validity/tests/test_config_compliance/test_device_config.py @@ -3,16 +3,11 @@ import pytest import yaml -from factories import DeviceFactory +from factories import DataFileFactory, DeviceFactory from validity.config_compliance.device_config import DeviceConfig -@pytest.fixture -def set_git_folder(tests_root): - DeviceConfig._git_folder = tests_root - - JSON_CONFIG = """ { "ntp_servers": [ @@ -105,7 +100,7 @@ def set_git_folder(tests_root): @pytest.mark.parametrize( - "extraction_method, file_content, serialized", + "extraction_method, contents, serialized", [ pytest.param("YAML", JSON_CONFIG, json.loads(JSON_CONFIG), id="YAML-JSON"), pytest.param("YAML", YAML_CONFIG, yaml.safe_load(YAML_CONFIG), id="YAML"), @@ -114,14 +109,12 @@ def set_git_folder(tests_root): ], ) @pytest.mark.django_db -def test_device_congig(temp_file_and_folder, set_git_folder, tests_root, extraction_method, file_content, serialized): +def test_device_config(extraction_method, contents, serialized): device = DeviceFactory() device.serializer = Mock(name="some_serializer", extraction_method=extraction_method) if extraction_method == "TTP": device.serializer.effective_template = TTP_TEMPLATE - temp_file_and_folder(tests_root, "some_repo", "config_file.txt", file_content) - device.repo = Mock(rendered_device_path=Mock(return_value="config_file.txt")) - device.repo.name = "some_repo" + device.data_file = DataFileFactory(data=contents.encode()) device_config = DeviceConfig.from_device(device) assert extraction_method.lower() in type(device_config).__name__.lower() assert device_config.serialized == serialized diff --git a/validity/tests/test_models/test_clean.py b/validity/tests/test_models/_test_clean.py similarity index 100% rename from validity/tests/test_models/test_clean.py rename to validity/tests/test_models/_test_clean.py diff --git a/validity/tests/test_models/test_git_link.py b/validity/tests/test_models/test_git_link.py index 8dbcbb9..bfc6ea2 100644 --- a/validity/tests/test_models/test_git_link.py +++ b/validity/tests/test_models/test_git_link.py @@ -1,39 +1,28 @@ from functools import partial as p -from unittest.mock import MagicMock, Mock import pytest from factories import ( CompTestDBFactory, - CompTestGitFactory, + CompTestDSFactory, NameSetDBFactory, - NameSetGitFactory, + NameSetDSFactory, SerializerDBFactory, - SerializerGitFactory, + SerializerDSFactory, ) -from validity.models import base - @pytest.mark.parametrize( "factory, prop_name, expected_value", [ (p(SerializerDBFactory, ttp_template="template"), "effective_template", "template"), - (SerializerGitFactory, "effective_template", ""), + (p(SerializerDSFactory, contents="template2"), "effective_template", "template2"), (p(NameSetDBFactory, definitions="def f(): pass"), "effective_definitions", "def f(): pass"), - (NameSetGitFactory, "effective_definitions", ""), + (p(NameSetDSFactory, contents="def f2(): pass"), "effective_definitions", "def f2(): pass"), (p(CompTestDBFactory, expression="1==2"), "effective_expression", "1==2"), - (CompTestGitFactory, "effective_expression", ""), + (p(CompTestDSFactory, contents="1==3"), "effective_expression", "1==3"), ], ) @pytest.mark.django_db -def test_git_link_model(factory, prop_name, expected_value, monkeypatch): +def test_git_link_model(factory, prop_name, expected_value): model = factory() - mock_git = Mock(GitRepo=Mock(from_db=MagicMock())) - monkeypatch.setattr(base, "git", mock_git) - value = getattr(model, prop_name) - if isinstance(value, Mock): - assert value._extract_mock_name() == "mock.GitRepo.from_db().local_path.__truediv__().open().__enter__().read()" - mock_git.GitRepo.from_db.assert_called_once_with(model.repo) - mock_git.GitRepo.from_db.return_value.local_path.__truediv__.assert_called_once_with(model.file_path) - else: - assert value == expected_value + assert getattr(model, prop_name) == expected_value diff --git a/validity/tests/test_models/test_git_repo.py b/validity/tests/test_models/test_git_repo.py deleted file mode 100644 index 734ca18..0000000 --- a/validity/tests/test_models/test_git_repo.py +++ /dev/null @@ -1,37 +0,0 @@ -from functools import partial - -import pytest -from factories import GitRepoFactory - -from validity.models import GitRepo - - -@pytest.mark.django_db -def test_password(): - repo = GitRepo( - name="r1", - git_url="http://repo.url/path", - web_url="http://repo.url/path", - device_config_path="somepath", - username="adm", - ) - repo.password = "some_password" - repo.save() - db_repo = GitRepo.objects.get(pk=repo.pk) - assert db_repo.password == repo.password - - -@pytest.mark.parametrize( - "factory, expected_url", - [ - (GitRepoFactory, GitRepoFactory.git_url), - ( - partial(GitRepoFactory, username="admin", password="1234", git_url="http://some.url/path"), - "http://admin:1234@some.url/path", - ), - ], -) -@pytest.mark.django_db -def test_full_git_url(factory, expected_url): - repo = factory() - assert repo.full_git_url == expected_url diff --git a/validity/tests/test_models/test_vdevice.py b/validity/tests/test_models/test_vdevice.py index b2e1fd4..4378ab3 100644 --- a/validity/tests/test_models/test_vdevice.py +++ b/validity/tests/test_models/test_vdevice.py @@ -1,76 +1,74 @@ +from operator import attrgetter + import pytest -from django.db import connection -from factories import DeviceFactory, GitRepoFactory, SelectorFactory, SerializerDBFactory, TenantFactory +from factories import ( + ConfigFileFactory, + DataSourceFactory, + DeviceFactory, + SelectorFactory, + SerializerDBFactory, + TenantFactory, +) from validity.models import VDevice @pytest.fixture -def setup_device_and_serializer(create_custom_fields): - serializer = SerializerDBFactory() - device = DeviceFactory() - device.device_type.custom_field_data["serializer"] = serializer.pk - device.device_type.save() - return device, serializer - - -@pytest.fixture -def setup_device_and_repo(create_custom_fields): - repo = GitRepoFactory() - tenant = TenantFactory() - tenant.custom_field_data["repo"] = repo.id - tenant.save() - device = DeviceFactory(tenant=tenant) - return device, repo - - -@pytest.mark.django_db -def test_adhoc_tenant_repo(setup_device_and_repo): - device, repo = setup_device_and_repo - vdevice = VDevice.objects.get(pk=device.pk) - assert vdevice.repo == repo +def setup_serializers(create_custom_fields): + serializers = [SerializerDBFactory() for _ in range(3)] + devices = [DeviceFactory() for _ in range(3)] + attrs = ["custom_field_data", "device_type.custom_field_data", "device_type.manufacturer.custom_field_data"] + for device, serializer, attr in zip(devices, serializers, attrs): + cf_dict = attrgetter(attr)(device) + cf_dict["serializer"] = serializer.pk + device.device_type.manufacturer.save() + device.device_type.save() + device.save() + return {d.pk: s.pk for d, s in zip(devices, serializers)} @pytest.mark.django_db -def test_annotated_tenant_repo(setup_device_and_repo): - device, repo = setup_device_and_repo - vdevice = VDevice.objects.annotate_json_repo().get(pk=device.pk) - queries_count = len(connection.queries) - assert vdevice.repo == repo - assert len(connection.queries) == queries_count +def test_datasource_tenant(create_custom_fields): + datasource = DataSourceFactory() + tenant = TenantFactory(custom_field_data={"config_data_source": datasource.pk}) + DeviceFactory(tenant=tenant) + device = VDevice.objects.prefetch_datasource().first() + assert device.data_source == datasource @pytest.mark.django_db -def test_adhoc_default_repo(): - repo = GitRepoFactory(default=True) - device = DeviceFactory() - assert device.repo == repo +def test_datasource_default(create_custom_fields): + datasource = DataSourceFactory(custom_field_data={"device_config_default": True}) + DeviceFactory() + device = VDevice.objects.prefetch_datasource().first() + assert device.data_source == datasource @pytest.mark.django_db -def test_annotated_default_repo(): - repo = GitRepoFactory(default=True) - device_pk = DeviceFactory().pk - device = VDevice.objects.annotate_json_repo().get(pk=device_pk) - queries_count = len(connection.queries) - assert device.repo == repo - assert len(connection.queries) == queries_count +def test_data_file(create_custom_fields): + DeviceFactory() + data_file = ConfigFileFactory() + device = VDevice.objects.prefetch_datasource().first() + assert device.data_file == data_file @pytest.mark.django_db -def test_adhoc_serializer(setup_device_and_serializer): - device, serializer = setup_device_and_serializer - vdevice = VDevice.objects.get(pk=device.pk) - assert vdevice.serializer == serializer +def test_serializer(setup_serializers, subtests): + device_serializer_map = setup_serializers + devices = VDevice.objects.prefetch_serializer() + for d in devices: + with subtests.test(id=d.name): + assert d.serializer.pk == device_serializer_map[d.pk] @pytest.mark.django_db -def test_annotated_serializer(setup_device_and_serializer): - device, serializer = setup_device_and_serializer - vdevice = VDevice.objects.annotate_json_serializer().get(pk=device.pk) - queries_count = len(connection.queries) - assert vdevice.serializer == serializer - assert len(connection.queries) == queries_count +def test_config_path(create_custom_fields): + DeviceFactory(name="device1") + DataSourceFactory( + custom_field_data={"device_config_path": "path/{{device.name}}.cfg", "device_config_default": True} + ) + device = VDevice.objects.prefetch_datasource().first() + assert device.config_path == "path/device1.cfg" @pytest.mark.parametrize("qs", [VDevice.objects.all(), VDevice.objects.filter(name__in=["d1", "d2"])]) diff --git a/validity/tests/test_scripts/test_run_tests.py b/validity/tests/test_scripts/test_run_tests.py index 938516b..2def742 100644 --- a/validity/tests/test_scripts/test_run_tests.py +++ b/validity/tests/test_scripts/test_run_tests.py @@ -3,16 +3,21 @@ from uuid import uuid4 import pytest +from extras.scripts import Script from factories import CompTestDBFactory, DeviceFactory, NameSetDBFactory, ReportFactory, SelectorFactory from simpleeval import InvalidExpression from validity.config_compliance.exceptions import EvalError from validity.models import ComplianceReport, ComplianceTestResult, VDevice from validity.scripts import run_tests -from validity.scripts.run_tests import RunTestsScript +from validity.scripts.run_tests import RunTestsScript as RunTestsMixin from validity.utils.misc import null_request +class RunTestsScript(RunTestsMixin, Script): + pass + + NS_1 = """ __all__ = ["func1", "var", "func2"] @@ -139,7 +144,7 @@ def test_run_tests_for_selector(mock_script_logging, monkeypatch): name="selector", **{ "devices.select_related.return_value" - ".annotate_json_serializer.return_value.annotate_json_repo.return_value": devices + ".prefetch_datasource.return_value.prefetch_serializer.return_value": devices } ) report = Mock() diff --git a/validity/tests/test_utils/test_git.py b/validity/tests/test_utils/test_git.py deleted file mode 100644 index 9b66eff..0000000 --- a/validity/tests/test_utils/test_git.py +++ /dev/null @@ -1,101 +0,0 @@ -import os -import shutil -from pathlib import Path -from unittest.mock import Mock - -import pytest -from factories import GitRepoFactory - -from validity import models -from validity.utils.git import GitRepo, SyncReposMixin - - -@pytest.fixture -def set_git_folder(tests_root): - GitRepo.git_folder = tests_root - - -@pytest.fixture -def create_repo_via_shell(tests_root): - repo_paths = [] - - def _create_repository(repo_name, origin=None): - repo_path = tests_root / repo_name - repo_paths.append(repo_path) - init_content = "init_content" - commands = ( - f"mkdir -p {repo_path}; cd {repo_path}; " - f"git init; echo -n {init_content} > file.txt; " - "git config user.name q; git config user.email 'q@q.q'; " - "git add -A; git commit -m init; " - ) - if origin: - commands += f"git remote add origin {origin}" - os.system(commands) - return repo_path - - yield _create_repository - for path in repo_paths: - shutil.rmtree(path) - - -@pytest.fixture -def create_repo_via_gitrepo(): - repo_paths = [] - - def _create_repository(repo_name, origin): - repo = GitRepo(repo_name, origin, "master") - repo_paths.append(repo.git_folder / repo.name) - return repo - - yield _create_repository - for path in repo_paths: - shutil.rmtree(path) - - -@pytest.fixture -def make_commit(tests_root): - def _make_commit(repo_name): - repo_path = tests_root / repo_name - os.system(f"cd {repo_path}; echo $RANDOM > file.txt; git add -A; git commit -m commit") - return repo_path / ".git" - - return _make_commit - - -def file_content(repo_path: Path) -> str: - with (repo_path / "file.txt").open("r") as file: - return file.read() - - -def test_clone(create_repo_via_shell, set_git_folder, tests_root, create_repo_via_gitrepo): - remote_repo_path = create_repo_via_shell("remote_repo") - repo: GitRepo = create_repo_via_gitrepo("some_repo", f"file://{remote_repo_path}") - repo.clone() - assert (tests_root / "some_repo").is_dir() - file_path = tests_root / "some_repo" / "file.txt" - assert file_path.is_file() - assert file_content(file_path.parent) == "init_content" - - -def test_force_pull(create_repo_via_shell, make_commit, set_git_folder): - remote_repo_path = create_repo_via_shell("remote_repo") - local_repo_path = create_repo_via_shell("local_repo", origin=remote_repo_path) - make_commit("local_repo") - local_repo = GitRepo("local_repo", f"file://{remote_repo_path}", "master") - local_repo.force_pull() - assert file_content(local_repo_path) == file_content(remote_repo_path) - assert GitRepo("remote_repo", "", "master").head_hash == local_repo.head_hash - - -@pytest.mark.django_db -def test_update_git_repos(create_repo_via_shell, make_commit, set_git_folder, tests_root, monkeypatch): - remote_repo_path = create_repo_via_shell("remote_repo") - GitRepoFactory(name="local_repo", git_url=f"file://{remote_repo_path}") - monkeypatch.setattr(SyncReposMixin, "log_success", Mock(), raising=False) - sync_script = SyncReposMixin() - sync_script.update_git_repos(models.GitRepo.objects.all()) - remote_repo = GitRepo("remote_repo", "", "master") - local_repo = GitRepo("local_repo", "", "master") - assert remote_repo.head_hash == local_repo.head_hash - assert file_content(remote_repo_path) == file_content(tests_root / "local_repo") diff --git a/validity/tests/test_utils/test_misc.py b/validity/tests/test_utils/test_misc.py index 1bf5c41..1d9208f 100644 --- a/validity/tests/test_utils/test_misc.py +++ b/validity/tests/test_utils/test_misc.py @@ -3,7 +3,8 @@ import pytest -from validity.utils.misc import NetboxVersion, reraise +from validity.utils.misc import reraise +from validity.utils.version import NetboxVersion class Error1(Exception): diff --git a/validity/tests/test_utils/test_orm.py b/validity/tests/test_utils/test_orm.py new file mode 100644 index 0000000..b5f2f3a --- /dev/null +++ b/validity/tests/test_utils/test_orm.py @@ -0,0 +1,40 @@ +import pytest +from dcim.models import Device +from django.db import connection +from django.db.models import BigIntegerField +from django.db.models.fields.json import KeyTextTransform +from django.db.models.functions import Cast +from factories import DeviceFactory, SerializerDBFactory + +from validity.models.serializer import ConfigSerializer +from validity.utils.orm import CustomPrefetchMixin, QuerySetMap + + +@pytest.mark.parametrize("attrib", ["pk", "name"]) +@pytest.mark.django_db +def test_qsmap(attrib): + devices = [DeviceFactory(), DeviceFactory(), DeviceFactory()] + qs_map = QuerySetMap(Device.objects.all(), attrib) + assert len(connection.queries) == 0 + for device in devices: + key = getattr(device, attrib) + assert key in qs_map + assert qs_map[key] == device + + +@pytest.mark.django_db +def test_custom_prefetch(): + devices = [DeviceFactory(), DeviceFactory(), DeviceFactory()] + device_qs = Device.objects.all() + custom_qs = CustomPrefetchMixin(device_qs.model, device_qs._query, device_qs._db, device_qs._hints) + serializers = [SerializerDBFactory(), SerializerDBFactory(), SerializerDBFactory()] + device_serializer_map = {} + for d, s in zip(devices, serializers): + d.custom_field_data["serializer"] = s.pk + d.save() + device_serializer_map[d.pk] = s.pk + + for device in custom_qs.annotate( + serializer_id=Cast(KeyTextTransform("serializer", "custom_field_data"), BigIntegerField()) + ).custom_prefetch("serializer", ConfigSerializer.objects.all()): + assert device_serializer_map[device.pk] == device.serializer.pk diff --git a/validity/tests/test_views.py b/validity/tests/test_views.py index abf096f..e225b31 100644 --- a/validity/tests/test_views.py +++ b/validity/tests/test_views.py @@ -6,18 +6,20 @@ from factories import ( CompTestDBFactory, CompTestResultFactory, + ConfigFileFactory, + DataFileFactory, + DataSourceFactory, DeviceFactory, DeviceTypeFactory, - GitRepoFactory, LocationFactory, ManufacturerFactory, NameSetDBFactory, - NameSetGitFactory, + NameSetDSFactory, PlatformFactory, ReportFactory, SelectorFactory, SerializerDBFactory, - SerializerGitFactory, + SerializerDSFactory, SiteFactory, TagFactory, TenantFactory, @@ -42,8 +44,8 @@ def f(): pass } -class TestGitNameSet(ViewTest): - factory_class = NameSetGitFactory +class TestDSNameSet(ViewTest): + factory_class = NameSetDSFactory model_class = models.NameSet post_body = { "name": "nameset-1", @@ -51,23 +53,8 @@ class TestGitNameSet(ViewTest): "_global": False, "tests": [CompTestDBFactory, CompTestDBFactory], "definitions": "", - "repo": GitRepoFactory, - "file_path": "some/file.txt", - } - - -class TestGitRepo(ViewTest): - factory_class = GitRepoFactory - model_class = models.GitRepo - post_body = { - "name": "repo-1", - "git_url": "http://some.url/path", - "web_url": "http://some.url/path", - "device_config_path": "device/path", - "default": True, - "username": "admin", - "password": "1234", - "branch": "master", + "data_source": DataSourceFactory, + "data_file": DataFileFactory, } @@ -104,15 +91,15 @@ class TestDBSerializer(ViewTest): post_body = {"name": "serializer-1", "extraction_method": "TTP", "ttp_template": "interface {{interface}}"} -class TestGitSerializer(ViewTest): - factory_class = SerializerGitFactory +class TestDSSerializer(ViewTest): + factory_class = SerializerDSFactory model_class = models.ConfigSerializer post_body = { "name": "serializer-1", "extraction_method": "TTP", "ttp_template": "", - "repo": GitRepoFactory, - "file_path": "some_file.txt", + "data_source": DataSourceFactory, + "data_file": DataFileFactory, } @@ -136,7 +123,7 @@ class TestDBTest(ViewTest): } -class TestGitTest(ViewTest): +class TestDSTest(ViewTest): factory_class = CompTestDBFactory model_class = models.ComplianceTest post_body = { @@ -145,14 +132,18 @@ class TestGitTest(ViewTest): "severity": "LOW", "expression": "", "selectors": [SelectorFactory], - "repo": GitRepoFactory, - "file_path": "some/file.txt", + "data_source": DataSourceFactory, + "data_file": DataFileFactory, } @pytest.mark.django_db def test_device_results(admin_client): device = DeviceFactory() + ConfigFileFactory() + serializer = SerializerDBFactory() + device.custom_field_data["serializer"] = serializer.pk + device.save() resp = admin_client.get(f"/dcim/devices/{device.pk}/serialized_config/") assert resp.status_code == HTTPStatus.OK diff --git a/validity/utils/orm.py b/validity/utils/orm.py index a1376b7..1e19474 100644 --- a/validity/utils/orm.py +++ b/validity/utils/orm.py @@ -28,7 +28,7 @@ def __init__(self, qs: QuerySet, attribute: str = "pk"): def _evaluate(self): if not self._evaluated: - for model in self._qs: + for model in self._qs.iterator(chunk_size=2000): self._map[getattr(model, self._attribute)] = model self._evaluated = True