Skip to content

Commit

Permalink
[DOP-19784] - add search_vector to Transfer model
Browse files Browse the repository at this point in the history
  • Loading branch information
maxim-lixakov committed Oct 1, 2024
1 parent 2967c7f commit 9372298
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/changelog/next_release/92.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add full-text search for **transfers**
6 changes: 6 additions & 0 deletions syncmaster/backend/api/v1/transfers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ async def read_transfers(
group_id: int,
page: int = Query(gt=0, default=1),
page_size: int = Query(gt=0, le=200, default=20),
search_query: str | None = Query(
None,
title="Search Query",
description="full-text search for transfers",
),
current_user: User = Depends(get_user(is_active=True)),
unit_of_work: UnitOfWork = Depends(UnitOfWorkMarker),
) -> TransferPageSchema:
Expand All @@ -61,6 +66,7 @@ async def read_transfers(
page=page,
page_size=page_size,
group_id=group_id,
search_query=search_query,
)

return TransferPageSchema.from_pagination(pagination=pagination)
Expand Down
55 changes: 55 additions & 0 deletions syncmaster/db/migrations/versions/2024-09-30_b9f5c4315bb2_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# SPDX-FileCopyrightText: 2023-2024 MTS PJSC
# SPDX-License-Identifier: Apache-2.0
"""add search_vector with GIN indexed to Transfer table
Revision ID: b9f5c4315bb2
Revises: 478240cdad4b
Create Date: 2024-09-30 14:25:37.264273
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "b9f5c4315bb2"
down_revision = "478240cdad4b"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
sql_expression = (
"to_tsvector('english'::regconfig, "
"name || ' ' || "
"translate(name, './', ' ') || ' ' || "
"COALESCE(json_extract_path_text(source_params, 'table_name'), '') || ' ' || "
"COALESCE(json_extract_path_text(target_params, 'table_name'), '') || ' ' || "
"COALESCE(json_extract_path_text(source_params, 'directory_path'), '') || ' ' || "
"COALESCE(json_extract_path_text(target_params, 'directory_path'), '') || ' ' || "
"COALESCE(translate(json_extract_path_text(source_params, 'table_name'), './', ' '), '') || ' ' || "
"COALESCE(translate(json_extract_path_text(target_params, 'table_name'), './', ' '), '') || ' ' || "
"COALESCE(translate(json_extract_path_text(source_params, 'directory_path'), './', ' '), '') || ' ' || "
"COALESCE(translate(json_extract_path_text(target_params, 'directory_path'), './', ' '), '')"
")"
)

op.add_column(
"transfer",
sa.Column(
"search_vector",
postgresql.TSVECTOR(),
sa.Computed(sql_expression, persisted=True),
nullable=False,
),
)
op.create_index("idx_transfer_search_vector", "transfer", ["search_vector"], unique=False, postgresql_using="gin")
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("idx_transfer_search_vector", table_name="transfer", postgresql_using="gin")
op.drop_column("transfer", "search_vector")
# ### end Alembic commands ###
32 changes: 31 additions & 1 deletion syncmaster/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@
JSON,
BigInteger,
Boolean,
Computed,
DateTime,
ForeignKey,
Index,
PrimaryKeyConstraint,
String,
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import TSVECTOR
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship
from sqlalchemy_utils import ChoiceType

Expand Down Expand Up @@ -129,9 +132,36 @@ class Transfer(
target_connection: Mapped[Connection] = relationship(foreign_keys=target_connection_id)
queue: Mapped[Queue] = relationship(back_populates="transfers")

search_vector: Mapped[str] = mapped_column(
TSVECTOR,
Computed(
"""
to_tsvector(
'english'::regconfig,
name || ' ' ||
translate(name, './', ' ') || ' ' ||
COALESCE(json_extract_path_text(source_params, 'table_name'), '') || ' ' ||
COALESCE(json_extract_path_text(target_params, 'table_name'), '') || ' ' ||
COALESCE(json_extract_path_text(source_params, 'directory_path'), '') || ' ' ||
COALESCE(json_extract_path_text(target_params, 'directory_path'), '') || ' ' ||
COALESCE(translate(json_extract_path_text(source_params, 'table_name'), './', ' '), '') || ' ' ||
COALESCE(translate(json_extract_path_text(target_params, 'table_name'), './', ' '), '') || ' ' ||
COALESCE(translate(json_extract_path_text(source_params, 'directory_path'), './', ' '), '') || ' ' ||
COALESCE(translate(json_extract_path_text(target_params, 'directory_path'), './', ' '), '')
)
""",
persisted=True,
),
nullable=False,
deferred=True,
)

@declared_attr
def __table_args__(cls) -> tuple:
return (UniqueConstraint("name", "group_id"),)
return (
UniqueConstraint("name", "group_id"),
Index("idx_transfer_search_vector", "search_vector", postgresql_using="gin"),
)


class Status(enum.StrEnum):
Expand Down
3 changes: 2 additions & 1 deletion syncmaster/db/repositories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def _read_by_id(self, id: int, **kwargs: Any) -> Model:
def _model_as_dict(model: Model) -> dict[str, Any]:
d = []
for c in model.__table__.columns:
if c.name == "id": # 'id' is PK autoincrement
if c.name in ("id", "search_vector"):
continue
d.append(c.name)

Expand All @@ -51,6 +51,7 @@ async def _copy(self, *args: Any, **kwargs: Any) -> Model:
kwargs.update({k: getattr(origin_model, k)})

d.update(kwargs) # Process kwargs in order to keep only what needs to be updated
d.pop("search_vector", None) # 'search_vector' is computed field
query_insert_new_row = insert(self._model).values(**d).returning(self._model)
try:
new_row = await self._session.scalars(query_insert_new_row)
Expand Down
12 changes: 11 additions & 1 deletion syncmaster/db/repositories/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Sequence
from typing import Any, NoReturn

from sqlalchemy import ScalarResult, insert, or_, select
from sqlalchemy import ScalarResult, func, insert, or_, select
from sqlalchemy.exc import DBAPIError, IntegrityError, NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
Expand Down Expand Up @@ -32,9 +32,19 @@ async def paginate(
page: int,
page_size: int,
group_id: int | None = None,
search_query: str | None = None,
) -> Pagination:
stmt = select(Transfer).where(Transfer.is_deleted.is_(False))

if search_query:
processed_query = search_query.replace("/", " ").replace(".", " ")
combined_query = f"{search_query} {processed_query}"
ts_query = func.plainto_tsquery("english", combined_query)
stmt = stmt.where(Transfer.search_vector.op("@@")(ts_query))
stmt = stmt.add_columns(func.ts_rank(Transfer.search_vector, ts_query).label("rank"))
# sort by ts_rank relevance
stmt = stmt.order_by(func.ts_rank(Transfer.search_vector, ts_query).desc())

return await self._paginate_scalar_result(
query=stmt.where(Transfer.group_id == group_id).order_by(Transfer.name),
page=page,
Expand Down
64 changes: 64 additions & 0 deletions tests/test_unit/test_transfers/test_read_transfer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import random
import string

import pytest
from httpx import AsyncClient

Expand Down Expand Up @@ -111,6 +114,67 @@ async def test_superuser_can_read_transfer(
assert result.status_code == 200


@pytest.mark.parametrize(
"search_value_extractor",
[
lambda transfer: transfer.name,
lambda transfer: transfer.source_params.get("table_name"),
lambda transfer: transfer.target_params.get("table_name"),
lambda transfer: transfer.source_params.get("directory_path"),
lambda transfer: transfer.target_params.get("directory_path"),
],
ids=["name", "source_table_name", "target_table_name", "source_directory_path", "target_directory_path"],
)
async def test_search_transfers_with_query(
client: AsyncClient,
superuser: MockUser,
group_transfer: MockTransfer,
search_value_extractor,
):
transfer = group_transfer.transfer
search_query = search_value_extractor(transfer)

result = await client.get(
"v1/transfers",
headers={"Authorization": f"Bearer {superuser.token}"},
params={"group_id": group_transfer.group_id, "search_query": search_query},
)

transfer_data = result.json()["items"][0]

assert transfer_data == {
"id": group_transfer.id,
"group_id": group_transfer.group_id,
"name": group_transfer.name,
"description": group_transfer.description,
"schedule": group_transfer.schedule,
"is_scheduled": group_transfer.is_scheduled,
"source_connection_id": group_transfer.source_connection_id,
"target_connection_id": group_transfer.target_connection_id,
"source_params": group_transfer.source_params,
"target_params": group_transfer.target_params,
"strategy_params": group_transfer.strategy_params,
"queue_id": group_transfer.transfer.queue_id,
}


async def test_search_transfers_with_nonexistent_query(
client: AsyncClient,
superuser: MockUser,
group_transfer: MockTransfer,
):
random_search_query = "".join(random.choices(string.ascii_lowercase + string.digits, k=12))

result = await client.get(
"v1/transfers",
headers={"Authorization": f"Bearer {superuser.token}"},
params={"group_id": group_transfer.group_id, "search_query": random_search_query},
)

assert result.status_code == 200
assert result.json()["items"] == []


async def test_unauthorized_user_cannot_read_transfer(
client: AsyncClient,
group_transfer: MockTransfer,
Expand Down

0 comments on commit 9372298

Please sign in to comment.