add textchannel (not ready) and fetch_channel method

This commit is contained in:
grey-cat-1908 2022-03-26 20:46:42 +03:00
parent ccb66efe2b
commit cabc848cb8
7 changed files with 142 additions and 20 deletions

View file

@ -8,9 +8,10 @@ from .utils.types import Coro
from .core.http import HTTPClient from .core.http import HTTPClient
from .core.gateway import GatewayBotInfo from .core.gateway import GatewayBotInfo
from .models.guild.channel import Channel, ChannelType, channel_types_for_converting
import asyncio import asyncio
from typing import Dict, List, Union from typing import Dict, List, Union, Any
class Client: class Client:
@ -163,3 +164,26 @@ class Client:
data = await self.http.get(f"guilds/{guild_id}") data = await self.http.get(f"guilds/{guild_id}")
return Guild.from_dict(data) return Guild.from_dict(data)
async def fetch_channel(
self, channel_id: Union[Snowflake, str, int]
) -> Union[Channel, Any]:
"""
Fetch Channel from the Discord API (by id).
If type of channel is unknown:
it will return just :class:`melisa.models.guild.channel.Channel` object.
Parameters
----------
channel_id : :class:`Union[Snowflake, str, int]`
Id of channel to fetch
"""
# ToDo: Update cache if CHANNEL_CACHE enabled.
data = (await self.http.get(f"channels/{channel_id}")) or {}
data.update({"type": ChannelType(data.pop("type"))})
channel_cls = channel_types_for_converting.get(data["type"], Channel)
return channel_cls.from_dict(data)

View file

@ -20,6 +20,7 @@ from melisa.exceptions import (
RateLimitError, RateLimitError,
) )
from .ratelimiter import RateLimiter from .ratelimiter import RateLimiter
from ..utils import remove_none
class HTTPClient: class HTTPClient:
@ -59,7 +60,13 @@ class HTTPClient:
await self.__aiohttp_session.close() await self.__aiohttp_session.close()
async def __send( async def __send(
self, method: str, endpoint: str, *, _ttl: int = None, **kwargs self,
method: str,
endpoint: str,
*,
_ttl: int = None,
params: Optional[Dict] = None,
**kwargs,
) -> Optional[Dict]: ) -> Optional[Dict]:
"""Send an API request to the Discord API.""" """Send an API request to the Discord API."""
@ -72,7 +79,9 @@ class HTTPClient:
url = f"{self.url}/{endpoint}" url = f"{self.url}/{endpoint}"
async with self.__aiohttp_session.request(method, url, **kwargs) as response: async with self.__aiohttp_session.request(
method, url, params=remove_none(params), **kwargs
) as response:
return await self.__handle_response( return await self.__handle_response(
response, method, endpoint, _ttl=ttl, **kwargs response, method, endpoint, _ttl=ttl, **kwargs
) )
@ -111,7 +120,7 @@ class HTTPClient:
return await self.__send(method, endpoint, _ttl=_ttl - 1, **kwargs) return await self.__send(method, endpoint, _ttl=_ttl - 1, **kwargs)
async def get(self, route: str, params: Optional[Dict] = None) -> Optional[Dict]: async def get(self, route: str, *, params: Optional[Dict] = None) -> Optional[Dict]:
"""|coro| """|coro|
Sends a GET request to a Discord REST API endpoint. Sends a GET request to a Discord REST API endpoint.

View file

@ -57,9 +57,9 @@ class Shard:
""" """
create_task(self._gateway.close()) create_task(self._gateway.close())
async def update_presence(self, async def update_presence(
activity: BotActivity = None, self, activity: BotActivity = None, status: str = None
status: str = None) -> Shard: ) -> Shard:
""" """
|coro| |coro|

View file

@ -4,17 +4,17 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum, Enum from enum import IntEnum
from typing import List, Any from typing import List, Any, Optional, AsyncIterator, Union, Dict
from ...utils import Snowflake from ...utils import Snowflake
from ...utils import APIModelBase from ...utils import APIModelBase
from ...utils.types import APINullable from ...utils.types import APINullable
class ChannelTypes(IntEnum): class ChannelType(IntEnum):
"""Channel Types """Channel Type
NOTE: Type 10, 11 and 12 are only available in API v9. NOTE: Type 10, 11 and 12 are only available in API v9 and older.
Attributes Attributes
---------- ----------
@ -109,3 +109,77 @@ class Channel(APIModelBase):
member: APINullable[List[Any]] = None member: APINullable[List[Any]] = None
default_auto_archive_duration: APINullable[int] = None default_auto_archive_duration: APINullable[int] = None
permissions: APINullable[str] = None permissions: APINullable[str] = None
@property
def mention(self):
return f"<#{self.id}>"
class TextChannel(Channel):
"""A subclass of ``Channel`` representing text channels with all the same attributes."""
async def history(
self,
limit: int = 50,
*,
before: Optional[Union[int, str, Snowflake]] = None,
after: Optional[Union[int, str, Snowflake]] = None,
around: Optional[Union[int, str, Snowflake]] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""|coro|
Returns a list of messages in this channel.
Examples
---------
Flattening messages into a list: ::
messages = [message async for message in channel.history(limit=111)]
All parameters are optional.
Parameters
----------
around : Optional[Union[:class:`int`, :class:`str`, :class:`Snowflake`]]
Get messages around this message ID.
before : Optional[Union[:class:`int`, :class:`str`, :class:`Snowflake`]]
Get messages before this message ID.
after : Optional[Union[:class:`int`, :class:`str`, :class:`Snowflake`]]
Get messages after this message ID.
limit : Optional[Union[:class:`int`, :class:`str`, :class:`Snowflake`]]
Max number of messages to return (1-100).
Returns
-------
AsyncIterator[Dict[:class:`str`, Any]]
An iterator of messages.
"""
if limit is None:
limit = 100
while limit > 0:
search_limit = min(limit, 100)
raw_messages = await self._http.get(
f"/channels/{self.id}/messages",
params={
"limit": search_limit,
"before": before,
"after": after,
"around": around,
},
)
if not raw_messages:
break
for message_data in raw_messages:
yield message_data
before = raw_messages[-1]["id"]
limit -= search_limit
# noinspection PyTypeChecker
channel_types_for_converting: Dict[ChannelType, Any] = {
ChannelType.GUILD_TEXT: TextChannel
}

View file

@ -7,5 +7,6 @@ from .snowflake import Snowflake
from .api_model import APIModelBase from .api_model import APIModelBase
from .conversion import remove_none
__all__ = ("Coro", "Snowflake", "APIModelBase") __all__ = ("Coro", "Snowflake", "APIModelBase", "remove_none")

View file

@ -18,12 +18,12 @@ from typing import (
T = TypeVar("T") T = TypeVar("T")
def _to_dict_without_none(model): def to_dict_without_none(model):
if _is_dataclass_instance(model): if _is_dataclass_instance(model):
result = [] result = []
for field in fields(model): for field in fields(model):
value = _to_dict_without_none(getattr(model, field.name)) value = to_dict_without_none(getattr(model, field.name))
if isinstance(value, Enum): if isinstance(value, Enum):
result.append((field.name, value.value)) result.append((field.name, value.value))
@ -33,15 +33,14 @@ def _to_dict_without_none(model):
return dict(result) return dict(result)
elif isinstance(model, tuple) and hasattr(model, "_fields"): elif isinstance(model, tuple) and hasattr(model, "_fields"):
return type(model)(*[_to_dict_without_none(v) for v in model]) return type(model)(*[to_dict_without_none(v) for v in model])
elif isinstance(model, (list, tuple)): elif isinstance(model, (list, tuple)):
return type(model)(_to_dict_without_none(v) for v in model) return type(model)(to_dict_without_none(v) for v in model)
elif isinstance(model, dict): elif isinstance(model, dict):
return type(model)( return type(model)(
(_to_dict_without_none(k), _to_dict_without_none(v)) (to_dict_without_none(k), to_dict_without_none(v)) for k, v in model.items()
for k, v in model.items()
) )
else: else:
return copy.deepcopy(model) return copy.deepcopy(model)
@ -90,4 +89,4 @@ class APIModelBase:
) )
def to_dict(self) -> Dict: def to_dict(self) -> Dict:
return _to_dict_without_none(self) return to_dict_without_none(self)

View file

@ -0,0 +1,15 @@
# Copyright MelisaDev 2022 - Present
# Full MIT License can be found in `LICENSE.txt` at the project root.
from __future__ import annotations
def remove_none(obj):
if isinstance(obj, list):
return [i for i in obj if i is not None]
elif isinstance(obj, tuple):
return tuple(i for i in obj if i is not None)
elif isinstance(obj, set):
return obj - {None}
elif isinstance(obj, dict):
return {k: v for k, v in obj.items() if None not in (k, v)}