add base ratelimiter

This commit is contained in:
grey-cat-1908 2022-03-25 16:24:20 +03:00
parent bc17864e3a
commit 3d66eb79c5
3 changed files with 102 additions and 6 deletions

View file

@ -14,7 +14,10 @@ from melisa.exceptions import (NotModifiedError,
UnauthorizedError, UnauthorizedError,
HTTPException, HTTPException,
NotFoundError, NotFoundError,
MethodNotAllowedError) MethodNotAllowedError,
ServerError,
RateLimitError)
from .ratelimiter import RateLimiter
class HTTPClient: class HTTPClient:
@ -23,6 +26,7 @@ class HTTPClient:
def __init__(self, token: str, *, ttl: int = 5): def __init__(self, token: str, *, ttl: int = 5):
self.url: str = f"https://discord.com/api/v{self.API_VERSION}" self.url: str = f"https://discord.com/api/v{self.API_VERSION}"
self.max_ttl: int = ttl self.max_ttl: int = ttl
self.__rate_limiter = RateLimiter()
headers: Dict[str, str] = { headers: Dict[str, str] = {
"Content-Type": "application/json", "Content-Type": "application/json",
@ -36,11 +40,18 @@ class HTTPClient:
401: UnauthorizedError(), 401: UnauthorizedError(),
403: ForbiddenError(), 403: ForbiddenError(),
404: NotFoundError(), 404: NotFoundError(),
405: MethodNotAllowedError() 405: MethodNotAllowedError(),
429: RateLimitError()
} }
self.__aiohttp_session: ClientSession = ClientSession(headers=headers) self.__aiohttp_session: ClientSession = ClientSession(headers=headers)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
await self.close()
async def close(self): async def close(self):
"""Close the aiohttp session""" """Close the aiohttp session"""
await self.__aiohttp_session.close() await self.__aiohttp_session.close()
@ -57,11 +68,18 @@ class HTTPClient:
ttl = _ttl or self.max_ttl ttl = _ttl or self.max_ttl
if ttl == 0:
raise ServerError(f"Maximum amount of retries for `{endpoint}`.")
await self.__rate_limiter.wait_until_not_ratelimited(
endpoint,
method
)
url = f"{self.url}/{endpoint}" url = f"{self.url}/{endpoint}"
if ttl != 0: async with self.__aiohttp_session.request(method, url, **kwargs) as response:
async with self.__aiohttp_session.request(method, url, **kwargs) as response: return await self.__handle_response(response, method, endpoint, _ttl=ttl, **kwargs)
return await self.__handle_response(response, method, endpoint, _ttl=ttl, **kwargs)
async def __handle_response( async def __handle_response(
self, self,
@ -74,12 +92,26 @@ class HTTPClient:
) -> Optional[Dict]: ) -> Optional[Dict]:
"""Handle responses from the Discord API.""" """Handle responses from the Discord API."""
self.__rate_limiter.save_response_bucket(
endpoint, method, res.headers
)
if res.ok: if res.ok:
return await res.json() return await res.json()
exception = self.__http_exceptions.get(res.status) exception = self.__http_exceptions.get(res.status)
if exception: if exception:
if isinstance(exception, RateLimitError):
timeout = (await res.json()).get("retry_after", 40)
await asyncio.sleep(timeout)
return await self.__send(
method,
endpoint,
**kwargs
)
exception.__init__(res.reason) exception.__init__(res.reason)
raise exception raise exception

View file

@ -0,0 +1,64 @@
# Copyright MelisaDev 2022 - Present
# Full MIT License can be found in `LICENSE.txt` at the project root.
from __future__ import annotations
from asyncio import sleep
from dataclasses import dataclass
from time import time
from typing import Dict, Tuple, Any
@dataclass
class RateLimitBucket:
"""Represents a rate limit bucket"""
limit: int
remaining: int
reset: float
reset_after_timestamp: float
since_timestamp: float
class RateLimiter:
"""Prevents ``user`` rate limits"""
def __init__(self) -> None:
self.bucket_map: Dict[Tuple[str, str], str] = {} # Dict[Tuple[endpoint, method], bucket_id]
self.buckets: Dict[str, RateLimitBucket] = {}
def save_response_bucket(
self,
endpoint: str,
method: str,
header: Any
):
ratelimit_bucket_id = header.get("X-RateLimit-Bucket")
if not ratelimit_bucket_id:
return
self.bucket_map[endpoint, method] = ratelimit_bucket_id
self.buckets[ratelimit_bucket_id] = RateLimitBucket(
limit=int(header["X-RateLimit-Limit"]),
remaining=int(header["X-RateLimit-Remaining"]),
reset=float(header["X-RateLimit-Reset"]),
reset_after_timestamp=float(header["X-RateLimit-Reset-After"]),
since_timestamp=time()
)
async def wait_until_not_ratelimited(
self,
endpoint: str,
method: str
):
bucket_id = self.bucket_map.get((endpoint, method))
if not bucket_id:
return
bucket = self.buckets[bucket_id]
if bucket.remaining == 0:
sleep_time = time() - bucket.since_timestamp + bucket.reset_after_timestamp
await sleep(sleep_time)

View file

@ -1 +1 @@
aiohttp~=3.8.1 aiohttp