From 604fb7d507b5a4a8bd194b0fb546ad90f894ed60 Mon Sep 17 00:00:00 2001 From: Tarrailt <3165388245@qq.com> Date: Fri, 31 Jan 2025 23:05:47 +0800 Subject: [PATCH 1/7] :sparkles: model_validator & field_validator for pydantic compat --- nonebot/compat.py | 49 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/nonebot/compat.py b/nonebot/compat.py index a13a88692e29..b5e20fc36dbd 100644 --- a/nonebot/compat.py +++ b/nonebot/compat.py @@ -45,6 +45,7 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ... __all__ = ( "DEFAULT_CONFIG", + "PYDANTIC_V2", "ConfigDict", "FieldInfo", "ModelField", @@ -54,9 +55,11 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ... "TypeAdapter", "custom_validation", "extract_field_info", + "field_validator", "model_config", "model_dump", "model_fields", + "model_validator", "type_validate_json", "type_validate_python", ) @@ -70,6 +73,8 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ... if PYDANTIC_V2: # pragma: pydantic-v2 from pydantic import GetCoreSchemaHandler from pydantic import TypeAdapter as TypeAdapter + from pydantic import field_validator as field_validator + from pydantic import model_validator as model_validator from pydantic._internal._repr import display_as_type from pydantic.fields import FieldInfo as BaseFieldInfo from pydantic_core import CoreSchema, core_schema @@ -254,7 +259,7 @@ def custom_validation(class_: type["CVC"]) -> type["CVC"]: else: # pragma: pydantic-v1 from pydantic import BaseConfig as PydanticConfig - from pydantic import Extra, parse_obj_as, parse_raw_as + from pydantic import Extra, parse_obj_as, parse_raw_as, root_validator, validator from pydantic.fields import FieldInfo as BaseFieldInfo from pydantic.fields import ModelField as BaseModelField from pydantic.schema import get_annotation_from_field_info @@ -367,6 +372,36 @@ def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]: kwargs.update(field_info.extra) return kwargs + @overload + def field_validator( + field: str, + /, + *fields: str, + mode: Literal["before"], + check_fields: Optional[bool] = None, + ): ... + + @overload + def field_validator( + field: str, + /, + *fields: str, + mode: Literal["after"], + check_fields: Optional[bool] = None, + ): ... + + def field_validator( + field: str, + /, + *fields: str, + mode: Literal["before", "after"], + check_fields: Optional[bool] = None, + ): + if mode == "before": + return validator(field, *fields, pre=True, check_fields=check_fields or True, allow_reuse=True) + else: + return validator(field, *fields, check_fields=check_fields or True, allow_reuse=True) + def model_fields(model: type[BaseModel]) -> list[ModelField]: """Get field list of a model.""" @@ -404,6 +439,18 @@ def model_dump( exclude_none=exclude_none, ) + @overload + def model_validator(*, mode: Literal["before"]): ... + + @overload + def model_validator(*, mode: Literal["after"]): ... + + def model_validator(*, mode: Literal["before", "after"]): + if mode == "before": + return root_validator(pre=True, allow_reuse=True) + else: + return root_validator(skip_on_failure=True, allow_reuse=True) + def type_validate_python(type_: type[T], data: Any) -> T: """Validate data with given type.""" return parse_obj_as(type_, data) From 8725d75780611c39ed1aa8b06c319abe4c9942d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 31 Jan 2025 15:06:00 +0000 Subject: [PATCH 2/7] :rotating_light: auto fix by pre-commit hooks --- nonebot/compat.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/nonebot/compat.py b/nonebot/compat.py index b5e20fc36dbd..f963c0448e67 100644 --- a/nonebot/compat.py +++ b/nonebot/compat.py @@ -398,9 +398,17 @@ def field_validator( check_fields: Optional[bool] = None, ): if mode == "before": - return validator(field, *fields, pre=True, check_fields=check_fields or True, allow_reuse=True) + return validator( + field, + *fields, + pre=True, + check_fields=check_fields or True, + allow_reuse=True, + ) else: - return validator(field, *fields, check_fields=check_fields or True, allow_reuse=True) + return validator( + field, *fields, check_fields=check_fields or True, allow_reuse=True + ) def model_fields(model: type[BaseModel]) -> list[ModelField]: """Get field list of a model.""" From bc802df7d4d6dc1a4d7dac69cdab58584fc41816 Mon Sep 17 00:00:00 2001 From: Tarrailt <3165388245@qq.com> Date: Fri, 31 Jan 2025 23:08:16 +0800 Subject: [PATCH 3/7] :bug: import missing typing --- nonebot/compat.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nonebot/compat.py b/nonebot/compat.py index f963c0448e67..58cdbcd36887 100644 --- a/nonebot/compat.py +++ b/nonebot/compat.py @@ -18,6 +18,7 @@ Any, Callable, Generic, + Literal, Optional, Protocol, TypeVar, From 225dbbf5d57ce773a48ae969b1df520c28bf74b8 Mon Sep 17 00:00:00 2001 From: RF-Tar-Railt Date: Fri, 31 Jan 2025 23:27:52 +0800 Subject: [PATCH 4/7] :white_check_mark: add tests for validator --- nonebot/compat.py | 4 ++-- tests/test_compat.py | 52 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/nonebot/compat.py b/nonebot/compat.py index 58cdbcd36887..8783d9caf090 100644 --- a/nonebot/compat.py +++ b/nonebot/compat.py @@ -387,7 +387,7 @@ def field_validator( field: str, /, *fields: str, - mode: Literal["after"], + mode: Literal["after"] = ..., check_fields: Optional[bool] = None, ): ... @@ -395,7 +395,7 @@ def field_validator( field: str, /, *fields: str, - mode: Literal["before", "after"], + mode: Literal["before", "after"] = "after", check_fields: Optional[bool] = None, ): if mode == "before": diff --git a/tests/test_compat.py b/tests/test_compat.py index ddb032db7c60..11f282c88eae 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -7,11 +7,13 @@ from nonebot.compat import ( DEFAULT_CONFIG, FieldInfo, + field_validator, PydanticUndefined, Required, TypeAdapter, custom_validation, model_dump, + model_validator, type_validate_json, type_validate_python, ) @@ -30,6 +32,32 @@ def test_field_info(): assert FieldInfo(test="test").extra["test"] == "test" +def test_field_validator(): + class TestModel(BaseModel): + foo: int + bar: str + + @field_validator("foo") + @classmethod + def test_validator(cls, v: Any) -> Any: + if v > 0: + return v + raise ValueError("test must be greater than 0") + + @field_validator("bar", mode="before") + @classmethod + def test_validator_before(cls, v: Any) -> Any: + if not isinstance(v, str): + v = str(v) + return v + + assert TestModel(foo=1, bar="test").foo == 1 + assert TestModel(foo=1, bar=123).bar == "123" + + with pytest.raises(ValidationError): + TestModel(foo=0, bar="test") + + def test_type_adapter(): t = TypeAdapter(Annotated[int, FieldInfo(ge=1)]) @@ -53,6 +81,30 @@ class TestModel(BaseModel): assert model_dump(TestModel(test1=1, test2=2), exclude={"test1"}) == {"test2": 2} +def test_model_validator(): + class TestModel(BaseModel): + foo: int + bar: str + + @model_validator(mode="before") + @classmethod + def test_validator_before(cls, data: Any) -> Any: + if isinstance(data, dict): + if "foo" not in data: + data["foo"] = 1 + return data + + @model_validator(mode="after") + def test_validator_after(self): + if self.bar == "test": + raise ValueError("bar should not be test") + + assert type_validate_python(TestModel, {"bar": "aaa"}).foo == 1 + + with pytest.raises(ValidationError): + type_validate_python(TestModel, {"foo": 1, "bar": "test"}) + + def test_custom_validation(): called = [] From d8247fd3028db652a9ddb0ee1ec03314c27c9f4d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 31 Jan 2025 15:29:49 +0000 Subject: [PATCH 5/7] :rotating_light: auto fix by pre-commit hooks --- tests/test_compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_compat.py b/tests/test_compat.py index 11f282c88eae..46e322b4da43 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -7,11 +7,11 @@ from nonebot.compat import ( DEFAULT_CONFIG, FieldInfo, - field_validator, PydanticUndefined, Required, TypeAdapter, custom_validation, + field_validator, model_dump, model_validator, type_validate_json, From b5d1ab99ce6bc2134b7332aacb30db46beaadb77 Mon Sep 17 00:00:00 2001 From: RF-Tar-Railt Date: Fri, 31 Jan 2025 23:38:05 +0800 Subject: [PATCH 6/7] :white_check_mark: fix tests for validator --- tests/test_compat.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_compat.py b/tests/test_compat.py index 46e322b4da43..ba60d6f84e1d 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -51,8 +51,8 @@ def test_validator_before(cls, v: Any) -> Any: v = str(v) return v - assert TestModel(foo=1, bar="test").foo == 1 - assert TestModel(foo=1, bar=123).bar == "123" + assert type_validate_python(TestModel, {"foo": 1, "bar": "test"}).foo == 1 + assert type_validate_python(TestModel, {"foo": 1, "bar": 123}).bar == "123" with pytest.raises(ValidationError): TestModel(foo=0, bar="test") @@ -95,9 +95,11 @@ def test_validator_before(cls, data: Any) -> Any: return data @model_validator(mode="after") - def test_validator_after(self): - if self.bar == "test": + @classmethod + def test_validator_after(cls, data: Any) -> Any: + if data.bar == "test": raise ValueError("bar should not be test") + return data assert type_validate_python(TestModel, {"bar": "aaa"}).foo == 1 From 24994ed032d07346486377ddb4df657af6ac187d Mon Sep 17 00:00:00 2001 From: RF-Tar-Railt Date: Fri, 31 Jan 2025 23:43:54 +0800 Subject: [PATCH 7/7] :white_check_mark: fix test_model_validator --- tests/test_compat.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_compat.py b/tests/test_compat.py index ba60d6f84e1d..04b6d1ba6106 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -97,7 +97,10 @@ def test_validator_before(cls, data: Any) -> Any: @model_validator(mode="after") @classmethod def test_validator_after(cls, data: Any) -> Any: - if data.bar == "test": + if isinstance(data, dict): + if data["bar"] == "test": + raise ValueError("bar should not be test") + elif data.bar == "test": raise ValueError("bar should not be test") return data