Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strict Tenant ID Enforcement #3871

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open

Strict Tenant ID Enforcement #3871

wants to merge 14 commits into from

Conversation

pablonyx
Copy link
Contributor

@pablonyx pablonyx commented Feb 1, 2025

Description

fixes https://linear.app/danswer/issue/DAN-1390/clean-up-of-tenant-id-assumption

  1. New Default behavior / assumptions

    • When a request comes in, the HTTP middleware sets CURRENT_TENANT_ID_CONTEXTVAR to the correct tenant ID.
    • In Celery tasks, we override the Task class to be TenantAwareTask and you pass tenant_id=... in kwargs, the task automatically sets the contextvar before running.
    • This means that, most of the time, you do not need to pass tenant_id around yourself. Just call:
      with get_session_with_current_tenant() as db_session:
          ...
      or
      r = get_redis_client()
      and it will pick up the ID from the contextvar.
  2. Shared utils

    • For truly global logic not tied to any tenant, use these dedicated helpers:
      • get_session_with_shared_schema(): forces usage of the “public” schema in Postgres.
      • get_shared_redis_client() / get_shared_redis_replica_client(): skip per-tenant prefixes in Redis.
      • get_shared_kv_store(): global K/V store that ignores the per-tenant context.
  3. Overriding Tenant ID

    • If you’re doing multi-tenant “manager” work, you can explicitly pass a tenant ID:
      with get_session_with_current_tenant("some_tenant") as db_session:
          ...
    • This directly sets the contextvar to some_tenant for the duration of that block.
  4. BG processes

  • Celery uses TenantAwareTask:
    @shared_task(base=TenantAwareTask)
    def my_task(*, tenant_id: str):
        # tenant_id is picked up and set into the contextvar automatically
        ...
  • When you call my_task.delay(tenant_id="..."), Celery sets that ID in the contextvar before running the task body.

How Has This Been Tested?

[Describe the tests you ran to verify your changes]

Backporting (check the box to trigger backport action)

Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.

  • This PR should be backported (make sure to check that the backport attempt succeeds)
  • [Optional] Override Linear Check

Copy link

vercel bot commented Feb 1, 2025

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Comments Updated (UTC)
internal-search ✅ Ready (Inspect) Visit Preview 💬 Add feedback Feb 6, 2025 6:35am

Copy link
Contributor

@Weves Weves left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, let's have a separate function for accessing a Postgres / Redis session with the public tenant to not overload the get_session_with_tenant etc. Let's also remove the tenant_id override option to enforce consistency.

)


"""Utils related to contextvars"""


def get_current_tenant_id() -> str:
def current_tenant_id(strict: bool = True) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: prefer get_current_tenant_id for consistency. We usually have a get / fetch prefix (e.g. make them verbs) for functions. For example get_llm / get_session / get_redis_client.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also is there ever a reason to pass in strict=False?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely agree in the abstract - didn't set this originally since we have a dependency called get_current_tenant_id, but with this refactory I think it makes sense to remove that function anyways.

CURRENT_TENANT_ID_CONTEXTVAR: contextvars.ContextVar[
str | None
] = contextvars.ContextVar(
"current_tenant_id", default=None if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thoughts on this always being None and enforcing that it be set on both multi and single tenant to bring behavior closer together? E.g. middleware always sets this to POSTGRES_DEFAULT_SCHEMA.

Not sure if this is better, but just a thought.

@@ -21,7 +21,7 @@
def perform_ttl_management_task(
retention_limit_days: int, *, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant(tenant_id) as db_session:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't believe saying "current tenant" has any semantic meaning to the function name when we're passing in a tenant id. if the function takes a tenant id, it doesn't care where the tenant id came from.

@@ -19,6 +19,7 @@

celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.beat")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Monkey patching Task seems too invasive as there are already built in ways to subclass Task behavior. Wouldn't something like https://docs.celeryq.dev/en/stable/userguide/tasks.html#task-inheritance accomplish this just fine?

@@ -68,12 +69,15 @@
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
"""a lightweight task used to kick off indexing tasks.
Occcasionally does some validation of existing state to clear up error conditions"""
print(f"TENANT ID, {tenant_id}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debug logging?


with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i've seen some infailedsqltransaction errors involving these search_path operations and i'm wondering if we might need a rollback somewhere. Don't have an exact direction to go on this but might be worth investigating.

backend/onyx/setup.py Show resolved Hide resolved
backend/onyx/setup.py Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants