diff --git a/melisa/client.py b/melisa/client.py index 7f476fc..f657c9e 100644 --- a/melisa/client.py +++ b/melisa/client.py @@ -83,24 +83,26 @@ class Client: def run(self) -> None: """ - Run Bot without shards (only 0 shard) + Run Bot without shards (only 0 shard) """ inited_shard = Shard(self, 0, 1) - asyncio.ensure_future(inited_shard.launch(activity=self._activity, - status=self._status), loop=self._loop) + asyncio.ensure_future( + inited_shard.launch(activity=self._activity, status=self._status), + loop=self._loop, + ) self._loop.run_forever() def run_shards(self, num_shards: int, *, shard_ids: List[int] = None): """ - Run Bot with shards specified by the user. + Run Bot with shards specified by the user. - Parameters - ---------- - num_shards : :class:`int` - The endpoint to send the request to. - shard_ids: Optional[:class:`List[int]`] - List of Ids of shards to start. + Parameters + ---------- + num_shards : :class:`int` + The endpoint to send the request to. + shard_ids: Optional[:class:`List[int]`] + List of Ids of shards to start. """ if not shard_ids: shard_ids = range(num_shards) @@ -108,13 +110,15 @@ class Client: for shard_id in shard_ids: inited_shard = Shard(self, shard_id, num_shards) - asyncio.ensure_future(inited_shard.launch(activity=self._activity, - status=self._status), loop=self._loop) + asyncio.ensure_future( + inited_shard.launch(activity=self._activity, status=self._status), + loop=self._loop, + ) self._loop.run_forever() def run_autosharded(self): """ - Runs the bot with the amount of shards specified by the Discord gateway. + Runs the bot with the amount of shards specified by the Discord gateway. """ num_shards = self._gateway_info.shards shard_ids = range(num_shards) @@ -122,18 +126,20 @@ class Client: for shard_id in shard_ids: inited_shard = Shard(self, shard_id, num_shards) - asyncio.ensure_future(inited_shard.launch(activity=self._activity, - status=self._status), loop=self._loop) + asyncio.ensure_future( + inited_shard.launch(activity=self._activity, status=self._status), + loop=self._loop, + ) self._loop.run_forever() async def fetch_user(self, user_id: Union[Snowflake, str, int]): """ - Fetch User from the Discord API (by id). + Fetch User from the Discord API (by id). - Parameters - ---------- - user_id : :class:`Union[Snowflake, str, int]` - Id of user to fetch + Parameters + ---------- + user_id : :class:`Union[Snowflake, str, int]` + Id of user to fetch """ # ToDo: Update cache if USER_CACHING enabled. @@ -144,12 +150,12 @@ class Client: async def fetch_guild(self, guild_id: Union[Snowflake, str, int]): """ - Fetch Guild from the Discord API (by id). + Fetch Guild from the Discord API (by id). - Parameters - ---------- - guild_id : :class:`Union[Snowflake, str, int]` - Id of guild to fetch + Parameters + ---------- + guild_id : :class:`Union[Snowflake, str, int]` + Id of guild to fetch """ # ToDo: Update cache if GUILD_CACHE enabled. diff --git a/melisa/core/gateway.py b/melisa/core/gateway.py index 664781d..d736052 100644 --- a/melisa/core/gateway.py +++ b/melisa/core/gateway.py @@ -20,6 +20,7 @@ from ..utils import APIModelBase, json @dataclass class GatewayBotInfo(APIModelBase): """Gateway info from the `gateway/bot` endpoint""" + url: str shards: int session_start_limit: dict @@ -39,11 +40,7 @@ class Gateway: HELLO = 10 HEARTBEAT_ACK = 11 - def __init__(self, - client, - shard_id: int = 0, - num_shards: int = 1, - **kwargs): + def __init__(self, client, shard_id: int = 0, num_shards: int = 1, **kwargs): self.GATEWAY_VERSION = "9" self.interval = None @@ -52,7 +49,7 @@ class Gateway: self.__session = aiohttp.ClientSession() self.session_id = None self.client = client - self.latency = float('inf') + self.latency = float("inf") self.ws = None self.loop = asyncio.get_event_loop() self.shard_id = shard_id @@ -63,7 +60,7 @@ class Gateway: 4011: GatewayError("Sharding required"), 4012: GatewayError("Invalid API version"), 4013: GatewayError("Invalid intents"), - 4014: PrivilegedIntentsRequired("Disallowed intents") + 4014: PrivilegedIntentsRequired("Disallowed intents"), } self.listeners = listeners @@ -76,19 +73,21 @@ class Gateway: "properties": { "$os": sys.platform, "$browser": "Melisa Python Library", - "$device": "Melisa Python Library" + "$device": "Melisa Python Library", }, "compress": True, "shard": [shard_id, num_shards], - "presence": self.generate_presence(kwargs.get("start_activity"), - kwargs.get("start_status"))} + "presence": self.generate_presence( + kwargs.get("start_activity"), kwargs.get("start_status") + ), + } self._zlib: zlib._Decompress = zlib.decompressobj() self._buffer: bytearray = bytearray() async def connect(self) -> None: self.ws = await self.__session.ws_connect( - f'wss://gateway.discord.gg/?v={self.GATEWAY_VERSION}&encoding=json&compress=zlib-stream' + f"wss://gateway.discord.gg/?v={self.GATEWAY_VERSION}&encoding=json&compress=zlib-stream" ) if self.session_id is None: @@ -114,10 +113,10 @@ class Gateway: if type(msg) is bytes: self._buffer.extend(msg) - if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff': + if len(msg) < 4 or msg[-4:] != b"\x00\x00\xff\xff": return None msg = self._zlib.decompress(self._buffer) - msg = msg.decode('utf-8') + msg = msg.decode("utf-8") self._buffer = bytearray() return json.loads(msg) @@ -125,7 +124,7 @@ class Gateway: return None async def handle_data(self, data): - if data['op'] == self.DISPATCH: + if data["op"] == self.DISPATCH: self.sequence = int(data["s"]) event_type = data["t"].lower() @@ -134,12 +133,12 @@ class Gateway: if event_to_call is not None: ensure_future(event_to_call(self.client, self, data["d"])) - elif data['op'] == self.INVALID_SESSION: + elif data["op"] == self.INVALID_SESSION: await self.ws.close(code=4000) await self.handle_close(4000) - elif data['op'] == self.HELLO: + elif data["op"] == self.HELLO: await self.send_hello(data) - elif data['op'] == self.HEARTBEAT_ACK: + elif data["op"] == self.HEARTBEAT_ACK: self.latency = time.perf_counter() - self._last_send async def receive(self) -> None: @@ -184,34 +183,28 @@ class Gateway: self._buffer.clear() async def send_hello(self, data: Dict) -> None: - interval = data['d']['heartbeat_interval'] / 1000 + interval = data["d"]["heartbeat_interval"] / 1000 await asyncio.sleep((interval - 2000) / 1000) self.loop.create_task(self.send_heartbeat(interval)) async def send_identify(self) -> None: - await self.send(self.opcode( - self.IDENTIFY, - self.auth - )) + await self.send(self.opcode(self.IDENTIFY, self.auth)) async def resume(self) -> None: await self.send( self.opcode( self.RESUME, { - 'token': self.client._token, - 'session_id': self.session_id, - 'seq': self.sequence, - } + "token": self.client._token, + "session_id": self.session_id, + "seq": self.sequence, + }, ) ) @staticmethod def generate_presence(activity: BotActivity = None, status: str = None): - data = { - "since": time.time() * 1000, - "afk": False - } + data = {"since": time.time() * 1000, "afk": False} if activity is not None: data["activities"] = activity.to_dict() @@ -226,8 +219,5 @@ class Gateway: @staticmethod def opcode(opcode: int, payload) -> str: - data = { - "op": opcode, - "d": payload - } + data = {"op": opcode, "d": payload} return json.dumps(data) diff --git a/melisa/core/http.py b/melisa/core/http.py index bd985dc..9b702fe 100644 --- a/melisa/core/http.py +++ b/melisa/core/http.py @@ -8,15 +8,17 @@ from typing import Dict, Optional from aiohttp import ClientSession, ClientResponse -from melisa.exceptions import (NotModifiedError, - BadRequestError, - ForbiddenError, - UnauthorizedError, - HTTPException, - NotFoundError, - MethodNotAllowedError, - ServerError, - RateLimitError) +from melisa.exceptions import ( + NotModifiedError, + BadRequestError, + ForbiddenError, + UnauthorizedError, + HTTPException, + NotFoundError, + MethodNotAllowedError, + ServerError, + RateLimitError, +) from .ratelimiter import RateLimiter @@ -31,7 +33,7 @@ class HTTPClient: headers: Dict[str, str] = { "Content-Type": "application/json", "Authorization": f"Bot {token}", - "User-Agent": "Melisa Python Library" + "User-Agent": "Melisa Python Library", } self.__http_exceptions: Dict[int, HTTPException] = { @@ -41,7 +43,7 @@ class HTTPClient: 403: ForbiddenError(), 404: NotFoundError(), 405: MethodNotAllowedError(), - 429: RateLimitError() + 429: RateLimitError(), } self.__aiohttp_session: ClientSession = ClientSession(headers=headers) @@ -57,12 +59,7 @@ class HTTPClient: await self.__aiohttp_session.close() async def __send( - self, - method: str, - endpoint: str, - *, - _ttl: int = None, - **kwargs + self, method: str, endpoint: str, *, _ttl: int = None, **kwargs ) -> Optional[Dict]: """Send an API request to the Discord API.""" @@ -71,30 +68,27 @@ class HTTPClient: if ttl == 0: raise ServerError(f"Maximum amount of retries for `{endpoint}`.") - await self.__rate_limiter.wait_until_not_ratelimited( - endpoint, - method - ) + await self.__rate_limiter.wait_until_not_ratelimited(endpoint, method) url = f"{self.url}/{endpoint}" async with self.__aiohttp_session.request(method, url, **kwargs) as response: - return await self.__handle_response(response, method, endpoint, _ttl=ttl, **kwargs) + return await self.__handle_response( + response, method, endpoint, _ttl=ttl, **kwargs + ) async def __handle_response( - self, - res: ClientResponse, - method: str, - endpoint: str, - *, - _ttl: int = None, - **kwargs + self, + res: ClientResponse, + method: str, + endpoint: str, + *, + _ttl: int = None, + **kwargs, ) -> Optional[Dict]: """Handle responses from the Discord API.""" - self.__rate_limiter.save_response_bucket( - endpoint, method, res.headers - ) + self.__rate_limiter.save_response_bucket(endpoint, method, res.headers) if res.ok: return await res.json() @@ -106,11 +100,7 @@ class HTTPClient: timeout = (await res.json()).get("retry_after", 40) await asyncio.sleep(timeout) - return await self.__send( - method, - endpoint, - **kwargs - ) + return await self.__send(method, endpoint, **kwargs) exception.__init__(res.reason) raise exception @@ -119,18 +109,9 @@ class HTTPClient: await asyncio.sleep(retry_in) - return await self.__send( - method, - endpoint, - _ttl=_ttl - 1, - **kwargs - ) + return await self.__send(method, endpoint, _ttl=_ttl - 1, **kwargs) - async def get( - self, - route: str, - params: Optional[Dict] = None - ) -> Optional[Dict]: + async def get(self, route: str, params: Optional[Dict] = None) -> Optional[Dict]: """|coro| Sends a GET request to a Discord REST API endpoint. @@ -146,17 +127,9 @@ class HTTPClient: Optional[:class:`Dict`] The response from Discord. """ - return await self.__send( - "GET", - route, - params=params - ) + return await self.__send("GET", route, params=params) - async def post( - self, - route: str, - data: Optional[Dict] = None - ) -> Optional[Dict]: + async def post(self, route: str, data: Optional[Dict] = None) -> Optional[Dict]: """|coro| Sends a POST request to a Discord REST API endpoint. @@ -178,11 +151,7 @@ class HTTPClient: json=data, ) - async def delete( - self, - route: str, - headers: dict = None - ) -> Optional[Dict]: + async def delete(self, route: str, headers: dict = None) -> Optional[Dict]: """|coro| Sends a DELETE request to a Discord REST API endpoint. @@ -198,8 +167,4 @@ class HTTPClient: Optional[:class:`Dict`] JSON response from the Discord API. """ - return await self.__send( - "DELETE", - route, - headers=headers - ) + return await self.__send("DELETE", route, headers=headers) diff --git a/melisa/core/ratelimiter.py b/melisa/core/ratelimiter.py index 2a35492..2099e56 100644 --- a/melisa/core/ratelimiter.py +++ b/melisa/core/ratelimiter.py @@ -12,6 +12,7 @@ from typing import Dict, Tuple, Any @dataclass class RateLimitBucket: """Represents a rate limit bucket""" + limit: int remaining: int reset: float @@ -23,15 +24,12 @@ class RateLimiter: """Prevents ``user`` rate limits""" def __init__(self) -> None: - self.bucket_map: Dict[Tuple[str, str], str] = {} # Dict[Tuple[endpoint, method], bucket_id] + self.bucket_map: Dict[ + Tuple[str, str], str + ] = {} # Dict[Tuple[endpoint, method], bucket_id] self.buckets: Dict[str, RateLimitBucket] = {} - def save_response_bucket( - self, - endpoint: str, - method: str, - header: Any - ): + def save_response_bucket(self, endpoint: str, method: str, header: Any): ratelimit_bucket_id = header.get("X-RateLimit-Bucket") if not ratelimit_bucket_id: @@ -44,14 +42,10 @@ class RateLimiter: remaining=int(header["X-RateLimit-Remaining"]), reset=float(header["X-RateLimit-Reset"]), reset_after_timestamp=float(header["X-RateLimit-Reset-After"]), - since_timestamp=time() + since_timestamp=time(), ) - async def wait_until_not_ratelimited( - self, - endpoint: str, - method: str - ): + async def wait_until_not_ratelimited(self, endpoint: str, method: str): bucket_id = self.bucket_map.get((endpoint, method)) if not bucket_id: diff --git a/melisa/exceptions.py b/melisa/exceptions.py index bd7b134..885e1cf 100644 --- a/melisa/exceptions.py +++ b/melisa/exceptions.py @@ -1,25 +1,29 @@ # Copyright MelisaDev 2022 - Present # Full MIT License can be found in `LICENSE.txt` at the project root. + class MelisaException(Exception): """Base exception""" + pass class ClientException(MelisaException): """Handling user errors""" + pass class LoginFailure(ClientException): """Fails to log you in from improper credentials or some other misc.""" + pass class ConnectionClosed(ClientException): """Exception that's thrown when the gateway connection is closed for reasons that could not be handled - internally. """ + internally.""" def __init__(self, socket, *, shard_id, code=None): message = "Websocket with shard ID {} closed with code {}" @@ -38,9 +42,11 @@ class PrivilegedIntentsRequired(ClientException): def __init__(self, shard_id): self.shard_id = shard_id - message = "Shard ID {} is requesting privileged intents " \ - "that have not been explicitly enabled in the " \ - "developer portal. Please visit to https://discord.com/developers/applications/ " + message = ( + "Shard ID {} is requesting privileged intents " + "that have not been explicitly enabled in the " + "developer portal. Please visit to https://discord.com/developers/applications/ " + ) super().__init__(message.format(self.shard_id)) diff --git a/melisa/models/app/shard.py b/melisa/models/app/shard.py index be07ed5..f6db724 100644 --- a/melisa/models/app/shard.py +++ b/melisa/models/app/shard.py @@ -10,10 +10,7 @@ from ..user import BotActivity class Shard: - def __init__(self, - client, - shard_id: int, - num_shards: int): + def __init__(self, client, shard_id: int, num_shards: int): self._client = client self._shard_id: int = shard_id @@ -38,11 +35,13 @@ class Shard: """|coro| Launches new shard""" - self._gateway = Gateway(self._client, - self._shard_id, - self._num_shards, - start_activity=kwargs.get("activity"), - start_status=kwargs.get("status")) + self._gateway = Gateway( + self._client, + self._shard_id, + self._num_shards, + start_activity=kwargs.get("activity"), + start_status=kwargs.get("status"), + ) self._client.shards[self._shard_id] = self @@ -54,11 +53,13 @@ class Shard: async def close(self): """|coro| - Disconnect shard + Disconnect shard """ create_task(self._gateway.close()) - async def update_presence(self, activity: BotActivity = None, status: str = None) -> Shard: + async def update_presence( + self, activity: BotActivity = None, status: str = None + ) -> Shard: """ |coro| diff --git a/melisa/models/user/presence.py b/melisa/models/user/presence.py index 09f7402..317663b 100644 --- a/melisa/models/user/presence.py +++ b/melisa/models/user/presence.py @@ -12,10 +12,10 @@ from ...utils.types import APINullable class BasePresence: """ - All the information about activities here is from the Discord API docs. - Read more here: https://discord.com/developers/docs/topics/gateway#activity-object + All the information about activities here is from the Discord API docs. + Read more here: https://discord.com/developers/docs/topics/gateway#activity-object - Unknown data will be returned as None. + Unknown data will be returned as None. """ @@ -39,6 +39,7 @@ class ActivityType(IntEnum): COMPETING: Competing in {name} (Competing in Arena World Champions) """ + GAME = 0 STREAMING = 1 LISTENING = 2 @@ -61,6 +62,7 @@ class ActivityTimestamp(BasePresence, APIModelBase): end: Optional[:class:`int`] Unix time (in milliseconds) of when the activity ends """ + start: APINullable[int] = None end: APINullable[int] = None @@ -78,6 +80,7 @@ class ActivityEmoji(BasePresence, APIModelBase): animated: Optional[:class:`bool`] Whether this emoji is animated """ + name: str id: APINullable[Snowflake] = None animated: APINullable[bool] = None @@ -94,6 +97,7 @@ class ActivityParty(BasePresence, APIModelBase): size: Optional[Tuple[:class:`int`, :class:`int`]] Array of two integers (current_size, max_size) """ + id: APINullable[str] = None size: APINullable[Tuple[int, int]] = None @@ -115,6 +119,7 @@ class ActivityAssets(BasePresence, APIModelBase): small_text: Optional[:class:`str`] text displayed when hovering over the small image of the activity """ + large_image: APINullable[str] = None large_text: APINullable[str] = None small_image: APINullable[str] = None @@ -134,6 +139,7 @@ class ActivitySecrets(BasePresence, APIModelBase): match: Optional[:class:`str`] The secret for a specific instanced match """ + join: APINullable[str] = None spectate: APINullable[str] = None match_: APINullable[str] = None @@ -141,9 +147,9 @@ class ActivitySecrets(BasePresence, APIModelBase): class ActivityFlags(BasePresence, APIModelBase): """ - Just Activity Flags (From Discord API). + Just Activity Flags (From Discord API). - Everything returns :class:`bool` value. + Everything returns :class:`bool` value. """ def __init__(self, flags) -> None: @@ -172,6 +178,7 @@ class ActivityButton(BasePresence, APIModelBase): url: :class:`str` The url opened when clicking the button (1-512 characters) """ + label: str url: str @@ -251,11 +258,11 @@ class BotActivity(BasePresence, APIModelBase): class StatusType(Enum): - ONLINE = 'online' - OFFLINE = 'offline' - IDLE = 'idle' - DND = 'dnd' - INVISIBLE = 'invisible' + ONLINE = "online" + OFFLINE = "offline" + IDLE = "idle" + DND = "dnd" + INVISIBLE = "invisible" def __str__(self): return self.value diff --git a/melisa/models/user/user.py b/melisa/models/user/user.py index e11abd1..2f91e7e 100644 --- a/melisa/models/user/user.py +++ b/melisa/models/user/user.py @@ -169,20 +169,12 @@ class User(APIModelBase): """APINullable[:class:`~melisa.models.user.user.PremiumTypes`]: The user their premium type in a usable enum. """ - return ( - None - if self.premium_type is None - else PremiumTypes(self.premium_type) - ) + return None if self.premium_type is None else PremiumTypes(self.premium_type) @property def flags(self) -> Optional[UserFlags]: """Flags of user""" - return( - None - if self.flags is None - else UserFlags(self.flags) - ) + return None if self.flags is None else UserFlags(self.flags) def __str__(self): """String representation of the User object""" @@ -195,10 +187,13 @@ class User(APIModelBase): def avatar_url(self) -> str: """Avatar url (from the Discord CDN server)""" - return "https://cdn.discordapp.com/avatars/{}/{}.png?size=1024".format(self.id, self.avatar) + return "https://cdn.discordapp.com/avatars/{}/{}.png?size=1024".format( + self.id, self.avatar + ) async def create_dm_channel(self): # ToDo: Add docstrings # ToDo: Add checking this channel in cache return await self._http.post( - "/users/@me/channels", data={"recipient_id": self.id}) + "/users/@me/channels", data={"recipient_id": self.id} + ) diff --git a/melisa/utils/__init__.py b/melisa/utils/__init__.py index cc9c10f..6f2a81a 100644 --- a/melisa/utils/__init__.py +++ b/melisa/utils/__init__.py @@ -1,17 +1,11 @@ # Copyright MelisaDev 2022 - Present # Full MIT License can be found in `LICENSE.txt` at the project root. -from .types import ( - Coro -) +from .types import Coro from .snowflake import Snowflake from .api_model import APIModelBase -__all__ = ( - "Coro", - "Snowflake", - "APIModelBase" -) +__all__ = ("Coro", "Snowflake", "APIModelBase") diff --git a/melisa/utils/api_model.py b/melisa/utils/api_model.py index 005e4c1..94d3980 100644 --- a/melisa/utils/api_model.py +++ b/melisa/utils/api_model.py @@ -66,9 +66,7 @@ class APIModelBase: cls._client = client @classmethod - def from_dict( - cls: Generic[T], data: Dict[str, Union[str, bool, int, Any]] - ) -> T: + def from_dict(cls: Generic[T], data: Dict[str, Union[str, bool, int, Any]]) -> T: """ Parse an API object from a dictionary. """ @@ -81,13 +79,10 @@ class APIModelBase: map( lambda key: ( key, - data[key].value - if isinstance(data[key], Enum) - else data[key], + data[key].value if isinstance(data[key], Enum) else data[key], ), filter( - lambda object_argument: data.get(object_argument) - is not None, + lambda object_argument: data.get(object_argument) is not None, getfullargspec(cls.__init__).args, ), ) diff --git a/melisa/utils/json.py b/melisa/utils/json.py index aa0e8d6..92cf1ba 100644 --- a/melisa/utils/json.py +++ b/melisa/utils/json.py @@ -15,12 +15,14 @@ else: HAS_ORJSON = True if HAS_ORJSON: + def dumps(obj: Any) -> str: - return orjson.dumps(obj).decode('utf-8') + return orjson.dumps(obj).decode("utf-8") loads = orjson.loads else: + def dumps(obj: Any) -> str: - return json.dumps(obj, separators=(',', ':'), ensure_ascii=True) + return json.dumps(obj, separators=(",", ":"), ensure_ascii=True) loads = json.loads diff --git a/melisa/utils/snowflake.py b/melisa/utils/snowflake.py index 7265371..6a83639 100644 --- a/melisa/utils/snowflake.py +++ b/melisa/utils/snowflake.py @@ -24,9 +24,7 @@ class Snowflake(int): super().__init__() if self < self._MIN_VALUE: - raise ValueError( - "snowflake value should be greater than or equal to 0." - ) + raise ValueError("snowflake value should be greater than or equal to 0.") if self > self._MAX_VALUE: raise ValueError( @@ -66,7 +64,7 @@ class Snowflake(int): @property def increment(self) -> int: - """ For every ID that is generated on that process, this number is incremented""" + """For every ID that is generated on that process, this number is incremented""" return self % 2048 @property