diff --git a/nonebot/compat.py b/nonebot/compat.py index a13a88692e29..8783d9caf090 100644 --- a/nonebot/compat.py +++ b/nonebot/compat.py @@ -18,6 +18,7 @@ Any, Callable, Generic, + Literal, Optional, Protocol, TypeVar, @@ -45,6 +46,7 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ... __all__ = ( "DEFAULT_CONFIG", + "PYDANTIC_V2", "ConfigDict", "FieldInfo", "ModelField", @@ -54,9 +56,11 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ... "TypeAdapter", "custom_validation", "extract_field_info", + "field_validator", "model_config", "model_dump", "model_fields", + "model_validator", "type_validate_json", "type_validate_python", ) @@ -70,6 +74,8 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ... if PYDANTIC_V2: # pragma: pydantic-v2 from pydantic import GetCoreSchemaHandler from pydantic import TypeAdapter as TypeAdapter + from pydantic import field_validator as field_validator + from pydantic import model_validator as model_validator from pydantic._internal._repr import display_as_type from pydantic.fields import FieldInfo as BaseFieldInfo from pydantic_core import CoreSchema, core_schema @@ -254,7 +260,7 @@ def custom_validation(class_: type["CVC"]) -> type["CVC"]: else: # pragma: pydantic-v1 from pydantic import BaseConfig as PydanticConfig - from pydantic import Extra, parse_obj_as, parse_raw_as + from pydantic import Extra, parse_obj_as, parse_raw_as, root_validator, validator from pydantic.fields import FieldInfo as BaseFieldInfo from pydantic.fields import ModelField as BaseModelField from pydantic.schema import get_annotation_from_field_info @@ -367,6 +373,44 @@ def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]: kwargs.update(field_info.extra) return kwargs + @overload + def field_validator( + field: str, + /, + *fields: str, + mode: Literal["before"], + check_fields: Optional[bool] = None, + ): ... + + @overload + def field_validator( + field: str, + /, + *fields: str, + mode: Literal["after"] = ..., + check_fields: Optional[bool] = None, + ): ... + + def field_validator( + field: str, + /, + *fields: str, + mode: Literal["before", "after"] = "after", + check_fields: Optional[bool] = None, + ): + if mode == "before": + return validator( + field, + *fields, + pre=True, + check_fields=check_fields or True, + allow_reuse=True, + ) + else: + return validator( + field, *fields, check_fields=check_fields or True, allow_reuse=True + ) + def model_fields(model: type[BaseModel]) -> list[ModelField]: """Get field list of a model.""" @@ -404,6 +448,18 @@ def model_dump( exclude_none=exclude_none, ) + @overload + def model_validator(*, mode: Literal["before"]): ... + + @overload + def model_validator(*, mode: Literal["after"]): ... + + def model_validator(*, mode: Literal["before", "after"]): + if mode == "before": + return root_validator(pre=True, allow_reuse=True) + else: + return root_validator(skip_on_failure=True, allow_reuse=True) + def type_validate_python(type_: type[T], data: Any) -> T: """Validate data with given type.""" return parse_obj_as(type_, data) diff --git a/tests/test_compat.py b/tests/test_compat.py index ddb032db7c60..04b6d1ba6106 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -11,7 +11,9 @@ Required, TypeAdapter, custom_validation, + field_validator, model_dump, + model_validator, type_validate_json, type_validate_python, ) @@ -30,6 +32,32 @@ def test_field_info(): assert FieldInfo(test="test").extra["test"] == "test" +def test_field_validator(): + class TestModel(BaseModel): + foo: int + bar: str + + @field_validator("foo") + @classmethod + def test_validator(cls, v: Any) -> Any: + if v > 0: + return v + raise ValueError("test must be greater than 0") + + @field_validator("bar", mode="before") + @classmethod + def test_validator_before(cls, v: Any) -> Any: + if not isinstance(v, str): + v = str(v) + return v + + assert type_validate_python(TestModel, {"foo": 1, "bar": "test"}).foo == 1 + assert type_validate_python(TestModel, {"foo": 1, "bar": 123}).bar == "123" + + with pytest.raises(ValidationError): + TestModel(foo=0, bar="test") + + def test_type_adapter(): t = TypeAdapter(Annotated[int, FieldInfo(ge=1)]) @@ -53,6 +81,35 @@ class TestModel(BaseModel): assert model_dump(TestModel(test1=1, test2=2), exclude={"test1"}) == {"test2": 2} +def test_model_validator(): + class TestModel(BaseModel): + foo: int + bar: str + + @model_validator(mode="before") + @classmethod + def test_validator_before(cls, data: Any) -> Any: + if isinstance(data, dict): + if "foo" not in data: + data["foo"] = 1 + return data + + @model_validator(mode="after") + @classmethod + def test_validator_after(cls, data: Any) -> Any: + if isinstance(data, dict): + if data["bar"] == "test": + raise ValueError("bar should not be test") + elif data.bar == "test": + raise ValueError("bar should not be test") + return data + + assert type_validate_python(TestModel, {"bar": "aaa"}).foo == 1 + + with pytest.raises(ValidationError): + type_validate_python(TestModel, {"foo": 1, "bar": "test"}) + + def test_custom_validation(): called = []