mirror of
https://github.com/MelisaDev/melisa.git
synced 2024-11-11 19:07:28 +03:00
wait_for method
This commit is contained in:
parent
8c7f4a8aec
commit
c9aead9ac2
12 changed files with 216 additions and 24 deletions
|
@ -7,7 +7,7 @@ import signal
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from typing import Dict, List, Union, Any, Iterable, Optional
|
from typing import Dict, List, Union, Any, Iterable, Optional, Callable
|
||||||
|
|
||||||
from .models import User, Guild, Activity
|
from .models import User, Guild, Activity
|
||||||
from .models.app import Shard
|
from .models.app import Shard
|
||||||
|
@ -18,6 +18,7 @@ from .core.gateway import GatewayBotInfo
|
||||||
from .models.guild.channel import Channel, ChannelType, channel_types_for_converting
|
from .models.guild.channel import Channel, ChannelType, channel_types_for_converting
|
||||||
from .utils.logging import init_logging
|
from .utils.logging import init_logging
|
||||||
from .models.app.intents import Intents
|
from .models.app.intents import Intents
|
||||||
|
from .utils.waiters import WaiterMgr
|
||||||
|
|
||||||
_logger = logging.getLogger("melisa")
|
_logger = logging.getLogger("melisa")
|
||||||
|
|
||||||
|
@ -71,16 +72,17 @@ class Client:
|
||||||
mobile: bool = False,
|
mobile: bool = False,
|
||||||
logs: Union[None, int, str, Dict[str, Any]] = "INFO",
|
logs: Union[None, int, str, Dict[str, Any]] = "INFO",
|
||||||
):
|
):
|
||||||
|
self._loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
self.shards: Dict[int, Shard] = {}
|
self.shards: Dict[int, Shard] = {}
|
||||||
self.http: HTTPClient = HTTPClient(token)
|
self.http: HTTPClient = HTTPClient(token)
|
||||||
self._events: Dict[str, Coro] = {}
|
self._events: Dict[str, Coro] = {}
|
||||||
|
self._waiter_mgr = WaiterMgr(self._loop)
|
||||||
|
|
||||||
# ToDo: Transfer guilds in to the cache manager
|
# ToDo: Transfer guilds in to the cache manager
|
||||||
self.guilds = {}
|
self.guilds = {}
|
||||||
self.user = None
|
self.user = None
|
||||||
|
|
||||||
self._loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
self._gateway_info = self._loop.run_until_complete(self._get_gateway())
|
self._gateway_info = self._loop.run_until_complete(self._get_gateway())
|
||||||
|
|
||||||
if isinstance(intents, Iterable):
|
if isinstance(intents, Iterable):
|
||||||
|
@ -164,6 +166,8 @@ class Client:
|
||||||
print(f"Ignoring exception in {name}", file=sys.stderr)
|
print(f"Ignoring exception in {name}", file=sys.stderr)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
self._waiter_mgr.process_events(name, *args)
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
"""
|
"""
|
||||||
Run Bot without shards (only 0 shard)
|
Run Bot without shards (only 0 shard)
|
||||||
|
@ -278,5 +282,63 @@ class Client:
|
||||||
channel_cls = channel_types_for_converting.get(data["type"], Channel)
|
channel_cls = channel_types_for_converting.get(data["type"], Channel)
|
||||||
return channel_cls.from_dict(data)
|
return channel_cls.from_dict(data)
|
||||||
|
|
||||||
|
async def wait_for(
|
||||||
|
self,
|
||||||
|
event_name: str,
|
||||||
|
*,
|
||||||
|
check: Optional[Callable[..., bool]] = None,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
):
|
||||||
|
"""|coro|
|
||||||
|
|
||||||
|
Waits for a WebSocket event to be dispatched.
|
||||||
|
|
||||||
|
This could be used to wait for a user to reply to a message,
|
||||||
|
or to react to a message.
|
||||||
|
|
||||||
|
The ``timeout`` parameter is passed onto :func:`asyncio.wait_for`. By default,
|
||||||
|
it does not timeout. Note that this does propagate the
|
||||||
|
:exc:`asyncio.TimeoutError` for you in case of timeout and is provided for
|
||||||
|
ease of use.
|
||||||
|
|
||||||
|
In case the event returns multiple arguments, a :class:`tuple` containing those
|
||||||
|
arguments is returned instead.
|
||||||
|
|
||||||
|
This function returns the **first event that meets the requirements**.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
Waiting for a user reply: ::
|
||||||
|
@client.listen
|
||||||
|
async def on_message_create(message):
|
||||||
|
if message.content.startswith('$greet'):
|
||||||
|
channel = await client.fetch_channel(message.channel_id)
|
||||||
|
await channel.send('Say hello!')
|
||||||
|
|
||||||
|
def check(m):
|
||||||
|
return m.content == "hello" and channel.id == message.channel_id
|
||||||
|
|
||||||
|
msg = await client.wait_for('on_message_create', check=check, timeout=10.0)
|
||||||
|
await channel.send(f'Hello man!')
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
event_name: :class:`str`
|
||||||
|
The type of event. It should starts with `on_`.
|
||||||
|
check: Optional[Callable[[Any], :class:`bool`]]
|
||||||
|
A predicate to check what to wait for. The arguments must meet the
|
||||||
|
parameters of the event being waited for.
|
||||||
|
timeout: Optional[float]
|
||||||
|
The number of seconds to wait before timing out and raising
|
||||||
|
:exc:`asyncio.TimeoutError`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
------
|
||||||
|
Any
|
||||||
|
Returns no arguments, a single argument, or a :class:`tuple` of multiple
|
||||||
|
arguments that mirrors the parameters passed in the event.
|
||||||
|
"""
|
||||||
|
return await self._waiter_mgr.wait_for(event_name, check, timeout)
|
||||||
|
|
||||||
|
|
||||||
Bot = Client
|
Bot = Client
|
||||||
|
|
|
@ -14,6 +14,11 @@ class ClientException(MelisaException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MelisaTimeoutError(MelisaException):
|
||||||
|
"""Exception raised when `wait_for` method timed out
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class LoginFailure(ClientException):
|
class LoginFailure(ClientException):
|
||||||
"""Fails to log you in from improper credentials or some other misc."""
|
"""Fails to log you in from improper credentials or some other misc."""
|
||||||
|
|
||||||
|
|
|
@ -8,8 +8,6 @@ from ..models.guild import Channel, ChannelType, channel_types_for_converting
|
||||||
|
|
||||||
|
|
||||||
async def channel_create_listener(self, gateway, payload: dict):
|
async def channel_create_listener(self, gateway, payload: dict):
|
||||||
gateway.session_id = payload.get("session_id")
|
|
||||||
|
|
||||||
payload.update({"type": ChannelType(payload.pop("type"))})
|
payload.update({"type": ChannelType(payload.pop("type"))})
|
||||||
|
|
||||||
channel_cls = channel_types_for_converting.get(payload["type"], Channel)
|
channel_cls = channel_types_for_converting.get(payload["type"], Channel)
|
||||||
|
|
|
@ -8,8 +8,6 @@ from ..models.guild import Channel, ChannelType, channel_types_for_converting
|
||||||
|
|
||||||
|
|
||||||
async def channel_delete_listener(self, gateway, payload: dict):
|
async def channel_delete_listener(self, gateway, payload: dict):
|
||||||
gateway.session_id = payload.get("session_id")
|
|
||||||
|
|
||||||
payload.update({"type": ChannelType(payload.pop("type"))})
|
payload.update({"type": ChannelType(payload.pop("type"))})
|
||||||
|
|
||||||
channel_cls = channel_types_for_converting.get(payload["type"], Channel)
|
channel_cls = channel_types_for_converting.get(payload["type"], Channel)
|
||||||
|
|
|
@ -9,15 +9,13 @@ from ..models.guild import Channel, ChannelType, channel_types_for_converting
|
||||||
|
|
||||||
async def channel_update_listener(self, gateway, payload: dict):
|
async def channel_update_listener(self, gateway, payload: dict):
|
||||||
# ToDo: Replace None to the old channel object (so it requires cache manager)
|
# ToDo: Replace None to the old channel object (so it requires cache manager)
|
||||||
gateway.session_id = payload.get("session_id")
|
|
||||||
|
|
||||||
payload.update({"type": ChannelType(payload.pop("type"))})
|
payload.update({"type": ChannelType(payload.pop("type"))})
|
||||||
|
|
||||||
channel_cls = channel_types_for_converting.get(payload["type"], Channel)
|
channel_cls = channel_types_for_converting.get(payload["type"], Channel)
|
||||||
|
|
||||||
channel = channel_cls.from_dict(payload)
|
channel = channel_cls.from_dict(payload)
|
||||||
|
|
||||||
await self.dispatch("on_channel_update", None, channel)
|
await self.dispatch("on_channel_update", (None, channel))
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -8,8 +8,6 @@ from ..models.guild import Guild
|
||||||
|
|
||||||
|
|
||||||
async def guild_create_listener(self, gateway, payload: dict):
|
async def guild_create_listener(self, gateway, payload: dict):
|
||||||
gateway.session_id = payload.get("session_id")
|
|
||||||
|
|
||||||
guild_was_cached_as_none = False
|
guild_was_cached_as_none = False
|
||||||
|
|
||||||
guild = Guild.from_dict(payload)
|
guild = Guild.from_dict(payload)
|
||||||
|
|
|
@ -8,8 +8,6 @@ from ..models.guild import UnavailableGuild
|
||||||
|
|
||||||
|
|
||||||
async def guild_delete_listener(self, gateway, payload: dict):
|
async def guild_delete_listener(self, gateway, payload: dict):
|
||||||
gateway.session_id = payload.get("session_id")
|
|
||||||
|
|
||||||
guild = UnavailableGuild.from_dict(payload)
|
guild = UnavailableGuild.from_dict(payload)
|
||||||
|
|
||||||
self.guilds.pop(guild.id, None)
|
self.guilds.pop(guild.id, None)
|
||||||
|
|
|
@ -8,14 +8,12 @@ from ..models.guild import Guild
|
||||||
|
|
||||||
|
|
||||||
async def guild_update_listener(self, gateway, payload: dict):
|
async def guild_update_listener(self, gateway, payload: dict):
|
||||||
gateway.session_id = payload.get("session_id")
|
|
||||||
|
|
||||||
new_guild = Guild.from_dict(payload)
|
new_guild = Guild.from_dict(payload)
|
||||||
old_guild = self.guilds.get(new_guild.id)
|
old_guild = self.guilds.get(new_guild.id)
|
||||||
|
|
||||||
self.guilds[new_guild.id] = new_guild
|
self.guilds[new_guild.id] = new_guild
|
||||||
|
|
||||||
await self.dispatch("on_channel_create", old_guild, new_guild)
|
await self.dispatch("on_guild_update", (old_guild, new_guild))
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
18
melisa/listeners/message_create.py
Normal file
18
melisa/listeners/message_create.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
# Copyright MelisaDev 2022 - Present
|
||||||
|
# Full MIT License can be found in `LICENSE.txt` at the project root.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from ..models.message.message import Message
|
||||||
|
from ..utils.types import Coro
|
||||||
|
|
||||||
|
|
||||||
|
async def message_create_listener(self, gateway, payload: dict):
|
||||||
|
message = Message.from_dict(payload)
|
||||||
|
await self.dispatch("on_message_create", message)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def export() -> Coro:
|
||||||
|
return message_create_listener
|
|
@ -23,6 +23,8 @@ class Intents(IntFlag):
|
||||||
DIRECT_MESSAGES = 1 << 12
|
DIRECT_MESSAGES = 1 << 12
|
||||||
DIRECT_MESSAGE_REACTIONS = 1 << 13
|
DIRECT_MESSAGE_REACTIONS = 1 << 13
|
||||||
DIRECT_MESSAGE_TYPING = 1 << 14
|
DIRECT_MESSAGE_TYPING = 1 << 14
|
||||||
|
MESSAGE_CONTENT = 1 << 15
|
||||||
|
GUILD_SCHEDULED_EVENTS = 1 << 16
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def all(cls) -> Intents:
|
def all(cls) -> Intents:
|
||||||
|
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import List, TYPE_CHECKING, Optional
|
from typing import List, TYPE_CHECKING, Optional, Dict
|
||||||
|
|
||||||
from ...utils import Snowflake, Timestamp
|
from ...utils import Snowflake, Timestamp
|
||||||
from ...utils import APIModelBase
|
from ...utils import APIModelBase
|
||||||
|
@ -173,8 +173,8 @@ class Message(APIModelBase):
|
||||||
id: APINullable[Snowflake] = None
|
id: APINullable[Snowflake] = None
|
||||||
channel_id: APINullable[Snowflake] = None
|
channel_id: APINullable[Snowflake] = None
|
||||||
guild_id: APINullable[Snowflake] = None
|
guild_id: APINullable[Snowflake] = None
|
||||||
author: APINullable[List] = None
|
author: APINullable[Dict] = None
|
||||||
member: APINullable[List] = None
|
member: APINullable[Dict] = None
|
||||||
content: APINullable[str] = None
|
content: APINullable[str] = None
|
||||||
timestamp: APINullable[Timestamp] = None
|
timestamp: APINullable[Timestamp] = None
|
||||||
edited_timestamp: APINullable[Timestamp] = None
|
edited_timestamp: APINullable[Timestamp] = None
|
||||||
|
@ -190,12 +190,12 @@ class Message(APIModelBase):
|
||||||
pinned: APINullable[bool] = None
|
pinned: APINullable[bool] = None
|
||||||
webhook_id: APINullable[Snowflake] = None
|
webhook_id: APINullable[Snowflake] = None
|
||||||
type: APINullable[int] = None
|
type: APINullable[int] = None
|
||||||
activity: APINullable[List] = None
|
activity: APINullable[Dict] = None
|
||||||
application: APINullable[List] = None
|
application: APINullable[Dict] = None
|
||||||
application_id: APINullable[Snowflake] = None
|
application_id: APINullable[Snowflake] = None
|
||||||
message_reference: APINullable[List] = None
|
message_reference: APINullable[Dict] = None
|
||||||
flags: APINullable[int] = None
|
flags: APINullable[int] = None
|
||||||
interaction: APINullable[List] = None
|
interaction: APINullable[Dict] = None
|
||||||
thread: APINullable[Thread] = None
|
thread: APINullable[Thread] = None
|
||||||
components: APINullable[List] = None
|
components: APINullable[List] = None
|
||||||
sticker_items: APINullable[List] = None
|
sticker_items: APINullable[List] = None
|
||||||
|
|
117
melisa/utils/waiters.py
Normal file
117
melisa/utils/waiters.py
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
# Copyright MelisaDev 2022 - Present
|
||||||
|
# Full MIT License can be found in `LICENSE.txt` at the project root.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from asyncio import AbstractEventLoop, Event, wait_for as async_wait, TimeoutError as AsyncTimeOut
|
||||||
|
from typing import List, Callable, Optional
|
||||||
|
|
||||||
|
from ..exceptions import MelisaTimeoutError
|
||||||
|
|
||||||
|
|
||||||
|
class _Waiter:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
event_name : str
|
||||||
|
The name of the event.
|
||||||
|
check : Optional[Callable[[Any], :class:`bool`]]
|
||||||
|
``can_be_set`` only returns true if this function returns true.
|
||||||
|
Will be ignored if set to None.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
event: :class:`asyncio.Event`
|
||||||
|
Even that is used to wait until the next valid discord event.
|
||||||
|
return_value : Optional[str]
|
||||||
|
Used to store the arguments from ``can_be_set`` so they can be
|
||||||
|
returned later.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, event_name: str, check: Optional[Callable] = None):
|
||||||
|
self.event_name = event_name
|
||||||
|
self.check = check
|
||||||
|
self.event = Event()
|
||||||
|
self.return_value = None
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def wait(self):
|
||||||
|
"""Waits until ``self.event`` is set."""
|
||||||
|
await self.event.wait()
|
||||||
|
|
||||||
|
def process(self, event_name: str, event_value):
|
||||||
|
if self.event_name != event_name:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.check:
|
||||||
|
if event_value is not None:
|
||||||
|
if not self.check(event_value):
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
if not self.check():
|
||||||
|
return
|
||||||
|
|
||||||
|
self.return_value = event_value
|
||||||
|
self.event.set()
|
||||||
|
|
||||||
|
|
||||||
|
class WaiterMgr:
|
||||||
|
"""
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
waiter_list : List
|
||||||
|
The List of events that need to be processed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, loop: AbstractEventLoop):
|
||||||
|
self.waiter_list: List[_Waiter] = []
|
||||||
|
self.loop = loop
|
||||||
|
|
||||||
|
def process_events(self, event_name, event_value):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
event_name : str
|
||||||
|
The name of the event to be processed.
|
||||||
|
event_value : Any
|
||||||
|
The object returned from the middleware for this event.
|
||||||
|
"""
|
||||||
|
for waiter in self.waiter_list:
|
||||||
|
waiter.process(event_name, event_value)
|
||||||
|
|
||||||
|
async def wait_for(
|
||||||
|
self,
|
||||||
|
event_name: str,
|
||||||
|
check: Optional[Callable[..., bool]] = None,
|
||||||
|
timeout: Optional[float] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
event_name: :class:`str`
|
||||||
|
The type of event. It should start with `on_`.
|
||||||
|
check: Optional[Callable[[Any], :class:`bool`]]
|
||||||
|
This function only returns a value if this return true.
|
||||||
|
timeout: Optional[float]
|
||||||
|
Amount of seconds before timeout. Use None for no timeout.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
------
|
||||||
|
Any
|
||||||
|
What the Discord API returns for this event.
|
||||||
|
"""
|
||||||
|
|
||||||
|
waiter = _Waiter(event_name, check)
|
||||||
|
self.waiter_list.append(waiter)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await async_wait(waiter.wait(), timeout=timeout)
|
||||||
|
self.waiter_list.remove(waiter)
|
||||||
|
except AsyncTimeOut:
|
||||||
|
self.waiter_list.remove(waiter)
|
||||||
|
raise MelisaTimeoutError(
|
||||||
|
"wait_for() timed out while waiting for an event."
|
||||||
|
)
|
||||||
|
|
||||||
|
return waiter.return_value
|
||||||
|
|
Loading…
Reference in a new issue