From e1b1bce04e477893d7a6caa07a46c7bd7bfc4c3c Mon Sep 17 00:00:00 2001 From: "Richard Kuo (Danswer)" Date: Mon, 3 Feb 2025 01:40:23 -0800 Subject: [PATCH 1/2] add validation for pruning --- .../background/celery/tasks/indexing/tasks.py | 4 +- .../background/celery/tasks/indexing/utils.py | 16 +- .../background/celery/tasks/pruning/tasks.py | 288 +++++++++++++++++- backend/onyx/configs/constants.py | 1 + backend/onyx/redis/redis_connector_prune.py | 55 +++- backend/onyx/server/documents/cc_pair.py | 6 +- 6 files changed, 345 insertions(+), 25 deletions(-) diff --git a/backend/onyx/background/celery/tasks/indexing/tasks.py b/backend/onyx/background/celery/tasks/indexing/tasks.py index 2e3c0e3aecf..d6068efb2a1 100644 --- a/backend/onyx/background/celery/tasks/indexing/tasks.py +++ b/backend/onyx/background/celery/tasks/indexing/tasks.py @@ -423,8 +423,8 @@ def connector_indexing_task( # define a callback class callback = IndexingCallback( os.getppid(), - redis_connector.stop.fence_key, - redis_connector_index.generator_progress_key, + redis_connector, + redis_connector_index, lock, r, ) diff --git a/backend/onyx/background/celery/tasks/indexing/utils.py b/backend/onyx/background/celery/tasks/indexing/utils.py index e14e79b5ff7..f5f6851d459 100644 --- a/backend/onyx/background/celery/tasks/indexing/utils.py +++ b/backend/onyx/background/celery/tasks/indexing/utils.py @@ -97,16 +97,16 @@ class IndexingCallback(IndexingHeartbeatInterface): def __init__( self, parent_pid: int, - stop_key: str, - generator_progress_key: str, + redis_connector: RedisConnector, + redis_connector_index: RedisConnectorIndex, redis_lock: RedisLock, redis_client: Redis, ): super().__init__() self.parent_pid = parent_pid + self.redis_connector: RedisConnector = redis_connector + self.redis_connector_index: RedisConnectorIndex = redis_connector_index self.redis_lock: RedisLock = redis_lock - self.stop_key: str = stop_key - self.generator_progress_key: str = generator_progress_key self.redis_client = redis_client self.started: datetime = datetime.now(timezone.utc) self.redis_lock.reacquire() @@ -118,7 +118,7 @@ def __init__( self.last_parent_check = time.monotonic() def should_stop(self) -> bool: - if self.redis_client.exists(self.stop_key): + if self.redis_connector.stop.fenced: return True return False @@ -141,6 +141,8 @@ def progress(self, tag: str, amount: int) -> None: # self.last_parent_check = now try: + self.redis_connector.prune.set_active() + current_time = time.monotonic() if current_time - self.last_lock_monotonic >= ( CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4 @@ -163,7 +165,9 @@ def progress(self, tag: str, amount: int) -> None: redis_lock_dump(self.redis_lock, self.redis_client) raise - self.redis_client.incrby(self.generator_progress_key, amount) + self.redis_client.incrby( + self.redis_connector_index.generator_progress_key, amount + ) def validate_indexing_fence( diff --git a/backend/onyx/background/celery/tasks/pruning/tasks.py b/backend/onyx/background/celery/tasks/pruning/tasks.py index 99a37ddd017..3ba306b8f5d 100644 --- a/backend/onyx/background/celery/tasks/pruning/tasks.py +++ b/backend/onyx/background/celery/tasks/pruning/tasks.py @@ -1,28 +1,37 @@ +import time from datetime import datetime from datetime import timedelta from datetime import timezone +from typing import cast from uuid import uuid4 from celery import Celery from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded +from pydantic import ValidationError from redis import Redis from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from onyx.background.celery.apps.app_base import task_logger +from onyx.background.celery.celery_redis import celery_find_task +from onyx.background.celery.celery_redis import celery_get_queue_length +from onyx.background.celery.celery_redis import celery_get_queued_task_ids +from onyx.background.celery.celery_redis import celery_get_unacked_task_ids from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector from onyx.background.celery.tasks.indexing.utils import IndexingCallback from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING from onyx.configs.app_configs import JOB_TIMEOUT from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT +from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisLocks +from onyx.configs.constants import OnyxRedisSignals from onyx.connectors.factory import instantiate_connector from onyx.connectors.models import InputType from onyx.db.connector import mark_ccpair_as_pruned @@ -35,10 +44,15 @@ from onyx.db.enums import SyncStatus from onyx.db.enums import SyncType from onyx.db.models import ConnectorCredentialPair +from onyx.db.search_settings import get_current_search_settings from onyx.db.sync_record import insert_sync_record from onyx.db.sync_record import update_sync_record_status from onyx.redis.redis_connector import RedisConnector +from onyx.redis.redis_connector_prune import RedisConnectorPrune +from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload from onyx.redis.redis_pool import get_redis_client +from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT +from onyx.server.utils import make_short_id from onyx.utils.logger import LoggerContextVars from onyx.utils.logger import pruning_ctx from onyx.utils.logger import setup_logger @@ -93,6 +107,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool: ) def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None: r = get_redis_client(tenant_id=tenant_id) + r_celery: Redis = self.app.broker_connection().channel().client # type: ignore lock_beat: RedisLock = r.lock( OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK, @@ -123,13 +138,28 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None: if not _is_pruning_due(cc_pair): continue - tasks_created = try_creating_prune_generator_task( + celery_task_id = try_creating_prune_generator_task( self.app, cc_pair, db_session, r, tenant_id ) - if not tasks_created: + if not celery_task_id: continue - task_logger.info(f"Pruning queued: cc_pair={cc_pair.id}") + task_logger.info( + f"Pruning queued: cc_pair={cc_pair.id} task={celery_task_id}" + ) + + # we want to run this less frequently than the overall task + lock_beat.reacquire() + if not r.exists(OnyxRedisSignals.VALIDATE_PRUNING_FENCES): + # clear any permission fences that don't have associated celery tasks in progress + # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker), + # or be currently executing + try: + validate_pruning_fences(tenant_id, r, r_celery, lock_beat) + except Exception: + task_logger.exception("Exception while validating pruning fences") + + r.set(OnyxRedisSignals.VALIDATE_PRUNING_FENCES, 1, ex=300) except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." @@ -149,7 +179,7 @@ def try_creating_prune_generator_task( db_session: Session, r: Redis, tenant_id: str | None, -) -> int | None: +) -> str | None: """Checks for any conditions that should block the pruning generator task from being created, then creates the task. @@ -168,7 +198,7 @@ def try_creating_prune_generator_task( # we need to serialize starting pruning since it can be triggered either via # celery beat or manually (API call) - lock = r.lock( + lock: RedisLock = r.lock( DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_prune_generator_task", timeout=LOCK_TIMEOUT, ) @@ -200,7 +230,17 @@ def try_creating_prune_generator_task( custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}" - celery_app.send_task( + # set a basic fence to start + redis_connector.prune.set_active() + payload = RedisConnectorPrunePayload( + id=make_short_id(), + submitted=datetime.now(timezone.utc), + started=None, + celery_task_id=None, + ) + redis_connector.prune.set_fence(payload) + + result = celery_app.send_task( OnyxCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK, kwargs=dict( cc_pair_id=cc_pair.id, @@ -221,8 +261,12 @@ def try_creating_prune_generator_task( sync_type=SyncType.PRUNING, ) - # set this only after all tasks have been added - redis_connector.prune.set_fence(True) + # fill in the celery task id + redis_connector.prune.set_active() + payload.celery_task_id = result.id + redis_connector.prune.set_fence(payload) + + payload_id = payload.celery_task_id except Exception: task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}") return None @@ -230,7 +274,7 @@ def try_creating_prune_generator_task( if lock.owned(): lock.release() - return 1 + return payload_id @shared_task( @@ -265,6 +309,43 @@ def connector_pruning_generator_task( r = get_redis_client(tenant_id=tenant_id) + # this wait is needed to avoid a race condition where + # the primary worker sends the task and it is immediately executed + # before the primary worker can finalize the fence + start = time.monotonic() + while True: + if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT: + raise ValueError( + f"connector_prune_generator_task - timed out waiting for fence to be ready: " + f"fence={redis_connector.prune.fence_key}" + ) + + if not redis_connector.prune.fenced: # The fence must exist + raise ValueError( + f"connector_prune_generator_task - fence not found: " + f"fence={redis_connector.prune.fence_key}" + ) + + payload = redis_connector.prune.payload # The payload must exist + if not payload: + raise ValueError( + "connector_prune_generator_task: payload invalid or not found" + ) + + if payload.celery_task_id is None: + logger.info( + f"connector_prune_generator_task - Waiting for fence: " + f"fence={redis_connector.prune.fence_key}" + ) + time.sleep(1) + continue + + logger.info( + f"connector_prune_generator_task - Fence found, continuing...: " + f"fence={redis_connector.prune.fence_key}" + ) + break + # set thread_local=False since we don't control what thread the indexing/pruning # might run our callback with lock: RedisLock = r.lock( @@ -294,6 +375,18 @@ def connector_pruning_generator_task( ) return + payload = redis_connector.prune.payload + if not payload: + raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}") + + new_payload = RedisConnectorPrunePayload( + id=payload.id, + submitted=payload.submitted, + started=datetime.now(timezone.utc), + celery_task_id=payload.celery_task_id, + ) + redis_connector.prune.set_fence(new_payload) + task_logger.info( f"Pruning generator running connector: " f"cc_pair={cc_pair_id} " @@ -307,10 +400,13 @@ def connector_pruning_generator_task( cc_pair.credential, ) + search_settings = get_current_search_settings(db_session) + redis_connector_index = redis_connector.new_index(search_settings.id) + callback = IndexingCallback( 0, - redis_connector.stop.fence_key, - redis_connector.prune.generator_progress_key, + redis_connector, + redis_connector_index, lock, r, ) @@ -415,4 +511,172 @@ def monitor_ccpair_pruning_taskset( redis_connector.prune.taskset_clear() redis_connector.prune.generator_clear() - redis_connector.prune.set_fence(False) + redis_connector.prune.set_fence(None) + + +def validate_pruning_fences( + tenant_id: str | None, + r: Redis, + r_celery: Redis, + lock_beat: RedisLock, +) -> None: + # building lookup table can be expensive, so we won't bother + # validating until the queue is small + PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN = 1024 + + queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery) + if queue_len > PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN: + return + + queued_upsert_tasks = celery_get_queued_task_ids( + OnyxCeleryQueues.CONNECTOR_DELETION, r_celery + ) + reserved_generator_tasks = celery_get_unacked_task_ids( + OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery + ) + + # validate all existing indexing jobs + for key_bytes in r.scan_iter( + RedisConnectorPrune.FENCE_PREFIX + "*", + count=SCAN_ITER_COUNT_DEFAULT, + ): + lock_beat.reacquire() + validate_pruning_fence( + tenant_id, + key_bytes, + queued_upsert_tasks, + reserved_generator_tasks, + r, + r_celery, + ) + return + + +def validate_pruning_fence( + tenant_id: str | None, + key_bytes: bytes, + queued_tasks: set[str], + reserved_tasks: set[str], + r: Redis, + r_celery: Redis, +) -> None: + """See validate_indexing_fence for an overall idea of validation flows. + + queued_tasks: the celery queue of lightweight permission sync tasks + reserved_tasks: prefetched tasks for sync task generator + """ + # if the fence doesn't exist, there's nothing to do + fence_key = key_bytes.decode("utf-8") + cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key) + if cc_pair_id_str is None: + task_logger.warning( + f"validate_pruning_fence - could not parse id from {fence_key}" + ) + return + + cc_pair_id = int(cc_pair_id_str) + # parse out metadata and initialize the helper class with it + redis_connector = RedisConnector(tenant_id, int(cc_pair_id)) + + # check to see if the fence/payload exists + if not redis_connector.prune.fenced: + return + + # in the cloud, the payload format may have changed ... + # it's a little sloppy, but just reset the fence for now if that happens + # TODO: add intentional cleanup/abort logic + try: + payload = redis_connector.prune.payload + except ValidationError: + task_logger.exception( + "validate_pruning_fence - " + "Resetting fence because fence schema is out of date: " + f"cc_pair={cc_pair_id} " + f"fence={fence_key}" + ) + + redis_connector.prune.reset() + return + + if not payload: + return + + if not payload.celery_task_id: + return + + # OK, there's actually something for us to validate + + # either the generator task must be in flight or its subtasks must be + found = celery_find_task( + payload.celery_task_id, + OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, + r_celery, + ) + if found: + # the celery task exists in the redis queue + redis_connector.prune.set_active() + return + + if payload.celery_task_id in reserved_tasks: + # the celery task was prefetched and is reserved within a worker + redis_connector.prune.set_active() + return + + # look up every task in the current taskset in the celery queue + # every entry in the taskset should have an associated entry in the celery task queue + # because we get the celery tasks first, the entries in our own permissions taskset + # should be roughly a subset of the tasks in celery + + # this check isn't very exact, but should be sufficient over a period of time + # A single successful check over some number of attempts is sufficient. + + # TODO: if the number of tasks in celery is much lower than than the taskset length + # we might be able to shortcut the lookup since by definition some of the tasks + # must not exist in celery. + + tasks_scanned = 0 + tasks_not_in_celery = 0 # a non-zero number after completing our check is bad + + for member in r.sscan_iter(redis_connector.prune.taskset_key): + tasks_scanned += 1 + + member_bytes = cast(bytes, member) + member_str = member_bytes.decode("utf-8") + if member_str in queued_tasks: + continue + + if member_str in reserved_tasks: + continue + + tasks_not_in_celery += 1 + + task_logger.info( + "validate_permission_sync_fence task check: " + f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}" + ) + + if tasks_not_in_celery == 0: + redis_connector.prune.set_active() + return + + # we may want to enable this check if using the active task list somehow isn't good enough + # if redis_connector_index.generator_locked(): + # logger.info(f"{payload.celery_task_id} is currently executing.") + + # if we get here, we didn't find any direct indication that the associated celery tasks exist, + # but they still might be there due to gaps in our ability to check states during transitions + # Checking the active signal safeguards us against these transition periods + # (which has a duration that allows us to bridge those gaps) + if redis_connector.prune.active(): + return + + # celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up. + task_logger.warning( + "validate_pruning_fence - " + "Resetting fence because no associated celery tasks were found: " + f"cc_pair={cc_pair_id} " + f"fence={fence_key}" + ) + + redis_connector.prune.reset() + return diff --git a/backend/onyx/configs/constants.py b/backend/onyx/configs/constants.py index e18a5ee3e7a..6447d1e95bd 100644 --- a/backend/onyx/configs/constants.py +++ b/backend/onyx/configs/constants.py @@ -302,6 +302,7 @@ class OnyxRedisSignals: VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences" VALIDATE_EXTERNAL_GROUP_SYNC_FENCES = "signal:validate_external_group_sync_fences" VALIDATE_PERMISSION_SYNC_FENCES = "signal:validate_permission_sync_fences" + VALIDATE_PRUNING_FENCES = "signal:validate_pruning_fences" class OnyxCeleryPriority(int, Enum): diff --git a/backend/onyx/redis/redis_connector_prune.py b/backend/onyx/redis/redis_connector_prune.py index ea4a923eb6d..10a4d3750fd 100644 --- a/backend/onyx/redis/redis_connector_prune.py +++ b/backend/onyx/redis/redis_connector_prune.py @@ -1,9 +1,11 @@ import time +from datetime import datetime from typing import cast from uuid import uuid4 import redis from celery import Celery +from pydantic import BaseModel from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session @@ -15,6 +17,13 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT +class RedisConnectorPrunePayload(BaseModel): + id: str + submitted: datetime + started: datetime | None + celery_task_id: str | None + + class RedisConnectorPrune: """Manages interactions with redis for pruning tasks. Should only be accessed through RedisConnector.""" @@ -35,6 +44,12 @@ class RedisConnectorPrune: TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpruning_taskset SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpruning+sub + # used to signal the overall workflow is still active + # it's impossible to get the exact state of the system at a single point in time + # so we need a signal with a TTL to bridge gaps in our checks + ACTIVE_PREFIX = PREFIX + "_active" + ACTIVE_TTL = 3600 + def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None: self.tenant_id: str | None = tenant_id self.id = id @@ -48,6 +63,7 @@ def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None: self.taskset_key = f"{self.TASKSET_PREFIX}_{id}" self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}" + self.active_key = f"{self.ACTIVE_PREFIX}_{id}" def taskset_clear(self) -> None: self.redis.delete(self.taskset_key) @@ -77,12 +93,41 @@ def fenced(self) -> bool: return False - def set_fence(self, value: bool) -> None: - if not value: + @property + def payload(self) -> RedisConnectorPrunePayload | None: + # read related data and evaluate/print task progress + fence_bytes = cast(bytes, self.redis.get(self.fence_key)) + if fence_bytes is None: + return None + + fence_str = fence_bytes.decode("utf-8") + payload = RedisConnectorPrunePayload.model_validate_json(cast(str, fence_str)) + + return payload + + def set_fence( + self, + payload: RedisConnectorPrunePayload | None, + ) -> None: + if not payload: self.redis.delete(self.fence_key) return - self.redis.set(self.fence_key, 0) + self.redis.set(self.fence_key, payload.model_dump_json()) + + def set_active(self) -> None: + """This sets a signal to keep the permissioning flow from getting cleaned up within + the expiration time. + + The slack in timing is needed to avoid race conditions where simply checking + the celery queue and task status could result in race conditions.""" + self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL) + + def active(self) -> bool: + if self.redis.exists(self.active_key): + return True + + return False @property def generator_complete(self) -> int | None: @@ -158,6 +203,7 @@ def generate_tasks( return len(async_results) def reset(self) -> None: + self.redis.delete(self.active_key) self.redis.delete(self.generator_progress_key) self.redis.delete(self.generator_complete_key) self.redis.delete(self.taskset_key) @@ -172,6 +218,9 @@ def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None: @staticmethod def reset_all(r: redis.Redis) -> None: """Deletes all redis values for all connectors""" + for key in r.scan_iter(RedisConnectorPrune.ACTIVE_PREFIX + "*"): + r.delete(key) + for key in r.scan_iter(RedisConnectorPrune.TASKSET_PREFIX + "*"): r.delete(key) diff --git a/backend/onyx/server/documents/cc_pair.py b/backend/onyx/server/documents/cc_pair.py index 3ba5984df38..c3f2e7a6654 100644 --- a/backend/onyx/server/documents/cc_pair.py +++ b/backend/onyx/server/documents/cc_pair.py @@ -359,15 +359,17 @@ def prune_cc_pair( f"credential={cc_pair.credential_id} " f"{cc_pair.connector.name} connector." ) - tasks_created = try_creating_prune_generator_task( + celery_task_id = try_creating_prune_generator_task( primary_app, cc_pair, db_session, r, CURRENT_TENANT_ID_CONTEXTVAR.get() ) - if not tasks_created: + if not celery_task_id: raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Pruning task creation failed.", ) + logger.info(f"Pruning queued: cc_pair={cc_pair.id} task={celery_task_id}") + return StatusResponse( success=True, message="Successfully created the pruning task.", From 06245fa7424e865b055991d93edec11b2adc1e51 Mon Sep 17 00:00:00 2001 From: "Richard Kuo (Danswer)" Date: Tue, 4 Feb 2025 16:28:14 -0800 Subject: [PATCH 2/2] fix missing class --- backend/onyx/configs/constants.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/backend/onyx/configs/constants.py b/backend/onyx/configs/constants.py index b910923c008..9b7c66aea4b 100644 --- a/backend/onyx/configs/constants.py +++ b/backend/onyx/configs/constants.py @@ -315,6 +315,9 @@ class OnyxRedisSignals: ) BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences" BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table" + + +class OnyxRedisConstants: ACTIVE_FENCES = "active_fences"