Skip to content

Commit

Permalink
fix: improve thread safety in SQLite test sessions
Browse files Browse the repository at this point in the history
Co-Authored-By: Chris Weaver <chris@onyx.app>
  • Loading branch information
devin-ai-integration[bot] and Chris Weaver committed Jan 30, 2025
1 parent 22a10f7 commit 3d90b79
Showing 1 changed file with 29 additions and 16 deletions.
45 changes: 29 additions & 16 deletions backend/tests/unit/onyx/db/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.session import sessionmaker

from onyx.auth.schemas import UserRole
from onyx.db.models import User
Expand All @@ -14,10 +14,17 @@

def _call_parallel(engine, email_list: List[str]) -> None:
# Create a new session for each thread
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
SessionLocal = sessionmaker(
bind=engine,
expire_on_commit=False,
autoflush=True,
)
session = SessionLocal()
try:
batch_add_ext_perm_user_if_not_exists(session, email_list)
except Exception as e:
session.rollback()
raise
finally:
session.close()

Expand Down Expand Up @@ -49,20 +56,26 @@ def test_batch_add_ext_perm_user_if_not_exists_concurrent(
for t in threads:
t.join()

# Verify results - should have exactly one user per unique email (case insensitive)
stmt = select(User).filter(
func.lower(User.email).in_([email.lower() for email in emails])
)
created_users = db_session.scalars(stmt).unique().all()
# Create a new session for verification
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
verify_session = SessionLocal()
try:
# Verify results - should have exactly one user per unique email (case insensitive)
stmt = select(User).filter(
func.lower(User.email).in_([email.lower() for email in emails])
)
created_users = verify_session.scalars(stmt).unique().all()

# Check total number of users (should be 2 since one email is a duplicate with different case)
assert len(created_users) == 2
# Check total number of users (should be 2 since one email is a duplicate with different case)
assert len(created_users) == 2

# Verify all users have the correct role
for user in created_users:
assert user.role == UserRole.EXT_PERM_USER
# Verify all users have the correct role
for user in created_users:
assert user.role == UserRole.EXT_PERM_USER

# Verify emails are present (case insensitive)
created_emails = [user.email.lower() for user in created_users]
assert "user1@example.com" in created_emails
assert "user2@example.com" in created_emails
# Verify emails are present (case insensitive)
created_emails = [user.email.lower() for user in created_users]
assert "user1@example.com" in created_emails
assert "user2@example.com" in created_emails
finally:
verify_session.close()

0 comments on commit 3d90b79

Please sign in to comment.