rewrite gateway

This commit is contained in:
grey-cat-1908 2022-03-17 13:58:58 +03:00
parent 4bbaff8d57
commit c85783026c
4 changed files with 115 additions and 93 deletions

View file

@ -1,14 +1,15 @@
import json
import asyncio
import sys
import zlib
import time
from asyncio import ensure_future
from dataclasses import dataclass
from typing import Dict, Any
import websockets
import aiohttp
from ..exceptions import InvalidPayload, GatewayError, PrivilegedIntentsRequired, LoginFailure
from ..exceptions import GatewayError, PrivilegedIntentsRequired, LoginFailure
from ..listeners import listeners
from ..models.user import BotActivity
from ..utils import APIObjectBase
@ -46,22 +47,15 @@ class Gateway:
self.interval = None
self.intents = client.intents
self.sequence = None
self.__session = aiohttp.ClientSession()
self.session_id = None
self.client = client
self.shard_id = shard_id
self.latency = float('inf')
self.connected = False
self.ws = None
self.loop = asyncio.get_event_loop()
self.__close_codes: Dict[int, Any] = {
4001: GatewayError("Invalid opcode was sent"),
4002: InvalidPayload("Invalid payload was sent."),
4003: GatewayError("Payload was sent prior to identifying"),
self.__raise_close_codes: Dict[int, Any] = {
4004: LoginFailure("Token is not valid"),
4005: GatewayError(
"Authentication was sent after client already authenticated"
),
4007: GatewayError("Invalid sequence sent when starting new session"),
4008: GatewayError("Client was rate limited"),
4010: GatewayError("Invalid shard"),
4011: GatewayError("Sharding required"),
4012: GatewayError("Invalid API version"),
@ -77,7 +71,7 @@ class Gateway:
"token": self.client._token,
"intents": self.intents,
"properties": {
"$os": "windows",
"$os": sys.platform,
"$browser": "melisa",
"$device": "melisa"
},
@ -89,7 +83,30 @@ class Gateway:
self._zlib: zlib._Decompress = zlib.decompressobj()
self._buffer: bytearray = bytearray()
async def websocket_message(self, msg):
async def connect(self) -> None:
self.ws = await self.__session.ws_connect(
f'wss://gateway.discord.gg/?v={self.GATEWAY_VERSION}&encoding=json&compress=zlib-stream')
if self.session_id is None:
await self.send_identify()
self.loop.create_task(self.receive())
await self.check_heartbeating()
else:
await self.resume()
async def check_heartbeating(self):
await asyncio.sleep(20)
if self._last_send + 60.0 < time.perf_counter():
await self.ws.close(code=4000)
await self.handle_close(4000)
await self.check_heartbeating()
async def send(self, payload: str) -> None:
await self.ws.send_str(payload)
async def parse_websocket_message(self, msg):
if type(msg) is bytes:
self._buffer.extend(msg)
@ -101,79 +118,87 @@ class Gateway:
return json.loads(msg)
async def start_loop(self):
async with websockets.connect(
f'wss://gateway.discord.gg/?v={self.GATEWAY_VERSION}&encoding=json&compress=zlib-stream') \
as self.websocket:
await self.hello()
if self.interval is None:
return
self.connected = True
await asyncio.gather(self.heartbeat(), self.receive())
async def handle_data(self, data):
if data['op'] == self.DISPATCH:
self.sequence = int(data["s"])
event_type = data["t"].lower()
async def close(self, code: int = 1000):
await self.websocket.close(code=code)
event_to_call = self.listeners.get(event_type)
async def resume(self):
resume_data = {
"seq": self.sequence,
"session_id": self.session_id,
"token": self.client._token
}
if event_to_call is not None:
ensure_future(event_to_call(self.client, self, data["d"]))
await self.send(self.RESUME, resume_data)
elif data['op'] == self.INVALID_SESSION:
await self.ws.close(code=4000)
await self.handle_close(4000)
elif data['op'] == self.HELLO:
await self.send_hello(data)
elif data['op'] == self.HEARTBEAT_ACK:
self.latency = time.perf_counter() - self._last_send
async def receive(self):
async for msg in self.websocket:
msg = await self.websocket_message(msg)
async def receive(self) -> None:
async for msg in self.ws:
if msg.type == aiohttp.WSMsgType.BINARY:
data = await self.parse_websocket_message(msg.data)
if data:
await self.handle_data(data)
elif msg.type == aiohttp.WSMsgType.TEXT:
await self.handle_data(msg.data)
else:
raise RuntimeError
if msg is None:
return None
if msg["op"] == self.HEARTBEAT_ACK:
self.latency = time.time() - self._last_send
if msg["op"] == self.DISPATCH:
self.sequence = int(msg["s"])
event_type = msg["t"].lower()
event_to_call = self.listeners.get(event_type)
if event_to_call is not None:
await event_to_call(self.client, self, msg["d"])
if msg["op"] != self.DISPATCH:
if msg["op"] == self.RECONNECT:
await self.websocket.close()
await self.resume()
async def send(self, opcode, payload):
data = self.opcode(opcode, payload)
if opcode == self.HEARTBEAT:
self._last_send = time.time()
await self.websocket.send(data)
async def heartbeat(self):
while self.interval is not None:
await self.send(self.HEARTBEAT, self.sequence)
self.connected = True
await asyncio.sleep(self.interval)
async def hello(self):
await self.send(self.IDENTIFY, self.auth)
ret = await self.websocket.recv()
data = await self.websocket_message(ret)
opcode = data["op"]
if opcode != 10:
close_code = self.ws.close_code
if close_code is None:
return
else:
await self.handle_close(close_code)
self.interval = (data["d"]["heartbeat_interval"] - 2000) / 1000
async def handle_close(self, code: int) -> None:
if code == 4009:
await self.resume()
return
else:
err = self.__raise_close_codes.get(code)
if err:
raise err
await self.connect()
async def send_heartbeat(self, interval: float) -> None:
if not self.ws.closed:
await self.send(self.opcode(self.HEARTBEAT, self.sequence))
self._last_send = time.perf_counter()
await asyncio.sleep(interval)
self.loop.create_task(self.send_heartbeat(interval))
async def close(self, code: int = 4000) -> None:
if self.ws:
await self.ws.close(code=code)
self._buffer.clear()
async def send_hello(self, data: Dict) -> None:
interval = data['d']['heartbeat_interval'] / 1000
await asyncio.sleep((interval - 2000) / 1000)
self.loop.create_task(self.send_heartbeat(interval))
async def send_identify(self) -> None:
await self.send(self.opcode(
self.IDENTIFY,
self.auth
))
async def resume(self) -> None:
await self.send(
self.opcode(
self.RESUME,
{
'token': self.client._token,
'session_id': self.session_id,
'seq': self.sequence,
}
)
)
@staticmethod
def generate_presence(activity: BotActivity = None, status: str = None):
@ -199,7 +224,7 @@ class Gateway:
return data
async def update_presence(self, data: dict):
await self.send(self.PRESENCE_UPDATE, data)
await self.send(self.opcode(self.PRESENCE_UPDATE, data))
@staticmethod
def opcode(opcode: int, payload) -> str:

View file

@ -8,11 +8,6 @@ class ClientException(MelisaException):
pass
class InvalidPayload(MelisaException):
"""This exception means invalid payload"""
pass
class LoginFailure(ClientException):
"""Fails to log you in from improper credentials or some other misc."""
pass

View file

@ -45,14 +45,15 @@ class Shard:
self.disconnected = False
create_task(self._gateway.start_loop())
create_task(self._gateway.connect())
return self
async def _try_close(self) -> None:
if self._gateway.connected:
self._gateway.connected = False
await self._gateway.close(code=1000)
async def close(self):
"""|coro|
Disconnect shard
"""
create_task(self._gateway.close())
async def update_presence(self, activity: BotActivity = None, status: str = None) -> Shard:
"""
@ -89,6 +90,6 @@ class Shard:
wait_time: :class:`int`
Reconnect after
"""
await self._try_close()
await self.close()
await sleep(wait_time)
await self.launch()

View file

@ -1,5 +1,6 @@
from ...utils import Snowflake
class Guild:
pass