Skip to content

Commit

Permalink
Fetch original path from aiohttp.response (#92)
Browse files Browse the repository at this point in the history
Fixes backward-compatibility issue with aiohttp >= 3.8.2
  • Loading branch information
vitek authored Apr 20, 2024
1 parent f6c8164 commit 2fb5261
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 5 deletions.
16 changes: 16 additions & 0 deletions tests/plugins/mockserver/test_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,19 @@ def test(request: fixture_types.MockserverRequest):
response = await mockserver_client.get('test123')
assert response.status_code == 200
assert test.next_call()


async def test_request_encoding(
mockserver: fixture_types.MockserverFixture,
mockserver_client: service_client.Client,
):
@mockserver.handler('/test', prefix=True)
def mock(request: fixture_types.MockserverRequest):
return mockserver.make_response(
'test', 200, content_type='text/csv', charset='utf-16le'
)

response = await mockserver_client.get('test')
assert response.status_code == 200
assert response.encoding == 'utf-16-le'
assert response.content_type == 'text/csv'
54 changes: 54 additions & 0 deletions tests/plugins/mockserver/test_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import asyncio

import pytest

from testsuite._internal import fixture_types


@pytest.fixture
def simple_client(mockserver_info):
async def client(path: str):
reader, writer = await asyncio.open_connection(
mockserver_info.host, mockserver_info.port
)

request = f'GET {path} HTTP/1.0\r\n' f'\r\n'
writer.write(request.encode())
await writer.drain()

data = await reader.read()
writer.close()
await writer.wait_closed()

lines = data.splitlines()
assert len(lines) > 1
assert lines[0].decode('utf-8') == ('HTTP/1.0 200 OK'), data

return client


@pytest.mark.parametrize(
'mock_url, request_path, expected_path',
[
('/foo', '/foo/abc', '/foo/abc'),
('/foo', '/foo/abc?a=b', '/foo/abc'),
('/foo', '/foo/bar%20maurice?a=b', '/foo/bar maurice'),
('http://foo/', 'http://foo/abc', 'http://foo/abc'),
('http://foo/', 'http://foo/abc?a=b', 'http://foo/abc'),
('http://foo/', 'http://foo/bar%20maurice', 'http://foo/bar maurice'),
],
)
async def test_path_basic(
mockserver: fixture_types.MockserverFixture,
simple_client,
mock_url,
request_path,
expected_path,
):
@mockserver.aiohttp_json_handler(mock_url, prefix=True)
def mock(request):
pass

await simple_client(request_path)
mock_request = mock.next_call()['request']
assert mock_request.original_path == expected_path
8 changes: 7 additions & 1 deletion tests/plugins/mockserver/test_tracing_disabled.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def test_mockserver_raises_on_unhandled_request_from_other_sources(
tracing_enabled=False,
)
with mockserver.new_session() as session:
request = aiohttp.test_utils.make_mocked_request(
request = _make_mocked_request(
'POST',
'/arbitrary/path',
headers=http_headers,
Expand All @@ -56,3 +56,9 @@ async def test_mockserver_raises_on_unhandled_request_from_other_sources(
assert len(session._errors) == 1
error = session._errors.pop()
assert isinstance(error, exceptions.HandlerNotFoundError)


def _make_mocked_request(*args, **kwargs):
request = aiohttp.test_utils.make_mocked_request(*args, **kwargs)
request.original_path = request.path
return request
33 changes: 29 additions & 4 deletions testsuite/mockserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ssl
import time
import typing
import urllib.parse
import uuid
import warnings
import yarl
Expand Down Expand Up @@ -45,6 +46,13 @@
RouteParams = typing.Dict[str, str]


class MockserverRequest(aiohttp.web.BaseRequest):
# We need original path including scheme and hostname
def __init__(self, message, *args, **kwargs):
super().__init__(message, *args, **kwargs)
self.original_path = _path_from_message(message)


class Handler:
def __init__(self, func, *, raw_request=False, json_response=False):
self.raw_request = raw_request
Expand Down Expand Up @@ -201,13 +209,14 @@ def register_handler(

def _get_handler_for_request(
self,
request: aiohttp.web.BaseRequest,
request: MockserverRequest,
) -> typing.Tuple[Handler, RouteParams]:
path = request.original_path
if self.http_proxy_enabled:
host = request.headers.get('host')
if host and host != self.mockserver_host:
return self.get_handler(f'http://{host}{request.path}')
return self.get_handler(request.path)
return self.get_handler(f'http://{host}{path}')
return self.get_handler(path)


# pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -616,7 +625,15 @@ def _create_server_obj(mockserver_info, pytestconfig) -> Server:


def _create_web_server(server: Server, loop) -> aiohttp.web.Server:
return aiohttp.web.Server(server.handle_request, loop=loop, access_log=None)
def request_factory(*args):
return MockserverRequest(*args, loop=loop)

return aiohttp.web.Server(
server.handle_request,
request_factory=request_factory,
loop=loop,
access_log=None,
)


@compat.asynccontextmanager
Expand Down Expand Up @@ -707,3 +724,11 @@ def _is_from_client_fixture(trace_id: str) -> bool:

def _is_other_test(trace_id: str, current_trace_id: str) -> bool:
return trace_id != current_trace_id and _is_from_client_fixture(trace_id)


def _path_from_message(message):
"""Returns original HTTP path without query."""
path = str(message.url)
path = path.split('?')[0]
path = urllib.parse.unquote(path)
return path

0 comments on commit 2fb5261

Please sign in to comment.