From 78f24e4fff7c89e52a54278704bd6c0d2e4c0831 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Tue, 30 Jul 2024 18:27:27 +0000 Subject: [PATCH 1/3] fix: typo in docstring --- httpx_oauth/oauth2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/httpx_oauth/oauth2.py b/httpx_oauth/oauth2.py index b4d9e7b..fc42b93 100644 --- a/httpx_oauth/oauth2.py +++ b/httpx_oauth/oauth2.py @@ -55,7 +55,7 @@ def __init__(self): class RevokeTokenNotSupportedError(OAuth2Error): """ - Error raised when trying to revole a token + Error raised when trying to revoke a token on a provider that does not support it. """ From 6449af9ec8dac4b2b760c8ee7f0cd59840b10bfa Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Wed, 31 Jul 2024 02:13:55 +0000 Subject: [PATCH 2/3] feat: add `litestar` integration --- docs/litestar.md | 54 ++++++ .../httpx_oauth.integrations.litestar.md | 6 + httpx_oauth/integrations/litestar.py | 120 +++++++++++++ pyproject.toml | 1 + tests/test_integrations_litestar.py | 157 ++++++++++++++++++ 5 files changed, 338 insertions(+) create mode 100644 docs/litestar.md create mode 100644 docs/reference/httpx_oauth.integrations.litestar.md create mode 100644 httpx_oauth/integrations/litestar.py create mode 100644 tests/test_integrations_litestar.py diff --git a/docs/litestar.md b/docs/litestar.md new file mode 100644 index 0000000..230d2c6 --- /dev/null +++ b/docs/litestar.md @@ -0,0 +1,54 @@ +# Litestar + +Utilities are provided to ease the integration of an OAuth2 process in [Litestar](https://litestar.dev/). + +## `OAuth2AuthorizeCallback` + +Dependency callable to handle the authorization callback. It reads the query parameters and returns the access token and the state. + +```py +from httpx_oauth.integrations.litestar import OAuth2AuthorizeCallback, AccessTokenState +from httpx_oauth.oauth2 import OAuth2 +from litestar import Litestar, get +from litestar.di import Provide +from litestar.params import Dependency + +client = OAuth2("CLIENT_ID", "CLIENT_SECRET", "AUTHORIZE_ENDPOINT", "ACCESS_TOKEN_ENDPOINT") +oauth2_authorize_callback = OAuth2AuthorizeCallback(client, "oauth-callback") + +@get("/oauth-callback", name="oauth-callback") +async def oauth_callback( + access_token_state: AccessTokenState = Dependency(skip_validation=True), +) -> AccessTokenState: + token, state = access_token_state + # Do something useful + +app = Litestar(route_handlers=[oauth_callback],dependencies={"access_token_state": Provide(oauth2_authorize_callback)}) + + +``` + +[Reference](./reference/httpx_oauth.integrations.litestar.md){ .md-button } +{ .buttons } + +### Custom exception handler + +If an error occurs inside the callback logic (the user denied access, the authorization code is invalid...), the dependency will raise [OAuth2AuthorizeCallbackError][httpx_oauth.integrations.litestar.OAuth2AuthorizeCallbackError]. + +It inherits from Litestar's [HTTPException][litestar.exceptions.HTTPException], so it's automatically handled by the default Litestar exception handler. You can customize this behavior by implementing your own exception handler for `OAuth2AuthorizeCallbackError`. + +```py +from httpx_oauth.integrations.litestar import OAuth2AuthorizeCallbackError +from litestar import Litestar +from litestar.response import Response + +async def oauth2_authorize_callback_error_handler(request: Request, exc: OAuth2AuthorizeCallbackError) -> Response: + detail = exc.detail + status_code = exc.status_code + return Response( + status_code=status_code, + content={"message": "The OAuth2 callback failed", "detail": detail}, + ) + +app = Litestar(exception_handlers={OAuth2AuthorizeCallbackError: oauth2_authorize_callback_error_handler}) +``` diff --git a/docs/reference/httpx_oauth.integrations.litestar.md b/docs/reference/httpx_oauth.integrations.litestar.md new file mode 100644 index 0000000..1779349 --- /dev/null +++ b/docs/reference/httpx_oauth.integrations.litestar.md @@ -0,0 +1,6 @@ +# Reference - Integrations - Litestar + +::: httpx_oauth.integrations.litestar + options: + show_root_heading: false + show_source: false diff --git a/httpx_oauth/integrations/litestar.py b/httpx_oauth/integrations/litestar.py new file mode 100644 index 0000000..b67c9e6 --- /dev/null +++ b/httpx_oauth/integrations/litestar.py @@ -0,0 +1,120 @@ +# pylint: disable=[invalid-name,import-outside-toplevel] +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, TypeAlias, Union # noqa: UP035 + +from litestar import status_codes as status +from litestar.exceptions import HTTPException +from litestar.params import Parameter + +from httpx_oauth.oauth2 import BaseOAuth2, GetAccessTokenError, OAuth2Error, OAuth2Token + +if TYPE_CHECKING: + import httpx + from litestar import Request + + +AccessTokenState: TypeAlias = tuple[OAuth2Token, str | None] + + +class OAuth2AuthorizeCallbackError(OAuth2Error, HTTPException): + """Error raised when an error occurs during the OAuth2 authorization callback. + + It inherits from [HTTPException][litestar.exceptions.HTTPException], so you can either keep + the default Litestar error handling or implement something dedicated. + + !!! Note + Due to the way the base `LitestarException` handles the `detail` argument, + the `OAuth2Error` is ordered first here + """ + + def __init__( + self, + status_code: int, + detail: Any = None, + headers: Union[Dict[str, str], None] = None, # noqa: UP007, UP006 + response: Union[httpx.Response, None] = None, # noqa: UP007 + extra: Union[Dict[str, Any], List[Any]] | None = None, # noqa: UP007, UP006 + ) -> None: + super().__init__(message=detail) + HTTPException.__init__( + self, detail=detail, status_code=status_code, extra=extra, headers=headers + ) + self.response = response + + +class OAuth2AuthorizeCallback: + """Dependency callable to handle the authorization callback. It reads the query parameters and returns the access token and the state. + + Examples: + ```py + from litestar import get + from httpx_oauth.integrations.litestar import OAuth2AuthorizeCallback + from httpx_oauth.oauth2 import OAuth2 + + client = OAuth2("CLIENT_ID", "CLIENT_SECRET", "AUTHORIZE_ENDPOINT", "ACCESS_TOKEN_ENDPOINT") + oauth2_authorize_callback = OAuth2AuthorizeCallback(client, "oauth-callback") + + @get("/oauth-callback", name="oauth-callback", dependencies={"access_token_state": Provide(oauth2_authorize_callback)}) + async def oauth_callback(access_token_state: AccessTokenState)) -> Response: + token, state = access_token_state + # Do something useful + ``` + """ + + client: BaseOAuth2 + route_name: str | None + redirect_url: str | None + + def __init__( + self, + client: BaseOAuth2, + route_name: str | None = None, + redirect_url: str | None = None, + ) -> None: + """Args: + client: An [OAuth2][httpx_oauth.oauth2.BaseOAuth2] client. + route_name: Name of the callback route, as defined in the `name` parameter of the route decorator. + redirect_url: Full URL to the callback route. + """ + assert (route_name is not None and redirect_url is None) or ( + route_name is None and redirect_url is not None + ), "You should either set route_name or redirect_url" + self.client = client + self.route_name = route_name + self.redirect_url = redirect_url + + async def __call__( + self, + request: Request, + code: str | None = Parameter(query="code", required=False), + code_verifier: str | None = Parameter(query="code_verifier", required=False), + callback_state: str | None = Parameter(query="state", required=False), + error: str | None = Parameter(query="error", required=False), + ) -> AccessTokenState: + if code is None or error is not None: + raise OAuth2AuthorizeCallbackError( + status_code=status.HTTP_400_BAD_REQUEST, + detail=error if error is not None else None, + ) + + if self.route_name: + redirect_url = str(request.url_for(self.route_name)) + elif self.redirect_url: + redirect_url = self.redirect_url + + try: + access_token = await self.client.get_access_token( + code, + redirect_url, + code_verifier, + ) + except GetAccessTokenError as e: + raise OAuth2AuthorizeCallbackError( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=e.message, + response=e.response, + extra={"message": e.message}, + ) from e + + return access_token, callback_state diff --git a/pyproject.toml b/pyproject.toml index 3829a19..3fd770c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "pytest-asyncio", "respx", "fastapi", + "litestar" ] [tool.hatch.envs.default.scripts] diff --git a/tests/test_integrations_litestar.py b/tests/test_integrations_litestar.py new file mode 100644 index 0000000..c05664f --- /dev/null +++ b/tests/test_integrations_litestar.py @@ -0,0 +1,157 @@ +import pytest +from litestar import Litestar, get +from litestar import status_codes as status +from litestar.di import Provide +from litestar.params import Dependency +from litestar.testing import TestClient +from pytest_mock import MockerFixture + +from httpx_oauth.integrations.litestar import AccessTokenState, OAuth2AuthorizeCallback +from httpx_oauth.oauth2 import GetAccessTokenError, OAuth2 + +CLIENT_ID = "CLIENT_ID" +CLIENT_SECRET = "CLIENT_SECRET" +AUTHORIZE_ENDPOINT = "https://www.camelot.bt/authorize" +ACCESS_TOKEN_ENDPOINT = "https://www.camelot.bt/access-token" +REDIRECT_URL = "https://www.tintagel.bt/callback" +ROUTE_NAME = "callback" + +client = OAuth2(CLIENT_ID, CLIENT_SECRET, AUTHORIZE_ENDPOINT, ACCESS_TOKEN_ENDPOINT) +oauth2_authorize_callback_route_name = OAuth2AuthorizeCallback( + client, route_name=ROUTE_NAME +) +oauth2_authorize_callback_redirect_url = OAuth2AuthorizeCallback( + client, redirect_url=REDIRECT_URL +) + + +@get( + "/authorize-route-name", + dependencies={"access_token_state": Provide(oauth2_authorize_callback_route_name)}, +) +async def authorize_route_name( + access_token_state: AccessTokenState = Dependency(skip_validation=True), +) -> AccessTokenState: + return access_token_state + + +@get( + "/authorize-redirect-url", + dependencies={ + "access_token_state": Provide(oauth2_authorize_callback_redirect_url) + }, +) +async def authorize_redirect_url( + access_token_state: AccessTokenState = Dependency(skip_validation=True), +) -> AccessTokenState: + return access_token_state + + +@get("/callback", name="callback") +async def callback() -> dict: + return {} + + +app = Litestar(route_handlers=[authorize_route_name, authorize_redirect_url, callback]) + +test_client = TestClient(app=app) + + +@pytest.mark.parametrize( + "route,expected_redirect_url", + [ + ("/authorize-route-name", "http://testserver.local/callback"), + ("/authorize-redirect-url", "https://www.tintagel.bt/callback"), + ], +) +class TestOAuth2AuthorizeCallback: + def test_oauth2_authorize_missing_code(self, route, expected_redirect_url): + response = test_client.get(route) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_oauth2_authorize_error(self, route, expected_redirect_url): + response = test_client.get(route, params={"error": "access_denied"}) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.json() == {"status_code": 400, "detail": "access_denied"} + + def test_oauth2_authorize_get_access_token_error( + self, mocker: MockerFixture, route, expected_redirect_url + ): + get_access_token_mock = mocker.patch.object( + client, "get_access_token", side_effect=GetAccessTokenError("ERROR") + ) + + response = test_client.get(route, params={"code": "CODE"}) + + get_access_token_mock.assert_called_once_with( + "CODE", expected_redirect_url, None + ) + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + # by default, litestar will only return `Internal Server Error` as the detail on a response. + # we are adding the ERROR to the `extra` payload + assert response.json() == { + "status_code": 500, + "detail": "Internal Server Error", + "extra": {"message": "ERROR"}, + } + + def test_oauth2_authorize_without_state( + self, patch_async_method, route, expected_redirect_url + ): + patch_async_method(client, "get_access_token", return_value="ACCESS_TOKEN") + + response = test_client.get(route, params={"code": "CODE"}) + + client.get_access_token.assert_called() + client.get_access_token.assert_called_once_with( + "CODE", expected_redirect_url, None + ) + assert response.status_code == status.HTTP_200_OK + assert response.json() == ["ACCESS_TOKEN", None] + + def test_oauth2_authorize_code_verifier_without_state( + self, patch_async_method, route, expected_redirect_url + ): + patch_async_method(client, "get_access_token", return_value="ACCESS_TOKEN") + + response = test_client.get( + route, params={"code": "CODE", "code_verifier": "CODE_VERIFIER"} + ) + + client.get_access_token.assert_called() + client.get_access_token.assert_called_once_with( + "CODE", expected_redirect_url, "CODE_VERIFIER" + ) + assert response.status_code == status.HTTP_200_OK + assert response.json() == ["ACCESS_TOKEN", None] + + def test_oauth2_authorize_with_state( + self, patch_async_method, route, expected_redirect_url + ): + patch_async_method(client, "get_access_token", return_value="ACCESS_TOKEN") + + response = test_client.get(route, params={"code": "CODE", "state": "STATE"}) + + client.get_access_token.assert_called() + client.get_access_token.assert_called_once_with( + "CODE", expected_redirect_url, None + ) + assert response.status_code == status.HTTP_200_OK + assert response.json() == ["ACCESS_TOKEN", "STATE"] + + def test_oauth2_authorize_with_state_and_code_verifier( + self, patch_async_method, route, expected_redirect_url + ): + patch_async_method(client, "get_access_token", return_value="ACCESS_TOKEN") + + response = test_client.get( + route, + params={"code": "CODE", "state": "STATE", "code_verifier": "CODE_VERIFIER"}, + ) + + client.get_access_token.assert_called() + client.get_access_token.assert_called_once_with( + "CODE", expected_redirect_url, "CODE_VERIFIER" + ) + assert response.status_code == status.HTTP_200_OK + assert response.json() == ["ACCESS_TOKEN", "STATE"] From b89d944f151fbb93dde0141c68191ce97b4d4a00 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Thu, 8 Aug 2024 14:41:29 +0000 Subject: [PATCH 3/3] fix: linting changes --- httpx_oauth/integrations/litestar.py | 12 +++--- tests/test_integrations_litestar.py | 56 +++++++--------------------- 2 files changed, 19 insertions(+), 49 deletions(-) diff --git a/httpx_oauth/integrations/litestar.py b/httpx_oauth/integrations/litestar.py index b67c9e6..df2b7ce 100644 --- a/httpx_oauth/integrations/litestar.py +++ b/httpx_oauth/integrations/litestar.py @@ -1,7 +1,7 @@ # pylint: disable=[invalid-name,import-outside-toplevel] from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List, TypeAlias, Union # noqa: UP035 +from typing import TYPE_CHECKING, Any, Dict, List, TypeAlias, Union from litestar import status_codes as status from litestar.exceptions import HTTPException @@ -32,14 +32,12 @@ def __init__( self, status_code: int, detail: Any = None, - headers: Union[Dict[str, str], None] = None, # noqa: UP007, UP006 - response: Union[httpx.Response, None] = None, # noqa: UP007 - extra: Union[Dict[str, Any], List[Any]] | None = None, # noqa: UP007, UP006 + headers: Union[Dict[str, str], None] = None, + response: Union[httpx.Response, None] = None, + extra: Union[Dict[str, Any], List[Any]] | None = None, ) -> None: super().__init__(message=detail) - HTTPException.__init__( - self, detail=detail, status_code=status_code, extra=extra, headers=headers - ) + HTTPException.__init__(self, detail=detail, status_code=status_code, extra=extra, headers=headers) self.response = response diff --git a/tests/test_integrations_litestar.py b/tests/test_integrations_litestar.py index c05664f..8c4e45f 100644 --- a/tests/test_integrations_litestar.py +++ b/tests/test_integrations_litestar.py @@ -17,12 +17,8 @@ ROUTE_NAME = "callback" client = OAuth2(CLIENT_ID, CLIENT_SECRET, AUTHORIZE_ENDPOINT, ACCESS_TOKEN_ENDPOINT) -oauth2_authorize_callback_route_name = OAuth2AuthorizeCallback( - client, route_name=ROUTE_NAME -) -oauth2_authorize_callback_redirect_url = OAuth2AuthorizeCallback( - client, redirect_url=REDIRECT_URL -) +oauth2_authorize_callback_route_name = OAuth2AuthorizeCallback(client, route_name=ROUTE_NAME) +oauth2_authorize_callback_redirect_url = OAuth2AuthorizeCallback(client, redirect_url=REDIRECT_URL) @get( @@ -37,9 +33,7 @@ async def authorize_route_name( @get( "/authorize-redirect-url", - dependencies={ - "access_token_state": Provide(oauth2_authorize_callback_redirect_url) - }, + dependencies={"access_token_state": Provide(oauth2_authorize_callback_redirect_url)}, ) async def authorize_redirect_url( access_token_state: AccessTokenState = Dependency(skip_validation=True), @@ -74,18 +68,14 @@ def test_oauth2_authorize_error(self, route, expected_redirect_url): assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.json() == {"status_code": 400, "detail": "access_denied"} - def test_oauth2_authorize_get_access_token_error( - self, mocker: MockerFixture, route, expected_redirect_url - ): + def test_oauth2_authorize_get_access_token_error(self, mocker: MockerFixture, route, expected_redirect_url): get_access_token_mock = mocker.patch.object( client, "get_access_token", side_effect=GetAccessTokenError("ERROR") ) response = test_client.get(route, params={"code": "CODE"}) - get_access_token_mock.assert_called_once_with( - "CODE", expected_redirect_url, None - ) + get_access_token_mock.assert_called_once_with("CODE", expected_redirect_url, None) assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR # by default, litestar will only return `Internal Server Error` as the detail on a response. # we are adding the ERROR to the `extra` payload @@ -95,53 +85,37 @@ def test_oauth2_authorize_get_access_token_error( "extra": {"message": "ERROR"}, } - def test_oauth2_authorize_without_state( - self, patch_async_method, route, expected_redirect_url - ): + def test_oauth2_authorize_without_state(self, patch_async_method, route, expected_redirect_url): patch_async_method(client, "get_access_token", return_value="ACCESS_TOKEN") response = test_client.get(route, params={"code": "CODE"}) client.get_access_token.assert_called() - client.get_access_token.assert_called_once_with( - "CODE", expected_redirect_url, None - ) + client.get_access_token.assert_called_once_with("CODE", expected_redirect_url, None) assert response.status_code == status.HTTP_200_OK assert response.json() == ["ACCESS_TOKEN", None] - def test_oauth2_authorize_code_verifier_without_state( - self, patch_async_method, route, expected_redirect_url - ): + def test_oauth2_authorize_code_verifier_without_state(self, patch_async_method, route, expected_redirect_url): patch_async_method(client, "get_access_token", return_value="ACCESS_TOKEN") - response = test_client.get( - route, params={"code": "CODE", "code_verifier": "CODE_VERIFIER"} - ) + response = test_client.get(route, params={"code": "CODE", "code_verifier": "CODE_VERIFIER"}) client.get_access_token.assert_called() - client.get_access_token.assert_called_once_with( - "CODE", expected_redirect_url, "CODE_VERIFIER" - ) + client.get_access_token.assert_called_once_with("CODE", expected_redirect_url, "CODE_VERIFIER") assert response.status_code == status.HTTP_200_OK assert response.json() == ["ACCESS_TOKEN", None] - def test_oauth2_authorize_with_state( - self, patch_async_method, route, expected_redirect_url - ): + def test_oauth2_authorize_with_state(self, patch_async_method, route, expected_redirect_url): patch_async_method(client, "get_access_token", return_value="ACCESS_TOKEN") response = test_client.get(route, params={"code": "CODE", "state": "STATE"}) client.get_access_token.assert_called() - client.get_access_token.assert_called_once_with( - "CODE", expected_redirect_url, None - ) + client.get_access_token.assert_called_once_with("CODE", expected_redirect_url, None) assert response.status_code == status.HTTP_200_OK assert response.json() == ["ACCESS_TOKEN", "STATE"] - def test_oauth2_authorize_with_state_and_code_verifier( - self, patch_async_method, route, expected_redirect_url - ): + def test_oauth2_authorize_with_state_and_code_verifier(self, patch_async_method, route, expected_redirect_url): patch_async_method(client, "get_access_token", return_value="ACCESS_TOKEN") response = test_client.get( @@ -150,8 +124,6 @@ def test_oauth2_authorize_with_state_and_code_verifier( ) client.get_access_token.assert_called() - client.get_access_token.assert_called_once_with( - "CODE", expected_redirect_url, "CODE_VERIFIER" - ) + client.get_access_token.assert_called_once_with("CODE", expected_redirect_url, "CODE_VERIFIER") assert response.status_code == status.HTTP_200_OK assert response.json() == ["ACCESS_TOKEN", "STATE"]