From 6ccb3f085a6f4fab33b79ef8585f2e16f5f9739d Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Wed, 5 Feb 2025 23:00:40 -0800 Subject: [PATCH] select only doc_id (#3920) * select only doc_id * select more doc ids * fix user group --------- Co-authored-by: Richard Kuo (Danswer) --- backend/ee/onyx/db/user_group.py | 4 +-- .../celery/tasks/connector_deletion/tasks.py | 13 ++++++---- backend/onyx/db/document.py | 26 +++++++++++++++++++ backend/onyx/db/document_set.py | 4 +-- .../redis/redis_connector_credential_pair.py | 26 +++++++++---------- backend/onyx/redis/redis_document_set.py | 20 +++++++------- backend/onyx/redis/redis_usergroup.py | 22 +++++++--------- 7 files changed, 71 insertions(+), 44 deletions(-) diff --git a/backend/ee/onyx/db/user_group.py b/backend/ee/onyx/db/user_group.py index 827cdcae559..c2a36d33086 100644 --- a/backend/ee/onyx/db/user_group.py +++ b/backend/ee/onyx/db/user_group.py @@ -218,14 +218,14 @@ def fetch_user_groups_for_user( return db_session.scalars(stmt).all() -def construct_document_select_by_usergroup( +def construct_document_id_select_by_usergroup( user_group_id: int, ) -> Select: """This returns a statement that should be executed using .yield_per() to minimize overhead. The primary consumers of this function are background processing task generators.""" stmt = ( - select(Document) + select(Document.id) .join( DocumentByConnectorCredentialPair, Document.id == DocumentByConnectorCredentialPair.id, diff --git a/backend/onyx/background/celery/tasks/connector_deletion/tasks.py b/backend/onyx/background/celery/tasks/connector_deletion/tasks.py index 2b8b3f16070..003742f6b11 100644 --- a/backend/onyx/background/celery/tasks/connector_deletion/tasks.py +++ b/backend/onyx/background/celery/tasks/connector_deletion/tasks.py @@ -179,11 +179,14 @@ def try_generate_document_cc_pair_cleanup_tasks( if tasks_generated is None: raise ValueError("RedisConnectorDeletion.generate_tasks returned None") - insert_sync_record( - db_session=db_session, - entity_id=cc_pair_id, - sync_type=SyncType.CONNECTOR_DELETION, - ) + try: + insert_sync_record( + db_session=db_session, + entity_id=cc_pair_id, + sync_type=SyncType.CONNECTOR_DELETION, + ) + except Exception: + pass except TaskDependencyError: redis_connector.delete.set_fence(None) diff --git a/backend/onyx/db/document.py b/backend/onyx/db/document.py index 082ee3f94a9..6cad5868e7a 100644 --- a/backend/onyx/db/document.py +++ b/backend/onyx/db/document.py @@ -105,6 +105,32 @@ def construct_document_select_for_connector_credential_pair_by_needs_sync( return stmt +def construct_document_id_select_for_connector_credential_pair_by_needs_sync( + connector_id: int, credential_id: int +) -> Select: + initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where( + and_( + DocumentByConnectorCredentialPair.connector_id == connector_id, + DocumentByConnectorCredentialPair.credential_id == credential_id, + ) + ) + + stmt = ( + select(DbDocument.id) + .where( + DbDocument.id.in_(initial_doc_ids_stmt), + or_( + DbDocument.last_modified + > DbDocument.last_synced, # last_modified is newer than last_synced + DbDocument.last_synced.is_(None), # never synced + ), + ) + .distinct() + ) + + return stmt + + def get_all_documents_needing_vespa_sync_for_cc_pair( db_session: Session, cc_pair_id: int ) -> list[DbDocument]: diff --git a/backend/onyx/db/document_set.py b/backend/onyx/db/document_set.py index 0229682fddd..0f91cbc71b8 100644 --- a/backend/onyx/db/document_set.py +++ b/backend/onyx/db/document_set.py @@ -545,7 +545,7 @@ def fetch_documents_for_document_set_paginated( return documents, documents[-1].id if documents else None -def construct_document_select_by_docset( +def construct_document_id_select_by_docset( document_set_id: int, current_only: bool = True, ) -> Select: @@ -554,7 +554,7 @@ def construct_document_select_by_docset( are background processing task generators.""" stmt = ( - select(Document) + select(Document.id) .join( DocumentByConnectorCredentialPair, DocumentByConnectorCredentialPair.id == Document.id, diff --git a/backend/onyx/redis/redis_connector_credential_pair.py b/backend/onyx/redis/redis_connector_credential_pair.py index db8d526c0dc..3738b10b8a9 100644 --- a/backend/onyx/redis/redis_connector_credential_pair.py +++ b/backend/onyx/redis/redis_connector_credential_pair.py @@ -16,9 +16,8 @@ from onyx.configs.constants import OnyxRedisConstants from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.document import ( - construct_document_select_for_connector_credential_pair_by_needs_sync, + construct_document_id_select_for_connector_credential_pair_by_needs_sync, ) -from onyx.db.models import Document from onyx.redis.redis_object_helper import RedisObjectHelper @@ -72,7 +71,8 @@ def generate_tasks( last_lock_time = time.monotonic() - async_results = [] + num_tasks_sent = 0 + cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=int(self._id), @@ -80,14 +80,14 @@ def generate_tasks( if not cc_pair: return None - stmt = construct_document_select_for_connector_credential_pair_by_needs_sync( + stmt = construct_document_id_select_for_connector_credential_pair_by_needs_sync( cc_pair.connector_id, cc_pair.credential_id ) num_docs = 0 - for doc in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT): - doc = cast(Document, doc) + for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT): + doc_id = cast(str, doc_id) current_time = time.monotonic() if current_time - last_lock_time >= ( CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4 @@ -98,7 +98,7 @@ def generate_tasks( num_docs += 1 # check if we should skip the document (typically because it's already syncing) - if doc.id in self.skip_docs: + if doc_id in self.skip_docs: continue # celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac" @@ -114,21 +114,21 @@ def generate_tasks( ) # Priority on sync's triggered by new indexing should be medium - result = celery_app.send_task( + celery_app.send_task( OnyxCeleryTask.VESPA_METADATA_SYNC_TASK, - kwargs=dict(document_id=doc.id, tenant_id=tenant_id), + kwargs=dict(document_id=doc_id, tenant_id=tenant_id), queue=OnyxCeleryQueues.VESPA_METADATA_SYNC, task_id=custom_task_id, priority=OnyxCeleryPriority.MEDIUM, ) - async_results.append(result) - self.skip_docs.add(doc.id) + num_tasks_sent += 1 + self.skip_docs.add(doc_id) - if len(async_results) >= max_tasks: + if num_tasks_sent >= max_tasks: break - return len(async_results), num_docs + return num_tasks_sent, num_docs class RedisGlobalConnectorCredentialPair: diff --git a/backend/onyx/redis/redis_document_set.py b/backend/onyx/redis/redis_document_set.py index c0c3ce2a0f3..6fd5a453b01 100644 --- a/backend/onyx/redis/redis_document_set.py +++ b/backend/onyx/redis/redis_document_set.py @@ -14,8 +14,7 @@ from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisConstants -from onyx.db.document_set import construct_document_select_by_docset -from onyx.db.models import Document +from onyx.db.document_set import construct_document_id_select_by_docset from onyx.redis.redis_object_helper import RedisObjectHelper @@ -66,10 +65,11 @@ def generate_tasks( """ last_lock_time = time.monotonic() - async_results = [] - stmt = construct_document_select_by_docset(int(self._id), current_only=False) - for doc in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT): - doc = cast(Document, doc) + num_tasks_sent = 0 + + stmt = construct_document_id_select_by_docset(int(self._id), current_only=False) + for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT): + doc_id = cast(str, doc_id) current_time = time.monotonic() if current_time - last_lock_time >= ( CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4 @@ -86,17 +86,17 @@ def generate_tasks( # add to the set BEFORE creating the task. redis_client.sadd(self.taskset_key, custom_task_id) - result = celery_app.send_task( + celery_app.send_task( OnyxCeleryTask.VESPA_METADATA_SYNC_TASK, - kwargs=dict(document_id=doc.id, tenant_id=tenant_id), + kwargs=dict(document_id=doc_id, tenant_id=tenant_id), queue=OnyxCeleryQueues.VESPA_METADATA_SYNC, task_id=custom_task_id, priority=OnyxCeleryPriority.LOW, ) - async_results.append(result) + num_tasks_sent += 1 - return len(async_results), len(async_results) + return num_tasks_sent, num_tasks_sent def reset(self) -> None: self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) diff --git a/backend/onyx/redis/redis_usergroup.py b/backend/onyx/redis/redis_usergroup.py index 88080685031..92ff5548c73 100644 --- a/backend/onyx/redis/redis_usergroup.py +++ b/backend/onyx/redis/redis_usergroup.py @@ -14,7 +14,6 @@ from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisConstants -from onyx.db.models import Document from onyx.redis.redis_object_helper import RedisObjectHelper from onyx.utils.variable_functionality import fetch_versioned_implementation from onyx.utils.variable_functionality import global_version @@ -66,23 +65,22 @@ def generate_tasks( user group up to date over multiple batches. """ last_lock_time = time.monotonic() - - async_results = [] + num_tasks_sent = 0 if not global_version.is_ee_version(): return 0, 0 try: - construct_document_select_by_usergroup = fetch_versioned_implementation( + construct_document_id_select_by_usergroup = fetch_versioned_implementation( "onyx.db.user_group", - "construct_document_select_by_usergroup", + "construct_document_id_select_by_usergroup", ) except ModuleNotFoundError: return 0, 0 - stmt = construct_document_select_by_usergroup(int(self._id)) - for doc in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT): - doc = cast(Document, doc) + stmt = construct_document_id_select_by_usergroup(int(self._id)) + for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT): + doc_id = cast(str, doc_id) current_time = time.monotonic() if current_time - last_lock_time >= ( CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4 @@ -99,17 +97,17 @@ def generate_tasks( # add to the set BEFORE creating the task. redis_client.sadd(self.taskset_key, custom_task_id) - result = celery_app.send_task( + celery_app.send_task( OnyxCeleryTask.VESPA_METADATA_SYNC_TASK, - kwargs=dict(document_id=doc.id, tenant_id=tenant_id), + kwargs=dict(document_id=doc_id, tenant_id=tenant_id), queue=OnyxCeleryQueues.VESPA_METADATA_SYNC, task_id=custom_task_id, priority=OnyxCeleryPriority.LOW, ) - async_results.append(result) + num_tasks_sent += 1 - return len(async_results), len(async_results) + return num_tasks_sent, num_tasks_sent def reset(self) -> None: self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)