From 359ea514c0e4df2f0ecdd36773893e98b00e736c Mon Sep 17 00:00:00 2001 From: TheMisterSenpai Date: Mon, 18 Apr 2022 19:34:33 +0300 Subject: [PATCH] BREAKING CHANGE: fix my stupidity --- melisa/client.py | 25 +---- melisa/core/http.py | 2 +- melisa/core/ratelimiter.py | 4 +- melisa/listeners/channel_create.py | 8 +- melisa/models/__init__.py | 2 +- melisa/models/app/__init__.py | 4 +- melisa/models/message/embed.py | 145 ++++++++++++++++++++++------ melisa/models/message/message.py | 146 ++++++++++++++++++++++------- melisa/models/user/user.py | 21 +---- melisa/rest.py | 70 ++++++++++++++ melisa/utils/__init__.py | 12 ++- packages/speed.txt | 2 +- tests/test_embeds.py | 6 -- 13 files changed, 318 insertions(+), 129 deletions(-) create mode 100644 melisa/rest.py diff --git a/melisa/client.py b/melisa/client.py index 4a9c941..836b9d0 100644 --- a/melisa/client.py +++ b/melisa/client.py @@ -1,6 +1,3 @@ -# Copyright MelisaDev 2022 - Present -# Full MIT License can be found in `LICENSE.txt` at the project root. - import logging import asyncio import signal @@ -26,9 +23,7 @@ _logger = logging.getLogger("melisa") class Client: """ This is the main instance which is between the programmer and the Discord API. - This Client represents your bot. - Parameters ---------- token: :class:`str` @@ -50,7 +45,6 @@ class Client: If you pass a :class:`str` or a :class:`int`, it is interpreted as the global logging level to use, and should match one of **DEBUG**, **INFO**, **WARNING**, **ERROR** or **CRITICAL**, if :class:`str`. - Attributes ---------- user: :class:`~models.user.user.User` @@ -126,7 +120,6 @@ class Client: def listen(self, callback: Coro): """Method or Decorator to set the listener. - Parameters ---------- callback : :class:`melisa.utils.types.Coro` @@ -142,7 +135,6 @@ class Client: async def dispatch(self, name: str, *args): """ Dispatches an event - Parameters ---------- name: :class:`str` @@ -183,7 +175,6 @@ class Client: def run_shards(self, num_shards: int, *, shard_ids: List[int] = None): """ Run Bot with shards specified by the user. - Parameters ---------- num_shards : :class:`int` @@ -226,7 +217,6 @@ class Client: async def fetch_user(self, user_id: Union[Snowflake, str, int]): """ Fetch User from the Discord API (by id). - Parameters ---------- user_id : :class:`Union[Snowflake, str, int]` @@ -242,7 +232,6 @@ class Client: async def fetch_guild(self, guild_id: Union[Snowflake, str, int]): """ Fetch Guild from the Discord API (by id). - Parameters ---------- guild_id : :class:`Union[Snowflake, str, int]` @@ -262,7 +251,6 @@ class Client: Fetch Channel from the Discord API (by id). If type of channel is unknown: it will return just :class:`melisa.models.guild.channel.Channel` object. - Parameters ---------- channel_id : :class:`Union[Snowflake, str, int]` @@ -286,38 +274,28 @@ class Client: timeout: Optional[float] = None, ): """|coro| - Waits for a WebSocket event to be dispatched. - This could be used to wait for a user to reply to a message, or to react to a message. - The ``timeout`` parameter is passed onto :func:`asyncio.wait_for`. By default, it does not timeout. Note that this does propagate the :exc:`asyncio.TimeoutError` for you in case of timeout and is provided for ease of use. - In case the event returns multiple arguments, a :class:`tuple` containing those arguments is returned instead. - This function returns the **first event that meets the requirements**. - Examples -------- Waiting for a user reply: :: - @client.listen async def on_message_create(message): if message.content.startswith('$greet'): channel = await client.fetch_channel(message.channel_id) await channel.send('Say hello!') - def check(m): return m.content == "hello" and channel.id == message.channel_id - msg = await client.wait_for('on_message_create', check=check, timeout=10.0) await channel.send(f'Hello man!') - Parameters ---------- event_name: :class:`str` @@ -328,7 +306,6 @@ class Client: timeout: Optional[:class:`float`] The number of seconds to wait before timing out and raising :exc:`asyncio.TimeoutError`. - Returns ------ Any @@ -338,4 +315,4 @@ class Client: return await self._waiter_mgr.wait_for(event_name, check, timeout) -Bot = Client +Bot = Client \ No newline at end of file diff --git a/melisa/core/http.py b/melisa/core/http.py index d189e96..365aa08 100644 --- a/melisa/core/http.py +++ b/melisa/core/http.py @@ -9,7 +9,7 @@ from typing import Dict, Optional, Any from aiohttp import ClientSession, ClientResponse -from melisa.exceptions import ( +from ..exceptions import ( NotModifiedError, BadRequestError, ForbiddenError, diff --git a/melisa/core/ratelimiter.py b/melisa/core/ratelimiter.py index 456b717..6dc971e 100644 --- a/melisa/core/ratelimiter.py +++ b/melisa/core/ratelimiter.py @@ -63,10 +63,10 @@ class RateLimiter: bucket = self.buckets[bucket_id] if bucket.remaining == 0: - sleep_time = time() - bucket.since_timestamp + bucket.reset_after_timestamp + sleep_time = time() - bucket.since_timestamp + bucket.reset_after _logger.info( - "Waiting until rate limit for bucket %s is over.", sleep_time, bucket_id + "Waiting until rate limit for bucket %s is over.", bucket_id ) await sleep(sleep_time) diff --git a/melisa/listeners/channel_create.py b/melisa/listeners/channel_create.py index fc3190d..81fe5f2 100644 --- a/melisa/listeners/channel_create.py +++ b/melisa/listeners/channel_create.py @@ -3,16 +3,12 @@ from __future__ import annotations +from ..models.guild.channel import _choose_channel_type from ..utils.types import Coro -from ..models.guild import Channel, ChannelType, channel_types_for_converting async def channel_create_listener(self, gateway, payload: dict): - payload.update({"type": ChannelType(payload.pop("type"))}) - - channel_cls = channel_types_for_converting.get(payload["type"], Channel) - - channel = channel_cls.from_dict(payload) + channel = _choose_channel_type(payload) await self.dispatch("on_channel_create", channel) diff --git a/melisa/models/__init__.py b/melisa/models/__init__.py index 42e98ae..0e58bb0 100644 --- a/melisa/models/__init__.py +++ b/melisa/models/__init__.py @@ -1,7 +1,7 @@ # Copyright MelisaDev 2022 - Present # Full MIT License can be found in `LICENSE.txt` at the project root. -from .app import * +from .app import Shard, Intents from .guild import * from .user import * from .message import * diff --git a/melisa/models/app/__init__.py b/melisa/models/app/__init__.py index ac32f4f..db2ec03 100644 --- a/melisa/models/app/__init__.py +++ b/melisa/models/app/__init__.py @@ -1,5 +1,5 @@ # Copyright MelisaDev 2022 - Present # Full MIT License can be found in `LICENSE.txt` at the project root. -from .intents import * -from .shard import * +from .intents import Intents +from .shard import Shard diff --git a/melisa/models/message/embed.py b/melisa/models/message/embed.py index 5a1a6fd..3074f4a 100644 --- a/melisa/models/message/embed.py +++ b/melisa/models/message/embed.py @@ -7,10 +7,11 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import List, Union, Optional +from typing import List, Union, Optional, Dict, Any from .colors import Color from melisa.exceptions import EmbedFieldError +from ...utils.conversion import try_enum from ...utils.api_model import APIModelBase from ...utils.types import APINullable, UNDEFINED from melisa.utils.timestamp import Timestamp @@ -64,9 +65,9 @@ class EmbedThumbnail: """ url: str - proxy_url: APINullable[str] = UNDEFINED - height: APINullable[int] = UNDEFINED - width: APINullable[int] = UNDEFINED + proxy_url: APINullable[str] = None + height: APINullable[int] = None + width: APINullable[int] = None @dataclass(repr=False) @@ -86,9 +87,9 @@ class EmbedVideo: """ url: str - proxy_url: APINullable[str] = UNDEFINED - height: APINullable[int] = UNDEFINED - width: APINullable[int] = UNDEFINED + proxy_url: APINullable[str] = None + height: APINullable[int] = None + width: APINullable[int] = None @dataclass(repr=False) @@ -108,9 +109,9 @@ class EmbedImage: """ url: str - proxy_url: APINullable[str] = UNDEFINED - height: APINullable[int] = UNDEFINED - width: APINullable[int] = UNDEFINED + proxy_url: APINullable[str] = None + height: APINullable[int] = None + width: APINullable[int] = None @dataclass(repr=False) @@ -125,8 +126,8 @@ class EmbedProvider: Url of provider """ - name: APINullable[str] = UNDEFINED - url: APINullable[str] = UNDEFINED + name: APINullable[str] = None + url: APINullable[str] = None @dataclass(repr=False) @@ -146,9 +147,9 @@ class EmbedAuthor: """ name: str - url: APINullable[str] = UNDEFINED - icon_url: APINullable[str] = UNDEFINED - proxy_icon_url: APINullable[str] = UNDEFINED + url: APINullable[str] = None + icon_url: APINullable[str] = None + proxy_icon_url: APINullable[str] = None @dataclass(repr=False) @@ -166,8 +167,8 @@ class EmbedFooter: """ text: str - icon_url: APINullable[str] = UNDEFINED - proxy_icon_url: APINullable[str] = UNDEFINED + icon_url: APINullable[str] = None + proxy_icon_url: APINullable[str] = None @dataclass(repr=False) @@ -227,19 +228,101 @@ class Embed(APIModelBase): Video information. """ - title: APINullable[str] = UNDEFINED - type: APINullable[EmbedType] = UNDEFINED - description: APINullable[str] = UNDEFINED - url: APINullable[str] = UNDEFINED - timestamp: APINullable[Timestamp] = UNDEFINED - color: APINullable[Color] = UNDEFINED - footer: APINullable[EmbedFooter] = UNDEFINED - image: APINullable[EmbedImage] = UNDEFINED - thumbnail: APINullable[EmbedThumbnail] = UNDEFINED - video: APINullable[EmbedVideo] = UNDEFINED - provider: APINullable[EmbedProvider] = UNDEFINED - author: APINullable[EmbedAuthor] = UNDEFINED - fields: APINullable[List[EmbedField]] = UNDEFINED + title: APINullable[str] = None + type: APINullable[EmbedType] = None + description: APINullable[str] = None + url: APINullable[str] = None + timestamp: APINullable[Timestamp] = None + color: APINullable[Color] = None + footer: APINullable[EmbedFooter] = None + image: APINullable[EmbedImage] = None + thumbnail: APINullable[EmbedThumbnail] = None + video: APINullable[EmbedVideo] = None + provider: APINullable[EmbedProvider] = None + author: APINullable[EmbedAuthor] = None + fields: APINullable[List[EmbedField]] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]): + """Generate a message from the given data. + + Parameters + ---------- + data: :class:`dict` + The dictionary to convert into an unknown channel. + """ + self: Embed = super().__new__(cls) + + self.title = data.get("title") + self.type = ( + try_enum(EmbedType, data["type"]) if data.get("type") is not None else None + ) + self.description = data.get("description") + self.url = data.get("url") + self.timestamp = ( + Timestamp.parse(data["timestamp"]) + if data.get("timestamp") is not None + else None + ) + self.color = Color(data["color"]) if data.get("color") is not None else None + + self.footer = None + self.image = None + self.thumbnail = None + self.video = None + self.provider = None + self.author = None + self.fields = [] + + if data.get("footer") is not None: + self.footer = EmbedFooter( + text=data["footer"]["text"], + icon_url=data["footer"].get("icon_url"), + proxy_icon_url=data["footer"].get("proxy_icon_url"), + ) + + if data.get("image") is not None: + self.image = EmbedImage( + url=data["image"]["url"], + proxy_url=data["image"].get("proxy_url"), + height=data["image"].get("height"), + width=data["image"].get("width"), + ) + + if data.get("video") is not None: + self.video = EmbedVideo( + url=data["video"]["url"], + proxy_url=data["video"].get("proxy_url"), + height=data["video"].get("height"), + width=data["video"].get("width"), + ) + + if data.get("provider") is not None: + self.provider = EmbedProvider( + name=data["provider"].get("name"), url=data["provider"].get("url") + ) + + if data.get("author") is not None: + self.author = EmbedAuthor( + name=data["author"]["name"], + url=data["author"].get("url"), + icon_url=data["author"].get("icon_url"), + proxy_icon_url=data["author"].get("proxy_icon_url"), + ) + + if data.get("fields") is not None: + for field in data["fields"]: + self.fields.append( + EmbedField( + name=field["name"], + value=field["value"], + inline=field["inline"] + if field.get("inline") is not None + else False, + ) + ) + + return self def __post_init__(self): if self.title and len(self.title) > 256: @@ -420,7 +503,7 @@ class Embed(APIModelBase): This embed. """ - if self.fields is UNDEFINED: + if self.fields is None: self.fields = [] self.fields.append(EmbedField(name=name, value=value, inline=inline)) diff --git a/melisa/models/message/message.py b/melisa/models/message/message.py index 11b63c1..b5a2a70 100644 --- a/melisa/models/message/message.py +++ b/melisa/models/message/message.py @@ -5,14 +5,14 @@ from __future__ import annotations from dataclasses import dataclass from enum import IntEnum -from typing import List, TYPE_CHECKING, Optional, Dict +from typing import List, TYPE_CHECKING, Optional, Dict, Any -from ...utils import Snowflake, Timestamp -from ...utils import APIModelBase +from .embed import Embed +from ...utils import Snowflake, Timestamp, try_enum, APIModelBase from ...utils.types import APINullable, UNDEFINED if TYPE_CHECKING: - from ..guild.channel import Thread + from ..guild.channel import Thread, _choose_channel_type class MessageType(IntEnum): @@ -144,7 +144,7 @@ class Message(APIModelBase): Whether this message is pinned webhook_id: :class:`~melisa.utils.types.snowflake.Snowflake` If the message is generated by a webhook, this is the webhook's id - type: :class:`int` + type: :class:`MessageType` Type of message activity: :class:`typing.Any` Sent with Rich Presence-related chat embeds @@ -170,36 +170,112 @@ class Message(APIModelBase): Deprecated the stickers sent with the message """ - id: APINullable[Snowflake] = UNDEFINED - channel_id: APINullable[Snowflake] = UNDEFINED - guild_id: APINullable[Snowflake] = UNDEFINED - author: APINullable[Dict] = UNDEFINED - member: APINullable[Dict] = UNDEFINED - content: APINullable[str] = UNDEFINED - timestamp: APINullable[Timestamp] = UNDEFINED - edited_timestamp: APINullable[Timestamp] = UNDEFINED - tts: APINullable[bool] = UNDEFINED - mention_everyone: APINullable[bool] = UNDEFINED - mentions: APINullable[List] = UNDEFINED - mention_roles: APINullable[List] = UNDEFINED - mention_channels: APINullable[List] = UNDEFINED - attachments: APINullable[List] = UNDEFINED - embeds: APINullable[List] = UNDEFINED - reactions: APINullable[List] = UNDEFINED - nonce: APINullable[int] or APINullable[str] = UNDEFINED - pinned: APINullable[bool] = UNDEFINED - webhook_id: APINullable[Snowflake] = UNDEFINED - type: APINullable[int] = UNDEFINED - activity: APINullable[Dict] = UNDEFINED - application: APINullable[Dict] = UNDEFINED - application_id: APINullable[Snowflake] = UNDEFINED - message_reference: APINullable[Dict] = UNDEFINED - flags: APINullable[int] = UNDEFINED - interaction: APINullable[Dict] = UNDEFINED - thread: APINullable[Thread] = UNDEFINED - components: APINullable[List] = UNDEFINED - sticker_items: APINullable[List] = UNDEFINED - stickers: APINullable[List] = UNDEFINED + id: APINullable[Snowflake] = None + channel_id: APINullable[Snowflake] = None + guild_id: APINullable[Snowflake] = None + author: APINullable[Dict] = None + member: APINullable[Dict] = None + content: APINullable[str] = None + timestamp: APINullable[Timestamp] = None + edited_timestamp: APINullable[Timestamp] = None + tts: APINullable[bool] = None + mention_everyone: APINullable[bool] = None + mentions: APINullable[List] = None + mention_roles: APINullable[List] = None + mention_channels: APINullable[List] = None + attachments: APINullable[List] = None + embeds: APINullable[List] = None + reactions: APINullable[List] = None + nonce: APINullable[int] or APINullable[str] = None + pinned: APINullable[bool] = None + webhook_id: APINullable[Snowflake] = None + type: APINullable[MessageType] = None + activity: APINullable[Dict] = None # ToDo Set model here + application: APINullable[Dict] = None + application_id: APINullable[Snowflake] = None + message_reference: APINullable[Dict] = None + flags: APINullable[int] = None + referenced_message: APINullable[Message] = None + interaction: APINullable[Dict] = None + thread: APINullable[Thread] = None + components: APINullable[List] = None + sticker_items: APINullable[List] = None + stickers: APINullable[List] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]): + """Generate a message from the given data. + + Parameters + ---------- + data: :class:`dict` + The dictionary to convert into an unknown channel. + """ + self: Message = super().__new__(cls) + + self.id = data["id"] + self.channel_id = Snowflake(data["channel_id"]) + self.guild_id = ( + Snowflake(data["guild_id"]) if data.get("guild_id") is not None else None + ) + self.author = data.get("author") # ToDo: User object + self.member = data.get("member") + self.content = data.get("content", "") + self.timestamp = Timestamp.parse(data["timestamp"]) + self.edited_timestamp = ( + Timestamp.parse(data["edited_timestamp"]) + if data.get("edited_timestamp") is not None + else None + ) + self.tts = data["tts"] + self.mention_everyone = data["mention_everyone"] + self.mentions = data["mentions"] # ToDo: Convert to models + self.mention_roles = data.get("mention_roles") + self.attachments = data.get("attachments", []) + self.reactions = data.get("reactions", []) + self.nonce = data.get("nonce") + self.pinned = data.get("pinned", False) + self.webhook_id = ( + Snowflake(data["webhook_id"]) + if data.get("webhook_id") is not None + else None + ) + self.type = try_enum(MessageType, data.get("type", 0)) + self.activity = data.get("activity") + self.application = data.get("application") + self.application_id = ( + Snowflake(data["application_id"]) + if data.get("application_id") is not None + else None + ) + self.message_reference = data.get( + "message_reference" + ) # ToDo: message reference object + self.flags = try_enum(MessageFlags, data.get("flags", 0)) + self.referenced_message = ( + Message.from_dict(data["referenced_message"]) + if data.get("referenced_message") is not None + else None + ) + self.interaction = data.get("interaction") + self.thread = ( + Thread.from_dict(data["thread"]) if data.get("thread") is not None else None + ) + self.components = data.get("components") + self.sticker_items = data.get("sticker_items") + self.stickers = data.get("stickers") + + self.mention_channels = [] + self.embeds = [] + + for channel in data.get("mention_channels", []): + channel = _choose_channel_type(channel) + self.mention_channels.append(channel) + + for embed in data.get("embeds", []): + self.embeds.append(Embed.from_dict(embed)) + + return self async def pin(self, *, reason: Optional[str] = None): """|coro| diff --git a/melisa/models/user/user.py b/melisa/models/user/user.py index 8936964..5cde982 100644 --- a/melisa/models/user/user.py +++ b/melisa/models/user/user.py @@ -5,11 +5,7 @@ from __future__ import annotations from enum import IntEnum from dataclasses import dataclass -from typing import ( - Optional, - Dict, - Any -) +from typing import Optional, Dict, Any from ...utils.conversion import try_enum from ...utils.api_model import APIModelBase @@ -201,13 +197,6 @@ class User(APIModelBase): @classmethod def from_dict(cls, data: Dict[str, Any]) -> User: - """Generate a user from the given data. - - Parameters - ---------- - data: :class:`dict` - The dictionary to convert into a user. - """ self: User = super().__new__(cls) self.id = int(data["id"]) @@ -222,11 +211,7 @@ class User(APIModelBase): self.local = data.get("local") self.verified = data.get("verified", False) self.email = data.get("email") - self.premium_type = try_enum( - PremiumTypes, data.get("premium_type") - ) - self.public_flags = try_enum( - UserFlags, data.get("public_flags") - ) + self.premium_type = try_enum(PremiumTypes, data.get("premium_type")) + self.public_flags = try_enum(UserFlags, data.get("public_flags")) return self diff --git a/melisa/rest.py b/melisa/rest.py new file mode 100644 index 0000000..f5e0578 --- /dev/null +++ b/melisa/rest.py @@ -0,0 +1,70 @@ +# Copyright MelisaDev 2022 - Present +# Full MIT License can be found in `LICENSE.txt` at the project root. + +from typing import Union + +from .core.http import HTTPClient +from .utils.snowflake import Snowflake +from .models.guild.guild import Guild +from .models.user.user import User +from .models.guild.channel import _choose_channel_type, Channel + + +class RESTApp: + """ + This instance may be used to send http requests to the Discord REST API. + + **It will not cache anything.** + + Parameters + ---------- + token: :class:`str` + The token to authorize (you can found it in the developer portal) + """ + + def __init__(self, token: str): + self.http: HTTPClient = HTTPClient(token) + + async def fetch_user(self, user_id: Union[Snowflake, int, str]) -> User: + """ + Fetch User from the Discord API (by id). + + Parameters + ---------- + user_id: Union[:class:`~melisa.utils.snowflake.Snowflake`, str, int] + Id of user to fetch + """ + + data = await self.http.get(f"users/{user_id}") + + return User.from_dict(data) + + async def fetch_guild(self, guild_id: Union[Snowflake, int, str]) -> Guild: + """ + Fetch Guild from the Discord API (by id). + + Parameters + ---------- + guild_id : Union[:class:`~melisa.utils.snowflake.Snowflake`, str, int] + Id of guild to fetch + """ + + data = await self.http.get(f"guilds/{guild_id}") + + return Guild.from_dict(data) + + async def fetch_channel(self, channel_id: Union[Snowflake, str, int]) -> Channel: + """ + Fetch Channel from the Discord API (by id). + + Parameters + ---------- + channel_id : Union[:class:`~melisa.utils.snowflake.Snowflake`, str, int] + Id of channel to fetch + """ + + # ToDo: Update cache if CHANNEL_CACHE enabled. + + data = await self.http.get(f"channels/{channel_id}") + + return _choose_channel_type(data) diff --git a/melisa/utils/__init__.py b/melisa/utils/__init__.py index d0ea4a2..67dd610 100644 --- a/melisa/utils/__init__.py +++ b/melisa/utils/__init__.py @@ -5,6 +5,14 @@ from .types import Coro, UNDEFINED from .timestamp import Timestamp from .snowflake import Snowflake from .api_model import APIModelBase -from .conversion import remove_none +from .conversion import remove_none, try_enum -__all__ = ("Coro", "Snowflake", "APIModelBase", "remove_none", "Timestamp", "UNDEFINED") +__all__ = ( + "Coro", + "Snowflake", + "APIModelBase", + "remove_none", + "Timestamp", + "UNDEFINED", + "try_enum", +) diff --git a/packages/speed.txt b/packages/speed.txt index d56bbbd..da80dc5 100644 --- a/packages/speed.txt +++ b/packages/speed.txt @@ -1 +1 @@ -orjson==3.6.7 \ No newline at end of file +orjson==3.6.8 \ No newline at end of file diff --git a/tests/test_embeds.py b/tests/test_embeds.py index 2e0b8cc..b6406d8 100644 --- a/tests/test_embeds.py +++ b/tests/test_embeds.py @@ -71,9 +71,3 @@ class TestEmbed: is correct. """ assert has_key_vals(EMBED.to_dict(), dict_embed) - - def test_embed_from_dict(self): - assert has_key_vals( - Embed.from_dict(dict_embed).to_dict(), - dict_embed - )