rewrite gateway

This commit is contained in:
grey-cat-1908 2022-03-17 13:58:58 +03:00
parent 4bbaff8d57
commit c85783026c
4 changed files with 115 additions and 93 deletions

View file

@ -1,14 +1,15 @@
import json import json
import asyncio import asyncio
import sys
import zlib import zlib
import time import time
from asyncio import ensure_future from asyncio import ensure_future
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Any from typing import Dict, Any
import websockets import aiohttp
from ..exceptions import InvalidPayload, GatewayError, PrivilegedIntentsRequired, LoginFailure from ..exceptions import GatewayError, PrivilegedIntentsRequired, LoginFailure
from ..listeners import listeners from ..listeners import listeners
from ..models.user import BotActivity from ..models.user import BotActivity
from ..utils import APIObjectBase from ..utils import APIObjectBase
@ -46,22 +47,15 @@ class Gateway:
self.interval = None self.interval = None
self.intents = client.intents self.intents = client.intents
self.sequence = None self.sequence = None
self.__session = aiohttp.ClientSession()
self.session_id = None self.session_id = None
self.client = client self.client = client
self.shard_id = shard_id
self.latency = float('inf') self.latency = float('inf')
self.connected = False self.ws = None
self.loop = asyncio.get_event_loop()
self.__close_codes: Dict[int, Any] = { self.__raise_close_codes: Dict[int, Any] = {
4001: GatewayError("Invalid opcode was sent"),
4002: InvalidPayload("Invalid payload was sent."),
4003: GatewayError("Payload was sent prior to identifying"),
4004: LoginFailure("Token is not valid"), 4004: LoginFailure("Token is not valid"),
4005: GatewayError(
"Authentication was sent after client already authenticated"
),
4007: GatewayError("Invalid sequence sent when starting new session"),
4008: GatewayError("Client was rate limited"),
4010: GatewayError("Invalid shard"), 4010: GatewayError("Invalid shard"),
4011: GatewayError("Sharding required"), 4011: GatewayError("Sharding required"),
4012: GatewayError("Invalid API version"), 4012: GatewayError("Invalid API version"),
@ -77,7 +71,7 @@ class Gateway:
"token": self.client._token, "token": self.client._token,
"intents": self.intents, "intents": self.intents,
"properties": { "properties": {
"$os": "windows", "$os": sys.platform,
"$browser": "melisa", "$browser": "melisa",
"$device": "melisa" "$device": "melisa"
}, },
@ -89,7 +83,30 @@ class Gateway:
self._zlib: zlib._Decompress = zlib.decompressobj() self._zlib: zlib._Decompress = zlib.decompressobj()
self._buffer: bytearray = bytearray() self._buffer: bytearray = bytearray()
async def websocket_message(self, msg): async def connect(self) -> None:
self.ws = await self.__session.ws_connect(
f'wss://gateway.discord.gg/?v={self.GATEWAY_VERSION}&encoding=json&compress=zlib-stream')
if self.session_id is None:
await self.send_identify()
self.loop.create_task(self.receive())
await self.check_heartbeating()
else:
await self.resume()
async def check_heartbeating(self):
await asyncio.sleep(20)
if self._last_send + 60.0 < time.perf_counter():
await self.ws.close(code=4000)
await self.handle_close(4000)
await self.check_heartbeating()
async def send(self, payload: str) -> None:
await self.ws.send_str(payload)
async def parse_websocket_message(self, msg):
if type(msg) is bytes: if type(msg) is bytes:
self._buffer.extend(msg) self._buffer.extend(msg)
@ -101,79 +118,87 @@ class Gateway:
return json.loads(msg) return json.loads(msg)
async def start_loop(self): async def handle_data(self, data):
async with websockets.connect( if data['op'] == self.DISPATCH:
f'wss://gateway.discord.gg/?v={self.GATEWAY_VERSION}&encoding=json&compress=zlib-stream') \ self.sequence = int(data["s"])
as self.websocket: event_type = data["t"].lower()
await self.hello()
if self.interval is None:
return
self.connected = True
await asyncio.gather(self.heartbeat(), self.receive())
async def close(self, code: int = 1000): event_to_call = self.listeners.get(event_type)
await self.websocket.close(code=code)
async def resume(self): if event_to_call is not None:
resume_data = { ensure_future(event_to_call(self.client, self, data["d"]))
"seq": self.sequence,
"session_id": self.session_id,
"token": self.client._token
}
await self.send(self.RESUME, resume_data) elif data['op'] == self.INVALID_SESSION:
await self.ws.close(code=4000)
await self.handle_close(4000)
elif data['op'] == self.HELLO:
await self.send_hello(data)
elif data['op'] == self.HEARTBEAT_ACK:
self.latency = time.perf_counter() - self._last_send
async def receive(self): async def receive(self) -> None:
async for msg in self.websocket: async for msg in self.ws:
msg = await self.websocket_message(msg) if msg.type == aiohttp.WSMsgType.BINARY:
data = await self.parse_websocket_message(msg.data)
if data:
await self.handle_data(data)
elif msg.type == aiohttp.WSMsgType.TEXT:
await self.handle_data(msg.data)
else:
raise RuntimeError
if msg is None: close_code = self.ws.close_code
return None if close_code is None:
if msg["op"] == self.HEARTBEAT_ACK:
self.latency = time.time() - self._last_send
if msg["op"] == self.DISPATCH:
self.sequence = int(msg["s"])
event_type = msg["t"].lower()
event_to_call = self.listeners.get(event_type)
if event_to_call is not None:
await event_to_call(self.client, self, msg["d"])
if msg["op"] != self.DISPATCH:
if msg["op"] == self.RECONNECT:
await self.websocket.close()
await self.resume()
async def send(self, opcode, payload):
data = self.opcode(opcode, payload)
if opcode == self.HEARTBEAT:
self._last_send = time.time()
await self.websocket.send(data)
async def heartbeat(self):
while self.interval is not None:
await self.send(self.HEARTBEAT, self.sequence)
self.connected = True
await asyncio.sleep(self.interval)
async def hello(self):
await self.send(self.IDENTIFY, self.auth)
ret = await self.websocket.recv()
data = await self.websocket_message(ret)
opcode = data["op"]
if opcode != 10:
return return
else:
await self.handle_close(close_code)
self.interval = (data["d"]["heartbeat_interval"] - 2000) / 1000 async def handle_close(self, code: int) -> None:
if code == 4009:
await self.resume()
return
else:
err = self.__raise_close_codes.get(code)
if err:
raise err
await self.connect()
async def send_heartbeat(self, interval: float) -> None:
if not self.ws.closed:
await self.send(self.opcode(self.HEARTBEAT, self.sequence))
self._last_send = time.perf_counter()
await asyncio.sleep(interval)
self.loop.create_task(self.send_heartbeat(interval))
async def close(self, code: int = 4000) -> None:
if self.ws:
await self.ws.close(code=code)
self._buffer.clear()
async def send_hello(self, data: Dict) -> None:
interval = data['d']['heartbeat_interval'] / 1000
await asyncio.sleep((interval - 2000) / 1000)
self.loop.create_task(self.send_heartbeat(interval))
async def send_identify(self) -> None:
await self.send(self.opcode(
self.IDENTIFY,
self.auth
))
async def resume(self) -> None:
await self.send(
self.opcode(
self.RESUME,
{
'token': self.client._token,
'session_id': self.session_id,
'seq': self.sequence,
}
)
)
@staticmethod @staticmethod
def generate_presence(activity: BotActivity = None, status: str = None): def generate_presence(activity: BotActivity = None, status: str = None):
@ -199,7 +224,7 @@ class Gateway:
return data return data
async def update_presence(self, data: dict): async def update_presence(self, data: dict):
await self.send(self.PRESENCE_UPDATE, data) await self.send(self.opcode(self.PRESENCE_UPDATE, data))
@staticmethod @staticmethod
def opcode(opcode: int, payload) -> str: def opcode(opcode: int, payload) -> str:

View file

@ -8,11 +8,6 @@ class ClientException(MelisaException):
pass pass
class InvalidPayload(MelisaException):
"""This exception means invalid payload"""
pass
class LoginFailure(ClientException): class LoginFailure(ClientException):
"""Fails to log you in from improper credentials or some other misc.""" """Fails to log you in from improper credentials or some other misc."""
pass pass

View file

@ -45,14 +45,15 @@ class Shard:
self.disconnected = False self.disconnected = False
create_task(self._gateway.start_loop()) create_task(self._gateway.connect())
return self return self
async def _try_close(self) -> None: async def close(self):
if self._gateway.connected: """|coro|
self._gateway.connected = False Disconnect shard
await self._gateway.close(code=1000) """
create_task(self._gateway.close())
async def update_presence(self, activity: BotActivity = None, status: str = None) -> Shard: async def update_presence(self, activity: BotActivity = None, status: str = None) -> Shard:
""" """
@ -89,6 +90,6 @@ class Shard:
wait_time: :class:`int` wait_time: :class:`int`
Reconnect after Reconnect after
""" """
await self._try_close() await self.close()
await sleep(wait_time) await sleep(wait_time)
await self.launch() await self.launch()

View file

@ -1,5 +1,6 @@
from ...utils import Snowflake from ...utils import Snowflake
class Guild: class Guild:
pass pass