mirror of
https://github.com/MelisaDev/melisa.git
synced 2024-11-11 19:07:28 +03:00
rewrite gateway
This commit is contained in:
parent
4bbaff8d57
commit
c85783026c
4 changed files with 115 additions and 93 deletions
|
@ -1,14 +1,15 @@
|
|||
import json
|
||||
import asyncio
|
||||
import sys
|
||||
import zlib
|
||||
import time
|
||||
from asyncio import ensure_future
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any
|
||||
|
||||
import websockets
|
||||
import aiohttp
|
||||
|
||||
from ..exceptions import InvalidPayload, GatewayError, PrivilegedIntentsRequired, LoginFailure
|
||||
from ..exceptions import GatewayError, PrivilegedIntentsRequired, LoginFailure
|
||||
from ..listeners import listeners
|
||||
from ..models.user import BotActivity
|
||||
from ..utils import APIObjectBase
|
||||
|
@ -46,22 +47,15 @@ class Gateway:
|
|||
self.interval = None
|
||||
self.intents = client.intents
|
||||
self.sequence = None
|
||||
self.__session = aiohttp.ClientSession()
|
||||
self.session_id = None
|
||||
self.client = client
|
||||
self.shard_id = shard_id
|
||||
self.latency = float('inf')
|
||||
self.connected = False
|
||||
self.ws = None
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
self.__close_codes: Dict[int, Any] = {
|
||||
4001: GatewayError("Invalid opcode was sent"),
|
||||
4002: InvalidPayload("Invalid payload was sent."),
|
||||
4003: GatewayError("Payload was sent prior to identifying"),
|
||||
self.__raise_close_codes: Dict[int, Any] = {
|
||||
4004: LoginFailure("Token is not valid"),
|
||||
4005: GatewayError(
|
||||
"Authentication was sent after client already authenticated"
|
||||
),
|
||||
4007: GatewayError("Invalid sequence sent when starting new session"),
|
||||
4008: GatewayError("Client was rate limited"),
|
||||
4010: GatewayError("Invalid shard"),
|
||||
4011: GatewayError("Sharding required"),
|
||||
4012: GatewayError("Invalid API version"),
|
||||
|
@ -77,7 +71,7 @@ class Gateway:
|
|||
"token": self.client._token,
|
||||
"intents": self.intents,
|
||||
"properties": {
|
||||
"$os": "windows",
|
||||
"$os": sys.platform,
|
||||
"$browser": "melisa",
|
||||
"$device": "melisa"
|
||||
},
|
||||
|
@ -89,7 +83,30 @@ class Gateway:
|
|||
self._zlib: zlib._Decompress = zlib.decompressobj()
|
||||
self._buffer: bytearray = bytearray()
|
||||
|
||||
async def websocket_message(self, msg):
|
||||
async def connect(self) -> None:
|
||||
self.ws = await self.__session.ws_connect(
|
||||
f'wss://gateway.discord.gg/?v={self.GATEWAY_VERSION}&encoding=json&compress=zlib-stream')
|
||||
|
||||
if self.session_id is None:
|
||||
await self.send_identify()
|
||||
self.loop.create_task(self.receive())
|
||||
await self.check_heartbeating()
|
||||
else:
|
||||
await self.resume()
|
||||
|
||||
async def check_heartbeating(self):
|
||||
await asyncio.sleep(20)
|
||||
|
||||
if self._last_send + 60.0 < time.perf_counter():
|
||||
await self.ws.close(code=4000)
|
||||
await self.handle_close(4000)
|
||||
|
||||
await self.check_heartbeating()
|
||||
|
||||
async def send(self, payload: str) -> None:
|
||||
await self.ws.send_str(payload)
|
||||
|
||||
async def parse_websocket_message(self, msg):
|
||||
if type(msg) is bytes:
|
||||
self._buffer.extend(msg)
|
||||
|
||||
|
@ -101,79 +118,87 @@ class Gateway:
|
|||
|
||||
return json.loads(msg)
|
||||
|
||||
async def start_loop(self):
|
||||
async with websockets.connect(
|
||||
f'wss://gateway.discord.gg/?v={self.GATEWAY_VERSION}&encoding=json&compress=zlib-stream') \
|
||||
as self.websocket:
|
||||
await self.hello()
|
||||
if self.interval is None:
|
||||
return
|
||||
self.connected = True
|
||||
await asyncio.gather(self.heartbeat(), self.receive())
|
||||
async def handle_data(self, data):
|
||||
if data['op'] == self.DISPATCH:
|
||||
self.sequence = int(data["s"])
|
||||
event_type = data["t"].lower()
|
||||
|
||||
async def close(self, code: int = 1000):
|
||||
await self.websocket.close(code=code)
|
||||
event_to_call = self.listeners.get(event_type)
|
||||
|
||||
async def resume(self):
|
||||
resume_data = {
|
||||
"seq": self.sequence,
|
||||
"session_id": self.session_id,
|
||||
"token": self.client._token
|
||||
}
|
||||
if event_to_call is not None:
|
||||
ensure_future(event_to_call(self.client, self, data["d"]))
|
||||
|
||||
await self.send(self.RESUME, resume_data)
|
||||
elif data['op'] == self.INVALID_SESSION:
|
||||
await self.ws.close(code=4000)
|
||||
await self.handle_close(4000)
|
||||
elif data['op'] == self.HELLO:
|
||||
await self.send_hello(data)
|
||||
elif data['op'] == self.HEARTBEAT_ACK:
|
||||
self.latency = time.perf_counter() - self._last_send
|
||||
|
||||
async def receive(self):
|
||||
async for msg in self.websocket:
|
||||
msg = await self.websocket_message(msg)
|
||||
async def receive(self) -> None:
|
||||
async for msg in self.ws:
|
||||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||
data = await self.parse_websocket_message(msg.data)
|
||||
if data:
|
||||
await self.handle_data(data)
|
||||
elif msg.type == aiohttp.WSMsgType.TEXT:
|
||||
await self.handle_data(msg.data)
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
if msg is None:
|
||||
return None
|
||||
|
||||
if msg["op"] == self.HEARTBEAT_ACK:
|
||||
self.latency = time.time() - self._last_send
|
||||
|
||||
if msg["op"] == self.DISPATCH:
|
||||
self.sequence = int(msg["s"])
|
||||
event_type = msg["t"].lower()
|
||||
|
||||
event_to_call = self.listeners.get(event_type)
|
||||
|
||||
if event_to_call is not None:
|
||||
await event_to_call(self.client, self, msg["d"])
|
||||
|
||||
if msg["op"] != self.DISPATCH:
|
||||
if msg["op"] == self.RECONNECT:
|
||||
await self.websocket.close()
|
||||
await self.resume()
|
||||
|
||||
async def send(self, opcode, payload):
|
||||
data = self.opcode(opcode, payload)
|
||||
|
||||
if opcode == self.HEARTBEAT:
|
||||
self._last_send = time.time()
|
||||
|
||||
await self.websocket.send(data)
|
||||
|
||||
async def heartbeat(self):
|
||||
while self.interval is not None:
|
||||
await self.send(self.HEARTBEAT, self.sequence)
|
||||
self.connected = True
|
||||
await asyncio.sleep(self.interval)
|
||||
|
||||
async def hello(self):
|
||||
await self.send(self.IDENTIFY, self.auth)
|
||||
|
||||
ret = await self.websocket.recv()
|
||||
|
||||
data = await self.websocket_message(ret)
|
||||
|
||||
opcode = data["op"]
|
||||
|
||||
if opcode != 10:
|
||||
close_code = self.ws.close_code
|
||||
if close_code is None:
|
||||
return
|
||||
else:
|
||||
await self.handle_close(close_code)
|
||||
|
||||
self.interval = (data["d"]["heartbeat_interval"] - 2000) / 1000
|
||||
async def handle_close(self, code: int) -> None:
|
||||
if code == 4009:
|
||||
await self.resume()
|
||||
return
|
||||
else:
|
||||
err = self.__raise_close_codes.get(code)
|
||||
|
||||
if err:
|
||||
raise err
|
||||
|
||||
await self.connect()
|
||||
|
||||
async def send_heartbeat(self, interval: float) -> None:
|
||||
if not self.ws.closed:
|
||||
await self.send(self.opcode(self.HEARTBEAT, self.sequence))
|
||||
self._last_send = time.perf_counter()
|
||||
await asyncio.sleep(interval)
|
||||
self.loop.create_task(self.send_heartbeat(interval))
|
||||
|
||||
async def close(self, code: int = 4000) -> None:
|
||||
if self.ws:
|
||||
await self.ws.close(code=code)
|
||||
self._buffer.clear()
|
||||
|
||||
async def send_hello(self, data: Dict) -> None:
|
||||
interval = data['d']['heartbeat_interval'] / 1000
|
||||
await asyncio.sleep((interval - 2000) / 1000)
|
||||
self.loop.create_task(self.send_heartbeat(interval))
|
||||
|
||||
async def send_identify(self) -> None:
|
||||
await self.send(self.opcode(
|
||||
self.IDENTIFY,
|
||||
self.auth
|
||||
))
|
||||
|
||||
async def resume(self) -> None:
|
||||
await self.send(
|
||||
self.opcode(
|
||||
self.RESUME,
|
||||
{
|
||||
'token': self.client._token,
|
||||
'session_id': self.session_id,
|
||||
'seq': self.sequence,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_presence(activity: BotActivity = None, status: str = None):
|
||||
|
@ -199,7 +224,7 @@ class Gateway:
|
|||
return data
|
||||
|
||||
async def update_presence(self, data: dict):
|
||||
await self.send(self.PRESENCE_UPDATE, data)
|
||||
await self.send(self.opcode(self.PRESENCE_UPDATE, data))
|
||||
|
||||
@staticmethod
|
||||
def opcode(opcode: int, payload) -> str:
|
||||
|
|
|
@ -8,11 +8,6 @@ class ClientException(MelisaException):
|
|||
pass
|
||||
|
||||
|
||||
class InvalidPayload(MelisaException):
|
||||
"""This exception means invalid payload"""
|
||||
pass
|
||||
|
||||
|
||||
class LoginFailure(ClientException):
|
||||
"""Fails to log you in from improper credentials or some other misc."""
|
||||
pass
|
||||
|
|
|
@ -45,14 +45,15 @@ class Shard:
|
|||
|
||||
self.disconnected = False
|
||||
|
||||
create_task(self._gateway.start_loop())
|
||||
create_task(self._gateway.connect())
|
||||
|
||||
return self
|
||||
|
||||
async def _try_close(self) -> None:
|
||||
if self._gateway.connected:
|
||||
self._gateway.connected = False
|
||||
await self._gateway.close(code=1000)
|
||||
async def close(self):
|
||||
"""|coro|
|
||||
Disconnect shard
|
||||
"""
|
||||
create_task(self._gateway.close())
|
||||
|
||||
async def update_presence(self, activity: BotActivity = None, status: str = None) -> Shard:
|
||||
"""
|
||||
|
@ -89,6 +90,6 @@ class Shard:
|
|||
wait_time: :class:`int`
|
||||
Reconnect after
|
||||
"""
|
||||
await self._try_close()
|
||||
await self.close()
|
||||
await sleep(wait_time)
|
||||
await self.launch()
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from ...utils import Snowflake
|
||||
|
||||
|
||||
class Guild:
|
||||
pass
|
||||
|
||||
|
|
Loading…
Reference in a new issue