Skip to content

Commit

Permalink
✨ version 0.4.0
Browse files Browse the repository at this point in the history
dispatcher.waiter & session.prompt
  • Loading branch information
RF-Tar-Railt committed May 3, 2024
1 parent 7fbdc9b commit 2e70d6a
Show file tree
Hide file tree
Showing 71 changed files with 964 additions and 872 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ command = EntariCommands()

@command.on("add {a} {b}")
async def add(a: int, b: int, session: ContextSession):
await session.send_message(f"{a + b =}")
await session.send(f"{a + b =}")


app = Entari()
Expand Down
3 changes: 3 additions & 0 deletions arclet/entari/command/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
from .main import EntariCommands as EntariCommands
from .model import CommandResult as CommandResult
from .model import Match as Match
from .model import Query as Query
from .plugin import AlconnaDispatcher as AlconnaDispatcher
8 changes: 7 additions & 1 deletion arclet/entari/command/plugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from arclet.alconna import Alconna

from ..event import MessageEvent
from ..plugin import Plugin, PluginDispatcher, PluginDispatcherFactory
from ..plugin import Plugin, PluginDispatcher, PluginDispatcherFactory, register_factory
from .provider import AlconnaProviderFactory, AlconnaSuppiler, MessageJudger


Expand All @@ -22,3 +22,9 @@ def dispatch(self, plugin: Plugin) -> PluginDispatcher:
disp.bind(MessageJudger(), AlconnaSuppiler(self.command, self.need_tome, self.remove_tome))
disp.bind(AlconnaProviderFactory())
return disp


register_factory(
Alconna,
lambda cmd, *args, **kwargs: AlconnaDispatcher(cmd, *args, **kwargs),
)
1 change: 1 addition & 0 deletions arclet/entari/command/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ async def __call__(self, scope: Scope, context: Contexts) -> Optional[Union[bool
elif may_help_text:
await account.send(context["$event"], MessageChain(may_help_text))
return False
return False

@property
def scopes(self) -> set[Scope]:
Expand Down
4 changes: 3 additions & 1 deletion arclet/entari/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def __call__(self, context: Contexts):

class ContextSessionProvider(Provider[ContextSession]):
async def __call__(self, context: Contexts):
if "$origin_event" and "$account" in context:
if "$origin_event" in context and "$account" in context:
return ContextSession(context["$account"], context["$origin_event"])


Expand All @@ -48,6 +48,8 @@ async def event_parse_task(connection: Account, raw: Event):
ev = event_parse(connection, raw)
self.event_system.publish(ev)
for disp in dispatchers.values():
if not disp.validate(ev):
continue
task = loop.create_task(disp.publish(ev))
self._ref_tasks.add(task)
task.add_done_callback(self._ref_tasks.discard)
Expand Down
50 changes: 45 additions & 5 deletions arclet/entari/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
import inspect
from os import PathLike
from pathlib import Path
from typing import Callable
from typing import Any, Callable, TypeVar, overload
from typing_extensions import Unpack

from arclet.letoderea import BaseEvent, Publisher, system_ctx
from arclet.letoderea import BaseAuxiliary, BaseEvent, Provider, Publisher, StepOut, system_ctx
from arclet.letoderea.builtin.breakpoint import R
from arclet.letoderea.typing import TTarget
from loguru import logger

dispatchers = {}
dispatchers: dict[str, PluginDispatcher] = {}


class PluginDispatcher(Publisher):
Expand All @@ -28,6 +31,23 @@ def __init__(
es.register(self)
else:
dispatchers[self.id] = self
self._events = events

def waiter(
self,
*events: type[BaseEvent],
providers: list[Provider | type[Provider]] | None = None,
auxiliaries: list[BaseAuxiliary] | None = None,
priority: int = 15,
block: bool = False,
) -> Callable[[TTarget[R]], StepOut[R]]:
def wrapper(func: TTarget[R]):
nonlocal events
if not events:
events = self._events
return StepOut(list(events), func, providers, auxiliaries, priority, block) # type: ignore

return wrapper

on = Publisher.register
handle = Publisher.register
Expand All @@ -38,6 +58,15 @@ class PluginDispatcherFactory(ABC):
def dispatch(self, plugin: Plugin) -> PluginDispatcher: ...


MAPPING: dict[type, Callable[..., PluginDispatcherFactory]] = {}

T = TypeVar("T")


def register_factory(cls: type[T], factory: Callable[[T, Unpack[tuple[Any, ...]]], PluginDispatcherFactory]):
MAPPING[cls] = factory


@dataclass
class Plugin:
author: list[str] = field(default_factory=list)
Expand Down Expand Up @@ -65,8 +94,19 @@ def dispatch(self, *events: type[BaseEvent], predicate: Callable[[BaseEvent], bo
self._dispatchers[disp.id] = disp
return disp

def mount(self, factory: PluginDispatcherFactory):
disp = factory.dispatch(self)
@overload
def mount(self, factory: PluginDispatcherFactory) -> PluginDispatcher: ...

@overload
def mount(self, factory: object, *args, **kwargs) -> PluginDispatcher: ...

def mount(self, factory: Any, *args, **kwargs):
if isinstance(factory, PluginDispatcherFactory):
disp = factory.dispatch(self)
elif factory_cls := MAPPING.get(factory.__class__):
disp = factory_cls(factory, *args, **kwargs).dispatch(self)
else:
raise TypeError(f"unsupported factory {factory!r}")
self._dispatchers[disp.id] = disp
return disp

Expand Down
76 changes: 75 additions & 1 deletion arclet/entari/session.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from __future__ import annotations

from collections.abc import Iterable
from typing import NoReturn

from arclet.letoderea import ParsingStop, StepOut
from satori.client.account import Account
from satori.const import EventType
from satori.element import Element
from satori.model import Channel, Event, Member, MessageObject, PageResult, Role, User
from satori.model import Channel, Event, Guild, Member, MessageObject, PageResult, Role, User

from .event import MessageEvent
from .message import MessageChain


class ContextSession:
Expand All @@ -15,6 +20,75 @@ def __init__(self, account: Account, event: Event):
self.account = account
self.context = event

async def prompt(
self,
message: str | Iterable[str | Element],
timeout: float = 120,
timeout_message: str | Iterable[str | Element] = "等待超时",
) -> MessageChain:
"""发送提示消息, 并等待回复
参数:
message: 要发送的消息
"""
if self.context.type != EventType.MESSAGE_CREATED:
raise RuntimeError("Event cannot be prompted!")

await self.send(message)

async def waiter(content: MessageChain, session: ContextSession):
if (
self.context.channel
and session.context.channel
and self.context.channel.id == session.context.channel.id
):
return content
if self.context.user and session.context.user and self.context.user.id == session.context.user.id:
return content

waiter.__annotations__ = {"content": MessageChain, "session": self.__class__}

step = StepOut([MessageEvent], waiter)

result = await step.wait(timeout=timeout)
if not result:
await self.send(timeout_message)
raise ParsingStop()
return result

def stop(self) -> NoReturn:
raise ParsingStop()

@property
def user(self) -> User:
if not self.context.user:
raise RuntimeError(f"Event {self.context.type!r} has no User")
return self.context.user

@property
def guild(self) -> Guild:
if not self.context.guild:
raise RuntimeError(f"Event {self.context.type!r} has no Guild")
return self.context.guild

@property
def channel(self) -> Channel:
if not self.context.channel:
raise RuntimeError(f"Event {self.context.type!r} has no Channel")
return self.context.channel

@property
def member(self) -> Member:
if not self.context.member:
raise RuntimeError(f"Event {self.context.type!r} has no Member")
return self.context.member

@property
def content(self) -> str:
if not self.context.message:
raise RuntimeError(f"Event {self.context.type!r} has no Content")
return self.context.message.content

def __getattr__(self, item):
return getattr(self.account.session, item)

Expand Down
25 changes: 19 additions & 6 deletions example_plugin.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,35 @@
from arclet.entari import MessageCreatedEvent, Plugin, EntariCommands, ContextSession, AlconnaDispatcher, is_direct_message
from arclet.alconna import Alconna, Args, AllParam
from arclet.alconna import Alconna, AllParam, Args

from arclet.entari import (
ContextSession,
EntariCommands,
MessageChain,
MessageCreatedEvent,
Plugin,
is_direct_message,
)
from arclet.entari.command import Match

plug = Plugin()

disp_message = plug.dispatch(MessageCreatedEvent)


@disp_message.on(auxiliaries=[is_direct_message])
@disp_message.on(auxiliaries=[])
async def _(event: MessageCreatedEvent):
print(event.content)


on_alconna = plug.mount(AlconnaDispatcher(Alconna("chat", Args["content", AllParam])))
on_alconna = plug.mount(Alconna("echo", Args["content?", AllParam]))


@on_alconna.on()
async def _(event: MessageCreatedEvent):
print("matched:", event.content)
async def _(content: Match[MessageChain], session: ContextSession):
if content.available:
await session.send(content.result)
return

await session.send(await session.prompt("请输入内容"))


commands = EntariCommands.current()
Expand Down
10 changes: 3 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from satori import Image
from arclet.entari import ContextSession, Entari, WebsocketsInfo, EntariCommands, load_plugin

from arclet.entari import ContextSession, Entari, EntariCommands, WebsocketsInfo, load_plugin

commands = EntariCommands()

Expand All @@ -11,11 +12,6 @@ async def echoimg(img: Image, session: ContextSession):

load_plugin("example_plugin")

app = Entari(
WebsocketsInfo(
port=12345,
path="foo"
)
)
app = Entari(WebsocketsInfo(host="127.0.0.1", port=5140, path="satori"))

app.run()
1 change: 1 addition & 0 deletions old/arclet/edoves/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING

if not TYPE_CHECKING:
import arclet.edoves.builtin.event

Expand Down
23 changes: 7 additions & 16 deletions old/arclet/edoves/builtin/actions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
from typing import Optional, Union, Any, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional, Union

from arclet.edoves.main.action import ExecutiveAction

from .medium import Message, Request
Expand All @@ -22,9 +23,7 @@ def __init__(self, target: Union[int, str], relationship: str, whole: bool = Fal
async def execute(self) -> "Monomer":
entity = self.target.protocol.current_scene.monomer_map.get(self.mono_id)
if not entity:
return await self.target.action(self.action)(
self.mono_id, self.rs, **self.rest
)
return await self.target.action(self.action)(self.mono_id, self.rs, **self.rest)
return entity


Expand Down Expand Up @@ -53,9 +52,7 @@ def __init__(self, action: str, message: Message):
self.data = message

async def execute(self):
return await self.target.action(self.action)(
self.data, target=self.target
)
return await self.target.action(self.action)(self.data, target=self.target)


class MessageRevoke(MessageAction):
Expand All @@ -66,9 +63,7 @@ def __init__(self, message: Message, target: int = None):
self.message_id = target

async def execute(self):
return await self.target.action(self.action)(
self.data, target=self.message_id
)
return await self.target.action(self.action)(self.data, target=self.message_id)


class MessageSend(MessageAction):
Expand All @@ -90,9 +85,7 @@ def __init__(self, message: Optional[Message] = None):

class MessageSendDirectly(MessageSend):
async def execute(self):
return await self.target.action(self.action)(
self.data, type=self.data.type
)
return await self.target.action(self.action)(self.data, type=self.data.type)


class RequestAction(ExecutiveAction):
Expand All @@ -105,9 +98,7 @@ def __init__(self, action: str, request: Request, msg: str):
self.data = request

async def execute(self):
return await self.target.action(self.action)(
self.data, msg=self.msg
)
return await self.target.action(self.action)(self.data, msg=self.msg)


class RequestAccept(RequestAction):
Expand Down
9 changes: 5 additions & 4 deletions old/arclet/edoves/builtin/alconna.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from arclet.alconna import Alconna, compile, Arpamar
from arclet.letoderea.utils import ArgumentPackage
from arclet.letoderea.exceptions import ParsingStop
from arclet.alconna import Alconna, Arpamar, compile
from arclet.letoderea.entities.auxiliary import BaseAuxiliary
from arclet.letoderea.exceptions import ParsingStop
from arclet.letoderea.utils import ArgumentPackage

from .message.chain import MessageChain


Expand All @@ -18,4 +19,4 @@ def supply(target_argument: ArgumentPackage) -> Arpamar:
res = self.analyser.analyse(target_argument.value)
if not res.matched:
raise ParsingStop
return res
return res
Loading

0 comments on commit 2e70d6a

Please sign in to comment.