run shards, run autosharded methods

This commit is contained in:
grey-cat-1908 2022-03-12 11:15:08 +03:00
parent b1fa3626f9
commit 897154cdc8
5 changed files with 125 additions and 6 deletions

View file

@ -0,0 +1 @@
from .client import Client

View file

@ -2,9 +2,10 @@ from .models.app import Shard
from .utils.types import Coro
from .core.http import HTTPClient
from .core.gateway import GatewayBotInfo
import asyncio
from typing import Dict
from typing import Dict, List
class Client:
@ -17,17 +18,25 @@ class Client:
self.loop = asyncio.get_event_loop()
self._gateway_info = self.loop.run_until_complete(self._get_gateway())
self.intents = intents
self._token = token
self._activity = kwargs.get("activity")
self._status = kwargs.get("status")
async def _get_gateway(self):
"""Get Gateway information"""
return GatewayBotInfo.from_dict(await self.http.get("gateway/bot"))
def listen(self, callback: Coro):
"""Method to set the listener.
Args:
callback (:obj:`function`)
Coroutine Callback Function
Parameters
----------
callback (:obj:`function`)
Coroutine Callback Function
"""
if not asyncio.iscoroutinefunction(callback):
raise TypeError(f"<{callback.__qualname__}> must be a coroutine function")
@ -43,3 +52,41 @@ class Client:
asyncio.ensure_future(inited_shard.launch(activity=self._activity, status=self._status), loop=self.loop)
self.loop.run_forever()
def run_shards(self, num_shards: int, *, shard_ids: List[int] = None):
"""
Run Bot with shards specified by the user.
Parameters
----------
num_shards : :class:`int`
The endpoint to send the request to.
Keyword Arguments:
shard_ids: Optional[:class:`List[int]`]
List of Ids of shards to start.
"""
if not shard_ids:
shard_ids = range(num_shards)
for shard_id in shard_ids:
inited_shard = Shard(self, shard_id, num_shards)
asyncio.ensure_future(inited_shard.launch(activity=self._activity, status=self._status), loop=self.loop)
self.loop.run_forever()
def run_autosharded(self):
"""
Runs the bot with the amount of shards specified by the Discord gateway.
"""
num_shards = self._gateway_info.shards
shard_ids = range(num_shards)
for shard_id in shard_ids:
inited_shard = Shard(self, shard_id, num_shards)
asyncio.ensure_future(inited_shard.launch(activity=self._activity, status=self._status), loop=self.loop)
self.loop.run_forever()

View file

@ -2,11 +2,21 @@ import json
import asyncio
import zlib
import time
from dataclasses import dataclass
import websockets
from ..listeners import listeners
from ..models.user import BotActivity
from ..utils import APIObjectBase
@dataclass
class GatewayBotInfo(APIObjectBase):
"""Gateway info from the `gateway/bot` endpoint"""
url: str
shards: int
session_start_limit: dict
class Gateway:
@ -79,7 +89,7 @@ class Gateway:
return
await asyncio.gather(self.heartbeat(), self.receive())
async def close(self, code: int = 4000):
async def close(self, code: int = 1000):
await self.websocket.close(code=code)
async def resume(self):

View file

@ -4,7 +4,11 @@ from .types import (
from .snowflake import Snowflake
from .api_object import APIObjectBase
__all__ = (
"Coro",
"Snowflake"
"Snowflake",
"APIObjectBase"
)

View file

@ -0,0 +1,57 @@
from __future__ import annotations
from enum import Enum
from inspect import getfullargspec
from typing import (
Dict,
Union,
Generic,
TypeVar,
Any,
)
T = TypeVar("T")
class APIObjectBase:
"""
Represents an object which has been fetched from the Discord API.
"""
_client = None
@property
def _http(self):
if not self._client:
raise AttributeError("Object is not yet linked to a main client")
return self._client.http
@classmethod
def from_dict(
cls: Generic[T], data: Dict[str, Union[str, bool, int, Any]]
) -> T:
"""
Parse an API object from a dictionary.
"""
if isinstance(data, cls):
return data
# noinspection PyArgumentList
return cls(
**dict(
map(
lambda key: (
key,
data[key].value
if isinstance(data[key], Enum)
else data[key],
),
filter(
lambda object_argument: data.get(object_argument)
is not None,
getfullargspec(cls.__init__).args,
),
)
)
)