Skip to content

Commit

Permalink
Change validation schema for PydanticObjectId
Browse files Browse the repository at this point in the history
  • Loading branch information
dantetemplar committed Jan 2, 2025
1 parent 81da389 commit 0f1ce78
Showing 1 changed file with 84 additions and 58 deletions.
142 changes: 84 additions & 58 deletions beanie/odm/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,8 @@
TypeAdapter,
)
from pydantic.json_schema import JsonSchemaValue
from pydantic_core.core_schema import (
CoreSchema,
ValidationInfo,
any_schema,
dict_schema,
json_or_python_schema,
no_info_after_validator_function,
no_info_plain_validator_function,
plain_serializer_function_ser_schema,
simple_ser_schema,
str_schema,
typed_dict_field,
typed_dict_schema,
union_schema,
with_info_plain_validator_function,
)
from pydantic_core import core_schema
from pydantic_core.core_schema import CoreSchema, ValidationInfo
else:
from pydantic.fields import ModelField
from pydantic.json import ENCODERS_BY_TYPE
Expand Down Expand Up @@ -119,8 +105,8 @@ def __get_pydantic_core_schema__(
if custom_type is not None:
return custom_type(_source_type, _handler)

return no_info_after_validator_function(
lambda v: v, simple_ser_schema(typ.__name__)
return core_schema.no_info_after_validator_function(
lambda v: v, core_schema.simple_ser_schema(typ.__name__)
)

NewType.__name__ = f"Indexed {typ.__name__}"
Expand All @@ -147,34 +133,70 @@ def _validate(cls, v):
def __get_pydantic_core_schema__(
cls, source_type: Type[Any], handler: GetCoreSchemaHandler
) -> CoreSchema:
return json_or_python_schema(
python_schema=no_info_plain_validator_function(cls._validate),
json_schema=no_info_plain_validator_function(
cls._validate,
metadata={
"pydantic_js_input_core_schema": str_schema(
pattern="^[0-9a-f]{24}$",
min_length=24,
max_length=24,
)
},
),
serialization=plain_serializer_function_ser_schema(
lambda instance: str(instance),
return_schema=str_schema(),
when_used="json",
),
definition = core_schema.definition_reference_schema(
"PydanticObjectId"
) # used for deduplication

return core_schema.definitions_schema(
definition,
[
core_schema.json_or_python_schema(
python_schema=core_schema.no_info_plain_validator_function(
cls._validate
),
json_schema=core_schema.no_info_after_validator_function(
cls._validate,
core_schema.str_schema(
pattern="^[0-9a-f]{24}$",
min_length=24,
max_length=24,
),
),
serialization=core_schema.plain_serializer_function_ser_schema(
str, when_used="json"
),
ref=definition["schema_ref"],
)
],
)

@classmethod
def __get_pydantic_json_schema__(
cls, schema: CoreSchema, handler: GetJsonSchemaHandler
cls,
schema: core_schema.CoreSchema,
handler: GetJsonSchemaHandler, # type: ignore
) -> JsonSchemaValue:
"""
Results such schema:
```json
{
"components": {
"schemas": {
"Item": {
"properties": {
"id": {
"$ref": "#/components/schemas/PydanticObjectId"
}
},
"type": "object",
"title": "Item"
},
"PydanticObjectId": {
"type": "string",
"maxLength": 24,
"minLength": 24,
"pattern": "^[0-9a-f]{24}$",
"example": "5eb7cf5a86d9755df3a6c593"
}
}
}
}
```
"""

json_schema = handler(schema)
json_schema.update(
type="string",
example="5eb7cf5a86d9755df3a6c593",
)
schema_to_update = handler.resolve_ref_schema(json_schema)
schema_to_update.update(example="5eb7cf5a86d9755df3a6c593")
return json_schema

else:
Expand Down Expand Up @@ -412,25 +434,29 @@ def validate(
def __get_pydantic_core_schema__(
cls, source_type: Type[Any], handler: GetCoreSchemaHandler
) -> CoreSchema:
return json_or_python_schema(
python_schema=with_info_plain_validator_function(
return core_schema.json_or_python_schema(
python_schema=core_schema.with_info_plain_validator_function(
cls.wrapped_validate(source_type, handler)
),
json_schema=union_schema(
json_schema=core_schema.union_schema(
[
typed_dict_schema(
core_schema.typed_dict_schema(
{
"id": typed_dict_field(str_schema()),
"collection": typed_dict_field(str_schema()),
"id": core_schema.typed_dict_field(
core_schema.str_schema()
),
"collection": core_schema.typed_dict_field(
core_schema.str_schema()
),
}
),
dict_schema(
keys_schema=str_schema(),
values_schema=any_schema(),
core_schema.dict_schema(
keys_schema=core_schema.str_schema(),
values_schema=core_schema.any_schema(),
),
]
),
serialization=plain_serializer_function_ser_schema(
serialization=core_schema.plain_serializer_function_ser_schema(
function=lambda instance: cls.serialize(instance),
when_used="json-unless-none",
),
Expand Down Expand Up @@ -530,17 +556,17 @@ def __get_pydantic_core_schema__(
cls, source_type: Type[Any], handler: GetCoreSchemaHandler
) -> CoreSchema:
# NOTE: BackLinks are only virtual fields, they shouldn't be serialized nor appear in the schema.
return json_or_python_schema(
python_schema=with_info_plain_validator_function(
return core_schema.json_or_python_schema(
python_schema=core_schema.with_info_plain_validator_function(
cls.wrapped_validate(source_type, handler)
),
json_schema=dict_schema(
keys_schema=str_schema(),
values_schema=any_schema(),
json_schema=core_schema.dict_schema(
keys_schema=core_schema.str_schema(),
values_schema=core_schema.any_schema(),
),
serialization=plain_serializer_function_ser_schema(
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: cls.to_dict(instance),
return_schema=dict_schema(),
return_schema=core_schema.dict_schema(),
when_used="json-unless-none",
),
)
Expand Down Expand Up @@ -676,7 +702,7 @@ def _validate(cls, v: Any) -> "IndexModelField":
def __get_pydantic_core_schema__(
cls, source_type: Type[Any], handler: GetCoreSchemaHandler
) -> CoreSchema:
return no_info_plain_validator_function(cls._validate)
return core_schema.no_info_plain_validator_function(cls._validate)

else:

Expand Down

0 comments on commit 0f1ce78

Please sign in to comment.