Skip to content

Commit

Permalink
Add multithreading
Browse files Browse the repository at this point in the history
  • Loading branch information
Weves committed Jan 31, 2025
1 parent f929603 commit 8a8c048
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 20 deletions.
6 changes: 6 additions & 0 deletions backend/onyx/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,12 @@
# 0 disables this behavior and is the default.
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0)

# Enable multi-threaded embedding model calls for parallel processing
# Note: only applies for API-based embedding models
INDEXING_EMBEDDING_MODEL_NUM_THREADS = int(
os.environ.get("INDEXING_EMBEDDING_MODEL_NUM_THREADS") or 1
)

# During an indexing attempt, specifies the number of batches which are allowed to
# exception without aborting the attempt.
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT") or 0)
Expand Down
48 changes: 36 additions & 12 deletions backend/onyx/connectors/airtable/airtable_connector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from typing import Any

Expand Down Expand Up @@ -312,7 +314,7 @@ def _process_record(

def load_from_state(self) -> GenerateDocumentsOutput:
"""
Fetch all records from the table.
Fetch all records from the table in parallel batches.
NOTE: Airtable does not support filtering by time updated, so
we have to fetch all records every time.
Expand All @@ -334,21 +336,43 @@ def load_from_state(self) -> GenerateDocumentsOutput:

logger.info(f"Starting to process Airtable records for {table.name}.")

record_documents: list[Document] = []
for record in records:
logger.info(f"Processing record {record['id']} of {table.name}.")

document = self._process_record(
record=record,
table_schema=table_schema,
primary_field_name=primary_field_name,
)
if document:
record_documents.append(document)
# Process records in parallel batches using ThreadPoolExecutor
PARALLEL_BATCH_SIZE = 16
max_workers = min(PARALLEL_BATCH_SIZE, len(records))

# Process records in batches
for i in range(0, len(records), PARALLEL_BATCH_SIZE):
batch_records = records[i : i + PARALLEL_BATCH_SIZE]
record_documents: list[Document] = []

with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit batch tasks
future_to_record = {
executor.submit(
self._process_record,
record=record,
table_schema=table_schema,
primary_field_name=primary_field_name,
): record
for record in batch_records
}

# Wait for all tasks in this batch to complete
for future in as_completed(future_to_record):
record = future_to_record[future]
try:
document = future.result()
if document:
record_documents.append(document)
except Exception as e:
logger.exception(f"Failed to process record {record['id']}")
raise e

# After batch is complete, yield if we've hit the batch size
if len(record_documents) >= self.batch_size:
yield record_documents
record_documents = []

# Yield any remaining records
if record_documents:
yield record_documents
6 changes: 3 additions & 3 deletions backend/onyx/connectors/connector_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ def run(self) -> GenerateDocumentsOutput:
start = time.monotonic()
for batch in self.doc_batch_generator:
# to know how long connector is taking
end = time.monotonic()
logger.debug(
f"Connector tool in {end - start} seconds to build a batch."
f"Connector took {time.monotonic() - start} seconds to build a batch."
)
start = end

yield batch

start = time.monotonic()

except Exception:
exc_type, _, exc_traceback = sys.exc_info()

Expand Down
46 changes: 41 additions & 5 deletions backend/onyx/natural_language_processing/search_nlp_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import threading
import time
from collections.abc import Callable
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import wraps
from typing import Any

Expand All @@ -11,6 +13,7 @@
from requests import Response
from retry import retry

from onyx.configs.app_configs import INDEXING_EMBEDDING_MODEL_NUM_THREADS
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
from onyx.configs.app_configs import SKIP_WARM_UP
from onyx.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
Expand Down Expand Up @@ -155,6 +158,7 @@ def _batch_encode_texts(
text_type: EmbedTextType,
batch_size: int,
max_seq_length: int,
num_threads: int = INDEXING_EMBEDDING_MODEL_NUM_THREADS,
) -> list[Embedding]:
text_batches = batch_list(texts, batch_size)

Expand All @@ -163,12 +167,15 @@ def _batch_encode_texts(
)

embeddings: list[Embedding] = []
for idx, text_batch in enumerate(text_batches, start=1):

def process_batch(
batch_idx: int, text_batch: list[str]
) -> tuple[int, list[Embedding]]:
if self.callback:
if self.callback.should_stop():
raise RuntimeError("_batch_encode_texts detected stop signal")

logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
logger.debug(f"Encoding batch {batch_idx} of {len(text_batches)}")
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
Expand All @@ -185,10 +192,39 @@ def _batch_encode_texts(
)

response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)
return batch_idx, response.embeddings

if num_threads >= 1 and self.provider_type and len(text_batches) > 1:
with ThreadPoolExecutor(max_workers=num_threads) as executor:
future_to_batch = {
executor.submit(process_batch, idx, batch): idx
for idx, batch in enumerate(text_batches, start=1)
}

# Collect results in order
batch_results: list[tuple[int, list[Embedding]]] = []
for future in as_completed(future_to_batch):
try:
result = future.result()
batch_results.append(result)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)
except Exception as e:
logger.exception("Embedding model failed to process batch")
raise e

# Sort by batch index and extend embeddings
batch_results.sort(key=lambda x: x[0])
for _, batch_embeddings in batch_results:
embeddings.extend(batch_embeddings)
else:
# Original sequential processing
for idx, text_batch in enumerate(text_batches, start=1):
_, batch_embeddings = process_batch(idx, text_batch)
embeddings.extend(batch_embeddings)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)

if self.callback:
self.callback.progress("_batch_encode_texts", 1)
return embeddings

def encode(
Expand Down

0 comments on commit 8a8c048

Please sign in to comment.