add guild channel cache (set/get)

This commit is contained in:
grey-cat-1908 2022-05-19 17:58:19 +03:00
parent 78745c3334
commit 23c0ebbefb
4 changed files with 90 additions and 9 deletions

View file

@ -11,7 +11,7 @@ from typing import Dict, List, Union, Any, Iterable, Optional, Callable
from .models.app.cache import CacheManager from .models.app.cache import CacheManager
from .rest import RESTApp from .rest import RESTApp
from .core.gateway import GatewayBotInfo from .core.gateway import GatewayBotInfo
from .models.guild.channel import Channel from .models.guild.channel import Channel, ChannelType
from .models import Activity, AllowedMentions from .models import Activity, AllowedMentions
from .models.app.shard import Shard from .models.app.shard import Shard
from .models.app.intents import Intents from .models.app.intents import Intents
@ -266,7 +266,12 @@ class Client:
# ToDo: Update cache if CHANNEL_CACHE enabled. # ToDo: Update cache if CHANNEL_CACHE enabled.
return await self.rest.fetch_channel(channel_id) data = await self.rest.fetch_channel(channel_id)
if data.type not in [ChannelType.DM, ChannelType.GROUP_DM]:
self.cache.set_guild_channel(data)
return data
async def wait_for( async def wait_for(
self, self,

View file

@ -7,7 +7,7 @@ from enum import Enum
from typing import List, Dict, Optional, Any, Union from typing import List, Dict, Optional, Any, Union
from melisa.utils.types import UNDEFINED from melisa.utils.types import UNDEFINED
from melisa.models.guild import Guild, ChannelType, UnavailableGuild from melisa.models.guild import Guild, ChannelType, UnavailableGuild, Channel
from melisa.utils.snowflake import Snowflake from melisa.utils.snowflake import Snowflake
@ -16,6 +16,7 @@ class AutoCacheModels(Enum):
""" """ """ """
FULL_GUILDS = "FULL_GUILDS"
GUILD_ROLES = "GUILD_ROLES" GUILD_ROLES = "GUILD_ROLES"
GUILD_THREADS = "GUILD_THREADS" GUILD_THREADS = "GUILD_THREADS"
GUILD_EMOJIS = "GUILD_EMOJIS" GUILD_EMOJIS = "GUILD_EMOJIS"
@ -30,6 +31,7 @@ class CacheManager:
def __init__( def __init__(
self, self,
*, *,
disabled: bool = False,
auto_models: Optional[List[AutoCacheModels]] = None, auto_models: Optional[List[AutoCacheModels]] = None,
auto_unused_attributes: Optional[Dict[Any, List[str]]] = None, auto_unused_attributes: Optional[Dict[Any, List[str]]] = None,
): ):
@ -44,6 +46,8 @@ class CacheManager:
self._raw_users: Dict[Snowflake, Any] = {} self._raw_users: Dict[Snowflake, Any] = {}
self._raw_dm_channels: Dict[Snowflake, Any] = {} self._raw_dm_channels: Dict[Snowflake, Any] = {}
self._disabled = disabled
# We use symlinks to cache guild channels # We use symlinks to cache guild channels
# like we save channel in Guild and save it here # like we save channel in Guild and save it here
# and if you need channel, and you don't know its guild # and if you need channel, and you don't know its guild
@ -90,6 +94,9 @@ class CacheManager:
Guild to save into cache Guild to save into cache
""" """
if self._disabled:
return
if guild is None: if guild is None:
return None return None
@ -104,15 +111,75 @@ class CacheManager:
) )
for sym in channels: for sym in channels:
if self._channel_symlinks.get(sym.id, UNDEFINED) is not UNDEFINED: sym_id = Snowflake(int(sym.id))
self._channel_symlinks.pop(sym.id) if self._channel_symlinks.get(sym_id, UNDEFINED) is not UNDEFINED:
self._channel_symlinks.pop(sym_id)
self._channel_symlinks[sym.id] = guild.id self._channel_symlinks[sym_id] = guild.id
self._raw_guilds.update({guild.id: guild}) self._raw_guilds.update({guild.id: guild})
return guild return guild
def set_guild_channel(self, channel: Optional[Channel] = None):
"""
Save Guild into cache
Parameters
----------
channel: Optional[`~melisa.models.guild.Channel`]
Guild Channel to save into cache
"""
if self._disabled:
return
if channel is None:
return None
channel = self.__remove_unused_attributes(channel, Channel) # ToDo: add channel type
guild = self._raw_guilds.get(channel.guild_id, UNDEFINED)
channel_id = Snowflake(int(channel.id))
if guild != UNDEFINED:
if hasattr(guild, "channels"):
self._raw_guilds[guild.id].channels.update({channel_id: channel})
self._channel_symlinks.update({channel_id: channel.guild_id})
return channel
def get_guild_channel(self, channel_id: Union[Snowflake, str, int]):
"""
Get guild channel from cache
Parameters
----------
channel_id: Optional[:class:`~melisa.utils.snowflake.Snowflake`, `str`, `int`]
ID of guild channel to get from cache.
"""
if self._disabled:
return None
channel_id = Snowflake(int(channel_id))
guild_id = self._channel_symlinks.get(channel_id, UNDEFINED)
if guild_id == UNDEFINED:
return None
guild = self.get_guild(guild_id)
if guild is None:
return None
if hasattr(guild, "channels") is False:
return None
return guild.channels.get(channel_id)
def get_guild(self, guild_id: Union[Snowflake, str, int]): def get_guild(self, guild_id: Union[Snowflake, str, int]):
""" """
Get guild from cache Get guild from cache
@ -123,6 +190,9 @@ class CacheManager:
ID of guild to get from cache. ID of guild to get from cache.
""" """
if self._disabled:
return None
if not isinstance(guild_id, Snowflake): if not isinstance(guild_id, Snowflake):
guild_id = Snowflake(int(guild_id)) guild_id = Snowflake(int(guild_id))
return self._raw_guilds.get(guild_id, None) return self._raw_guilds.get(guild_id, None)
@ -137,8 +207,11 @@ class CacheManager:
Data of guilds tso insert to the cache Data of guilds tso insert to the cache
""" """
if self._disabled:
return
guilds_dict = dict( guilds_dict = dict(
map(lambda i: (i["id"], UnavailableGuild.from_dict(i)), guilds) map(lambda i: (Snowflake(int(i["id"])), UnavailableGuild.from_dict(i)), guilds)
) )
self._raw_guilds.update(guilds_dict) self._raw_guilds.update(guilds_dict)
@ -154,6 +227,9 @@ class CacheManager:
ID of guild to remove from cache. ID of guild to remove from cache.
""" """
if self._disabled:
return
if not isinstance(guild_id, Snowflake): if not isinstance(guild_id, Snowflake):
guild_id = Snowflake(int(guild_id)) guild_id = Snowflake(int(guild_id))

View file

@ -769,7 +769,7 @@ class TextChannel(MessageableChannel):
self: TextChannel = super().__new__(cls) self: TextChannel = super().__new__(cls)
self.id = data["id"] self.id = data["id"]
self.type = data["type"] self.type = ChannelType(data["type"])
self.position = data.get("position") self.position = data.get("position")
self.permission_overwrites = data["permission_overwrites"] self.permission_overwrites = data["permission_overwrites"]
self.name = data.get("name") self.name = data.get("name")

View file

@ -470,7 +470,7 @@ class Guild(APIModelBase):
for channel in data.get("channels", []): for channel in data.get("channels", []):
channel = _choose_channel_type(channel) channel = _choose_channel_type(channel)
self.channels[channel.id] = channel self.channels[Snowflake(int(channel.id))] = channel
return self return self