Skip to content

Commit

Permalink
remove two-phase and rollback
Browse files Browse the repository at this point in the history
  • Loading branch information
amyasnikov committed Aug 17, 2024
1 parent 337f552 commit c17df0e
Show file tree
Hide file tree
Showing 19 changed files with 123 additions and 178 deletions.
4 changes: 1 addition & 3 deletions validity/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import logging

from dimi import Container
from django.conf import settings as django_settings
from netbox.settings import VERSION
from pydantic import BaseModel, Field

from validity.utils.version import NetboxVersion

Expand Down Expand Up @@ -32,7 +30,7 @@ class NetBoxValidityConfig(PluginConfig):
netbox_version = NetboxVersion(VERSION)

def ready(self):
from validity import data_backends, dependencies
from validity import data_backends, dependencies, signals

return super().ready()

Expand Down
10 changes: 1 addition & 9 deletions validity/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import django_rq
from dimi.scopes import Singleton
from django.conf import LazySettings, settings
from rq import Callback

from validity import di
from validity.choices import ConnectionTypeChoices
Expand All @@ -30,12 +29,7 @@ def poller_map():
}


@di.dependency(scope=Singleton)
def runtests_transaction_template():
return "ApplyWorker_{job}_{worker}"


from validity.scripts import ApplyWorker, CombineWorker, Launcher, RollbackWorker, SplitWorker, Task # noqa
from validity.scripts import ApplyWorker, CombineWorker, Launcher, SplitWorker, Task # noqa


@di.dependency(scope=Singleton)
Expand All @@ -44,7 +38,6 @@ def runtests_launcher(
split_worker: Annotated[SplitWorker, ...],
apply_worker: Annotated[ApplyWorker, ...],
combine_worker: Annotated[CombineWorker, ...],
rollback_worker: Annotated[RollbackWorker, ...],
):
from validity.models import ComplianceReport

Expand All @@ -57,7 +50,6 @@ def runtests_launcher(
Task(
apply_worker,
job_timeout=vsettings.script_timeouts.runtests_apply,
on_failure=Callback(rollback_worker.as_func(), timeout=vsettings.script_timeouts.runtests_rollback),
multi_workers=True,
),
Task(combine_worker, job_timeout=vsettings.script_timeouts.runtests_combine),
Expand Down
21 changes: 14 additions & 7 deletions validity/managers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from functools import partialmethod
from itertools import chain

from core.models import Job
from django.contrib.contenttypes.models import ContentType
from django.contrib.postgres.aggregates import ArrayAgg
from django.db.models import (
BigIntegerField,
Expand Down Expand Up @@ -62,9 +64,8 @@ def last_more_than(self, than: int) -> "ComplianceTestResultQS":
def count_devices_and_tests(self):
return self.aggregate(device_count=Count("devices", distinct=True), test_count=Count("tests", distinct=True))

def delete_old(self):
del_count = self.filter(report=None).last_more_than(self.v_settings.store_last_results)._raw_delete(self.db)
return (del_count, {"validity.ComplianceTestResult": del_count})
def raw_delete(self):
return self._raw_delete(self.db)


def percentage(field1: str, field2: str) -> Case:
Expand Down Expand Up @@ -119,14 +120,20 @@ def count_devices_and_tests(self):
)

def delete_old(self):
from validity.models import ComplianceTestResult
from validity.models import ComplianceReport, ComplianceTestResult

old_reports = list(self.order_by("-created").values_list("pk", flat=True)[self.v_settings.store_reports :])
deleted_results = ComplianceTestResult.objects.filter(report__pk__in=old_reports)._raw_delete(self.db)
deleted_results = ComplianceTestResult.objects.filter(report__pk__in=old_reports).raw_delete()
report_content_type = ContentType.objects.get_for_model(ComplianceReport)
deleted_jobs = Job.objects.filter(object_id__in=old_reports, object_type=report_content_type).delete()
deleted_reports, _ = self.filter(pk__in=old_reports).delete()
return (
deleted_results + deleted_reports,
{"validity.ComplianceTestResult": deleted_results, "validity.ComplianceReport": deleted_reports},
deleted_results + deleted_reports + deleted_reports,
{
"validity.ComplianceTestResult": deleted_results,
"validity.ComplianceReport": deleted_reports,
"core.Job": deleted_jobs,
},
)


Expand Down
4 changes: 4 additions & 0 deletions validity/models/report.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from core.models import Job
from django.contrib.contenttypes.fields import GenericRelation
from netbox.models import ChangeLoggingMixin

from validity.managers import ComplianceReportQS
from .base import BaseReadOnlyModel


class ComplianceReport(ChangeLoggingMixin, BaseReadOnlyModel):
jobs = GenericRelation(Job, content_type_field="object_type")

objects = ComplianceReportQS.as_manager()

class Meta:
Expand Down
2 changes: 1 addition & 1 deletion validity/scripts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .data_models import Task
from .launch import Launcher
from .logger import Logger
from .steps import ApplyWorker, CombineWorker, RollbackWorker, SplitWorker
from .runtests import ApplyWorker, CombineWorker, SplitWorker
1 change: 1 addition & 0 deletions validity/scripts/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def serialized(self):
class ExecutionResult:
test_stat: TestResultRatio
log: list[Message]
errored: bool = False


@dataclass
Expand Down
7 changes: 7 additions & 0 deletions validity/scripts/logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import traceback as tb
from functools import partialmethod

from extras.choices import LogLevelChoices
Expand Down Expand Up @@ -28,3 +29,9 @@ def _log(self, message: str, level: LogLevelChoices):
info = partialmethod(_log, level=LogLevelChoices.LOG_INFO)
warning = partialmethod(_log, level=LogLevelChoices.LOG_WARNING)
failure = partialmethod(_log, level=LogLevelChoices.LOG_FAILURE)

def log_exception(self, exc_value, exc_type=None, exc_traceback=None):
exc_traceback = exc_traceback or exc_value.__traceback__
exc_type = exc_type or type(exc_value)
stacktrace = "".join(tb.format_tb(exc_traceback))
self.failure(f"Unhandled error occured: `{exc_type}: {exc_value}`\n```\n{stacktrace}\n```")
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .apply import ApplyWorker
from .combine import CombineWorker
from .rollback import RollbackWorker
from .split import SplitWorker
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from dataclasses import dataclass
from functools import cached_property
from itertools import chain
from typing import Annotated, Any, Callable, ContextManager, Iterable, Iterator
from typing import Annotated, Any, Callable, Iterable, Iterator

from dimi import Singleton
from django.db.models import Prefetch, QuerySet

from validity import di
from validity.compliance.exceptions import EvalError, SerializationError
from validity.models import ComplianceSelector, ComplianceTest, ComplianceTestResult, NameSet, VDataSource, VDevice
from validity.utils.orm import TwoPhaseTransaction
from ..data_models import ExecutionResult, FullRunTestsParams, TestResultRatio
from ..logger import Logger
from ..parent_jobs import JobExtractor
Expand Down Expand Up @@ -126,10 +125,6 @@ def _get_device_qs(self, selector: ComplianceSelector, device_ids: list[int]) ->
return device_qs


def prepare_transaction(transaction):
return TwoPhaseTransaction(transaction).prepare()


@di.dependency(scope=Singleton)
@dataclass(repr=False, kw_only=True)
class ApplyWorker:
Expand All @@ -138,28 +133,37 @@ class ApplyWorker:
"""

test_executor_cls: type[TestExecutor] = TestExecutor
logger_factory: Callable[[str], Logger] = Logger
device_test_gen: type[DeviceTestIterator] = DeviceTestIterator
result_batch_size: Annotated[int, "validity_settings.result_batch_size"]
job_extractor_factory: Callable[[], JobExtractor] = JobExtractor
prepare_transaction: Callable[[str], ContextManager] = prepare_transaction
transaction_template: Annotated[str, "runtests_transaction_template"]

def __call__(self, *, params: FullRunTestsParams, worker_id: int) -> ExecutionResult:
try:
executor = self.test_executor_cls(worker_id, params.explanation_verbosity, params.report_id)
test_results = self.get_test_results(params, worker_id, executor)
self.save_results_to_db(test_results)
return ExecutionResult(
TestResultRatio(executor.results_passed, executor.results_count), executor.log.messages
)
except Exception as err:
logger = self.logger_factory(f"Worker {worker_id}")
logger.log_exception(err)
return ExecutionResult(test_stat=TestResultRatio(0, 0), log=logger.messages, errored=True)

def get_test_results(
self, params: FullRunTestsParams, worker_id: int, executor: TestExecutor
) -> Iterator[ComplianceTestResult]:
selector_devices = self.get_selector_devices(worker_id)
executor = self.test_executor_cls(worker_id, params.explanation_verbosity, params.report_id)
test_results = (
executor(devices, tests)
for devices, tests in self.device_test_gen(selector_devices, params.test_tags, params.override_datasource)
)
chained_results = chain.from_iterable(test_results)
self.save_results_to_db(chained_results, params.job_id, worker_id)
return ExecutionResult(TestResultRatio(executor.results_passed, executor.results_count), executor.log.messages)
return chain.from_iterable(test_results)

def get_selector_devices(self, worker_id: int) -> dict[int, list[int]]:
job_extractor = self.job_extractor_factory()
return job_extractor.parent.job.result[worker_id]
return job_extractor.parent.job.result.slices[worker_id]

def save_results_to_db(self, results: Iterable[ComplianceTestResult], job_id: int, worker_id: int) -> None:
transaction_id = self.transaction_template.format(job=job_id, worker=worker_id)
with self.prepare_transaction(transaction_id):
ComplianceTestResult.objects.bulk_create(results, batch_size=self.result_batch_size)
def save_results_to_db(self, results: Iterable[ComplianceTestResult]) -> None:
ComplianceTestResult.objects.bulk_create(results, batch_size=self.result_batch_size)
25 changes: 25 additions & 0 deletions validity/scripts/runtests/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from contextlib import contextmanager

from core.choices import JobStatusChoices
from core.models import Job


class TerminateMixin:
def terminate_job(self, job: Job, status: str, error: str | None = None, logs=None, output=None):
logs = logs or []
job.data = {"log": [log.serialized for log in logs], "output": output}
job.terminate(status, error)

def terminate_errored_job(self, job: Job, type, value, traceback):
logger = self.log_factory()
logger.log_exception(value, type, traceback)
logger.info("Database changes have been reverted")
self.terminate_job(job, status=JobStatusChoices.STATUS_ERRORED, error=repr(value), logs=logger.messages)

@contextmanager
def terminate_job_on_error(self, job: Job):
try:
yield
except Exception as err:
self.terminate_errored_job(job, type(err), err, err.__traceback__)
raise
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from itertools import chain
from typing import Annotated, Any, Callable

from core.choices import JobStatusChoices
from core.models import Job
from dimi import Singleton
from django.db.models import QuerySet
Expand All @@ -13,40 +14,29 @@
from extras.choices import ObjectChangeActionChoices

from validity import di
from validity.models import ComplianceReport
from validity.models import ComplianceReport, ComplianceTestResult
from validity.netbox_changes import enqueue_object, events_queue
from validity.utils.orm import TwoPhaseTransaction
from ..data_models import FullRunTestsParams, Message, TestResultRatio
from ..launch import Launcher
from ..logger import Logger
from ..parent_jobs import JobExtractor
from .base import TracebackMixin
from .base import TerminateMixin


def enqueue(report, request, action):
return enqueue_object(events_queue.get(), report, request.get_user(), request.id, action)


def commit(transaction_id):
TwoPhaseTransaction(transaction_id).commit()


@di.dependency(scope=Singleton)
@dataclass(repr=False, kw_only=True)
class CombineWorker(TracebackMixin):
class CombineWorker(TerminateMixin):
log_factory: Callable[[], Logger] = Logger
job_extractor_factory: Callable[[], JobExtractor] = JobExtractor
enqueue_func: Callable[[ComplianceReport, HttpRequest, str], None] = enqueue
report_queryset: QuerySet[ComplianceReport] = field(
default_factory=ComplianceReport.objects.annotate_result_stats().count_devices_and_tests
)
commit_func: Callable[[str], None] = commit
transaction_template: Annotated[str, "runtests_transaction_template"]

def commit_transactions(self, workers_num: int, job_id: int) -> None:
for worker_id in range(workers_num):
transaction_id = self.transaction_template.format(job=job_id, worker=worker_id)
self.commit_func(transaction_id)
testresult_queryset: QuerySet = field(default_factory=ComplianceTestResult.objects.all)

def fire_report_webhook(self, report_id: int, request: HttpRequest) -> None:
report = self.report_queryset.get(pk=report_id)
Expand All @@ -57,12 +47,11 @@ def count_test_stats(self, job_extractor: JobExtractor) -> TestResultRatio:
return reduce(operator.add, result_ratios)

def collect_logs(self, logger: Logger, job_extractor: JobExtractor) -> list[Message]:
assert job_extractor.parents, "Combine must have parents"
parent_logs = chain.from_iterable(extractor.job.result.log for extractor in job_extractor.parents)
grandparent_logs = job_extractor.parent.parent.job.result.log
return [*grandparent_logs, *parent_logs, *logger.messages]

def terminate_job(self, job: Job, test_stats: TestResultRatio, logs: list[Message]):
def terminate_succeeded_job(self, job: Job, test_stats: TestResultRatio, logs: list[Message]):
job.data = {"log": [log.serialized for log in logs], "output": {"statistics": test_stats.serialized}}
job.terminate()

Expand All @@ -74,16 +63,27 @@ def schedule_next_job(
params.schedule_at = job.started + datetime.timedelta(params.schedule_interval)
launcher(params)

def get_previous_errors(self, job_extractor: JobExtractor) -> list[Message]:
error_logs = chain.from_iterable(
extractor.job.result.log for extractor in job_extractor.parents if extractor.job.result.errored
)
return list(error_logs)

def __call__(self, params: FullRunTestsParams) -> Any:
netbox_job = params.get_job()
with self.terminate_job_on_error(netbox_job):
logger = self.log_factory()
job_extractor = self.job_extractor_factory()
self.commit_transactions(params.workers_num, params.job_id)
if err_logs := self.get_previous_errors(job_extractor):
self.terminate_job(netbox_job, JobStatusChoices.STATUS_ERRORED, error="ApplyWorkerError", logs=err_logs)
self.testresult_queryset.filter(report=params.report_id).raw_delete()
return
logger = self.log_factory()
self.fire_report_webhook(params.report_id, params.request)
test_stats = self.count_test_stats(job_extractor)
report_url = reverse("plugins:validity:compliancereport", kwargs={"pk": params.report_id})
logger.success(f"Job succeeded. See [Compliance Report]({report_url}) for detailed statistics")
logs = self.collect_logs(logger, job_extractor)
self.terminate_job(netbox_job, test_stats, logs)
self.schedule_next_job(params, netbox_job)
self.terminate_job(
netbox_job, JobStatusChoices.STATUS_COMPLETED, logs=logs, output={"statistics": test_stats.serialized}
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from validity.utils.misc import batched, datasource_sync
from ..data_models import FullRunTestsParams, SplitResult
from ..logger import Logger
from .base import TracebackMixin
from .base import TerminateMixin


@di.dependency(scope=Singleton)
@dataclass(repr=False)
class SplitWorker(TracebackMixin):
class SplitWorker(TerminateMixin):
log_factory: Callable[[], Logger] = Logger
datasource_sync_fn: Callable[[Iterable[VDataSource], Q], None] = datasource_sync
device_batch_size: int = 2000
Expand All @@ -40,7 +40,7 @@ def sync_datasources(self, override_datasource: int | None, device_filter: Q):
def _work_slices(self, selector_qs: QuerySet[ComplianceSelector], devices_per_worker: int):
def device_ids(selector):
return (
selector.devices.iterator(chunk_size=self.device_batch_size).order_by("pk").values_list("pk", flat=True)
selector.devices.order_by("pk").values_list("pk", flat=True).iterator(chunk_size=self.device_batch_size)
)

selector_device = chain.from_iterable(
Expand Down Expand Up @@ -83,6 +83,7 @@ def __call__(self, params: FullRunTestsParams) -> SplitResult:
job = params.get_job()
with self.terminate_job_on_error(job):
job.start()
job.object_type.model_class().objects.delete_old()
logger = self.log_factory()
device_filter = params.get_device_filter()
if params.sync_datasources:
Expand Down
Loading

0 comments on commit c17df0e

Please sign in to comment.