From 8eab239aff194d216f7ec90a2f3fbb31db156404 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 10 Jun 2023 10:28:00 -0500 Subject: [PATCH 1/2] Move BackgroundTask execution outside of request/response cycle --- starlette/applications.py | 2 + starlette/middleware/background.py | 37 ++++++++ starlette/responses.py | 19 ++++ tests/middleware/test_base.py | 142 ++++++++++++++++++++++++++++- tests/test_background.py | 124 ++++++++++++++++++------- tests/test_responses.py | 10 +- 6 files changed, 298 insertions(+), 36 deletions(-) create mode 100644 starlette/middleware/background.py diff --git a/starlette/applications.py b/starlette/applications.py index 913fd4c9d..076b4d25f 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -11,6 +11,7 @@ from starlette.datastructures import State, URLPath from starlette.middleware import Middleware, _MiddlewareClass +from starlette.middleware.background import BackgroundTaskMiddleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.errors import ServerErrorMiddleware from starlette.middleware.exceptions import ExceptionMiddleware @@ -96,6 +97,7 @@ def build_middleware_stack(self) -> ASGIApp: middleware = ( [Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)] + + [Middleware(BackgroundTaskMiddleware)] + self.user_middleware + [ Middleware( diff --git a/starlette/middleware/background.py b/starlette/middleware/background.py new file mode 100644 index 000000000..13e18f049 --- /dev/null +++ b/starlette/middleware/background.py @@ -0,0 +1,37 @@ +from typing import List, cast + +from starlette.background import BackgroundTask +from starlette.types import ASGIApp, Receive, Scope, Send + +# consider this a private implementation detail subject to change +# do not rely on this key +_SCOPE_KEY = "starlette._background" + + +_BackgroundTaskList = List[BackgroundTask] + + +def is_background_task_middleware_installed(scope: Scope) -> bool: + return _SCOPE_KEY in scope + + +def add_tasks(scope: Scope, task: BackgroundTask, /) -> None: + if _SCOPE_KEY not in scope: # pragma: no cover + raise RuntimeError( + "`add_tasks` can only be used if `BackgroundTaskMIddleware is installed" + ) + cast(_BackgroundTaskList, scope[_SCOPE_KEY]).append(task) + + +class BackgroundTaskMiddleware: + def __init__(self, app: ASGIApp) -> None: + self._app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + tasks: _BackgroundTaskList + scope[_SCOPE_KEY] = tasks = [] + try: + await self._app(scope, receive, send) + finally: + for task in tasks: + await task() diff --git a/starlette/responses.py b/starlette/responses.py index a6975747b..ce09d69e4 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -16,6 +16,7 @@ import anyio.to_thread from starlette._compat import md5_hexdigest +from starlette.middleware import background from starlette.background import BackgroundTask from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import URL, MutableHeaders @@ -148,6 +149,12 @@ def delete_cookie( ) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if ( + self.background is not None + and background.is_background_task_middleware_installed(scope) + ): + background.add_tasks(scope, self.background) + self.background = None prefix = "websocket." if scope["type"] == "websocket" else "" await send( { @@ -255,6 +262,12 @@ async def stream_response(self, send: Send) -> None: await send({"type": "http.response.body", "body": b"", "more_body": False}) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if ( + self.background is not None + and background.is_background_task_middleware_installed(scope) + ): + background.add_tasks(scope, self.background) + self.background = None async with anyio.create_task_group() as task_group: async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: @@ -322,6 +335,12 @@ def set_stat_headers(self, stat_result: os.stat_result) -> None: self.headers.setdefault("etag", etag) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if ( + self.background is not None + and background.is_background_task_middleware_installed(scope) + ): + background.add_tasks(scope, self.background) + self.background = None if self.stat_result is None: try: stat_result = await anyio.to_thread.run_sync(os.stat, self.path) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 2176404d8..fd8bb2268 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -14,8 +14,9 @@ from anyio.abc import TaskStatus from starlette.applications import Starlette -from starlette.background import BackgroundTask +from starlette.background import BackgroundTask, BackgroundTasks from starlette.middleware import Middleware, _MiddlewareClass +from starlette.middleware.background import BackgroundTaskMiddleware from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.responses import PlainTextResponse, Response, StreamingResponse @@ -1035,3 +1036,142 @@ async def endpoint(request: Request) -> Response: resp.raise_for_status() assert bodies == [b"Hello, World!-foo"] + +@pytest.mark.anyio +async def test_background_tasks_client_disconnect() -> None: + # test for https://github.com/encode/starlette/issues/1438 + container: list[str] = [] + + disconnected = anyio.Event() + + async def slow_background() -> None: + # small delay to give BaseHTTPMiddleware a chance to cancel us + # this is required to make the test fail prior to fixing the issue + # so do not be surprised if you remove it and the test still passes + await anyio.sleep(0.1) + container.append("called") + + app: ASGIApp + app = PlainTextResponse("hi!", background=BackgroundTask(slow_background)) + + async def dispatch( + request: Request, call_next: RequestResponseEndpoint + ) -> Response: + return await call_next(request) + + app = BaseHTTPMiddleware(app, dispatch=dispatch) + + app = BackgroundTaskMiddleware(app) + + async def recv_gen() -> AsyncGenerator[Message, None]: + yield {"type": "http.request"} + await disconnected.wait() + while True: + yield {"type": "http.disconnect"} + + async def send_gen() -> AsyncGenerator[None, Message]: + while True: + msg = yield + if msg["type"] == "http.response.body" and not msg.get("more_body", False): + disconnected.set() + + scope = {"type": "http", "method": "GET", "path": "/"} + + async with AsyncExitStack() as stack: + recv = recv_gen() + stack.push_async_callback(recv.aclose) + send = send_gen() + stack.push_async_callback(send.aclose) + await send.__anext__() + await app(scope, recv.__aiter__().__anext__, send.asend) + + assert container == ["called"] + +@pytest.mark.anyio +async def test_background_tasks_client_disconnect() -> None: + # test for https://github.com/encode/starlette/issues/1438 + container: list[str] = [] + + disconnected = anyio.Event() + + async def slow_background() -> None: + # small delay to give BaseHTTPMiddleware a chance to cancel us + # this is required to make the test fail prior to fixing the issue + # so do not be surprised if you remove it and the test still passes + await anyio.sleep(0.1) + container.append("called") + + app: ASGIApp + app = PlainTextResponse("hi!", background=BackgroundTask(slow_background)) + + async def dispatch( + request: Request, call_next: RequestResponseEndpoint + ) -> Response: + return await call_next(request) + + app = BaseHTTPMiddleware(app, dispatch=dispatch) + + app = BackgroundTaskMiddleware(app) + + async def recv_gen() -> AsyncGenerator[Message, None]: + yield {"type": "http.request"} + await disconnected.wait() + while True: + yield {"type": "http.disconnect"} + + async def send_gen() -> AsyncGenerator[None, Message]: + while True: + msg = yield + if msg["type"] == "http.response.body" and not msg.get("more_body", False): + disconnected.set() + + scope = {"type": "http", "method": "GET", "path": "/"} + + async with AsyncExitStack() as stack: + recv = recv_gen() + stack.push_async_callback(recv.aclose) + send = send_gen() + stack.push_async_callback(send.aclose) + await send.__anext__() + await app(scope, recv.__aiter__().__anext__, send.asend) + + assert container == ["called"] + + +@pytest.mark.anyio +async def test_background_tasks_failure( + test_client_factory: TestClientFactory, +) -> None: + # test for https://github.com/encode/starlette/discussions/2640 + container: list[str] = [] + + def task1() -> None: + container.append("task1 called") + raise ValueError("task1 failed") + + def task2() -> None: + container.append("task2 called") + + async def endpoint(request: Request) -> Response: + background = BackgroundTasks() + background.add_task(task1) + background.add_task(task2) + return PlainTextResponse("hi!", background=background) + + async def dispatch( + request: Request, call_next: RequestResponseEndpoint + ) -> Response: + return await call_next(request) + + app = Starlette( + routes=[Route("/", endpoint)], + middleware=[Middleware(BaseHTTPMiddleware, dispatch=dispatch)], + ) + + client = test_client_factory(app) + + response = client.get("/") + assert response.status_code == 200 + assert response.text == "hi!" + + assert container == ["task1 called"] diff --git a/tests/test_background.py b/tests/test_background.py index 846deecfd..51504ad35 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -1,40 +1,100 @@ -from typing import Callable +from __future__ import annotations + +from tempfile import NamedTemporaryFile +from typing import Any, AsyncIterable, Callable import pytest from starlette.background import BackgroundTask, BackgroundTasks -from starlette.responses import Response +from starlette.middleware.background import BackgroundTaskMiddleware +from starlette.responses import FileResponse, Response, StreamingResponse from starlette.testclient import TestClient -from starlette.types import Receive, Scope, Send +from starlette.types import ASGIApp, Receive, Scope, Send -TestClientFactory = Callable[..., TestClient] +TestClientFactory = Callable[[ASGIApp], TestClient] -def test_async_task(test_client_factory: TestClientFactory) -> None: - TASK_COMPLETE = False +@pytest.fixture( + params=[[], [BackgroundTaskMiddleware]], + ids=["without BackgroundTaskMiddleware", "with BackgroundTaskMiddleware"], +) +def test_client_factory_mw( + test_client_factory: TestClientFactory, request: Any +) -> TestClientFactory: + mw_stack: list[Callable[[ASGIApp], ASGIApp]] = request.param - async def async_task() -> None: - nonlocal TASK_COMPLETE - TASK_COMPLETE = True + def client_factory(app: ASGIApp) -> TestClient: + for mw in mw_stack: + app = mw(app) + return test_client_factory(app) - task = BackgroundTask(async_task) + return client_factory - async def app(scope: Scope, receive: Receive, send: Send) -> None: - response = Response("task initiated", media_type="text/plain", background=task) + +def response_app_factory(task: BackgroundTask) -> ASGIApp: + async def app(scope: Scope, receive: Receive, send: Send): + response = Response(b"task initiated", media_type="text/plain", background=task) await response(scope, receive, send) - client = test_client_factory(app) + return app + + +def file_response_app_factory(task: BackgroundTask) -> ASGIApp: + async def app(scope: Scope, receive: Receive, send: Send): + with NamedTemporaryFile("wb+") as f: + f.write(b"task initiated") + f.seek(0) + response = FileResponse(f.name, media_type="text/plain", background=task) + await response(scope, receive, send) + + return app + + +def streaming_response_app_factory(task: BackgroundTask) -> ASGIApp: + async def app(scope: Scope, receive: Receive, send: Send): + async def stream() -> AsyncIterable[bytes]: + yield b"task initiated" + + response = StreamingResponse(stream(), media_type="text/plain", background=task) + await response(scope, receive, send) + + return app + + +@pytest.mark.parametrize( + "app_factory", + [ + response_app_factory, + streaming_response_app_factory, + file_response_app_factory, + ], +) +def test_async_task( + test_client_factory_mw: TestClientFactory, + app_factory: Callable[[BackgroundTask], ASGIApp], +): + task_complete = False + + async def async_task(): + nonlocal task_complete + task_complete = True + + task = BackgroundTask(async_task) + + app = app_factory(task) + + client = test_client_factory_mw(app) response = client.get("/") assert response.text == "task initiated" - assert TASK_COMPLETE + assert task_complete -def test_sync_task(test_client_factory: TestClientFactory) -> None: - TASK_COMPLETE = False +def test_sync_task(test_client_factory: TestClientFactory): + task_complete = False - def sync_task() -> None: - nonlocal TASK_COMPLETE - TASK_COMPLETE = True + def sync_task(): + nonlocal task_complete + task_complete = True task = BackgroundTask(sync_task) @@ -45,15 +105,15 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: client = test_client_factory(app) response = client.get("/") assert response.text == "task initiated" - assert TASK_COMPLETE + assert task_complete -def test_multiple_tasks(test_client_factory: TestClientFactory) -> None: - TASK_COUNTER = 0 +def test_multiple_tasks(test_client_factory: TestClientFactory): + task_counter = 0 - def increment(amount: int) -> None: - nonlocal TASK_COUNTER - TASK_COUNTER += amount + def increment(amount: int): + nonlocal task_counter + task_counter += amount async def app(scope: Scope, receive: Receive, send: Send) -> None: tasks = BackgroundTasks() @@ -68,18 +128,18 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: client = test_client_factory(app) response = client.get("/") assert response.text == "tasks initiated" - assert TASK_COUNTER == 1 + 2 + 3 + assert task_counter == 1 + 2 + 3 def test_multi_tasks_failure_avoids_next_execution( test_client_factory: TestClientFactory, ) -> None: - TASK_COUNTER = 0 + task_counter = 0 - def increment() -> None: - nonlocal TASK_COUNTER - TASK_COUNTER += 1 - if TASK_COUNTER == 1: + def increment(): + nonlocal task_counter + task_counter += 1 + if task_counter == 1: raise Exception("task failed") async def app(scope: Scope, receive: Receive, send: Send) -> None: @@ -94,4 +154,4 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: client = test_client_factory(app) with pytest.raises(Exception): client.get("/") - assert TASK_COUNTER == 1 + assert task_counter == 1 diff --git a/tests/test_responses.py b/tests/test_responses.py index 434cc5a22..9aa415c83 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -12,7 +12,11 @@ from starlette import status from starlette.background import BackgroundTask +<<<<<<< HEAD from starlette.datastructures import Headers +======= +from starlette.middleware.background import BackgroundTaskMiddleware +>>>>>>> c19dd7c (Move BackgroundTask execution outside of request/response cycle) from starlette.requests import Request from starlette.responses import ( FileResponse, @@ -126,7 +130,7 @@ async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None: await response(scope, receive, send) assert filled_by_bg_task == "" - client = test_client_factory(app) + client = test_client_factory(BackgroundTaskMiddleware(app)) response = client.get("/") assert response.text == "1, 2, 3, 4, 5" assert filled_by_bg_task == "6, 7, 8, 9" @@ -152,7 +156,7 @@ async def __anext__(self) -> str: response = StreamingResponse(CustomAsyncIterator(), media_type="text/plain") await response(scope, receive, send) - client = test_client_factory(app) + client = test_client_factory(BackgroundTaskMiddleware(app)) response = client.get("/") assert response.text == "12345" @@ -245,7 +249,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await response(scope, receive, send) assert filled_by_bg_task == "" - client = test_client_factory(app) + client = test_client_factory(BackgroundTaskMiddleware(app)) response = client.get("/") expected_disposition = 'attachment; filename="example.png"' assert response.status_code == status.HTTP_200_OK From 33c20a5d1318e1db5cc0967c8216eba72101a5e5 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 16 Jul 2024 07:36:25 -0700 Subject: [PATCH 2/2] Fix --- starlette/responses.py | 2 +- tests/middleware/test_base.py | 68 ++++++----------------------------- tests/test_background.py | 20 +++++------ tests/test_responses.py | 3 -- 4 files changed, 22 insertions(+), 71 deletions(-) diff --git a/starlette/responses.py b/starlette/responses.py index ce09d69e4..911c604e5 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -16,10 +16,10 @@ import anyio.to_thread from starlette._compat import md5_hexdigest -from starlette.middleware import background from starlette.background import BackgroundTask from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import URL, MutableHeaders +from starlette.middleware import background from starlette.types import Receive, Scope, Send diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index fd8bb2268..3860c1fcb 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -5,8 +5,8 @@ from typing import ( Any, AsyncGenerator, - Callable, Generator, + Literal, ) import anyio @@ -24,8 +24,7 @@ from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket - -TestClientFactory = Callable[[ASGIApp], TestClient] +from tests.conftest import TestClientFactory class CustomMiddleware(BaseHTTPMiddleware): @@ -373,8 +372,8 @@ async def send(message: Message) -> None: {"body": b"Hello", "more_body": True, "type": "http.response.body"}, {"body": b"", "more_body": False, "type": "http.response.body"}, "Background task started", - "Background task started", "Background task finished", + "Background task started", "Background task finished", ] @@ -1037,55 +1036,6 @@ async def endpoint(request: Request) -> Response: assert bodies == [b"Hello, World!-foo"] -@pytest.mark.anyio -async def test_background_tasks_client_disconnect() -> None: - # test for https://github.com/encode/starlette/issues/1438 - container: list[str] = [] - - disconnected = anyio.Event() - - async def slow_background() -> None: - # small delay to give BaseHTTPMiddleware a chance to cancel us - # this is required to make the test fail prior to fixing the issue - # so do not be surprised if you remove it and the test still passes - await anyio.sleep(0.1) - container.append("called") - - app: ASGIApp - app = PlainTextResponse("hi!", background=BackgroundTask(slow_background)) - - async def dispatch( - request: Request, call_next: RequestResponseEndpoint - ) -> Response: - return await call_next(request) - - app = BaseHTTPMiddleware(app, dispatch=dispatch) - - app = BackgroundTaskMiddleware(app) - - async def recv_gen() -> AsyncGenerator[Message, None]: - yield {"type": "http.request"} - await disconnected.wait() - while True: - yield {"type": "http.disconnect"} - - async def send_gen() -> AsyncGenerator[None, Message]: - while True: - msg = yield - if msg["type"] == "http.response.body" and not msg.get("more_body", False): - disconnected.set() - - scope = {"type": "http", "method": "GET", "path": "/"} - - async with AsyncExitStack() as stack: - recv = recv_gen() - stack.push_async_callback(recv.aclose) - send = send_gen() - stack.push_async_callback(send.aclose) - await send.__anext__() - await app(scope, recv.__aiter__().__anext__, send.asend) - - assert container == ["called"] @pytest.mark.anyio async def test_background_tasks_client_disconnect() -> None: @@ -1141,16 +1091,20 @@ async def send_gen() -> AsyncGenerator[None, Message]: @pytest.mark.anyio async def test_background_tasks_failure( test_client_factory: TestClientFactory, + anyio_backend_name: Literal["asyncio", "trio"], ) -> None: + if anyio_backend_name == "trio": + pytest.skip("this test hangs with trio") + # test for https://github.com/encode/starlette/discussions/2640 container: list[str] = [] - def task1() -> None: + async def task1() -> None: container.append("task1 called") raise ValueError("task1 failed") - def task2() -> None: - container.append("task2 called") + async def task2() -> None: + container.append("task2 called") # pragma: no cover async def endpoint(request: Request) -> Response: background = BackgroundTasks() @@ -1168,7 +1122,7 @@ async def dispatch( middleware=[Middleware(BaseHTTPMiddleware, dispatch=dispatch)], ) - client = test_client_factory(app) + client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 200 diff --git a/tests/test_background.py b/tests/test_background.py index 51504ad35..2ebe75a8a 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -32,7 +32,7 @@ def client_factory(app: ASGIApp) -> TestClient: def response_app_factory(task: BackgroundTask) -> ASGIApp: - async def app(scope: Scope, receive: Receive, send: Send): + async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response(b"task initiated", media_type="text/plain", background=task) await response(scope, receive, send) @@ -40,7 +40,7 @@ async def app(scope: Scope, receive: Receive, send: Send): def file_response_app_factory(task: BackgroundTask) -> ASGIApp: - async def app(scope: Scope, receive: Receive, send: Send): + async def app(scope: Scope, receive: Receive, send: Send) -> None: with NamedTemporaryFile("wb+") as f: f.write(b"task initiated") f.seek(0) @@ -51,7 +51,7 @@ async def app(scope: Scope, receive: Receive, send: Send): def streaming_response_app_factory(task: BackgroundTask) -> ASGIApp: - async def app(scope: Scope, receive: Receive, send: Send): + async def app(scope: Scope, receive: Receive, send: Send) -> None: async def stream() -> AsyncIterable[bytes]: yield b"task initiated" @@ -72,10 +72,10 @@ async def stream() -> AsyncIterable[bytes]: def test_async_task( test_client_factory_mw: TestClientFactory, app_factory: Callable[[BackgroundTask], ASGIApp], -): +) -> None: task_complete = False - async def async_task(): + async def async_task() -> None: nonlocal task_complete task_complete = True @@ -89,10 +89,10 @@ async def async_task(): assert task_complete -def test_sync_task(test_client_factory: TestClientFactory): +def test_sync_task(test_client_factory: TestClientFactory) -> None: task_complete = False - def sync_task(): + def sync_task() -> None: nonlocal task_complete task_complete = True @@ -108,10 +108,10 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert task_complete -def test_multiple_tasks(test_client_factory: TestClientFactory): +def test_multiple_tasks(test_client_factory: TestClientFactory) -> None: task_counter = 0 - def increment(amount: int): + def increment(amount: int) -> None: nonlocal task_counter task_counter += amount @@ -136,7 +136,7 @@ def test_multi_tasks_failure_avoids_next_execution( ) -> None: task_counter = 0 - def increment(): + def increment() -> None: nonlocal task_counter task_counter += 1 if task_counter == 1: diff --git a/tests/test_responses.py b/tests/test_responses.py index 9aa415c83..a1df737b8 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -12,11 +12,8 @@ from starlette import status from starlette.background import BackgroundTask -<<<<<<< HEAD from starlette.datastructures import Headers -======= from starlette.middleware.background import BackgroundTaskMiddleware ->>>>>>> c19dd7c (Move BackgroundTask execution outside of request/response cycle) from starlette.requests import Request from starlette.responses import ( FileResponse,