From bbab69c8007a82aadf702424acce0a605096fd63 Mon Sep 17 00:00:00 2001 From: grey-cat-1908 Date: Sat, 14 May 2022 17:35:49 +0300 Subject: [PATCH] feat(cache): initial cache manager --- melisa/client.py | 13 ++- melisa/listeners/channel_create.py | 2 +- melisa/listeners/channel_delete.py | 2 +- melisa/listeners/guild_create.py | 6 +- melisa/listeners/guild_remove.py | 9 +- melisa/listeners/guild_update.py | 4 +- melisa/listeners/message_create.py | 2 +- melisa/listeners/ready.py | 6 +- melisa/models/app/cache.py | 158 +++++++++++++++++++++++++++++ melisa/models/guild/emoji.py | 2 +- melisa/models/guild/guild.py | 16 +++ melisa/models/guild/webhook.py | 4 +- 12 files changed, 199 insertions(+), 25 deletions(-) create mode 100644 melisa/models/app/cache.py diff --git a/melisa/client.py b/melisa/client.py index 8065797..5273db2 100644 --- a/melisa/client.py +++ b/melisa/client.py @@ -8,6 +8,7 @@ import sys import traceback from typing import Dict, List, Union, Any, Iterable, Optional, Callable +from .models.app.cache import CacheManager from .rest import RESTApp from .core.gateway import GatewayBotInfo from .models.guild.channel import Channel @@ -73,6 +74,7 @@ class Client: mobile: bool = False, allowed_mentions: Optional[AllowedMentions] = None, logs: Union[None, int, str, Dict[str, Any]] = "INFO", + cache: CacheManager = None ): self._loop = asyncio.get_event_loop() @@ -82,8 +84,7 @@ class Client: self._events: Dict[str, Coro] = {} self._waiter_mgr = WaiterMgr(self._loop) - # ToDo: Transfer guilds in to the cache manager - self.guilds = {} + self.cache = cache if cache is not None else CacheManager() self.user = None self._gateway_info = self._loop.run_until_complete(self._get_gateway()) @@ -104,8 +105,6 @@ class Client: self._mobile: bool = mobile self.allowed_mentions: AllowedMentions = allowed_mentions - self._none_guilds_cached = False - APIModelBase.set_client(self) init_logging(logs) @@ -143,7 +142,7 @@ class Client: _logger.debug(f"Listener {callback.__qualname__} added successfully!") return self - async def dispatch(self, name: str, *args): + async def dispatch(self, name: str, args): """ Dispatches an event Parameters @@ -165,7 +164,7 @@ class Client: print(f"Ignoring exception in {name}", file=sys.stderr) traceback.print_exc() - self._waiter_mgr.process_events(name, *args) + self._waiter_mgr.process_events(name, args) def run(self): """ @@ -249,7 +248,7 @@ class Client: Id of guild to fetch """ - # ToDo: Update cache if GUILD_CACHE enabled. + # ToDo: Update cache if FULL_GUILD_CACHE enabled. return await self.rest.fetch_guild(guild_id) diff --git a/melisa/listeners/channel_create.py b/melisa/listeners/channel_create.py index 81fe5f2..bc96bf5 100644 --- a/melisa/listeners/channel_create.py +++ b/melisa/listeners/channel_create.py @@ -10,7 +10,7 @@ from ..utils.types import Coro async def channel_create_listener(self, gateway, payload: dict): channel = _choose_channel_type(payload) - await self.dispatch("on_channel_create", channel) + await self.dispatch("on_channel_create", (channel, )) return diff --git a/melisa/listeners/channel_delete.py b/melisa/listeners/channel_delete.py index c2e57ff..e80afab 100644 --- a/melisa/listeners/channel_delete.py +++ b/melisa/listeners/channel_delete.py @@ -14,7 +14,7 @@ async def channel_delete_listener(self, gateway, payload: dict): channel = channel_cls.from_dict(payload) - await self.dispatch("on_channel_delete", channel) + await self.dispatch("on_channel_delete", (channel, )) return diff --git a/melisa/listeners/guild_create.py b/melisa/listeners/guild_create.py index 410cbd0..6cee280 100644 --- a/melisa/listeners/guild_create.py +++ b/melisa/listeners/guild_create.py @@ -12,13 +12,13 @@ async def guild_create_listener(self, gateway, payload: dict): guild = Guild.from_dict(payload) - if self.guilds.get(guild.id, "empty") != "empty": + if self.cache.get_guild(guild.id) is not None: guild_was_cached_as_none = True - self.guilds[str(guild.id)] = guild + self.cache.set_guild(guild) if guild_was_cached_as_none is False: - await self.dispatch("on_guild_create", guild) + await self.dispatch("on_guild_create", (guild, )) return diff --git a/melisa/listeners/guild_remove.py b/melisa/listeners/guild_remove.py index ba351b8..6bac71a 100644 --- a/melisa/listeners/guild_remove.py +++ b/melisa/listeners/guild_remove.py @@ -8,11 +8,14 @@ from ..models.guild import UnavailableGuild async def guild_delete_listener(self, gateway, payload: dict): - guild = UnavailableGuild.from_dict(payload) + guild = self.cache.get_guild(payload["data"]["id"]) - self.guilds.pop(guild.id, None) + if guild is None: + guild = UnavailableGuild.from_dict(payload) - await self.dispatch("on_guild_remove", guild) + self.cache.remove_guild(guild.id) + + await self.dispatch("on_guild_remove", (guild, )) return diff --git a/melisa/listeners/guild_update.py b/melisa/listeners/guild_update.py index f90cce4..a3328f2 100644 --- a/melisa/listeners/guild_update.py +++ b/melisa/listeners/guild_update.py @@ -9,9 +9,9 @@ from ..models.guild import Guild async def guild_update_listener(self, gateway, payload: dict): new_guild = Guild.from_dict(payload) - old_guild = self.guilds.get(new_guild.id) + old_guild = self.cache.get_guild(payload["id"]) - self.guilds[new_guild.id] = new_guild + self.cache.set_guild(new_guild) await self.dispatch("on_guild_update", (old_guild, new_guild)) diff --git a/melisa/listeners/message_create.py b/melisa/listeners/message_create.py index 1b4881b..6a492c7 100644 --- a/melisa/listeners/message_create.py +++ b/melisa/listeners/message_create.py @@ -9,7 +9,7 @@ from ..utils.types import Coro async def message_create_listener(self, gateway, payload: dict): message = Message.from_dict(payload) - await self.dispatch("on_message_create", message) + await self.dispatch("on_message_create", (message, )) return diff --git a/melisa/listeners/ready.py b/melisa/listeners/ready.py index b5b40b2..7c20f61 100644 --- a/melisa/listeners/ready.py +++ b/melisa/listeners/ready.py @@ -12,13 +12,11 @@ async def on_ready_listener(self, gateway, payload: dict): guilds = payload.get("guilds") - if self._none_guilds_cached is False: - self.guilds = dict(map(lambda i: (i["id"], None), guilds)) - self._none_guilds_cached = True + self.cache._set_none_guilds(guilds) self.user = User.from_dict(payload.get("user")) - await self.dispatch("on_shard_ready", gateway.shard_id) + await self.dispatch("on_shard_ready", (gateway.shard_id, )) return diff --git a/melisa/models/app/cache.py b/melisa/models/app/cache.py new file mode 100644 index 0000000..88252d2 --- /dev/null +++ b/melisa/models/app/cache.py @@ -0,0 +1,158 @@ +# Copyright MelisaDev 2022 - Present +# Full MIT License can be found in `LICENSE.txt` at the project root. + +from __future__ import annotations + +from enum import Enum +from typing import List, Dict, Optional, Any, Union + +from melisa.utils.types import UNDEFINED +from melisa.models.guild import Guild, ChannelType, UnavailableGuild +from melisa.utils.snowflake import Snowflake + + +class AutoCacheModels(Enum): + # ToDo: Add FULL_GUILD auto cache model + + """ """ + + GUILD_ROLES = "GUILD_ROLES" + GUILD_THREADS = "GUILD_THREADS" + GUILD_EMOJIS = "GUILD_EMOJIS" + GUILD_WEBHOOKS = "GUILD_WEBHOOKS" + GUILD_MEMBERS = "GUILD_MEMBERS" + TEXT_CHANNELS = "TEXT_CHANNELS" + + +class CacheManager: + """ """ + + def __init__( + self, + *, + auto_models: Optional[List[AutoCacheModels]] = None, + auto_unused_attributes: Optional[Dict[Any, List[str]]] = None + ): + self._auto_models: List[AutoCacheModels] = ( + [] if auto_models is None else auto_models + ) + self.auto_unused_attributes: Dict[Any, List[str]] = ( + {} if auto_unused_attributes is not None else auto_unused_attributes + ) + + self._raw_guilds: Dict[Snowflake, Any] = {} + self._raw_users: Dict[Snowflake, Any] = {} + self._raw_dm_channels: Dict[Snowflake, Any] = {} + + # We use symlinks to cache guild channels + # like we save channel in Guild and save it here + # and if you need channel, and you don't know its guild + # you can use special method, and it will find it in guild + self._channel_symlinks: Dict[Snowflake, Snowflake] = {} + + def guilds_count(self) -> int: + """Cached Guilds Count""" + return len(self._raw_guilds) + + def users_count(self) -> int: + """Cached Users Count""" + return len(self._raw_users) + + def guild_channels_count(self) -> int: + """Cached Guild Channels Count""" + return len(self._channel_symlinks) + + def total_channels_count(self) -> int: + """Total Cached Channel Count""" + return len(self._raw_dm_channels) + len(self._channel_symlinks) + + def __remove_unused_attributes(self, model, _type): + if self.auto_unused_attributes is None: + self.auto_unused_attributes = {} + + unused_attributes = self.auto_unused_attributes.get(_type) + + if unused_attributes and unused_attributes is not None: + unused_attributes = unused_attributes.__dict__.keys() + + for attr in unused_attributes: + model.__delattr__(attr) + + return model + + def set_guild(self, guild: Optional[Guild] = None): + """ + Save Guild into cache + + Parameters + ---------- + guild: Optional[`~melisa.models.guild.Guild`] + Guild to save into cache + """ + + if guild is None: + return None + + guild = self.__remove_unused_attributes(guild, Guild) + + if hasattr(guild, "channels"): + channels = guild.channels.values() + + if not AutoCacheModels.TEXT_CHANNELS in self._auto_models: + channels = filter( + lambda channel: channel.type != ChannelType.GUILD_TEXT, channels + ) + + for sym in channels: + if self._channel_symlinks.get(sym.id, UNDEFINED) is not UNDEFINED: + self._channel_symlinks.pop(sym.id) + + self._channel_symlinks[sym.id] = guild.id + + self._raw_guilds.update({guild.id: guild}) + + return guild + + def get_guild(self, guild_id: Union[Snowflake, str, int]): + """ + Get guild from cache + + Parameters + ---------- + guild_id: Optional[:class:`~melisa.utils.snowflake.Snowflake`, `str`, `int`] + ID of guild to get from cache. + """ + + if not isinstance(guild_id, Snowflake): + guild_id = Snowflake(int(guild_id)) + return self._raw_guilds.get(guild_id, None) + + def _set_none_guilds(self, guilds: List[Dict[str, Any]]) -> None: + """ + Insert None-Guilds to cache + + Parameters + ---------- + guilds: Optional[:class:`~melisa.utils.snowflake.Snowflake`, `str`, `int`] + Data of guilds tso insert to the cache + """ + + guilds_dict = dict(map(lambda i: (i["id"], UnavailableGuild.from_dict(i)), guilds)) + + self._raw_guilds.update(guilds_dict) + return None + + def remove_guild(self, guild_id: Union[Snowflake, str, int]): + """ + Remove guild from cache + + Parameters + ---------- + guild_id: Optional[:class:`~melisa.utils.snowflake.Snowflake`, `str`, `int`] + ID of guild to remove from cache. + """ + + if not isinstance(guild_id, Snowflake): + guild_id = Snowflake(int(guild_id)) + + return self._raw_guilds.pop(guild_id, None) diff --git a/melisa/models/guild/emoji.py b/melisa/models/guild/emoji.py index b49debb..f2d177a 100644 --- a/melisa/models/guild/emoji.py +++ b/melisa/models/guild/emoji.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: @dataclass(repr=False) class Emoji(APIModelBase): """Emoji Structure - + Attributes ---------- id: :class:`~melisa.utils.types.Snowflake` diff --git a/melisa/models/guild/guild.py b/melisa/models/guild/guild.py index 7cae9d0..42c95fe 100644 --- a/melisa/models/guild/guild.py +++ b/melisa/models/guild/guild.py @@ -586,3 +586,19 @@ class UnavailableGuild(APIModelBase): id: APINullable[Snowflake] = None unavailable: APINullable[bool] = True + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> UnavailableGuild: + """Generate a unavailable guild from the given data. + + Parameters + ---------- + data: :class:`dict` + The dictionary to convert into an unavailable guild. + """ + self: UnavailableGuild = super().__new__(cls) + + self.id = Snowflake(int(data["id"])) + self.unavailable = data["unavailable"] + + return self diff --git a/melisa/models/guild/webhook.py b/melisa/models/guild/webhook.py index 5687e29..93ffa9b 100644 --- a/melisa/models/guild/webhook.py +++ b/melisa/models/guild/webhook.py @@ -111,12 +111,12 @@ class Webhook(APIModelBase): self.source_guild = data.get("source_guild", {}) else: self.source_guild = None - + if data.get("source_channel") is not None: self.source_channel = data.get("source_channel", {}) else: self.source_channel = None - + self.url = data.get("url") return self