mirror of
https://github.com/MelisaDev/melisa.git
synced 2024-11-11 19:07:28 +03:00
wait_for method
This commit is contained in:
parent
8c7f4a8aec
commit
c9aead9ac2
12 changed files with 216 additions and 24 deletions
|
@ -7,7 +7,7 @@ import signal
|
|||
import sys
|
||||
import traceback
|
||||
|
||||
from typing import Dict, List, Union, Any, Iterable, Optional
|
||||
from typing import Dict, List, Union, Any, Iterable, Optional, Callable
|
||||
|
||||
from .models import User, Guild, Activity
|
||||
from .models.app import Shard
|
||||
|
@ -18,6 +18,7 @@ from .core.gateway import GatewayBotInfo
|
|||
from .models.guild.channel import Channel, ChannelType, channel_types_for_converting
|
||||
from .utils.logging import init_logging
|
||||
from .models.app.intents import Intents
|
||||
from .utils.waiters import WaiterMgr
|
||||
|
||||
_logger = logging.getLogger("melisa")
|
||||
|
||||
|
@ -71,16 +72,17 @@ class Client:
|
|||
mobile: bool = False,
|
||||
logs: Union[None, int, str, Dict[str, Any]] = "INFO",
|
||||
):
|
||||
self._loop = asyncio.get_event_loop()
|
||||
|
||||
self.shards: Dict[int, Shard] = {}
|
||||
self.http: HTTPClient = HTTPClient(token)
|
||||
self._events: Dict[str, Coro] = {}
|
||||
self._waiter_mgr = WaiterMgr(self._loop)
|
||||
|
||||
# ToDo: Transfer guilds in to the cache manager
|
||||
self.guilds = {}
|
||||
self.user = None
|
||||
|
||||
self._loop = asyncio.get_event_loop()
|
||||
|
||||
self._gateway_info = self._loop.run_until_complete(self._get_gateway())
|
||||
|
||||
if isinstance(intents, Iterable):
|
||||
|
@ -164,6 +166,8 @@ class Client:
|
|||
print(f"Ignoring exception in {name}", file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
self._waiter_mgr.process_events(name, *args)
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run Bot without shards (only 0 shard)
|
||||
|
@ -278,5 +282,63 @@ class Client:
|
|||
channel_cls = channel_types_for_converting.get(data["type"], Channel)
|
||||
return channel_cls.from_dict(data)
|
||||
|
||||
async def wait_for(
|
||||
self,
|
||||
event_name: str,
|
||||
*,
|
||||
check: Optional[Callable[..., bool]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
):
|
||||
"""|coro|
|
||||
|
||||
Waits for a WebSocket event to be dispatched.
|
||||
|
||||
This could be used to wait for a user to reply to a message,
|
||||
or to react to a message.
|
||||
|
||||
The ``timeout`` parameter is passed onto :func:`asyncio.wait_for`. By default,
|
||||
it does not timeout. Note that this does propagate the
|
||||
:exc:`asyncio.TimeoutError` for you in case of timeout and is provided for
|
||||
ease of use.
|
||||
|
||||
In case the event returns multiple arguments, a :class:`tuple` containing those
|
||||
arguments is returned instead.
|
||||
|
||||
This function returns the **first event that meets the requirements**.
|
||||
|
||||
Examples
|
||||
--------
|
||||
Waiting for a user reply: ::
|
||||
@client.listen
|
||||
async def on_message_create(message):
|
||||
if message.content.startswith('$greet'):
|
||||
channel = await client.fetch_channel(message.channel_id)
|
||||
await channel.send('Say hello!')
|
||||
|
||||
def check(m):
|
||||
return m.content == "hello" and channel.id == message.channel_id
|
||||
|
||||
msg = await client.wait_for('on_message_create', check=check, timeout=10.0)
|
||||
await channel.send(f'Hello man!')
|
||||
|
||||
Parameters
|
||||
----------
|
||||
event_name: :class:`str`
|
||||
The type of event. It should starts with `on_`.
|
||||
check: Optional[Callable[[Any], :class:`bool`]]
|
||||
A predicate to check what to wait for. The arguments must meet the
|
||||
parameters of the event being waited for.
|
||||
timeout: Optional[float]
|
||||
The number of seconds to wait before timing out and raising
|
||||
:exc:`asyncio.TimeoutError`.
|
||||
|
||||
Returns
|
||||
------
|
||||
Any
|
||||
Returns no arguments, a single argument, or a :class:`tuple` of multiple
|
||||
arguments that mirrors the parameters passed in the event.
|
||||
"""
|
||||
return await self._waiter_mgr.wait_for(event_name, check, timeout)
|
||||
|
||||
|
||||
Bot = Client
|
||||
|
|
|
@ -14,6 +14,11 @@ class ClientException(MelisaException):
|
|||
pass
|
||||
|
||||
|
||||
class MelisaTimeoutError(MelisaException):
|
||||
"""Exception raised when `wait_for` method timed out
|
||||
"""
|
||||
|
||||
|
||||
class LoginFailure(ClientException):
|
||||
"""Fails to log you in from improper credentials or some other misc."""
|
||||
|
||||
|
|
|
@ -8,8 +8,6 @@ from ..models.guild import Channel, ChannelType, channel_types_for_converting
|
|||
|
||||
|
||||
async def channel_create_listener(self, gateway, payload: dict):
|
||||
gateway.session_id = payload.get("session_id")
|
||||
|
||||
payload.update({"type": ChannelType(payload.pop("type"))})
|
||||
|
||||
channel_cls = channel_types_for_converting.get(payload["type"], Channel)
|
||||
|
|
|
@ -8,8 +8,6 @@ from ..models.guild import Channel, ChannelType, channel_types_for_converting
|
|||
|
||||
|
||||
async def channel_delete_listener(self, gateway, payload: dict):
|
||||
gateway.session_id = payload.get("session_id")
|
||||
|
||||
payload.update({"type": ChannelType(payload.pop("type"))})
|
||||
|
||||
channel_cls = channel_types_for_converting.get(payload["type"], Channel)
|
||||
|
|
|
@ -9,15 +9,13 @@ from ..models.guild import Channel, ChannelType, channel_types_for_converting
|
|||
|
||||
async def channel_update_listener(self, gateway, payload: dict):
|
||||
# ToDo: Replace None to the old channel object (so it requires cache manager)
|
||||
gateway.session_id = payload.get("session_id")
|
||||
|
||||
payload.update({"type": ChannelType(payload.pop("type"))})
|
||||
|
||||
channel_cls = channel_types_for_converting.get(payload["type"], Channel)
|
||||
|
||||
channel = channel_cls.from_dict(payload)
|
||||
|
||||
await self.dispatch("on_channel_update", None, channel)
|
||||
await self.dispatch("on_channel_update", (None, channel))
|
||||
|
||||
return
|
||||
|
||||
|
|
|
@ -8,8 +8,6 @@ from ..models.guild import Guild
|
|||
|
||||
|
||||
async def guild_create_listener(self, gateway, payload: dict):
|
||||
gateway.session_id = payload.get("session_id")
|
||||
|
||||
guild_was_cached_as_none = False
|
||||
|
||||
guild = Guild.from_dict(payload)
|
||||
|
|
|
@ -8,8 +8,6 @@ from ..models.guild import UnavailableGuild
|
|||
|
||||
|
||||
async def guild_delete_listener(self, gateway, payload: dict):
|
||||
gateway.session_id = payload.get("session_id")
|
||||
|
||||
guild = UnavailableGuild.from_dict(payload)
|
||||
|
||||
self.guilds.pop(guild.id, None)
|
||||
|
|
|
@ -8,14 +8,12 @@ from ..models.guild import Guild
|
|||
|
||||
|
||||
async def guild_update_listener(self, gateway, payload: dict):
|
||||
gateway.session_id = payload.get("session_id")
|
||||
|
||||
new_guild = Guild.from_dict(payload)
|
||||
old_guild = self.guilds.get(new_guild.id)
|
||||
|
||||
self.guilds[new_guild.id] = new_guild
|
||||
|
||||
await self.dispatch("on_channel_create", old_guild, new_guild)
|
||||
await self.dispatch("on_guild_update", (old_guild, new_guild))
|
||||
|
||||
return
|
||||
|
||||
|
|
18
melisa/listeners/message_create.py
Normal file
18
melisa/listeners/message_create.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
# Copyright MelisaDev 2022 - Present
|
||||
# Full MIT License can be found in `LICENSE.txt` at the project root.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ..models.message.message import Message
|
||||
from ..utils.types import Coro
|
||||
|
||||
|
||||
async def message_create_listener(self, gateway, payload: dict):
|
||||
message = Message.from_dict(payload)
|
||||
await self.dispatch("on_message_create", message)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def export() -> Coro:
|
||||
return message_create_listener
|
|
@ -23,6 +23,8 @@ class Intents(IntFlag):
|
|||
DIRECT_MESSAGES = 1 << 12
|
||||
DIRECT_MESSAGE_REACTIONS = 1 << 13
|
||||
DIRECT_MESSAGE_TYPING = 1 << 14
|
||||
MESSAGE_CONTENT = 1 << 15
|
||||
GUILD_SCHEDULED_EVENTS = 1 << 16
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Intents:
|
||||
|
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from typing import List, TYPE_CHECKING, Optional
|
||||
from typing import List, TYPE_CHECKING, Optional, Dict
|
||||
|
||||
from ...utils import Snowflake, Timestamp
|
||||
from ...utils import APIModelBase
|
||||
|
@ -173,8 +173,8 @@ class Message(APIModelBase):
|
|||
id: APINullable[Snowflake] = None
|
||||
channel_id: APINullable[Snowflake] = None
|
||||
guild_id: APINullable[Snowflake] = None
|
||||
author: APINullable[List] = None
|
||||
member: APINullable[List] = None
|
||||
author: APINullable[Dict] = None
|
||||
member: APINullable[Dict] = None
|
||||
content: APINullable[str] = None
|
||||
timestamp: APINullable[Timestamp] = None
|
||||
edited_timestamp: APINullable[Timestamp] = None
|
||||
|
@ -190,12 +190,12 @@ class Message(APIModelBase):
|
|||
pinned: APINullable[bool] = None
|
||||
webhook_id: APINullable[Snowflake] = None
|
||||
type: APINullable[int] = None
|
||||
activity: APINullable[List] = None
|
||||
application: APINullable[List] = None
|
||||
activity: APINullable[Dict] = None
|
||||
application: APINullable[Dict] = None
|
||||
application_id: APINullable[Snowflake] = None
|
||||
message_reference: APINullable[List] = None
|
||||
message_reference: APINullable[Dict] = None
|
||||
flags: APINullable[int] = None
|
||||
interaction: APINullable[List] = None
|
||||
interaction: APINullable[Dict] = None
|
||||
thread: APINullable[Thread] = None
|
||||
components: APINullable[List] = None
|
||||
sticker_items: APINullable[List] = None
|
||||
|
|
117
melisa/utils/waiters.py
Normal file
117
melisa/utils/waiters.py
Normal file
|
@ -0,0 +1,117 @@
|
|||
# Copyright MelisaDev 2022 - Present
|
||||
# Full MIT License can be found in `LICENSE.txt` at the project root.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from asyncio import AbstractEventLoop, Event, wait_for as async_wait, TimeoutError as AsyncTimeOut
|
||||
from typing import List, Callable, Optional
|
||||
|
||||
from ..exceptions import MelisaTimeoutError
|
||||
|
||||
|
||||
class _Waiter:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
event_name : str
|
||||
The name of the event.
|
||||
check : Optional[Callable[[Any], :class:`bool`]]
|
||||
``can_be_set`` only returns true if this function returns true.
|
||||
Will be ignored if set to None.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
event: :class:`asyncio.Event`
|
||||
Even that is used to wait until the next valid discord event.
|
||||
return_value : Optional[str]
|
||||
Used to store the arguments from ``can_be_set`` so they can be
|
||||
returned later.
|
||||
"""
|
||||
|
||||
def __init__(self, event_name: str, check: Optional[Callable] = None):
|
||||
self.event_name = event_name
|
||||
self.check = check
|
||||
self.event = Event()
|
||||
self.return_value = None
|
||||
super().__init__()
|
||||
|
||||
async def wait(self):
|
||||
"""Waits until ``self.event`` is set."""
|
||||
await self.event.wait()
|
||||
|
||||
def process(self, event_name: str, event_value):
|
||||
if self.event_name != event_name:
|
||||
return False
|
||||
|
||||
if self.check:
|
||||
if event_value is not None:
|
||||
if not self.check(event_value):
|
||||
return
|
||||
else:
|
||||
if not self.check():
|
||||
return
|
||||
|
||||
self.return_value = event_value
|
||||
self.event.set()
|
||||
|
||||
|
||||
class WaiterMgr:
|
||||
"""
|
||||
Attributes
|
||||
----------
|
||||
waiter_list : List
|
||||
The List of events that need to be processed.
|
||||
"""
|
||||
|
||||
def __init__(self, loop: AbstractEventLoop):
|
||||
self.waiter_list: List[_Waiter] = []
|
||||
self.loop = loop
|
||||
|
||||
def process_events(self, event_name, event_value):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
event_name : str
|
||||
The name of the event to be processed.
|
||||
event_value : Any
|
||||
The object returned from the middleware for this event.
|
||||
"""
|
||||
for waiter in self.waiter_list:
|
||||
waiter.process(event_name, event_value)
|
||||
|
||||
async def wait_for(
|
||||
self,
|
||||
event_name: str,
|
||||
check: Optional[Callable[..., bool]] = None,
|
||||
timeout: Optional[float] = None
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
event_name: :class:`str`
|
||||
The type of event. It should start with `on_`.
|
||||
check: Optional[Callable[[Any], :class:`bool`]]
|
||||
This function only returns a value if this return true.
|
||||
timeout: Optional[float]
|
||||
Amount of seconds before timeout. Use None for no timeout.
|
||||
|
||||
Returns
|
||||
------
|
||||
Any
|
||||
What the Discord API returns for this event.
|
||||
"""
|
||||
|
||||
waiter = _Waiter(event_name, check)
|
||||
self.waiter_list.append(waiter)
|
||||
|
||||
try:
|
||||
await async_wait(waiter.wait(), timeout=timeout)
|
||||
self.waiter_list.remove(waiter)
|
||||
except AsyncTimeOut:
|
||||
self.waiter_list.remove(waiter)
|
||||
raise MelisaTimeoutError(
|
||||
"wait_for() timed out while waiting for an event."
|
||||
)
|
||||
|
||||
return waiter.return_value
|
||||
|
Loading…
Reference in a new issue