mirror of
https://github.com/MelisaDev/melisa.git
synced 2024-11-11 19:07:28 +03:00
add base ratelimiter
This commit is contained in:
parent
bc17864e3a
commit
3d66eb79c5
3 changed files with 102 additions and 6 deletions
|
@ -14,7 +14,10 @@ from melisa.exceptions import (NotModifiedError,
|
|||
UnauthorizedError,
|
||||
HTTPException,
|
||||
NotFoundError,
|
||||
MethodNotAllowedError)
|
||||
MethodNotAllowedError,
|
||||
ServerError,
|
||||
RateLimitError)
|
||||
from .ratelimiter import RateLimiter
|
||||
|
||||
|
||||
class HTTPClient:
|
||||
|
@ -23,6 +26,7 @@ class HTTPClient:
|
|||
def __init__(self, token: str, *, ttl: int = 5):
|
||||
self.url: str = f"https://discord.com/api/v{self.API_VERSION}"
|
||||
self.max_ttl: int = ttl
|
||||
self.__rate_limiter = RateLimiter()
|
||||
|
||||
headers: Dict[str, str] = {
|
||||
"Content-Type": "application/json",
|
||||
|
@ -36,11 +40,18 @@ class HTTPClient:
|
|||
401: UnauthorizedError(),
|
||||
403: ForbiddenError(),
|
||||
404: NotFoundError(),
|
||||
405: MethodNotAllowedError()
|
||||
405: MethodNotAllowedError(),
|
||||
429: RateLimitError()
|
||||
}
|
||||
|
||||
self.__aiohttp_session: ClientSession = ClientSession(headers=headers)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
await self.close()
|
||||
|
||||
async def close(self):
|
||||
"""Close the aiohttp session"""
|
||||
await self.__aiohttp_session.close()
|
||||
|
@ -57,11 +68,18 @@ class HTTPClient:
|
|||
|
||||
ttl = _ttl or self.max_ttl
|
||||
|
||||
if ttl == 0:
|
||||
raise ServerError(f"Maximum amount of retries for `{endpoint}`.")
|
||||
|
||||
await self.__rate_limiter.wait_until_not_ratelimited(
|
||||
endpoint,
|
||||
method
|
||||
)
|
||||
|
||||
url = f"{self.url}/{endpoint}"
|
||||
|
||||
if ttl != 0:
|
||||
async with self.__aiohttp_session.request(method, url, **kwargs) as response:
|
||||
return await self.__handle_response(response, method, endpoint, _ttl=ttl, **kwargs)
|
||||
async with self.__aiohttp_session.request(method, url, **kwargs) as response:
|
||||
return await self.__handle_response(response, method, endpoint, _ttl=ttl, **kwargs)
|
||||
|
||||
async def __handle_response(
|
||||
self,
|
||||
|
@ -74,12 +92,26 @@ class HTTPClient:
|
|||
) -> Optional[Dict]:
|
||||
"""Handle responses from the Discord API."""
|
||||
|
||||
self.__rate_limiter.save_response_bucket(
|
||||
endpoint, method, res.headers
|
||||
)
|
||||
|
||||
if res.ok:
|
||||
return await res.json()
|
||||
|
||||
exception = self.__http_exceptions.get(res.status)
|
||||
|
||||
if exception:
|
||||
if isinstance(exception, RateLimitError):
|
||||
timeout = (await res.json()).get("retry_after", 40)
|
||||
|
||||
await asyncio.sleep(timeout)
|
||||
return await self.__send(
|
||||
method,
|
||||
endpoint,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
exception.__init__(res.reason)
|
||||
raise exception
|
||||
|
||||
|
|
64
melisa/core/ratelimiter.py
Normal file
64
melisa/core/ratelimiter.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
# Copyright MelisaDev 2022 - Present
|
||||
# Full MIT License can be found in `LICENSE.txt` at the project root.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from asyncio import sleep
|
||||
from dataclasses import dataclass
|
||||
from time import time
|
||||
from typing import Dict, Tuple, Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitBucket:
|
||||
"""Represents a rate limit bucket"""
|
||||
limit: int
|
||||
remaining: int
|
||||
reset: float
|
||||
reset_after_timestamp: float
|
||||
since_timestamp: float
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Prevents ``user`` rate limits"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.bucket_map: Dict[Tuple[str, str], str] = {} # Dict[Tuple[endpoint, method], bucket_id]
|
||||
self.buckets: Dict[str, RateLimitBucket] = {}
|
||||
|
||||
def save_response_bucket(
|
||||
self,
|
||||
endpoint: str,
|
||||
method: str,
|
||||
header: Any
|
||||
):
|
||||
ratelimit_bucket_id = header.get("X-RateLimit-Bucket")
|
||||
|
||||
if not ratelimit_bucket_id:
|
||||
return
|
||||
|
||||
self.bucket_map[endpoint, method] = ratelimit_bucket_id
|
||||
|
||||
self.buckets[ratelimit_bucket_id] = RateLimitBucket(
|
||||
limit=int(header["X-RateLimit-Limit"]),
|
||||
remaining=int(header["X-RateLimit-Remaining"]),
|
||||
reset=float(header["X-RateLimit-Reset"]),
|
||||
reset_after_timestamp=float(header["X-RateLimit-Reset-After"]),
|
||||
since_timestamp=time()
|
||||
)
|
||||
|
||||
async def wait_until_not_ratelimited(
|
||||
self,
|
||||
endpoint: str,
|
||||
method: str
|
||||
):
|
||||
bucket_id = self.bucket_map.get((endpoint, method))
|
||||
|
||||
if not bucket_id:
|
||||
return
|
||||
|
||||
bucket = self.buckets[bucket_id]
|
||||
|
||||
if bucket.remaining == 0:
|
||||
sleep_time = time() - bucket.since_timestamp + bucket.reset_after_timestamp
|
||||
await sleep(sleep_time)
|
|
@ -1 +1 @@
|
|||
aiohttp~=3.8.1
|
||||
aiohttp
|
Loading…
Reference in a new issue