mirror of
https://github.com/MelisaDev/melisa.git
synced 2024-11-11 19:07:28 +03:00
add base ratelimiter
This commit is contained in:
parent
bc17864e3a
commit
3d66eb79c5
3 changed files with 102 additions and 6 deletions
|
@ -14,7 +14,10 @@ from melisa.exceptions import (NotModifiedError,
|
||||||
UnauthorizedError,
|
UnauthorizedError,
|
||||||
HTTPException,
|
HTTPException,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
MethodNotAllowedError)
|
MethodNotAllowedError,
|
||||||
|
ServerError,
|
||||||
|
RateLimitError)
|
||||||
|
from .ratelimiter import RateLimiter
|
||||||
|
|
||||||
|
|
||||||
class HTTPClient:
|
class HTTPClient:
|
||||||
|
@ -23,6 +26,7 @@ class HTTPClient:
|
||||||
def __init__(self, token: str, *, ttl: int = 5):
|
def __init__(self, token: str, *, ttl: int = 5):
|
||||||
self.url: str = f"https://discord.com/api/v{self.API_VERSION}"
|
self.url: str = f"https://discord.com/api/v{self.API_VERSION}"
|
||||||
self.max_ttl: int = ttl
|
self.max_ttl: int = ttl
|
||||||
|
self.__rate_limiter = RateLimiter()
|
||||||
|
|
||||||
headers: Dict[str, str] = {
|
headers: Dict[str, str] = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
@ -36,11 +40,18 @@ class HTTPClient:
|
||||||
401: UnauthorizedError(),
|
401: UnauthorizedError(),
|
||||||
403: ForbiddenError(),
|
403: ForbiddenError(),
|
||||||
404: NotFoundError(),
|
404: NotFoundError(),
|
||||||
405: MethodNotAllowedError()
|
405: MethodNotAllowedError(),
|
||||||
|
429: RateLimitError()
|
||||||
}
|
}
|
||||||
|
|
||||||
self.__aiohttp_session: ClientSession = ClientSession(headers=headers)
|
self.__aiohttp_session: ClientSession = ClientSession(headers=headers)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
await self.close()
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""Close the aiohttp session"""
|
"""Close the aiohttp session"""
|
||||||
await self.__aiohttp_session.close()
|
await self.__aiohttp_session.close()
|
||||||
|
@ -57,11 +68,18 @@ class HTTPClient:
|
||||||
|
|
||||||
ttl = _ttl or self.max_ttl
|
ttl = _ttl or self.max_ttl
|
||||||
|
|
||||||
|
if ttl == 0:
|
||||||
|
raise ServerError(f"Maximum amount of retries for `{endpoint}`.")
|
||||||
|
|
||||||
|
await self.__rate_limiter.wait_until_not_ratelimited(
|
||||||
|
endpoint,
|
||||||
|
method
|
||||||
|
)
|
||||||
|
|
||||||
url = f"{self.url}/{endpoint}"
|
url = f"{self.url}/{endpoint}"
|
||||||
|
|
||||||
if ttl != 0:
|
async with self.__aiohttp_session.request(method, url, **kwargs) as response:
|
||||||
async with self.__aiohttp_session.request(method, url, **kwargs) as response:
|
return await self.__handle_response(response, method, endpoint, _ttl=ttl, **kwargs)
|
||||||
return await self.__handle_response(response, method, endpoint, _ttl=ttl, **kwargs)
|
|
||||||
|
|
||||||
async def __handle_response(
|
async def __handle_response(
|
||||||
self,
|
self,
|
||||||
|
@ -74,12 +92,26 @@ class HTTPClient:
|
||||||
) -> Optional[Dict]:
|
) -> Optional[Dict]:
|
||||||
"""Handle responses from the Discord API."""
|
"""Handle responses from the Discord API."""
|
||||||
|
|
||||||
|
self.__rate_limiter.save_response_bucket(
|
||||||
|
endpoint, method, res.headers
|
||||||
|
)
|
||||||
|
|
||||||
if res.ok:
|
if res.ok:
|
||||||
return await res.json()
|
return await res.json()
|
||||||
|
|
||||||
exception = self.__http_exceptions.get(res.status)
|
exception = self.__http_exceptions.get(res.status)
|
||||||
|
|
||||||
if exception:
|
if exception:
|
||||||
|
if isinstance(exception, RateLimitError):
|
||||||
|
timeout = (await res.json()).get("retry_after", 40)
|
||||||
|
|
||||||
|
await asyncio.sleep(timeout)
|
||||||
|
return await self.__send(
|
||||||
|
method,
|
||||||
|
endpoint,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
exception.__init__(res.reason)
|
exception.__init__(res.reason)
|
||||||
raise exception
|
raise exception
|
||||||
|
|
||||||
|
|
64
melisa/core/ratelimiter.py
Normal file
64
melisa/core/ratelimiter.py
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
# Copyright MelisaDev 2022 - Present
|
||||||
|
# Full MIT License can be found in `LICENSE.txt` at the project root.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from asyncio import sleep
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from time import time
|
||||||
|
from typing import Dict, Tuple, Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RateLimitBucket:
|
||||||
|
"""Represents a rate limit bucket"""
|
||||||
|
limit: int
|
||||||
|
remaining: int
|
||||||
|
reset: float
|
||||||
|
reset_after_timestamp: float
|
||||||
|
since_timestamp: float
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""Prevents ``user`` rate limits"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.bucket_map: Dict[Tuple[str, str], str] = {} # Dict[Tuple[endpoint, method], bucket_id]
|
||||||
|
self.buckets: Dict[str, RateLimitBucket] = {}
|
||||||
|
|
||||||
|
def save_response_bucket(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
method: str,
|
||||||
|
header: Any
|
||||||
|
):
|
||||||
|
ratelimit_bucket_id = header.get("X-RateLimit-Bucket")
|
||||||
|
|
||||||
|
if not ratelimit_bucket_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.bucket_map[endpoint, method] = ratelimit_bucket_id
|
||||||
|
|
||||||
|
self.buckets[ratelimit_bucket_id] = RateLimitBucket(
|
||||||
|
limit=int(header["X-RateLimit-Limit"]),
|
||||||
|
remaining=int(header["X-RateLimit-Remaining"]),
|
||||||
|
reset=float(header["X-RateLimit-Reset"]),
|
||||||
|
reset_after_timestamp=float(header["X-RateLimit-Reset-After"]),
|
||||||
|
since_timestamp=time()
|
||||||
|
)
|
||||||
|
|
||||||
|
async def wait_until_not_ratelimited(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
method: str
|
||||||
|
):
|
||||||
|
bucket_id = self.bucket_map.get((endpoint, method))
|
||||||
|
|
||||||
|
if not bucket_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
bucket = self.buckets[bucket_id]
|
||||||
|
|
||||||
|
if bucket.remaining == 0:
|
||||||
|
sleep_time = time() - bucket.since_timestamp + bucket.reset_after_timestamp
|
||||||
|
await sleep(sleep_time)
|
|
@ -1 +1 @@
|
||||||
aiohttp~=3.8.1
|
aiohttp
|
Loading…
Reference in a new issue