run shards, run autosharded methods

This commit is contained in:
grey-cat-1908 2022-03-12 11:15:08 +03:00
parent b1fa3626f9
commit 897154cdc8
5 changed files with 125 additions and 6 deletions

View file

@ -0,0 +1 @@
from .client import Client

View file

@ -2,9 +2,10 @@ from .models.app import Shard
from .utils.types import Coro from .utils.types import Coro
from .core.http import HTTPClient from .core.http import HTTPClient
from .core.gateway import GatewayBotInfo
import asyncio import asyncio
from typing import Dict from typing import Dict, List
class Client: class Client:
@ -17,17 +18,25 @@ class Client:
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self._gateway_info = self.loop.run_until_complete(self._get_gateway())
self.intents = intents self.intents = intents
self._token = token self._token = token
self._activity = kwargs.get("activity") self._activity = kwargs.get("activity")
self._status = kwargs.get("status") self._status = kwargs.get("status")
async def _get_gateway(self):
"""Get Gateway information"""
return GatewayBotInfo.from_dict(await self.http.get("gateway/bot"))
def listen(self, callback: Coro): def listen(self, callback: Coro):
"""Method to set the listener. """Method to set the listener.
Args:
callback (:obj:`function`) Parameters
Coroutine Callback Function ----------
callback (:obj:`function`)
Coroutine Callback Function
""" """
if not asyncio.iscoroutinefunction(callback): if not asyncio.iscoroutinefunction(callback):
raise TypeError(f"<{callback.__qualname__}> must be a coroutine function") raise TypeError(f"<{callback.__qualname__}> must be a coroutine function")
@ -43,3 +52,41 @@ class Client:
asyncio.ensure_future(inited_shard.launch(activity=self._activity, status=self._status), loop=self.loop) asyncio.ensure_future(inited_shard.launch(activity=self._activity, status=self._status), loop=self.loop)
self.loop.run_forever() self.loop.run_forever()
def run_shards(self, num_shards: int, *, shard_ids: List[int] = None):
"""
Run Bot with shards specified by the user.
Parameters
----------
num_shards : :class:`int`
The endpoint to send the request to.
Keyword Arguments:
shard_ids: Optional[:class:`List[int]`]
List of Ids of shards to start.
"""
if not shard_ids:
shard_ids = range(num_shards)
for shard_id in shard_ids:
inited_shard = Shard(self, shard_id, num_shards)
asyncio.ensure_future(inited_shard.launch(activity=self._activity, status=self._status), loop=self.loop)
self.loop.run_forever()
def run_autosharded(self):
"""
Runs the bot with the amount of shards specified by the Discord gateway.
"""
num_shards = self._gateway_info.shards
shard_ids = range(num_shards)
for shard_id in shard_ids:
inited_shard = Shard(self, shard_id, num_shards)
asyncio.ensure_future(inited_shard.launch(activity=self._activity, status=self._status), loop=self.loop)
self.loop.run_forever()

View file

@ -2,11 +2,21 @@ import json
import asyncio import asyncio
import zlib import zlib
import time import time
from dataclasses import dataclass
import websockets import websockets
from ..listeners import listeners from ..listeners import listeners
from ..models.user import BotActivity from ..models.user import BotActivity
from ..utils import APIObjectBase
@dataclass
class GatewayBotInfo(APIObjectBase):
"""Gateway info from the `gateway/bot` endpoint"""
url: str
shards: int
session_start_limit: dict
class Gateway: class Gateway:
@ -79,7 +89,7 @@ class Gateway:
return return
await asyncio.gather(self.heartbeat(), self.receive()) await asyncio.gather(self.heartbeat(), self.receive())
async def close(self, code: int = 4000): async def close(self, code: int = 1000):
await self.websocket.close(code=code) await self.websocket.close(code=code)
async def resume(self): async def resume(self):

View file

@ -4,7 +4,11 @@ from .types import (
from .snowflake import Snowflake from .snowflake import Snowflake
from .api_object import APIObjectBase
__all__ = ( __all__ = (
"Coro", "Coro",
"Snowflake" "Snowflake",
"APIObjectBase"
) )

View file

@ -0,0 +1,57 @@
from __future__ import annotations
from enum import Enum
from inspect import getfullargspec
from typing import (
Dict,
Union,
Generic,
TypeVar,
Any,
)
T = TypeVar("T")
class APIObjectBase:
"""
Represents an object which has been fetched from the Discord API.
"""
_client = None
@property
def _http(self):
if not self._client:
raise AttributeError("Object is not yet linked to a main client")
return self._client.http
@classmethod
def from_dict(
cls: Generic[T], data: Dict[str, Union[str, bool, int, Any]]
) -> T:
"""
Parse an API object from a dictionary.
"""
if isinstance(data, cls):
return data
# noinspection PyArgumentList
return cls(
**dict(
map(
lambda key: (
key,
data[key].value
if isinstance(data[key], Enum)
else data[key],
),
filter(
lambda object_argument: data.get(object_argument)
is not None,
getfullargspec(cls.__init__).args,
),
)
)
)