diff --git a/melisa/models/app/shard.py b/melisa/models/app/shard.py index 799288a..4706a25 100644 --- a/melisa/models/app/shard.py +++ b/melisa/models/app/shard.py @@ -10,6 +10,10 @@ from ..user import Activity class Shard: + _num_shards: int + _shard_id: int + _gateway: Gateway + def __init__(self, client, shard_id: int, num_shards: int): self._client = client diff --git a/melisa/models/guild/channel.py b/melisa/models/guild/channel.py index 93cc0da..8c9cf07 100644 --- a/melisa/models/guild/channel.py +++ b/melisa/models/guild/channel.py @@ -779,13 +779,13 @@ class TextChannel(MessageableChannel): self.id = data["id"] self.type = ChannelType(data["type"]) self.position = data.get("position") - self.permission_overwrites = data["permission_overwrites"] + self.permission_overwrites = data.get("permission_overwrites") self.name = data.get("name") self.topic = data.get("topic") self.nsfw = data.get("nsfw") if data.get("last_message_id") is not None: - self.last_message_id = Snowflake(data["last_message_id"]) + self.last_message_id = Snowflake(data.get("last_message_id", 0)) else: self.last_message_id = None @@ -802,7 +802,7 @@ class TextChannel(MessageableChannel): self.parent_id = None if data.get("last_pin_timestamp") is not None: - self.last_pin_timestamp = Timestamp.parse(data["last_pin_timestamp"]) + self.last_pin_timestamp = Timestamp.parse(data.get("last_pin_timestamp", 0)) else: self.last_pin_timestamp = None diff --git a/melisa/models/guild/guild.py b/melisa/models/guild/guild.py index 4580821..0e4495b 100644 --- a/melisa/models/guild/guild.py +++ b/melisa/models/guild/guild.py @@ -10,16 +10,14 @@ from typing import List, Any, Optional, overload, Dict from .channel import ( Channel, ChannelType, - channel_types_for_converting, ThreadsList, Thread, _choose_channel_type, - NoneTypedChannel, ) from ...utils import Snowflake, Timestamp from ...utils.api_model import APIModelBase from ...utils.conversion import try_enum -from ...utils.types import APINullable, UNDEFINED +from ...utils.types import APINullable class DefaultMessageNotificationLevel(IntEnum): @@ -302,11 +300,11 @@ class Guild(APIModelBase): """ id: Snowflake - roles: APINullable[List] - emojis: APINullable[List] - members: APINullable[List] + roles: APINullable[Dict] + emojis: APINullable[Dict] + members: APINullable[Dict] threads: APINullable[Dict] - presences: APINullable[List] + presences: APINullable[Dict] channels: APINullable[Dict] name: APINullable[str] = None icon: APINullable[str] = None @@ -601,4 +599,4 @@ class UnavailableGuild(APIModelBase): self.id = Snowflake(int(data["id"])) self.unavailable = data["unavailable"] - return self + return self \ No newline at end of file diff --git a/tests/test_caching.py b/tests/test_caching.py new file mode 100644 index 0000000..63c6fd0 --- /dev/null +++ b/tests/test_caching.py @@ -0,0 +1,54 @@ +import pytest + +from melisa import CacheManager, Snowflake, Guild, TextChannel + + +class TestCache: + @pytest.fixture() + def cache(self, cache_params: dict = None): + if cache_params is None: + cache_params = {} + + return CacheManager(**cache_params) + + def test_count(self, cache): + cache._raw_guilds = { + Snowflake(123): 123, + Snowflake(1234): 1111 + } + cache._raw_users = { + Snowflake(123): 123, + Snowflake(1234): 1111 + } + cache._raw_dm_channels = { + Snowflake(123): 123, + Snowflake(1234): 1111 + } + cache._channel_symlinks = { + Snowflake(123): 123, + Snowflake(1234): 1111 + } + + assert cache.guilds_count() == 2 + assert cache.users_count() == 2 + assert cache.guild_channels_count() == 2 + assert cache.total_channels_count() == 4 + + def test_set_and_get_guild(self, cache): + cache.set_guild( + Guild.from_dict({"id": "123", "name": "test"}) + ) + + assert cache.get_guild(123).name == "test" + + def test_set_and_get_guild_channel(self, cache): + channel = {"id": "456", "name": "test", "type": 0, "guild_id": 123} + + cache.set_guild( + Guild.from_dict({"id": "123", "name": "test", "channels": [channel]}) + ) + cache.set_guild_channel( + TextChannel.from_dict(channel) + ) + + assert cache.get_guild_channel(456).name == "test"