-
-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cleanup patches after test in pytest plugin (#1148) #1164
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,17 @@ | ||
import inspect | ||
import sys | ||
from contextlib import contextmanager | ||
from contextlib import ExitStack, contextmanager | ||
from functools import partial, wraps | ||
from types import FrameType | ||
from types import FrameType, MappingProxyType | ||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, TypeVar, Union | ||
from unittest import mock | ||
|
||
import pytest | ||
from typing_extensions import Final, final | ||
|
||
if TYPE_CHECKING: | ||
from returns.interfaces.specific.result import ResultLikeN | ||
|
||
_ERROR_HANDLERS: Final = ( | ||
'lash', | ||
) | ||
_ERRORS_COPIERS: Final = ( | ||
'map', | ||
'alt', | ||
) | ||
|
||
# We keep track of errors handled by keeping a mapping of <object id>: object. | ||
# If an error is handled, it is in the mapping. | ||
# If it isn't in the mapping, the error is not handled. | ||
|
@@ -28,7 +21,7 @@ | |
# Also, the object itself cannot be (in) the key because | ||
# (1) we cannot always assume hashability and | ||
# (2) we need to track the object identity, not its value | ||
_ERRORS_HANDLED: Final[Dict[int, Any]] = {} # noqa: WPS407 | ||
_ErrorsHandled = Dict[int, Any] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, this is now a type alias. Let's leave a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was doubting as well whether to use it. Based on the PEP it seemed optional -- necessary in case of forward references, for example. On the other hand, if this is looking to be the default way of declaring type aliases (explicit!) then we can already import from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm something strange in
mypy doesn't support them yet at all yet? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there other features coming in 0.920 that returns finds useful? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure about |
||
|
||
_FunctionType = TypeVar('_FunctionType', bound=Callable) | ||
_ReturnsResultType = TypeVar( | ||
|
@@ -41,7 +34,11 @@ | |
class ReturnsAsserts(object): | ||
"""Class with helpers assertions to check containers.""" | ||
|
||
__slots__ = () | ||
__slots__ = ('_errors_handled', ) | ||
|
||
def __init__(self, errors_handled: _ErrorsHandled) -> None: | ||
"""Constructor for this type.""" | ||
self._errors_handled = errors_handled | ||
|
||
@staticmethod # noqa: WPS602 | ||
def assert_equal( # noqa: WPS602 | ||
|
@@ -55,10 +52,9 @@ def assert_equal( # noqa: WPS602 | |
from returns.primitives.asserts import assert_equal | ||
assert_equal(first, second, deps=deps, backend=backend) | ||
|
||
@staticmethod # noqa: WPS602 | ||
def is_error_handled(container) -> bool: # noqa: WPS602 | ||
def is_error_handled(self, container) -> bool: | ||
"""Ensures that container has its error handled in the end.""" | ||
return id(container) in _ERRORS_HANDLED | ||
return id(container) in self._errors_handled | ||
|
||
@staticmethod # noqa: WPS602 | ||
@contextmanager | ||
|
@@ -86,59 +82,6 @@ def assert_trace( # noqa: WPS602 | |
sys.settrace(old_tracer) | ||
|
||
|
||
@pytest.fixture(scope='session') | ||
def returns(_patch_containers) -> ReturnsAsserts: | ||
"""Returns our own class with helpers assertions to check containers.""" | ||
return ReturnsAsserts() | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def _clear_errors_handled(): | ||
"""Ensures the 'errors handled' registry doesn't leak memory.""" | ||
yield | ||
_ERRORS_HANDLED.clear() | ||
|
||
|
||
def pytest_configure(config) -> None: | ||
""" | ||
Hook to be executed on import. | ||
|
||
We use it define custom markers. | ||
""" | ||
config.addinivalue_line( | ||
'markers', | ||
( | ||
'returns_lawful: all tests under `check_all_laws` ' + | ||
'is marked this way, ' + | ||
'use `-m "not returns_lawful"` to skip them.' | ||
), | ||
) | ||
|
||
|
||
@pytest.fixture(scope='session') | ||
def _patch_containers() -> None: | ||
""" | ||
Fixture to add test specifics into our containers. | ||
|
||
Currently we inject: | ||
|
||
- Error handling state, this is required to test that ``Result``-based | ||
containers do handle errors | ||
|
||
Even more things to come! | ||
""" | ||
_patch_error_handling(_ERROR_HANDLERS, _PatchedContainer.error_handler) | ||
_patch_error_handling(_ERRORS_COPIERS, _PatchedContainer.copy_handler) | ||
|
||
|
||
def _patch_error_handling(methods, patch_handler) -> None: | ||
for container in _PatchedContainer.containers_to_patch(): | ||
for method in methods: | ||
original = getattr(container, method, None) | ||
if original: | ||
setattr(container, method, patch_handler(original)) | ||
|
||
|
||
def _trace_function( | ||
trace_type: _ReturnsResultType, | ||
function_to_search: _FunctionType, | ||
|
@@ -166,65 +109,107 @@ def _trace_function( | |
raise _DesiredFunctionFound() | ||
|
||
|
||
@final | ||
class _PatchedContainer(object): | ||
"""Class with helper methods to patched containers.""" | ||
|
||
__slots__ = () | ||
|
||
@classmethod | ||
def containers_to_patch(cls) -> tuple: | ||
"""We need this method so coverage will work correctly.""" | ||
from returns.context import ( | ||
RequiresContextFutureResult, | ||
RequiresContextIOResult, | ||
RequiresContextResult, | ||
) | ||
from returns.future import FutureResult | ||
from returns.io import IOFailure, IOSuccess | ||
from returns.result import Failure, Success | ||
|
||
return ( | ||
Success, | ||
Failure, | ||
IOSuccess, | ||
IOFailure, | ||
RequiresContextResult, | ||
RequiresContextIOResult, | ||
RequiresContextFutureResult, | ||
FutureResult, | ||
) | ||
class _DesiredFunctionFound(BaseException): # noqa: WPS418 | ||
"""Exception to raise when expected function is found.""" | ||
|
||
@classmethod | ||
def error_handler(cls, original): | ||
if inspect.iscoroutinefunction(original): | ||
async def factory(self, *args, **kwargs): | ||
original_result = await original(self, *args, **kwargs) | ||
_ERRORS_HANDLED[id(original_result)] = original_result | ||
return original_result | ||
else: | ||
def factory(self, *args, **kwargs): | ||
original_result = original(self, *args, **kwargs) | ||
_ERRORS_HANDLED[id(original_result)] = original_result | ||
return original_result | ||
return wraps(original)(factory) | ||
|
||
@classmethod | ||
def copy_handler(cls, original): | ||
if inspect.iscoroutinefunction(original): | ||
async def factory(self, *args, **kwargs): | ||
original_result = await original(self, *args, **kwargs) | ||
if id(self) in _ERRORS_HANDLED: | ||
_ERRORS_HANDLED[id(original_result)] = original_result | ||
return original_result | ||
else: | ||
def factory(self, *args, **kwargs): | ||
original_result = original(self, *args, **kwargs) | ||
if id(self) in _ERRORS_HANDLED: | ||
_ERRORS_HANDLED[id(original_result)] = original_result | ||
return original_result | ||
return wraps(original)(factory) | ||
|
||
def pytest_configure(config) -> None: | ||
""" | ||
Hook to be executed on import. | ||
|
||
class _DesiredFunctionFound(BaseException): # noqa: WPS418 | ||
"""Exception to raise when expected function is found.""" | ||
We use it define custom markers. | ||
""" | ||
config.addinivalue_line( | ||
'markers', | ||
( | ||
'returns_lawful: all tests under `check_all_laws` ' + | ||
'is marked this way, ' + | ||
'use `-m "not returns_lawful"` to skip them.' | ||
), | ||
) | ||
|
||
|
||
@pytest.fixture() | ||
def returns() -> Iterator[ReturnsAsserts]: | ||
"""Returns class with helpers assertions to check containers.""" | ||
with _spy_error_handling() as errors_handled: | ||
yield ReturnsAsserts(errors_handled) | ||
|
||
|
||
@contextmanager | ||
def _spy_error_handling() -> Iterator[_ErrorsHandled]: | ||
"""Track error handling of containers.""" | ||
errs: _ErrorsHandled = {} | ||
with ExitStack() as cleanup: | ||
for container in _containers_to_patch(): | ||
for method, patch in _ERROR_HANDLING_PATCHERS.items(): | ||
cleanup.enter_context(mock.patch.object( | ||
container, | ||
method, | ||
patch(getattr(container, method), errs=errs), | ||
)) | ||
yield errs | ||
|
||
|
||
# delayed imports are needed to prevent messing up coverage | ||
def _containers_to_patch() -> tuple: | ||
from returns.context import ( | ||
RequiresContextFutureResult, | ||
RequiresContextIOResult, | ||
RequiresContextResult, | ||
) | ||
from returns.future import FutureResult | ||
from returns.io import IOFailure, IOSuccess | ||
from returns.result import Failure, Success | ||
|
||
return ( | ||
Success, | ||
Failure, | ||
IOSuccess, | ||
IOFailure, | ||
RequiresContextResult, | ||
RequiresContextIOResult, | ||
RequiresContextFutureResult, | ||
FutureResult, | ||
) | ||
|
||
|
||
def _patched_error_handler( | ||
original: _FunctionType, errs: _ErrorsHandled, | ||
) -> _FunctionType: | ||
if inspect.iscoroutinefunction(original): | ||
async def wrapper(self, *args, **kwargs): | ||
original_result = await original(self, *args, **kwargs) | ||
errs[id(original_result)] = original_result | ||
return original_result | ||
else: | ||
def wrapper(self, *args, **kwargs): | ||
original_result = original(self, *args, **kwargs) | ||
errs[id(original_result)] = original_result | ||
return original_result | ||
return wraps(original)(wrapper) # type: ignore | ||
|
||
|
||
def _patched_error_copier( | ||
original: _FunctionType, errs: _ErrorsHandled, | ||
) -> _FunctionType: | ||
if inspect.iscoroutinefunction(original): | ||
async def wrapper(self, *args, **kwargs): | ||
original_result = await original(self, *args, **kwargs) | ||
if id(self) in errs: | ||
errs[id(original_result)] = original_result | ||
return original_result | ||
else: | ||
def wrapper(self, *args, **kwargs): | ||
original_result = original(self, *args, **kwargs) | ||
if id(self) in errs: | ||
errs[id(original_result)] = original_result | ||
return original_result | ||
return wraps(original)(wrapper) # type: ignore | ||
|
||
|
||
_ERROR_HANDLING_PATCHERS: Final = MappingProxyType({ | ||
'lash': _patched_error_handler, | ||
'map': _patched_error_copier, | ||
'alt': _patched_error_copier, | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extra thanks for this! 👍