Initial commit

This commit is contained in:
2026-02-01 09:31:38 +01:00
commit e02db93960
4396 changed files with 1511612 additions and 0 deletions

View File

@@ -0,0 +1,19 @@
""""Low-level communication layer for PRAW 4+."""
import logging
from .auth import (
Authorizer,
DeviceIDAuthorizer,
ImplicitAuthorizer,
ReadOnlyAuthorizer,
ScriptAuthorizer,
TrustedAuthenticator,
UntrustedAuthenticator,
)
from .const import __version__
from .exceptions import * # noqa: F403
from .requestor import Requestor
from .sessions import Session, session
logging.getLogger(__package__).addHandler(logging.NullHandler())

View File

@@ -0,0 +1,477 @@
"""Provides Authentication and Authorization classes."""
from __future__ import annotations
import time
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable
from requests import Request
from requests.status_codes import codes
from . import const
from .exceptions import InvalidInvocation, OAuthException, ResponseException
if TYPE_CHECKING:
from requests.models import Response
from prawcore.requestor import Requestor
class BaseAuthenticator(ABC):
"""Provide the base authenticator object that stores OAuth2 credentials."""
@abstractmethod
def _auth(self):
pass
def __init__(
self,
requestor: Requestor,
client_id: str,
redirect_uri: str | None = None,
) -> None:
"""Represent a single authentication to Reddit's API.
:param requestor: An instance of :class:`.Requestor`.
:param client_id: The OAuth2 client ID to use with the session.
:param redirect_uri: The redirect URI exactly as specified in your OAuth
application settings on Reddit. This parameter is required if you want to
use the :meth:`~.Authorizer.authorize_url` method, or the
:meth:`~.Authorizer.authorize` method of the :class:`.Authorizer` class
(default: ``None``).
"""
self._requestor = requestor
self.client_id = client_id
self.redirect_uri = redirect_uri
def _post(
self, url: str, success_status: int = codes["ok"], **data: Any
) -> Response:
response = self._requestor.request(
"post",
url,
auth=self._auth(),
data=sorted(data.items()),
headers={"Connection": "close"},
)
if response.status_code != success_status:
raise ResponseException(response)
return response
def authorize_url(
self,
duration: str,
scopes: list[str],
state: str,
implicit: bool = False,
) -> str:
"""Return the URL used out-of-band to grant access to your application.
:param duration: Either ``"permanent"`` or ``"temporary"``. ``"temporary"``
authorizations generate access tokens that last only 1 hour. ``"permanent"``
authorizations additionally generate a refresh token that can be
indefinitely used to generate new hour-long access tokens. Only
``"temporary"`` can be specified if ``implicit`` is set to ``True``.
:param scopes: A list of OAuth scopes to request authorization for.
:param state: A string that will be reflected in the callback to
``redirect_uri``. Elements must be printable ASCII characters in the range
``0x20`` through ``0x7E`` inclusive. This value should be temporarily unique
to the client for whom the URL was generated.
:param implicit: Use the implicit grant flow (default: ``False``). This flow is
only available for ``UntrustedAuthenticators``.
:returns: URL to be used out-of-band for granting access to your application.
:raises: :class:`.InvalidInvocation` if ``redirect_uri`` is not provided, if
``implicit`` is ``True`` and an authenticator other than
:class:`.UntrustedAuthenticator` is used, or ``implicit`` is ``True`` and
``duration`` is ``"permanent"``.
"""
if self.redirect_uri is None:
msg = "redirect URI not provided"
raise InvalidInvocation(msg)
if implicit and not isinstance(self, UntrustedAuthenticator):
msg = (
"Only UntrustedAuthenticator instances can use the implicit grant flow."
)
raise InvalidInvocation(msg)
if implicit and duration != "temporary":
msg = "The implicit grant flow only supports temporary access tokens."
raise InvalidInvocation(msg)
params = {
"client_id": self.client_id,
"duration": duration,
"redirect_uri": self.redirect_uri,
"response_type": "token" if implicit else "code",
"scope": " ".join(scopes),
"state": state,
}
url = self._requestor.reddit_url + const.AUTHORIZATION_PATH
request = Request("GET", url, params=params)
return request.prepare().url
def revoke_token(self, token: str, token_type: str | None = None) -> None:
"""Ask Reddit to revoke the provided token.
:param token: The access or refresh token to revoke.
:param token_type: When provided, hint to Reddit what the token type is for a
possible efficiency gain. The value can be either ``"access_token"`` or
``"refresh_token"``.
"""
data = {"token": token}
if token_type is not None:
data["token_type_hint"] = token_type
url = self._requestor.reddit_url + const.REVOKE_TOKEN_PATH
self._post(url, **data)
class BaseAuthorizer(ABC):
"""Superclass for OAuth2 authorization tokens and scopes."""
AUTHENTICATOR_CLASS: tuple | type = BaseAuthenticator
def __init__(self, authenticator: BaseAuthenticator) -> None:
"""Represent a single authorization to Reddit's API.
:param authenticator: An instance of :class:`.BaseAuthenticator`.
"""
self._authenticator = authenticator
self._clear_access_token()
self._validate_authenticator()
def _clear_access_token(self) -> None:
self._expiration_timestamp: float
self.access_token: str | None = None
self.scopes: set[str] | None = None
def _request_token(self, **data: Any) -> None:
url = self._authenticator._requestor.reddit_url + const.ACCESS_TOKEN_PATH
pre_request_time = time.time()
response = self._authenticator._post(url=url, **data)
payload = response.json()
if "error" in payload: # Why are these OKAY responses?
raise OAuthException(
response, payload["error"], payload.get("error_description")
)
self._expiration_timestamp = pre_request_time - 10 + payload["expires_in"]
self.access_token = payload["access_token"]
if "refresh_token" in payload:
self.refresh_token = payload["refresh_token"]
self.scopes = set(payload["scope"].split(" "))
def _validate_authenticator(self) -> None:
if not isinstance(self._authenticator, self.AUTHENTICATOR_CLASS):
msg = "Must use an authenticator of type"
if isinstance(self.AUTHENTICATOR_CLASS, type):
msg += f" {self.AUTHENTICATOR_CLASS.__name__}."
else:
msg += (
f" {' or '.join([i.__name__ for i in self.AUTHENTICATOR_CLASS])}."
)
raise InvalidInvocation(msg)
def is_valid(self) -> bool:
"""Return whether the :class`.Authorizer` is ready to authorize requests.
A ``True`` return value does not guarantee that the ``access_token`` is actually
valid on the server side.
"""
return (
self.access_token is not None and time.time() < self._expiration_timestamp
)
def revoke(self) -> None:
"""Revoke the current Authorization."""
if self.access_token is None:
msg = "no token available to revoke"
raise InvalidInvocation(msg)
self._authenticator.revoke_token(self.access_token, "access_token")
self._clear_access_token()
class TrustedAuthenticator(BaseAuthenticator):
"""Store OAuth2 authentication credentials for web, or script type apps."""
RESPONSE_TYPE: str = "code"
def __init__(
self,
requestor: Requestor,
client_id: str,
client_secret: str,
redirect_uri: str | None = None,
) -> None:
"""Represent a single authentication to Reddit's API.
:param requestor: An instance of :class:`.Requestor`.
:param client_id: The OAuth2 client ID to use with the session.
:param client_secret: The OAuth2 client secret to use with the session.
:param redirect_uri: The redirect URI exactly as specified in your OAuth
application settings on Reddit. This parameter is required if you want to
use the :meth:`~.Authorizer.authorize_url` method, or the
:meth:`~.Authorizer.authorize` method of the :class:`.Authorizer` class
(default: ``None``).
"""
super().__init__(requestor, client_id, redirect_uri)
self.client_secret = client_secret
def _auth(self) -> tuple[str, str]:
return self.client_id, self.client_secret
class UntrustedAuthenticator(BaseAuthenticator):
"""Store OAuth2 authentication credentials for installed applications."""
def _auth(self) -> tuple[str, str]:
return self.client_id, ""
class Authorizer(BaseAuthorizer):
"""Manages OAuth2 authorization tokens and scopes."""
def __init__(
self,
authenticator: BaseAuthenticator,
*,
post_refresh_callback: Callable[[Authorizer], None] | None = None,
pre_refresh_callback: Callable[[Authorizer], None] | None = None,
refresh_token: str | None = None,
) -> None:
"""Represent a single authorization to Reddit's API.
:param authenticator: An instance of a subclass of :class:`.BaseAuthenticator`.
:param post_refresh_callback: When a single-argument function is passed, the
function will be called prior to refreshing the access and refresh tokens.
The argument to the callback is the :class:`.Authorizer` instance. This
callback can be used to inspect and modify the attributes of the
:class:`.Authorizer`.
:param pre_refresh_callback: When a single-argument function is passed, the
function will be called after refreshing the access and refresh tokens. The
argument to the callback is the :class:`.Authorizer` instance. This callback
can be used to inspect and modify the attributes of the
:class:`.Authorizer`.
:param refresh_token: Enables the ability to refresh the authorization.
"""
super().__init__(authenticator)
self._post_refresh_callback = post_refresh_callback
self._pre_refresh_callback = pre_refresh_callback
self.refresh_token = refresh_token
def authorize(self, code: str) -> None:
"""Obtain and set authorization tokens based on ``code``.
:param code: The code obtained by an out-of-band authorization request to
Reddit.
"""
if self._authenticator.redirect_uri is None:
msg = "redirect URI not provided"
raise InvalidInvocation(msg)
self._request_token(
code=code,
grant_type="authorization_code",
redirect_uri=self._authenticator.redirect_uri,
)
def refresh(self) -> None:
"""Obtain a new access token from the refresh_token."""
if self._pre_refresh_callback:
self._pre_refresh_callback(self)
if self.refresh_token is None:
msg = "refresh token not provided"
raise InvalidInvocation(msg)
self._request_token(
grant_type="refresh_token", refresh_token=self.refresh_token
)
if self._post_refresh_callback:
self._post_refresh_callback(self)
def revoke(self, only_access: bool = False) -> None:
"""Revoke the current Authorization.
:param only_access: When explicitly set to ``True``, do not evict the refresh
token if one is set.
Revoking a refresh token will in-turn revoke all access tokens associated with
that authorization.
"""
if only_access or self.refresh_token is None:
super().revoke()
else:
self._authenticator.revoke_token(self.refresh_token, "refresh_token")
self._clear_access_token()
self.refresh_token = None
class ImplicitAuthorizer(BaseAuthorizer):
"""Manages implicit installed-app type authorizations."""
AUTHENTICATOR_CLASS = UntrustedAuthenticator
def __init__(
self,
authenticator: UntrustedAuthenticator,
access_token: str,
expires_in: int,
scope: str,
) -> None:
"""Represent a single implicit authorization to Reddit's API.
:param authenticator: An instance of :class:`.UntrustedAuthenticator`.
:param access_token: The access_token obtained from Reddit via callback to the
authenticator's ``redirect_uri``.
:param expires_in: The number of seconds the ``access_token`` is valid for. The
origin of this value was returned from Reddit via callback to the
authenticator's redirect uri. Note, you may need to subtract an offset
before passing in this number to account for a delay between when Reddit
prepared the response, and when you make this function call.
:param scope: A space-delimited string of Reddit OAuth2 scope names as returned
from Reddit in the callback to the authenticator's redirect uri.
"""
super().__init__(authenticator)
self._expiration_timestamp = time.time() + expires_in
self.access_token = access_token
self.scopes = set(scope.split(" "))
class ReadOnlyAuthorizer(Authorizer):
"""Manages authorizations that are not associated with a Reddit account.
While the ``"*"`` scope will be available, some endpoints simply will not work due
to the lack of an associated Reddit account.
"""
AUTHENTICATOR_CLASS = TrustedAuthenticator
def __init__(
self,
authenticator: BaseAuthenticator,
scopes: list[str] | None = None,
) -> None:
"""Represent a ReadOnly authorization to Reddit's API.
:param scopes: A list of OAuth scopes to request authorization for (default:
``None``). The scope ``"*"`` is requested when the default argument is used.
"""
super().__init__(authenticator)
self._scopes = scopes
def refresh(self) -> None:
"""Obtain a new ReadOnly access token."""
additional_kwargs = {}
if self._scopes:
additional_kwargs["scope"] = " ".join(self._scopes)
self._request_token(grant_type="client_credentials", **additional_kwargs)
class ScriptAuthorizer(Authorizer):
"""Manages personal-use script type authorizations.
Only users who are listed as developers for the application will be granted access
tokens.
"""
AUTHENTICATOR_CLASS = TrustedAuthenticator
def __init__(
self,
authenticator: BaseAuthenticator,
username: str | None,
password: str | None,
two_factor_callback: Callable | None = None,
scopes: list[str] | None = None,
) -> None:
"""Represent a single personal-use authorization to Reddit's API.
:param authenticator: An instance of :class:`.TrustedAuthenticator`.
:param username: The Reddit username of one of the application's developers.
:param password: The password associated with ``username``.
:param two_factor_callback: A function that returns OTPs (One-Time Passcodes),
also known as 2FA auth codes. If this function is provided, prawcore will
call it when authenticating.
:param scopes: A list of OAuth scopes to request authorization for (default:
``None``). The scope ``"*"`` is requested when the default argument is used.
"""
super().__init__(authenticator)
self._password = password
self._scopes = scopes
self._two_factor_callback = two_factor_callback
self._username = username
def refresh(self) -> None:
"""Obtain a new personal-use script type access token."""
additional_kwargs = {}
if self._scopes:
additional_kwargs["scope"] = " ".join(self._scopes)
two_factor_code = self._two_factor_callback and self._two_factor_callback()
if two_factor_code:
additional_kwargs["otp"] = two_factor_code
self._request_token(
grant_type="password",
username=self._username,
password=self._password,
**additional_kwargs,
)
class DeviceIDAuthorizer(BaseAuthorizer):
"""Manages app-only OAuth2 for 'installed' applications.
While the ``"*"`` scope will be available, some endpoints simply will not work due
to the lack of an associated Reddit account.
"""
AUTHENTICATOR_CLASS = (TrustedAuthenticator, UntrustedAuthenticator)
def __init__(
self,
authenticator: BaseAuthenticator,
device_id: str | None = None,
scopes: list[str] | None = None,
) -> None:
"""Represent an app-only OAuth2 authorization for 'installed' apps.
:param authenticator: An instance of :class:`.UntrustedAuthenticator` or
:class:`.TrustedAuthenticator`.
:param device_id: A unique ID (20-30 character ASCII string) (default:
``None``). ``device_id`` is set to ``"DO_NOT_TRACK_THIS_DEVICE"`` when the
default argument is used. For more information about this parameter, see:
https://github.com/reddit/reddit/wiki/OAuth2#application-only-oauth
:param scopes: A list of OAuth scopes to request authorization for (default:
``None``). The scope ``"*"`` is requested when the default argument is used.
"""
if device_id is None:
device_id = "DO_NOT_TRACK_THIS_DEVICE"
super().__init__(authenticator)
self._device_id = device_id
self._scopes = scopes
def refresh(self) -> None:
"""Obtain a new access token."""
additional_kwargs = {}
if self._scopes:
additional_kwargs["scope"] = " ".join(self._scopes)
grant_type = "https://oauth.reddit.com/grants/installed_client"
self._request_token(
grant_type=grant_type,
device_id=self._device_id,
**additional_kwargs,
)

View File

@@ -0,0 +1,14 @@
"""Constants for the prawcore package."""
import os
__version__ = "2.4.0"
ACCESS_TOKEN_PATH = "/api/v1/access_token" # noqa: S105
AUTHORIZATION_PATH = "/api/v1/authorize" # noqa: S105
REVOKE_TOKEN_PATH = "/api/v1/revoke_token" # noqa: S105
TIMEOUT = float(
os.environ.get(
"PRAWCORE_TIMEOUT", os.environ.get("prawcore_timeout", 16) # noqa: SIM112
)
)
WINDOW_SIZE = 600

View File

@@ -0,0 +1,188 @@
"""Provide exception classes for the prawcore package."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
if TYPE_CHECKING:
from requests.models import Response
class PrawcoreException(Exception): # noqa: N818
"""Base exception class for exceptions that occur within this package."""
class InvalidInvocation(PrawcoreException):
"""Indicate that the code to execute cannot be completed."""
class OAuthException(PrawcoreException):
"""Indicate that there was an OAuth2 related error with the request."""
def __init__(
self, response: Response, error: str, description: str | None = None
) -> None:
"""Initialize a OAuthException instance.
:param response: A ``requests.response`` instance.
:param error: The error type returned by Reddit.
:param description: A description of the error when provided.
"""
self.error = error
self.description = description
self.response = response
message = f"{error} error processing request"
if description:
message += f" ({description})"
PrawcoreException.__init__(self, message)
class RequestException(PrawcoreException):
"""Indicate that there was an error with the incomplete HTTP request."""
def __init__(
self,
original_exception: Exception,
request_args: tuple[Any, ...],
request_kwargs: dict[
str, bool | (dict[str, int] | (dict[str, str] | str)) | None
],
) -> None:
"""Initialize a RequestException instance.
:param original_exception: The original exception that occurred.
:param request_args: The arguments to the request function.
:param request_kwargs: The keyword arguments to the request function.
"""
self.original_exception = original_exception
self.request_args = request_args
self.request_kwargs = request_kwargs
super().__init__(f"error with request {original_exception}")
class ResponseException(PrawcoreException):
"""Indicate that there was an error with the completed HTTP request."""
def __init__(self, response: Response) -> None:
"""Initialize a ResponseException instance.
:param response: A ``requests.response`` instance.
"""
self.response = response
super().__init__(f"received {response.status_code} HTTP response")
class BadJSON(ResponseException):
"""Indicate the response did not contain valid JSON."""
class BadRequest(ResponseException):
"""Indicate invalid parameters for the request."""
class Conflict(ResponseException):
"""Indicate a conflicting change in the target resource."""
class Forbidden(ResponseException):
"""Indicate the authentication is not permitted for the request."""
class InsufficientScope(ResponseException):
"""Indicate that the request requires a different scope."""
class InvalidToken(ResponseException):
"""Indicate that the request used an invalid access token."""
class NotFound(ResponseException):
"""Indicate that the requested URL was not found."""
class Redirect(ResponseException):
"""Indicate the request resulted in a redirect.
This class adds the attribute ``path``, which is the path to which the response
redirects.
"""
def __init__(self, response: Response) -> None:
"""Initialize a Redirect exception instance.
:param response: A ``requests.response`` instance containing a location header.
"""
path = urlparse(response.headers["location"]).path
self.path = path[:-5] if path.endswith(".json") else path
self.response = response
msg = f"Redirect to {self.path}"
msg += (
" (You may be trying to perform a non-read-only action via a "
"read-only instance.)"
if "/login/" in self.path
else ""
)
PrawcoreException.__init__(self, msg)
class ServerError(ResponseException):
"""Indicate issues on the server end preventing request fulfillment."""
class SpecialError(ResponseException):
"""Indicate syntax or spam-prevention issues."""
def __init__(self, response: Response) -> None:
"""Initialize a SpecialError exception instance.
:param response: A ``requests.response`` instance containing a message and a
list of special errors.
"""
self.response = response
resp_dict = self.response.json() # assumes valid JSON
self.message = resp_dict.get("message", "")
self.reason = resp_dict.get("reason", "")
self.special_errors = resp_dict.get("special_errors", [])
PrawcoreException.__init__(self, f"Special error {self.message!r}")
class TooLarge(ResponseException):
"""Indicate that the request data exceeds the allowed limit."""
class TooManyRequests(ResponseException):
"""Indicate that the user has sent too many requests in a given amount of time."""
def __init__(self, response: Response) -> None:
"""Initialize a TooManyRequests exception instance.
:param response: A ``requests.response`` instance that may contain a retry-after
header and a message.
"""
self.response = response
self.retry_after = response.headers.get("retry-after")
self.message = response.text # Not all response bodies are valid JSON
msg = f"received {response.status_code} HTTP response"
if self.retry_after:
msg += (
f". Please wait at least {float(self.retry_after)} seconds before"
f" re-trying this request."
)
PrawcoreException.__init__(self, msg)
class URITooLong(ResponseException):
"""Indicate that the length of the request URI exceeds the allowed limit."""
class UnavailableForLegalReasons(ResponseException):
"""Indicate that the requested URL is unavailable due to legal reasons."""

View File

@@ -0,0 +1,103 @@
"""Provide the RateLimiter class."""
from __future__ import annotations
import logging
import time
from typing import TYPE_CHECKING, Any, Callable, Mapping
if TYPE_CHECKING:
from requests.models import Response
log = logging.getLogger(__package__)
class RateLimiter:
"""Facilitates the rate limiting of requests to Reddit.
Rate limits are controlled based on feedback from requests to Reddit.
"""
def __init__(self, *, window_size: int) -> None:
"""Create an instance of the RateLimit class."""
self.remaining: float | None = None
self.next_request_timestamp: float | None = None
self.reset_timestamp: float | None = None
self.used: int | None = None
self.window_size: int = window_size
def call(
self,
request_function: Callable[[Any], Response],
set_header_callback: Callable[[], dict[str, str]],
*args: Any,
**kwargs: Any,
) -> Response:
"""Rate limit the call to ``request_function``.
:param request_function: A function call that returns an HTTP response object.
:param set_header_callback: A callback function used to set the request headers.
This callback is called after any necessary sleep time occurs.
:param args: The positional arguments to ``request_function``.
:param kwargs: The keyword arguments to ``request_function``.
"""
self.delay()
kwargs["headers"] = set_header_callback()
response = request_function(*args, **kwargs)
self.update(response.headers)
return response
def delay(self) -> None:
"""Sleep for an amount of time to remain under the rate limit."""
if self.next_request_timestamp is None:
return
sleep_seconds = self.next_request_timestamp - time.time()
if sleep_seconds <= 0:
return
message = f"Sleeping: {sleep_seconds:0.2f} seconds prior to call"
log.debug(message)
time.sleep(sleep_seconds)
def update(self, response_headers: Mapping[str, str]) -> None:
"""Update the state of the rate limiter based on the response headers.
This method should only be called following an HTTP request to Reddit.
Response headers that do not contain ``x-ratelimit`` fields will be treated as a
single request. This behavior is to error on the safe-side as such responses
should trigger exceptions that indicate invalid behavior.
"""
if "x-ratelimit-remaining" not in response_headers:
if self.remaining is not None:
self.remaining -= 1
self.used += 1
return
now = time.time()
seconds_to_reset = int(response_headers["x-ratelimit-reset"])
self.remaining = float(response_headers["x-ratelimit-remaining"])
self.used = int(response_headers["x-ratelimit-used"])
self.reset_timestamp = now + seconds_to_reset
if self.remaining <= 0:
self.next_request_timestamp = self.reset_timestamp
return
self.next_request_timestamp = min(
self.reset_timestamp,
now
+ min(
max(
seconds_to_reset
- (
self.window_size
- (self.window_size / (self.remaining + self.used) * self.used)
),
0,
),
10,
),
)

View File

@@ -0,0 +1,70 @@
"""Provides the HTTP request handling interface."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import requests
from .const import TIMEOUT, __version__
from .exceptions import InvalidInvocation, RequestException
if TYPE_CHECKING:
from requests.models import Response
from .sessions import Session
class Requestor:
"""Requestor provides an interface to HTTP requests."""
def __getattr__(self, attribute: str) -> Any:
"""Pass all undefined attributes to the ``_http`` attribute."""
if attribute.startswith("__"):
raise AttributeError
return getattr(self._http, attribute)
def __init__(
self,
user_agent: str,
oauth_url: str = "https://oauth.reddit.com",
reddit_url: str = "https://www.reddit.com",
session: Session | None = None,
timeout: float = TIMEOUT,
) -> None:
"""Create an instance of the Requestor class.
:param user_agent: The user-agent for your application. Please follow Reddit's
user-agent guidelines: https://github.com/reddit/reddit/wiki/API#rules
:param oauth_url: The URL used to make OAuth requests to the Reddit site
(default: ``"https://oauth.reddit.com"``).
:param reddit_url: The URL used when obtaining access tokens (default:
``"https://www.reddit.com"``).
:param session: A session to handle requests, compatible with
``requests.Session()`` (default: ``None``).
:param timeout: How many seconds to wait for the server to send data before
giving up (default: ``prawcore.const.TIMEOUT``).
"""
if user_agent is None or len(user_agent) < 7:
msg = "user_agent is not descriptive"
raise InvalidInvocation(msg)
self._http = session or requests.Session()
self._http.headers["User-Agent"] = f"{user_agent} prawcore/{__version__}"
self.oauth_url = oauth_url
self.reddit_url = reddit_url
self.timeout = timeout
def close(self) -> None:
"""Call close on the underlying session."""
return self._http.close()
def request(
self, *args: Any, timeout: float | None = None, **kwargs: Any
) -> Response:
"""Issue the HTTP request capturing any errors that may occur."""
try:
return self._http.request(*args, timeout=timeout or self.timeout, **kwargs)
except Exception as exc: # noqa: BLE001
raise RequestException(exc, args, kwargs) from None

View File

@@ -0,0 +1,375 @@
"""prawcore.sessions: Provides prawcore.Session and prawcore.session."""
from __future__ import annotations
import logging
import random
import time
from abc import ABC, abstractmethod
from copy import deepcopy
from pprint import pformat
from typing import TYPE_CHECKING, Any
from urllib.parse import urljoin
from requests.exceptions import ChunkedEncodingError, ConnectionError, ReadTimeout
from requests.status_codes import codes
from .auth import BaseAuthorizer
from .const import TIMEOUT, WINDOW_SIZE
from .exceptions import (
BadJSON,
BadRequest,
Conflict,
InvalidInvocation,
NotFound,
Redirect,
RequestException,
ServerError,
SpecialError,
TooLarge,
TooManyRequests,
UnavailableForLegalReasons,
URITooLong,
)
from .rate_limit import RateLimiter
from .util import authorization_error_class
if TYPE_CHECKING:
from io import BufferedReader
from requests.models import Response
from .auth import Authorizer
from .requestor import Requestor
log = logging.getLogger(__package__)
class RetryStrategy(ABC):
"""An abstract class for scheduling request retries.
The strategy controls both the number and frequency of retry attempts.
Instances of this class are immutable.
"""
@abstractmethod
def _sleep_seconds(self) -> float | None:
pass
def sleep(self) -> None:
"""Sleep until we are ready to attempt the request."""
sleep_seconds = self._sleep_seconds()
if sleep_seconds is not None:
message = f"Sleeping: {sleep_seconds:0.2f} seconds prior to retry"
log.debug(message)
time.sleep(sleep_seconds)
class Session:
"""The low-level connection interface to Reddit's API."""
RETRY_EXCEPTIONS = (ChunkedEncodingError, ConnectionError, ReadTimeout)
RETRY_STATUSES = {
520,
522,
codes["bad_gateway"],
codes["gateway_timeout"],
codes["internal_server_error"],
codes["request_timeout"],
codes["service_unavailable"],
}
STATUS_EXCEPTIONS = {
codes["bad_gateway"]: ServerError,
codes["bad_request"]: BadRequest,
codes["conflict"]: Conflict,
codes["found"]: Redirect,
codes["forbidden"]: authorization_error_class,
codes["gateway_timeout"]: ServerError,
codes["internal_server_error"]: ServerError,
codes["media_type"]: SpecialError,
codes["moved_permanently"]: Redirect,
codes["not_found"]: NotFound,
codes["request_entity_too_large"]: TooLarge,
codes["request_uri_too_large"]: URITooLong,
codes["service_unavailable"]: ServerError,
codes["too_many_requests"]: TooManyRequests,
codes["unauthorized"]: authorization_error_class,
codes[
"unavailable_for_legal_reasons"
]: UnavailableForLegalReasons, # Cloudflare's status (not named in requests)
520: ServerError,
522: ServerError,
}
SUCCESS_STATUSES = {codes["accepted"], codes["created"], codes["ok"]}
@staticmethod
def _log_request(
data: list[tuple[str, str]] | None,
method: str,
params: dict[str, int],
url: str,
) -> None:
log.debug("Fetching: %s %s at %s", method, url, time.time())
log.debug("Data: %s", pformat(data))
log.debug("Params: %s", pformat(params))
@property
def _requestor(self) -> Requestor:
return self._authorizer._authenticator._requestor
def __enter__(self) -> Session: # noqa: PYI034
"""Allow this object to be used as a context manager."""
return self
def __exit__(self, *_args) -> None:
"""Allow this object to be used as a context manager."""
self.close()
def __init__(
self,
authorizer: BaseAuthorizer | None,
window_size: int = WINDOW_SIZE,
) -> None:
"""Prepare the connection to Reddit's API.
:param authorizer: An instance of :class:`.Authorizer`.
:param window_size: The size of the rate limit reset window in seconds.
"""
if not isinstance(authorizer, BaseAuthorizer):
msg = f"invalid Authorizer: {authorizer}"
raise InvalidInvocation(msg)
self._authorizer = authorizer
self._rate_limiter = RateLimiter(window_size=window_size)
self._retry_strategy_class = FiniteRetryStrategy
def _do_retry(
self,
data: list[tuple[str, Any]],
files: dict[str, BufferedReader],
json: dict[str, Any],
method: str,
params: dict[str, int],
response: Response | None,
retry_strategy_state: FiniteRetryStrategy,
saved_exception: Exception | None,
timeout: float,
url: str,
) -> dict[str, Any] | str | None:
status = repr(saved_exception) if saved_exception else response.status_code
log.warning("Retrying due to %s status: %s %s", status, method, url)
return self._request_with_retries(
data=data,
files=files,
json=json,
method=method,
params=params,
timeout=timeout,
url=url,
retry_strategy_state=retry_strategy_state.consume_available_retry(),
# noqa: E501
)
def _make_request(
self,
data: list[tuple[str, Any]],
files: dict[str, BufferedReader],
json: dict[str, Any],
method: str,
params: dict[str, Any],
retry_strategy_state: FiniteRetryStrategy,
timeout: float,
url: str,
) -> tuple[Response, None] | tuple[None, Exception]:
try:
response = self._rate_limiter.call(
self._requestor.request,
self._set_header_callback,
method,
url,
allow_redirects=False,
data=data,
files=files,
json=json,
params=params,
timeout=timeout,
)
log.debug(
"Response: %s (%s bytes) (rst-%s:rem-%s:used-%s ratelimit) at %s",
response.status_code,
response.headers.get("content-length"),
response.headers.get("x-ratelimit-reset"),
response.headers.get("x-ratelimit-remaining"),
response.headers.get("x-ratelimit-used"),
time.time(),
)
return response, None
except RequestException as exception:
if (
not retry_strategy_state.should_retry_on_failure()
or not isinstance( # noqa: E501
exception.original_exception, self.RETRY_EXCEPTIONS
)
):
raise
return None, exception.original_exception
def _request_with_retries(
self,
data: list[tuple[str, Any]],
files: dict[str, BufferedReader],
json: dict[str, Any],
method: str,
params: dict[str, int],
timeout: float,
url: str,
retry_strategy_state: FiniteRetryStrategy | None = None,
) -> dict[str, Any] | str | None:
if retry_strategy_state is None:
retry_strategy_state = self._retry_strategy_class()
retry_strategy_state.sleep()
self._log_request(data, method, params, url)
response, saved_exception = self._make_request(
data,
files,
json,
method,
params,
retry_strategy_state,
timeout,
url,
)
do_retry = False
if response is not None and response.status_code == codes["unauthorized"]:
self._authorizer._clear_access_token()
if hasattr(self._authorizer, "refresh"):
do_retry = True
if retry_strategy_state.should_retry_on_failure() and (
do_retry or response is None or response.status_code in self.RETRY_STATUSES
):
return self._do_retry(
data,
files,
json,
method,
params,
response,
retry_strategy_state,
saved_exception,
timeout,
url,
)
if response.status_code in self.STATUS_EXCEPTIONS:
raise self.STATUS_EXCEPTIONS[response.status_code](response)
if response.status_code == codes["no_content"]:
return None
assert (
response.status_code in self.SUCCESS_STATUSES
), f"Unexpected status code: {response.status_code}"
if response.headers.get("content-length") == "0":
return ""
try:
return response.json()
except ValueError:
raise BadJSON(response) from None
def _set_header_callback(self) -> dict[str, str]:
if not self._authorizer.is_valid() and hasattr(self._authorizer, "refresh"):
self._authorizer.refresh()
return {"Authorization": f"bearer {self._authorizer.access_token}"}
def close(self) -> None:
"""Close the session and perform any clean up."""
self._requestor.close()
def request(
self,
method: str,
path: str,
data: dict[str, Any] | None = None,
files: dict[str, BufferedReader] | None = None,
json: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
timeout: float = TIMEOUT,
) -> dict[str, Any] | str | None:
"""Return the json content from the resource at ``path``.
:param method: The request verb. E.g., ``"GET"``, ``"POST"``, ``"PUT"``.
:param path: The path of the request. This path will be combined with the
``oauth_url`` of the Requestor.
:param data: Dictionary, bytes, or file-like object to send in the body of the
request.
:param files: Dictionary, mapping ``filename`` to file-like object.
:param json: Object to be serialized to JSON in the body of the request.
:param params: The query parameters to send with the request.
:param timeout: Specifies a particular timeout, in seconds.
Automatically refreshes the access token if it becomes invalid and a refresh
token is available.
:raises: :class:`.InvalidInvocation` in such a case if a refresh token is not
available.
"""
params = deepcopy(params) or {}
params["raw_json"] = 1
if isinstance(data, dict):
data = deepcopy(data)
data["api_type"] = "json"
data = sorted(data.items())
if isinstance(json, dict):
json = deepcopy(json)
json["api_type"] = "json"
url = urljoin(self._requestor.oauth_url, path)
return self._request_with_retries(
data=data,
files=files,
json=json,
method=method,
params=params,
timeout=timeout,
url=url,
)
def session(
authorizer: Authorizer = None,
window_size: int = WINDOW_SIZE,
) -> Session:
"""Return a :class:`.Session` instance.
:param authorizer: An instance of :class:`.Authorizer`.
:param window_size: The size of the rate limit reset window in seconds.
"""
return Session(authorizer=authorizer, window_size=window_size)
class FiniteRetryStrategy(RetryStrategy):
"""A ``RetryStrategy`` that retries requests a finite number of times."""
def __init__(self, retries: int = 3) -> None:
"""Initialize the strategy.
:param retries: Number of times to attempt a request (default: ``3``).
"""
self._retries = retries
def _sleep_seconds(self) -> float | None:
if self._retries < 3:
base = 0 if self._retries == 2 else 2
return base + 2 * random.random() # noqa: S311
return None
def consume_available_retry(self) -> FiniteRetryStrategy:
"""Allow one fewer retry."""
return type(self)(self._retries - 1)
def should_retry_on_failure(self) -> bool:
"""Return ``True`` if and only if the strategy will allow another retry."""
return self._retries > 1

View File

@@ -0,0 +1,32 @@
"""Provide utility for the prawcore package."""
from __future__ import annotations
from typing import TYPE_CHECKING
from .exceptions import Forbidden, InsufficientScope, InvalidToken
if TYPE_CHECKING:
from requests.models import Response
_auth_error_mapping = {
403: Forbidden,
"insufficient_scope": InsufficientScope,
"invalid_token": InvalidToken,
}
def authorization_error_class(
response: Response,
) -> InvalidToken | (Forbidden | InsufficientScope):
"""Return an exception instance that maps to the OAuth Error.
:param response: The HTTP response containing a www-authenticate error.
"""
message = response.headers.get("www-authenticate")
error: int | str
if message:
error = message.replace('"', "").rsplit("=", 1)[1]
else:
error = response.status_code
return _auth_error_mapping[error](response)