diff --git a/openapi_core/contrib/requests/requests.py b/openapi_core/contrib/requests/requests.py index 00a462f5..c8409c59 100644 --- a/openapi_core/contrib/requests/requests.py +++ b/openapi_core/contrib/requests/requests.py @@ -1,4 +1,5 @@ """OpenAPI core contrib requests requests module""" +from typing import Mapping from typing import Optional from typing import Union from urllib.parse import parse_qs @@ -7,6 +8,7 @@ from requests import PreparedRequest from requests import Request from requests.cookies import RequestsCookieJar +from requests.utils import rewind_body from werkzeug.datastructures import Headers from werkzeug.datastructures import ImmutableMultiDict @@ -28,7 +30,9 @@ def __init__(self, request: Union[Request, PreparedRequest]): "'request' argument is not type of " f"{Request} or {PreparedRequest}" ) + self._request = None if isinstance(request, Request): + self._request = request request = request.prepare() self.request = request @@ -65,13 +69,28 @@ def method(self) -> str: @property def body(self) -> Optional[str]: - if self.request.body is None: + import ipdb; ipdb.set_trace() + body = self.request.body + if body is None: return None - if isinstance(self.request.body, bytes): - return self.request.body.decode("utf-8") - assert isinstance(self.request.body, str) + chunks = None + is_stream = all( + [ + hasattr(body, "__iter__"), + not isinstance(body, (bytes, str, list, tuple, Mapping)), + ] + ) + if is_stream: + chunks = list(body) + body = b"".join(chunks) + if isinstance(body, bytes): + body = body.decode("utf-8") + assert isinstance(body, str) + # recreate request stream from evaluated chunks + if chunks is not None: + self.request.body = (x for x in chunks) # TODO: figure out if request._body_position is relevant - return self.request.body + return body @property def mimetype(self) -> str: diff --git a/tests/integration/contrib/requests/test_requests_validation.py b/tests/integration/contrib/requests/test_requests_validation.py index 2e8aee8c..bf34beb6 100644 --- a/tests/integration/contrib/requests/test_requests_validation.py +++ b/tests/integration/contrib/requests/test_requests_validation.py @@ -1,3 +1,5 @@ +from types import GeneratorType + import pytest import requests import responses @@ -9,6 +11,8 @@ from openapi_core.contrib.requests import RequestsOpenAPIRequest from openapi_core.contrib.requests import RequestsOpenAPIResponse from openapi_core.contrib.requests import RequestsOpenAPIWebhookRequest +from openapi_core.datatypes import Parameters +from openapi_core.datatypes import RequestParameters class TestRequestsOpenAPIValidation: @@ -72,6 +76,36 @@ def test_request_validator_path_pattern(self, request_unmarshaller): result = request_unmarshaller.unmarshal(openapi_request) assert not result.errors + def test_request_validator_encoded_chunks(self, request_unmarshaller): + request_chunks = [ + b'{', + b'"param1": 1', + b'}', + ] + def gen(): + for chunk in request_chunks: + yield chunk + + request = requests.Request( + "POST", + "http://localhost/browse/12/", + params={"q": "string"}, + headers={"content-type": "application/json"}, + data=gen(), + ) + openapi_request = RequestsOpenAPIRequest(request) + result = request_unmarshaller.unmarshal(openapi_request) + assert not result.errors + assert result.body == {"param1": 1} + assert result.parameters == Parameters( + query={"q": "string"}, + header={}, + cookie={}, + path={"id": 12}, + ) + assert isinstance(request.data, GeneratorType) + assert list(request.data) == request_chunks + def test_request_validator_prepared_request(self, request_unmarshaller): request = requests.Request( "POST",