mirror of
https://github.com/MelisaDev/melisa.git
synced 2024-09-22 19:22:01 +03:00
run shards, run autosharded methods
This commit is contained in:
parent
b1fa3626f9
commit
897154cdc8
5 changed files with 125 additions and 6 deletions
|
@ -0,0 +1 @@
|
|||
from .client import Client
|
|
@ -2,9 +2,10 @@ from .models.app import Shard
|
|||
from .utils.types import Coro
|
||||
|
||||
from .core.http import HTTPClient
|
||||
from .core.gateway import GatewayBotInfo
|
||||
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
class Client:
|
||||
|
@ -17,15 +18,23 @@ class Client:
|
|||
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
self._gateway_info = self.loop.run_until_complete(self._get_gateway())
|
||||
|
||||
self.intents = intents
|
||||
self._token = token
|
||||
|
||||
self._activity = kwargs.get("activity")
|
||||
self._status = kwargs.get("status")
|
||||
|
||||
async def _get_gateway(self):
|
||||
"""Get Gateway information"""
|
||||
return GatewayBotInfo.from_dict(await self.http.get("gateway/bot"))
|
||||
|
||||
def listen(self, callback: Coro):
|
||||
"""Method to set the listener.
|
||||
Args:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
callback (:obj:`function`)
|
||||
Coroutine Callback Function
|
||||
"""
|
||||
|
@ -43,3 +52,41 @@ class Client:
|
|||
|
||||
asyncio.ensure_future(inited_shard.launch(activity=self._activity, status=self._status), loop=self.loop)
|
||||
self.loop.run_forever()
|
||||
|
||||
def run_shards(self, num_shards: int, *, shard_ids: List[int] = None):
|
||||
"""
|
||||
Run Bot with shards specified by the user.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_shards : :class:`int`
|
||||
The endpoint to send the request to.
|
||||
|
||||
Keyword Arguments:
|
||||
|
||||
shard_ids: Optional[:class:`List[int]`]
|
||||
List of Ids of shards to start.
|
||||
"""
|
||||
if not shard_ids:
|
||||
shard_ids = range(num_shards)
|
||||
|
||||
for shard_id in shard_ids:
|
||||
inited_shard = Shard(self, shard_id, num_shards)
|
||||
|
||||
asyncio.ensure_future(inited_shard.launch(activity=self._activity, status=self._status), loop=self.loop)
|
||||
self.loop.run_forever()
|
||||
|
||||
def run_autosharded(self):
|
||||
"""
|
||||
Runs the bot with the amount of shards specified by the Discord gateway.
|
||||
"""
|
||||
num_shards = self._gateway_info.shards
|
||||
shard_ids = range(num_shards)
|
||||
|
||||
for shard_id in shard_ids:
|
||||
inited_shard = Shard(self, shard_id, num_shards)
|
||||
|
||||
asyncio.ensure_future(inited_shard.launch(activity=self._activity, status=self._status), loop=self.loop)
|
||||
self.loop.run_forever()
|
||||
|
||||
|
||||
|
|
|
@ -2,11 +2,21 @@ import json
|
|||
import asyncio
|
||||
import zlib
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import websockets
|
||||
|
||||
from ..listeners import listeners
|
||||
from ..models.user import BotActivity
|
||||
from ..utils import APIObjectBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class GatewayBotInfo(APIObjectBase):
|
||||
"""Gateway info from the `gateway/bot` endpoint"""
|
||||
url: str
|
||||
shards: int
|
||||
session_start_limit: dict
|
||||
|
||||
|
||||
class Gateway:
|
||||
|
@ -79,7 +89,7 @@ class Gateway:
|
|||
return
|
||||
await asyncio.gather(self.heartbeat(), self.receive())
|
||||
|
||||
async def close(self, code: int = 4000):
|
||||
async def close(self, code: int = 1000):
|
||||
await self.websocket.close(code=code)
|
||||
|
||||
async def resume(self):
|
||||
|
|
|
@ -4,7 +4,11 @@ from .types import (
|
|||
|
||||
from .snowflake import Snowflake
|
||||
|
||||
|
||||
from .api_object import APIObjectBase
|
||||
|
||||
__all__ = (
|
||||
"Coro",
|
||||
"Snowflake"
|
||||
"Snowflake",
|
||||
"APIObjectBase"
|
||||
)
|
||||
|
|
57
melisa/utils/api_object.py
Normal file
57
melisa/utils/api_object.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from inspect import getfullargspec
|
||||
from typing import (
|
||||
Dict,
|
||||
Union,
|
||||
Generic,
|
||||
TypeVar,
|
||||
Any,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class APIObjectBase:
|
||||
"""
|
||||
Represents an object which has been fetched from the Discord API.
|
||||
"""
|
||||
|
||||
_client = None
|
||||
|
||||
@property
|
||||
def _http(self):
|
||||
if not self._client:
|
||||
raise AttributeError("Object is not yet linked to a main client")
|
||||
|
||||
return self._client.http
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls: Generic[T], data: Dict[str, Union[str, bool, int, Any]]
|
||||
) -> T:
|
||||
"""
|
||||
Parse an API object from a dictionary.
|
||||
"""
|
||||
if isinstance(data, cls):
|
||||
return data
|
||||
|
||||
# noinspection PyArgumentList
|
||||
return cls(
|
||||
**dict(
|
||||
map(
|
||||
lambda key: (
|
||||
key,
|
||||
data[key].value
|
||||
if isinstance(data[key], Enum)
|
||||
else data[key],
|
||||
),
|
||||
filter(
|
||||
lambda object_argument: data.get(object_argument)
|
||||
is not None,
|
||||
getfullargspec(cls.__init__).args,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
Loading…
Reference in a new issue