376 lines
12 KiB
Python
376 lines
12 KiB
Python
"""prawcore.sessions: Provides prawcore.Session and prawcore.session."""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import random
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
from copy import deepcopy
|
|
from pprint import pformat
|
|
from typing import TYPE_CHECKING, Any
|
|
from urllib.parse import urljoin
|
|
|
|
from requests.exceptions import ChunkedEncodingError, ConnectionError, ReadTimeout
|
|
from requests.status_codes import codes
|
|
|
|
from .auth import BaseAuthorizer
|
|
from .const import TIMEOUT, WINDOW_SIZE
|
|
from .exceptions import (
|
|
BadJSON,
|
|
BadRequest,
|
|
Conflict,
|
|
InvalidInvocation,
|
|
NotFound,
|
|
Redirect,
|
|
RequestException,
|
|
ServerError,
|
|
SpecialError,
|
|
TooLarge,
|
|
TooManyRequests,
|
|
UnavailableForLegalReasons,
|
|
URITooLong,
|
|
)
|
|
from .rate_limit import RateLimiter
|
|
from .util import authorization_error_class
|
|
|
|
if TYPE_CHECKING:
|
|
from io import BufferedReader
|
|
|
|
from requests.models import Response
|
|
|
|
from .auth import Authorizer
|
|
from .requestor import Requestor
|
|
|
|
log = logging.getLogger(__package__)
|
|
|
|
|
|
class RetryStrategy(ABC):
|
|
"""An abstract class for scheduling request retries.
|
|
|
|
The strategy controls both the number and frequency of retry attempts.
|
|
|
|
Instances of this class are immutable.
|
|
|
|
"""
|
|
|
|
@abstractmethod
|
|
def _sleep_seconds(self) -> float | None:
|
|
pass
|
|
|
|
def sleep(self) -> None:
|
|
"""Sleep until we are ready to attempt the request."""
|
|
sleep_seconds = self._sleep_seconds()
|
|
if sleep_seconds is not None:
|
|
message = f"Sleeping: {sleep_seconds:0.2f} seconds prior to retry"
|
|
log.debug(message)
|
|
time.sleep(sleep_seconds)
|
|
|
|
|
|
class Session:
|
|
"""The low-level connection interface to Reddit's API."""
|
|
|
|
RETRY_EXCEPTIONS = (ChunkedEncodingError, ConnectionError, ReadTimeout)
|
|
RETRY_STATUSES = {
|
|
520,
|
|
522,
|
|
codes["bad_gateway"],
|
|
codes["gateway_timeout"],
|
|
codes["internal_server_error"],
|
|
codes["request_timeout"],
|
|
codes["service_unavailable"],
|
|
}
|
|
STATUS_EXCEPTIONS = {
|
|
codes["bad_gateway"]: ServerError,
|
|
codes["bad_request"]: BadRequest,
|
|
codes["conflict"]: Conflict,
|
|
codes["found"]: Redirect,
|
|
codes["forbidden"]: authorization_error_class,
|
|
codes["gateway_timeout"]: ServerError,
|
|
codes["internal_server_error"]: ServerError,
|
|
codes["media_type"]: SpecialError,
|
|
codes["moved_permanently"]: Redirect,
|
|
codes["not_found"]: NotFound,
|
|
codes["request_entity_too_large"]: TooLarge,
|
|
codes["request_uri_too_large"]: URITooLong,
|
|
codes["service_unavailable"]: ServerError,
|
|
codes["too_many_requests"]: TooManyRequests,
|
|
codes["unauthorized"]: authorization_error_class,
|
|
codes[
|
|
"unavailable_for_legal_reasons"
|
|
]: UnavailableForLegalReasons, # Cloudflare's status (not named in requests)
|
|
520: ServerError,
|
|
522: ServerError,
|
|
}
|
|
SUCCESS_STATUSES = {codes["accepted"], codes["created"], codes["ok"]}
|
|
|
|
@staticmethod
|
|
def _log_request(
|
|
data: list[tuple[str, str]] | None,
|
|
method: str,
|
|
params: dict[str, int],
|
|
url: str,
|
|
) -> None:
|
|
log.debug("Fetching: %s %s at %s", method, url, time.time())
|
|
log.debug("Data: %s", pformat(data))
|
|
log.debug("Params: %s", pformat(params))
|
|
|
|
@property
|
|
def _requestor(self) -> Requestor:
|
|
return self._authorizer._authenticator._requestor
|
|
|
|
def __enter__(self) -> Session: # noqa: PYI034
|
|
"""Allow this object to be used as a context manager."""
|
|
return self
|
|
|
|
def __exit__(self, *_args) -> None:
|
|
"""Allow this object to be used as a context manager."""
|
|
self.close()
|
|
|
|
def __init__(
|
|
self,
|
|
authorizer: BaseAuthorizer | None,
|
|
window_size: int = WINDOW_SIZE,
|
|
) -> None:
|
|
"""Prepare the connection to Reddit's API.
|
|
|
|
:param authorizer: An instance of :class:`.Authorizer`.
|
|
:param window_size: The size of the rate limit reset window in seconds.
|
|
|
|
"""
|
|
if not isinstance(authorizer, BaseAuthorizer):
|
|
msg = f"invalid Authorizer: {authorizer}"
|
|
raise InvalidInvocation(msg)
|
|
self._authorizer = authorizer
|
|
self._rate_limiter = RateLimiter(window_size=window_size)
|
|
self._retry_strategy_class = FiniteRetryStrategy
|
|
|
|
def _do_retry(
|
|
self,
|
|
data: list[tuple[str, Any]],
|
|
files: dict[str, BufferedReader],
|
|
json: dict[str, Any],
|
|
method: str,
|
|
params: dict[str, int],
|
|
response: Response | None,
|
|
retry_strategy_state: FiniteRetryStrategy,
|
|
saved_exception: Exception | None,
|
|
timeout: float,
|
|
url: str,
|
|
) -> dict[str, Any] | str | None:
|
|
status = repr(saved_exception) if saved_exception else response.status_code
|
|
log.warning("Retrying due to %s status: %s %s", status, method, url)
|
|
return self._request_with_retries(
|
|
data=data,
|
|
files=files,
|
|
json=json,
|
|
method=method,
|
|
params=params,
|
|
timeout=timeout,
|
|
url=url,
|
|
retry_strategy_state=retry_strategy_state.consume_available_retry(),
|
|
# noqa: E501
|
|
)
|
|
|
|
def _make_request(
|
|
self,
|
|
data: list[tuple[str, Any]],
|
|
files: dict[str, BufferedReader],
|
|
json: dict[str, Any],
|
|
method: str,
|
|
params: dict[str, Any],
|
|
retry_strategy_state: FiniteRetryStrategy,
|
|
timeout: float,
|
|
url: str,
|
|
) -> tuple[Response, None] | tuple[None, Exception]:
|
|
try:
|
|
response = self._rate_limiter.call(
|
|
self._requestor.request,
|
|
self._set_header_callback,
|
|
method,
|
|
url,
|
|
allow_redirects=False,
|
|
data=data,
|
|
files=files,
|
|
json=json,
|
|
params=params,
|
|
timeout=timeout,
|
|
)
|
|
log.debug(
|
|
"Response: %s (%s bytes) (rst-%s:rem-%s:used-%s ratelimit) at %s",
|
|
response.status_code,
|
|
response.headers.get("content-length"),
|
|
response.headers.get("x-ratelimit-reset"),
|
|
response.headers.get("x-ratelimit-remaining"),
|
|
response.headers.get("x-ratelimit-used"),
|
|
time.time(),
|
|
)
|
|
return response, None
|
|
except RequestException as exception:
|
|
if (
|
|
not retry_strategy_state.should_retry_on_failure()
|
|
or not isinstance( # noqa: E501
|
|
exception.original_exception, self.RETRY_EXCEPTIONS
|
|
)
|
|
):
|
|
raise
|
|
return None, exception.original_exception
|
|
|
|
def _request_with_retries(
|
|
self,
|
|
data: list[tuple[str, Any]],
|
|
files: dict[str, BufferedReader],
|
|
json: dict[str, Any],
|
|
method: str,
|
|
params: dict[str, int],
|
|
timeout: float,
|
|
url: str,
|
|
retry_strategy_state: FiniteRetryStrategy | None = None,
|
|
) -> dict[str, Any] | str | None:
|
|
if retry_strategy_state is None:
|
|
retry_strategy_state = self._retry_strategy_class()
|
|
|
|
retry_strategy_state.sleep()
|
|
self._log_request(data, method, params, url)
|
|
response, saved_exception = self._make_request(
|
|
data,
|
|
files,
|
|
json,
|
|
method,
|
|
params,
|
|
retry_strategy_state,
|
|
timeout,
|
|
url,
|
|
)
|
|
|
|
do_retry = False
|
|
if response is not None and response.status_code == codes["unauthorized"]:
|
|
self._authorizer._clear_access_token()
|
|
if hasattr(self._authorizer, "refresh"):
|
|
do_retry = True
|
|
|
|
if retry_strategy_state.should_retry_on_failure() and (
|
|
do_retry or response is None or response.status_code in self.RETRY_STATUSES
|
|
):
|
|
return self._do_retry(
|
|
data,
|
|
files,
|
|
json,
|
|
method,
|
|
params,
|
|
response,
|
|
retry_strategy_state,
|
|
saved_exception,
|
|
timeout,
|
|
url,
|
|
)
|
|
if response.status_code in self.STATUS_EXCEPTIONS:
|
|
raise self.STATUS_EXCEPTIONS[response.status_code](response)
|
|
if response.status_code == codes["no_content"]:
|
|
return None
|
|
assert (
|
|
response.status_code in self.SUCCESS_STATUSES
|
|
), f"Unexpected status code: {response.status_code}"
|
|
if response.headers.get("content-length") == "0":
|
|
return ""
|
|
try:
|
|
return response.json()
|
|
except ValueError:
|
|
raise BadJSON(response) from None
|
|
|
|
def _set_header_callback(self) -> dict[str, str]:
|
|
if not self._authorizer.is_valid() and hasattr(self._authorizer, "refresh"):
|
|
self._authorizer.refresh()
|
|
return {"Authorization": f"bearer {self._authorizer.access_token}"}
|
|
|
|
def close(self) -> None:
|
|
"""Close the session and perform any clean up."""
|
|
self._requestor.close()
|
|
|
|
def request(
|
|
self,
|
|
method: str,
|
|
path: str,
|
|
data: dict[str, Any] | None = None,
|
|
files: dict[str, BufferedReader] | None = None,
|
|
json: dict[str, Any] | None = None,
|
|
params: dict[str, Any] | None = None,
|
|
timeout: float = TIMEOUT,
|
|
) -> dict[str, Any] | str | None:
|
|
"""Return the json content from the resource at ``path``.
|
|
|
|
:param method: The request verb. E.g., ``"GET"``, ``"POST"``, ``"PUT"``.
|
|
:param path: The path of the request. This path will be combined with the
|
|
``oauth_url`` of the Requestor.
|
|
:param data: Dictionary, bytes, or file-like object to send in the body of the
|
|
request.
|
|
:param files: Dictionary, mapping ``filename`` to file-like object.
|
|
:param json: Object to be serialized to JSON in the body of the request.
|
|
:param params: The query parameters to send with the request.
|
|
:param timeout: Specifies a particular timeout, in seconds.
|
|
|
|
Automatically refreshes the access token if it becomes invalid and a refresh
|
|
token is available.
|
|
|
|
:raises: :class:`.InvalidInvocation` in such a case if a refresh token is not
|
|
available.
|
|
|
|
"""
|
|
params = deepcopy(params) or {}
|
|
params["raw_json"] = 1
|
|
if isinstance(data, dict):
|
|
data = deepcopy(data)
|
|
data["api_type"] = "json"
|
|
data = sorted(data.items())
|
|
if isinstance(json, dict):
|
|
json = deepcopy(json)
|
|
json["api_type"] = "json"
|
|
url = urljoin(self._requestor.oauth_url, path)
|
|
return self._request_with_retries(
|
|
data=data,
|
|
files=files,
|
|
json=json,
|
|
method=method,
|
|
params=params,
|
|
timeout=timeout,
|
|
url=url,
|
|
)
|
|
|
|
|
|
def session(
|
|
authorizer: Authorizer = None,
|
|
window_size: int = WINDOW_SIZE,
|
|
) -> Session:
|
|
"""Return a :class:`.Session` instance.
|
|
|
|
:param authorizer: An instance of :class:`.Authorizer`.
|
|
:param window_size: The size of the rate limit reset window in seconds.
|
|
|
|
"""
|
|
return Session(authorizer=authorizer, window_size=window_size)
|
|
|
|
|
|
class FiniteRetryStrategy(RetryStrategy):
|
|
"""A ``RetryStrategy`` that retries requests a finite number of times."""
|
|
|
|
def __init__(self, retries: int = 3) -> None:
|
|
"""Initialize the strategy.
|
|
|
|
:param retries: Number of times to attempt a request (default: ``3``).
|
|
|
|
"""
|
|
self._retries = retries
|
|
|
|
def _sleep_seconds(self) -> float | None:
|
|
if self._retries < 3:
|
|
base = 0 if self._retries == 2 else 2
|
|
return base + 2 * random.random() # noqa: S311
|
|
return None
|
|
|
|
def consume_available_retry(self) -> FiniteRetryStrategy:
|
|
"""Allow one fewer retry."""
|
|
return type(self)(self._retries - 1)
|
|
|
|
def should_retry_on_failure(self) -> bool:
|
|
"""Return ``True`` if and only if the strategy will allow another retry."""
|
|
return self._retries > 1
|