add guild channel cache (set/get)

This commit is contained in:
grey-cat-1908 2022-05-19 17:58:19 +03:00
parent 78745c3334
commit 23c0ebbefb
4 changed files with 90 additions and 9 deletions

View file

@ -11,7 +11,7 @@ from typing import Dict, List, Union, Any, Iterable, Optional, Callable
from .models.app.cache import CacheManager
from .rest import RESTApp
from .core.gateway import GatewayBotInfo
from .models.guild.channel import Channel
from .models.guild.channel import Channel, ChannelType
from .models import Activity, AllowedMentions
from .models.app.shard import Shard
from .models.app.intents import Intents
@ -266,7 +266,12 @@ class Client:
# ToDo: Update cache if CHANNEL_CACHE enabled.
return await self.rest.fetch_channel(channel_id)
data = await self.rest.fetch_channel(channel_id)
if data.type not in [ChannelType.DM, ChannelType.GROUP_DM]:
self.cache.set_guild_channel(data)
return data
async def wait_for(
self,

View file

@ -7,7 +7,7 @@ from enum import Enum
from typing import List, Dict, Optional, Any, Union
from melisa.utils.types import UNDEFINED
from melisa.models.guild import Guild, ChannelType, UnavailableGuild
from melisa.models.guild import Guild, ChannelType, UnavailableGuild, Channel
from melisa.utils.snowflake import Snowflake
@ -16,6 +16,7 @@ class AutoCacheModels(Enum):
""" """
FULL_GUILDS = "FULL_GUILDS"
GUILD_ROLES = "GUILD_ROLES"
GUILD_THREADS = "GUILD_THREADS"
GUILD_EMOJIS = "GUILD_EMOJIS"
@ -30,6 +31,7 @@ class CacheManager:
def __init__(
self,
*,
disabled: bool = False,
auto_models: Optional[List[AutoCacheModels]] = None,
auto_unused_attributes: Optional[Dict[Any, List[str]]] = None,
):
@ -44,6 +46,8 @@ class CacheManager:
self._raw_users: Dict[Snowflake, Any] = {}
self._raw_dm_channels: Dict[Snowflake, Any] = {}
self._disabled = disabled
# We use symlinks to cache guild channels
# like we save channel in Guild and save it here
# and if you need channel, and you don't know its guild
@ -90,6 +94,9 @@ class CacheManager:
Guild to save into cache
"""
if self._disabled:
return
if guild is None:
return None
@ -104,15 +111,75 @@ class CacheManager:
)
for sym in channels:
if self._channel_symlinks.get(sym.id, UNDEFINED) is not UNDEFINED:
self._channel_symlinks.pop(sym.id)
sym_id = Snowflake(int(sym.id))
if self._channel_symlinks.get(sym_id, UNDEFINED) is not UNDEFINED:
self._channel_symlinks.pop(sym_id)
self._channel_symlinks[sym.id] = guild.id
self._channel_symlinks[sym_id] = guild.id
self._raw_guilds.update({guild.id: guild})
return guild
def set_guild_channel(self, channel: Optional[Channel] = None):
"""
Save Guild into cache
Parameters
----------
channel: Optional[`~melisa.models.guild.Channel`]
Guild Channel to save into cache
"""
if self._disabled:
return
if channel is None:
return None
channel = self.__remove_unused_attributes(channel, Channel) # ToDo: add channel type
guild = self._raw_guilds.get(channel.guild_id, UNDEFINED)
channel_id = Snowflake(int(channel.id))
if guild != UNDEFINED:
if hasattr(guild, "channels"):
self._raw_guilds[guild.id].channels.update({channel_id: channel})
self._channel_symlinks.update({channel_id: channel.guild_id})
return channel
def get_guild_channel(self, channel_id: Union[Snowflake, str, int]):
"""
Get guild channel from cache
Parameters
----------
channel_id: Optional[:class:`~melisa.utils.snowflake.Snowflake`, `str`, `int`]
ID of guild channel to get from cache.
"""
if self._disabled:
return None
channel_id = Snowflake(int(channel_id))
guild_id = self._channel_symlinks.get(channel_id, UNDEFINED)
if guild_id == UNDEFINED:
return None
guild = self.get_guild(guild_id)
if guild is None:
return None
if hasattr(guild, "channels") is False:
return None
return guild.channels.get(channel_id)
def get_guild(self, guild_id: Union[Snowflake, str, int]):
"""
Get guild from cache
@ -123,6 +190,9 @@ class CacheManager:
ID of guild to get from cache.
"""
if self._disabled:
return None
if not isinstance(guild_id, Snowflake):
guild_id = Snowflake(int(guild_id))
return self._raw_guilds.get(guild_id, None)
@ -137,8 +207,11 @@ class CacheManager:
Data of guilds tso insert to the cache
"""
if self._disabled:
return
guilds_dict = dict(
map(lambda i: (i["id"], UnavailableGuild.from_dict(i)), guilds)
map(lambda i: (Snowflake(int(i["id"])), UnavailableGuild.from_dict(i)), guilds)
)
self._raw_guilds.update(guilds_dict)
@ -154,6 +227,9 @@ class CacheManager:
ID of guild to remove from cache.
"""
if self._disabled:
return
if not isinstance(guild_id, Snowflake):
guild_id = Snowflake(int(guild_id))

View file

@ -769,7 +769,7 @@ class TextChannel(MessageableChannel):
self: TextChannel = super().__new__(cls)
self.id = data["id"]
self.type = data["type"]
self.type = ChannelType(data["type"])
self.position = data.get("position")
self.permission_overwrites = data["permission_overwrites"]
self.name = data.get("name")

View file

@ -470,7 +470,7 @@ class Guild(APIModelBase):
for channel in data.get("channels", []):
channel = _choose_channel_type(channel)
self.channels[channel.id] = channel
self.channels[Snowflake(int(channel.id))] = channel
return self