Skip to content

Commit

Permalink
⚡ improve pydantic v2 performance
Browse files Browse the repository at this point in the history
  • Loading branch information
yanyongyu authored Aug 11, 2024
1 parent 26eabfa commit 92dff90
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 35 deletions.
98 changes: 67 additions & 31 deletions nonebot/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@
"""

from collections.abc import Generator
from functools import cached_property
from dataclasses import dataclass, is_dataclass
from typing_extensions import Self, get_args, get_origin, is_typeddict
from typing import (
TYPE_CHECKING,
Any,
Union,
Generic,
TypeVar,
Callable,
Optional,
Protocol,
Annotated,
overload,
)

from pydantic import VERSION, BaseModel
Expand Down Expand Up @@ -46,8 +49,8 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ...
"DEFAULT_CONFIG",
"FieldInfo",
"ModelField",
"TypeAdapter",
"extract_field_info",
"model_field_validate",
"model_fields",
"model_config",
"model_dump",
Expand All @@ -63,9 +66,10 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ...


if PYDANTIC_V2: # pragma: pydantic-v2
from pydantic import GetCoreSchemaHandler
from pydantic import TypeAdapter as TypeAdapter
from pydantic_core import CoreSchema, core_schema
from pydantic._internal._repr import display_as_type
from pydantic import TypeAdapter, GetCoreSchemaHandler
from pydantic.fields import FieldInfo as BaseFieldInfo

Required = Ellipsis
Expand Down Expand Up @@ -125,6 +129,25 @@ def construct(
"""Construct a ModelField from given infos."""
return cls._construct(name, annotation, field_info or FieldInfo())

def __hash__(self) -> int:
# Each ModelField is unique for our purposes,
# to allow store them in a set.
return id(self)

@cached_property
def type_adapter(self) -> TypeAdapter:
"""TypeAdapter of the field.
Cache the TypeAdapter to avoid creating it multiple times.
Pydantic v2 uses too much cpu time to create TypeAdapter.
See: https://github.com/pydantic/pydantic/issues/9834
"""
return TypeAdapter(
Annotated[self.annotation, self.field_info],
config=None if self._annotation_has_config() else DEFAULT_CONFIG,
)

def _annotation_has_config(self) -> bool:
"""Check if the annotation has config.
Expand Down Expand Up @@ -152,10 +175,9 @@ def _type_display(self):
"""Get the display of the type of the field."""
return display_as_type(self.annotation)

def __hash__(self) -> int:
# Each ModelField is unique for our purposes,
# to allow store them in a set.
return id(self)
def validate(self, value: Any) -> Any:
"""Validate the value pass to the field."""
return self.type_adapter.validate_python(value)

def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]:
"""Get FieldInfo init kwargs from a FieldInfo instance."""
Expand All @@ -164,15 +186,6 @@ def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]:
kwargs["annotation"] = field_info.rebuild_annotation()
return kwargs

def model_field_validate(
model_field: ModelField, value: Any, config: Optional[ConfigDict] = None
) -> Any:
"""Validate the value pass to the field."""
type: Any = Annotated[model_field.annotation, model_field.field_info]
return TypeAdapter(
type, config=None if model_field._annotation_has_config() else config
).validate_python(value)

def model_fields(model: type[BaseModel]) -> list[ModelField]:
"""Get field list of a model."""

Expand Down Expand Up @@ -305,6 +318,45 @@ def construct(
)
return cls._construct(name, annotation, field_info or FieldInfo())

def validate(self, value: Any) -> Any:

Check failure on line 321 in nonebot/compat.py

View workflow job for this annotation

GitHub Actions / Pyright Lint (pydantic-v1)

Method "validate" overrides class "ModelField" in an incompatible manner   Positional parameter count mismatch; base method has 5, but override has 2   Parameter 2 name mismatch: base parameter is named "v", override parameter is named "value"   Parameter 3 mismatch: base parameter "values" is keyword parameter, override parameter is position-only   Parameter "loc" is missing in override   Parameter "cls" is missing in override (reportIncompatibleMethodOverride)
"""Validate the value pass to the field."""
v, errs_ = super().validate(value, {}, loc=())
if errs_:
raise ValueError(value, self)
return v

class TypeAdapter(Generic[T]):
@overload
def __init__(
self,
type: type[T],
*,
config: Optional[ConfigDict] = ...,
) -> None: ...

@overload
def __init__(
self,
type: Any,
*,
config: Optional[ConfigDict] = ...,
) -> None: ...

def __init__(
self,
type: Any,
*,
config: Optional[ConfigDict] = None,
) -> None:
self.type = type
self.config = config

def validate_python(self, value: Any) -> T:
return type_validate_python(self.type, value)

def validate_json(self, value: Union[str, bytes]) -> T:
return type_validate_json(self.type, value)

def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]:
"""Get FieldInfo init kwargs from a FieldInfo instance."""

Expand All @@ -314,22 +366,6 @@ def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]:
kwargs.update(field_info.extra)
return kwargs

def model_field_validate(
model_field: ModelField, value: Any, config: Optional[type[ConfigDict]] = None
) -> Any:
"""Validate the value pass to the field.
Set config before validate to ensure validate correctly.
"""

if model_field.model_config is not config:
model_field.set_config(config or ConfigDict)

v, errs_ = model_field.validate(value, {}, loc=())
if errs_:
raise ValueError(value, model_field)
return v

def model_fields(model: type[BaseModel]) -> list[ModelField]:
"""Get field list of a model."""

Expand Down
4 changes: 2 additions & 2 deletions nonebot/dependencies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

from loguru import logger

from nonebot.compat import ModelField
from nonebot.exception import TypeMisMatch
from nonebot.typing import evaluate_forwardref
from nonebot.compat import DEFAULT_CONFIG, ModelField, model_field_validate


def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
Expand Down Expand Up @@ -51,6 +51,6 @@ def check_field_type(field: ModelField, value: Any) -> Any:
"""检查字段类型是否匹配"""

try:
return model_field_validate(field, value, DEFAULT_CONFIG)
return field.validate(value)
except ValueError:
raise TypeMisMatch(field, value)
20 changes: 18 additions & 2 deletions tests/test_compat.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Any, Optional
from dataclasses import dataclass
from typing import Any, Optional, Annotated

import pytest
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError

from nonebot.compat import (
DEFAULT_CONFIG,
Required,
FieldInfo,
TypeAdapter,
PydanticUndefined,
model_dump,
custom_validation,
Expand All @@ -31,6 +32,21 @@ async def test_field_info():
assert FieldInfo(test="test").extra["test"] == "test"


@pytest.mark.asyncio
async def test_type_adapter():
t = TypeAdapter(Annotated[int, FieldInfo(ge=1)])

assert t.validate_python(2) == 2

with pytest.raises(ValidationError):
t.validate_python(0)

assert t.validate_json("2") == 2

with pytest.raises(ValidationError):
t.validate_json("0")


@pytest.mark.asyncio
async def test_model_dump():
class TestModel(BaseModel):
Expand Down

0 comments on commit 92dff90

Please sign in to comment.