Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add client parameter to TestClient #2810

Merged
merged 12 commits into from
Dec 28, 2024
5 changes: 5 additions & 0 deletions docs/testclient.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ case you should use `client = TestClient(app, raise_server_exceptions=False)`.
not be triggered when the `TestClient` is instantiated. You can learn more about it
[here](lifespan.md#running-lifespan-in-tests).

!!! note

By default the `TestClient` uses `testclient:50000` as the host and port.
To use a different client, you can pass the host and port as a tuple: `TestClient(client=(<host>, <port>)`.

Kludex marked this conversation as resolved.
Show resolved Hide resolved
### Selecting the Async backend

`TestClient` takes arguments `backend` (a string) and `backend_options` (a dictionary).
Expand Down
8 changes: 6 additions & 2 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,15 @@ def __init__(
raise_server_exceptions: bool = True,
root_path: str = "",
*,
client: tuple[str, int],
app_state: dict[str, typing.Any],
) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
self.root_path = root_path
self.portal_factory = portal_factory
self.app_state = app_state
self.client = client

def handle_request(self, request: httpx.Request) -> httpx.Response:
scheme = request.url.scheme
Expand Down Expand Up @@ -285,7 +287,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
"client": ["testclient", 50000],
"client": self.client,
"server": [host, port],
"subprotocols": subprotocols,
"state": self.app_state.copy(),
Expand All @@ -304,7 +306,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
"client": ["testclient", 50000],
"client": self.client,
"server": [host, port],
"extensions": {"http.response.debug": {}},
"state": self.app_state.copy(),
Expand Down Expand Up @@ -409,6 +411,7 @@ def __init__(
cookies: httpx._types.CookieTypes | None = None,
headers: dict[str, str] | None = None,
follow_redirects: bool = True,
client: tuple[str, int] = ("testclient", 50000),
) -> None:
self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
if _is_asgi3(app):
Expand All @@ -424,6 +427,7 @@ def __init__(
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
app_state=self.app_state,
client=client,
)
if headers is None:
headers = {}
Expand Down
13 changes: 13 additions & 0 deletions tests/test_testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,19 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
assert response.json() == {"host": "testclient", "port": 50000}


def test_client_custom_client(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
client = scope.get("client")
assert client is not None
host, port = client
response = JSONResponse({"host": host, "port": port})
await response(scope, receive, send)

client = test_client_factory(app, client=("192.168.0.1", 3000))
response = client.get("/")
assert response.json() == {"host": "192.168.0.1", "port": 3000}


@pytest.mark.parametrize("param", ("2020-07-14T00:00:00+00:00", "España", "voilà"))
def test_query_params(test_client_factory: TestClientFactory, param: str) -> None:
def homepage(request: Request) -> Response:
Expand Down
1 change: 1 addition & 0 deletions tests/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __call__(
cookies: httpx._types.CookieTypes | None = None,
headers: dict[str, str] | None = None,
follow_redirects: bool = True,
client: tuple[str, int] = ("testclient", 50000),
) -> TestClient: ...
else: # pragma: no cover

Expand Down
Loading