mirror of
https://github.com/MelisaDev/melisa.git
synced 2024-11-14 12:27:28 +03:00
rewrite gateway
This commit is contained in:
parent
4bbaff8d57
commit
c85783026c
4 changed files with 115 additions and 93 deletions
|
@ -1,14 +1,15 @@
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import sys
|
||||||
import zlib
|
import zlib
|
||||||
import time
|
import time
|
||||||
from asyncio import ensure_future
|
from asyncio import ensure_future
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
import websockets
|
import aiohttp
|
||||||
|
|
||||||
from ..exceptions import InvalidPayload, GatewayError, PrivilegedIntentsRequired, LoginFailure
|
from ..exceptions import GatewayError, PrivilegedIntentsRequired, LoginFailure
|
||||||
from ..listeners import listeners
|
from ..listeners import listeners
|
||||||
from ..models.user import BotActivity
|
from ..models.user import BotActivity
|
||||||
from ..utils import APIObjectBase
|
from ..utils import APIObjectBase
|
||||||
|
@ -46,22 +47,15 @@ class Gateway:
|
||||||
self.interval = None
|
self.interval = None
|
||||||
self.intents = client.intents
|
self.intents = client.intents
|
||||||
self.sequence = None
|
self.sequence = None
|
||||||
|
self.__session = aiohttp.ClientSession()
|
||||||
self.session_id = None
|
self.session_id = None
|
||||||
self.client = client
|
self.client = client
|
||||||
self.shard_id = shard_id
|
|
||||||
self.latency = float('inf')
|
self.latency = float('inf')
|
||||||
self.connected = False
|
self.ws = None
|
||||||
|
self.loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
self.__close_codes: Dict[int, Any] = {
|
self.__raise_close_codes: Dict[int, Any] = {
|
||||||
4001: GatewayError("Invalid opcode was sent"),
|
|
||||||
4002: InvalidPayload("Invalid payload was sent."),
|
|
||||||
4003: GatewayError("Payload was sent prior to identifying"),
|
|
||||||
4004: LoginFailure("Token is not valid"),
|
4004: LoginFailure("Token is not valid"),
|
||||||
4005: GatewayError(
|
|
||||||
"Authentication was sent after client already authenticated"
|
|
||||||
),
|
|
||||||
4007: GatewayError("Invalid sequence sent when starting new session"),
|
|
||||||
4008: GatewayError("Client was rate limited"),
|
|
||||||
4010: GatewayError("Invalid shard"),
|
4010: GatewayError("Invalid shard"),
|
||||||
4011: GatewayError("Sharding required"),
|
4011: GatewayError("Sharding required"),
|
||||||
4012: GatewayError("Invalid API version"),
|
4012: GatewayError("Invalid API version"),
|
||||||
|
@ -77,7 +71,7 @@ class Gateway:
|
||||||
"token": self.client._token,
|
"token": self.client._token,
|
||||||
"intents": self.intents,
|
"intents": self.intents,
|
||||||
"properties": {
|
"properties": {
|
||||||
"$os": "windows",
|
"$os": sys.platform,
|
||||||
"$browser": "melisa",
|
"$browser": "melisa",
|
||||||
"$device": "melisa"
|
"$device": "melisa"
|
||||||
},
|
},
|
||||||
|
@ -89,7 +83,30 @@ class Gateway:
|
||||||
self._zlib: zlib._Decompress = zlib.decompressobj()
|
self._zlib: zlib._Decompress = zlib.decompressobj()
|
||||||
self._buffer: bytearray = bytearray()
|
self._buffer: bytearray = bytearray()
|
||||||
|
|
||||||
async def websocket_message(self, msg):
|
async def connect(self) -> None:
|
||||||
|
self.ws = await self.__session.ws_connect(
|
||||||
|
f'wss://gateway.discord.gg/?v={self.GATEWAY_VERSION}&encoding=json&compress=zlib-stream')
|
||||||
|
|
||||||
|
if self.session_id is None:
|
||||||
|
await self.send_identify()
|
||||||
|
self.loop.create_task(self.receive())
|
||||||
|
await self.check_heartbeating()
|
||||||
|
else:
|
||||||
|
await self.resume()
|
||||||
|
|
||||||
|
async def check_heartbeating(self):
|
||||||
|
await asyncio.sleep(20)
|
||||||
|
|
||||||
|
if self._last_send + 60.0 < time.perf_counter():
|
||||||
|
await self.ws.close(code=4000)
|
||||||
|
await self.handle_close(4000)
|
||||||
|
|
||||||
|
await self.check_heartbeating()
|
||||||
|
|
||||||
|
async def send(self, payload: str) -> None:
|
||||||
|
await self.ws.send_str(payload)
|
||||||
|
|
||||||
|
async def parse_websocket_message(self, msg):
|
||||||
if type(msg) is bytes:
|
if type(msg) is bytes:
|
||||||
self._buffer.extend(msg)
|
self._buffer.extend(msg)
|
||||||
|
|
||||||
|
@ -101,79 +118,87 @@ class Gateway:
|
||||||
|
|
||||||
return json.loads(msg)
|
return json.loads(msg)
|
||||||
|
|
||||||
async def start_loop(self):
|
async def handle_data(self, data):
|
||||||
async with websockets.connect(
|
if data['op'] == self.DISPATCH:
|
||||||
f'wss://gateway.discord.gg/?v={self.GATEWAY_VERSION}&encoding=json&compress=zlib-stream') \
|
self.sequence = int(data["s"])
|
||||||
as self.websocket:
|
event_type = data["t"].lower()
|
||||||
await self.hello()
|
|
||||||
if self.interval is None:
|
|
||||||
return
|
|
||||||
self.connected = True
|
|
||||||
await asyncio.gather(self.heartbeat(), self.receive())
|
|
||||||
|
|
||||||
async def close(self, code: int = 1000):
|
|
||||||
await self.websocket.close(code=code)
|
|
||||||
|
|
||||||
async def resume(self):
|
|
||||||
resume_data = {
|
|
||||||
"seq": self.sequence,
|
|
||||||
"session_id": self.session_id,
|
|
||||||
"token": self.client._token
|
|
||||||
}
|
|
||||||
|
|
||||||
await self.send(self.RESUME, resume_data)
|
|
||||||
|
|
||||||
async def receive(self):
|
|
||||||
async for msg in self.websocket:
|
|
||||||
msg = await self.websocket_message(msg)
|
|
||||||
|
|
||||||
if msg is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if msg["op"] == self.HEARTBEAT_ACK:
|
|
||||||
self.latency = time.time() - self._last_send
|
|
||||||
|
|
||||||
if msg["op"] == self.DISPATCH:
|
|
||||||
self.sequence = int(msg["s"])
|
|
||||||
event_type = msg["t"].lower()
|
|
||||||
|
|
||||||
event_to_call = self.listeners.get(event_type)
|
event_to_call = self.listeners.get(event_type)
|
||||||
|
|
||||||
if event_to_call is not None:
|
if event_to_call is not None:
|
||||||
await event_to_call(self.client, self, msg["d"])
|
ensure_future(event_to_call(self.client, self, data["d"]))
|
||||||
|
|
||||||
if msg["op"] != self.DISPATCH:
|
elif data['op'] == self.INVALID_SESSION:
|
||||||
if msg["op"] == self.RECONNECT:
|
await self.ws.close(code=4000)
|
||||||
await self.websocket.close()
|
await self.handle_close(4000)
|
||||||
await self.resume()
|
elif data['op'] == self.HELLO:
|
||||||
|
await self.send_hello(data)
|
||||||
|
elif data['op'] == self.HEARTBEAT_ACK:
|
||||||
|
self.latency = time.perf_counter() - self._last_send
|
||||||
|
|
||||||
async def send(self, opcode, payload):
|
async def receive(self) -> None:
|
||||||
data = self.opcode(opcode, payload)
|
async for msg in self.ws:
|
||||||
|
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||||
|
data = await self.parse_websocket_message(msg.data)
|
||||||
|
if data:
|
||||||
|
await self.handle_data(data)
|
||||||
|
elif msg.type == aiohttp.WSMsgType.TEXT:
|
||||||
|
await self.handle_data(msg.data)
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
if opcode == self.HEARTBEAT:
|
close_code = self.ws.close_code
|
||||||
self._last_send = time.time()
|
if close_code is None:
|
||||||
|
|
||||||
await self.websocket.send(data)
|
|
||||||
|
|
||||||
async def heartbeat(self):
|
|
||||||
while self.interval is not None:
|
|
||||||
await self.send(self.HEARTBEAT, self.sequence)
|
|
||||||
self.connected = True
|
|
||||||
await asyncio.sleep(self.interval)
|
|
||||||
|
|
||||||
async def hello(self):
|
|
||||||
await self.send(self.IDENTIFY, self.auth)
|
|
||||||
|
|
||||||
ret = await self.websocket.recv()
|
|
||||||
|
|
||||||
data = await self.websocket_message(ret)
|
|
||||||
|
|
||||||
opcode = data["op"]
|
|
||||||
|
|
||||||
if opcode != 10:
|
|
||||||
return
|
return
|
||||||
|
else:
|
||||||
|
await self.handle_close(close_code)
|
||||||
|
|
||||||
self.interval = (data["d"]["heartbeat_interval"] - 2000) / 1000
|
async def handle_close(self, code: int) -> None:
|
||||||
|
if code == 4009:
|
||||||
|
await self.resume()
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
err = self.__raise_close_codes.get(code)
|
||||||
|
|
||||||
|
if err:
|
||||||
|
raise err
|
||||||
|
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
async def send_heartbeat(self, interval: float) -> None:
|
||||||
|
if not self.ws.closed:
|
||||||
|
await self.send(self.opcode(self.HEARTBEAT, self.sequence))
|
||||||
|
self._last_send = time.perf_counter()
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
self.loop.create_task(self.send_heartbeat(interval))
|
||||||
|
|
||||||
|
async def close(self, code: int = 4000) -> None:
|
||||||
|
if self.ws:
|
||||||
|
await self.ws.close(code=code)
|
||||||
|
self._buffer.clear()
|
||||||
|
|
||||||
|
async def send_hello(self, data: Dict) -> None:
|
||||||
|
interval = data['d']['heartbeat_interval'] / 1000
|
||||||
|
await asyncio.sleep((interval - 2000) / 1000)
|
||||||
|
self.loop.create_task(self.send_heartbeat(interval))
|
||||||
|
|
||||||
|
async def send_identify(self) -> None:
|
||||||
|
await self.send(self.opcode(
|
||||||
|
self.IDENTIFY,
|
||||||
|
self.auth
|
||||||
|
))
|
||||||
|
|
||||||
|
async def resume(self) -> None:
|
||||||
|
await self.send(
|
||||||
|
self.opcode(
|
||||||
|
self.RESUME,
|
||||||
|
{
|
||||||
|
'token': self.client._token,
|
||||||
|
'session_id': self.session_id,
|
||||||
|
'seq': self.sequence,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_presence(activity: BotActivity = None, status: str = None):
|
def generate_presence(activity: BotActivity = None, status: str = None):
|
||||||
|
@ -199,7 +224,7 @@ class Gateway:
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def update_presence(self, data: dict):
|
async def update_presence(self, data: dict):
|
||||||
await self.send(self.PRESENCE_UPDATE, data)
|
await self.send(self.opcode(self.PRESENCE_UPDATE, data))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def opcode(opcode: int, payload) -> str:
|
def opcode(opcode: int, payload) -> str:
|
||||||
|
|
|
@ -8,11 +8,6 @@ class ClientException(MelisaException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InvalidPayload(MelisaException):
|
|
||||||
"""This exception means invalid payload"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class LoginFailure(ClientException):
|
class LoginFailure(ClientException):
|
||||||
"""Fails to log you in from improper credentials or some other misc."""
|
"""Fails to log you in from improper credentials or some other misc."""
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -45,14 +45,15 @@ class Shard:
|
||||||
|
|
||||||
self.disconnected = False
|
self.disconnected = False
|
||||||
|
|
||||||
create_task(self._gateway.start_loop())
|
create_task(self._gateway.connect())
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def _try_close(self) -> None:
|
async def close(self):
|
||||||
if self._gateway.connected:
|
"""|coro|
|
||||||
self._gateway.connected = False
|
Disconnect shard
|
||||||
await self._gateway.close(code=1000)
|
"""
|
||||||
|
create_task(self._gateway.close())
|
||||||
|
|
||||||
async def update_presence(self, activity: BotActivity = None, status: str = None) -> Shard:
|
async def update_presence(self, activity: BotActivity = None, status: str = None) -> Shard:
|
||||||
"""
|
"""
|
||||||
|
@ -89,6 +90,6 @@ class Shard:
|
||||||
wait_time: :class:`int`
|
wait_time: :class:`int`
|
||||||
Reconnect after
|
Reconnect after
|
||||||
"""
|
"""
|
||||||
await self._try_close()
|
await self.close()
|
||||||
await sleep(wait_time)
|
await sleep(wait_time)
|
||||||
await self.launch()
|
await self.launch()
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from ...utils import Snowflake
|
from ...utils import Snowflake
|
||||||
|
|
||||||
|
|
||||||
class Guild:
|
class Guild:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue