Skip to content

Commit

Permalink
Add new tests for OAuth and add Security
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil committed Nov 15, 2024
1 parent daa8983 commit 216d783
Show file tree
Hide file tree
Showing 10 changed files with 546 additions and 21 deletions.
3 changes: 2 additions & 1 deletion esmerald/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
ValidationErrorException,
)
from .interceptors.interceptor import EsmeraldInterceptor
from .param_functions import DirectInjects
from .param_functions import DirectInjects, Security
from .params import Body, Cookie, File, Form, Header, Injects, Param, Path, Query
from .permissions import AllowAny, BasePermission, DenyAll
from .pluggables import Extension, Pluggable
Expand Down Expand Up @@ -93,6 +93,7 @@
"Request",
"Response",
"Router",
"Security",
"ServiceUnavailable",
"SessionConfig",
"SimpleAPIView",
Expand Down
1 change: 0 additions & 1 deletion esmerald/openapi/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def get_openapi_security_schemes(schemes: Any) -> Tuple[dict, list]:
security_name = security_requirement.scheme_name
security_definitions[security_name] = security_definition
operation_security.append({security_name: security_requirement})

return security_definitions, operation_security


Expand Down
12 changes: 1 addition & 11 deletions esmerald/openapi/schemas/v3_1_0/security_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,6 @@ class SecurityScheme(BaseModel):
An optional model to be used for the security scheme.
"""

scheme_name: Optional[str] = None
"""
An optional name for the security scheme.
"""

auto_error: bool = False
"""
A flag to indicate if automatic error handling should be enabled.
"""

model_config = ConfigDict(
extra="ignore",
populate_by_name=True,
Expand All @@ -110,7 +100,7 @@ class SecurityScheme(BaseModel):
{
"type": "openIdConnect",
"openIdConnectUrl": "openIdConnect",
}, # issue #5: allow relative path
},
]
},
)
15 changes: 14 additions & 1 deletion esmerald/param_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, List, Optional, Sequence, Union

from pydantic.fields import AliasChoices, AliasPath

Expand All @@ -10,6 +10,19 @@
_PyUndefined: Any = Undefined


def Security(
dependency: Optional[Callable[..., Any]] = None,
*,
scopes: Optional[Sequence[str]] = None,
use_cache: bool = True,
) -> Any:
"""
This function should be only called if Inject/Injects is not used in the dependencies.
This is a simple wrapper of the classic Inject().
"""
return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache)


def DirectInjects(
dependency: Optional[Callable[..., Any]] = None,
*,
Expand Down
25 changes: 24 additions & 1 deletion esmerald/params.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

from pydantic.dataclasses import dataclass
from pydantic.fields import AliasChoices, AliasPath, FieldInfo
Expand Down Expand Up @@ -650,3 +650,26 @@ def __hash__(self) -> int:
else:
values[key] = value
return hash((type(self),) + tuple(values))


class Depends:
def __init__(self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True):
self.dependency = dependency
self.use_cache = use_cache

def __repr__(self) -> str:
attr = getattr(self.dependency, "__name__", type(self.dependency).__name__)
cache = "" if self.use_cache else ", use_cache=False"
return f"{self.__class__.__name__}({attr}{cache})"


class Security(Depends):
def __init__(
self,
dependency: Optional[Callable[..., Any]] = None,
*,
scopes: Optional[Sequence[str]] = None,
use_cache: bool = True,
):
super().__init__(dependency=dependency, use_cache=use_cache)
self.scopes = scopes or []
22 changes: 16 additions & 6 deletions esmerald/security/oauth2/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@
from esmerald.security.utils import get_authorization_scheme_param


class SecurityBase(SecurityScheme): ...
class SecurityBase(SecurityScheme):
scheme_name: Optional[str] = None
"""
An optional name for the security scheme.
"""
__auto_error__: bool = False
"""
A flag to indicate if automatic error handling should be enabled.
"""


class OAuth2PasswordRequestForm:
Expand Down Expand Up @@ -393,19 +401,21 @@ def __init__(
),
] = True,
) -> None:
model = OAuth2Model(flows=cast(OAuthFlowsModel, flows), description=description)
model = OAuth2Model(
flows=cast(OAuthFlowsModel, flows), scheme=scheme_name, description=description
)
model_dump = model.model_dump()
super().__init__(**model_dump)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
self.__auto_error__ = auto_error

async def __call__(self, request: Request) -> Any:
authorization = request.headers.get("Authorization")

if authorization:
return authorization

if self.auto_error:
if self.__auto_error__:
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Not authenticated")

return None
Expand Down Expand Up @@ -526,7 +536,7 @@ async def __call__(self, request: Request) -> Optional[str]:
if authorization and scheme.lower() == "bearer":
return param

if self.auto_error:
if self.__auto_error__:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
Expand Down Expand Up @@ -652,7 +662,7 @@ async def __call__(self, request: Request) -> Optional[str]:
if authorization and scheme.lower() == "bearer":
return param

if self.auto_error:
if self.__auto_error__:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
Expand Down
Empty file.
165 changes: 165 additions & 0 deletions tests/security/oauth/test_oauth_code_bearer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from typing import Any, Optional

from esmerald import Gateway, Inject, Injects, get
from esmerald.security.oauth2 import OAuth2AuthorizationCodeBearer
from esmerald.testclient import create_client

oauth2_scheme = OAuth2AuthorizationCodeBearer(
authorizationUrl="authorize", tokenUrl="token", auto_error=True
)


@get("/items", dependencies={"token": Inject(oauth2_scheme)}, security=[oauth2_scheme])
async def read_items(token: Optional[str] = Injects()) -> dict[str, Any]:
return {"token": token}


def test_no_token():
with create_client(
routes=[
Gateway(handler=read_items),
],
) as client:
response = client.get("/items")
assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"}


def test_incorrect_token():
with create_client(
routes=[
Gateway(handler=read_items),
],
) as client:
response = client.get("/items", headers={"Authorization": "Non-existent testtoken"})
assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"}


def test_token():
with create_client(
routes=[
Gateway(handler=read_items),
],
) as client:
response = client.get("/items", headers={"Authorization": "Bearer testtoken"})
assert response.status_code == 200, response.text
assert response.json() == {"token": "testtoken"}


def test_openapi_schema():
with create_client(
routes=[
Gateway(handler=read_items),
],
) as client:
response = client.get("/openapi.json")
assert response.status_code == 200, response.text

assert response.json() == {
"openapi": "3.1.0",
"info": {
"title": "Esmerald",
"summary": "Esmerald application",
"description": "Highly scalable, performant, easy to learn and for every application.",
"contact": {"name": "admin", "email": "[email protected]"},
"version": client.app.version,
},
"servers": [{"url": "/"}],
"paths": {
"/items": {
"get": {
"summary": "Read Items",
"description": "",
"operationId": "read_items_items_get",
"deprecated": False,
"security": [
{
"OAuth2AuthorizationCodeBearer": {
"type": "oauth2",
"flows": {
"authorizationCode": {
"authorizationUrl": "authorize",
"tokenUrl": "token",
"scopes": {},
}
},
"scheme_name": "OAuth2AuthorizationCodeBearer",
}
}
],
"parameters": [
{
"name": "token",
"in": "query",
"required": True,
"deprecated": False,
"allowEmptyValue": False,
"allowReserved": False,
"schema": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"title": "Token",
},
}
],
"responses": {
"200": {
"description": "Successful response",
"content": {"application/json": {"schema": {"type": "string"}}},
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
},
}
}
},
"components": {
"schemas": {
"HTTPValidationError": {
"properties": {
"detail": {
"items": {"$ref": "#/components/schemas/ValidationError"},
"type": "array",
"title": "Detail",
}
},
"type": "object",
"title": "HTTPValidationError",
},
"ValidationError": {
"properties": {
"loc": {
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
"type": "array",
"title": "Location",
},
"msg": {"type": "string", "title": "Message"},
"type": {"type": "string", "title": "Error Type"},
},
"type": "object",
"required": ["loc", "msg", "type"],
"title": "ValidationError",
},
},
"securitySchemes": {
"OAuth2AuthorizationCodeBearer": {
"type": "oauth2",
"flows": {
"authorizationCode": {
"authorizationUrl": "authorize",
"tokenUrl": "token",
"scopes": {},
}
},
}
},
},
}
Loading

0 comments on commit 216d783

Please sign in to comment.