Skip to content

Commit

Permalink
Do not create FastAPI app until config loaded
Browse files Browse the repository at this point in the history
  • Loading branch information
DiamondJoseph committed Nov 7, 2024
1 parent 1c50a4b commit e581c27
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 50 deletions.
95 changes: 55 additions & 40 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from contextlib import asynccontextmanager

from fastapi import (
APIRouter,
BackgroundTasks,
Body,
Depends,
Expand Down Expand Up @@ -80,25 +81,34 @@ async def lifespan(app: FastAPI):
teardown_runner()


app = FastAPI(
docs_url="/docs",
title="BlueAPI Control",
lifespan=lifespan,
version=REST_API_VERSION,
)
router = APIRouter()


def get_app():
app = FastAPI(
docs_url="/docs",
title="BlueAPI Control",
lifespan=lifespan,
version=REST_API_VERSION,
)
app.include_router(router)
app.add_exception_handler(KeyError, on_key_error_404)
add_api_version_header(app)
inject_propagated_observability_context(app)
return app


TRACER = get_tracer("interface")


@app.exception_handler(KeyError)
async def on_key_error_404(_: Request, __: KeyError):
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={"detail": "Item not found"},
)


@app.get("/environment", response_model=EnvironmentResponse)
@router.get("/environment", response_model=EnvironmentResponse)
@start_as_current_span(TRACER, "runner")
def get_environment(
runner: WorkerDispatcher = Depends(_runner),
Expand All @@ -107,7 +117,7 @@ def get_environment(
return runner.state


@app.delete("/environment", response_model=EnvironmentResponse)
@router.delete("/environment", response_model=EnvironmentResponse)
async def delete_environment(
background_tasks: BackgroundTasks,
runner: WorkerDispatcher = Depends(_runner),
Expand All @@ -119,14 +129,14 @@ async def delete_environment(
return EnvironmentResponse(initialized=False)


@app.get("/plans", response_model=PlanResponse)
@router.get("/plans", response_model=PlanResponse)
@start_as_current_span(TRACER)
def get_plans(runner: WorkerDispatcher = Depends(_runner)):
"""Retrieve information about all available plans."""
return PlanResponse(plans=runner.run(interface.get_plans))


@app.get(
@router.get(
"/plans/{name}",
response_model=PlanModel,
)
Expand All @@ -136,14 +146,14 @@ def get_plan_by_name(name: str, runner: WorkerDispatcher = Depends(_runner)):
return runner.run(interface.get_plan, name)


@app.get("/devices", response_model=DeviceResponse)
@router.get("/devices", response_model=DeviceResponse)
@start_as_current_span(TRACER)
def get_devices(runner: WorkerDispatcher = Depends(_runner)):
"""Retrieve information about all available devices."""
return DeviceResponse(devices=runner.run(interface.get_devices))


@app.get(
@router.get(
"/devices/{name}",
response_model=DeviceModel,
)
Expand All @@ -156,7 +166,7 @@ def get_device_by_name(name: str, runner: WorkerDispatcher = Depends(_runner)):
example_task = Task(name="count", params={"detectors": ["x"]})


@app.post(
@router.post(
"/tasks",
response_model=TaskResponse,
status_code=status.HTTP_201_CREATED,
Expand Down Expand Up @@ -190,7 +200,7 @@ def submit_task(
) from e


@app.delete("/tasks/{task_id}", status_code=status.HTTP_200_OK)
@router.delete("/tasks/{task_id}", status_code=status.HTTP_200_OK)
@start_as_current_span(TRACER, "task_id")
def delete_submitted_task(
task_id: str,
Expand All @@ -207,7 +217,7 @@ def validate_task_status(v: str) -> TaskStatusEnum:
return TaskStatusEnum(v_upper)


@app.get("/tasks", response_model=TasksListResponse, status_code=status.HTTP_200_OK)
@router.get("/tasks", response_model=TasksListResponse, status_code=status.HTTP_200_OK)
@start_as_current_span(TRACER)
def get_tasks(
task_status: str | None = None,
Expand All @@ -234,7 +244,7 @@ def get_tasks(
return TasksListResponse(tasks=tasks)


@app.put(
@router.put(
"/worker/task",
response_model=WorkerTask,
responses={status.HTTP_409_CONFLICT: {"worker": "already active"}},
Expand All @@ -255,7 +265,7 @@ def set_active_task(
return task


@app.get(
@router.get(
"/tasks/{task_id}",
response_model=TrackableTask,
)
Expand All @@ -271,7 +281,7 @@ def get_task(
return task


@app.get("/worker/task")
@router.get("/worker/task")
@start_as_current_span(TRACER)
def get_active_task(runner: WorkerDispatcher = Depends(_runner)) -> WorkerTask:
active = runner.run(interface.get_active_task)
Expand All @@ -281,7 +291,7 @@ def get_active_task(runner: WorkerDispatcher = Depends(_runner)) -> WorkerTask:
return WorkerTask(task_id=None)


@app.get("/worker/state")
@router.get("/worker/state")
@start_as_current_span(TRACER)
def get_state(runner: WorkerDispatcher = Depends(_runner)) -> WorkerState:
"""Get the State of the Worker"""
Expand All @@ -303,7 +313,7 @@ def get_state(runner: WorkerDispatcher = Depends(_runner)) -> WorkerState:
}


@app.put(
@router.put(
"/worker/state",
status_code=status.HTTP_202_ACCEPTED,
responses={
Expand Down Expand Up @@ -372,34 +382,39 @@ def start(config: ApplicationConfig):
"%(asctime)s %(levelprefix)s %(client_addr)s"
+ " - '%(request_line)s' %(status_code)s"
)
app = get_app()

Check warning on line 385 in src/blueapi/service/main.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/service/main.py#L385

Added line #L385 was not covered by tests

FastAPIInstrumentor().instrument_app(
app,
tracer_provider=get_tracer_provider(),
http_capture_headers_server_request=[",*"],
http_capture_headers_server_response=[",*"],
)
app.state.config = config

uvicorn.run(app, host=config.api.host, port=config.api.port)


@app.middleware("http")
async def add_api_version_header(request: Request, call_next):
response = await call_next(request)
response.headers["X-API-Version"] = REST_API_VERSION
return response
def add_api_version_header(app: FastAPI):
@app.middleware("http")
async def add_api_version_header(request: Request, call_next):
response = await call_next(request)
response.headers["X-API-Version"] = REST_API_VERSION
return response


@app.middleware("http")
async def inject_propagated_observability_context(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Middleware to extract the any prorpagated observability context from the
HTTP headers and attatch it to the local one.
"""
if CONTEXT_HEADER in request.headers:
ctx = get_global_textmap().extract(
{CONTEXT_HEADER: request.headers[CONTEXT_HEADER]}
)
attach(ctx)
response = await call_next(request)
return response
def inject_propagated_observability_context(app: FastAPI):
@app.middleware("http")
async def inject_propagated_observability_context(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Middleware to extract the any propagated observability context from the
HTTP headers and attach it to the local one.
"""
if CONTEXT_HEADER in request.headers:
ctx = get_global_textmap().extract(

Check warning on line 415 in src/blueapi/service/main.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/service/main.py#L415

Added line #L415 was not covered by tests
{CONTEXT_HEADER: request.headers[CONTEXT_HEADER]}
)
attach(ctx)

Check warning on line 418 in src/blueapi/service/main.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/service/main.py#L418

Added line #L418 was not covered by tests
response = await call_next(request)
return response
3 changes: 2 additions & 1 deletion src/blueapi/service/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from fastapi.openapi.utils import get_openapi
from pyparsing import Any

from blueapi.service.main import app
from blueapi.service.main import get_app

DOCS_SCHEMA_LOCATION = Path(__file__).parents[3] / "docs" / "reference" / "openapi.yaml"


def generate_schema() -> Mapping[str, Any]:
app = get_app()
return get_openapi(
title=app.title,
version=app.version,
Expand Down
12 changes: 7 additions & 5 deletions tests/unit_tests/service/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
from blueapi.service.openapi import DOCS_SCHEMA_LOCATION, generate_schema


@mock.patch("blueapi.service.openapi.app")
def test_generate_schema(mock_app: Mock) -> None:
from blueapi.service.main import app
@mock.patch("blueapi.service.openapi.get_app")
def test_generate_schema(mock_get_app: Mock) -> None:
mock_app = mock_get_app()

from blueapi.service.main import get_app

app = get_app()

title = PropertyMock(return_value="title")
version = PropertyMock(return_value=app.version)
Expand All @@ -23,8 +27,6 @@ def test_generate_schema(mock_app: Mock) -> None:
type(mock_app).description = description
type(mock_app).routes = routes

# from blueapi.service.openapi import generate_schema

assert generate_schema() == {
"openapi": openapi_version(),
"info": {
Expand Down
6 changes: 2 additions & 4 deletions tests/unit_tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@

@pytest.fixture
def client() -> Iterator[TestClient]:
with (
patch("blueapi.service.interface.worker"),
):
with patch("blueapi.service.interface.worker"):
main.setup_runner(use_subprocess=False)
yield TestClient(main.app)
yield TestClient(main.get_app())
main.teardown_runner()


Expand Down

0 comments on commit e581c27

Please sign in to comment.