Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add litestar integration #332

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions docs/litestar.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Litestar

Utilities are provided to ease the integration of an OAuth2 process in [Litestar](https://litestar.dev/).

## `OAuth2AuthorizeCallback`

Dependency callable to handle the authorization callback. It reads the query parameters and returns the access token and the state.

```py
from httpx_oauth.integrations.litestar import OAuth2AuthorizeCallback, AccessTokenState
from httpx_oauth.oauth2 import OAuth2
from litestar import Litestar, get
from litestar.di import Provide
from litestar.params import Dependency

client = OAuth2("CLIENT_ID", "CLIENT_SECRET", "AUTHORIZE_ENDPOINT", "ACCESS_TOKEN_ENDPOINT")
oauth2_authorize_callback = OAuth2AuthorizeCallback(client, "oauth-callback")

@get("/oauth-callback", name="oauth-callback")
async def oauth_callback(
access_token_state: AccessTokenState = Dependency(skip_validation=True),
) -> AccessTokenState:
token, state = access_token_state
# Do something useful

app = Litestar(route_handlers=[oauth_callback],dependencies={"access_token_state": Provide(oauth2_authorize_callback)})


```

[Reference](./reference/httpx_oauth.integrations.litestar.md){ .md-button }
{ .buttons }

### Custom exception handler

If an error occurs inside the callback logic (the user denied access, the authorization code is invalid...), the dependency will raise [OAuth2AuthorizeCallbackError][httpx_oauth.integrations.litestar.OAuth2AuthorizeCallbackError].

It inherits from Litestar's [HTTPException][litestar.exceptions.HTTPException], so it's automatically handled by the default Litestar exception handler. You can customize this behavior by implementing your own exception handler for `OAuth2AuthorizeCallbackError`.

```py
from httpx_oauth.integrations.litestar import OAuth2AuthorizeCallbackError
from litestar import Litestar
from litestar.response import Response

async def oauth2_authorize_callback_error_handler(request: Request, exc: OAuth2AuthorizeCallbackError) -> Response:
detail = exc.detail
status_code = exc.status_code
return Response(
status_code=status_code,
content={"message": "The OAuth2 callback failed", "detail": detail},
)

app = Litestar(exception_handlers={OAuth2AuthorizeCallbackError: oauth2_authorize_callback_error_handler})
```
6 changes: 6 additions & 0 deletions docs/reference/httpx_oauth.integrations.litestar.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Reference - Integrations - Litestar

::: httpx_oauth.integrations.litestar
options:
show_root_heading: false
show_source: false
118 changes: 118 additions & 0 deletions httpx_oauth/integrations/litestar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# pylint: disable=[invalid-name,import-outside-toplevel]
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, List, TypeAlias, Union

from litestar import status_codes as status
from litestar.exceptions import HTTPException
from litestar.params import Parameter

from httpx_oauth.oauth2 import BaseOAuth2, GetAccessTokenError, OAuth2Error, OAuth2Token

if TYPE_CHECKING:
import httpx
from litestar import Request


AccessTokenState: TypeAlias = tuple[OAuth2Token, str | None]


class OAuth2AuthorizeCallbackError(OAuth2Error, HTTPException):
"""Error raised when an error occurs during the OAuth2 authorization callback.

It inherits from [HTTPException][litestar.exceptions.HTTPException], so you can either keep
the default Litestar error handling or implement something dedicated.

!!! Note
Due to the way the base `LitestarException` handles the `detail` argument,
the `OAuth2Error` is ordered first here
"""

def __init__(
self,
status_code: int,
detail: Any = None,
headers: Union[Dict[str, str], None] = None,
response: Union[httpx.Response, None] = None,
extra: Union[Dict[str, Any], List[Any]] | None = None,
) -> None:
super().__init__(message=detail)
HTTPException.__init__(self, detail=detail, status_code=status_code, extra=extra, headers=headers)
self.response = response


class OAuth2AuthorizeCallback:
"""Dependency callable to handle the authorization callback. It reads the query parameters and returns the access token and the state.

Examples:
```py
from litestar import get
from httpx_oauth.integrations.litestar import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import OAuth2

client = OAuth2("CLIENT_ID", "CLIENT_SECRET", "AUTHORIZE_ENDPOINT", "ACCESS_TOKEN_ENDPOINT")
oauth2_authorize_callback = OAuth2AuthorizeCallback(client, "oauth-callback")

@get("/oauth-callback", name="oauth-callback", dependencies={"access_token_state": Provide(oauth2_authorize_callback)})
async def oauth_callback(access_token_state: AccessTokenState)) -> Response:
token, state = access_token_state
# Do something useful
```
"""

client: BaseOAuth2
route_name: str | None
redirect_url: str | None

def __init__(
self,
client: BaseOAuth2,
route_name: str | None = None,
redirect_url: str | None = None,
) -> None:
"""Args:
client: An [OAuth2][httpx_oauth.oauth2.BaseOAuth2] client.
route_name: Name of the callback route, as defined in the `name` parameter of the route decorator.
redirect_url: Full URL to the callback route.
"""
assert (route_name is not None and redirect_url is None) or (
route_name is None and redirect_url is not None
), "You should either set route_name or redirect_url"
self.client = client
self.route_name = route_name
self.redirect_url = redirect_url

async def __call__(
self,
request: Request,
code: str | None = Parameter(query="code", required=False),
code_verifier: str | None = Parameter(query="code_verifier", required=False),
callback_state: str | None = Parameter(query="state", required=False),
error: str | None = Parameter(query="error", required=False),
) -> AccessTokenState:
if code is None or error is not None:
raise OAuth2AuthorizeCallbackError(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error if error is not None else None,
)

if self.route_name:
redirect_url = str(request.url_for(self.route_name))
elif self.redirect_url:
redirect_url = self.redirect_url

try:
access_token = await self.client.get_access_token(
code,
redirect_url,
code_verifier,
)
except GetAccessTokenError as e:
raise OAuth2AuthorizeCallbackError(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=e.message,
response=e.response,
extra={"message": e.message},
) from e

return access_token, callback_state
2 changes: 1 addition & 1 deletion httpx_oauth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self):

class RevokeTokenNotSupportedError(OAuth2Error):
"""
Error raised when trying to revole a token
Error raised when trying to revoke a token
on a provider that does not support it.
"""

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"pytest-asyncio",
"respx",
"fastapi",
"litestar"
]

[tool.hatch.envs.default.scripts]
Expand Down
129 changes: 129 additions & 0 deletions tests/test_integrations_litestar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import pytest
from litestar import Litestar, get
from litestar import status_codes as status
from litestar.di import Provide
from litestar.params import Dependency
from litestar.testing import TestClient
from pytest_mock import MockerFixture

from httpx_oauth.integrations.litestar import AccessTokenState, OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import GetAccessTokenError, OAuth2

CLIENT_ID = "CLIENT_ID"
CLIENT_SECRET = "CLIENT_SECRET"
AUTHORIZE_ENDPOINT = "https://www.camelot.bt/authorize"
ACCESS_TOKEN_ENDPOINT = "https://www.camelot.bt/access-token"
REDIRECT_URL = "https://www.tintagel.bt/callback"
ROUTE_NAME = "callback"

client = OAuth2(CLIENT_ID, CLIENT_SECRET, AUTHORIZE_ENDPOINT, ACCESS_TOKEN_ENDPOINT)
oauth2_authorize_callback_route_name = OAuth2AuthorizeCallback(client, route_name=ROUTE_NAME)
oauth2_authorize_callback_redirect_url = OAuth2AuthorizeCallback(client, redirect_url=REDIRECT_URL)


@get(
"/authorize-route-name",
dependencies={"access_token_state": Provide(oauth2_authorize_callback_route_name)},
)
async def authorize_route_name(
access_token_state: AccessTokenState = Dependency(skip_validation=True),
) -> AccessTokenState:
return access_token_state


@get(
"/authorize-redirect-url",
dependencies={"access_token_state": Provide(oauth2_authorize_callback_redirect_url)},
)
async def authorize_redirect_url(
access_token_state: AccessTokenState = Dependency(skip_validation=True),
) -> AccessTokenState:
return access_token_state


@get("/callback", name="callback")
async def callback() -> dict:
return {}


app = Litestar(route_handlers=[authorize_route_name, authorize_redirect_url, callback])

test_client = TestClient(app=app)


@pytest.mark.parametrize(
"route,expected_redirect_url",
[
("/authorize-route-name", "http://testserver.local/callback"),
("/authorize-redirect-url", "https://www.tintagel.bt/callback"),
],
)
class TestOAuth2AuthorizeCallback:
def test_oauth2_authorize_missing_code(self, route, expected_redirect_url):
response = test_client.get(route)
assert response.status_code == status.HTTP_400_BAD_REQUEST

def test_oauth2_authorize_error(self, route, expected_redirect_url):
response = test_client.get(route, params={"error": "access_denied"})
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json() == {"status_code": 400, "detail": "access_denied"}

def test_oauth2_authorize_get_access_token_error(self, mocker: MockerFixture, route, expected_redirect_url):
get_access_token_mock = mocker.patch.object(
client, "get_access_token", side_effect=GetAccessTokenError("ERROR")
)

response = test_client.get(route, params={"code": "CODE"})

get_access_token_mock.assert_called_once_with("CODE", expected_redirect_url, None)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
# by default, litestar will only return `Internal Server Error` as the detail on a response.
# we are adding the ERROR to the `extra` payload
assert response.json() == {
"status_code": 500,
"detail": "Internal Server Error",
"extra": {"message": "ERROR"},
}

def test_oauth2_authorize_without_state(self, patch_async_method, route, expected_redirect_url):
patch_async_method(client, "get_access_token", return_value="ACCESS_TOKEN")

response = test_client.get(route, params={"code": "CODE"})

client.get_access_token.assert_called()
client.get_access_token.assert_called_once_with("CODE", expected_redirect_url, None)
assert response.status_code == status.HTTP_200_OK
assert response.json() == ["ACCESS_TOKEN", None]

def test_oauth2_authorize_code_verifier_without_state(self, patch_async_method, route, expected_redirect_url):
patch_async_method(client, "get_access_token", return_value="ACCESS_TOKEN")

response = test_client.get(route, params={"code": "CODE", "code_verifier": "CODE_VERIFIER"})

client.get_access_token.assert_called()
client.get_access_token.assert_called_once_with("CODE", expected_redirect_url, "CODE_VERIFIER")
assert response.status_code == status.HTTP_200_OK
assert response.json() == ["ACCESS_TOKEN", None]

def test_oauth2_authorize_with_state(self, patch_async_method, route, expected_redirect_url):
patch_async_method(client, "get_access_token", return_value="ACCESS_TOKEN")

response = test_client.get(route, params={"code": "CODE", "state": "STATE"})

client.get_access_token.assert_called()
client.get_access_token.assert_called_once_with("CODE", expected_redirect_url, None)
assert response.status_code == status.HTTP_200_OK
assert response.json() == ["ACCESS_TOKEN", "STATE"]

def test_oauth2_authorize_with_state_and_code_verifier(self, patch_async_method, route, expected_redirect_url):
patch_async_method(client, "get_access_token", return_value="ACCESS_TOKEN")

response = test_client.get(
route,
params={"code": "CODE", "state": "STATE", "code_verifier": "CODE_VERIFIER"},
)

client.get_access_token.assert_called()
client.get_access_token.assert_called_once_with("CODE", expected_redirect_url, "CODE_VERIFIER")
assert response.status_code == status.HTTP_200_OK
assert response.json() == ["ACCESS_TOKEN", "STATE"]