Skip to content

Commit 2d58ded

Browse files
authored
Enhancement: Add Support for Custom Sortable Field Mapping in SQLAlchemy ModelView (#328)
* Enhancement: Add Support for Custom Sortable Field Mapping in SQLAlchemy ModelView * Fix linting * Add changelog
1 parent ccdf8b5 commit 2d58ded

File tree

3 files changed

+65
-15
lines changed

3 files changed

+65
-15
lines changed

docs/changelog/index.md

+20
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,26 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
99

1010
### Added
1111

12+
* Add Support for Custom Sortable Field Mapping in SQLAlchemy ModelView by [@jowilf](https://github.com/jowilf)
13+
in [#328](https://github.com/jowilf/starlette-admin/pull/328)
14+
15+
!!! usage
16+
```python
17+
class Post(Base):
18+
__tablename__ = "post"
19+
20+
id: Mapped[int] = mapped_column(primary_key=True)
21+
title: Mapped[str] = mapped_column()
22+
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
23+
user: Mapped[User] = relationship(back_populates="posts")
24+
25+
class PostView(ModelView):
26+
sortable_field = ["id", "title", "user"]
27+
sortable_field_mapping = {
28+
"user": User.age, # Sort by the age of the related user
29+
}
30+
```
31+
1232
* Add support for datatables [state saving](https://datatables.net/examples/basic_init/state_save.html)
1333

1434
!!! usage

starlette_admin/contrib/sqla/helpers.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Dict, List, Optional, Sequence
1+
from typing import Any, Callable, Dict, Optional, Sequence
22

33
from sqlalchemy import Column, String, and_, cast, false, not_, or_, true
44
from sqlalchemy.orm import (
@@ -58,16 +58,6 @@ def build_query(
5858
return and_(True)
5959

6060

61-
def build_order_clauses(order_list: List[str], model: Any) -> Any:
62-
clauses = []
63-
for value in order_list:
64-
attr_key, order = value.strip().split(maxsplit=1)
65-
attr = getattr(model, attr_key, None)
66-
if attr is not None:
67-
clauses.append(attr.desc() if order.lower() == "desc" else attr)
68-
return clauses
69-
70-
7161
def normalize_list(
7262
arr: Optional[Sequence[Any]], is_default_sort_list: bool = False
7363
) -> Optional[Sequence[str]]:

starlette_admin/contrib/sqla/view.py

+44-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional, Sequence, Type, Union
1+
from typing import Any, ClassVar, Dict, List, Optional, Sequence, Type, Union
22

33
import anyio.to_thread
44
from sqlalchemy import Column, String, cast, func, inspect, or_, select
@@ -20,7 +20,6 @@
2020
)
2121
from starlette_admin.contrib.sqla.exceptions import InvalidModelError
2222
from starlette_admin.contrib.sqla.helpers import (
23-
build_order_clauses,
2423
build_query,
2524
extract_column_python_type,
2625
normalize_list,
@@ -41,6 +40,30 @@
4140

4241

4342
class ModelView(BaseModelView):
43+
"""A view for managing SQLAlchemy models."""
44+
45+
sortable_field_mapping: ClassVar[Dict[str, InstrumentedAttribute]] = {}
46+
"""A dictionary for overriding the default model attribute used for sorting.
47+
48+
Example:
49+
```python
50+
class Post(Base):
51+
__tablename__ = "post"
52+
53+
id: Mapped[int] = mapped_column(primary_key=True)
54+
title: Mapped[str] = mapped_column()
55+
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
56+
user: Mapped[User] = relationship(back_populates="posts")
57+
58+
59+
class PostView(ModelView):
60+
sortable_field = ["id", "title", "user"]
61+
sortable_field_mapping = {
62+
"user": User.age, # Sort by the age of the related user
63+
}
64+
```
65+
"""
66+
4467
def __init__(
4568
self,
4669
model: Type[Any],
@@ -231,10 +254,14 @@ async def find_all(
231254
request, where, self.model
232255
)
233256
stmt = stmt.where(where) # type: ignore
234-
stmt = stmt.order_by(*build_order_clauses(order_by or [], self.model))
257+
stmt = stmt.order_by(
258+
*self.build_order_clauses(request, order_by or [], self.model)
259+
)
235260
for field in self.get_fields_list(request, RequestAction.LIST):
236261
if isinstance(field, RelationField):
237-
stmt = stmt.options(joinedload(getattr(self.model, field.name)))
262+
stmt = stmt.outerjoin(getattr(self.model, field.name)).options(
263+
joinedload(getattr(self.model, field.name))
264+
)
238265
if isinstance(session, AsyncSession):
239266
return (await session.execute(stmt)).scalars().unique().all()
240267
return (
@@ -417,6 +444,19 @@ async def build_full_text_search_query(
417444
) -> Any:
418445
return self.get_search_query(request, term)
419446

447+
def build_order_clauses(
448+
self, request: Request, order_list: List[str], model: Any
449+
) -> Any:
450+
clauses = []
451+
for value in order_list:
452+
attr_key, order = value.strip().split(maxsplit=1)
453+
attr = self.sortable_field_mapping.get(
454+
attr_key, getattr(model, attr_key, None)
455+
)
456+
if attr is not None:
457+
clauses.append(attr.desc() if order.lower() == "desc" else attr)
458+
return clauses
459+
420460
def handle_exception(self, exc: Exception) -> None:
421461
try:
422462
"""Automatically handle sqlalchemy_file error"""

0 commit comments

Comments
 (0)