diff --git a/melisa/core/gateway.py b/melisa/core/gateway.py index 73bb6f4..08eb760 100644 --- a/melisa/core/gateway.py +++ b/melisa/core/gateway.py @@ -1,14 +1,15 @@ import json import asyncio +import sys import zlib import time from asyncio import ensure_future from dataclasses import dataclass from typing import Dict, Any -import websockets +import aiohttp -from ..exceptions import InvalidPayload, GatewayError, PrivilegedIntentsRequired, LoginFailure +from ..exceptions import GatewayError, PrivilegedIntentsRequired, LoginFailure from ..listeners import listeners from ..models.user import BotActivity from ..utils import APIObjectBase @@ -46,22 +47,15 @@ class Gateway: self.interval = None self.intents = client.intents self.sequence = None + self.__session = aiohttp.ClientSession() self.session_id = None self.client = client - self.shard_id = shard_id self.latency = float('inf') - self.connected = False + self.ws = None + self.loop = asyncio.get_event_loop() - self.__close_codes: Dict[int, Any] = { - 4001: GatewayError("Invalid opcode was sent"), - 4002: InvalidPayload("Invalid payload was sent."), - 4003: GatewayError("Payload was sent prior to identifying"), + self.__raise_close_codes: Dict[int, Any] = { 4004: LoginFailure("Token is not valid"), - 4005: GatewayError( - "Authentication was sent after client already authenticated" - ), - 4007: GatewayError("Invalid sequence sent when starting new session"), - 4008: GatewayError("Client was rate limited"), 4010: GatewayError("Invalid shard"), 4011: GatewayError("Sharding required"), 4012: GatewayError("Invalid API version"), @@ -77,7 +71,7 @@ class Gateway: "token": self.client._token, "intents": self.intents, "properties": { - "$os": "windows", + "$os": sys.platform, "$browser": "melisa", "$device": "melisa" }, @@ -89,7 +83,30 @@ class Gateway: self._zlib: zlib._Decompress = zlib.decompressobj() self._buffer: bytearray = bytearray() - async def websocket_message(self, msg): + async def connect(self) -> None: + self.ws = await self.__session.ws_connect( + f'wss://gateway.discord.gg/?v={self.GATEWAY_VERSION}&encoding=json&compress=zlib-stream') + + if self.session_id is None: + await self.send_identify() + self.loop.create_task(self.receive()) + await self.check_heartbeating() + else: + await self.resume() + + async def check_heartbeating(self): + await asyncio.sleep(20) + + if self._last_send + 60.0 < time.perf_counter(): + await self.ws.close(code=4000) + await self.handle_close(4000) + + await self.check_heartbeating() + + async def send(self, payload: str) -> None: + await self.ws.send_str(payload) + + async def parse_websocket_message(self, msg): if type(msg) is bytes: self._buffer.extend(msg) @@ -101,79 +118,87 @@ class Gateway: return json.loads(msg) - async def start_loop(self): - async with websockets.connect( - f'wss://gateway.discord.gg/?v={self.GATEWAY_VERSION}&encoding=json&compress=zlib-stream') \ - as self.websocket: - await self.hello() - if self.interval is None: - return - self.connected = True - await asyncio.gather(self.heartbeat(), self.receive()) + async def handle_data(self, data): + if data['op'] == self.DISPATCH: + self.sequence = int(data["s"]) + event_type = data["t"].lower() - async def close(self, code: int = 1000): - await self.websocket.close(code=code) + event_to_call = self.listeners.get(event_type) - async def resume(self): - resume_data = { - "seq": self.sequence, - "session_id": self.session_id, - "token": self.client._token - } + if event_to_call is not None: + ensure_future(event_to_call(self.client, self, data["d"])) - await self.send(self.RESUME, resume_data) + elif data['op'] == self.INVALID_SESSION: + await self.ws.close(code=4000) + await self.handle_close(4000) + elif data['op'] == self.HELLO: + await self.send_hello(data) + elif data['op'] == self.HEARTBEAT_ACK: + self.latency = time.perf_counter() - self._last_send - async def receive(self): - async for msg in self.websocket: - msg = await self.websocket_message(msg) + async def receive(self) -> None: + async for msg in self.ws: + if msg.type == aiohttp.WSMsgType.BINARY: + data = await self.parse_websocket_message(msg.data) + if data: + await self.handle_data(data) + elif msg.type == aiohttp.WSMsgType.TEXT: + await self.handle_data(msg.data) + else: + raise RuntimeError - if msg is None: - return None - - if msg["op"] == self.HEARTBEAT_ACK: - self.latency = time.time() - self._last_send - - if msg["op"] == self.DISPATCH: - self.sequence = int(msg["s"]) - event_type = msg["t"].lower() - - event_to_call = self.listeners.get(event_type) - - if event_to_call is not None: - await event_to_call(self.client, self, msg["d"]) - - if msg["op"] != self.DISPATCH: - if msg["op"] == self.RECONNECT: - await self.websocket.close() - await self.resume() - - async def send(self, opcode, payload): - data = self.opcode(opcode, payload) - - if opcode == self.HEARTBEAT: - self._last_send = time.time() - - await self.websocket.send(data) - - async def heartbeat(self): - while self.interval is not None: - await self.send(self.HEARTBEAT, self.sequence) - self.connected = True - await asyncio.sleep(self.interval) - - async def hello(self): - await self.send(self.IDENTIFY, self.auth) - - ret = await self.websocket.recv() - - data = await self.websocket_message(ret) - - opcode = data["op"] - - if opcode != 10: + close_code = self.ws.close_code + if close_code is None: return + else: + await self.handle_close(close_code) - self.interval = (data["d"]["heartbeat_interval"] - 2000) / 1000 + async def handle_close(self, code: int) -> None: + if code == 4009: + await self.resume() + return + else: + err = self.__raise_close_codes.get(code) + + if err: + raise err + + await self.connect() + + async def send_heartbeat(self, interval: float) -> None: + if not self.ws.closed: + await self.send(self.opcode(self.HEARTBEAT, self.sequence)) + self._last_send = time.perf_counter() + await asyncio.sleep(interval) + self.loop.create_task(self.send_heartbeat(interval)) + + async def close(self, code: int = 4000) -> None: + if self.ws: + await self.ws.close(code=code) + self._buffer.clear() + + async def send_hello(self, data: Dict) -> None: + interval = data['d']['heartbeat_interval'] / 1000 + await asyncio.sleep((interval - 2000) / 1000) + self.loop.create_task(self.send_heartbeat(interval)) + + async def send_identify(self) -> None: + await self.send(self.opcode( + self.IDENTIFY, + self.auth + )) + + async def resume(self) -> None: + await self.send( + self.opcode( + self.RESUME, + { + 'token': self.client._token, + 'session_id': self.session_id, + 'seq': self.sequence, + } + ) + ) @staticmethod def generate_presence(activity: BotActivity = None, status: str = None): @@ -199,7 +224,7 @@ class Gateway: return data async def update_presence(self, data: dict): - await self.send(self.PRESENCE_UPDATE, data) + await self.send(self.opcode(self.PRESENCE_UPDATE, data)) @staticmethod def opcode(opcode: int, payload) -> str: diff --git a/melisa/exceptions.py b/melisa/exceptions.py index 8481fbb..3985905 100644 --- a/melisa/exceptions.py +++ b/melisa/exceptions.py @@ -8,11 +8,6 @@ class ClientException(MelisaException): pass -class InvalidPayload(MelisaException): - """This exception means invalid payload""" - pass - - class LoginFailure(ClientException): """Fails to log you in from improper credentials or some other misc.""" pass diff --git a/melisa/models/app/shard.py b/melisa/models/app/shard.py index 85475dc..1cf37ec 100644 --- a/melisa/models/app/shard.py +++ b/melisa/models/app/shard.py @@ -45,14 +45,15 @@ class Shard: self.disconnected = False - create_task(self._gateway.start_loop()) + create_task(self._gateway.connect()) return self - async def _try_close(self) -> None: - if self._gateway.connected: - self._gateway.connected = False - await self._gateway.close(code=1000) + async def close(self): + """|coro| + Disconnect shard + """ + create_task(self._gateway.close()) async def update_presence(self, activity: BotActivity = None, status: str = None) -> Shard: """ @@ -89,6 +90,6 @@ class Shard: wait_time: :class:`int` Reconnect after """ - await self._try_close() + await self.close() await sleep(wait_time) await self.launch() diff --git a/melisa/models/guild/guild.py b/melisa/models/guild/guild.py index 8f0b3bc..712e9c6 100644 --- a/melisa/models/guild/guild.py +++ b/melisa/models/guild/guild.py @@ -1,5 +1,6 @@ from ...utils import Snowflake + class Guild: pass