-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
69 lines (51 loc) · 2.14 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
Database models
Fields of the models are annotated by their python type
so typecheckers can really help with code that uses models
"""
from __future__ import annotations
from datetime import datetime
from typing import List, Optional, Set
from uuid import UUID
from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, String
from sqlalchemy.orm import DeclarativeMeta, declarative_base, relationship
from sqlalchemy_utils import UUIDType
from .schemas import ShopUnitType
# for future: github.com/tiangolo/sqlmodel
# may be a better way to handle typechecking + reuse code
Base: DeclarativeMeta = declarative_base()
class BaseUnit:
name: str = Column(String) # type: ignore
date: datetime = Column(DateTime) # type: ignore
type: ShopUnitType = Column(Enum(ShopUnitType)) # type: ignore
price: int = Column(Integer) # type: ignore
@classmethod
def _fields(cls, *, exclude: Optional[Set[str]] = None) -> List[str]:
if exclude is None:
exclude = set()
result = cls.__annotations__.keys()
if issubclass(cls, Base):
result |= cls._sa_class_manager.keys() # type: ignore
return sorted(result - exclude)
def __repr__(self) -> str:
kv = {}
for name in self._fields():
try:
kv[name] = getattr(self, name)
except Exception as e:
kv[name] = f"[Error while getting - {e}]"
pairs = (f"{name}={repr(value)}" for name, value in kv.items())
return f"{type(self).__name__}({', '.join(pairs)})"
class ShopUnit(Base, BaseUnit):
__tablename__ = "shop"
id: UUID = Column(UUIDType(), primary_key=True) # type: ignore
parentId: Optional[UUID] = Column( # type: ignore
UUIDType(), ForeignKey("shop.id"), nullable=True)
children: List[ShopUnit]
sub_offers_count: int = Column(Integer) # type: ignore
class StatUnit(Base, BaseUnit):
__tablename__ = "stat"
_unique_id = Column(Integer, primary_key=True, autoincrement=True)
id: UUID = Column(UUIDType()) # type: ignore
parentId: Optional[UUID] = Column( # type: ignore
UUIDType(), nullable=True)