-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpatches.py
80 lines (56 loc) · 2.2 KB
/
patches.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import logging
from datetime import datetime
from typing import Any, Callable
from ciso8601 import parse_datetime, parse_rfc3339
def strict_datetime(value: Any) -> datetime:
if isinstance(value, datetime):
return value
try:
# raises TypeError value is not a string
# raises ValueError if string is invalid datetime
return parse_datetime(value) # parse_rfc3339(v)
except Exception as e:
from pydantic.errors import DateTimeError
raise DateTimeError() from e
def patch_datetime_validation() -> None:
from pydantic.validators import _VALIDATORS
for i, (tp, _) in enumerate(_VALIDATORS):
if tp == datetime:
_VALIDATORS[i] = (tp, [strict_datetime])
break
def serialize_datetime(dt: datetime) -> str:
return dt.isoformat(timespec="milliseconds") + "Z"
def patch_datetime_serialization() -> None:
from pydantic.json import ENCODERS_BY_TYPE
ENCODERS_BY_TYPE[datetime] = serialize_datetime
def request_response(func: Callable):
"""
Takes a function or coroutine `func(request) -> response`,
and returns an ASGI application.
"""
from starlette import routing
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from starlette.types import ASGIApp, Receive, Scope, Send
from . import __name__ as mod_name
logger = logging.getLogger(mod_name)
is_coroutine = routing.iscoroutinefunction_or_partial(func)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive=receive, send=send)
body = await request.body()
if body:
logger.info(f"{body!r}")
if is_coroutine:
response = await func(request)
else:
response = await run_in_threadpool(func, request)
await response(scope, receive, send)
return app
def patch_request_response() -> None:
from starlette import routing
routing.request_response = request_response
def patch() -> None:
patch_datetime_validation()
patch_datetime_serialization()
# sometimes useful for debug, but let's turn it off for the other uses
# patch_request_response()