Skip to content

Commit

Permalink
fix: pass engine URL instead of engine to thread function
Browse files Browse the repository at this point in the history
Co-Authored-By: Chris Weaver <chris@onyx.app>
  • Loading branch information
devin-ai-integration[bot] and Chris Weaver committed Jan 30, 2025
1 parent 2ef78ec commit c41c21b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 22 deletions.
12 changes: 2 additions & 10 deletions backend/tests/unit/onyx/db/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,24 +85,16 @@ def db_session() -> Generator[Session, None, None]:
# Create all tables after type adaptation
Base.metadata.create_all(bind=engine)

# For SQLite, we need to create a new connection for each session
# to avoid threading issues
connection = engine.connect()
SessionLocal = sessionmaker(
bind=connection,
expire_on_commit=False, # Prevent detached instance errors
autoflush=True,
)
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
session = SessionLocal()

try:
yield session
session.flush() # Make sure all SQL is executed
session.commit()
except:
session.rollback()
raise
finally:
session.close()
connection.close()
engine.dispose()
Base.metadata.drop_all(bind=engine)
21 changes: 9 additions & 12 deletions backend/tests/unit/onyx/db/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from typing import List

import pytest
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import create_engine, func, select
from sqlalchemy.orm import Session
from sqlalchemy.orm.session import sessionmaker

Expand All @@ -12,23 +11,20 @@
from onyx.db.users import batch_add_ext_perm_user_if_not_exists


def _call_parallel(engine, email_list: List[str]) -> None:
# Create a new connection and session for each thread to handle SQLite's threading restrictions
connection = engine.connect()
SessionLocal = sessionmaker(
bind=connection,
expire_on_commit=False,
autoflush=True,
)
def _call_parallel(engine_url: str, email_list: List[str]) -> None:
# Create a new engine for each thread to handle SQLite's threading restrictions
thread_engine = create_engine(engine_url)
SessionLocal = sessionmaker(bind=thread_engine, expire_on_commit=False)
session = SessionLocal()
try:
batch_add_ext_perm_user_if_not_exists(session, email_list)
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
connection.close()
thread_engine.dispose()


@pytest.mark.parametrize(
Expand All @@ -49,8 +45,9 @@ def test_batch_add_ext_perm_user_if_not_exists_concurrent(
engine = db_session.get_bind()

# Create and start multiple threads that all try to add the same users
engine_url = str(engine.url)
for _ in range(thread_count):
t = threading.Thread(target=_call_parallel, args=(engine, emails))
t = threading.Thread(target=_call_parallel, args=(engine_url, emails))
threads.append(t)
t.start()

Expand Down

0 comments on commit c41c21b

Please sign in to comment.