formated files with the black formatter

This commit is contained in:
grey-cat-1908 2022-03-25 20:18:42 +03:00
parent 3d66eb79c5
commit 2b2f5a2d85
12 changed files with 152 additions and 199 deletions

View file

@ -83,24 +83,26 @@ class Client:
def run(self) -> None:
"""
Run Bot without shards (only 0 shard)
Run Bot without shards (only 0 shard)
"""
inited_shard = Shard(self, 0, 1)
asyncio.ensure_future(inited_shard.launch(activity=self._activity,
status=self._status), loop=self._loop)
asyncio.ensure_future(
inited_shard.launch(activity=self._activity, status=self._status),
loop=self._loop,
)
self._loop.run_forever()
def run_shards(self, num_shards: int, *, shard_ids: List[int] = None):
"""
Run Bot with shards specified by the user.
Run Bot with shards specified by the user.
Parameters
----------
num_shards : :class:`int`
The endpoint to send the request to.
shard_ids: Optional[:class:`List[int]`]
List of Ids of shards to start.
Parameters
----------
num_shards : :class:`int`
The endpoint to send the request to.
shard_ids: Optional[:class:`List[int]`]
List of Ids of shards to start.
"""
if not shard_ids:
shard_ids = range(num_shards)
@ -108,13 +110,15 @@ class Client:
for shard_id in shard_ids:
inited_shard = Shard(self, shard_id, num_shards)
asyncio.ensure_future(inited_shard.launch(activity=self._activity,
status=self._status), loop=self._loop)
asyncio.ensure_future(
inited_shard.launch(activity=self._activity, status=self._status),
loop=self._loop,
)
self._loop.run_forever()
def run_autosharded(self):
"""
Runs the bot with the amount of shards specified by the Discord gateway.
Runs the bot with the amount of shards specified by the Discord gateway.
"""
num_shards = self._gateway_info.shards
shard_ids = range(num_shards)
@ -122,18 +126,20 @@ class Client:
for shard_id in shard_ids:
inited_shard = Shard(self, shard_id, num_shards)
asyncio.ensure_future(inited_shard.launch(activity=self._activity,
status=self._status), loop=self._loop)
asyncio.ensure_future(
inited_shard.launch(activity=self._activity, status=self._status),
loop=self._loop,
)
self._loop.run_forever()
async def fetch_user(self, user_id: Union[Snowflake, str, int]):
"""
Fetch User from the Discord API (by id).
Fetch User from the Discord API (by id).
Parameters
----------
user_id : :class:`Union[Snowflake, str, int]`
Id of user to fetch
Parameters
----------
user_id : :class:`Union[Snowflake, str, int]`
Id of user to fetch
"""
# ToDo: Update cache if USER_CACHING enabled.
@ -144,12 +150,12 @@ class Client:
async def fetch_guild(self, guild_id: Union[Snowflake, str, int]):
"""
Fetch Guild from the Discord API (by id).
Fetch Guild from the Discord API (by id).
Parameters
----------
guild_id : :class:`Union[Snowflake, str, int]`
Id of guild to fetch
Parameters
----------
guild_id : :class:`Union[Snowflake, str, int]`
Id of guild to fetch
"""
# ToDo: Update cache if GUILD_CACHE enabled.

View file

@ -20,6 +20,7 @@ from ..utils import APIModelBase, json
@dataclass
class GatewayBotInfo(APIModelBase):
"""Gateway info from the `gateway/bot` endpoint"""
url: str
shards: int
session_start_limit: dict
@ -39,11 +40,7 @@ class Gateway:
HELLO = 10
HEARTBEAT_ACK = 11
def __init__(self,
client,
shard_id: int = 0,
num_shards: int = 1,
**kwargs):
def __init__(self, client, shard_id: int = 0, num_shards: int = 1, **kwargs):
self.GATEWAY_VERSION = "9"
self.interval = None
@ -52,7 +49,7 @@ class Gateway:
self.__session = aiohttp.ClientSession()
self.session_id = None
self.client = client
self.latency = float('inf')
self.latency = float("inf")
self.ws = None
self.loop = asyncio.get_event_loop()
self.shard_id = shard_id
@ -63,7 +60,7 @@ class Gateway:
4011: GatewayError("Sharding required"),
4012: GatewayError("Invalid API version"),
4013: GatewayError("Invalid intents"),
4014: PrivilegedIntentsRequired("Disallowed intents")
4014: PrivilegedIntentsRequired("Disallowed intents"),
}
self.listeners = listeners
@ -76,19 +73,21 @@ class Gateway:
"properties": {
"$os": sys.platform,
"$browser": "Melisa Python Library",
"$device": "Melisa Python Library"
"$device": "Melisa Python Library",
},
"compress": True,
"shard": [shard_id, num_shards],
"presence": self.generate_presence(kwargs.get("start_activity"),
kwargs.get("start_status"))}
"presence": self.generate_presence(
kwargs.get("start_activity"), kwargs.get("start_status")
),
}
self._zlib: zlib._Decompress = zlib.decompressobj()
self._buffer: bytearray = bytearray()
async def connect(self) -> None:
self.ws = await self.__session.ws_connect(
f'wss://gateway.discord.gg/?v={self.GATEWAY_VERSION}&encoding=json&compress=zlib-stream'
f"wss://gateway.discord.gg/?v={self.GATEWAY_VERSION}&encoding=json&compress=zlib-stream"
)
if self.session_id is None:
@ -114,10 +113,10 @@ class Gateway:
if type(msg) is bytes:
self._buffer.extend(msg)
if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff':
if len(msg) < 4 or msg[-4:] != b"\x00\x00\xff\xff":
return None
msg = self._zlib.decompress(self._buffer)
msg = msg.decode('utf-8')
msg = msg.decode("utf-8")
self._buffer = bytearray()
return json.loads(msg)
@ -125,7 +124,7 @@ class Gateway:
return None
async def handle_data(self, data):
if data['op'] == self.DISPATCH:
if data["op"] == self.DISPATCH:
self.sequence = int(data["s"])
event_type = data["t"].lower()
@ -134,12 +133,12 @@ class Gateway:
if event_to_call is not None:
ensure_future(event_to_call(self.client, self, data["d"]))
elif data['op'] == self.INVALID_SESSION:
elif data["op"] == self.INVALID_SESSION:
await self.ws.close(code=4000)
await self.handle_close(4000)
elif data['op'] == self.HELLO:
elif data["op"] == self.HELLO:
await self.send_hello(data)
elif data['op'] == self.HEARTBEAT_ACK:
elif data["op"] == self.HEARTBEAT_ACK:
self.latency = time.perf_counter() - self._last_send
async def receive(self) -> None:
@ -184,34 +183,28 @@ class Gateway:
self._buffer.clear()
async def send_hello(self, data: Dict) -> None:
interval = data['d']['heartbeat_interval'] / 1000
interval = data["d"]["heartbeat_interval"] / 1000
await asyncio.sleep((interval - 2000) / 1000)
self.loop.create_task(self.send_heartbeat(interval))
async def send_identify(self) -> None:
await self.send(self.opcode(
self.IDENTIFY,
self.auth
))
await self.send(self.opcode(self.IDENTIFY, self.auth))
async def resume(self) -> None:
await self.send(
self.opcode(
self.RESUME,
{
'token': self.client._token,
'session_id': self.session_id,
'seq': self.sequence,
}
"token": self.client._token,
"session_id": self.session_id,
"seq": self.sequence,
},
)
)
@staticmethod
def generate_presence(activity: BotActivity = None, status: str = None):
data = {
"since": time.time() * 1000,
"afk": False
}
data = {"since": time.time() * 1000, "afk": False}
if activity is not None:
data["activities"] = activity.to_dict()
@ -226,8 +219,5 @@ class Gateway:
@staticmethod
def opcode(opcode: int, payload) -> str:
data = {
"op": opcode,
"d": payload
}
data = {"op": opcode, "d": payload}
return json.dumps(data)

View file

@ -8,15 +8,17 @@ from typing import Dict, Optional
from aiohttp import ClientSession, ClientResponse
from melisa.exceptions import (NotModifiedError,
BadRequestError,
ForbiddenError,
UnauthorizedError,
HTTPException,
NotFoundError,
MethodNotAllowedError,
ServerError,
RateLimitError)
from melisa.exceptions import (
NotModifiedError,
BadRequestError,
ForbiddenError,
UnauthorizedError,
HTTPException,
NotFoundError,
MethodNotAllowedError,
ServerError,
RateLimitError,
)
from .ratelimiter import RateLimiter
@ -31,7 +33,7 @@ class HTTPClient:
headers: Dict[str, str] = {
"Content-Type": "application/json",
"Authorization": f"Bot {token}",
"User-Agent": "Melisa Python Library"
"User-Agent": "Melisa Python Library",
}
self.__http_exceptions: Dict[int, HTTPException] = {
@ -41,7 +43,7 @@ class HTTPClient:
403: ForbiddenError(),
404: NotFoundError(),
405: MethodNotAllowedError(),
429: RateLimitError()
429: RateLimitError(),
}
self.__aiohttp_session: ClientSession = ClientSession(headers=headers)
@ -57,12 +59,7 @@ class HTTPClient:
await self.__aiohttp_session.close()
async def __send(
self,
method: str,
endpoint: str,
*,
_ttl: int = None,
**kwargs
self, method: str, endpoint: str, *, _ttl: int = None, **kwargs
) -> Optional[Dict]:
"""Send an API request to the Discord API."""
@ -71,30 +68,27 @@ class HTTPClient:
if ttl == 0:
raise ServerError(f"Maximum amount of retries for `{endpoint}`.")
await self.__rate_limiter.wait_until_not_ratelimited(
endpoint,
method
)
await self.__rate_limiter.wait_until_not_ratelimited(endpoint, method)
url = f"{self.url}/{endpoint}"
async with self.__aiohttp_session.request(method, url, **kwargs) as response:
return await self.__handle_response(response, method, endpoint, _ttl=ttl, **kwargs)
return await self.__handle_response(
response, method, endpoint, _ttl=ttl, **kwargs
)
async def __handle_response(
self,
res: ClientResponse,
method: str,
endpoint: str,
*,
_ttl: int = None,
**kwargs
self,
res: ClientResponse,
method: str,
endpoint: str,
*,
_ttl: int = None,
**kwargs,
) -> Optional[Dict]:
"""Handle responses from the Discord API."""
self.__rate_limiter.save_response_bucket(
endpoint, method, res.headers
)
self.__rate_limiter.save_response_bucket(endpoint, method, res.headers)
if res.ok:
return await res.json()
@ -106,11 +100,7 @@ class HTTPClient:
timeout = (await res.json()).get("retry_after", 40)
await asyncio.sleep(timeout)
return await self.__send(
method,
endpoint,
**kwargs
)
return await self.__send(method, endpoint, **kwargs)
exception.__init__(res.reason)
raise exception
@ -119,18 +109,9 @@ class HTTPClient:
await asyncio.sleep(retry_in)
return await self.__send(
method,
endpoint,
_ttl=_ttl - 1,
**kwargs
)
return await self.__send(method, endpoint, _ttl=_ttl - 1, **kwargs)
async def get(
self,
route: str,
params: Optional[Dict] = None
) -> Optional[Dict]:
async def get(self, route: str, params: Optional[Dict] = None) -> Optional[Dict]:
"""|coro|
Sends a GET request to a Discord REST API endpoint.
@ -146,17 +127,9 @@ class HTTPClient:
Optional[:class:`Dict`]
The response from Discord.
"""
return await self.__send(
"GET",
route,
params=params
)
return await self.__send("GET", route, params=params)
async def post(
self,
route: str,
data: Optional[Dict] = None
) -> Optional[Dict]:
async def post(self, route: str, data: Optional[Dict] = None) -> Optional[Dict]:
"""|coro|
Sends a POST request to a Discord REST API endpoint.
@ -178,11 +151,7 @@ class HTTPClient:
json=data,
)
async def delete(
self,
route: str,
headers: dict = None
) -> Optional[Dict]:
async def delete(self, route: str, headers: dict = None) -> Optional[Dict]:
"""|coro|
Sends a DELETE request to a Discord REST API endpoint.
@ -198,8 +167,4 @@ class HTTPClient:
Optional[:class:`Dict`]
JSON response from the Discord API.
"""
return await self.__send(
"DELETE",
route,
headers=headers
)
return await self.__send("DELETE", route, headers=headers)

View file

@ -12,6 +12,7 @@ from typing import Dict, Tuple, Any
@dataclass
class RateLimitBucket:
"""Represents a rate limit bucket"""
limit: int
remaining: int
reset: float
@ -23,15 +24,12 @@ class RateLimiter:
"""Prevents ``user`` rate limits"""
def __init__(self) -> None:
self.bucket_map: Dict[Tuple[str, str], str] = {} # Dict[Tuple[endpoint, method], bucket_id]
self.bucket_map: Dict[
Tuple[str, str], str
] = {} # Dict[Tuple[endpoint, method], bucket_id]
self.buckets: Dict[str, RateLimitBucket] = {}
def save_response_bucket(
self,
endpoint: str,
method: str,
header: Any
):
def save_response_bucket(self, endpoint: str, method: str, header: Any):
ratelimit_bucket_id = header.get("X-RateLimit-Bucket")
if not ratelimit_bucket_id:
@ -44,14 +42,10 @@ class RateLimiter:
remaining=int(header["X-RateLimit-Remaining"]),
reset=float(header["X-RateLimit-Reset"]),
reset_after_timestamp=float(header["X-RateLimit-Reset-After"]),
since_timestamp=time()
since_timestamp=time(),
)
async def wait_until_not_ratelimited(
self,
endpoint: str,
method: str
):
async def wait_until_not_ratelimited(self, endpoint: str, method: str):
bucket_id = self.bucket_map.get((endpoint, method))
if not bucket_id:

View file

@ -1,25 +1,29 @@
# Copyright MelisaDev 2022 - Present
# Full MIT License can be found in `LICENSE.txt` at the project root.
class MelisaException(Exception):
"""Base exception"""
pass
class ClientException(MelisaException):
"""Handling user errors"""
pass
class LoginFailure(ClientException):
"""Fails to log you in from improper credentials or some other misc."""
pass
class ConnectionClosed(ClientException):
"""Exception that's thrown when the gateway connection is closed
for reasons that could not be handled
internally. """
internally."""
def __init__(self, socket, *, shard_id, code=None):
message = "Websocket with shard ID {} closed with code {}"
@ -38,9 +42,11 @@ class PrivilegedIntentsRequired(ClientException):
def __init__(self, shard_id):
self.shard_id = shard_id
message = "Shard ID {} is requesting privileged intents " \
"that have not been explicitly enabled in the " \
"developer portal. Please visit to https://discord.com/developers/applications/ "
message = (
"Shard ID {} is requesting privileged intents "
"that have not been explicitly enabled in the "
"developer portal. Please visit to https://discord.com/developers/applications/ "
)
super().__init__(message.format(self.shard_id))

View file

@ -10,10 +10,7 @@ from ..user import BotActivity
class Shard:
def __init__(self,
client,
shard_id: int,
num_shards: int):
def __init__(self, client, shard_id: int, num_shards: int):
self._client = client
self._shard_id: int = shard_id
@ -38,11 +35,13 @@ class Shard:
"""|coro|
Launches new shard"""
self._gateway = Gateway(self._client,
self._shard_id,
self._num_shards,
start_activity=kwargs.get("activity"),
start_status=kwargs.get("status"))
self._gateway = Gateway(
self._client,
self._shard_id,
self._num_shards,
start_activity=kwargs.get("activity"),
start_status=kwargs.get("status"),
)
self._client.shards[self._shard_id] = self
@ -54,11 +53,13 @@ class Shard:
async def close(self):
"""|coro|
Disconnect shard
Disconnect shard
"""
create_task(self._gateway.close())
async def update_presence(self, activity: BotActivity = None, status: str = None) -> Shard:
async def update_presence(
self, activity: BotActivity = None, status: str = None
) -> Shard:
"""
|coro|

View file

@ -12,10 +12,10 @@ from ...utils.types import APINullable
class BasePresence:
"""
All the information about activities here is from the Discord API docs.
Read more here: https://discord.com/developers/docs/topics/gateway#activity-object
All the information about activities here is from the Discord API docs.
Read more here: https://discord.com/developers/docs/topics/gateway#activity-object
Unknown data will be returned as None.
Unknown data will be returned as None.
"""
@ -39,6 +39,7 @@ class ActivityType(IntEnum):
COMPETING:
Competing in {name} (Competing in Arena World Champions)
"""
GAME = 0
STREAMING = 1
LISTENING = 2
@ -61,6 +62,7 @@ class ActivityTimestamp(BasePresence, APIModelBase):
end: Optional[:class:`int`]
Unix time (in milliseconds) of when the activity ends
"""
start: APINullable[int] = None
end: APINullable[int] = None
@ -78,6 +80,7 @@ class ActivityEmoji(BasePresence, APIModelBase):
animated: Optional[:class:`bool`]
Whether this emoji is animated
"""
name: str
id: APINullable[Snowflake] = None
animated: APINullable[bool] = None
@ -94,6 +97,7 @@ class ActivityParty(BasePresence, APIModelBase):
size: Optional[Tuple[:class:`int`, :class:`int`]]
Array of two integers (current_size, max_size)
"""
id: APINullable[str] = None
size: APINullable[Tuple[int, int]] = None
@ -115,6 +119,7 @@ class ActivityAssets(BasePresence, APIModelBase):
small_text: Optional[:class:`str`]
text displayed when hovering over the small image of the activity
"""
large_image: APINullable[str] = None
large_text: APINullable[str] = None
small_image: APINullable[str] = None
@ -134,6 +139,7 @@ class ActivitySecrets(BasePresence, APIModelBase):
match: Optional[:class:`str`]
The secret for a specific instanced match
"""
join: APINullable[str] = None
spectate: APINullable[str] = None
match_: APINullable[str] = None
@ -141,9 +147,9 @@ class ActivitySecrets(BasePresence, APIModelBase):
class ActivityFlags(BasePresence, APIModelBase):
"""
Just Activity Flags (From Discord API).
Just Activity Flags (From Discord API).
Everything returns :class:`bool` value.
Everything returns :class:`bool` value.
"""
def __init__(self, flags) -> None:
@ -172,6 +178,7 @@ class ActivityButton(BasePresence, APIModelBase):
url: :class:`str`
The url opened when clicking the button (1-512 characters)
"""
label: str
url: str
@ -251,11 +258,11 @@ class BotActivity(BasePresence, APIModelBase):
class StatusType(Enum):
ONLINE = 'online'
OFFLINE = 'offline'
IDLE = 'idle'
DND = 'dnd'
INVISIBLE = 'invisible'
ONLINE = "online"
OFFLINE = "offline"
IDLE = "idle"
DND = "dnd"
INVISIBLE = "invisible"
def __str__(self):
return self.value

View file

@ -169,20 +169,12 @@ class User(APIModelBase):
"""APINullable[:class:`~melisa.models.user.user.PremiumTypes`]: The
user their premium type in a usable enum.
"""
return (
None
if self.premium_type is None
else PremiumTypes(self.premium_type)
)
return None if self.premium_type is None else PremiumTypes(self.premium_type)
@property
def flags(self) -> Optional[UserFlags]:
"""Flags of user"""
return(
None
if self.flags is None
else UserFlags(self.flags)
)
return None if self.flags is None else UserFlags(self.flags)
def __str__(self):
"""String representation of the User object"""
@ -195,10 +187,13 @@ class User(APIModelBase):
def avatar_url(self) -> str:
"""Avatar url (from the Discord CDN server)"""
return "https://cdn.discordapp.com/avatars/{}/{}.png?size=1024".format(self.id, self.avatar)
return "https://cdn.discordapp.com/avatars/{}/{}.png?size=1024".format(
self.id, self.avatar
)
async def create_dm_channel(self):
# ToDo: Add docstrings
# ToDo: Add checking this channel in cache
return await self._http.post(
"/users/@me/channels", data={"recipient_id": self.id})
"/users/@me/channels", data={"recipient_id": self.id}
)

View file

@ -1,17 +1,11 @@
# Copyright MelisaDev 2022 - Present
# Full MIT License can be found in `LICENSE.txt` at the project root.
from .types import (
Coro
)
from .types import Coro
from .snowflake import Snowflake
from .api_model import APIModelBase
__all__ = (
"Coro",
"Snowflake",
"APIModelBase"
)
__all__ = ("Coro", "Snowflake", "APIModelBase")

View file

@ -66,9 +66,7 @@ class APIModelBase:
cls._client = client
@classmethod
def from_dict(
cls: Generic[T], data: Dict[str, Union[str, bool, int, Any]]
) -> T:
def from_dict(cls: Generic[T], data: Dict[str, Union[str, bool, int, Any]]) -> T:
"""
Parse an API object from a dictionary.
"""
@ -81,13 +79,10 @@ class APIModelBase:
map(
lambda key: (
key,
data[key].value
if isinstance(data[key], Enum)
else data[key],
data[key].value if isinstance(data[key], Enum) else data[key],
),
filter(
lambda object_argument: data.get(object_argument)
is not None,
lambda object_argument: data.get(object_argument) is not None,
getfullargspec(cls.__init__).args,
),
)

View file

@ -15,12 +15,14 @@ else:
HAS_ORJSON = True
if HAS_ORJSON:
def dumps(obj: Any) -> str:
return orjson.dumps(obj).decode('utf-8')
return orjson.dumps(obj).decode("utf-8")
loads = orjson.loads
else:
def dumps(obj: Any) -> str:
return json.dumps(obj, separators=(',', ':'), ensure_ascii=True)
return json.dumps(obj, separators=(",", ":"), ensure_ascii=True)
loads = json.loads

View file

@ -24,9 +24,7 @@ class Snowflake(int):
super().__init__()
if self < self._MIN_VALUE:
raise ValueError(
"snowflake value should be greater than or equal to 0."
)
raise ValueError("snowflake value should be greater than or equal to 0.")
if self > self._MAX_VALUE:
raise ValueError(
@ -66,7 +64,7 @@ class Snowflake(int):
@property
def increment(self) -> int:
""" For every ID that is generated on that process, this number is incremented"""
"""For every ID that is generated on that process, this number is incremented"""
return self % 2048
@property