Skip to content

Commit

Permalink
Drop trio from test suite
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Jan 1, 2025
1 parent 7c0d1e6 commit d5ede70
Show file tree
Hide file tree
Showing 27 changed files with 693 additions and 854 deletions.
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ types-contextvars==2.4.7.3
types-PyYAML==6.0.12.20240917
types-dataclasses==0.6.6
pytest==8.3.4
trio==0.27.0

# Documentation
black==24.10.0
Expand Down
19 changes: 3 additions & 16 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,10 @@
from __future__ import annotations

import functools
from typing import Any, Literal
from typing import Literal

import pytest

from starlette.testclient import TestClient
from tests.types import TestClientFactory


@pytest.fixture
def test_client_factory(
anyio_backend_name: Literal["asyncio", "trio"],
anyio_backend_options: dict[str, Any],
) -> TestClientFactory:
# anyio_backend_name defined by:
# https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on
return functools.partial(
TestClient,
backend=anyio_backend_name,
backend_options=anyio_backend_options,
)
def anyio_backend() -> Literal["asyncio"]:
return "asyncio"
92 changes: 32 additions & 60 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket
from tests.types import TestClientFactory


class CustomMiddleware(BaseHTTPMiddleware):
Expand Down Expand Up @@ -84,8 +83,8 @@ async def websocket_endpoint(session: WebSocket) -> None:
)


def test_custom_middleware(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
def test_custom_middleware() -> None:
client = TestClient(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"

Expand All @@ -105,9 +104,7 @@ def test_custom_middleware(test_client_factory: TestClientFactory) -> None:
assert text == "Hello, world!"


def test_state_data_across_multiple_middlewares(
test_client_factory: TestClientFactory,
) -> None:
def test_state_data_across_multiple_middlewares() -> None:
expected_value1 = "foo"
expected_value2 = "bar"

Expand Down Expand Up @@ -154,25 +151,25 @@ def homepage(request: Request) -> PlainTextResponse:
],
)

client = test_client_factory(app)
client = TestClient(app)
response = client.get("/")
assert response.text == "OK"
assert response.headers["X-State-Foo"] == expected_value1
assert response.headers["X-State-Bar"] == expected_value2


def test_app_middleware_argument(test_client_factory: TestClientFactory) -> None:
def test_app_middleware_argument() -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage")

app = Starlette(routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)])

client = test_client_factory(app)
client = TestClient(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"


def test_fully_evaluated_response(test_client_factory: TestClientFactory) -> None:
def test_fully_evaluated_response() -> None:
# Test for https://github.com/encode/starlette/issues/1022
class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(
Expand All @@ -185,7 +182,7 @@ async def dispatch(

app = Starlette(middleware=[Middleware(CustomMiddleware)])

client = test_client_factory(app)
client = TestClient(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"

Expand Down Expand Up @@ -231,10 +228,7 @@ async def dispatch(
),
],
)
def test_contextvars(
test_client_factory: TestClientFactory,
middleware_cls: _MiddlewareFactory[Any],
) -> None:
def test_contextvars(middleware_cls: _MiddlewareFactory[Any]) -> None:
# this has to be an async endpoint because Starlette calls run_in_threadpool
# on sync endpoints which has it's own set of peculiarities w.r.t propagating
# contextvars (it propagates them forwards but not backwards)
Expand All @@ -245,7 +239,7 @@ async def homepage(request: Request) -> PlainTextResponse:

app = Starlette(middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)])

client = test_client_factory(app)
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200, response.content

Expand Down Expand Up @@ -411,9 +405,7 @@ async def send(message: Message) -> None:
assert context_manager_exited.is_set()


def test_app_receives_http_disconnect_while_sending_if_discarded(
test_client_factory: TestClientFactory,
) -> None:
def test_app_receives_http_disconnect_while_sending_if_discarded() -> None:
class DiscardingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
Expand Down Expand Up @@ -482,14 +474,12 @@ async def cancel_on_disconnect(

app = DiscardingMiddleware(downstream_app)

client = test_client_factory(app)
client = TestClient(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"


def test_app_receives_http_disconnect_after_sending_if_discarded(
test_client_factory: TestClientFactory,
) -> None:
def test_app_receives_http_disconnect_after_sending_if_discarded() -> None:
class DiscardingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
Expand Down Expand Up @@ -532,14 +522,12 @@ async def downstream_app(

app = DiscardingMiddleware(downstream_app)

client = test_client_factory(app)
client = TestClient(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"


def test_read_request_stream_in_app_after_middleware_calls_stream(
test_client_factory: TestClientFactory,
) -> None:
def test_read_request_stream_in_app_after_middleware_calls_stream() -> None:
async def homepage(request: Request) -> PlainTextResponse:
expected = [b""]
async for chunk in request.stream():
Expand All @@ -564,14 +552,12 @@ async def dispatch(
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
client: TestClient = TestClient(app)
response = client.post("/", content=b"a")
assert response.status_code == 200


def test_read_request_stream_in_app_after_middleware_calls_body(
test_client_factory: TestClientFactory,
) -> None:
def test_read_request_stream_in_app_after_middleware_calls_body() -> None:
async def homepage(request: Request) -> PlainTextResponse:
expected = [b"a", b""]
async for chunk in request.stream():
Expand All @@ -593,14 +579,12 @@ async def dispatch(
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
client: TestClient = TestClient(app)
response = client.post("/", content=b"a")
assert response.status_code == 200


def test_read_request_body_in_app_after_middleware_calls_stream(
test_client_factory: TestClientFactory,
) -> None:
def test_read_request_body_in_app_after_middleware_calls_stream() -> None:
async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b""
return PlainTextResponse("Homepage")
Expand All @@ -622,14 +606,12 @@ async def dispatch(
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
client: TestClient = TestClient(app)
response = client.post("/", content=b"a")
assert response.status_code == 200


def test_read_request_body_in_app_after_middleware_calls_body(
test_client_factory: TestClientFactory,
) -> None:
def test_read_request_body_in_app_after_middleware_calls_body() -> None:
async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b"a"
return PlainTextResponse("Homepage")
Expand All @@ -648,14 +630,12 @@ async def dispatch(
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
client: TestClient = TestClient(app)
response = client.post("/", content=b"a")
assert response.status_code == 200


def test_read_request_stream_in_dispatch_after_app_calls_stream(
test_client_factory: TestClientFactory,
) -> None:
def test_read_request_stream_in_dispatch_after_app_calls_stream() -> None:
async def homepage(request: Request) -> PlainTextResponse:
expected = [b"a", b""]
async for chunk in request.stream():
Expand All @@ -680,14 +660,12 @@ async def dispatch(
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
client: TestClient = TestClient(app)
response = client.post("/", content=b"a")
assert response.status_code == 200


def test_read_request_stream_in_dispatch_after_app_calls_body(
test_client_factory: TestClientFactory,
) -> None:
def test_read_request_stream_in_dispatch_after_app_calls_body() -> None:
async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b"a"
return PlainTextResponse("Homepage")
Expand All @@ -709,7 +687,7 @@ async def dispatch(
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
client: TestClient = TestClient(app)
response = client.post("/", content=b"a")
assert response.status_code == 200

Expand Down Expand Up @@ -773,9 +751,7 @@ async def send(msg: Message) -> None:
await rcv_stream.aclose()


def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next(
test_client_factory: TestClientFactory,
) -> None:
def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next() -> None:
async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b"a"
return PlainTextResponse("Homepage")
Expand All @@ -798,14 +774,12 @@ async def dispatch(
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
client: TestClient = TestClient(app)
response = client.post("/", content=b"a")
assert response.status_code == 200


def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next(
test_client_factory: TestClientFactory,
) -> None:
def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next() -> None:
async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b"a"
return PlainTextResponse("Homepage")
Expand All @@ -826,7 +800,7 @@ async def dispatch(
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
client: TestClient = TestClient(app)
response = client.post("/", content=b"a")
assert response.status_code == 200

Expand Down Expand Up @@ -917,9 +891,7 @@ async def send(msg: Message) -> None:
await rcv.aclose()


def test_downstream_middleware_modifies_receive(
test_client_factory: TestClientFactory,
) -> None:
def test_downstream_middleware_modifies_receive() -> None:
"""If a downstream middleware modifies receive() the final ASGI app
should see the modified version.
"""
Expand Down Expand Up @@ -952,7 +924,7 @@ async def wrapped_receive() -> Message:

return wrapped_app

client = test_client_factory(ConsumingMiddleware(modifying_middleware(endpoint)))
client = TestClient(ConsumingMiddleware(modifying_middleware(endpoint)))

resp = client.post("/", content=b"foo ")
assert resp.status_code == 200
Expand Down
Loading

0 comments on commit d5ede70

Please sign in to comment.