diff --git a/starlette/middleware/gzip.py b/starlette/middleware/gzip.py index 0eb2bc8dff..03e6626035 100644 --- a/starlette/middleware/gzip.py +++ b/starlette/middleware/gzip.py @@ -95,8 +95,9 @@ async def send_with_gzip(self, message: Message) -> None: more_body = message.get("more_body", False) self.gzip_file.write(body) - self.gzip_file.flush() - if not more_body: + if more_body: + self.gzip_file.flush() + else: self.gzip_file.close() message["body"] = self.gzip_buffer.getvalue() diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py index b20a7cb84b..57d558dcbd 100644 --- a/tests/middleware/test_gzip.py +++ b/tests/middleware/test_gzip.py @@ -1,9 +1,12 @@ +from __future__ import annotations + from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.gzip import GZipMiddleware from starlette.requests import Request from starlette.responses import ContentStream, PlainTextResponse, StreamingResponse from starlette.routing import Route +from starlette.types import ASGIApp, Message, Receive, Scope, Send from tests.types import TestClientFactory @@ -61,6 +64,24 @@ def homepage(request: Request) -> PlainTextResponse: def test_gzip_streaming_response(test_client_factory: TestClientFactory) -> None: + class VerifyingMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + self.send = send + self.received_chunks: list[bytes] = [] + await self.app(scope, receive, self.sent_from_gzip) + assert all(chunk != b"" for chunk in self.received_chunks) + assert len(self.received_chunks) == 11 + + async def sent_from_gzip(self, message: Message) -> None: + message_type = message["type"] + if message_type == "http.response.body": + body = message.get("body", b"") + self.received_chunks.append(body) + await self.send(message) + def homepage(request: Request) -> StreamingResponse: async def generator(bytes: bytes, count: int) -> ContentStream: for index in range(count): @@ -71,11 +92,15 @@ async def generator(bytes: bytes, count: int) -> ContentStream: app = Starlette( routes=[Route("/", endpoint=homepage)], - middleware=[Middleware(GZipMiddleware)], + middleware=[ + Middleware(VerifyingMiddleware), + Middleware(GZipMiddleware), + ], ) client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "gzip"}) + assert response.status_code == 200 assert response.text == "x" * 4000 assert response.headers["Content-Encoding"] == "gzip"