wait_for method

This commit is contained in:
grey-cat-1908 2022-04-11 14:26:49 +03:00
parent 8c7f4a8aec
commit c9aead9ac2
12 changed files with 216 additions and 24 deletions

View file

@ -7,7 +7,7 @@ import signal
import sys
import traceback
from typing import Dict, List, Union, Any, Iterable, Optional
from typing import Dict, List, Union, Any, Iterable, Optional, Callable
from .models import User, Guild, Activity
from .models.app import Shard
@ -18,6 +18,7 @@ from .core.gateway import GatewayBotInfo
from .models.guild.channel import Channel, ChannelType, channel_types_for_converting
from .utils.logging import init_logging
from .models.app.intents import Intents
from .utils.waiters import WaiterMgr
_logger = logging.getLogger("melisa")
@ -71,16 +72,17 @@ class Client:
mobile: bool = False,
logs: Union[None, int, str, Dict[str, Any]] = "INFO",
):
self._loop = asyncio.get_event_loop()
self.shards: Dict[int, Shard] = {}
self.http: HTTPClient = HTTPClient(token)
self._events: Dict[str, Coro] = {}
self._waiter_mgr = WaiterMgr(self._loop)
# ToDo: Transfer guilds in to the cache manager
self.guilds = {}
self.user = None
self._loop = asyncio.get_event_loop()
self._gateway_info = self._loop.run_until_complete(self._get_gateway())
if isinstance(intents, Iterable):
@ -164,6 +166,8 @@ class Client:
print(f"Ignoring exception in {name}", file=sys.stderr)
traceback.print_exc()
self._waiter_mgr.process_events(name, *args)
def run(self) -> None:
"""
Run Bot without shards (only 0 shard)
@ -278,5 +282,63 @@ class Client:
channel_cls = channel_types_for_converting.get(data["type"], Channel)
return channel_cls.from_dict(data)
async def wait_for(
self,
event_name: str,
*,
check: Optional[Callable[..., bool]] = None,
timeout: Optional[float] = None,
):
"""|coro|
Waits for a WebSocket event to be dispatched.
This could be used to wait for a user to reply to a message,
or to react to a message.
The ``timeout`` parameter is passed onto :func:`asyncio.wait_for`. By default,
it does not timeout. Note that this does propagate the
:exc:`asyncio.TimeoutError` for you in case of timeout and is provided for
ease of use.
In case the event returns multiple arguments, a :class:`tuple` containing those
arguments is returned instead.
This function returns the **first event that meets the requirements**.
Examples
--------
Waiting for a user reply: ::
@client.listen
async def on_message_create(message):
if message.content.startswith('$greet'):
channel = await client.fetch_channel(message.channel_id)
await channel.send('Say hello!')
def check(m):
return m.content == "hello" and channel.id == message.channel_id
msg = await client.wait_for('on_message_create', check=check, timeout=10.0)
await channel.send(f'Hello man!')
Parameters
----------
event_name: :class:`str`
The type of event. It should starts with `on_`.
check: Optional[Callable[[Any], :class:`bool`]]
A predicate to check what to wait for. The arguments must meet the
parameters of the event being waited for.
timeout: Optional[float]
The number of seconds to wait before timing out and raising
:exc:`asyncio.TimeoutError`.
Returns
------
Any
Returns no arguments, a single argument, or a :class:`tuple` of multiple
arguments that mirrors the parameters passed in the event.
"""
return await self._waiter_mgr.wait_for(event_name, check, timeout)
Bot = Client

View file

@ -14,6 +14,11 @@ class ClientException(MelisaException):
pass
class MelisaTimeoutError(MelisaException):
"""Exception raised when `wait_for` method timed out
"""
class LoginFailure(ClientException):
"""Fails to log you in from improper credentials or some other misc."""

View file

@ -8,8 +8,6 @@ from ..models.guild import Channel, ChannelType, channel_types_for_converting
async def channel_create_listener(self, gateway, payload: dict):
gateway.session_id = payload.get("session_id")
payload.update({"type": ChannelType(payload.pop("type"))})
channel_cls = channel_types_for_converting.get(payload["type"], Channel)

View file

@ -8,8 +8,6 @@ from ..models.guild import Channel, ChannelType, channel_types_for_converting
async def channel_delete_listener(self, gateway, payload: dict):
gateway.session_id = payload.get("session_id")
payload.update({"type": ChannelType(payload.pop("type"))})
channel_cls = channel_types_for_converting.get(payload["type"], Channel)

View file

@ -9,15 +9,13 @@ from ..models.guild import Channel, ChannelType, channel_types_for_converting
async def channel_update_listener(self, gateway, payload: dict):
# ToDo: Replace None to the old channel object (so it requires cache manager)
gateway.session_id = payload.get("session_id")
payload.update({"type": ChannelType(payload.pop("type"))})
channel_cls = channel_types_for_converting.get(payload["type"], Channel)
channel = channel_cls.from_dict(payload)
await self.dispatch("on_channel_update", None, channel)
await self.dispatch("on_channel_update", (None, channel))
return

View file

@ -8,8 +8,6 @@ from ..models.guild import Guild
async def guild_create_listener(self, gateway, payload: dict):
gateway.session_id = payload.get("session_id")
guild_was_cached_as_none = False
guild = Guild.from_dict(payload)

View file

@ -8,8 +8,6 @@ from ..models.guild import UnavailableGuild
async def guild_delete_listener(self, gateway, payload: dict):
gateway.session_id = payload.get("session_id")
guild = UnavailableGuild.from_dict(payload)
self.guilds.pop(guild.id, None)

View file

@ -8,14 +8,12 @@ from ..models.guild import Guild
async def guild_update_listener(self, gateway, payload: dict):
gateway.session_id = payload.get("session_id")
new_guild = Guild.from_dict(payload)
old_guild = self.guilds.get(new_guild.id)
self.guilds[new_guild.id] = new_guild
await self.dispatch("on_channel_create", old_guild, new_guild)
await self.dispatch("on_guild_update", (old_guild, new_guild))
return

View file

@ -0,0 +1,18 @@
# Copyright MelisaDev 2022 - Present
# Full MIT License can be found in `LICENSE.txt` at the project root.
from __future__ import annotations
from ..models.message.message import Message
from ..utils.types import Coro
async def message_create_listener(self, gateway, payload: dict):
message = Message.from_dict(payload)
await self.dispatch("on_message_create", message)
return
def export() -> Coro:
return message_create_listener

View file

@ -23,6 +23,8 @@ class Intents(IntFlag):
DIRECT_MESSAGES = 1 << 12
DIRECT_MESSAGE_REACTIONS = 1 << 13
DIRECT_MESSAGE_TYPING = 1 << 14
MESSAGE_CONTENT = 1 << 15
GUILD_SCHEDULED_EVENTS = 1 << 16
@classmethod
def all(cls) -> Intents:

View file

@ -5,7 +5,7 @@ from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum
from typing import List, TYPE_CHECKING, Optional
from typing import List, TYPE_CHECKING, Optional, Dict
from ...utils import Snowflake, Timestamp
from ...utils import APIModelBase
@ -173,8 +173,8 @@ class Message(APIModelBase):
id: APINullable[Snowflake] = None
channel_id: APINullable[Snowflake] = None
guild_id: APINullable[Snowflake] = None
author: APINullable[List] = None
member: APINullable[List] = None
author: APINullable[Dict] = None
member: APINullable[Dict] = None
content: APINullable[str] = None
timestamp: APINullable[Timestamp] = None
edited_timestamp: APINullable[Timestamp] = None
@ -190,12 +190,12 @@ class Message(APIModelBase):
pinned: APINullable[bool] = None
webhook_id: APINullable[Snowflake] = None
type: APINullable[int] = None
activity: APINullable[List] = None
application: APINullable[List] = None
activity: APINullable[Dict] = None
application: APINullable[Dict] = None
application_id: APINullable[Snowflake] = None
message_reference: APINullable[List] = None
message_reference: APINullable[Dict] = None
flags: APINullable[int] = None
interaction: APINullable[List] = None
interaction: APINullable[Dict] = None
thread: APINullable[Thread] = None
components: APINullable[List] = None
sticker_items: APINullable[List] = None

117
melisa/utils/waiters.py Normal file
View file

@ -0,0 +1,117 @@
# Copyright MelisaDev 2022 - Present
# Full MIT License can be found in `LICENSE.txt` at the project root.
from __future__ import annotations
from asyncio import AbstractEventLoop, Event, wait_for as async_wait, TimeoutError as AsyncTimeOut
from typing import List, Callable, Optional
from ..exceptions import MelisaTimeoutError
class _Waiter:
"""
Parameters
----------
event_name : str
The name of the event.
check : Optional[Callable[[Any], :class:`bool`]]
``can_be_set`` only returns true if this function returns true.
Will be ignored if set to None.
Attributes
----------
event: :class:`asyncio.Event`
Even that is used to wait until the next valid discord event.
return_value : Optional[str]
Used to store the arguments from ``can_be_set`` so they can be
returned later.
"""
def __init__(self, event_name: str, check: Optional[Callable] = None):
self.event_name = event_name
self.check = check
self.event = Event()
self.return_value = None
super().__init__()
async def wait(self):
"""Waits until ``self.event`` is set."""
await self.event.wait()
def process(self, event_name: str, event_value):
if self.event_name != event_name:
return False
if self.check:
if event_value is not None:
if not self.check(event_value):
return
else:
if not self.check():
return
self.return_value = event_value
self.event.set()
class WaiterMgr:
"""
Attributes
----------
waiter_list : List
The List of events that need to be processed.
"""
def __init__(self, loop: AbstractEventLoop):
self.waiter_list: List[_Waiter] = []
self.loop = loop
def process_events(self, event_name, event_value):
"""
Parameters
----------
event_name : str
The name of the event to be processed.
event_value : Any
The object returned from the middleware for this event.
"""
for waiter in self.waiter_list:
waiter.process(event_name, event_value)
async def wait_for(
self,
event_name: str,
check: Optional[Callable[..., bool]] = None,
timeout: Optional[float] = None
):
"""
Parameters
----------
event_name: :class:`str`
The type of event. It should start with `on_`.
check: Optional[Callable[[Any], :class:`bool`]]
This function only returns a value if this return true.
timeout: Optional[float]
Amount of seconds before timeout. Use None for no timeout.
Returns
------
Any
What the Discord API returns for this event.
"""
waiter = _Waiter(event_name, check)
self.waiter_list.append(waiter)
try:
await async_wait(waiter.wait(), timeout=timeout)
self.waiter_list.remove(waiter)
except AsyncTimeOut:
self.waiter_list.remove(waiter)
raise MelisaTimeoutError(
"wait_for() timed out while waiting for an event."
)
return waiter.return_value