Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: State ForwardRef 检测错误 #2698

Merged
merged 2 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions nonebot/internal/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@

from nonebot.dependencies import Param, Dependent
from nonebot.dependencies.utils import check_field_type
from nonebot.typing import T_State, T_Handler, T_DependencyCache
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
from nonebot.typing import (
_STATE_FLAG,
T_State,
T_Handler,
T_DependencyCache,
origin_is_annotated,
)
from nonebot.utils import (
get_name,
run_sync,
Expand Down Expand Up @@ -349,7 +355,9 @@ def _check_param(
cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]
) -> Optional[Self]:
# param type is T_State
if param.annotation is T_State:
if origin_is_annotated(
get_origin(param.annotation)
) and _STATE_FLAG in get_args(param.annotation):
return cls()
# legacy: param is named "state" and has no type annotation
elif param.annotation == param.empty and param.name == "state":
Expand Down
10 changes: 9 additions & 1 deletion nonebot/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,15 @@ def evaluate_forwardref(


# state
T_State: TypeAlias = dict[t.Any, t.Any]
# use annotated flag to avoid ForwardRef recreate generic type (py >= 3.11)
class StateFlag:
def __repr__(self) -> str:
return "StateFlag()"


_STATE_FLAG = StateFlag()

T_State: TypeAlias = t.Annotated[dict[t.Any, t.Any], _STATE_FLAG]
"""事件处理状态 State 类型"""

_DependentCallable: TypeAlias = t.Union[
Expand Down
4 changes: 4 additions & 0 deletions tests/plugins/param/param_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ async def get_bot(b: Bot) -> Bot:
return b


async def postpone_bot(b: "Bot") -> Bot:
return b


async def legacy_bot(bot):
return bot

Expand Down
4 changes: 4 additions & 0 deletions tests/plugins/param/param_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ async def event(e: Event) -> Event:
return e


async def postpone_event(e: "Event") -> Event:
return e


async def legacy_event(event):
return event

Expand Down
6 changes: 5 additions & 1 deletion tests/plugins/param/param_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ async def matcher(m: Matcher) -> Matcher:
return m


async def postpone_matcher(m: "Matcher") -> Matcher:
return m


async def legacy_matcher(matcher):
return matcher

Expand All @@ -27,7 +31,7 @@ class BarMatcher(Matcher): ...


async def union_matcher(
m: Union[FooMatcher, BarMatcher]
m: Union[FooMatcher, BarMatcher],
) -> Union[FooMatcher, BarMatcher]:
return m

Expand Down
4 changes: 4 additions & 0 deletions tests/plugins/param/param_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ async def state(x: T_State) -> T_State:
return x


async def postpone_state(x: "T_State") -> T_State:
return x


async def legacy_state(state):
return state

Expand Down
21 changes: 21 additions & 0 deletions tests/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ async def test_bot(app: App):
union_bot,
legacy_bot,
generic_bot,
postpone_bot,
not_legacy_bot,
generic_bot_none,
)
Expand All @@ -138,6 +139,11 @@ async def test_bot(app: App):
ctx.pass_params(bot=bot)
ctx.should_return(bot)

async with app.test_dependent(postpone_bot, allow_types=[BotParam]) as ctx:
bot = ctx.create_bot()
ctx.pass_params(bot=bot)
ctx.should_return(bot)

async with app.test_dependent(legacy_bot, allow_types=[BotParam]) as ctx:
bot = ctx.create_bot()
ctx.pass_params(bot=bot)
Expand Down Expand Up @@ -188,6 +194,7 @@ async def test_event(app: App):
legacy_event,
event_message,
generic_event,
postpone_event,
event_plain_text,
not_legacy_event,
generic_event_none,
Expand All @@ -201,6 +208,10 @@ async def test_event(app: App):
ctx.pass_params(event=fake_event)
ctx.should_return(fake_event)

async with app.test_dependent(postpone_event, allow_types=[EventParam]) as ctx:
ctx.pass_params(event=fake_event)
ctx.should_return(fake_event)

async with app.test_dependent(legacy_event, allow_types=[EventParam]) as ctx:
ctx.pass_params(event=fake_event)
ctx.should_return(fake_event)
Expand Down Expand Up @@ -273,6 +284,7 @@ async def test_state(app: App):
legacy_state,
command_start,
regex_matched,
postpone_state,
not_legacy_state,
command_whitespace,
shell_command_args,
Expand Down Expand Up @@ -302,6 +314,10 @@ async def test_state(app: App):
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state)

async with app.test_dependent(postpone_state, allow_types=[StateParam]) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state)

async with app.test_dependent(legacy_state, allow_types=[StateParam]) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state)
Expand Down Expand Up @@ -414,6 +430,7 @@ async def test_matcher(app: App):
union_matcher,
legacy_matcher,
generic_matcher,
postpone_matcher,
not_legacy_matcher,
generic_matcher_none,
)
Expand All @@ -425,6 +442,10 @@ async def test_matcher(app: App):
ctx.pass_params(matcher=fake_matcher)
ctx.should_return(fake_matcher)

async with app.test_dependent(postpone_matcher, allow_types=[MatcherParam]) as ctx:
ctx.pass_params(matcher=fake_matcher)
ctx.should_return(fake_matcher)

async with app.test_dependent(legacy_matcher, allow_types=[MatcherParam]) as ctx:
ctx.pass_params(matcher=fake_matcher)
ctx.should_return(fake_matcher)
Expand Down
Loading