Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move BackgroundTask execution outside of request/response cycle #2176

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from starlette.datastructures import State, URLPath
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware.background import BackgroundTaskMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.exceptions import ExceptionMiddleware
Expand Down Expand Up @@ -96,6 +97,7 @@ def build_middleware_stack(self) -> ASGIApp:

middleware = (
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
+ [Middleware(BackgroundTaskMiddleware)]
+ self.user_middleware
+ [
Middleware(
Expand Down
37 changes: 37 additions & 0 deletions starlette/middleware/background.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import List, cast

from starlette.background import BackgroundTask
from starlette.types import ASGIApp, Receive, Scope, Send

# consider this a private implementation detail subject to change
# do not rely on this key
_SCOPE_KEY = "starlette._background"


_BackgroundTaskList = List[BackgroundTask]


def is_background_task_middleware_installed(scope: Scope) -> bool:
return _SCOPE_KEY in scope


def add_tasks(scope: Scope, task: BackgroundTask, /) -> None:
if _SCOPE_KEY not in scope: # pragma: no cover
raise RuntimeError(
"`add_tasks` can only be used if `BackgroundTaskMIddleware is installed"
)
cast(_BackgroundTaskList, scope[_SCOPE_KEY]).append(task)


class BackgroundTaskMiddleware:
def __init__(self, app: ASGIApp) -> None:
self._app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
tasks: _BackgroundTaskList
scope[_SCOPE_KEY] = tasks = []
try:
await self._app(scope, receive, send)
finally:
for task in tasks:
await task()
19 changes: 19 additions & 0 deletions starlette/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, MutableHeaders
from starlette.middleware import background
from starlette.types import Receive, Scope, Send


Expand Down Expand Up @@ -148,6 +149,12 @@ def delete_cookie(
)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if (
self.background is not None
and background.is_background_task_middleware_installed(scope)
):
background.add_tasks(scope, self.background)
self.background = None
prefix = "websocket." if scope["type"] == "websocket" else ""
await send(
{
Expand Down Expand Up @@ -255,6 +262,12 @@ async def stream_response(self, send: Send) -> None:
await send({"type": "http.response.body", "body": b"", "more_body": False})

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if (
self.background is not None
and background.is_background_task_middleware_installed(scope)
):
background.add_tasks(scope, self.background)
self.background = None
async with anyio.create_task_group() as task_group:

async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
Expand Down Expand Up @@ -322,6 +335,12 @@ def set_stat_headers(self, stat_result: os.stat_result) -> None:
self.headers.setdefault("etag", etag)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if (
self.background is not None
and background.is_background_task_middleware_installed(scope)
):
background.add_tasks(scope, self.background)
self.background = None
if self.stat_result is None:
try:
stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
Expand Down
104 changes: 99 additions & 5 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,26 @@
from typing import (
Any,
AsyncGenerator,
Callable,
Generator,
Literal,
)

import anyio
import pytest
from anyio.abc import TaskStatus

from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.background import BackgroundTask, BackgroundTasks
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware.background import BackgroundTaskMiddleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket

TestClientFactory = Callable[[ASGIApp], TestClient]
from tests.conftest import TestClientFactory


class CustomMiddleware(BaseHTTPMiddleware):
Expand Down Expand Up @@ -372,8 +372,8 @@ async def send(message: Message) -> None:
{"body": b"Hello", "more_body": True, "type": "http.response.body"},
{"body": b"", "more_body": False, "type": "http.response.body"},
"Background task started",
"Background task started",
"Background task finished",
"Background task started",
"Background task finished",
]

Expand Down Expand Up @@ -1035,3 +1035,97 @@ async def endpoint(request: Request) -> Response:
resp.raise_for_status()

assert bodies == [b"Hello, World!-foo"]


@pytest.mark.anyio
async def test_background_tasks_client_disconnect() -> None:
# test for https://github.com/encode/starlette/issues/1438
container: list[str] = []

disconnected = anyio.Event()

async def slow_background() -> None:
# small delay to give BaseHTTPMiddleware a chance to cancel us
# this is required to make the test fail prior to fixing the issue
# so do not be surprised if you remove it and the test still passes
await anyio.sleep(0.1)
container.append("called")

app: ASGIApp
app = PlainTextResponse("hi!", background=BackgroundTask(slow_background))

async def dispatch(
request: Request, call_next: RequestResponseEndpoint
) -> Response:
return await call_next(request)

app = BaseHTTPMiddleware(app, dispatch=dispatch)

app = BackgroundTaskMiddleware(app)

async def recv_gen() -> AsyncGenerator[Message, None]:
yield {"type": "http.request"}
await disconnected.wait()
while True:
yield {"type": "http.disconnect"}

async def send_gen() -> AsyncGenerator[None, Message]:
while True:
msg = yield
if msg["type"] == "http.response.body" and not msg.get("more_body", False):
disconnected.set()

scope = {"type": "http", "method": "GET", "path": "/"}

async with AsyncExitStack() as stack:
recv = recv_gen()
stack.push_async_callback(recv.aclose)
send = send_gen()
stack.push_async_callback(send.aclose)
await send.__anext__()
await app(scope, recv.__aiter__().__anext__, send.asend)

assert container == ["called"]


@pytest.mark.anyio
async def test_background_tasks_failure(
test_client_factory: TestClientFactory,
anyio_backend_name: Literal["asyncio", "trio"],
) -> None:
if anyio_backend_name == "trio":
pytest.skip("this test hangs with trio")

# test for https://github.com/encode/starlette/discussions/2640
container: list[str] = []

async def task1() -> None:
container.append("task1 called")
raise ValueError("task1 failed")

async def task2() -> None:
container.append("task2 called") # pragma: no cover

async def endpoint(request: Request) -> Response:
background = BackgroundTasks()
background.add_task(task1)
background.add_task(task2)
return PlainTextResponse("hi!", background=background)

async def dispatch(
request: Request, call_next: RequestResponseEndpoint
) -> Response:
return await call_next(request)

app = Starlette(
routes=[Route("/", endpoint)],
middleware=[Middleware(BaseHTTPMiddleware, dispatch=dispatch)],
)

client = test_client_factory(app, raise_server_exceptions=False)

response = client.get("/")
assert response.status_code == 200
assert response.text == "hi!"

assert container == ["task1 called"]
Loading