Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more airtable logging #3862

Merged
merged 4 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions backend/onyx/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,12 @@
# 0 disables this behavior and is the default.
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0)

# Enable multi-threaded embedding model calls for parallel processing
# Note: only applies for API-based embedding models
INDEXING_EMBEDDING_MODEL_NUM_THREADS = int(
os.environ.get("INDEXING_EMBEDDING_MODEL_NUM_THREADS") or 1
)

# During an indexing attempt, specifies the number of batches which are allowed to
# exception without aborting the attempt.
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT") or 0)
Expand Down
51 changes: 42 additions & 9 deletions backend/onyx/connectors/airtable/airtable_connector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from typing import Any

Expand Down Expand Up @@ -274,6 +276,11 @@ def _process_record(
field_val = fields.get(field_name)
field_type = field_schema.type

logger.debug(
f"Processing field '{field_name}' of type '{field_type}' "
f"for record '{record_id}'."
)

field_sections, field_metadata = self._process_field(
field_id=field_schema.id,
field_name=field_name,
Expand Down Expand Up @@ -327,19 +334,45 @@ def load_from_state(self) -> GenerateDocumentsOutput:
primary_field_name = field.name
break

record_documents: list[Document] = []
for record in records:
document = self._process_record(
record=record,
table_schema=table_schema,
primary_field_name=primary_field_name,
)
if document:
record_documents.append(document)
logger.info(f"Starting to process Airtable records for {table.name}.")

# Process records in parallel batches using ThreadPoolExecutor
PARALLEL_BATCH_SIZE = 16
max_workers = min(PARALLEL_BATCH_SIZE, len(records))

# Process records in batches
for i in range(0, len(records), PARALLEL_BATCH_SIZE):
batch_records = records[i : i + PARALLEL_BATCH_SIZE]
record_documents: list[Document] = []

with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit batch tasks
future_to_record = {
executor.submit(
self._process_record,
record=record,
table_schema=table_schema,
primary_field_name=primary_field_name,
): record
for record in batch_records
}

# Wait for all tasks in this batch to complete
for future in as_completed(future_to_record):
record = future_to_record[future]
try:
document = future.result()
if document:
record_documents.append(document)
except Exception as e:
logger.exception(f"Failed to process record {record['id']}")
raise e

# After batch is complete, yield if we've hit the batch size
if len(record_documents) >= self.batch_size:
yield record_documents
record_documents = []

# Yield any remaining records
if record_documents:
yield record_documents
13 changes: 12 additions & 1 deletion backend/onyx/connectors/connector_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import time
from datetime import datetime

from onyx.connectors.interfaces import BaseConnector
Expand Down Expand Up @@ -45,7 +46,17 @@ def __init__(
def run(self) -> GenerateDocumentsOutput:
"""Adds additional exception logging to the connector."""
try:
yield from self.doc_batch_generator
start = time.monotonic()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the past i've seen that items yielded from generators in a loop are lazily evaluated (aka not processed until first accessed in the loop. Just wanted to double check that the timings printed here are indeed working correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it does work, not sure exactly the standard behavior here though

for batch in self.doc_batch_generator:
# to know how long connector is taking
logger.debug(
f"Connector took {time.monotonic() - start} seconds to build a batch."
)

yield batch

start = time.monotonic()

except Exception:
exc_type, _, exc_traceback = sys.exc_info()

Expand Down
10 changes: 10 additions & 0 deletions backend/onyx/connectors/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@ class Document(DocumentBase):
id: str # This must be unique or during indexing/reindexing, chunks will be overwritten
source: DocumentSource

def get_total_char_length(self) -> int:
"""Calculate the total character length of the document including sections, metadata, and identifiers."""
section_length = sum(len(section.text) for section in self.sections)
identifier_length = len(self.semantic_identifier) + len(self.title or "")
metadata_length = sum(
len(k) + len(v) if isinstance(v, str) else len(k) + sum(len(x) for x in v)
for k, v in self.metadata.items()
)
return section_length + identifier_length + metadata_length

def to_short_descriptor(self) -> str:
"""Used when logging the identity of a document"""
return f"ID: '{self.id}'; Semantic ID: '{self.semantic_identifier}'"
Expand Down
9 changes: 9 additions & 0 deletions backend/onyx/indexing/indexing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,15 @@ def index_doc_batch(
new_docs=0, total_docs=len(filtered_documents), total_chunks=0
)

doc_descriptors = [
{
"doc_id": doc.id,
"doc_length": doc.get_total_char_length(),
}
for doc in ctx.updatable_docs
]
logger.debug(f"Starting indexing process for documents: {doc_descriptors}")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this make the line below ... "starting chunking" ... irrelevant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo its fine to keep it in (lets you know that chunking is the first step, although no strong opinion)

logger.debug("Starting chunking")
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)

Expand Down
50 changes: 45 additions & 5 deletions backend/onyx/natural_language_processing/search_nlp_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import threading
import time
from collections.abc import Callable
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import wraps
from typing import Any

Expand All @@ -11,6 +13,7 @@
from requests import Response
from retry import retry

from onyx.configs.app_configs import INDEXING_EMBEDDING_MODEL_NUM_THREADS
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
from onyx.configs.app_configs import SKIP_WARM_UP
from onyx.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
Expand Down Expand Up @@ -155,6 +158,7 @@ def _batch_encode_texts(
text_type: EmbedTextType,
batch_size: int,
max_seq_length: int,
num_threads: int = INDEXING_EMBEDDING_MODEL_NUM_THREADS,
) -> list[Embedding]:
text_batches = batch_list(texts, batch_size)

Expand All @@ -163,12 +167,15 @@ def _batch_encode_texts(
)

embeddings: list[Embedding] = []
for idx, text_batch in enumerate(text_batches, start=1):

def process_batch(
batch_idx: int, text_batch: list[str]
) -> tuple[int, list[Embedding]]:
if self.callback:
if self.callback.should_stop():
raise RuntimeError("_batch_encode_texts detected stop signal")

logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
logger.debug(f"Encoding batch {batch_idx} of {len(text_batches)}")
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
Expand All @@ -185,10 +192,43 @@ def _batch_encode_texts(
)

response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)
return batch_idx, response.embeddings

# only multi thread if:
# 1. num_threads is greater than 1
# 2. we are using an API-based embedding model (provider_type is not None)
# 3. there are more than 1 batch (no point in threading if only 1)
if num_threads >= 1 and self.provider_type and len(text_batches) > 1:
with ThreadPoolExecutor(max_workers=num_threads) as executor:
future_to_batch = {
executor.submit(process_batch, idx, batch): idx
for idx, batch in enumerate(text_batches, start=1)
}

# Collect results in order
batch_results: list[tuple[int, list[Embedding]]] = []
for future in as_completed(future_to_batch):
try:
result = future.result()
batch_results.append(result)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)
except Exception as e:
logger.exception("Embedding model failed to process batch")
raise e

# Sort by batch index and extend embeddings
batch_results.sort(key=lambda x: x[0])
for _, batch_embeddings in batch_results:
embeddings.extend(batch_embeddings)
else:
# Original sequential processing
for idx, text_batch in enumerate(text_batches, start=1):
_, batch_embeddings = process_batch(idx, text_batch)
embeddings.extend(batch_embeddings)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)

if self.callback:
self.callback.progress("_batch_encode_texts", 1)
return embeddings

def encode(
Expand Down
Loading