diff --git a/requests_oauth2client/client.py b/requests_oauth2client/client.py index e9cca21..cdaba8f 100644 --- a/requests_oauth2client/client.py +++ b/requests_oauth2client/client.py @@ -305,7 +305,7 @@ class OAuth2Client: extra_metadata: dict[str, Any] = field(factory=dict) testing: bool = False - bearer_token_class: type[BearerToken] = BearerToken + token_class: type[BearerToken] = BearerToken exception_classes: ClassVar[dict[str, type[EndpointError]]] = { "server_error": ServerError, @@ -348,7 +348,7 @@ def __init__( # noqa: PLR0913 id_token_decryption_key: Jwk | dict[str, Any] | None = None, code_challenge_method: str = CodeChallengeMethods.S256, authorization_response_iss_parameter_supported: bool = False, - bearer_token_class: type[BearerToken] = BearerToken, + token_class: type[BearerToken] = BearerToken, session: requests.Session | None = None, testing: bool = False, **extra_metadata: Any, @@ -402,8 +402,8 @@ def __init__( # noqa: PLR0913 id_token_decryption_key=id_token_decryption_key, code_challenge_method=code_challenge_method, authorization_response_iss_parameter_supported=authorization_response_iss_parameter_supported, - bearer_token_class=bearer_token_class, extra_metadata=extra_metadata, + token_class=token_class, ) @token_endpoint.validator @@ -565,7 +565,7 @@ def parse_token_response(self, response: requests.Response) -> BearerToken: """ try: - token_response = self.bearer_token_class(**response.json()) + token_response = self.token_class(**response.json()) except Exception: # noqa: BLE001 return self.on_token_error(response) else: @@ -623,7 +623,7 @@ def client_credentials( **token_kwargs: additional parameters for the token endpoint, alongside `grant_type`. Common parameters Returns: - a TokenResponse + a BearerToken Raises: InvalidScopeParam: if the `scope` parameter is not suitable diff --git a/requests_oauth2client/tokens.py b/requests_oauth2client/tokens.py index ceed139..517edc6 100644 --- a/requests_oauth2client/tokens.py +++ b/requests_oauth2client/tokens.py @@ -229,6 +229,7 @@ class BearerToken(TokenResponse, requests.auth.AuthBase): """ TOKEN_TYPE: ClassVar[str] = AccessTokenType.BEARER.value + AUTHORIZATION_HEADER: ClassVar[str] = "Authorization" access_token: str expires_at: datetime | None = None @@ -540,7 +541,7 @@ def __call__(self, request: requests.PreparedRequest) -> requests.PreparedReques return request # pragma: no cover if self.is_expired(): raise ExpiredAccessToken(self) - request.headers["Authorization"] = self.authorization_header() + request.headers[self.AUTHORIZATION_HEADER] = self.authorization_header() return request @@ -626,7 +627,3 @@ def loads(self, serialized: str) -> BearerToken: """ return self.loader(serialized) - - -class DPoPToken(TokenResponse): - """Represents a DPoP Token.""" diff --git a/tests/unit_tests/test_client.py b/tests/unit_tests/test_client.py index e2b7efb..5f56854 100644 --- a/tests/unit_tests/test_client.py +++ b/tests/unit_tests/test_client.py @@ -1423,7 +1423,7 @@ def test_custom_token_type(requests_mock: RequestsMocker, token_endpoint: str) - class CustomBearerToken(BearerToken): TOKEN_TYPE = "CustomBearerToken" - client = OAuth2Client(token_endpoint, ("client_id", "client_secret"), bearer_token_class=CustomBearerToken) + client = OAuth2Client(token_endpoint, ("client_id", "client_secret"), token_class=CustomBearerToken) requests_mock.post( token_endpoint, @@ -1532,3 +1532,17 @@ def test_testing_oauth2client() -> None: assert test_client.token_endpoint == token_endpoint assert test_client.issuer == issuer + + +def test_proxy_authorization(requests_mock: RequestsMocker, target_api: str) -> None: + access_token = "my_proxy_auth_token" + auth_header = "Proxy-Authorization" + + class ProxyAuthorizationBearerToken(BearerToken): + AUTHORIZATION_HEADER = auth_header + + requests_mock.post(target_api) + + requests.post(target_api, auth=ProxyAuthorizationBearerToken(access_token)) + assert requests_mock.last_request is not None + assert requests_mock.last_request.headers[auth_header] == f"Bearer {access_token}"