|
1 |
| -from typing import Any, Dict, List, Optional, Sequence, Type, Union |
| 1 | +from typing import Any, ClassVar, Dict, List, Optional, Sequence, Type, Union |
2 | 2 |
|
3 | 3 | import anyio.to_thread
|
4 | 4 | from sqlalchemy import Column, String, cast, func, inspect, or_, select
|
|
20 | 20 | )
|
21 | 21 | from starlette_admin.contrib.sqla.exceptions import InvalidModelError
|
22 | 22 | from starlette_admin.contrib.sqla.helpers import (
|
23 |
| - build_order_clauses, |
24 | 23 | build_query,
|
25 | 24 | extract_column_python_type,
|
26 | 25 | normalize_list,
|
|
41 | 40 |
|
42 | 41 |
|
43 | 42 | class ModelView(BaseModelView):
|
| 43 | + """A view for managing SQLAlchemy models.""" |
| 44 | + |
| 45 | + sortable_field_mapping: ClassVar[Dict[str, InstrumentedAttribute]] = {} |
| 46 | + """A dictionary for overriding the default model attribute used for sorting. |
| 47 | +
|
| 48 | + Example: |
| 49 | + ```python |
| 50 | + class Post(Base): |
| 51 | + __tablename__ = "post" |
| 52 | +
|
| 53 | + id: Mapped[int] = mapped_column(primary_key=True) |
| 54 | + title: Mapped[str] = mapped_column() |
| 55 | + user_id: Mapped[int] = mapped_column(ForeignKey("user.id")) |
| 56 | + user: Mapped[User] = relationship(back_populates="posts") |
| 57 | +
|
| 58 | +
|
| 59 | + class PostView(ModelView): |
| 60 | + sortable_field = ["id", "title", "user"] |
| 61 | + sortable_field_mapping = { |
| 62 | + "user": User.age, # Sort by the age of the related user |
| 63 | + } |
| 64 | + ``` |
| 65 | + """ |
| 66 | + |
44 | 67 | def __init__(
|
45 | 68 | self,
|
46 | 69 | model: Type[Any],
|
@@ -231,10 +254,14 @@ async def find_all(
|
231 | 254 | request, where, self.model
|
232 | 255 | )
|
233 | 256 | stmt = stmt.where(where) # type: ignore
|
234 |
| - stmt = stmt.order_by(*build_order_clauses(order_by or [], self.model)) |
| 257 | + stmt = stmt.order_by( |
| 258 | + *self.build_order_clauses(request, order_by or [], self.model) |
| 259 | + ) |
235 | 260 | for field in self.get_fields_list(request, RequestAction.LIST):
|
236 | 261 | if isinstance(field, RelationField):
|
237 |
| - stmt = stmt.options(joinedload(getattr(self.model, field.name))) |
| 262 | + stmt = stmt.outerjoin(getattr(self.model, field.name)).options( |
| 263 | + joinedload(getattr(self.model, field.name)) |
| 264 | + ) |
238 | 265 | if isinstance(session, AsyncSession):
|
239 | 266 | return (await session.execute(stmt)).scalars().unique().all()
|
240 | 267 | return (
|
@@ -417,6 +444,19 @@ async def build_full_text_search_query(
|
417 | 444 | ) -> Any:
|
418 | 445 | return self.get_search_query(request, term)
|
419 | 446 |
|
| 447 | + def build_order_clauses( |
| 448 | + self, request: Request, order_list: List[str], model: Any |
| 449 | + ) -> Any: |
| 450 | + clauses = [] |
| 451 | + for value in order_list: |
| 452 | + attr_key, order = value.strip().split(maxsplit=1) |
| 453 | + attr = self.sortable_field_mapping.get( |
| 454 | + attr_key, getattr(model, attr_key, None) |
| 455 | + ) |
| 456 | + if attr is not None: |
| 457 | + clauses.append(attr.desc() if order.lower() == "desc" else attr) |
| 458 | + return clauses |
| 459 | + |
420 | 460 | def handle_exception(self, exc: Exception) -> None:
|
421 | 461 | try:
|
422 | 462 | """Automatically handle sqlalchemy_file error"""
|
|
0 commit comments