Skip to content

Commit

Permalink
feat: api_key and jwt request headers are hidden behind SecretStr (ep…
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Jul 10, 2024
1 parent efd3725 commit 054afa5
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 14 deletions.
54 changes: 43 additions & 11 deletions aidial_sdk/deployment/from_request_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from json import JSONDecodeError
from typing import Any, Mapping, Optional, Type, TypeVar

from fastapi import Request
import fastapi

from aidial_sdk.exceptions import HTTPException as DIALException
from aidial_sdk.pydantic_v1 import StrictStr
from aidial_sdk.pydantic_v1 import SecretStr, StrictStr, root_validator
from aidial_sdk.utils.logging import log_debug
from aidial_sdk.utils.pydantic import ExtraForbidModel

Expand All @@ -15,25 +15,52 @@
class FromRequestMixin(ABC, ExtraForbidModel):
@classmethod
@abstractmethod
async def from_request(cls: Type[T], request: Request) -> T:
async def from_request(cls: Type[T], request: fastapi.Request) -> T:
pass


class FromRequestBasicMixin(FromRequestMixin):
@classmethod
async def from_request(cls, request: Request):
async def from_request(cls, request: fastapi.Request):
return cls(**(await _get_request_body(request)))


class FromRequestDeploymentMixin(FromRequestMixin):
api_key: StrictStr
jwt: Optional[StrictStr] = None
api_key_secret: SecretStr
jwt_secret: Optional[SecretStr] = None

deployment_id: StrictStr
api_version: Optional[StrictStr] = None
headers: Mapping[StrictStr, StrictStr]

@root_validator(pre=True)
def create_secrets(cls, values: dict):
if "api_key" in values:
if "api_key_secret" not in values:
values["api_key_secret"] = SecretStr(values.pop("api_key"))
else:
raise ValueError(
"api_key and api_key_secret cannot be both provided"
)

if "jwt" in values:
if "jwt_secret" not in values:
values["jwt_secret"] = SecretStr(values.pop("jwt"))
else:
raise ValueError("jwt and jwt_secret cannot be both provided")

return values

@property
def api_key(self) -> str:
return self.api_key_secret.get_secret_value()

@property
def jwt(self) -> Optional[str]:
return self.jwt_secret.get_secret_value() if self.jwt_secret else None

@classmethod
async def from_request(cls, request: Request):
async def from_request(cls, request: fastapi.Request):
deployment_id = request.path_params.get("deployment_id")
if deployment_id is None or not isinstance(deployment_id, str):
raise DIALException(
Expand All @@ -42,26 +69,31 @@ async def from_request(cls, request: Request):
message="Invalid path",
)

headers = request.headers
headers = request.headers.mutablecopy()

api_key = headers.get("Api-Key")
if api_key is None:
raise DIALException(
status_code=400,
type="invalid_request_error",
message="Api-Key header is required",
)
del headers["Api-Key"]

jwt = headers.get("Authorization")
del headers["Authorization"]

return cls(
**(await _get_request_body(request)),
api_key=api_key,
jwt=headers.get("Authorization"),
api_key_secret=SecretStr(api_key),
jwt_secret=SecretStr(jwt) if jwt else None,
deployment_id=deployment_id,
api_version=request.query_params.get("api-version"),
headers=headers,
)


async def _get_request_body(request: Request) -> Any:
async def _get_request_body(request: fastapi.Request) -> Any:
try:
body = await request.json()
log_debug(f"request: {body}")
Expand Down
4 changes: 3 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def format(session: nox.Session):


@nox.session(python=["3.8", "3.9", "3.10", "3.11"])
def test(session: nox.Session) -> None:
@nox.parametrize("pydantic", ["1.10.17", "2.8.2"])
def test(session: nox.Session, pydantic: str) -> None:
"""Runs tests"""
session.run("poetry", "install", external=True)
session.install(f"pydantic=={pydantic}")
session.run("pytest")
16 changes: 14 additions & 2 deletions tests/test_discarded_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from aidial_sdk import DIALApp, HTTPException
from aidial_sdk.chat_completion import ChatCompletion, Request, Response
from aidial_sdk.pydantic_v1 import SecretStr

DISCARDED_MESSAGES = list(range(0, 12))

Expand Down Expand Up @@ -112,7 +113,13 @@ def identity(data: str):


def test_discarded_messages_is_set_twice():
request = Request(headers={}, api_key="", deployment_id="", messages=[])
request = Request(
headers={},
api_key_secret=SecretStr("dummy_key"),
deployment_id="",
messages=[],
)

response = Response(request)

with response.create_single_choice():
Expand All @@ -125,7 +132,12 @@ def test_discarded_messages_is_set_twice():


def test_discarded_messages_is_set_before_choice():
request = Request(headers={}, api_key="", deployment_id="", messages=[])
request = Request(
headers={},
api_key_secret=SecretStr("dummy_key"),
deployment_id="",
messages=[],
)
response = Response(request)

with pytest.raises(HTTPException):
Expand Down

0 comments on commit 054afa5

Please sign in to comment.