Skip to content

Commit

Permalink
Replace deprecated functions with new pydantic v2 functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisLovering committed Aug 11, 2023
1 parent 38ad2c4 commit 292fba7
Show file tree
Hide file tree
Showing 15 changed files with 53 additions and 42 deletions.
4 changes: 2 additions & 2 deletions bot/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
from enum import Enum

from pydantic import BaseModel, root_validator
from pydantic import BaseModel, model_validator
from pydantic_settings import BaseSettings


Expand Down Expand Up @@ -311,7 +311,7 @@ class _Colours(EnvConfig, env_prefix="colours_"):
white: int = 0xfffffe
yellow: int = 0xffd241

@root_validator(pre=True)
@model_validator(mode="before")
def parse_hex_values(cls, values: dict) -> dict: # noqa: N805
"""Convert hex strings to ints."""
for key, value in values.items():
Expand Down
2 changes: 1 addition & 1 deletion bot/exts/filtering/_filter_lists/antispam.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def actions_for(
current_actions.pop("ping", None)
current_actions.pop("send_alert", None)

new_infraction = current_actions[InfractionAndNotification.name].copy()
new_infraction = current_actions[InfractionAndNotification.name].model_copy()
# Smaller infraction value => higher in hierarchy.
if not current_infraction or new_infraction.infraction_type.value < current_infraction.value:
# Pick the first triggered filter for the reason, there's no good way to decide between them.
Expand Down
4 changes: 2 additions & 2 deletions bot/exts/filtering/_filters/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, filter_data: dict, defaults: Defaults | None = None):
self.updated_at = arrow.get(filter_data["updated_at"])
self.actions, self.validations = create_settings(filter_data["settings"], defaults=defaults)
if self.extra_fields_type:
self.extra_fields = self.extra_fields_type.parse_obj(filter_data["additional_settings"])
self.extra_fields = self.extra_fields_type.model_validate(filter_data["additional_settings"])
else:
self.extra_fields = None

Expand All @@ -46,7 +46,7 @@ def overrides(self) -> tuple[dict[str, Any], dict[str, Any]]:

filter_settings = {}
if self.extra_fields:
filter_settings = self.extra_fields.dict(exclude_unset=True)
filter_settings = self.extra_fields.model_dump(exclude_unset=True)

return settings, filter_settings

Expand Down
2 changes: 1 addition & 1 deletion bot/exts/filtering/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,5 +227,5 @@ def dict(self) -> dict[str, Any]:
"""Return a dict representation of the stored fields across all entries."""
dict_ = {}
for settings in self:
dict_ = reduce(operator.or_, (entry.dict() for entry in settings.values()), dict_)
dict_ = reduce(operator.or_, (entry.model_dump() for entry in settings.values()), dict_)
return dict_
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dateutil.relativedelta import relativedelta
from discord import Colour, Embed, Member, User
from discord.errors import Forbidden
from pydantic import validator
from pydantic import field_validator
from pydis_core.utils.logging import get_logger
from pydis_core.utils.members import get_or_fetch_member

Expand Down Expand Up @@ -151,7 +151,7 @@ class InfractionAndNotification(ActionEntry):
infraction_duration: InfractionDuration
infraction_channel: int

@validator("infraction_type", pre=True)
@field_validator("infraction_type", mode="before")
@classmethod
def convert_infraction_name(cls, infr_type: str | Infraction) -> Infraction:
"""Convert the string to an Infraction by name."""
Expand Down Expand Up @@ -221,24 +221,24 @@ def union(self, other: Self) -> Self:
"""
# Lower number -> higher in the hierarchy
if self.infraction_type is None:
return other.copy()
return other.model_copy()
if other.infraction_type is None:
return self.copy()
return self.model_copy()

if self.infraction_type.value < other.infraction_type.value:
result = self.copy()
result = self.model_copy()
elif self.infraction_type.value > other.infraction_type.value:
result = other.copy()
result = other.model_copy()
other = self
else:
now = arrow.utcnow().datetime
if self.infraction_duration is None or (
other.infraction_duration is not None
and now + self.infraction_duration.value > now + other.infraction_duration.value
):
result = self.copy()
result = self.model_copy()
else:
result = other.copy()
result = other.model_copy()
other = self

# If the winner has no message but the loser does, copy the message to the winner.
Expand Down
4 changes: 2 additions & 2 deletions bot/exts/filtering/_settings_types/actions/ping.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import ClassVar, Self

from pydantic import validator
from pydantic import field_validator

from bot.exts.filtering._filter_context import FilterContext
from bot.exts.filtering._settings_types.settings_entry import ActionEntry
Expand All @@ -25,7 +25,7 @@ class Ping(ActionEntry):
guild_pings: set[str]
dm_pings: set[str]

@validator("*", pre=True)
@field_validator("*", mode="before")
@classmethod
def init_sequence_if_none(cls, pings: list[str] | None) -> list[str]:
"""Initialize an empty sequence if the value is None."""
Expand Down
2 changes: 1 addition & 1 deletion bot/exts/filtering/_settings_types/settings_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class SettingsEntry(BaseModel, FieldRequiring):
def __init__(self, defaults: SettingsEntry | None = None, /, **data):
overrides = set()
if defaults:
defaults_dict = defaults.dict()
defaults_dict = defaults.model_dump()
for field_name, field_value in list(data.items()):
if field_value is None:
data[field_name] = defaults_dict[field_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import ClassVar, Union

from pydantic import validator
from pydantic import field_validator

from bot.exts.filtering._filter_context import FilterContext
from bot.exts.filtering._settings_types.settings_entry import ValidationEntry
Expand Down Expand Up @@ -36,7 +36,7 @@ class ChannelScope(ValidationEntry):
enabled_channels: set[Union[int, str]] # noqa: UP007
enabled_categories: set[Union[int, str]] # noqa: UP007

@validator("*", pre=True)
@field_validator("*", mode="before")
@classmethod
def init_if_sequence_none(cls, sequence: list[str] | None) -> list[str]:
"""Initialize an empty sequence if the value is None."""
Expand Down
6 changes: 3 additions & 3 deletions bot/exts/filtering/_ui/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def build_filter_repr_dict(
default_setting_values = {}
for settings_group in filter_list[list_type].defaults:
for _, setting in settings_group.items():
default_setting_values.update(to_serializable(setting.dict(), ui_repr=True))
default_setting_values.update(to_serializable(setting.model_dump(), ui_repr=True))

# Add overrides. It's done in this way to preserve field order, since the filter won't have all settings.
total_values = {}
Expand All @@ -47,7 +47,7 @@ def build_filter_repr_dict(
# Add the filter-specific settings.
if filter_type.extra_fields_type:
# This iterates over the default values of the extra fields model.
for name, value in filter_type.extra_fields_type().dict().items():
for name, value in filter_type.extra_fields_type().model_dump().items():
if name not in extra_fields_overrides or repr_equals(extra_fields_overrides[name], value):
total_values[f"{filter_type.name}/{name}"] = value
else:
Expand Down Expand Up @@ -287,7 +287,7 @@ async def update_embed(
if "/" in setting_name:
filter_name, setting_name = setting_name.split("/", maxsplit=1)
dict_to_edit = self.filter_settings_overrides
default_value = self.filter_type.extra_fields_type().dict()[setting_name]
default_value = self.filter_type.extra_fields_type().model_dump()[setting_name]
else:
dict_to_edit = self.settings_overrides
default_value = self.filter_list[self.list_type].default(setting_name)
Expand Down
2 changes: 1 addition & 1 deletion bot/exts/filtering/_ui/filter_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def build_filterlist_repr_dict(filter_list: FilterList, list_type: ListType, new
default_setting_values = {}
for settings_group in filter_list[list_type].defaults:
for _, setting in settings_group.items():
default_setting_values.update(to_serializable(setting.dict(), ui_repr=True))
default_setting_values.update(to_serializable(setting.model_dump(), ui_repr=True))

# Add new values. It's done in this way to preserve field order, since the new_values won't have all settings.
total_values = {}
Expand Down
13 changes: 9 additions & 4 deletions bot/exts/filtering/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
import types
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from functools import cache
from typing import Any, Self, TypeVar, Union, get_args, get_origin

import discord
import regex
from discord.ext.commands import Command
from pydantic_core import core_schema

import bot
from bot.bot import Bot
Expand Down Expand Up @@ -252,12 +253,16 @@ def __init__(self, value: Any):
self.value = self.process_value(value)

@classmethod
def __get_validators__(cls):
def __get_pydantic_core_schema__(
cls,
_source: type[Any],
_handler: Callable[[Any], core_schema.CoreSchema],
) -> core_schema.CoreSchema:
"""Boilerplate for Pydantic."""
yield cls.validate
return core_schema.general_plain_validator_function(cls.validate)

@classmethod
def validate(cls, v: Any) -> Self:
def validate(cls, v: Any, _info: core_schema.ValidationInfo) -> Self:
"""Takes the given value and returns a class instance with that value."""
if isinstance(v, CustomIOField):
return cls(v.value)
Expand Down
14 changes: 10 additions & 4 deletions bot/exts/filtering/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def collect_loaded_types(self, example_list: AtomicList) -> None:
extra_fields_type,
type_hints[field_name]
)
for field_name in extra_fields_type.__fields__
for field_name in extra_fields_type.model_fields
}

async def schedule_offending_messages_deletion(self) -> None:
Expand Down Expand Up @@ -754,7 +754,7 @@ async def fl_describe(
setting_values = {}
for settings_group in filter_list[list_type].defaults:
for _, setting in settings_group.items():
setting_values.update(to_serializable(setting.dict(), ui_repr=True))
setting_values.update(to_serializable(setting.model_dump(), ui_repr=True))

embed = Embed(colour=Colour.blue())
populate_embed_from_dict(embed, setting_values)
Expand Down Expand Up @@ -1239,7 +1239,13 @@ async def _patch_filter(
for current_settings in (filter_.actions, filter_.validations):
if current_settings:
for setting_entry in current_settings.values():
settings.update({setting: None for setting in setting_entry.dict() if setting not in settings})
settings.update(
{
setting: None
for setting in setting_entry.model_dump()
if setting not in settings
}
)

# Even though the list ID remains unchanged, it still needs to be provided for correct serializer validation.
list_id = filter_list[list_type].id
Expand Down Expand Up @@ -1295,7 +1301,7 @@ def _filter_match_query(
if not (differ_by_default <= override_matches): # The overrides didn't cover for the default mismatches.
return False

filter_settings = filter_.extra_fields.dict() if filter_.extra_fields else {}
filter_settings = filter_.extra_fields.model_dump() if filter_.extra_fields else {}
# If the dict changes then some fields were not the same.
return (filter_settings | filter_settings_query) == filter_settings

Expand Down
12 changes: 6 additions & 6 deletions bot/exts/recruitment/talentpool/_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime

from pydantic import BaseModel, Field, parse_obj_as
from pydantic import BaseModel, Field, TypeAdapter
from pydis_core.site_api import APIClient


Expand Down Expand Up @@ -50,13 +50,13 @@ async def get_nominations(
params["user__id"] = str(user_id)

data = await self.site_api.get("bot/nominations", params=params)
nominations = parse_obj_as(list[Nomination], data)
nominations = TypeAdapter(list[Nomination]).validate_python(data)
return nominations

async def get_nomination(self, nomination_id: int) -> Nomination:
"""Fetch a nomination by ID."""
data = await self.site_api.get(f"bot/nominations/{nomination_id}")
nomination = Nomination.parse_obj(data)
nomination = Nomination.model_validate(data)
return nomination

async def edit_nomination(
Expand Down Expand Up @@ -84,7 +84,7 @@ async def edit_nomination(
data["thread_id"] = thread_id

result = await self.site_api.patch(f"bot/nominations/{nomination_id}", json=data)
return Nomination.parse_obj(result)
return Nomination.model_validate(result)

async def edit_nomination_entry(
self,
Expand All @@ -96,7 +96,7 @@ async def edit_nomination_entry(
"""Edit a nomination entry."""
data = {"actor": actor_id, "reason": reason}
result = await self.site_api.patch(f"bot/nominations/{nomination_id}", json=data)
return Nomination.parse_obj(result)
return Nomination.model_validate(result)

async def post_nomination(
self,
Expand All @@ -111,7 +111,7 @@ async def post_nomination(
"user": user_id,
}
result = await self.site_api.post("bot/nominations", json=data)
return Nomination.parse_obj(result)
return Nomination.model_validate(result)

async def get_activity(
self,
Expand Down
6 changes: 3 additions & 3 deletions botstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def create_webhook(self, name: str, channel_id_: int) -> str:

all_roles = discord_client.get_all_roles()

for role_name in _Roles.__fields__:
for role_name in _Roles.model_fields:

role_id = all_roles.get(role_name, None)
if not role_id:
Expand Down Expand Up @@ -209,7 +209,7 @@ def create_webhook(self, name: str, channel_id_: int) -> str:
python_help_channel_id = discord_client.create_forum_channel(python_help_channel_name, python_help_category_id)
all_channels[PYTHON_HELP_CHANNEL_NAME] = python_help_channel_id

for channel_name in _Channels.__fields__:
for channel_name in _Channels.model_fields:
channel_id = all_channels.get(channel_name, None)
if not channel_id:
log.warning(
Expand All @@ -222,7 +222,7 @@ def create_webhook(self, name: str, channel_id_: int) -> str:

config_str += "\n#Categories\n"

for category_name in _Categories.__fields__:
for category_name in _Categories.model_fields:
category_id = all_categories.get(category_name, None)
if not category_id:
log.warning(
Expand Down
4 changes: 2 additions & 2 deletions tests/bot/exts/filtering/test_settings_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def test_infraction_merge_of_same_infraction_type(self):
result = infraction1.union(infraction2)

self.assertDictEqual(
result.dict(),
result.model_dump(),
{
"infraction_type": Infraction.TIMEOUT,
"infraction_reason": "there",
Expand Down Expand Up @@ -206,7 +206,7 @@ def test_infraction_merge_of_different_infraction_types(self):
result = infraction1.union(infraction2)

self.assertDictEqual(
result.dict(),
result.model_dump(),
{
"infraction_type": Infraction.BAN,
"infraction_reason": "",
Expand Down

0 comments on commit 292fba7

Please sign in to comment.