Skip to content

Commit 33abffb

Browse files
authored
Add support for sqlalchemy collection_class property (#625)
* Add support for sqlalchemy collection_class property * fix ci
1 parent 99a1425 commit 33abffb

File tree

4 files changed

+14
-8
lines changed

4 files changed

+14
-8
lines changed

starlette_admin/contrib/sqla/converters.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,13 @@ def convert_fields_list(
106106
):
107107
converted_fields.append(HasOne(attr.key, identity=identity))
108108
else:
109-
converted_fields.append(HasMany(attr.key, identity=identity))
109+
converted_fields.append(
110+
HasMany(
111+
attr.key,
112+
identity=identity,
113+
collection_class=attr.collection_class or list,
114+
)
115+
)
110116
elif isinstance(attr, ColumnProperty):
111117
assert (
112118
len(attr.columns) == 1

starlette_admin/contrib/sqla/view.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from sqlalchemy.sql import Select
1515
from starlette.requests import Request
1616
from starlette.responses import Response
17-
from starlette_admin import BaseField
17+
from starlette_admin import BaseField, HasMany
1818
from starlette_admin._types import RequestAction
1919
from starlette_admin.contrib.sqla.converters import (
2020
BaseSQLAModelConverter,
@@ -521,12 +521,10 @@ async def _arrange_data(
521521
for field in self.get_fields_list(request, request.state.action):
522522
if isinstance(field, RelationField) and data[field.name] is not None:
523523
foreign_model = self._find_foreign_model(field.identity) # type: ignore
524-
if not field.multiple:
525-
arranged_data[field.name] = await foreign_model.find_by_pk(
526-
request, data[field.name]
527-
)
524+
if isinstance(field, HasMany):
525+
arranged_data[field.name] = field.collection_class(await foreign_model.find_by_pks(request, data[field.name])) # type: ignore[call-arg]
528526
else:
529-
arranged_data[field.name] = await foreign_model.find_by_pks(
527+
arranged_data[field.name] = await foreign_model.find_by_pk(
530528
request, data[field.name]
531529
)
532530
else:

starlette_admin/fields.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import (
1010
Any,
1111
Callable,
12+
Collection,
1213
Dict,
1314
List,
1415
Optional,
@@ -1094,6 +1095,7 @@ class HasMany(RelationField):
10941095
"""A field representing a "has-many" relationship between two models."""
10951096

10961097
multiple: bool = True
1098+
collection_class: Union[Type[Collection[Any]], Callable[[], Collection[Any]]] = list
10971099

10981100

10991101
@dataclass(init=False)

tests/sqla/test_sqla_and_pydantic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class User(Base, IDMixin):
3838

3939
name = Column(String(100))
4040

41-
todos = relationship("Todo", back_populates="user")
41+
todos = relationship("Todo", back_populates="user", collection_class=set)
4242

4343

4444
class Todo(Base, IDMixin):

0 commit comments

Comments
 (0)