mirror of
https://github.com/MelisaDev/melisa.git
synced 2024-09-22 19:22:01 +03:00
run shards, run autosharded methods
This commit is contained in:
parent
b1fa3626f9
commit
897154cdc8
5 changed files with 125 additions and 6 deletions
|
@ -0,0 +1 @@
|
||||||
|
from .client import Client
|
|
@ -2,9 +2,10 @@ from .models.app import Shard
|
||||||
from .utils.types import Coro
|
from .utils.types import Coro
|
||||||
|
|
||||||
from .core.http import HTTPClient
|
from .core.http import HTTPClient
|
||||||
|
from .core.gateway import GatewayBotInfo
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Dict
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
||||||
class Client:
|
class Client:
|
||||||
|
@ -17,15 +18,23 @@ class Client:
|
||||||
|
|
||||||
self.loop = asyncio.get_event_loop()
|
self.loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
self._gateway_info = self.loop.run_until_complete(self._get_gateway())
|
||||||
|
|
||||||
self.intents = intents
|
self.intents = intents
|
||||||
self._token = token
|
self._token = token
|
||||||
|
|
||||||
self._activity = kwargs.get("activity")
|
self._activity = kwargs.get("activity")
|
||||||
self._status = kwargs.get("status")
|
self._status = kwargs.get("status")
|
||||||
|
|
||||||
|
async def _get_gateway(self):
|
||||||
|
"""Get Gateway information"""
|
||||||
|
return GatewayBotInfo.from_dict(await self.http.get("gateway/bot"))
|
||||||
|
|
||||||
def listen(self, callback: Coro):
|
def listen(self, callback: Coro):
|
||||||
"""Method to set the listener.
|
"""Method to set the listener.
|
||||||
Args:
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
callback (:obj:`function`)
|
callback (:obj:`function`)
|
||||||
Coroutine Callback Function
|
Coroutine Callback Function
|
||||||
"""
|
"""
|
||||||
|
@ -43,3 +52,41 @@ class Client:
|
||||||
|
|
||||||
asyncio.ensure_future(inited_shard.launch(activity=self._activity, status=self._status), loop=self.loop)
|
asyncio.ensure_future(inited_shard.launch(activity=self._activity, status=self._status), loop=self.loop)
|
||||||
self.loop.run_forever()
|
self.loop.run_forever()
|
||||||
|
|
||||||
|
def run_shards(self, num_shards: int, *, shard_ids: List[int] = None):
|
||||||
|
"""
|
||||||
|
Run Bot with shards specified by the user.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
num_shards : :class:`int`
|
||||||
|
The endpoint to send the request to.
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
|
||||||
|
shard_ids: Optional[:class:`List[int]`]
|
||||||
|
List of Ids of shards to start.
|
||||||
|
"""
|
||||||
|
if not shard_ids:
|
||||||
|
shard_ids = range(num_shards)
|
||||||
|
|
||||||
|
for shard_id in shard_ids:
|
||||||
|
inited_shard = Shard(self, shard_id, num_shards)
|
||||||
|
|
||||||
|
asyncio.ensure_future(inited_shard.launch(activity=self._activity, status=self._status), loop=self.loop)
|
||||||
|
self.loop.run_forever()
|
||||||
|
|
||||||
|
def run_autosharded(self):
|
||||||
|
"""
|
||||||
|
Runs the bot with the amount of shards specified by the Discord gateway.
|
||||||
|
"""
|
||||||
|
num_shards = self._gateway_info.shards
|
||||||
|
shard_ids = range(num_shards)
|
||||||
|
|
||||||
|
for shard_id in shard_ids:
|
||||||
|
inited_shard = Shard(self, shard_id, num_shards)
|
||||||
|
|
||||||
|
asyncio.ensure_future(inited_shard.launch(activity=self._activity, status=self._status), loop=self.loop)
|
||||||
|
self.loop.run_forever()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,11 +2,21 @@ import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import zlib
|
import zlib
|
||||||
import time
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
|
|
||||||
from ..listeners import listeners
|
from ..listeners import listeners
|
||||||
from ..models.user import BotActivity
|
from ..models.user import BotActivity
|
||||||
|
from ..utils import APIObjectBase
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GatewayBotInfo(APIObjectBase):
|
||||||
|
"""Gateway info from the `gateway/bot` endpoint"""
|
||||||
|
url: str
|
||||||
|
shards: int
|
||||||
|
session_start_limit: dict
|
||||||
|
|
||||||
|
|
||||||
class Gateway:
|
class Gateway:
|
||||||
|
@ -79,7 +89,7 @@ class Gateway:
|
||||||
return
|
return
|
||||||
await asyncio.gather(self.heartbeat(), self.receive())
|
await asyncio.gather(self.heartbeat(), self.receive())
|
||||||
|
|
||||||
async def close(self, code: int = 4000):
|
async def close(self, code: int = 1000):
|
||||||
await self.websocket.close(code=code)
|
await self.websocket.close(code=code)
|
||||||
|
|
||||||
async def resume(self):
|
async def resume(self):
|
||||||
|
|
|
@ -4,7 +4,11 @@ from .types import (
|
||||||
|
|
||||||
from .snowflake import Snowflake
|
from .snowflake import Snowflake
|
||||||
|
|
||||||
|
|
||||||
|
from .api_object import APIObjectBase
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"Coro",
|
"Coro",
|
||||||
"Snowflake"
|
"Snowflake",
|
||||||
|
"APIObjectBase"
|
||||||
)
|
)
|
||||||
|
|
57
melisa/utils/api_object.py
Normal file
57
melisa/utils/api_object.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from inspect import getfullargspec
|
||||||
|
from typing import (
|
||||||
|
Dict,
|
||||||
|
Union,
|
||||||
|
Generic,
|
||||||
|
TypeVar,
|
||||||
|
Any,
|
||||||
|
)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class APIObjectBase:
|
||||||
|
"""
|
||||||
|
Represents an object which has been fetched from the Discord API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_client = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _http(self):
|
||||||
|
if not self._client:
|
||||||
|
raise AttributeError("Object is not yet linked to a main client")
|
||||||
|
|
||||||
|
return self._client.http
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(
|
||||||
|
cls: Generic[T], data: Dict[str, Union[str, bool, int, Any]]
|
||||||
|
) -> T:
|
||||||
|
"""
|
||||||
|
Parse an API object from a dictionary.
|
||||||
|
"""
|
||||||
|
if isinstance(data, cls):
|
||||||
|
return data
|
||||||
|
|
||||||
|
# noinspection PyArgumentList
|
||||||
|
return cls(
|
||||||
|
**dict(
|
||||||
|
map(
|
||||||
|
lambda key: (
|
||||||
|
key,
|
||||||
|
data[key].value
|
||||||
|
if isinstance(data[key], Enum)
|
||||||
|
else data[key],
|
||||||
|
),
|
||||||
|
filter(
|
||||||
|
lambda object_argument: data.get(object_argument)
|
||||||
|
is not None,
|
||||||
|
getfullargspec(cls.__init__).args,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
Loading…
Reference in a new issue