diff --git a/melisa/__init__.py b/melisa/__init__.py index e69de29..3ff722b 100644 --- a/melisa/__init__.py +++ b/melisa/__init__.py @@ -0,0 +1 @@ +from .client import Client diff --git a/melisa/client.py b/melisa/client.py index e8a3a3f..a324950 100644 --- a/melisa/client.py +++ b/melisa/client.py @@ -2,9 +2,10 @@ from .models.app import Shard from .utils.types import Coro from .core.http import HTTPClient +from .core.gateway import GatewayBotInfo import asyncio -from typing import Dict +from typing import Dict, List class Client: @@ -17,17 +18,25 @@ class Client: self.loop = asyncio.get_event_loop() + self._gateway_info = self.loop.run_until_complete(self._get_gateway()) + self.intents = intents self._token = token self._activity = kwargs.get("activity") self._status = kwargs.get("status") + async def _get_gateway(self): + """Get Gateway information""" + return GatewayBotInfo.from_dict(await self.http.get("gateway/bot")) + def listen(self, callback: Coro): """Method to set the listener. - Args: - callback (:obj:`function`) - Coroutine Callback Function + + Parameters + ---------- + callback (:obj:`function`) + Coroutine Callback Function """ if not asyncio.iscoroutinefunction(callback): raise TypeError(f"<{callback.__qualname__}> must be a coroutine function") @@ -43,3 +52,41 @@ class Client: asyncio.ensure_future(inited_shard.launch(activity=self._activity, status=self._status), loop=self.loop) self.loop.run_forever() + + def run_shards(self, num_shards: int, *, shard_ids: List[int] = None): + """ + Run Bot with shards specified by the user. + + Parameters + ---------- + num_shards : :class:`int` + The endpoint to send the request to. + + Keyword Arguments: + + shard_ids: Optional[:class:`List[int]`] + List of Ids of shards to start. + """ + if not shard_ids: + shard_ids = range(num_shards) + + for shard_id in shard_ids: + inited_shard = Shard(self, shard_id, num_shards) + + asyncio.ensure_future(inited_shard.launch(activity=self._activity, status=self._status), loop=self.loop) + self.loop.run_forever() + + def run_autosharded(self): + """ + Runs the bot with the amount of shards specified by the Discord gateway. + """ + num_shards = self._gateway_info.shards + shard_ids = range(num_shards) + + for shard_id in shard_ids: + inited_shard = Shard(self, shard_id, num_shards) + + asyncio.ensure_future(inited_shard.launch(activity=self._activity, status=self._status), loop=self.loop) + self.loop.run_forever() + + diff --git a/melisa/core/gateway.py b/melisa/core/gateway.py index 4972869..3f303ca 100644 --- a/melisa/core/gateway.py +++ b/melisa/core/gateway.py @@ -2,11 +2,21 @@ import json import asyncio import zlib import time +from dataclasses import dataclass import websockets from ..listeners import listeners from ..models.user import BotActivity +from ..utils import APIObjectBase + + +@dataclass +class GatewayBotInfo(APIObjectBase): + """Gateway info from the `gateway/bot` endpoint""" + url: str + shards: int + session_start_limit: dict class Gateway: @@ -79,7 +89,7 @@ class Gateway: return await asyncio.gather(self.heartbeat(), self.receive()) - async def close(self, code: int = 4000): + async def close(self, code: int = 1000): await self.websocket.close(code=code) async def resume(self): diff --git a/melisa/utils/__init__.py b/melisa/utils/__init__.py index 5f23682..afc82c3 100644 --- a/melisa/utils/__init__.py +++ b/melisa/utils/__init__.py @@ -4,7 +4,11 @@ from .types import ( from .snowflake import Snowflake + +from .api_object import APIObjectBase + __all__ = ( "Coro", - "Snowflake" + "Snowflake", + "APIObjectBase" ) diff --git a/melisa/utils/api_object.py b/melisa/utils/api_object.py new file mode 100644 index 0000000..85de592 --- /dev/null +++ b/melisa/utils/api_object.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from enum import Enum +from inspect import getfullargspec +from typing import ( + Dict, + Union, + Generic, + TypeVar, + Any, +) + +T = TypeVar("T") + + +class APIObjectBase: + """ + Represents an object which has been fetched from the Discord API. + """ + + _client = None + + @property + def _http(self): + if not self._client: + raise AttributeError("Object is not yet linked to a main client") + + return self._client.http + + @classmethod + def from_dict( + cls: Generic[T], data: Dict[str, Union[str, bool, int, Any]] + ) -> T: + """ + Parse an API object from a dictionary. + """ + if isinstance(data, cls): + return data + + # noinspection PyArgumentList + return cls( + **dict( + map( + lambda key: ( + key, + data[key].value + if isinstance(data[key], Enum) + else data[key], + ), + filter( + lambda object_argument: data.get(object_argument) + is not None, + getfullargspec(cls.__init__).args, + ), + ) + ) + )