Skip to content

Commit

Permalink
select only doc_id (#3920)
Browse files Browse the repository at this point in the history
* select only doc_id

* select more doc ids

* fix user group

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
  • Loading branch information
rkuo-danswer and Richard Kuo (Danswer) authored Feb 6, 2025
1 parent a0a1b43 commit 6ccb3f0
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 44 deletions.
4 changes: 2 additions & 2 deletions backend/ee/onyx/db/user_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,14 @@ def fetch_user_groups_for_user(
return db_session.scalars(stmt).all()


def construct_document_select_by_usergroup(
def construct_document_id_select_by_usergroup(
user_group_id: int,
) -> Select:
"""This returns a statement that should be executed using
.yield_per() to minimize overhead. The primary consumers of this function
are background processing task generators."""
stmt = (
select(Document)
select(Document.id)
.join(
DocumentByConnectorCredentialPair,
Document.id == DocumentByConnectorCredentialPair.id,
Expand Down
13 changes: 8 additions & 5 deletions backend/onyx/background/celery/tasks/connector_deletion/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,14 @@ def try_generate_document_cc_pair_cleanup_tasks(
if tasks_generated is None:
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")

insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
)
try:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
)
except Exception:
pass

except TaskDependencyError:
redis_connector.delete.set_fence(None)
Expand Down
26 changes: 26 additions & 0 deletions backend/onyx/db/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,32 @@ def construct_document_select_for_connector_credential_pair_by_needs_sync(
return stmt


def construct_document_id_select_for_connector_credential_pair_by_needs_sync(
connector_id: int, credential_id: int
) -> Select:
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)

stmt = (
select(DbDocument.id)
.where(
DbDocument.id.in_(initial_doc_ids_stmt),
or_(
DbDocument.last_modified
> DbDocument.last_synced, # last_modified is newer than last_synced
DbDocument.last_synced.is_(None), # never synced
),
)
.distinct()
)

return stmt


def get_all_documents_needing_vespa_sync_for_cc_pair(
db_session: Session, cc_pair_id: int
) -> list[DbDocument]:
Expand Down
4 changes: 2 additions & 2 deletions backend/onyx/db/document_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def fetch_documents_for_document_set_paginated(
return documents, documents[-1].id if documents else None


def construct_document_select_by_docset(
def construct_document_id_select_by_docset(
document_set_id: int,
current_only: bool = True,
) -> Select:
Expand All @@ -554,7 +554,7 @@ def construct_document_select_by_docset(
are background processing task generators."""

stmt = (
select(Document)
select(Document.id)
.join(
DocumentByConnectorCredentialPair,
DocumentByConnectorCredentialPair.id == Document.id,
Expand Down
26 changes: 13 additions & 13 deletions backend/onyx/redis/redis_connector_credential_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import (
construct_document_select_for_connector_credential_pair_by_needs_sync,
construct_document_id_select_for_connector_credential_pair_by_needs_sync,
)
from onyx.db.models import Document
from onyx.redis.redis_object_helper import RedisObjectHelper


Expand Down Expand Up @@ -72,22 +71,23 @@ def generate_tasks(

last_lock_time = time.monotonic()

async_results = []
num_tasks_sent = 0

cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=int(self._id),
)
if not cc_pair:
return None

stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
stmt = construct_document_id_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)

num_docs = 0

for doc in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc = cast(Document, doc)
for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc_id = cast(str, doc_id)
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
Expand All @@ -98,7 +98,7 @@ def generate_tasks(
num_docs += 1

# check if we should skip the document (typically because it's already syncing)
if doc.id in self.skip_docs:
if doc_id in self.skip_docs:
continue

# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
Expand All @@ -114,21 +114,21 @@ def generate_tasks(
)

# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
celery_app.send_task(
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
kwargs=dict(document_id=doc_id, tenant_id=tenant_id),
queue=OnyxCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
)

async_results.append(result)
self.skip_docs.add(doc.id)
num_tasks_sent += 1
self.skip_docs.add(doc_id)

if len(async_results) >= max_tasks:
if num_tasks_sent >= max_tasks:
break

return len(async_results), num_docs
return num_tasks_sent, num_docs


class RedisGlobalConnectorCredentialPair:
Expand Down
20 changes: 10 additions & 10 deletions backend/onyx/redis/redis_document_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.document_set import construct_document_select_by_docset
from onyx.db.models import Document
from onyx.db.document_set import construct_document_id_select_by_docset
from onyx.redis.redis_object_helper import RedisObjectHelper


Expand Down Expand Up @@ -66,10 +65,11 @@ def generate_tasks(
"""
last_lock_time = time.monotonic()

async_results = []
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
for doc in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc = cast(Document, doc)
num_tasks_sent = 0

stmt = construct_document_id_select_by_docset(int(self._id), current_only=False)
for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc_id = cast(str, doc_id)
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
Expand All @@ -86,17 +86,17 @@ def generate_tasks(
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)

result = celery_app.send_task(
celery_app.send_task(
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
kwargs=dict(document_id=doc_id, tenant_id=tenant_id),
queue=OnyxCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.LOW,
)

async_results.append(result)
num_tasks_sent += 1

return len(async_results), len(async_results)
return num_tasks_sent, num_tasks_sent

def reset(self) -> None:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
Expand Down
22 changes: 10 additions & 12 deletions backend/onyx/redis/redis_usergroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.models import Document
from onyx.redis.redis_object_helper import RedisObjectHelper
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import global_version
Expand Down Expand Up @@ -66,23 +65,22 @@ def generate_tasks(
user group up to date over multiple batches.
"""
last_lock_time = time.monotonic()

async_results = []
num_tasks_sent = 0

if not global_version.is_ee_version():
return 0, 0

try:
construct_document_select_by_usergroup = fetch_versioned_implementation(
construct_document_id_select_by_usergroup = fetch_versioned_implementation(
"onyx.db.user_group",
"construct_document_select_by_usergroup",
"construct_document_id_select_by_usergroup",
)
except ModuleNotFoundError:
return 0, 0

stmt = construct_document_select_by_usergroup(int(self._id))
for doc in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc = cast(Document, doc)
stmt = construct_document_id_select_by_usergroup(int(self._id))
for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc_id = cast(str, doc_id)
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
Expand All @@ -99,17 +97,17 @@ def generate_tasks(
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)

result = celery_app.send_task(
celery_app.send_task(
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
kwargs=dict(document_id=doc_id, tenant_id=tenant_id),
queue=OnyxCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.LOW,
)

async_results.append(result)
num_tasks_sent += 1

return len(async_results), len(async_results)
return num_tasks_sent, num_tasks_sent

def reset(self) -> None:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
Expand Down

0 comments on commit 6ccb3f0

Please sign in to comment.