Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: 添加 pydantic validator 兼容函数 #3291

Merged
merged 7 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion nonebot/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Any,
Callable,
Generic,
Literal,
Optional,
Protocol,
TypeVar,
Expand Down Expand Up @@ -45,6 +46,7 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ...

__all__ = (
"DEFAULT_CONFIG",
"PYDANTIC_V2",
"ConfigDict",
"FieldInfo",
"ModelField",
Expand All @@ -54,9 +56,11 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ...
"TypeAdapter",
"custom_validation",
"extract_field_info",
"field_validator",
"model_config",
"model_dump",
"model_fields",
"model_validator",
"type_validate_json",
"type_validate_python",
)
Expand All @@ -70,6 +74,8 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ...
if PYDANTIC_V2: # pragma: pydantic-v2
from pydantic import GetCoreSchemaHandler
from pydantic import TypeAdapter as TypeAdapter
from pydantic import field_validator as field_validator
from pydantic import model_validator as model_validator
from pydantic._internal._repr import display_as_type
from pydantic.fields import FieldInfo as BaseFieldInfo
from pydantic_core import CoreSchema, core_schema
Expand Down Expand Up @@ -254,7 +260,7 @@ def custom_validation(class_: type["CVC"]) -> type["CVC"]:

else: # pragma: pydantic-v1
from pydantic import BaseConfig as PydanticConfig
from pydantic import Extra, parse_obj_as, parse_raw_as
from pydantic import Extra, parse_obj_as, parse_raw_as, root_validator, validator
from pydantic.fields import FieldInfo as BaseFieldInfo
from pydantic.fields import ModelField as BaseModelField
from pydantic.schema import get_annotation_from_field_info
Expand Down Expand Up @@ -367,6 +373,44 @@ def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]:
kwargs.update(field_info.extra)
return kwargs

@overload
def field_validator(
field: str,
/,
*fields: str,
mode: Literal["before"],
check_fields: Optional[bool] = None,
): ...

@overload
def field_validator(
field: str,
/,
*fields: str,
mode: Literal["after"] = ...,
check_fields: Optional[bool] = None,
): ...

def field_validator(
field: str,
/,
*fields: str,
mode: Literal["before", "after"] = "after",
check_fields: Optional[bool] = None,
):
if mode == "before":
return validator(
field,
*fields,
pre=True,
check_fields=check_fields or True,
allow_reuse=True,
)
else:
return validator(
field, *fields, check_fields=check_fields or True, allow_reuse=True
)

def model_fields(model: type[BaseModel]) -> list[ModelField]:
"""Get field list of a model."""

Expand Down Expand Up @@ -404,6 +448,18 @@ def model_dump(
exclude_none=exclude_none,
)

@overload
def model_validator(*, mode: Literal["before"]): ...

@overload
def model_validator(*, mode: Literal["after"]): ...

def model_validator(*, mode: Literal["before", "after"]):
if mode == "before":
return root_validator(pre=True, allow_reuse=True)
else:
return root_validator(skip_on_failure=True, allow_reuse=True)

def type_validate_python(type_: type[T], data: Any) -> T:
"""Validate data with given type."""
return parse_obj_as(type_, data)
Expand Down
57 changes: 57 additions & 0 deletions tests/test_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
Required,
TypeAdapter,
custom_validation,
field_validator,
model_dump,
model_validator,
type_validate_json,
type_validate_python,
)
Expand All @@ -30,6 +32,32 @@ def test_field_info():
assert FieldInfo(test="test").extra["test"] == "test"


def test_field_validator():
class TestModel(BaseModel):
foo: int
bar: str

@field_validator("foo")
@classmethod
def test_validator(cls, v: Any) -> Any:
if v > 0:
return v
raise ValueError("test must be greater than 0")

@field_validator("bar", mode="before")
@classmethod
def test_validator_before(cls, v: Any) -> Any:
if not isinstance(v, str):
v = str(v)
return v

assert type_validate_python(TestModel, {"foo": 1, "bar": "test"}).foo == 1
assert type_validate_python(TestModel, {"foo": 1, "bar": 123}).bar == "123"

with pytest.raises(ValidationError):
TestModel(foo=0, bar="test")


def test_type_adapter():
t = TypeAdapter(Annotated[int, FieldInfo(ge=1)])

Expand All @@ -53,6 +81,35 @@ class TestModel(BaseModel):
assert model_dump(TestModel(test1=1, test2=2), exclude={"test1"}) == {"test2": 2}


def test_model_validator():
class TestModel(BaseModel):
foo: int
bar: str

@model_validator(mode="before")
@classmethod
def test_validator_before(cls, data: Any) -> Any:
if isinstance(data, dict):
if "foo" not in data:
data["foo"] = 1
return data

@model_validator(mode="after")
@classmethod
def test_validator_after(cls, data: Any) -> Any:
if isinstance(data, dict):
if data["bar"] == "test":
raise ValueError("bar should not be test")
elif data.bar == "test":
raise ValueError("bar should not be test")
return data

assert type_validate_python(TestModel, {"bar": "aaa"}).foo == 1

with pytest.raises(ValidationError):
type_validate_python(TestModel, {"foo": 1, "bar": "test"})


def test_custom_validation():
called = []

Expand Down