mirror of
https://github.com/MelisaDev/melisa.git
synced 2024-11-11 19:07:28 +03:00
add guild channel cache (set/get)
This commit is contained in:
parent
78745c3334
commit
23c0ebbefb
4 changed files with 90 additions and 9 deletions
|
@ -11,7 +11,7 @@ from typing import Dict, List, Union, Any, Iterable, Optional, Callable
|
|||
from .models.app.cache import CacheManager
|
||||
from .rest import RESTApp
|
||||
from .core.gateway import GatewayBotInfo
|
||||
from .models.guild.channel import Channel
|
||||
from .models.guild.channel import Channel, ChannelType
|
||||
from .models import Activity, AllowedMentions
|
||||
from .models.app.shard import Shard
|
||||
from .models.app.intents import Intents
|
||||
|
@ -266,7 +266,12 @@ class Client:
|
|||
|
||||
# ToDo: Update cache if CHANNEL_CACHE enabled.
|
||||
|
||||
return await self.rest.fetch_channel(channel_id)
|
||||
data = await self.rest.fetch_channel(channel_id)
|
||||
|
||||
if data.type not in [ChannelType.DM, ChannelType.GROUP_DM]:
|
||||
self.cache.set_guild_channel(data)
|
||||
|
||||
return data
|
||||
|
||||
async def wait_for(
|
||||
self,
|
||||
|
|
|
@ -7,7 +7,7 @@ from enum import Enum
|
|||
from typing import List, Dict, Optional, Any, Union
|
||||
|
||||
from melisa.utils.types import UNDEFINED
|
||||
from melisa.models.guild import Guild, ChannelType, UnavailableGuild
|
||||
from melisa.models.guild import Guild, ChannelType, UnavailableGuild, Channel
|
||||
from melisa.utils.snowflake import Snowflake
|
||||
|
||||
|
||||
|
@ -16,6 +16,7 @@ class AutoCacheModels(Enum):
|
|||
|
||||
""" """
|
||||
|
||||
FULL_GUILDS = "FULL_GUILDS"
|
||||
GUILD_ROLES = "GUILD_ROLES"
|
||||
GUILD_THREADS = "GUILD_THREADS"
|
||||
GUILD_EMOJIS = "GUILD_EMOJIS"
|
||||
|
@ -30,6 +31,7 @@ class CacheManager:
|
|||
def __init__(
|
||||
self,
|
||||
*,
|
||||
disabled: bool = False,
|
||||
auto_models: Optional[List[AutoCacheModels]] = None,
|
||||
auto_unused_attributes: Optional[Dict[Any, List[str]]] = None,
|
||||
):
|
||||
|
@ -44,6 +46,8 @@ class CacheManager:
|
|||
self._raw_users: Dict[Snowflake, Any] = {}
|
||||
self._raw_dm_channels: Dict[Snowflake, Any] = {}
|
||||
|
||||
self._disabled = disabled
|
||||
|
||||
# We use symlinks to cache guild channels
|
||||
# like we save channel in Guild and save it here
|
||||
# and if you need channel, and you don't know its guild
|
||||
|
@ -90,6 +94,9 @@ class CacheManager:
|
|||
Guild to save into cache
|
||||
"""
|
||||
|
||||
if self._disabled:
|
||||
return
|
||||
|
||||
if guild is None:
|
||||
return None
|
||||
|
||||
|
@ -104,15 +111,75 @@ class CacheManager:
|
|||
)
|
||||
|
||||
for sym in channels:
|
||||
if self._channel_symlinks.get(sym.id, UNDEFINED) is not UNDEFINED:
|
||||
self._channel_symlinks.pop(sym.id)
|
||||
sym_id = Snowflake(int(sym.id))
|
||||
if self._channel_symlinks.get(sym_id, UNDEFINED) is not UNDEFINED:
|
||||
self._channel_symlinks.pop(sym_id)
|
||||
|
||||
self._channel_symlinks[sym.id] = guild.id
|
||||
self._channel_symlinks[sym_id] = guild.id
|
||||
|
||||
self._raw_guilds.update({guild.id: guild})
|
||||
|
||||
return guild
|
||||
|
||||
def set_guild_channel(self, channel: Optional[Channel] = None):
|
||||
"""
|
||||
Save Guild into cache
|
||||
|
||||
Parameters
|
||||
----------
|
||||
channel: Optional[`~melisa.models.guild.Channel`]
|
||||
Guild Channel to save into cache
|
||||
"""
|
||||
|
||||
if self._disabled:
|
||||
return
|
||||
|
||||
if channel is None:
|
||||
return None
|
||||
|
||||
channel = self.__remove_unused_attributes(channel, Channel) # ToDo: add channel type
|
||||
|
||||
guild = self._raw_guilds.get(channel.guild_id, UNDEFINED)
|
||||
|
||||
channel_id = Snowflake(int(channel.id))
|
||||
|
||||
if guild != UNDEFINED:
|
||||
if hasattr(guild, "channels"):
|
||||
self._raw_guilds[guild.id].channels.update({channel_id: channel})
|
||||
|
||||
self._channel_symlinks.update({channel_id: channel.guild_id})
|
||||
|
||||
return channel
|
||||
|
||||
def get_guild_channel(self, channel_id: Union[Snowflake, str, int]):
|
||||
"""
|
||||
Get guild channel from cache
|
||||
|
||||
Parameters
|
||||
----------
|
||||
channel_id: Optional[:class:`~melisa.utils.snowflake.Snowflake`, `str`, `int`]
|
||||
ID of guild channel to get from cache.
|
||||
"""
|
||||
|
||||
if self._disabled:
|
||||
return None
|
||||
|
||||
channel_id = Snowflake(int(channel_id))
|
||||
guild_id = self._channel_symlinks.get(channel_id, UNDEFINED)
|
||||
|
||||
if guild_id == UNDEFINED:
|
||||
return None
|
||||
|
||||
guild = self.get_guild(guild_id)
|
||||
|
||||
if guild is None:
|
||||
return None
|
||||
|
||||
if hasattr(guild, "channels") is False:
|
||||
return None
|
||||
|
||||
return guild.channels.get(channel_id)
|
||||
|
||||
def get_guild(self, guild_id: Union[Snowflake, str, int]):
|
||||
"""
|
||||
Get guild from cache
|
||||
|
@ -123,6 +190,9 @@ class CacheManager:
|
|||
ID of guild to get from cache.
|
||||
"""
|
||||
|
||||
if self._disabled:
|
||||
return None
|
||||
|
||||
if not isinstance(guild_id, Snowflake):
|
||||
guild_id = Snowflake(int(guild_id))
|
||||
return self._raw_guilds.get(guild_id, None)
|
||||
|
@ -137,8 +207,11 @@ class CacheManager:
|
|||
Data of guilds tso insert to the cache
|
||||
"""
|
||||
|
||||
if self._disabled:
|
||||
return
|
||||
|
||||
guilds_dict = dict(
|
||||
map(lambda i: (i["id"], UnavailableGuild.from_dict(i)), guilds)
|
||||
map(lambda i: (Snowflake(int(i["id"])), UnavailableGuild.from_dict(i)), guilds)
|
||||
)
|
||||
|
||||
self._raw_guilds.update(guilds_dict)
|
||||
|
@ -154,6 +227,9 @@ class CacheManager:
|
|||
ID of guild to remove from cache.
|
||||
"""
|
||||
|
||||
if self._disabled:
|
||||
return
|
||||
|
||||
if not isinstance(guild_id, Snowflake):
|
||||
guild_id = Snowflake(int(guild_id))
|
||||
|
||||
|
|
|
@ -769,7 +769,7 @@ class TextChannel(MessageableChannel):
|
|||
self: TextChannel = super().__new__(cls)
|
||||
|
||||
self.id = data["id"]
|
||||
self.type = data["type"]
|
||||
self.type = ChannelType(data["type"])
|
||||
self.position = data.get("position")
|
||||
self.permission_overwrites = data["permission_overwrites"]
|
||||
self.name = data.get("name")
|
||||
|
|
|
@ -470,7 +470,7 @@ class Guild(APIModelBase):
|
|||
|
||||
for channel in data.get("channels", []):
|
||||
channel = _choose_channel_type(channel)
|
||||
self.channels[channel.id] = channel
|
||||
self.channels[Snowflake(int(channel.id))] = channel
|
||||
|
||||
return self
|
||||
|
||||
|
|
Loading…
Reference in a new issue