BREAKING CHANGE: new thread model parsing method

This commit is contained in:
grey-cat-1908 2022-04-16 14:54:34 +03:00
parent 003d6e7bcc
commit 9a25d93786
7 changed files with 97 additions and 29 deletions

View file

@ -1007,5 +1007,5 @@ channel_types_for_converting: Dict[ChannelType, Channel] = {
ChannelType.GUILD_TEXT: TextChannel,
ChannelType.GUILD_NEWS_THREAD: Thread,
ChannelType.GUILD_PUBLIC_THREAD: Thread,
ChannelType.GUILD_PRIVATE_THREAD: Thread
ChannelType.GUILD_PRIVATE_THREAD: Thread,
}

View file

@ -7,7 +7,15 @@ from dataclasses import dataclass
from enum import IntEnum, Enum
from typing import List, Any, Optional, overload, Dict
from .channel import Channel, ChannelType, channel_types_for_converting, ThreadsList, Thread, _choose_channel_type
from .channel import (
Channel,
ChannelType,
channel_types_for_converting,
ThreadsList,
Thread,
_choose_channel_type,
NoneTypedChannel,
)
from ...utils import Snowflake, Timestamp
from ...utils.api_model import APIModelBase
from ...utils.conversion import try_enum
@ -356,6 +364,13 @@ class Guild(APIModelBase):
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> Guild:
"""Generate a guild from the given data.
Parameters
----------
data: :class:`dict`
The dictionary to convert into a guild.
"""
self: Guild = super().__new__(cls)
self.id = int(data["id"])
@ -389,11 +404,15 @@ class Guild(APIModelBase):
self.max_video_channel_users = data.get("max_video_channel_users")
self.premium_tier = try_enum(PremiumTier, data.get("premium_tier", 0))
self.premium_subscription_count = data.get("premium_subscription_count", 0)
self.system_channel_flags = try_enum(SystemChannelFlags, data.get("system_channel_flags", 0))
self.system_channel_flags = try_enum(
SystemChannelFlags, data.get("system_channel_flags", 0)
)
self.preferred_locale = data.get("preferred_locale")
self.discovery_splash = data.get("discovery_splash")
self.nsfw_level = data.get("nsfw_level", 0)
self.premium_progress_bar_enabled = data.get("premium_progress_bar_enabled", False)
self.premium_progress_bar_enabled = data.get(
"premium_progress_bar_enabled", False
)
self.approximate_presence_count = data.get("approximate_presence_count")
self.approximate_member_count = data.get("approximate_member_count")
self.widget_enabled = data.get("widget_enabled")
@ -414,7 +433,9 @@ class Guild(APIModelBase):
self.rules_channel_id = None
if data.get("public_updates_channel_id") is not None:
self.public_updates_channel_id = Snowflake(data["public_updates_channel_id"])
self.public_updates_channel_id = Snowflake(
data["public_updates_channel_id"]
)
else:
self.public_updates_channel_id = None
@ -524,10 +545,7 @@ class Guild(APIModelBase):
headers={"X-Audit-Log-Reason": reason},
)
data.update({"type": ChannelType(data.pop("type"))})
channel_cls = channel_types_for_converting.get(data["type"], Channel)
return channel_cls.from_dict(data)
return _choose_channel_type(data)
async def active_threads(self) -> ThreadsList:
"""|coro|

View file

@ -4,6 +4,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Any
from ...utils.api_model import APIModelBase
from ...utils.types import APINullable, UNDEFINED
@ -40,8 +41,33 @@ class ThreadMetadata(APIModelBase):
auto_archive_duration: int
archive_timestamp: Timestamp
locked: bool
invitable: APINullable[bool] = UNDEFINED
create_timestamp: APINullable[Timestamp] = UNDEFINED
invitable: APINullable[bool] = None
create_timestamp: APINullable[Timestamp] = None
@classmethod
def from_dict(cls, data: Dict[str, Any]):
"""Generate a thread metadata object from the given data.
Parameters
----------
data: :class:`dict`
The dictionary to convert into thread metadata.
"""
self: ThreadMetadata = super().__new__(cls)
self.archived = data["archived"]
self.auto_archive_duration = data["auto_archive_duration"]
self.archive_timestamp = Timestamp.parse(data["archive_timestamp"])
self.locked = data["locked"]
self.invitable = data.get("invitable", None)
if data.get("create_timestamp"):
self.create_timestamp = Timestamp.parse(data["create_timestamp"])
else:
self.create_timestamp = None
return self
@dataclass(repr=False)
@ -62,5 +88,26 @@ class ThreadMember(APIModelBase):
join_timestamp: Timestamp
flags: int
id: APINullable[Snowflake] = UNDEFINED
user_id: APINullable[Snowflake] = UNDEFINED
id: APINullable[Snowflake] = None
user_id: APINullable[Snowflake] = None
@classmethod
def from_dict(cls, data: Dict[str, Any]):
"""Generate a thread member object from the given data.
Parameters
----------
data: :class:`dict`
The dictionary to convert into thread member.
"""
self: ThreadMember = super().__new__(cls)
self.archived = data["flags"]
self.archive_timestamp = Timestamp.parse(data["join_timestamp"])
self.id = Snowflake(data["id"]) if data.get("id") is not None else None
self.user_id = (
Snowflake(data["user_id"]) if data.get("user_id") is not None else None
)
return self

View file

@ -86,9 +86,7 @@ class Webhook(APIModelBase):
source_channel: APINullable[Channel] = UNDEFINED
url: APINullable[str] = UNDEFINED
async def delete(
self, *, reason: Optional[str] = None
):
async def delete(self, *, reason: Optional[str] = None):
"""|coro|
Delete a webhook permanently. Requires the ``MANAGE_WEBHOOKS`` permission.
Returns a ``204 No Content`` response on success.
@ -104,7 +102,11 @@ class Webhook(APIModelBase):
)
async def modify(
self, *, name: Optional[str] = None, channel_id: Optional[Snowflake] = None, reason: Optional[str] = None
self,
*,
name: Optional[str] = None,
channel_id: Optional[Snowflake] = None,
reason: Optional[str] = None,
):
"""|coro|
Modify a webhook. Requires the ``MANAGE_WEBHOOKS permission``. Returns the updated webhook object on success.
@ -121,5 +123,9 @@ class Webhook(APIModelBase):
await self._http.patch(
f"/webhooks/{self.id}",
headers={"name": name, "channel_id": channel_id, "X-Audit-Log-Reason": reason},
headers={
"name": name,
"channel_id": channel_id,
"X-Audit-Log-Reason": reason,
},
)

View file

@ -348,7 +348,9 @@ class Embed(APIModelBase):
return self
def set_thumbnail(self, url: str, *, proxy_url: APINullable[str] = UNDEFINED) -> Embed:
def set_thumbnail(
self, url: str, *, proxy_url: APINullable[str] = UNDEFINED
) -> Embed:
"""Set the thumbnail for the embed.
Parameters

View file

@ -18,7 +18,8 @@ from typing import (
Any,
get_origin,
Tuple,
get_args, Optional,
get_args,
Optional,
)
from typing_extensions import get_type_hints
@ -50,9 +51,7 @@ def _asdict_ignore_none(obj: Generic[T]) -> Union[Tuple, Dict, T]:
if isinstance(value, Enum):
result.append((f.name, value.value))
# This if statement was added to the function
elif not isinstance(value, UndefinedType) and not f.name.startswith(
"_"
):
elif not isinstance(value, UndefinedType) and not f.name.startswith("_"):
result.append((f.name, value))
return dict(result)
@ -65,8 +64,7 @@ def _asdict_ignore_none(obj: Generic[T]) -> Union[Tuple, Dict, T]:
elif isinstance(obj, dict):
return type(obj)(
(_asdict_ignore_none(k), _asdict_ignore_none(v))
for k, v in obj.items()
(_asdict_ignore_none(k), _asdict_ignore_none(v)) for k, v in obj.items()
)
else:
return copy.deepcopy(obj)
@ -142,9 +140,7 @@ class APIModelBase:
types = self.__get_types(attr, attr_type)
types = tuple(
filter(
lambda tpe: tpe is not None and tpe is not UNDEFINED, types
)
filter(lambda tpe: tpe is not None and tpe is not UNDEFINED, types)
)
if not types:

View file

@ -25,4 +25,3 @@ def try_enum(cls: Type[T], val: Any) -> T:
return cls(val)
except (KeyError, TypeError, AttributeError, ValueError):
return val