Skip to content

Commit

Permalink
Removed calls to function causing deprecation warning where possible
Browse files Browse the repository at this point in the history
  • Loading branch information
07pepa authored and 07pepa committed Aug 15, 2024
1 parent 0b8d50b commit 86cbc28
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 25 deletions.
8 changes: 6 additions & 2 deletions beanie/migrations/controllers/iterative.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from beanie.migrations.controllers.base import BaseMigrationController
from beanie.migrations.utils import update_dict
from beanie.odm.documents import Document
from beanie.odm.utils.pydantic import parse_model
from beanie.odm.utils.pydantic import IS_PYDANTIC_V2, parse_model


class DummyOutput:
Expand Down Expand Up @@ -104,7 +104,11 @@ async def run(self, session):
if "self" in self.function_signature.parameters:
function_kwargs["self"] = None
await self.function(**function_kwargs)
output_dict = input_document.dict()
output_dict = (
input_document.dict()
if not IS_PYDANTIC_V2
else input_document.model_dump()
)
update_dict(output_dict, output.dict())
output_document = parse_model(
self.output_document_model, output_dict
Expand Down
7 changes: 5 additions & 2 deletions beanie/migrations/runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import types
from importlib.machinery import SourceFileLoader
from pathlib import Path
from typing import List, Optional, Type
Expand Down Expand Up @@ -218,9 +219,11 @@ async def build(cls, path: Path):
prev_migration_node = root_migration_node

for name in names:
module = SourceFileLoader(
loader = SourceFileLoader(
(path / name).stem, str((path / name).absolute())
).load_module((path / name).stem)
)
module = types.ModuleType(loader.name)
loader.exec_module(module)
forward_class = getattr(module, "Forward", None)
backward_class = getattr(module, "Backward", None)
migration_node = cls(
Expand Down
26 changes: 15 additions & 11 deletions beanie/odm/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@
if TYPE_CHECKING:
from beanie.odm.documents import DocType

if IS_PYDANTIC_V2:
plain_validator = (
core_schema.with_info_plain_validator_function
if hasattr(core_schema, "with_info_plain_validator_function")
else core_schema.general_plain_validator_function
)
else:

def plain_validator(v):
return v


@dataclass(frozen=True)
class IndexedAnnotation:
Expand Down Expand Up @@ -147,9 +158,7 @@ def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema: # type: ignore
return core_schema.json_or_python_schema(
python_schema=core_schema.with_info_plain_validator_function(
cls.validate
),
python_schema=plain_validator(cls.validate),
json_schema=str_schema(),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: str(instance), when_used="json"
Expand Down Expand Up @@ -401,7 +410,7 @@ def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema: # type: ignore
return core_schema.json_or_python_schema(
python_schema=core_schema.with_info_plain_validator_function(
python_schema=plain_validator(
cls.build_validation(handler, source_type)
),
json_schema=core_schema.typed_dict_schema(
Expand All @@ -419,9 +428,6 @@ def __get_pydantic_core_schema__(
when_used="json", # type: ignore
),
)
return core_schema.with_info_plain_validator_function(
cls.build_validation(handler, source_type)
)

else:

Expand Down Expand Up @@ -481,9 +487,7 @@ def validate(v: Union[DBRef, T], field):
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema: # type: ignore
return core_schema.with_info_plain_validator_function(
cls.build_validation(handler, source_type)
)
return plain_validator(cls.build_validation(handler, source_type))

else:

Expand Down Expand Up @@ -588,7 +592,7 @@ def validate(v, _):
else:
return IndexModelField(IndexModel(v))

return core_schema.with_info_plain_validator_function(validate)
return plain_validator(validate)

else:

Expand Down
16 changes: 10 additions & 6 deletions tests/fastapi/app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from contextlib import asynccontextmanager

import motor.motor_asyncio
from fastapi import FastAPI

Expand All @@ -6,11 +8,9 @@
from tests.fastapi.models import DoorAPI, HouseAPI, RoofAPI, WindowAPI
from tests.fastapi.routes import house_router

app = FastAPI()


@app.on_event("startup")
async def app_init():
@asynccontextmanager
async def live_span(_: FastAPI):
# CREATE MOTOR CLIENT
client = motor.motor_asyncio.AsyncIOMotorClient(Settings().mongodb_dsn)

Expand All @@ -19,6 +19,10 @@ async def app_init():
client.beanie_db,
document_models=[HouseAPI, WindowAPI, DoorAPI, RoofAPI],
)
yield


app = FastAPI(lifespan=live_span)

# ADD ROUTES
app.include_router(house_router, prefix="/v1", tags=["house"])
# ADD ROUTES
app.include_router(house_router, prefix="/v1", tags=["house"])
6 changes: 4 additions & 2 deletions tests/fastapi/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from asgi_lifespan import LifespanManager
from httpx import AsyncClient
from httpx import ASGITransport, AsyncClient

from tests.fastapi.app import app
from tests.fastapi.models import DoorAPI, HouseAPI, RoofAPI, WindowAPI
Expand All @@ -11,7 +11,9 @@ async def api_client(clean_db):
"""api client fixture."""
async with LifespanManager(app, startup_timeout=100, shutdown_timeout=100):
server_name = "https://localhost"
async with AsyncClient(app=app, base_url=server_name) as ac:
async with AsyncClient(
transport=ASGITransport(app=app), base_url=server_name
) as ac:
yield ac


Expand Down
5 changes: 4 additions & 1 deletion tests/fastapi/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ async def create_house(window: WindowAPI):

@house_router.post("/houses_with_window_link/", response_model=HouseAPI)
async def create_houses_with_window_link(window: WindowInput):
house = HouseAPI.parse_obj(
validator = (
HouseAPI.model_validate if IS_PYDANTIC_V2 else HouseAPI.parse_obj
)
house = validator(
dict(name="test_name", windows=[WindowAPI.link_from_id(window.id)])
)
await house.insert(link_rule=WriteRules.WRITE)
Expand Down
7 changes: 6 additions & 1 deletion tests/odm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,12 @@ def validate(value, _: FieldInfo) -> Color:
return Color(value["value"])
return Color(value)

python_schema = core_schema.general_plain_validator_function(validate)
vf = (
core_schema.with_info_plain_validator_function
if hasattr(core_schema, "with_info_plain_validator_function")
else core_schema.general_plain_validator_function
)
python_schema = vf(validate)

return core_schema.json_or_python_schema(
json_schema=core_schema.str_schema(),
Expand Down

0 comments on commit 86cbc28

Please sign in to comment.