diff --git a/docs/changelog/next_release/92.feature.rst b/docs/changelog/next_release/92.feature.rst new file mode 100644 index 00000000..eb6702e3 --- /dev/null +++ b/docs/changelog/next_release/92.feature.rst @@ -0,0 +1 @@ +Add full-text search for **transfers** \ No newline at end of file diff --git a/syncmaster/backend/api/v1/transfers.py b/syncmaster/backend/api/v1/transfers.py index 0af69dc8..cc08ab08 100644 --- a/syncmaster/backend/api/v1/transfers.py +++ b/syncmaster/backend/api/v1/transfers.py @@ -45,6 +45,12 @@ async def read_transfers( group_id: int, page: int = Query(gt=0, default=1), page_size: int = Query(gt=0, le=200, default=20), + search_query: str | None = Query( + None, + title="Search Query", + description="full-text search for transfers", + deferred=True, + ), current_user: User = Depends(get_user(is_active=True)), unit_of_work: UnitOfWork = Depends(UnitOfWorkMarker), ) -> TransferPageSchema: @@ -61,6 +67,7 @@ async def read_transfers( page=page, page_size=page_size, group_id=group_id, + search_query=search_query, ) return TransferPageSchema.from_pagination(pagination=pagination) diff --git a/syncmaster/db/migrations/versions/2024-09-30_b9f5c4315bb2_.py b/syncmaster/db/migrations/versions/2024-09-30_b9f5c4315bb2_.py new file mode 100644 index 00000000..8c5b44bf --- /dev/null +++ b/syncmaster/db/migrations/versions/2024-09-30_b9f5c4315bb2_.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: 2023-2024 MTS PJSC +# SPDX-License-Identifier: Apache-2.0 +"""add search_vector with GIN indexed to Transfer table + +Revision ID: b9f5c4315bb2 +Revises: 478240cdad4b +Create Date: 2024-09-30 14:25:37.264273 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "b9f5c4315bb2" +down_revision = "478240cdad4b" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "transfer", + sa.Column( + "search_vector", + postgresql.TSVECTOR(), + sa.Computed( + "\n to_tsvector(\n 'english'::regconfig,\n COALESCE(name, '') || ' ' ||\n COALESCE(translate(json_extract_path_text(source_params, 'table_name'), './', ' '), '') || ' ' ||\n COALESCE(translate(json_extract_path_text(target_params, 'table_name'), './', ' '), '') || ' ' ||\n COALESCE(translate(json_extract_path_text(source_params, 'directory_path'), './', ' '), '') || ' ' ||\n COALESCE(translate(json_extract_path_text(target_params, 'directory_path'), './', ' '), '')\n )\n ", + persisted=True, + ), + nullable=False, + ), + ) + op.create_index("idx_transfer_search_vector", "transfer", ["search_vector"], unique=False, postgresql_using="gin") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("idx_transfer_search_vector", table_name="transfer", postgresql_using="gin") + op.drop_column("transfer", "search_vector") + # ### end Alembic commands ### diff --git a/syncmaster/db/models.py b/syncmaster/db/models.py index 8e6980e3..b5d59e1f 100644 --- a/syncmaster/db/models.py +++ b/syncmaster/db/models.py @@ -10,12 +10,15 @@ JSON, BigInteger, Boolean, + Computed, DateTime, ForeignKey, + Index, PrimaryKeyConstraint, String, UniqueConstraint, ) +from sqlalchemy.dialects.postgresql import TSVECTOR from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship from sqlalchemy_utils import ChoiceType @@ -129,9 +132,30 @@ class Transfer( target_connection: Mapped[Connection] = relationship(foreign_keys=target_connection_id) queue: Mapped[Queue] = relationship(back_populates="transfers") + search_vector: Mapped[str] = mapped_column( + TSVECTOR, + Computed( + """ + to_tsvector( + 'english'::regconfig, + COALESCE(name, '') || ' ' || + COALESCE(translate(json_extract_path_text(source_params, 'table_name'), './', ' '), '') || ' ' || + COALESCE(translate(json_extract_path_text(target_params, 'table_name'), './', ' '), '') || ' ' || + COALESCE(translate(json_extract_path_text(source_params, 'directory_path'), './', ' '), '') || ' ' || + COALESCE(translate(json_extract_path_text(target_params, 'directory_path'), './', ' '), '') + ) + """, + persisted=True, + ), + nullable=False, + ) + @declared_attr def __table_args__(cls) -> tuple: - return (UniqueConstraint("name", "group_id"),) + return ( + UniqueConstraint("name", "group_id"), + Index("idx_transfer_search_vector", "search_vector", postgresql_using="gin"), + ) class Status(enum.StrEnum): diff --git a/syncmaster/db/repositories/base.py b/syncmaster/db/repositories/base.py index a4b4e118..59bfbb89 100644 --- a/syncmaster/db/repositories/base.py +++ b/syncmaster/db/repositories/base.py @@ -51,6 +51,7 @@ async def _copy(self, *args: Any, **kwargs: Any) -> Model: kwargs.update({k: getattr(origin_model, k)}) d.update(kwargs) # Process kwargs in order to keep only what needs to be updated + d.pop("search_vector", None) # 'search_vector' is computed field query_insert_new_row = insert(self._model).values(**d).returning(self._model) try: new_row = await self._session.scalars(query_insert_new_row) diff --git a/syncmaster/db/repositories/transfer.py b/syncmaster/db/repositories/transfer.py index aba546dd..23b22eaa 100644 --- a/syncmaster/db/repositories/transfer.py +++ b/syncmaster/db/repositories/transfer.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from typing import Any, NoReturn -from sqlalchemy import ScalarResult, insert, or_, select +from sqlalchemy import ScalarResult, func, insert, or_, select from sqlalchemy.exc import DBAPIError, IntegrityError, NoResultFound from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload @@ -32,9 +32,15 @@ async def paginate( page: int, page_size: int, group_id: int | None = None, + search_query: str | None = None, ) -> Pagination: stmt = select(Transfer).where(Transfer.is_deleted.is_(False)) + if search_query: + processed_query = search_query.replace("/", " ").replace(".", " ") + ts_query = func.plainto_tsquery("english", processed_query) + stmt = stmt.where(Transfer.search_vector.op("@@")(ts_query)) + return await self._paginate_scalar_result( query=stmt.where(Transfer.group_id == group_id).order_by(Transfer.name), page=page, diff --git a/tests/test_unit/test_transfers/test_read_transfer.py b/tests/test_unit/test_transfers/test_read_transfer.py index ca15ae4c..b3c2e45e 100644 --- a/tests/test_unit/test_transfers/test_read_transfer.py +++ b/tests/test_unit/test_transfers/test_read_transfer.py @@ -1,3 +1,6 @@ +import random +import string + import pytest from httpx import AsyncClient @@ -111,6 +114,66 @@ async def test_superuser_can_read_transfer( assert result.status_code == 200 +@pytest.mark.parametrize( + "search_value_extractor", + [ + lambda transfer: transfer.name, + lambda transfer: transfer.source_params.get("table_name"), + lambda transfer: transfer.target_params.get("table_name"), + lambda transfer: transfer.source_params.get("directory_path"), + lambda transfer: transfer.target_params.get("directory_path"), + ], +) +async def test_search_transfers_with_query( + client: AsyncClient, + superuser: MockUser, + group_transfer: MockTransfer, + search_value_extractor, +): + transfer = group_transfer.transfer + search_query = search_value_extractor(transfer) + + result = await client.get( + "v1/transfers", + headers={"Authorization": f"Bearer {superuser.token}"}, + params={"group_id": group_transfer.group_id, "search_query": search_query}, + ) + + transfer_data = result.json()["items"][0] + + assert transfer_data == { + "id": group_transfer.id, + "group_id": group_transfer.group_id, + "name": group_transfer.name, + "description": group_transfer.description, + "schedule": group_transfer.schedule, + "is_scheduled": group_transfer.is_scheduled, + "source_connection_id": group_transfer.source_connection_id, + "target_connection_id": group_transfer.target_connection_id, + "source_params": group_transfer.source_params, + "target_params": group_transfer.target_params, + "strategy_params": group_transfer.strategy_params, + "queue_id": group_transfer.transfer.queue_id, + } + + +async def test_search_transfers_with_nonexistent_query( + client: AsyncClient, + superuser: MockUser, + group_transfer: MockTransfer, +): + random_search_query = "".join(random.choices(string.ascii_lowercase + string.digits, k=12)) + + result = await client.get( + "v1/transfers", + headers={"Authorization": f"Bearer {superuser.token}"}, + params={"group_id": group_transfer.group_id, "search_query": random_search_query}, + ) + + assert result.status_code == 200 + assert result.json()["items"] == [] + + async def test_unauthorized_user_cannot_read_transfer( client: AsyncClient, group_transfer: MockTransfer,