feat(cache): initial cache manager

This commit is contained in:
grey-cat-1908 2022-05-14 17:35:49 +03:00
parent c4ce2de1a4
commit bbab69c800
12 changed files with 199 additions and 25 deletions

View file

@ -8,6 +8,7 @@ import sys
import traceback import traceback
from typing import Dict, List, Union, Any, Iterable, Optional, Callable from typing import Dict, List, Union, Any, Iterable, Optional, Callable
from .models.app.cache import CacheManager
from .rest import RESTApp from .rest import RESTApp
from .core.gateway import GatewayBotInfo from .core.gateway import GatewayBotInfo
from .models.guild.channel import Channel from .models.guild.channel import Channel
@ -73,6 +74,7 @@ class Client:
mobile: bool = False, mobile: bool = False,
allowed_mentions: Optional[AllowedMentions] = None, allowed_mentions: Optional[AllowedMentions] = None,
logs: Union[None, int, str, Dict[str, Any]] = "INFO", logs: Union[None, int, str, Dict[str, Any]] = "INFO",
cache: CacheManager = None
): ):
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
@ -82,8 +84,7 @@ class Client:
self._events: Dict[str, Coro] = {} self._events: Dict[str, Coro] = {}
self._waiter_mgr = WaiterMgr(self._loop) self._waiter_mgr = WaiterMgr(self._loop)
# ToDo: Transfer guilds in to the cache manager self.cache = cache if cache is not None else CacheManager()
self.guilds = {}
self.user = None self.user = None
self._gateway_info = self._loop.run_until_complete(self._get_gateway()) self._gateway_info = self._loop.run_until_complete(self._get_gateway())
@ -104,8 +105,6 @@ class Client:
self._mobile: bool = mobile self._mobile: bool = mobile
self.allowed_mentions: AllowedMentions = allowed_mentions self.allowed_mentions: AllowedMentions = allowed_mentions
self._none_guilds_cached = False
APIModelBase.set_client(self) APIModelBase.set_client(self)
init_logging(logs) init_logging(logs)
@ -143,7 +142,7 @@ class Client:
_logger.debug(f"Listener {callback.__qualname__} added successfully!") _logger.debug(f"Listener {callback.__qualname__} added successfully!")
return self return self
async def dispatch(self, name: str, *args): async def dispatch(self, name: str, args):
""" """
Dispatches an event Dispatches an event
Parameters Parameters
@ -165,7 +164,7 @@ class Client:
print(f"Ignoring exception in {name}", file=sys.stderr) print(f"Ignoring exception in {name}", file=sys.stderr)
traceback.print_exc() traceback.print_exc()
self._waiter_mgr.process_events(name, *args) self._waiter_mgr.process_events(name, args)
def run(self): def run(self):
""" """
@ -249,7 +248,7 @@ class Client:
Id of guild to fetch Id of guild to fetch
""" """
# ToDo: Update cache if GUILD_CACHE enabled. # ToDo: Update cache if FULL_GUILD_CACHE enabled.
return await self.rest.fetch_guild(guild_id) return await self.rest.fetch_guild(guild_id)

View file

@ -10,7 +10,7 @@ from ..utils.types import Coro
async def channel_create_listener(self, gateway, payload: dict): async def channel_create_listener(self, gateway, payload: dict):
channel = _choose_channel_type(payload) channel = _choose_channel_type(payload)
await self.dispatch("on_channel_create", channel) await self.dispatch("on_channel_create", (channel, ))
return return

View file

@ -14,7 +14,7 @@ async def channel_delete_listener(self, gateway, payload: dict):
channel = channel_cls.from_dict(payload) channel = channel_cls.from_dict(payload)
await self.dispatch("on_channel_delete", channel) await self.dispatch("on_channel_delete", (channel, ))
return return

View file

@ -12,13 +12,13 @@ async def guild_create_listener(self, gateway, payload: dict):
guild = Guild.from_dict(payload) guild = Guild.from_dict(payload)
if self.guilds.get(guild.id, "empty") != "empty": if self.cache.get_guild(guild.id) is not None:
guild_was_cached_as_none = True guild_was_cached_as_none = True
self.guilds[str(guild.id)] = guild self.cache.set_guild(guild)
if guild_was_cached_as_none is False: if guild_was_cached_as_none is False:
await self.dispatch("on_guild_create", guild) await self.dispatch("on_guild_create", (guild, ))
return return

View file

@ -8,11 +8,14 @@ from ..models.guild import UnavailableGuild
async def guild_delete_listener(self, gateway, payload: dict): async def guild_delete_listener(self, gateway, payload: dict):
guild = UnavailableGuild.from_dict(payload) guild = self.cache.get_guild(payload["data"]["id"])
self.guilds.pop(guild.id, None) if guild is None:
guild = UnavailableGuild.from_dict(payload)
await self.dispatch("on_guild_remove", guild) self.cache.remove_guild(guild.id)
await self.dispatch("on_guild_remove", (guild, ))
return return

View file

@ -9,9 +9,9 @@ from ..models.guild import Guild
async def guild_update_listener(self, gateway, payload: dict): async def guild_update_listener(self, gateway, payload: dict):
new_guild = Guild.from_dict(payload) new_guild = Guild.from_dict(payload)
old_guild = self.guilds.get(new_guild.id) old_guild = self.cache.get_guild(payload["id"])
self.guilds[new_guild.id] = new_guild self.cache.set_guild(new_guild)
await self.dispatch("on_guild_update", (old_guild, new_guild)) await self.dispatch("on_guild_update", (old_guild, new_guild))

View file

@ -9,7 +9,7 @@ from ..utils.types import Coro
async def message_create_listener(self, gateway, payload: dict): async def message_create_listener(self, gateway, payload: dict):
message = Message.from_dict(payload) message = Message.from_dict(payload)
await self.dispatch("on_message_create", message) await self.dispatch("on_message_create", (message, ))
return return

View file

@ -12,13 +12,11 @@ async def on_ready_listener(self, gateway, payload: dict):
guilds = payload.get("guilds") guilds = payload.get("guilds")
if self._none_guilds_cached is False: self.cache._set_none_guilds(guilds)
self.guilds = dict(map(lambda i: (i["id"], None), guilds))
self._none_guilds_cached = True
self.user = User.from_dict(payload.get("user")) self.user = User.from_dict(payload.get("user"))
await self.dispatch("on_shard_ready", gateway.shard_id) await self.dispatch("on_shard_ready", (gateway.shard_id, ))
return return

158
melisa/models/app/cache.py Normal file
View file

@ -0,0 +1,158 @@
# Copyright MelisaDev 2022 - Present
# Full MIT License can be found in `LICENSE.txt` at the project root.
from __future__ import annotations
from enum import Enum
from typing import List, Dict, Optional, Any, Union
from melisa.utils.types import UNDEFINED
from melisa.models.guild import Guild, ChannelType, UnavailableGuild
from melisa.utils.snowflake import Snowflake
class AutoCacheModels(Enum):
# ToDo: Add FULL_GUILD auto cache model
""" """
GUILD_ROLES = "GUILD_ROLES"
GUILD_THREADS = "GUILD_THREADS"
GUILD_EMOJIS = "GUILD_EMOJIS"
GUILD_WEBHOOKS = "GUILD_WEBHOOKS"
GUILD_MEMBERS = "GUILD_MEMBERS"
TEXT_CHANNELS = "TEXT_CHANNELS"
class CacheManager:
""" """
def __init__(
self,
*,
auto_models: Optional[List[AutoCacheModels]] = None,
auto_unused_attributes: Optional[Dict[Any, List[str]]] = None
):
self._auto_models: List[AutoCacheModels] = (
[] if auto_models is None else auto_models
)
self.auto_unused_attributes: Dict[Any, List[str]] = (
{} if auto_unused_attributes is not None else auto_unused_attributes
)
self._raw_guilds: Dict[Snowflake, Any] = {}
self._raw_users: Dict[Snowflake, Any] = {}
self._raw_dm_channels: Dict[Snowflake, Any] = {}
# We use symlinks to cache guild channels
# like we save channel in Guild and save it here
# and if you need channel, and you don't know its guild
# you can use special method, and it will find it in guild
self._channel_symlinks: Dict[Snowflake, Snowflake] = {}
def guilds_count(self) -> int:
"""Cached Guilds Count"""
return len(self._raw_guilds)
def users_count(self) -> int:
"""Cached Users Count"""
return len(self._raw_users)
def guild_channels_count(self) -> int:
"""Cached Guild Channels Count"""
return len(self._channel_symlinks)
def total_channels_count(self) -> int:
"""Total Cached Channel Count"""
return len(self._raw_dm_channels) + len(self._channel_symlinks)
def __remove_unused_attributes(self, model, _type):
if self.auto_unused_attributes is None:
self.auto_unused_attributes = {}
unused_attributes = self.auto_unused_attributes.get(_type)
if unused_attributes and unused_attributes is not None:
unused_attributes = unused_attributes.__dict__.keys()
for attr in unused_attributes:
model.__delattr__(attr)
return model
def set_guild(self, guild: Optional[Guild] = None):
"""
Save Guild into cache
Parameters
----------
guild: Optional[`~melisa.models.guild.Guild`]
Guild to save into cache
"""
if guild is None:
return None
guild = self.__remove_unused_attributes(guild, Guild)
if hasattr(guild, "channels"):
channels = guild.channels.values()
if not AutoCacheModels.TEXT_CHANNELS in self._auto_models:
channels = filter(
lambda channel: channel.type != ChannelType.GUILD_TEXT, channels
)
for sym in channels:
if self._channel_symlinks.get(sym.id, UNDEFINED) is not UNDEFINED:
self._channel_symlinks.pop(sym.id)
self._channel_symlinks[sym.id] = guild.id
self._raw_guilds.update({guild.id: guild})
return guild
def get_guild(self, guild_id: Union[Snowflake, str, int]):
"""
Get guild from cache
Parameters
----------
guild_id: Optional[:class:`~melisa.utils.snowflake.Snowflake`, `str`, `int`]
ID of guild to get from cache.
"""
if not isinstance(guild_id, Snowflake):
guild_id = Snowflake(int(guild_id))
return self._raw_guilds.get(guild_id, None)
def _set_none_guilds(self, guilds: List[Dict[str, Any]]) -> None:
"""
Insert None-Guilds to cache
Parameters
----------
guilds: Optional[:class:`~melisa.utils.snowflake.Snowflake`, `str`, `int`]
Data of guilds tso insert to the cache
"""
guilds_dict = dict(map(lambda i: (i["id"], UnavailableGuild.from_dict(i)), guilds))
self._raw_guilds.update(guilds_dict)
return None
def remove_guild(self, guild_id: Union[Snowflake, str, int]):
"""
Remove guild from cache
Parameters
----------
guild_id: Optional[:class:`~melisa.utils.snowflake.Snowflake`, `str`, `int`]
ID of guild to remove from cache.
"""
if not isinstance(guild_id, Snowflake):
guild_id = Snowflake(int(guild_id))
return self._raw_guilds.pop(guild_id, None)

View file

@ -18,7 +18,7 @@ if TYPE_CHECKING:
@dataclass(repr=False) @dataclass(repr=False)
class Emoji(APIModelBase): class Emoji(APIModelBase):
"""Emoji Structure """Emoji Structure
Attributes Attributes
---------- ----------
id: :class:`~melisa.utils.types.Snowflake` id: :class:`~melisa.utils.types.Snowflake`

View file

@ -586,3 +586,19 @@ class UnavailableGuild(APIModelBase):
id: APINullable[Snowflake] = None id: APINullable[Snowflake] = None
unavailable: APINullable[bool] = True unavailable: APINullable[bool] = True
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> UnavailableGuild:
"""Generate a unavailable guild from the given data.
Parameters
----------
data: :class:`dict`
The dictionary to convert into an unavailable guild.
"""
self: UnavailableGuild = super().__new__(cls)
self.id = Snowflake(int(data["id"]))
self.unavailable = data["unavailable"]
return self

View file

@ -111,12 +111,12 @@ class Webhook(APIModelBase):
self.source_guild = data.get("source_guild", {}) self.source_guild = data.get("source_guild", {})
else: else:
self.source_guild = None self.source_guild = None
if data.get("source_channel") is not None: if data.get("source_channel") is not None:
self.source_channel = data.get("source_channel", {}) self.source_channel = data.get("source_channel", {})
else: else:
self.source_channel = None self.source_channel = None
self.url = data.get("url") self.url = data.get("url")
return self return self