Skip to content

Commit

Permalink
✨ model_validator & field_validator for pydantic compat
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt authored Jan 31, 2025
1 parent b16ddf3 commit 604fb7d
Showing 1 changed file with 48 additions and 1 deletion.
49 changes: 48 additions & 1 deletion nonebot/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ...

__all__ = (
"DEFAULT_CONFIG",
"PYDANTIC_V2",
"ConfigDict",
"FieldInfo",
"ModelField",
Expand All @@ -54,9 +55,11 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ...
"TypeAdapter",
"custom_validation",
"extract_field_info",
"field_validator",
"model_config",
"model_dump",
"model_fields",
"model_validator",
"type_validate_json",
"type_validate_python",
)
Expand All @@ -70,6 +73,8 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ...
if PYDANTIC_V2: # pragma: pydantic-v2
from pydantic import GetCoreSchemaHandler
from pydantic import TypeAdapter as TypeAdapter
from pydantic import field_validator as field_validator
from pydantic import model_validator as model_validator
from pydantic._internal._repr import display_as_type
from pydantic.fields import FieldInfo as BaseFieldInfo
from pydantic_core import CoreSchema, core_schema
Expand Down Expand Up @@ -254,7 +259,7 @@ def custom_validation(class_: type["CVC"]) -> type["CVC"]:

else: # pragma: pydantic-v1
from pydantic import BaseConfig as PydanticConfig
from pydantic import Extra, parse_obj_as, parse_raw_as
from pydantic import Extra, parse_obj_as, parse_raw_as, root_validator, validator
from pydantic.fields import FieldInfo as BaseFieldInfo
from pydantic.fields import ModelField as BaseModelField
from pydantic.schema import get_annotation_from_field_info
Expand Down Expand Up @@ -367,6 +372,36 @@ def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]:
kwargs.update(field_info.extra)
return kwargs

@overload
def field_validator(
field: str,
/,
*fields: str,
mode: Literal["before"],

Check failure on line 380 in nonebot/compat.py

View workflow job for this annotation

GitHub Actions / Ruff Lint

Ruff (F821)

nonebot/compat.py:380:15: F821 Undefined name `Literal`

Check failure on line 380 in nonebot/compat.py

View workflow job for this annotation

GitHub Actions / Ruff Lint

Ruff (F821)

nonebot/compat.py:380:24: F821 Undefined name `before`
check_fields: Optional[bool] = None,
): ...

@overload
def field_validator(
field: str,
/,
*fields: str,
mode: Literal["after"],

Check failure on line 389 in nonebot/compat.py

View workflow job for this annotation

GitHub Actions / Ruff Lint

Ruff (F821)

nonebot/compat.py:389:15: F821 Undefined name `Literal`

Check failure on line 389 in nonebot/compat.py

View workflow job for this annotation

GitHub Actions / Ruff Lint

Ruff (F821)

nonebot/compat.py:389:24: F821 Undefined name `after`
check_fields: Optional[bool] = None,
): ...

def field_validator(
field: str,
/,
*fields: str,
mode: Literal["before", "after"],

Check failure on line 397 in nonebot/compat.py

View workflow job for this annotation

GitHub Actions / Ruff Lint

Ruff (F821)

nonebot/compat.py:397:15: F821 Undefined name `Literal`

Check failure on line 397 in nonebot/compat.py

View workflow job for this annotation

GitHub Actions / Ruff Lint

Ruff (F821)

nonebot/compat.py:397:24: F821 Undefined name `before`

Check failure on line 397 in nonebot/compat.py

View workflow job for this annotation

GitHub Actions / Ruff Lint

Ruff (F821)

nonebot/compat.py:397:34: F821 Undefined name `after`
check_fields: Optional[bool] = None,
):
if mode == "before":
return validator(field, *fields, pre=True, check_fields=check_fields or True, allow_reuse=True)

Check failure on line 401 in nonebot/compat.py

View workflow job for this annotation

GitHub Actions / Ruff Lint

Ruff (E501)

nonebot/compat.py:401:89: E501 Line too long (107 > 88)
else:
return validator(field, *fields, check_fields=check_fields or True, allow_reuse=True)

Check failure on line 403 in nonebot/compat.py

View workflow job for this annotation

GitHub Actions / Ruff Lint

Ruff (E501)

nonebot/compat.py:403:89: E501 Line too long (97 > 88)

def model_fields(model: type[BaseModel]) -> list[ModelField]:
"""Get field list of a model."""

Expand Down Expand Up @@ -404,6 +439,18 @@ def model_dump(
exclude_none=exclude_none,
)

@overload
def model_validator(*, mode: Literal["before"]): ...

Check failure on line 443 in nonebot/compat.py

View workflow job for this annotation

GitHub Actions / Ruff Lint

Ruff (F821)

nonebot/compat.py:443:34: F821 Undefined name `Literal`

@overload
def model_validator(*, mode: Literal["after"]): ...

def model_validator(*, mode: Literal["before", "after"]):
if mode == "before":
return root_validator(pre=True, allow_reuse=True)
else:
return root_validator(skip_on_failure=True, allow_reuse=True)

def type_validate_python(type_: type[T], data: Any) -> T:
"""Validate data with given type."""
return parse_obj_as(type_, data)
Expand Down

0 comments on commit 604fb7d

Please sign in to comment.