Skip to content

Commit

Permalink
added authorization header name as ClassVar in BearerToken, for #70
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-pujol committed Sep 4, 2024
1 parent 10f813a commit f10ef79
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 11 deletions.
10 changes: 5 additions & 5 deletions requests_oauth2client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ class OAuth2Client:
extra_metadata: dict[str, Any] = field(factory=dict)
testing: bool = False

bearer_token_class: type[BearerToken] = BearerToken
token_class: type[BearerToken] = BearerToken

exception_classes: ClassVar[dict[str, type[EndpointError]]] = {
"server_error": ServerError,
Expand Down Expand Up @@ -348,7 +348,7 @@ def __init__( # noqa: PLR0913
id_token_decryption_key: Jwk | dict[str, Any] | None = None,
code_challenge_method: str = CodeChallengeMethods.S256,
authorization_response_iss_parameter_supported: bool = False,
bearer_token_class: type[BearerToken] = BearerToken,
token_class: type[BearerToken] = BearerToken,
session: requests.Session | None = None,
testing: bool = False,
**extra_metadata: Any,
Expand Down Expand Up @@ -402,8 +402,8 @@ def __init__( # noqa: PLR0913
id_token_decryption_key=id_token_decryption_key,
code_challenge_method=code_challenge_method,
authorization_response_iss_parameter_supported=authorization_response_iss_parameter_supported,
bearer_token_class=bearer_token_class,
extra_metadata=extra_metadata,
token_class=token_class,
)

@token_endpoint.validator
Expand Down Expand Up @@ -565,7 +565,7 @@ def parse_token_response(self, response: requests.Response) -> BearerToken:
"""
try:
token_response = self.bearer_token_class(**response.json())
token_response = self.token_class(**response.json())
except Exception: # noqa: BLE001
return self.on_token_error(response)
else:
Expand Down Expand Up @@ -623,7 +623,7 @@ def client_credentials(
**token_kwargs: additional parameters for the token endpoint, alongside `grant_type`. Common parameters
Returns:
a TokenResponse
a BearerToken
Raises:
InvalidScopeParam: if the `scope` parameter is not suitable
Expand Down
7 changes: 2 additions & 5 deletions requests_oauth2client/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ class BearerToken(TokenResponse, requests.auth.AuthBase):
"""

TOKEN_TYPE: ClassVar[str] = AccessTokenType.BEARER.value
AUTHORIZATION_HEADER: ClassVar[str] = "Authorization"

access_token: str
expires_at: datetime | None = None
Expand Down Expand Up @@ -540,7 +541,7 @@ def __call__(self, request: requests.PreparedRequest) -> requests.PreparedReques
return request # pragma: no cover
if self.is_expired():
raise ExpiredAccessToken(self)
request.headers["Authorization"] = self.authorization_header()
request.headers[self.AUTHORIZATION_HEADER] = self.authorization_header()
return request


Expand Down Expand Up @@ -626,7 +627,3 @@ def loads(self, serialized: str) -> BearerToken:
"""
return self.loader(serialized)


class DPoPToken(TokenResponse):
"""Represents a DPoP Token."""
16 changes: 15 additions & 1 deletion tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,7 +1423,7 @@ def test_custom_token_type(requests_mock: RequestsMocker, token_endpoint: str) -
class CustomBearerToken(BearerToken):
TOKEN_TYPE = "CustomBearerToken"

client = OAuth2Client(token_endpoint, ("client_id", "client_secret"), bearer_token_class=CustomBearerToken)
client = OAuth2Client(token_endpoint, ("client_id", "client_secret"), token_class=CustomBearerToken)

requests_mock.post(
token_endpoint,
Expand Down Expand Up @@ -1532,3 +1532,17 @@ def test_testing_oauth2client() -> None:

assert test_client.token_endpoint == token_endpoint
assert test_client.issuer == issuer


def test_proxy_authorization(requests_mock: RequestsMocker, target_api: str) -> None:
access_token = "my_proxy_auth_token"
auth_header = "Proxy-Authorization"

class ProxyAuthorizationBearerToken(BearerToken):
AUTHORIZATION_HEADER = auth_header

requests_mock.post(target_api)

requests.post(target_api, auth=ProxyAuthorizationBearerToken(access_token))
assert requests_mock.last_request is not None
assert requests_mock.last_request.headers[auth_header] == f"Bearer {access_token}"

0 comments on commit f10ef79

Please sign in to comment.